Compare commits

...

104 Commits

Author SHA1 Message Date
Luis Pater
34c8ccb961 Fixed: #437
feat(runtime): strip `service_tier` in GitHub Copilot response normalization
2026-03-13 11:50:21 +08:00
Luis Pater
d08e164af3 chore(runtime): remove unused FetchAntigravityModels function from executor 2026-03-13 11:38:44 +08:00
Luis Pater
8178efaeda Merge pull request #439 from router-for-me/plus
v6.8.52
2026-03-13 11:30:25 +08:00
Luis Pater
86d5db472a Merge branch 'main' into plus 2026-03-13 11:28:52 +08:00
Luis Pater
020d36f6e8 Merge pull request #433 from LuxVTZ/feat/gitlab-duo-auth-plus
Add GitLab Duo management OAuth and PAT endpoints
2026-03-13 11:25:19 +08:00
Luis Pater
1db23979e8 Merge pull request #2106 from router-for-me/model
feat(model_registry): enhance model registration and refresh mechanisms
2026-03-13 11:18:51 +08:00
hkfires
c3d5dbe96f feat(model_registry): enhance model registration and refresh mechanisms 2026-03-13 10:56:39 +08:00
Luis Pater
5484489406 chore(ci): update model catalog fetch method in workflows 2026-03-12 11:19:24 +08:00
Luis Pater
0ac52da460 chore(ci): update model catalog fetch method in release workflow 2026-03-12 10:50:46 +08:00
Luis Pater
817cebb321 Merge pull request #2082 from router-for-me/antigravity
Refactor Antigravity model handling and improve logging
2026-03-12 10:39:13 +08:00
Luis Pater
683f3709d6 Merge pull request #2076 from aikins01/fix/backfill-empty-function-response-names
fix: backfill empty functionResponse.name from preceding functionCall
2026-03-12 10:35:44 +08:00
hkfires
dbd42a42b2 fix(model_updater): clarify log message for model refresh failure 2026-03-12 10:32:04 +08:00
hkfires
ec24baf757 feat(fetch_antigravity_models): add command to fetch and save Antigravity model list 2026-03-12 10:21:09 +08:00
hkfires
dea3e74d35 feat(antigravity): refactor model handling and remove unused code 2026-03-12 09:24:45 +08:00
Aikins Laryea
a6c3042e34 refactor: remove redundant bounds checks per code review 2026-03-12 00:12:43 +00:00
Aikins Laryea
861537c9bd fix: backfill empty functionResponse.name from preceding functionCall
when Amp or Claude Code sends functionResponse with an empty name in Gemini
conversation history, the Gemini API rejects the request with 400
"Name cannot be empty". this fix backfills empty names from the
corresponding preceding functionCall parts using positional matching.

covers all three Gemini translator paths:
- gemini/gemini (direct API key)
- antigravity/gemini (OAuth)
- gemini-cli/gemini (Gemini CLI)

also switches fixCLIToolResponse pending group matching from LIFO to
FIFO to correctly handle multiple sequential tool call groups.

fixes #1903
2026-03-12 00:00:38 +00:00
Luis Pater
8c92cb0883 Merge pull request #2056 from lang-911/codex/custom-useragent-request
feat(config/codex): Add Codex header defaults (`user-agent`: override; `beta-features`: default)
2026-03-11 22:56:36 +08:00
Luis Pater
89d7be9525 Merge branch 'dev' into codex/custom-useragent-request 2026-03-11 22:55:50 +08:00
lang-911
2b79d7f22f fix: restore double quotes style in config.example.yaml for consistency and readability 2026-03-11 06:59:26 -07:00
LuxVTZ
2bb686f594 Add GitLab Duo management OAuth and PAT endpoints 2026-03-11 17:58:34 +04:00
lang-911
163fe287ce fix: codex header defaults example 2026-03-11 06:55:03 -07:00
lang-911
70988d387b Add Codex websocket header defaults 2026-03-11 00:34:57 -07:00
Luis Pater
52058a1659 docs: remove GitLab Duo sections from README and README_CN 2026-03-11 11:51:17 +08:00
Luis Pater
df5595a0c9 Merge pull request #428 from LuxVTZ/feat/gitlab-duo-auth-plus
Add GitLab Duo provider support
2026-03-11 11:50:02 +08:00
Luis Pater
ddaa9d2436 Fixed: #2034
feat(proxy): centralize proxy handling with `proxyutil` package and enhance test coverage

- Added `proxyutil` package to simplify proxy handling across the codebase.
- Refactored various components (`executor`, `cliproxy`, `auth`, etc.) to use `proxyutil` for consistent and reusable proxy logic.
- Introduced support for "direct" proxy mode to explicitly bypass all proxies.
- Updated tests to validate proxy behavior (e.g., `direct`, HTTP/HTTPS, and SOCKS5).
- Enhanced YAML configuration documentation for proxy options.
2026-03-11 11:08:02 +08:00
Luis Pater
7b7b258c38 Fixed: #2022
test(translator): add tests for handling Claude system messages as string and array
2026-03-11 10:47:33 +08:00
LuxVTZ
a00f774f5a Add GitLab Duo usage docs 2026-03-10 22:20:40 +04:00
LuxVTZ
9daf1ba8b5 test(gitlab): add duo openai handler smoke 2026-03-10 22:19:36 +04:00
LuxVTZ
76f2359637 test(gitlab): add duo claude handler smoke 2026-03-10 22:19:36 +04:00
LuxVTZ
dcb1c9be8a feat(gitlab): route duo openai via gateway 2026-03-10 22:19:36 +04:00
LuxVTZ
a24f4ace78 feat(gitlab): route duo anthropic via gateway 2026-03-10 22:19:36 +04:00
LuxVTZ
c631df8c3b feat(gitlab): add duo streaming transport 2026-03-10 22:19:36 +04:00
LuxVTZ
54c3eb1b1e Add GitLab Duo auth and executor support 2026-03-10 22:19:36 +04:00
LuxVTZ
bb28cd26ad Add GitLab Duo OAuth and PAT support 2026-03-10 22:18:54 +04:00
Luis Pater
046865461e Merge PR #424 from router-for-me/main 2026-03-10 19:19:29 +08:00
Luis Pater
cf74ed2f0c Merge pull request #2013 from router-for-me/model
Fetch model catalog from network
2026-03-10 19:07:23 +08:00
hkfires
e333fbea3d feat(updater): update StartModelsUpdater to block until models refresh completes 2026-03-10 14:41:58 +08:00
hkfires
efbe36d1d4 feat(updater): change models refresh to one-time fetch on startup 2026-03-10 14:18:54 +08:00
hkfires
8553cfa40e feat(workflows): refresh models catalog in workflows 2026-03-10 14:03:31 +08:00
hkfires
30d5c95b26 feat(registry): refresh model catalog from network 2026-03-10 14:02:54 +08:00
hkfires
d1e3195e6f feat(codex): register models by plan tier 2026-03-10 11:20:37 +08:00
Luis Pater
05a35662ae Merge branch 'router-for-me:main' into main 2026-03-09 23:05:51 +08:00
Luis Pater
ce53d3a287 Fixed: #1997
test(auth-scheduler): add benchmarks and priority-based scheduling improvements

- Added `BenchmarkManagerPickNextMixedPriority500` for mixed-priority performance assessment.
- Updated `pickNextMixed` to prioritize highest ready priority tiers.
- Introduced `highestReadyPriorityLocked` and `pickReadyAtPriorityLocked` for better scheduling logic.
- Added unit test to validate selection of highest priority tiers in mixed provider scenarios.
2026-03-09 22:27:15 +08:00
Luis Pater
4cc99e7449 Merge pull request #1992 from dcrdev/main
System prompt silently dropped when sent as a string
2026-03-09 21:03:15 +08:00
Luis Pater
71773fe032 Merge pull request #1996 from router-for-me/codex/fix-unbounded-websocket-log-buffering
fix: cap websocket body log growth in responses handler
2026-03-09 20:50:38 +08:00
Dominic Robinson
a1e0fa0f39 test(executor): cover string system prompt handling in checkSystemInstructionsWithMode 2026-03-09 12:40:27 +00:00
Supra4E8C
fc2f0b6983 fix: cap websocket body log growth 2026-03-09 17:48:30 +08:00
Dominic Robinson
5c9997cdac fix: Preserve system prompt when sent as a string instead of content block array 2026-03-09 07:38:11 +00:00
Luis Pater
6f81046730 docs: remove outdated sections from README and README_CN 2026-03-09 09:35:25 +08:00
Luis Pater
0687472d01 Merge pull request #422 from router-for-me/plus
v6.8.49
2026-03-09 09:34:05 +08:00
Luis Pater
7739738fb3 Merge branch 'main' into plus 2026-03-09 09:33:22 +08:00
Luis Pater
99d1ce247b Merge pull request #420 from Skadli/codex/responses-computer-tool
Fixed: preserve Responses computer tool passthrough
2026-03-09 09:31:30 +08:00
Luis Pater
f5941a411c test(auth): cover scheduler refresh regression paths 2026-03-09 09:27:56 +08:00
Luis Pater
ba672bbd07 Merge PR #1969 into dev 2026-03-09 09:25:06 +08:00
Luis Pater
d9c6627a53 Merge pull request #1963 from qixing-jk/docs/add-all-api-hub-showcase
docs: add All API Hub to related projects list
2026-03-09 09:16:41 +08:00
Luis Pater
2e9907c3ac Merge pull request #1959 from thebtf/fix/system-instruction-camelcase
fix: use camelCase systemInstruction in OpenAI-to-Gemini translators
2026-03-09 09:09:03 +08:00
DragonFSKY
90afb9cb73 fix(auth): new OAuth accounts invisible to scheduler after dynamic registration
When new OAuth auth files are added while the service is running,
`applyCoreAuthAddOrUpdate` calls `coreManager.Register()` (which upserts
into the scheduler) BEFORE `registerModelsForAuth()`. At upsert time,
`buildScheduledAuthMeta` snapshots `supportedModelSetForAuth` from the
global model registry — but models haven't been registered yet, so the
set is empty. With an empty `supportedModelSet`, `supportsModel()`
always returns false and the new auth is never added to any model shard.

Additionally, when all existing accounts are in cooldown, the scheduler
returns `modelCooldownError`, but `shouldRetrySchedulerPick` only
handles `*Error` types — so the `syncScheduler` safety-net rebuild
never triggers and the new accounts remain invisible.

Fix:
1. Add `RefreshSchedulerEntry()` to re-upsert a single auth after its
   models are registered, rebuilding `supportedModelSet` from the
   now-populated registry.
2. Call it from `applyCoreAuthAddOrUpdate` after `registerModelsForAuth`.
3. Make `shouldRetrySchedulerPick` also match `*modelCooldownError` so
   the full scheduler rebuild triggers when all credentials are cooling
   down — catching any similar stale-snapshot edge cases.
2026-03-09 03:11:47 +08:00
anime
d0cc0cd9a5 docs: add All API Hub to related projects list
- Update README.md with All API Hub entry in English
- Update README_CN.md with All API Hub entry in Chinese
2026-03-09 02:00:16 +08:00
Kirill Turanskiy
338321e553 fix: use camelCase systemInstruction in OpenAI-to-Gemini translators
The Gemini v1internal (cloudcode-pa) and Antigravity Manager endpoints
require camelCase "systemInstruction" in request JSON. The current
snake_case "system_instruction" causes system prompts to be silently
ignored when routing through these endpoints.

Replace all "system_instruction" JSON keys with "systemInstruction" in
chat-completions and responses request translators.
2026-03-08 15:59:13 +03:00
Luis Pater
182b31963a Merge branch 'router-for-me:main' into main 2026-03-08 20:48:05 +08:00
Luis Pater
4f48e5254a Merge pull request #1957 from router-for-me/thinking
fix(translator): pass through adaptive thinking effort
2026-03-08 20:46:58 +08:00
Luis Pater
15dd5db1d7 Merge pull request #1956 from router-for-me/vertex
fix(executor): use aiplatform base url for vertex api key calls
2026-03-08 20:46:28 +08:00
hkfires
424711b718 fix(executor): use aiplatform base url for vertex api key calls 2026-03-08 20:13:12 +08:00
skad
91a2b1f0b4 Fixed: preserve Responses computer tool passthrough
Keep the OpenAI Responses computer tool intact when normalizing requests for the GitHub Copilot executor.

This change preserves built-in computer tool definitions instead of dropping them as non-function tools, keeps explicit computer tool_choice selections unchanged, and classifies computer_call / computer_call_output items as assistant and tool turns when deriving the initiator header.

Together these adjustments allow Responses requests that use the computer tool to reach the upstream executor without losing tool metadata or switching turn ownership unexpectedly.
2026-03-08 13:59:32 +08:00
Luis Pater
2b134fc378 test(auth-scheduler): add unit tests and scheduler implementation
- Added comprehensive unit tests for `authScheduler` and related components.
- Implemented `authScheduler` with support for Round Robin, Fill First, and custom selector strategies.
- Improved tracking of auth states, cooldowns, and recovery logic in scheduler.
2026-03-08 05:52:55 +08:00
Luis Pater
b9153719b0 Merge pull request #1925 from shenshuoyaoyouguang/pr/openai-compat-pool-thinking
fix(openai-compat): improve pool fallback and preserve adaptive thinking
2026-03-08 01:05:05 +08:00
Luis Pater
631e5c8331 Merge pull request #1922 from shenshuoyaoyouguang/pr/model-registry-safety
fix(registry): clone model snapshots and invalidate available-model cache
2026-03-07 23:01:42 +08:00
Luis Pater
e9c60a0a67 Merge pull request #1910 from thebtf/fix/gemini-oauth-error-messages
fix: surface upstream error details in Gemini CLI OAuth onboarding UI
2026-03-07 22:25:18 +08:00
Luis Pater
98a1bb5a7f Merge pull request #1900 from rex-zsd/feature/add-gemini-3.1-flash-image-preview
feat(registry): add gemini-3.1-flash-image-preview model definition
2026-03-07 22:17:10 +08:00
Luis Pater
ca90487a8c Merge branch 'main' into feature/add-gemini-3.1-flash-image-preview 2026-03-07 22:16:09 +08:00
Luis Pater
1042489f85 Merge pull request #1893 from thebtf/fix/normalize-ttl-byte-preservation-mainline
fix: preserve original JSON bytes in normalizeCacheControlTTL
2026-03-07 22:14:13 +08:00
Luis Pater
38277c1ea6 Merge pull request #1875 from woqiqishi/fix/tool-use-id-sanitize
fix: sanitize tool_use.id to comply with Claude API regex ^[a-zA-Z0-9_-]+$
2026-03-07 22:06:36 +08:00
Luis Pater
ee0c24628f Merge branch 'router-for-me:main' into main 2026-03-07 20:42:22 +08:00
chujian
3a18f6fcca fix(registry): clone slice fields in model map output 2026-03-07 18:53:56 +08:00
chujian
099e734a02 fix(registry): always clone available model snapshots 2026-03-07 18:40:02 +08:00
chujian
a52da26b5d fix(auth): stop draining stream pool goroutines after context cancellation 2026-03-07 18:30:33 +08:00
chujian
522a68a4ea fix(openai-compat): retry empty bootstrap streams 2026-03-07 18:08:13 +08:00
chujian
a02eda54d0 fix(openai-compat): address review feedback 2026-03-07 17:39:42 +08:00
chujian
97ef633c57 fix(registry): address review feedback 2026-03-07 17:36:57 +08:00
chujian
dae8463ba1 fix(registry): clone model snapshots and invalidate available-model cache 2026-03-07 16:59:23 +08:00
chujian
7c1299922e fix(openai-compat): improve pool fallback and preserve adaptive thinking 2026-03-07 16:54:28 +08:00
Luis Pater
ddcf1f279d Fixed: #1901
test(websocket): add tests for incremental input and prewarm handling logic

- Added test cases for incremental input support based on upstream capabilities.
- Introduced validation for prewarm handling of `response.create` messages locally.
- Enhanced test coverage for websocket executor behavior, including payload forwarding checks.
- Updated websocket implementation with prewarm and incremental input logic for better testability.
2026-03-07 13:11:28 +08:00
Luis Pater
7e6bb8fdc5 Merge origin/dev into pr-1774-review and resolve watcher conflicts 2026-03-07 11:12:42 +08:00
Luis Pater
9cee8ef87b Merge pull request #1684 from alexey-yanchenko/fix/input-audio-from-openai-to-antigravity
fix: preserve input_audio content parts when proxying to Antigravity
2026-03-07 10:12:28 +08:00
Luis Pater
93fb841bcb Fixed: #1670
test(translator): add unit tests for OpenAI to Claude requests and tool result handling

- Introduced tests for converting OpenAI requests to Claude with text, base64 images, and URL images in tool results.
- Refactored `convertClaudeToolResultContent` and related functionality to properly handle raw content with images and text.
- Updated conversion logic to streamline image handling for both base64 and URL formats.
2026-03-07 09:25:22 +08:00
Luis Pater
0c05131aeb Merge branch 'router-for-me:main' into main 2026-03-07 09:08:28 +08:00
Luis Pater
5ebc58fab4 refactor(executor): remove legacy connCreateSent logic and standardize response.create usage for all websocket events
- Simplified connection logic by removing `connCreateSent` and related state handling.
- Updated `buildCodexWebsocketRequestBody` to always use `response.create`.
- Added unit tests to validate `response.create` behavior and beta header preservation.
- Dropped unsupported `response.append` and outdated `response.done` event types.
2026-03-07 09:07:23 +08:00
Luis Pater
2b609dd891 Merge pull request #1912 from FradSer/main
feat(registry): add gemini 3.1 flash lite preview
2026-03-07 05:41:31 +08:00
Frad LEE
a8cbc68c3e feat(registry): add gemini 3.1 flash lite preview
- Add model to GetGeminiModels()
- Add model to GetGeminiVertexModels()
- Add model to GetGeminiCLIModels()
- Add model to GetAIStudioModels()
- Add to AntigravityModelConfig with thinking levels
- Update gemini-3-flash-preview description

Registers the new lightweight Gemini model across all provider
endpoints for cost-effective high-volume usage scenarios.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-03-06 20:52:28 +08:00
Kirill Turanskiy
11a795a01c fix: surface upstream error details in Gemini CLI OAuth onboarding UI
SetOAuthSessionError previously sent generic messages to the management
panel (e.g. "Failed to complete Gemini CLI onboarding"), hiding the
actual error returned by Google APIs. The specific error was only
written to the server log via log.Errorf, which is often inaccessible
in headless/Docker deployments.

Include the upstream error in all 8 OAuth error paths so the
management panel shows actionable messages like "no Google Cloud
projects available for this account" instead of a generic failure.
2026-03-06 13:06:37 +03:00
Luis Pater
89c428216e Merge branch 'router-for-me:main' into main 2026-03-06 11:09:31 +08:00
Luis Pater
2695a99623 fix(translator): conditionally remove service_tier from OpenAI response processing 2026-03-06 11:07:22 +08:00
zhongnan.rex
242aecd924 feat(registry): add gemini-3.1-flash-image-preview model definition 2026-03-06 10:50:04 +08:00
hkfires
ce8cc1ba33 fix(translator): pass through adaptive thinking effort 2026-03-06 09:13:32 +08:00
Luis Pater
ad5253bd2b Merge branch 'router-for-me:main' into main 2026-03-06 04:15:55 +08:00
Kirill Turanskiy
97fdd2e088 fix: preserve original JSON bytes in normalizeCacheControlTTL when no TTL change needed
normalizeCacheControlTTL unconditionally re-serializes the entire request
body through json.Unmarshal/json.Marshal even when no TTL normalization
is needed. Go's json.Marshal randomizes map key order and HTML-escapes
<, >, & characters (to \u003c, \u003e, \u0026), producing different raw
bytes on every call.

Anthropic's prompt caching uses byte-prefix matching, so any byte-level
difference causes a cache miss. This means the ~119K system prompt and
tools are re-processed on every request when routed through CPA.

The fix adds a bool return to normalizeTTLForBlock to indicate whether
it actually modified anything, and skips the marshal step in
normalizeCacheControlTTL when no blocks were changed.
2026-03-05 22:28:01 +03:00
Luis Pater
9397f7049f fix(registry): simplify GPT 5.4 model description in static data 2026-03-06 02:32:56 +08:00
Xu Hong
553d6f50ea fix: sanitize tool_use.id to comply with Claude API regex ^[a-zA-Z0-9_-]+$
Add util.SanitizeClaudeToolID() to replace non-conforming characters in
tool_use.id fields across all five response translators (gemini, codex,
openai, antigravity, gemini-cli).

Upstream tool names may contain dots or other special characters
(e.g. "fs.readFile") that violate Claude's ID validation regex.
The sanitizer replaces such characters with underscores and provides
a generated fallback for empty IDs.

Fixes #1872, Fixes #1849

Made-with: Cursor
2026-03-06 00:10:09 +08:00
lyd123qw2008
dd44413ba5 refactor(watcher): make authSliceToMap always return map 2026-03-02 10:09:56 +08:00
lyd123qw2008
10fa0f2062 refactor(watcher): dedupe auth map conversion in incremental flow 2026-03-02 10:03:42 +08:00
lyd123qw2008
30338ecec4 perf(watcher): remove redundant auth clones in incremental path 2026-03-01 14:05:11 +08:00
lyd123qw2008
9a37defed3 test(watcher): restore main test names and max-retry callback coverage 2026-03-01 13:54:03 +08:00
lyd123qw2008
c83a057996 refactor(watcher): make auth file events fully incremental 2026-03-01 13:42:42 +08:00
Alexey Yanchenko
b7588428c5 fix: preserve input_audio content parts when proxying to Antigravity
- Add input_audio handling in chat/completions translator (antigravity_openai_request.go)
- Add input_audio handling in responses translator (gemini_openai-responses_request.go)
- Map OpenAI audio formats (mp3, wav, ogg, flac, aac, webm, pcm16, g711_ulaw, g711_alaw) to correct MIME types for Gemini inlineData
2026-02-23 20:50:28 +07:00
101 changed files with 14481 additions and 2870 deletions

View File

@@ -16,6 +16,10 @@ jobs:
steps:
- name: Checkout
uses: actions/checkout@v4
- name: Refresh models catalog
run: |
git fetch --depth 1 https://github.com/router-for-me/models.git main
git show FETCH_HEAD:models.json > internal/registry/models/models.json
- name: Set up Docker Buildx
uses: docker/setup-buildx-action@v3
- name: Login to DockerHub
@@ -47,6 +51,10 @@ jobs:
steps:
- name: Checkout
uses: actions/checkout@v4
- name: Refresh models catalog
run: |
git fetch --depth 1 https://github.com/router-for-me/models.git main
git show FETCH_HEAD:models.json > internal/registry/models/models.json
- name: Set up Docker Buildx
uses: docker/setup-buildx-action@v3
- name: Login to DockerHub

View File

@@ -12,6 +12,10 @@ jobs:
steps:
- name: Checkout
uses: actions/checkout@v4
- name: Refresh models catalog
run: |
git fetch --depth 1 https://github.com/router-for-me/models.git main
git show FETCH_HEAD:models.json > internal/registry/models/models.json
- name: Set up Go
uses: actions/setup-go@v5
with:

View File

@@ -16,6 +16,10 @@ jobs:
- uses: actions/checkout@v4
with:
fetch-depth: 0
- name: Refresh models catalog
run: |
git fetch --depth 1 https://github.com/router-for-me/models.git main
git show FETCH_HEAD:models.json > internal/registry/models/models.json
- run: git fetch --force --tags
- uses: actions/setup-go@v4
with:

117
README.md
View File

@@ -8,123 +8,6 @@ All third-party provider support is maintained by community contributors; CLIPro
The Plus release stays in lockstep with the mainline features.
## Differences from the Mainline
[![z.ai](https://assets.router-for.me/english-5-0.jpg)](https://z.ai/subscribe?ic=8JVLJQFSKB)
## New Features (Plus Enhanced)
GLM CODING PLAN is a subscription service designed for AI coding, starting at just $10/month. It provides access to their flagship GLM-4.7 & GLM-5 Only Available for Pro Usersmodel across 10+ popular AI coding tools (Claude Code, Cline, Roo Code, etc.), offering developers top-tier, fast, and stable coding experiences.
## Kiro Authentication
### CLI Login
> **Note:** Google/GitHub login is not available for third-party applications due to AWS Cognito restrictions.
**AWS Builder ID** (recommended):
```bash
# Device code flow
./CLIProxyAPI --kiro-aws-login
# Authorization code flow
./CLIProxyAPI --kiro-aws-authcode
```
**Import token from Kiro IDE:**
```bash
./CLIProxyAPI --kiro-import
```
To get a token from Kiro IDE:
1. Open Kiro IDE and login with Google (or GitHub)
2. Find the token file: `~/.kiro/kiro-auth-token.json`
3. Run: `./CLIProxyAPI --kiro-import`
**AWS IAM Identity Center (IDC):**
```bash
./CLIProxyAPI --kiro-idc-login --kiro-idc-start-url https://d-xxxxxxxxxx.awsapps.com/start
# Specify region
./CLIProxyAPI --kiro-idc-login --kiro-idc-start-url https://d-xxxxxxxxxx.awsapps.com/start --kiro-idc-region us-west-2
```
**Additional flags:**
| Flag | Description |
|------|-------------|
| `--no-browser` | Don't open browser automatically, print URL instead |
| `--no-incognito` | Use existing browser session (Kiro defaults to incognito). Useful for corporate SSO that requires an authenticated browser session |
| `--kiro-idc-start-url` | IDC Start URL (required with `--kiro-idc-login`) |
| `--kiro-idc-region` | IDC region (default: `us-east-1`) |
| `--kiro-idc-flow` | IDC flow type: `authcode` (default) or `device` |
### Web-based OAuth Login
Access the Kiro OAuth web interface at:
```
http://your-server:8080/v0/oauth/kiro
```
This provides a browser-based OAuth flow for Kiro (AWS CodeWhisperer) authentication with:
- AWS Builder ID login
- AWS Identity Center (IDC) login
- Token import from Kiro IDE
## Quick Deployment with Docker
### One-Command Deployment
```bash
# Create deployment directory
mkdir -p ~/cli-proxy && cd ~/cli-proxy
# Create docker-compose.yml
cat > docker-compose.yml << 'EOF'
services:
cli-proxy-api:
image: eceasy/cli-proxy-api-plus:latest
container_name: cli-proxy-api-plus
ports:
- "8317:8317"
volumes:
- ./config.yaml:/CLIProxyAPI/config.yaml
- ./auths:/root/.cli-proxy-api
- ./logs:/CLIProxyAPI/logs
restart: unless-stopped
EOF
# Download example config
curl -o config.yaml https://raw.githubusercontent.com/router-for-me/CLIProxyAPIPlus/main/config.example.yaml
# Pull and start
docker compose pull && docker compose up -d
```
### Configuration
Edit `config.yaml` before starting:
```yaml
# Basic configuration example
server:
port: 8317
# Add your provider configurations here
```
### Update to Latest Version
```bash
cd ~/cli-proxy
docker compose pull && docker compose up -d
```
## Contributing
This project only accepts pull requests that relate to third-party provider support. Any pull requests unrelated to third-party provider support will be rejected.

View File

@@ -6,125 +6,6 @@
所有的第三方供应商支持都由第三方社区维护者提供CLIProxyAPI 不提供技术支持。如需取得支持,请与对应的社区维护者联系。
该 Plus 版本的主线功能与主线功能强制同步。
## 与主线版本版本差异
[![bigmodel.cn](https://assets.router-for.me/chinese-5-0.jpg)](https://www.bigmodel.cn/claude-code?ic=RRVJPB5SII)
## 新增功能 (Plus 增强版)
GLM CODING PLAN 是专为AI编码打造的订阅套餐每月最低仅需20元即可在十余款主流AI编码工具如 Claude Code、Cline、Roo Code 中畅享智谱旗舰模型GLM-4.7受限于算力目前仅限Pro用户开放为开发者提供顶尖的编码体验。
智谱AI为本产品提供了特别优惠使用以下链接购买可以享受九折优惠https://www.bigmodel.cn/claude-code?ic=RRVJPB5SII
### 命令行登录
> **注意:** 由于 AWS Cognito 限制Google/GitHub 登录不可用于第三方应用。
**AWS Builder ID**(推荐):
```bash
# 设备码流程
./CLIProxyAPI --kiro-aws-login
# 授权码流程
./CLIProxyAPI --kiro-aws-authcode
```
**从 Kiro IDE 导入令牌:**
```bash
./CLIProxyAPI --kiro-import
```
获取令牌步骤:
1. 打开 Kiro IDE使用 Google或 GitHub登录
2. 找到令牌文件:`~/.kiro/kiro-auth-token.json`
3. 运行:`./CLIProxyAPI --kiro-import`
**AWS IAM Identity Center (IDC)**
```bash
./CLIProxyAPI --kiro-idc-login --kiro-idc-start-url https://d-xxxxxxxxxx.awsapps.com/start
# 指定区域
./CLIProxyAPI --kiro-idc-login --kiro-idc-start-url https://d-xxxxxxxxxx.awsapps.com/start --kiro-idc-region us-west-2
```
**附加参数:**
| 参数 | 说明 |
|------|------|
| `--no-browser` | 不自动打开浏览器,打印 URL |
| `--no-incognito` | 使用已有浏览器会话Kiro 默认使用无痕模式),适用于需要已登录浏览器会话的企业 SSO 场景 |
| `--kiro-idc-start-url` | IDC Start URL`--kiro-idc-login` 必需) |
| `--kiro-idc-region` | IDC 区域(默认:`us-east-1` |
| `--kiro-idc-flow` | IDC 流程类型:`authcode`(默认)或 `device` |
### 网页端 OAuth 登录
访问 Kiro OAuth 网页认证界面:
```
http://your-server:8080/v0/oauth/kiro
```
提供基于浏览器的 Kiro (AWS CodeWhisperer) OAuth 认证流程,支持:
- AWS Builder ID 登录
- AWS Identity Center (IDC) 登录
- 从 Kiro IDE 导入令牌
## Docker 快速部署
### 一键部署
```bash
# 创建部署目录
mkdir -p ~/cli-proxy && cd ~/cli-proxy
# 创建 docker-compose.yml
cat > docker-compose.yml << 'EOF'
services:
cli-proxy-api:
image: eceasy/cli-proxy-api-plus:latest
container_name: cli-proxy-api-plus
ports:
- "8317:8317"
volumes:
- ./config.yaml:/CLIProxyAPI/config.yaml
- ./auths:/root/.cli-proxy-api
- ./logs:/CLIProxyAPI/logs
restart: unless-stopped
EOF
# 下载示例配置
curl -o config.yaml https://raw.githubusercontent.com/router-for-me/CLIProxyAPIPlus/main/config.example.yaml
# 拉取并启动
docker compose pull && docker compose up -d
```
### 配置说明
启动前请编辑 `config.yaml`
```yaml
# 基本配置示例
server:
port: 8317
# 在此添加你的供应商配置
```
### 更新到最新版本
```bash
cd ~/cli-proxy
docker compose pull && docker compose up -d
```
## 贡献
该项目仅接受第三方供应商支持的 Pull Request。任何非第三方供应商支持的 Pull Request 都将被拒绝。
@@ -133,4 +14,4 @@ docker compose pull && docker compose up -d
## 许可证
此项目根据 MIT 许可证授权 - 有关详细信息,请参阅 [LICENSE](LICENSE) 文件。
此项目根据 MIT 许可证授权 - 有关详细信息,请参阅 [LICENSE](LICENSE) 文件。

View File

@@ -0,0 +1,275 @@
// Command fetch_antigravity_models connects to the Antigravity API using the
// stored auth credentials and saves the dynamically fetched model list to a
// JSON file for inspection or offline use.
//
// Usage:
//
// go run ./cmd/fetch_antigravity_models [flags]
//
// Flags:
//
// --auths-dir <path> Directory containing auth JSON files (default: "auths")
// --output <path> Output JSON file path (default: "antigravity_models.json")
// --pretty Pretty-print the output JSON (default: true)
package main
import (
"context"
"encoding/json"
"flag"
"fmt"
"io"
"net/http"
"os"
"path/filepath"
"strings"
"time"
"github.com/router-for-me/CLIProxyAPI/v6/internal/logging"
sdkauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/auth"
coreauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth"
"github.com/router-for-me/CLIProxyAPI/v6/sdk/proxyutil"
log "github.com/sirupsen/logrus"
"github.com/tidwall/gjson"
)
const (
antigravityBaseURLDaily = "https://daily-cloudcode-pa.googleapis.com"
antigravitySandboxBaseURLDaily = "https://daily-cloudcode-pa.sandbox.googleapis.com"
antigravityBaseURLProd = "https://cloudcode-pa.googleapis.com"
antigravityModelsPath = "/v1internal:fetchAvailableModels"
)
func init() {
logging.SetupBaseLogger()
log.SetLevel(log.InfoLevel)
}
// modelOutput wraps the fetched model list with fetch metadata.
type modelOutput struct {
Models []modelEntry `json:"models"`
}
// modelEntry contains only the fields we want to keep for static model definitions.
type modelEntry struct {
ID string `json:"id"`
Object string `json:"object"`
OwnedBy string `json:"owned_by"`
Type string `json:"type"`
DisplayName string `json:"display_name"`
Name string `json:"name"`
Description string `json:"description"`
ContextLength int `json:"context_length,omitempty"`
MaxCompletionTokens int `json:"max_completion_tokens,omitempty"`
}
func main() {
var authsDir string
var outputPath string
var pretty bool
flag.StringVar(&authsDir, "auths-dir", "auths", "Directory containing auth JSON files")
flag.StringVar(&outputPath, "output", "antigravity_models.json", "Output JSON file path")
flag.BoolVar(&pretty, "pretty", true, "Pretty-print the output JSON")
flag.Parse()
// Resolve relative paths against the working directory.
wd, err := os.Getwd()
if err != nil {
fmt.Fprintf(os.Stderr, "error: cannot get working directory: %v\n", err)
os.Exit(1)
}
if !filepath.IsAbs(authsDir) {
authsDir = filepath.Join(wd, authsDir)
}
if !filepath.IsAbs(outputPath) {
outputPath = filepath.Join(wd, outputPath)
}
fmt.Printf("Scanning auth files in: %s\n", authsDir)
// Load all auth records from the directory.
fileStore := sdkauth.NewFileTokenStore()
fileStore.SetBaseDir(authsDir)
ctx := context.Background()
auths, err := fileStore.List(ctx)
if err != nil {
fmt.Fprintf(os.Stderr, "error: failed to list auth files: %v\n", err)
os.Exit(1)
}
if len(auths) == 0 {
fmt.Fprintf(os.Stderr, "error: no auth files found in %s\n", authsDir)
os.Exit(1)
}
// Find the first enabled antigravity auth.
var chosen *coreauth.Auth
for _, a := range auths {
if a == nil || a.Disabled {
continue
}
if strings.EqualFold(strings.TrimSpace(a.Provider), "antigravity") {
chosen = a
break
}
}
if chosen == nil {
fmt.Fprintf(os.Stderr, "error: no enabled antigravity auth found in %s\n", authsDir)
os.Exit(1)
}
fmt.Printf("Using auth: id=%s label=%s\n", chosen.ID, chosen.Label)
// Fetch models from the upstream Antigravity API.
fmt.Println("Fetching Antigravity model list from upstream...")
fetchCtx, cancel := context.WithTimeout(ctx, 30*time.Second)
defer cancel()
models := fetchModels(fetchCtx, chosen)
if len(models) == 0 {
fmt.Fprintln(os.Stderr, "warning: no models returned (API may be unavailable or token expired)")
} else {
fmt.Printf("Fetched %d models.\n", len(models))
}
// Build the output payload.
out := modelOutput{
Models: models,
}
// Marshal to JSON.
var raw []byte
if pretty {
raw, err = json.MarshalIndent(out, "", " ")
} else {
raw, err = json.Marshal(out)
}
if err != nil {
fmt.Fprintf(os.Stderr, "error: failed to marshal JSON: %v\n", err)
os.Exit(1)
}
if err = os.WriteFile(outputPath, raw, 0o644); err != nil {
fmt.Fprintf(os.Stderr, "error: failed to write output file %s: %v\n", outputPath, err)
os.Exit(1)
}
fmt.Printf("Model list saved to: %s\n", outputPath)
}
func fetchModels(ctx context.Context, auth *coreauth.Auth) []modelEntry {
accessToken := metaStringValue(auth.Metadata, "access_token")
if accessToken == "" {
fmt.Fprintln(os.Stderr, "error: no access token found in auth")
return nil
}
baseURLs := []string{antigravityBaseURLProd, antigravityBaseURLDaily, antigravitySandboxBaseURLDaily}
for _, baseURL := range baseURLs {
modelsURL := baseURL + antigravityModelsPath
var payload []byte
if auth != nil && auth.Metadata != nil {
if pid, ok := auth.Metadata["project_id"].(string); ok && strings.TrimSpace(pid) != "" {
payload = []byte(fmt.Sprintf(`{"project": "%s"}`, strings.TrimSpace(pid)))
}
}
if len(payload) == 0 {
payload = []byte(`{}`)
}
httpReq, errReq := http.NewRequestWithContext(ctx, http.MethodPost, modelsURL, strings.NewReader(string(payload)))
if errReq != nil {
continue
}
httpReq.Close = true
httpReq.Header.Set("Content-Type", "application/json")
httpReq.Header.Set("Authorization", "Bearer "+accessToken)
httpReq.Header.Set("User-Agent", "antigravity/1.19.6 darwin/arm64")
httpClient := &http.Client{Timeout: 30 * time.Second}
if transport, _, errProxy := proxyutil.BuildHTTPTransport(auth.ProxyURL); errProxy == nil && transport != nil {
httpClient.Transport = transport
}
httpResp, errDo := httpClient.Do(httpReq)
if errDo != nil {
continue
}
bodyBytes, errRead := io.ReadAll(httpResp.Body)
httpResp.Body.Close()
if errRead != nil {
continue
}
if httpResp.StatusCode < http.StatusOK || httpResp.StatusCode >= http.StatusMultipleChoices {
continue
}
result := gjson.GetBytes(bodyBytes, "models")
if !result.Exists() {
continue
}
var models []modelEntry
for originalName, modelData := range result.Map() {
modelID := strings.TrimSpace(originalName)
if modelID == "" {
continue
}
// Skip internal/experimental models
switch modelID {
case "chat_20706", "chat_23310", "tab_flash_lite_preview", "tab_jump_flash_lite_preview", "gemini-2.5-flash-thinking", "gemini-2.5-pro":
continue
}
displayName := modelData.Get("displayName").String()
if displayName == "" {
displayName = modelID
}
entry := modelEntry{
ID: modelID,
Object: "model",
OwnedBy: "antigravity",
Type: "antigravity",
DisplayName: displayName,
Name: modelID,
Description: displayName,
}
if maxTok := modelData.Get("maxTokens").Int(); maxTok > 0 {
entry.ContextLength = int(maxTok)
}
if maxOut := modelData.Get("maxOutputTokens").Int(); maxOut > 0 {
entry.MaxCompletionTokens = int(maxOut)
}
models = append(models, entry)
}
return models
}
return nil
}
func metaStringValue(m map[string]interface{}, key string) string {
if m == nil {
return ""
}
v, ok := m[key]
if !ok {
return ""
}
switch val := v.(type) {
case string:
return val
default:
return ""
}
}

View File

@@ -25,6 +25,7 @@ import (
"github.com/router-for-me/CLIProxyAPI/v6/internal/logging"
"github.com/router-for-me/CLIProxyAPI/v6/internal/managementasset"
"github.com/router-for-me/CLIProxyAPI/v6/internal/misc"
"github.com/router-for-me/CLIProxyAPI/v6/internal/registry"
"github.com/router-for-me/CLIProxyAPI/v6/internal/store"
_ "github.com/router-for-me/CLIProxyAPI/v6/internal/translator"
"github.com/router-for-me/CLIProxyAPI/v6/internal/tui"
@@ -78,6 +79,8 @@ func main() {
var kiloLogin bool
var iflowLogin bool
var iflowCookie bool
var gitlabLogin bool
var gitlabTokenLogin bool
var noBrowser bool
var oauthCallbackPort int
var antigravityLogin bool
@@ -110,6 +113,8 @@ func main() {
flag.BoolVar(&kiloLogin, "kilo-login", false, "Login to Kilo AI using device flow")
flag.BoolVar(&iflowLogin, "iflow-login", false, "Login to iFlow using OAuth")
flag.BoolVar(&iflowCookie, "iflow-cookie", false, "Login to iFlow using Cookie")
flag.BoolVar(&gitlabLogin, "gitlab-login", false, "Login to GitLab Duo using OAuth")
flag.BoolVar(&gitlabTokenLogin, "gitlab-token-login", false, "Login to GitLab Duo using a personal access token")
flag.BoolVar(&noBrowser, "no-browser", false, "Don't open browser automatically for OAuth")
flag.IntVar(&oauthCallbackPort, "oauth-callback-port", 0, "Override OAuth callback port (defaults to provider-specific port)")
flag.BoolVar(&useIncognito, "incognito", false, "Open browser in incognito/private mode for OAuth (useful for multiple accounts)")
@@ -526,6 +531,10 @@ func main() {
cmd.DoIFlowLogin(cfg, options)
} else if iflowCookie {
cmd.DoIFlowCookieAuth(cfg, options)
} else if gitlabLogin {
cmd.DoGitLabLogin(cfg, options)
} else if gitlabTokenLogin {
cmd.DoGitLabTokenLogin(cfg, options)
} else if kimiLogin {
cmd.DoKimiLogin(cfg, options)
} else if kiroLogin {
@@ -573,6 +582,7 @@ func main() {
if standalone {
// Standalone mode: start an embedded local server and connect TUI client to it.
managementasset.StartAutoUpdater(context.Background(), configFilePath)
registry.StartModelsUpdater(context.Background())
hook := tui.NewLogHook(2000)
hook.SetFormatter(&logging.LogFormatter{})
log.AddHook(hook)
@@ -643,15 +653,16 @@ func main() {
}
}
} else {
// Start the main proxy service
managementasset.StartAutoUpdater(context.Background(), configFilePath)
// Start the main proxy service
managementasset.StartAutoUpdater(context.Background(), configFilePath)
registry.StartModelsUpdater(context.Background())
if cfg.AuthDir != "" {
kiro.InitializeAndStart(cfg.AuthDir, cfg)
defer kiro.StopGlobalRefreshManager()
}
if cfg.AuthDir != "" {
kiro.InitializeAndStart(cfg.AuthDir, cfg)
defer kiro.StopGlobalRefreshManager()
}
cmd.StartService(cfg, configFilePath, password)
cmd.StartService(cfg, configFilePath, password)
}
}
}

View File

@@ -68,7 +68,8 @@ error-logs-max-files: 10
usage-statistics-enabled: false
# Proxy URL. Supports socks5/http/https protocols. Example: socks5://user:pass@192.168.1.1:1080/
proxy-url: ''
# Per-entry proxy-url also supports "direct" or "none" to bypass both the global proxy-url and environment proxies explicitly.
proxy-url: ""
# When true, unprefixed model requests only use credentials without a prefix (except when prefix == model name).
force-model-prefix: false
@@ -115,6 +116,7 @@ nonstream-keepalive-interval: 0
# headers:
# X-Custom-Header: "custom-value"
# proxy-url: "socks5://proxy.example.com:1080"
# # proxy-url: "direct" # optional: explicit direct connect for this credential
# models:
# - name: "gemini-2.5-flash" # upstream model name
# alias: "gemini-flash" # client alias mapped to the upstream model
@@ -133,6 +135,7 @@ nonstream-keepalive-interval: 0
# headers:
# X-Custom-Header: "custom-value"
# proxy-url: "socks5://proxy.example.com:1080" # optional: per-key proxy override
# # proxy-url: "direct" # optional: explicit direct connect for this credential
# models:
# - name: "gpt-5-codex" # upstream model name
# alias: "codex-latest" # client alias mapped to the upstream model
@@ -151,6 +154,7 @@ nonstream-keepalive-interval: 0
# headers:
# X-Custom-Header: "custom-value"
# proxy-url: "socks5://proxy.example.com:1080" # optional: per-key proxy override
# # proxy-url: "direct" # optional: explicit direct connect for this credential
# models:
# - name: "claude-3-5-sonnet-20241022" # upstream model name
# alias: "claude-sonnet-latest" # client alias mapped to the upstream model
@@ -178,6 +182,14 @@ nonstream-keepalive-interval: 0
# runtime-version: "v24.3.0"
# timeout: "600"
# Default headers for Codex OAuth model requests.
# These are used only for file-backed/OAuth Codex requests when the client
# does not send the header. `user-agent` applies to HTTP and websocket requests;
# `beta-features` only applies to websocket requests. They do not apply to codex-api-key entries.
# codex-header-defaults:
# user-agent: "codex_cli_rs/0.114.0 (Mac OS 14.2.0; x86_64) vscode/1.111.0"
# beta-features: "multi_agent"
# Kiro (AWS CodeWhisperer) configuration
# Note: Kiro API currently only operates in us-east-1 region
#kiro:
@@ -215,10 +227,22 @@ nonstream-keepalive-interval: 0
# api-key-entries:
# - api-key: "sk-or-v1-...b780"
# proxy-url: "socks5://proxy.example.com:1080" # optional: per-key proxy override
# # proxy-url: "direct" # optional: explicit direct connect for this credential
# - api-key: "sk-or-v1-...b781" # without proxy-url
# models: # The models supported by the provider.
# - name: "moonshotai/kimi-k2:free" # The actual model name.
# alias: "kimi-k2" # The alias used in the API.
# # You may repeat the same alias to build an internal model pool.
# # The client still sees only one alias in the model list.
# # Requests to that alias will round-robin across the upstream names below,
# # and if the chosen upstream fails before producing output, the request will
# # continue with the next upstream model in the same alias pool.
# - name: "qwen3.5-plus"
# alias: "claude-opus-4.66"
# - name: "glm-5"
# alias: "claude-opus-4.66"
# - name: "kimi-k2.5"
# alias: "claude-opus-4.66"
# Vertex API keys (Vertex-compatible endpoints, use API key + base URL)
# vertex-api-key:
@@ -226,6 +250,7 @@ nonstream-keepalive-interval: 0
# prefix: "test" # optional: require calls like "test/vertex-pro" to target this credential
# base-url: "https://example.com/api" # e.g. https://zenmux.ai/api
# proxy-url: "socks5://proxy.example.com:1080" # optional per-key proxy override
# # proxy-url: "direct" # optional: explicit direct connect for this credential
# headers:
# X-Custom-Header: "custom-value"
# models: # optional: map aliases to upstream model names

115
docs/gitlab-duo.md Normal file
View File

@@ -0,0 +1,115 @@
# GitLab Duo guide
CLIProxyAPI can now use GitLab Duo as a first-class provider instead of treating it as a plain text wrapper.
It supports:
- OAuth login
- personal access token login
- automatic refresh of GitLab `direct_access` metadata
- dynamic model discovery from GitLab metadata
- native GitLab AI gateway routing for Anthropic and OpenAI/Codex managed models
- Claude-compatible and OpenAI-compatible downstream APIs
## What this means
If GitLab Duo returns an Anthropic-managed model, CLIProxyAPI routes requests through the GitLab AI gateway Anthropic proxy and uses the existing Claude executor path.
If GitLab Duo returns an OpenAI-managed model, CLIProxyAPI routes requests through the GitLab AI gateway OpenAI proxy and uses the existing Codex/OpenAI executor path.
That gives GitLab Duo much closer runtime behavior to the built-in `codex` provider:
- Claude-compatible clients can use GitLab Duo models through `/v1/messages`
- OpenAI-compatible clients can use GitLab Duo models through `/v1/chat/completions`
- OpenAI Responses clients can use GitLab Duo models through `/v1/responses`
The model list is not hardcoded. CLIProxyAPI reads the current model metadata from GitLab `direct_access` and registers:
- a stable alias: `gitlab-duo`
- any discovered managed model names, such as `claude-sonnet-4-5` or `gpt-5-codex`
## Login
OAuth login:
```bash
./CLIProxyAPI -gitlab-login
```
PAT login:
```bash
./CLIProxyAPI -gitlab-token-login
```
You can also provide inputs through environment variables:
```bash
export GITLAB_BASE_URL=https://gitlab.com
export GITLAB_OAUTH_CLIENT_ID=your-client-id
export GITLAB_OAUTH_CLIENT_SECRET=your-client-secret
export GITLAB_PERSONAL_ACCESS_TOKEN=glpat-...
```
Notes:
- OAuth requires a GitLab OAuth application.
- PAT login requires a personal access token that can call the GitLab APIs used by Duo. In practice, `api` scope is the safe baseline.
- Self-managed GitLab instances are supported through `GITLAB_BASE_URL`.
## Using the models
After login, start CLIProxyAPI normally and point your client at the local proxy.
You can select:
- `gitlab-duo` to use the current Duo-managed model for that account
- the discovered provider model name if you want to pin it explicitly
Examples:
```bash
curl http://127.0.0.1:8080/v1/models
```
```bash
curl http://127.0.0.1:8080/v1/chat/completions \
-H 'Content-Type: application/json' \
-d '{
"model": "gitlab-duo",
"messages": [
{"role": "user", "content": "Write a Go HTTP middleware for request IDs."}
]
}'
```
If the GitLab account is currently mapped to an Anthropic model, Claude-compatible clients can use the same account through the Claude handler path. If the account is currently mapped to an OpenAI/Codex model, OpenAI-compatible clients can use `/v1/chat/completions` or `/v1/responses`.
## How model freshness works
CLIProxyAPI does not ship a fixed GitLab Duo model catalog.
Instead, it refreshes GitLab `direct_access` metadata and uses the returned `model_details` and any discovered model list entries to keep the local registry aligned with the current GitLab-managed model assignment.
This matches GitLab's current public contract better than hardcoding model names.
## Current scope
The GitLab Duo provider now has:
- OAuth and PAT auth flows
- runtime refresh of Duo gateway credentials
- native Anthropic gateway routing
- native OpenAI/Codex gateway routing
- handler-level smoke tests for Claude-compatible and OpenAI-compatible paths
Still out of scope today:
- websocket or session-specific parity beyond the current HTTP APIs
- GitLab-specific IDE features that are not exposed through the public gateway contract
## References
- GitLab Code Suggestions API: https://docs.gitlab.com/api/code_suggestions/
- GitLab Agent Assistant and managed credentials: https://docs.gitlab.com/user/duo_agent_platform/agent_assistant/
- GitLab Duo model selection: https://docs.gitlab.com/user/gitlab_duo/model_selection/

115
docs/gitlab-duo_CN.md Normal file
View File

@@ -0,0 +1,115 @@
# GitLab Duo 使用说明
CLIProxyAPI 现在可以把 GitLab Duo 当作一等 Provider 来使用,而不是仅仅把它当成简单的文本补全封装。
当前支持:
- OAuth 登录
- personal access token 登录
- 自动刷新 GitLab `direct_access` 元数据
- 根据 GitLab 返回的元数据动态发现模型
- 针对 Anthropic 和 OpenAI/Codex 托管模型的 GitLab AI gateway 原生路由
- Claude 兼容与 OpenAI 兼容下游 API
## 这意味着什么
如果 GitLab Duo 返回的是 Anthropic 托管模型CLIProxyAPI 会通过 GitLab AI gateway 的 Anthropic 代理转发,并复用现有的 Claude executor 路径。
如果 GitLab Duo 返回的是 OpenAI 托管模型CLIProxyAPI 会通过 GitLab AI gateway 的 OpenAI 代理转发,并复用现有的 Codex/OpenAI executor 路径。
这让 GitLab Duo 的运行时行为更接近内置的 `codex` Provider
- Claude 兼容客户端可以通过 `/v1/messages` 使用 GitLab Duo 模型
- OpenAI 兼容客户端可以通过 `/v1/chat/completions` 使用 GitLab Duo 模型
- OpenAI Responses 客户端可以通过 `/v1/responses` 使用 GitLab Duo 模型
模型列表不是硬编码的。CLIProxyAPI 会从 GitLab `direct_access` 中读取当前模型元数据,并注册:
- 一个稳定别名:`gitlab-duo`
- GitLab 当前发现到的托管模型名,例如 `claude-sonnet-4-5``gpt-5-codex`
## 登录
OAuth 登录:
```bash
./CLIProxyAPI -gitlab-login
```
PAT 登录:
```bash
./CLIProxyAPI -gitlab-token-login
```
也可以通过环境变量提供输入:
```bash
export GITLAB_BASE_URL=https://gitlab.com
export GITLAB_OAUTH_CLIENT_ID=your-client-id
export GITLAB_OAUTH_CLIENT_SECRET=your-client-secret
export GITLAB_PERSONAL_ACCESS_TOKEN=glpat-...
```
说明:
- OAuth 方式需要一个 GitLab OAuth application。
- PAT 登录需要一个能够调用 GitLab Duo 相关 API 的 personal access token。实践上`api` scope 是最稳妥的基线。
- 自建 GitLab 实例可以通过 `GITLAB_BASE_URL` 接入。
## 如何使用模型
登录完成后,正常启动 CLIProxyAPI并让客户端连接到本地代理。
你可以选择:
- `gitlab-duo`,始终使用该账号当前的 Duo 托管模型
- GitLab 当前发现到的 provider 模型名,如果你想显式固定模型
示例:
```bash
curl http://127.0.0.1:8080/v1/models
```
```bash
curl http://127.0.0.1:8080/v1/chat/completions \
-H 'Content-Type: application/json' \
-d '{
"model": "gitlab-duo",
"messages": [
{"role": "user", "content": "Write a Go HTTP middleware for request IDs."}
]
}'
```
如果该 GitLab 账号当前绑定的是 Anthropic 模型Claude 兼容客户端可以通过 Claude handler 路径直接使用它。如果当前绑定的是 OpenAI/Codex 模型OpenAI 兼容客户端可以通过 `/v1/chat/completions``/v1/responses` 使用它。
## 模型如何保持最新
CLIProxyAPI 不内置固定的 GitLab Duo 模型清单。
它会刷新 GitLab `direct_access` 元数据,并使用返回的 `model_details` 以及可能存在的模型列表字段,让本地 registry 尽量与 GitLab 当前分配的托管模型保持一致。
这比硬编码模型名更符合 GitLab 当前公开 API 的实际契约。
## 当前覆盖范围
GitLab Duo Provider 目前已经具备:
- OAuth 和 PAT 登录流程
- Duo gateway 凭据的运行时刷新
- Anthropic gateway 原生路由
- OpenAI/Codex gateway 原生路由
- Claude 兼容和 OpenAI 兼容路径的 handler 级 smoke 测试
当前仍未覆盖:
- websocket 或 session 级别的完全对齐
- GitLab 公开 gateway 契约之外的 IDE 专有能力
## 参考资料
- GitLab Code Suggestions API: https://docs.gitlab.com/api/code_suggestions/
- GitLab Agent Assistant 与 managed credentials: https://docs.gitlab.com/user/duo_agent_platform/agent_assistant/
- GitLab Duo 模型选择: https://docs.gitlab.com/user/gitlab_duo/model_selection/

View File

@@ -0,0 +1,278 @@
# Plan: GitLab Duo Codex Parity
**Generated**: 2026-03-10
**Estimated Complexity**: High
## Overview
Bring GitLab Duo support from the current "auth + basic executor" stage to the same practical level as `codex` inside `CLIProxyAPI`: a user logs in once, points external clients such as Claude Code at `CLIProxyAPI`, selects GitLab Duo-backed models, and gets stable streaming, multi-turn behavior, tool calling compatibility, and predictable model routing without manual provider-specific workarounds.
The core architectural shift is to stop treating GitLab Duo as only two REST wrappers (`/api/v4/chat/completions` and `/api/v4/code_suggestions/completions`) and instead use GitLab's `direct_access` contract as the primary runtime entrypoint wherever possible. Official GitLab docs confirm that `direct_access` returns AI gateway connection details, headers, token, and expiry; that contract is the closest path to codex-like provider behavior.
## Prerequisites
- Official GitLab Duo API references confirmed during implementation:
- `POST /api/v4/code_suggestions/direct_access`
- `POST /api/v4/code_suggestions/completions`
- `POST /api/v4/chat/completions`
- Access to at least one real GitLab Duo account for manual verification.
- One downstream client target for acceptance testing:
- Claude Code against Claude-compatible endpoint
- OpenAI-compatible client against `/v1/chat/completions` and `/v1/responses`
- Existing PR branch as starting point:
- `feat/gitlab-duo-auth`
- PR [#2028](https://github.com/router-for-me/CLIProxyAPI/pull/2028)
## Definition Of Done
- GitLab Duo models can be used via `CLIProxyAPI` from the same client surfaces that already work for `codex`.
- Upstream streaming is real passthrough or faithful chunked forwarding, not synthetic whole-response replay.
- Tool/function calling survives translation layers without dropping fields or corrupting names.
- Multi-turn and session semantics are stable across `chat/completions`, `responses`, and Claude-compatible routes.
- Model exposure stays current from GitLab metadata or gateway discovery without hardcoded stale model tables.
- `go test ./...` stays green and at least one real manual end-to-end client flow is documented.
## Sprint 1: Contract And Gap Closure
**Goal**: Replace assumptions with a hard compatibility contract between current `codex` behavior and what GitLab Duo can actually support.
**Demo/Validation**:
- Written matrix showing `codex` features vs current GitLab Duo behavior.
- One checked-in developer note or test fixture for real GitLab Duo payload examples.
### Task 1.1: Freeze Codex Parity Checklist
- **Location**: [internal/runtime/executor/codex_executor.go](/home/luxvtz/projects/cliproxyapi/CLIProxyAPI/internal/runtime/executor/codex_executor.go), [internal/runtime/executor/codex_websockets_executor.go](/home/luxvtz/projects/cliproxyapi/CLIProxyAPI/internal/runtime/executor/codex_websockets_executor.go), [sdk/api/handlers/openai/openai_responses_handlers.go](/home/luxvtz/projects/cliproxyapi/CLIProxyAPI/sdk/api/handlers/openai/openai_responses_handlers.go), [sdk/api/handlers/openai/openai_responses_websocket.go](/home/luxvtz/projects/cliproxyapi/CLIProxyAPI/sdk/api/handlers/openai/openai_responses_websocket.go)
- **Description**: Produce a concrete feature matrix for `codex`: HTTP execute, SSE execute, `/v1/responses`, websocket downstream path, tool calling, request IDs, session close semantics, and model registration behavior.
- **Dependencies**: None
- **Acceptance Criteria**:
- A checklist exists in repo docs or issue notes.
- Each capability is marked `required`, `optional`, or `not possible` for GitLab Duo.
- **Validation**:
- Review against current `codex` code paths.
### Task 1.2: Lock GitLab Duo Runtime Contract
- **Location**: [internal/auth/gitlab/gitlab.go](/home/luxvtz/projects/cliproxyapi/CLIProxyAPI/internal/auth/gitlab/gitlab.go), [internal/runtime/executor/gitlab_executor.go](/home/luxvtz/projects/cliproxyapi/CLIProxyAPI/internal/runtime/executor/gitlab_executor.go)
- **Description**: Validate the exact upstream contract we can rely on:
- `direct_access` fields and refresh cadence
- whether AI gateway path is usable directly
- when `chat/completions` is available vs when fallback is required
- what streaming shape is returned by `code_suggestions/completions?stream=true`
- **Dependencies**: Task 1.1
- **Acceptance Criteria**:
- GitLab transport decision is explicit: `gateway-first`, `REST-first`, or `hybrid`.
- Unknown areas are isolated behind feature flags, not spread across executor logic.
- **Validation**:
- Official docs + captured real responses from a Duo account.
### Task 1.3: Define Client-Facing Compatibility Targets
- **Location**: [README.md](/home/luxvtz/projects/cliproxyapi/CLIProxyAPI/README.md), [gitlab-duo-codex-parity-plan.md](/home/luxvtz/projects/cliproxyapi/CLIProxyAPI/gitlab-duo-codex-parity-plan.md)
- **Description**: Define exactly which external flows must work to call GitLab Duo support "like codex".
- **Dependencies**: Task 1.2
- **Acceptance Criteria**:
- Required surfaces are listed:
- Claude-compatible route
- OpenAI `chat/completions`
- OpenAI `responses`
- optional downstream websocket path
- Non-goals are explicit if GitLab upstream cannot support them.
- **Validation**:
- Maintainer review of stated scope.
## Sprint 2: Primary Transport Parity
**Goal**: Move GitLab Duo execution onto a transport that supports codex-like runtime behavior.
**Demo/Validation**:
- A GitLab Duo model works over real streaming through `/v1/chat/completions`.
- No synthetic "collect full body then fake stream" path remains on the primary flow.
### Task 2.1: Refactor GitLab Executor Into Strategy Layers
- **Location**: [internal/runtime/executor/gitlab_executor.go](/home/luxvtz/projects/cliproxyapi/CLIProxyAPI/internal/runtime/executor/gitlab_executor.go)
- **Description**: Split current executor into explicit strategies:
- auth refresh/direct access refresh
- gateway transport
- GitLab REST fallback transport
- downstream translation helpers
- **Dependencies**: Sprint 1
- **Acceptance Criteria**:
- Executor no longer mixes discovery, refresh, fallback selection, and response synthesis in one path.
- Transport choice is testable in isolation.
- **Validation**:
- Unit tests for strategy selection and fallback boundaries.
### Task 2.2: Implement Real Streaming Path
- **Location**: [internal/runtime/executor/gitlab_executor.go](/home/luxvtz/projects/cliproxyapi/CLIProxyAPI/internal/runtime/executor/gitlab_executor.go), [internal/runtime/executor/gitlab_executor_test.go](/home/luxvtz/projects/cliproxyapi/CLIProxyAPI/internal/runtime/executor/gitlab_executor_test.go)
- **Description**: Replace synthetic streaming with true upstream incremental forwarding:
- use gateway stream if available
- otherwise consume GitLab Code Suggestions streaming response and map chunks incrementally
- **Dependencies**: Task 2.1
- **Acceptance Criteria**:
- `ExecuteStream` emits chunks before upstream completion.
- error handling preserves status and early failure semantics.
- **Validation**:
- tests with chunked upstream server
- manual curl check against `/v1/chat/completions` with `stream=true`
### Task 2.3: Preserve Upstream Auth And Headers Correctly
- **Location**: [internal/runtime/executor/gitlab_executor.go](/home/luxvtz/projects/cliproxyapi/CLIProxyAPI/internal/runtime/executor/gitlab_executor.go), [internal/auth/gitlab/gitlab.go](/home/luxvtz/projects/cliproxyapi/CLIProxyAPI/internal/auth/gitlab/gitlab.go)
- **Description**: Use `direct_access` connection details as first-class transport state:
- gateway token
- expiry
- mandatory forwarded headers
- model metadata
- **Dependencies**: Task 2.1
- **Acceptance Criteria**:
- executor stops ignoring gateway headers/token when transport requires them
- refresh logic never over-fetches `direct_access`
- **Validation**:
- tests verifying propagated headers and refresh interval behavior
## Sprint 3: Request/Response Semantics Parity
**Goal**: Make GitLab Duo behave correctly under the same request shapes that current `codex` consumers send.
**Demo/Validation**:
- OpenAI and Claude-compatible clients can do non-streaming and streaming conversations without losing structure.
### Task 3.1: Normalize Multi-Turn Message Mapping
- **Location**: [internal/runtime/executor/gitlab_executor.go](/home/luxvtz/projects/cliproxyapi/CLIProxyAPI/internal/runtime/executor/gitlab_executor.go), [sdk/translator](/home/luxvtz/projects/cliproxyapi/CLIProxyAPI/sdk/translator)
- **Description**: Replace the current "flatten prompt into one instruction" behavior with stable multi-turn mapping:
- preserve system context
- preserve user/assistant ordering
- maintain bounded context truncation
- **Dependencies**: Sprint 2
- **Acceptance Criteria**:
- multi-turn requests are not collapsed into a lossy single string unless fallback mode explicitly requires it
- truncation policy is deterministic and tested
- **Validation**:
- golden tests for request mapping
### Task 3.2: Tool Calling Compatibility Layer
- **Location**: [internal/runtime/executor/gitlab_executor.go](/home/luxvtz/projects/cliproxyapi/CLIProxyAPI/internal/runtime/executor/gitlab_executor.go), [sdk/api/handlers/openai/openai_responses_handlers.go](/home/luxvtz/projects/cliproxyapi/CLIProxyAPI/sdk/api/handlers/openai/openai_responses_handlers.go)
- **Description**: Decide and implement one of two paths:
- native pass-through if GitLab gateway supports tool/function structures
- strict downgrade path with explicit unsupported errors instead of silent field loss
- **Dependencies**: Task 3.1
- **Acceptance Criteria**:
- tool-related fields are either preserved correctly or rejected explicitly
- no silent corruption of tool names, tool calls, or tool results
- **Validation**:
- table-driven tests for tool payloads
- one manual client scenario using tools
### Task 3.3: Token Counting And Usage Reporting Fidelity
- **Location**: [internal/runtime/executor/gitlab_executor.go](/home/luxvtz/projects/cliproxyapi/CLIProxyAPI/internal/runtime/executor/gitlab_executor.go), [internal/runtime/executor/usage_helpers.go](/home/luxvtz/projects/cliproxyapi/CLIProxyAPI/internal/runtime/executor/usage_helpers.go)
- **Description**: Improve token/usage reporting so GitLab models behave like first-class providers in logs and scheduling.
- **Dependencies**: Sprint 2
- **Acceptance Criteria**:
- `CountTokens` uses the closest supported estimation path
- usage logging distinguishes prompt vs completion when possible
- **Validation**:
- unit tests for token estimation outputs
## Sprint 4: Responses And Session Parity
**Goal**: Reach codex-level support for OpenAI Responses clients and long-lived sessions where GitLab upstream permits it.
**Demo/Validation**:
- `/v1/responses` works with GitLab Duo in a realistic client flow.
- If websocket parity is not possible, the code explicitly declines it and keeps HTTP paths stable.
### Task 4.1: Make GitLab Compatible With `/v1/responses`
- **Location**: [sdk/api/handlers/openai/openai_responses_handlers.go](/home/luxvtz/projects/cliproxyapi/CLIProxyAPI/sdk/api/handlers/openai/openai_responses_handlers.go), [internal/runtime/executor/gitlab_executor.go](/home/luxvtz/projects/cliproxyapi/CLIProxyAPI/internal/runtime/executor/gitlab_executor.go)
- **Description**: Ensure GitLab transport can safely back the Responses API path, including compact responses if applicable.
- **Dependencies**: Sprint 3
- **Acceptance Criteria**:
- GitLab Duo can be selected behind `/v1/responses`
- response IDs and follow-up semantics are defined
- **Validation**:
- handler tests analogous to codex/openai responses tests
### Task 4.2: Evaluate Downstream Websocket Parity
- **Location**: [sdk/api/handlers/openai/openai_responses_websocket.go](/home/luxvtz/projects/cliproxyapi/CLIProxyAPI/sdk/api/handlers/openai/openai_responses_websocket.go), [internal/runtime/executor/gitlab_executor.go](/home/luxvtz/projects/cliproxyapi/CLIProxyAPI/internal/runtime/executor/gitlab_executor.go)
- **Description**: Decide whether GitLab Duo can support downstream websocket sessions like codex:
- if yes, add session-aware execution path
- if no, mark GitLab auth as websocket-ineligible and keep HTTP routes first-class
- **Dependencies**: Task 4.1
- **Acceptance Criteria**:
- websocket behavior is explicit, not accidental
- no route claims websocket support when the upstream cannot honor it
- **Validation**:
- websocket handler tests or explicit capability tests
### Task 4.3: Add Session Cleanup And Failure Recovery Semantics
- **Location**: [internal/runtime/executor/gitlab_executor.go](/home/luxvtz/projects/cliproxyapi/CLIProxyAPI/internal/runtime/executor/gitlab_executor.go), [sdk/cliproxy/auth/conductor.go](/home/luxvtz/projects/cliproxyapi/CLIProxyAPI/sdk/cliproxy/auth/conductor.go)
- **Description**: Add codex-like session cleanup, retry boundaries, and model suspension/resume behavior for GitLab failures and quota events.
- **Dependencies**: Sprint 2
- **Acceptance Criteria**:
- auth/model cooldown behavior is predictable on GitLab 4xx/5xx/quota responses
- executor cleans up per-session resources if any are introduced
- **Validation**:
- tests for quota and retry behavior
## Sprint 5: Client UX, Model UX, And Manual E2E
**Goal**: Make GitLab Duo feel like a normal built-in provider to operators and downstream clients.
**Demo/Validation**:
- A documented setup exists for "login once, point Claude Code at CLIProxyAPI, use GitLab Duo-backed model".
### Task 5.1: Model Alias And Provider UX Cleanup
- **Location**: [sdk/cliproxy/service.go](/home/luxvtz/projects/cliproxyapi/CLIProxyAPI/sdk/cliproxy/service.go), [README.md](/home/luxvtz/projects/cliproxyapi/CLIProxyAPI/README.md)
- **Description**: Normalize what users see:
- stable alias such as `gitlab-duo`
- discovered upstream model names
- optional prefix behavior
- account labels that clearly distinguish OAuth vs PAT
- **Dependencies**: Sprint 3
- **Acceptance Criteria**:
- users can select a stable GitLab alias even when upstream model changes
- dynamic model discovery does not cause confusing model churn
- **Validation**:
- registry tests and manual `/v1/models` inspection
### Task 5.2: Add Real End-To-End Acceptance Tests
- **Location**: [internal/runtime/executor/gitlab_executor_test.go](/home/luxvtz/projects/cliproxyapi/CLIProxyAPI/internal/runtime/executor/gitlab_executor_test.go), [sdk/api/handlers/openai](/home/luxvtz/projects/cliproxyapi/CLIProxyAPI/sdk/api/handlers/openai)
- **Description**: Add higher-level tests covering the actual proxy surfaces:
- OpenAI `chat/completions`
- OpenAI `responses`
- Claude-compatible request path if GitLab is routed there
- **Dependencies**: Sprint 4
- **Acceptance Criteria**:
- tests fail if streaming regresses into synthetic buffering again
- tests cover at least one tool-related request and one multi-turn request
- **Validation**:
- `go test ./...`
### Task 5.3: Publish Operator Documentation
- **Location**: [README.md](/home/luxvtz/projects/cliproxyapi/CLIProxyAPI/README.md)
- **Description**: Document:
- OAuth setup requirements
- PAT requirements
- current capability matrix
- known limitations if websocket/tool parity is partial
- **Dependencies**: Sprint 5.1
- **Acceptance Criteria**:
- setup instructions are enough for a new user to reproduce the GitLab Duo flow
- limitations are explicit
- **Validation**:
- dry-run docs review from a clean environment
## Testing Strategy
- Keep `go test ./...` green after every committable task.
- Add table-driven tests first for request mapping, refresh behavior, and dynamic model registration.
- Add transport tests with `httptest.Server` for:
- real chunked streaming
- header propagation from `direct_access`
- upstream fallback rules
- Add at least one manual acceptance checklist:
- login via OAuth
- login via PAT
- list models
- run one streaming prompt via OpenAI route
- run one prompt from the target downstream client
## Potential Risks & Gotchas
- GitLab public docs expose `direct_access`, but do not fully document every possible AI gateway path. We should isolate any empirically discovered gateway assumptions behind one transport layer and feature flags.
- `chat/completions` availability differs by GitLab offering and version. The executor must not assume it always exists.
- Code Suggestions is completion-oriented; lossy mapping from rich chat/tool payloads will make GitLab Duo feel worse than codex unless explicitly handled.
- Synthetic streaming is not good enough for codex parity and will cause regressions in interactive clients.
- Dynamic model discovery can create unstable UX if the stable alias and discovered model IDs are not separated cleanly.
- PAT auth may validate successfully while still lacking effective Duo permissions. Error reporting must surface this explicitly.
## Rollback Plan
- Keep the current basic GitLab executor behind a fallback mode until the new transport path is stable.
- If parity work destabilizes existing providers, revert only GitLab-specific executor changes and leave auth support intact.
- Preserve the stable `gitlab-duo` alias so rollback does not break client configuration.

View File

@@ -6,7 +6,6 @@ import (
"encoding/json"
"fmt"
"io"
"net"
"net/http"
"net/url"
"strings"
@@ -14,13 +13,12 @@ import (
"github.com/fxamacker/cbor/v2"
"github.com/gin-gonic/gin"
log "github.com/sirupsen/logrus"
"golang.org/x/net/proxy"
"golang.org/x/oauth2"
"golang.org/x/oauth2/google"
"github.com/router-for-me/CLIProxyAPI/v6/internal/runtime/geminicli"
coreauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth"
"github.com/router-for-me/CLIProxyAPI/v6/sdk/proxyutil"
log "github.com/sirupsen/logrus"
"golang.org/x/oauth2"
"golang.org/x/oauth2/google"
)
const defaultAPICallTimeout = 60 * time.Second
@@ -725,47 +723,12 @@ func (h *Handler) apiCallTransport(auth *coreauth.Auth) http.RoundTripper {
}
func buildProxyTransport(proxyStr string) *http.Transport {
proxyStr = strings.TrimSpace(proxyStr)
if proxyStr == "" {
transport, _, errBuild := proxyutil.BuildHTTPTransport(proxyStr)
if errBuild != nil {
log.WithError(errBuild).Debug("build proxy transport failed")
return nil
}
proxyURL, errParse := url.Parse(proxyStr)
if errParse != nil {
log.WithError(errParse).Debug("parse proxy URL failed")
return nil
}
if proxyURL.Scheme == "" || proxyURL.Host == "" {
log.Debug("proxy URL missing scheme/host")
return nil
}
if proxyURL.Scheme == "socks5" {
var proxyAuth *proxy.Auth
if proxyURL.User != nil {
username := proxyURL.User.Username()
password, _ := proxyURL.User.Password()
proxyAuth = &proxy.Auth{User: username, Password: password}
}
dialer, errSOCKS5 := proxy.SOCKS5("tcp", proxyURL.Host, proxyAuth, proxy.Direct)
if errSOCKS5 != nil {
log.WithError(errSOCKS5).Debug("create SOCKS5 dialer failed")
return nil
}
return &http.Transport{
Proxy: nil,
DialContext: func(ctx context.Context, network, addr string) (net.Conn, error) {
return dialer.Dial(network, addr)
},
}
}
if proxyURL.Scheme == "http" || proxyURL.Scheme == "https" {
return &http.Transport{Proxy: http.ProxyURL(proxyURL)}
}
log.Debugf("unsupported proxy scheme: %s", proxyURL.Scheme)
return nil
return transport
}
// headerContainsValue checks whether a header map contains a target value (case-insensitive key and value).

View File

@@ -1,173 +1,58 @@
package management
import (
"context"
"encoding/json"
"io"
"net/http"
"net/http/httptest"
"net/url"
"strings"
"sync"
"testing"
"time"
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
coreauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth"
sdkconfig "github.com/router-for-me/CLIProxyAPI/v6/sdk/config"
)
type memoryAuthStore struct {
mu sync.Mutex
items map[string]*coreauth.Auth
}
func TestAPICallTransportDirectBypassesGlobalProxy(t *testing.T) {
t.Parallel()
func (s *memoryAuthStore) List(ctx context.Context) ([]*coreauth.Auth, error) {
_ = ctx
s.mu.Lock()
defer s.mu.Unlock()
out := make([]*coreauth.Auth, 0, len(s.items))
for _, a := range s.items {
out = append(out, a.Clone())
}
return out, nil
}
func (s *memoryAuthStore) Save(ctx context.Context, auth *coreauth.Auth) (string, error) {
_ = ctx
if auth == nil {
return "", nil
}
s.mu.Lock()
if s.items == nil {
s.items = make(map[string]*coreauth.Auth)
}
s.items[auth.ID] = auth.Clone()
s.mu.Unlock()
return auth.ID, nil
}
func (s *memoryAuthStore) Delete(ctx context.Context, id string) error {
_ = ctx
s.mu.Lock()
delete(s.items, id)
s.mu.Unlock()
return nil
}
func TestResolveTokenForAuth_Antigravity_RefreshesExpiredToken(t *testing.T) {
var callCount int
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
callCount++
if r.Method != http.MethodPost {
t.Fatalf("expected POST, got %s", r.Method)
}
if ct := r.Header.Get("Content-Type"); !strings.HasPrefix(ct, "application/x-www-form-urlencoded") {
t.Fatalf("unexpected content-type: %s", ct)
}
bodyBytes, _ := io.ReadAll(r.Body)
_ = r.Body.Close()
values, err := url.ParseQuery(string(bodyBytes))
if err != nil {
t.Fatalf("parse form: %v", err)
}
if values.Get("grant_type") != "refresh_token" {
t.Fatalf("unexpected grant_type: %s", values.Get("grant_type"))
}
if values.Get("refresh_token") != "rt" {
t.Fatalf("unexpected refresh_token: %s", values.Get("refresh_token"))
}
if values.Get("client_id") != antigravityOAuthClientID {
t.Fatalf("unexpected client_id: %s", values.Get("client_id"))
}
if values.Get("client_secret") != antigravityOAuthClientSecret {
t.Fatalf("unexpected client_secret")
}
w.Header().Set("Content-Type", "application/json")
_ = json.NewEncoder(w).Encode(map[string]any{
"access_token": "new-token",
"refresh_token": "rt2",
"expires_in": int64(3600),
"token_type": "Bearer",
})
}))
t.Cleanup(srv.Close)
originalURL := antigravityOAuthTokenURL
antigravityOAuthTokenURL = srv.URL
t.Cleanup(func() { antigravityOAuthTokenURL = originalURL })
store := &memoryAuthStore{}
manager := coreauth.NewManager(store, nil, nil)
auth := &coreauth.Auth{
ID: "antigravity-test.json",
FileName: "antigravity-test.json",
Provider: "antigravity",
Metadata: map[string]any{
"type": "antigravity",
"access_token": "old-token",
"refresh_token": "rt",
"expires_in": int64(3600),
"timestamp": time.Now().Add(-2 * time.Hour).UnixMilli(),
"expired": time.Now().Add(-1 * time.Hour).Format(time.RFC3339),
h := &Handler{
cfg: &config.Config{
SDKConfig: sdkconfig.SDKConfig{ProxyURL: "http://global-proxy.example.com:8080"},
},
}
if _, err := manager.Register(context.Background(), auth); err != nil {
t.Fatalf("register auth: %v", err)
}
h := &Handler{authManager: manager}
token, err := h.resolveTokenForAuth(context.Background(), auth)
if err != nil {
t.Fatalf("resolveTokenForAuth: %v", err)
transport := h.apiCallTransport(&coreauth.Auth{ProxyURL: "direct"})
httpTransport, ok := transport.(*http.Transport)
if !ok {
t.Fatalf("transport type = %T, want *http.Transport", transport)
}
if token != "new-token" {
t.Fatalf("expected refreshed token, got %q", token)
}
if callCount != 1 {
t.Fatalf("expected 1 refresh call, got %d", callCount)
}
updated, ok := manager.GetByID(auth.ID)
if !ok || updated == nil {
t.Fatalf("expected auth in manager after update")
}
if got := tokenValueFromMetadata(updated.Metadata); got != "new-token" {
t.Fatalf("expected manager metadata updated, got %q", got)
if httpTransport.Proxy != nil {
t.Fatal("expected direct transport to disable proxy function")
}
}
func TestResolveTokenForAuth_Antigravity_SkipsRefreshWhenTokenValid(t *testing.T) {
var callCount int
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
callCount++
w.WriteHeader(http.StatusInternalServerError)
}))
t.Cleanup(srv.Close)
func TestAPICallTransportInvalidAuthFallsBackToGlobalProxy(t *testing.T) {
t.Parallel()
originalURL := antigravityOAuthTokenURL
antigravityOAuthTokenURL = srv.URL
t.Cleanup(func() { antigravityOAuthTokenURL = originalURL })
auth := &coreauth.Auth{
ID: "antigravity-valid.json",
FileName: "antigravity-valid.json",
Provider: "antigravity",
Metadata: map[string]any{
"type": "antigravity",
"access_token": "ok-token",
"expired": time.Now().Add(30 * time.Minute).Format(time.RFC3339),
h := &Handler{
cfg: &config.Config{
SDKConfig: sdkconfig.SDKConfig{ProxyURL: "http://global-proxy.example.com:8080"},
},
}
h := &Handler{}
token, err := h.resolveTokenForAuth(context.Background(), auth)
if err != nil {
t.Fatalf("resolveTokenForAuth: %v", err)
transport := h.apiCallTransport(&coreauth.Auth{ProxyURL: "bad-value"})
httpTransport, ok := transport.(*http.Transport)
if !ok {
t.Fatalf("transport type = %T, want *http.Transport", transport)
}
if token != "ok-token" {
t.Fatalf("expected existing token, got %q", token)
req, errRequest := http.NewRequest(http.MethodGet, "https://example.com", nil)
if errRequest != nil {
t.Fatalf("http.NewRequest returned error: %v", errRequest)
}
if callCount != 0 {
t.Fatalf("expected no refresh calls, got %d", callCount)
proxyURL, errProxy := httpTransport.Proxy(req)
if errProxy != nil {
t.Fatalf("httpTransport.Proxy returned error: %v", errProxy)
}
if proxyURL == nil || proxyURL.String() != "http://global-proxy.example.com:8080" {
t.Fatalf("proxy URL = %v, want http://global-proxy.example.com:8080", proxyURL)
}
}

View File

@@ -29,6 +29,7 @@ import (
"github.com/router-for-me/CLIProxyAPI/v6/internal/auth/codex"
"github.com/router-for-me/CLIProxyAPI/v6/internal/auth/copilot"
geminiAuth "github.com/router-for-me/CLIProxyAPI/v6/internal/auth/gemini"
gitlabauth "github.com/router-for-me/CLIProxyAPI/v6/internal/auth/gitlab"
iflowauth "github.com/router-for-me/CLIProxyAPI/v6/internal/auth/iflow"
"github.com/router-for-me/CLIProxyAPI/v6/internal/auth/kilo"
"github.com/router-for-me/CLIProxyAPI/v6/internal/auth/kimi"
@@ -54,6 +55,8 @@ const (
codexCallbackPort = 1455
geminiCLIEndpoint = "https://cloudcode-pa.googleapis.com"
geminiCLIVersion = "v1internal"
gitLabLoginModeOAuth = "oauth"
gitLabLoginModePAT = "pat"
)
type callbackForwarder struct {
@@ -999,6 +1002,165 @@ func (h *Handler) saveTokenRecord(ctx context.Context, record *coreauth.Auth) (s
return store.Save(ctx, record)
}
func gitLabBaseURLFromRequest(c *gin.Context) string {
if c != nil {
if raw := strings.TrimSpace(c.Query("base_url")); raw != "" {
return gitlabauth.NormalizeBaseURL(raw)
}
}
if raw := strings.TrimSpace(os.Getenv("GITLAB_BASE_URL")); raw != "" {
return gitlabauth.NormalizeBaseURL(raw)
}
return gitlabauth.DefaultBaseURL
}
func buildGitLabAuthMetadata(baseURL, mode string, tokenResp *gitlabauth.TokenResponse, direct *gitlabauth.DirectAccessResponse) map[string]any {
metadata := map[string]any{
"type": "gitlab",
"auth_method": strings.TrimSpace(mode),
"base_url": gitlabauth.NormalizeBaseURL(baseURL),
"last_refresh": time.Now().UTC().Format(time.RFC3339),
"refresh_interval_seconds": 240,
}
if tokenResp != nil {
metadata["access_token"] = strings.TrimSpace(tokenResp.AccessToken)
if refreshToken := strings.TrimSpace(tokenResp.RefreshToken); refreshToken != "" {
metadata["refresh_token"] = refreshToken
}
if tokenType := strings.TrimSpace(tokenResp.TokenType); tokenType != "" {
metadata["token_type"] = tokenType
}
if scope := strings.TrimSpace(tokenResp.Scope); scope != "" {
metadata["scope"] = scope
}
if expiry := gitlabauth.TokenExpiry(time.Now(), tokenResp); !expiry.IsZero() {
metadata["oauth_expires_at"] = expiry.Format(time.RFC3339)
}
}
mergeGitLabDirectAccessMetadata(metadata, direct)
return metadata
}
func mergeGitLabDirectAccessMetadata(metadata map[string]any, direct *gitlabauth.DirectAccessResponse) {
if metadata == nil || direct == nil {
return
}
if base := strings.TrimSpace(direct.BaseURL); base != "" {
metadata["duo_gateway_base_url"] = base
}
if token := strings.TrimSpace(direct.Token); token != "" {
metadata["duo_gateway_token"] = token
}
if direct.ExpiresAt > 0 {
expiry := time.Unix(direct.ExpiresAt, 0).UTC()
metadata["duo_gateway_expires_at"] = expiry.Format(time.RFC3339)
now := time.Now().UTC()
if ttl := expiry.Sub(now); ttl > 0 {
interval := int(ttl.Seconds()) / 2
switch {
case interval < 60:
interval = 60
case interval > 240:
interval = 240
}
metadata["refresh_interval_seconds"] = interval
}
}
if len(direct.Headers) > 0 {
headers := make(map[string]string, len(direct.Headers))
for key, value := range direct.Headers {
key = strings.TrimSpace(key)
value = strings.TrimSpace(value)
if key == "" || value == "" {
continue
}
headers[key] = value
}
if len(headers) > 0 {
metadata["duo_gateway_headers"] = headers
}
}
if direct.ModelDetails != nil {
modelDetails := map[string]any{}
if provider := strings.TrimSpace(direct.ModelDetails.ModelProvider); provider != "" {
modelDetails["model_provider"] = provider
metadata["model_provider"] = provider
}
if model := strings.TrimSpace(direct.ModelDetails.ModelName); model != "" {
modelDetails["model_name"] = model
metadata["model_name"] = model
}
if len(modelDetails) > 0 {
metadata["model_details"] = modelDetails
}
}
}
func primaryGitLabEmail(user *gitlabauth.User) string {
if user == nil {
return ""
}
if value := strings.TrimSpace(user.Email); value != "" {
return value
}
return strings.TrimSpace(user.PublicEmail)
}
func gitLabAccountIdentifier(user *gitlabauth.User) string {
if user == nil {
return "user"
}
for _, value := range []string{user.Username, primaryGitLabEmail(user), user.Name} {
if trimmed := strings.TrimSpace(value); trimmed != "" {
return trimmed
}
}
return "user"
}
func sanitizeGitLabFileName(value string) string {
value = strings.TrimSpace(strings.ToLower(value))
if value == "" {
return "user"
}
var builder strings.Builder
lastDash := false
for _, r := range value {
switch {
case r >= 'a' && r <= 'z':
builder.WriteRune(r)
lastDash = false
case r >= '0' && r <= '9':
builder.WriteRune(r)
lastDash = false
case r == '-' || r == '_' || r == '.':
builder.WriteRune(r)
lastDash = false
default:
if !lastDash {
builder.WriteRune('-')
lastDash = true
}
}
}
result := strings.Trim(builder.String(), "-")
if result == "" {
return "user"
}
return result
}
func maskGitLabToken(token string) string {
trimmed := strings.TrimSpace(token)
if trimmed == "" {
return ""
}
if len(trimmed) <= 8 {
return trimmed
}
return trimmed[:4] + "..." + trimmed[len(trimmed)-4:]
}
func (h *Handler) RequestAnthropicToken(c *gin.Context) {
ctx := context.Background()
ctx = PopulateAuthContext(ctx, c)
@@ -1312,12 +1474,12 @@ func (h *Handler) RequestGeminiCLIToken(c *gin.Context) {
projects, errAll := onboardAllGeminiProjects(ctx, gemClient, &ts)
if errAll != nil {
log.Errorf("Failed to complete Gemini CLI onboarding: %v", errAll)
SetOAuthSessionError(state, "Failed to complete Gemini CLI onboarding")
SetOAuthSessionError(state, fmt.Sprintf("Failed to complete Gemini CLI onboarding: %v", errAll))
return
}
if errVerify := ensureGeminiProjectsEnabled(ctx, gemClient, projects); errVerify != nil {
log.Errorf("Failed to verify Cloud AI API status: %v", errVerify)
SetOAuthSessionError(state, "Failed to verify Cloud AI API status")
SetOAuthSessionError(state, fmt.Sprintf("Failed to verify Cloud AI API status: %v", errVerify))
return
}
ts.ProjectID = strings.Join(projects, ",")
@@ -1326,7 +1488,7 @@ func (h *Handler) RequestGeminiCLIToken(c *gin.Context) {
ts.Auto = false
if errSetup := performGeminiCLISetup(ctx, gemClient, &ts, ""); errSetup != nil {
log.Errorf("Google One auto-discovery failed: %v", errSetup)
SetOAuthSessionError(state, "Google One auto-discovery failed")
SetOAuthSessionError(state, fmt.Sprintf("Google One auto-discovery failed: %v", errSetup))
return
}
if strings.TrimSpace(ts.ProjectID) == "" {
@@ -1337,19 +1499,19 @@ func (h *Handler) RequestGeminiCLIToken(c *gin.Context) {
isChecked, errCheck := checkCloudAPIIsEnabled(ctx, gemClient, ts.ProjectID)
if errCheck != nil {
log.Errorf("Failed to verify Cloud AI API status: %v", errCheck)
SetOAuthSessionError(state, "Failed to verify Cloud AI API status")
SetOAuthSessionError(state, fmt.Sprintf("Failed to verify Cloud AI API status: %v", errCheck))
return
}
ts.Checked = isChecked
if !isChecked {
log.Error("Cloud AI API is not enabled for the auto-discovered project")
SetOAuthSessionError(state, "Cloud AI API not enabled")
SetOAuthSessionError(state, fmt.Sprintf("Cloud AI API not enabled for project %s", ts.ProjectID))
return
}
} else {
if errEnsure := ensureGeminiProjectAndOnboard(ctx, gemClient, &ts, requestedProjectID); errEnsure != nil {
log.Errorf("Failed to complete Gemini CLI onboarding: %v", errEnsure)
SetOAuthSessionError(state, "Failed to complete Gemini CLI onboarding")
SetOAuthSessionError(state, fmt.Sprintf("Failed to complete Gemini CLI onboarding: %v", errEnsure))
return
}
@@ -1362,13 +1524,13 @@ func (h *Handler) RequestGeminiCLIToken(c *gin.Context) {
isChecked, errCheck := checkCloudAPIIsEnabled(ctx, gemClient, ts.ProjectID)
if errCheck != nil {
log.Errorf("Failed to verify Cloud AI API status: %v", errCheck)
SetOAuthSessionError(state, "Failed to verify Cloud AI API status")
SetOAuthSessionError(state, fmt.Sprintf("Failed to verify Cloud AI API status: %v", errCheck))
return
}
ts.Checked = isChecked
if !isChecked {
log.Error("Cloud AI API is not enabled for the selected project")
SetOAuthSessionError(state, "Cloud AI API not enabled")
SetOAuthSessionError(state, fmt.Sprintf("Cloud AI API not enabled for project %s", ts.ProjectID))
return
}
}
@@ -1549,6 +1711,263 @@ func (h *Handler) RequestCodexToken(c *gin.Context) {
c.JSON(200, gin.H{"status": "ok", "url": authURL, "state": state})
}
func (h *Handler) RequestGitLabToken(c *gin.Context) {
ctx := context.Background()
ctx = PopulateAuthContext(ctx, c)
fmt.Println("Initializing GitLab Duo authentication...")
baseURL := gitLabBaseURLFromRequest(c)
clientID := strings.TrimSpace(c.Query("client_id"))
clientSecret := strings.TrimSpace(c.Query("client_secret"))
if clientID == "" {
clientID = strings.TrimSpace(os.Getenv("GITLAB_OAUTH_CLIENT_ID"))
}
if clientSecret == "" {
clientSecret = strings.TrimSpace(os.Getenv("GITLAB_OAUTH_CLIENT_SECRET"))
}
if clientID == "" {
c.JSON(http.StatusBadRequest, gin.H{"error": "gitlab client_id is required"})
return
}
pkceCodes, err := gitlabauth.GeneratePKCECodes()
if err != nil {
log.Errorf("Failed to generate GitLab PKCE codes: %v", err)
c.JSON(http.StatusInternalServerError, gin.H{"error": "failed to generate PKCE codes"})
return
}
state, err := misc.GenerateRandomState()
if err != nil {
log.Errorf("Failed to generate GitLab state parameter: %v", err)
c.JSON(http.StatusInternalServerError, gin.H{"error": "failed to generate state parameter"})
return
}
redirectURI := gitlabauth.RedirectURL(gitlabauth.DefaultCallbackPort)
authClient := gitlabauth.NewAuthClient(h.cfg)
authURL, err := authClient.GenerateAuthURL(baseURL, clientID, redirectURI, state, pkceCodes)
if err != nil {
log.Errorf("Failed to generate GitLab authorization URL: %v", err)
c.JSON(http.StatusInternalServerError, gin.H{"error": "failed to generate authorization url"})
return
}
RegisterOAuthSession(state, "gitlab")
isWebUI := isWebUIRequest(c)
var forwarder *callbackForwarder
if isWebUI {
targetURL, errTarget := h.managementCallbackURL("/gitlab/callback")
if errTarget != nil {
log.WithError(errTarget).Error("failed to compute gitlab callback target")
c.JSON(http.StatusInternalServerError, gin.H{"error": "callback server unavailable"})
return
}
var errStart error
if forwarder, errStart = startCallbackForwarder(gitlabauth.DefaultCallbackPort, "gitlab", targetURL); errStart != nil {
log.WithError(errStart).Error("failed to start gitlab callback forwarder")
c.JSON(http.StatusInternalServerError, gin.H{"error": "failed to start callback server"})
return
}
}
go func() {
if isWebUI {
defer stopCallbackForwarderInstance(gitlabauth.DefaultCallbackPort, forwarder)
}
waitFile := filepath.Join(h.cfg.AuthDir, fmt.Sprintf(".oauth-gitlab-%s.oauth", state))
deadline := time.Now().Add(5 * time.Minute)
var code string
for {
if !IsOAuthSessionPending(state, "gitlab") {
return
}
if time.Now().After(deadline) {
log.Error("gitlab oauth flow timed out")
SetOAuthSessionError(state, "Timeout waiting for OAuth callback")
return
}
if data, errRead := os.ReadFile(waitFile); errRead == nil {
var payload map[string]string
_ = json.Unmarshal(data, &payload)
_ = os.Remove(waitFile)
if errStr := strings.TrimSpace(payload["error"]); errStr != "" {
SetOAuthSessionError(state, errStr)
return
}
if payloadState := strings.TrimSpace(payload["state"]); payloadState != state {
SetOAuthSessionError(state, "State code error")
return
}
code = strings.TrimSpace(payload["code"])
if code == "" {
SetOAuthSessionError(state, "Authorization code missing")
return
}
break
}
time.Sleep(500 * time.Millisecond)
}
tokenResp, errExchange := authClient.ExchangeCodeForTokens(ctx, baseURL, clientID, clientSecret, redirectURI, code, pkceCodes.CodeVerifier)
if errExchange != nil {
log.Errorf("Failed to exchange GitLab authorization code: %v", errExchange)
SetOAuthSessionError(state, "Failed to exchange authorization code for tokens")
return
}
user, errUser := authClient.GetCurrentUser(ctx, baseURL, tokenResp.AccessToken)
if errUser != nil {
log.Errorf("Failed to fetch GitLab user profile: %v", errUser)
SetOAuthSessionError(state, "Failed to fetch account profile")
return
}
direct, errDirect := authClient.FetchDirectAccess(ctx, baseURL, tokenResp.AccessToken)
if errDirect != nil {
log.Errorf("Failed to fetch GitLab direct access metadata: %v", errDirect)
SetOAuthSessionError(state, "Failed to fetch GitLab Duo access")
return
}
identifier := gitLabAccountIdentifier(user)
fileName := fmt.Sprintf("gitlab-%s.json", sanitizeGitLabFileName(identifier))
metadata := buildGitLabAuthMetadata(baseURL, gitLabLoginModeOAuth, tokenResp, direct)
metadata["auth_kind"] = "oauth"
metadata["oauth_client_id"] = clientID
if clientSecret != "" {
metadata["oauth_client_secret"] = clientSecret
}
metadata["username"] = strings.TrimSpace(user.Username)
if email := primaryGitLabEmail(user); email != "" {
metadata["email"] = email
}
metadata["name"] = strings.TrimSpace(user.Name)
record := &coreauth.Auth{
ID: fileName,
Provider: "gitlab",
FileName: fileName,
Label: identifier,
Metadata: metadata,
}
savedPath, errSave := h.saveTokenRecord(ctx, record)
if errSave != nil {
log.Errorf("Failed to save GitLab auth record: %v", errSave)
SetOAuthSessionError(state, "Failed to save authentication tokens")
return
}
fmt.Printf("GitLab Duo authentication successful. Token saved to %s\n", savedPath)
CompleteOAuthSession(state)
CompleteOAuthSessionsByProvider("gitlab")
}()
c.JSON(http.StatusOK, gin.H{"status": "ok", "url": authURL, "state": state})
}
func (h *Handler) RequestGitLabPATToken(c *gin.Context) {
ctx := context.Background()
ctx = PopulateAuthContext(ctx, c)
var payload struct {
BaseURL string `json:"base_url"`
PersonalAccessToken string `json:"personal_access_token"`
Token string `json:"token"`
}
if err := c.ShouldBindJSON(&payload); err != nil {
c.JSON(http.StatusBadRequest, gin.H{"status": "error", "error": "invalid body"})
return
}
baseURL := gitlabauth.NormalizeBaseURL(strings.TrimSpace(payload.BaseURL))
if baseURL == "" {
baseURL = gitLabBaseURLFromRequest(nil)
}
pat := strings.TrimSpace(payload.PersonalAccessToken)
if pat == "" {
pat = strings.TrimSpace(payload.Token)
}
if pat == "" {
c.JSON(http.StatusBadRequest, gin.H{"status": "error", "error": "personal_access_token is required"})
return
}
authClient := gitlabauth.NewAuthClient(h.cfg)
user, err := authClient.GetCurrentUser(ctx, baseURL, pat)
if err != nil {
c.JSON(http.StatusBadRequest, gin.H{"status": "error", "error": err.Error()})
return
}
patSelf, err := authClient.GetPersonalAccessTokenSelf(ctx, baseURL, pat)
if err != nil {
c.JSON(http.StatusBadRequest, gin.H{"status": "error", "error": err.Error()})
return
}
direct, err := authClient.FetchDirectAccess(ctx, baseURL, pat)
if err != nil {
c.JSON(http.StatusBadRequest, gin.H{"status": "error", "error": err.Error()})
return
}
identifier := gitLabAccountIdentifier(user)
fileName := fmt.Sprintf("gitlab-%s-pat.json", sanitizeGitLabFileName(identifier))
metadata := buildGitLabAuthMetadata(baseURL, gitLabLoginModePAT, nil, direct)
metadata["auth_kind"] = "personal_access_token"
metadata["personal_access_token"] = pat
metadata["token_preview"] = maskGitLabToken(pat)
metadata["username"] = strings.TrimSpace(user.Username)
if email := primaryGitLabEmail(user); email != "" {
metadata["email"] = email
}
metadata["name"] = strings.TrimSpace(user.Name)
if patSelf != nil {
if name := strings.TrimSpace(patSelf.Name); name != "" {
metadata["pat_name"] = name
}
if len(patSelf.Scopes) > 0 {
metadata["pat_scopes"] = append([]string(nil), patSelf.Scopes...)
}
}
record := &coreauth.Auth{
ID: fileName,
Provider: "gitlab",
FileName: fileName,
Label: identifier + " (PAT)",
Metadata: metadata,
}
savedPath, err := h.saveTokenRecord(ctx, record)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"status": "error", "error": "failed to save authentication tokens"})
return
}
response := gin.H{
"status": "ok",
"saved_path": savedPath,
"username": strings.TrimSpace(user.Username),
"email": primaryGitLabEmail(user),
"token_label": identifier,
}
if direct != nil && direct.ModelDetails != nil {
if provider := strings.TrimSpace(direct.ModelDetails.ModelProvider); provider != "" {
response["model_provider"] = provider
}
if model := strings.TrimSpace(direct.ModelDetails.ModelName); model != "" {
response["model_name"] = model
}
}
fmt.Printf("GitLab Duo PAT authentication successful. Token saved to %s\n", savedPath)
c.JSON(http.StatusOK, response)
}
func (h *Handler) RequestAntigravityToken(c *gin.Context) {
ctx := context.Background()
ctx = PopulateAuthContext(ctx, c)

View File

@@ -0,0 +1,164 @@
package management
import (
"encoding/json"
"net/http"
"net/http/httptest"
"os"
"path/filepath"
"strings"
"testing"
"github.com/gin-gonic/gin"
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
coreauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth"
)
func TestRequestGitLabPATToken_SavesAuthRecord(t *testing.T) {
t.Setenv("MANAGEMENT_PASSWORD", "")
gin.SetMode(gin.TestMode)
upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if got := r.Header.Get("Authorization"); got != "Bearer glpat-test-token" {
t.Fatalf("authorization header = %q, want Bearer glpat-test-token", got)
}
w.Header().Set("Content-Type", "application/json")
switch r.URL.Path {
case "/api/v4/user":
_ = json.NewEncoder(w).Encode(map[string]any{
"id": 42,
"username": "gitlab-user",
"name": "GitLab User",
"email": "gitlab@example.com",
})
case "/api/v4/personal_access_tokens/self":
_ = json.NewEncoder(w).Encode(map[string]any{
"id": 7,
"name": "management-center",
"scopes": []string{"api", "read_user"},
"user_id": 42,
})
case "/api/v4/code_suggestions/direct_access":
_ = json.NewEncoder(w).Encode(map[string]any{
"base_url": "https://cloud.gitlab.example.com",
"token": "gateway-token",
"expires_at": 1893456000,
"headers": map[string]string{
"X-Gitlab-Realm": "saas",
},
"model_details": map[string]any{
"model_provider": "anthropic",
"model_name": "claude-sonnet-4-5",
},
})
default:
http.NotFound(w, r)
}
}))
defer upstream.Close()
store := &memoryAuthStore{}
h := NewHandlerWithoutConfigFilePath(&config.Config{AuthDir: t.TempDir()}, coreauth.NewManager(nil, nil, nil))
h.tokenStore = store
rec := httptest.NewRecorder()
ctx, _ := gin.CreateTestContext(rec)
ctx.Request = httptest.NewRequest(http.MethodPost, "/v0/management/gitlab-auth-url", strings.NewReader(`{"base_url":"`+upstream.URL+`","personal_access_token":"glpat-test-token"}`))
ctx.Request.Header.Set("Content-Type", "application/json")
h.RequestGitLabPATToken(ctx)
if rec.Code != http.StatusOK {
t.Fatalf("expected status %d, got %d with body %s", http.StatusOK, rec.Code, rec.Body.String())
}
var resp map[string]any
if err := json.Unmarshal(rec.Body.Bytes(), &resp); err != nil {
t.Fatalf("decode response: %v", err)
}
if got := resp["status"]; got != "ok" {
t.Fatalf("status = %#v, want ok", got)
}
if got := resp["model_provider"]; got != "anthropic" {
t.Fatalf("model_provider = %#v, want anthropic", got)
}
if got := resp["model_name"]; got != "claude-sonnet-4-5" {
t.Fatalf("model_name = %#v, want claude-sonnet-4-5", got)
}
store.mu.Lock()
defer store.mu.Unlock()
if len(store.items) != 1 {
t.Fatalf("expected 1 saved auth record, got %d", len(store.items))
}
var saved *coreauth.Auth
for _, item := range store.items {
saved = item
}
if saved == nil {
t.Fatal("expected saved auth record")
}
if saved.Provider != "gitlab" {
t.Fatalf("provider = %q, want gitlab", saved.Provider)
}
if got := saved.Metadata["auth_kind"]; got != "personal_access_token" {
t.Fatalf("auth_kind = %#v, want personal_access_token", got)
}
if got := saved.Metadata["model_provider"]; got != "anthropic" {
t.Fatalf("saved model_provider = %#v, want anthropic", got)
}
if got := saved.Metadata["duo_gateway_token"]; got != "gateway-token" {
t.Fatalf("saved duo_gateway_token = %#v, want gateway-token", got)
}
}
func TestPostOAuthCallback_GitLabWritesPendingCallbackFile(t *testing.T) {
t.Setenv("MANAGEMENT_PASSWORD", "")
gin.SetMode(gin.TestMode)
authDir := t.TempDir()
state := "gitlab-state-123"
RegisterOAuthSession(state, "gitlab")
t.Cleanup(func() { CompleteOAuthSession(state) })
h := NewHandlerWithoutConfigFilePath(&config.Config{AuthDir: authDir}, coreauth.NewManager(nil, nil, nil))
rec := httptest.NewRecorder()
ctx, _ := gin.CreateTestContext(rec)
ctx.Request = httptest.NewRequest(http.MethodPost, "/v0/management/oauth-callback", strings.NewReader(`{"provider":"gitlab","redirect_url":"http://localhost:17171/auth/callback?code=test-code&state=`+state+`"}`))
ctx.Request.Header.Set("Content-Type", "application/json")
h.PostOAuthCallback(ctx)
if rec.Code != http.StatusOK {
t.Fatalf("expected status %d, got %d with body %s", http.StatusOK, rec.Code, rec.Body.String())
}
filePath := filepath.Join(authDir, ".oauth-gitlab-"+state+".oauth")
data, err := os.ReadFile(filePath)
if err != nil {
t.Fatalf("read callback file: %v", err)
}
var payload map[string]string
if err := json.Unmarshal(data, &payload); err != nil {
t.Fatalf("decode callback payload: %v", err)
}
if got := payload["code"]; got != "test-code" {
t.Fatalf("callback code = %q, want test-code", got)
}
if got := payload["state"]; got != state {
t.Fatalf("callback state = %q, want %q", got, state)
}
}
func TestNormalizeOAuthProvider_GitLab(t *testing.T) {
provider, err := NormalizeOAuthProvider("gitlab")
if err != nil {
t.Fatalf("NormalizeOAuthProvider returned error: %v", err)
}
if provider != "gitlab" {
t.Fatalf("provider = %q, want gitlab", provider)
}
}

View File

@@ -228,6 +228,8 @@ func NormalizeOAuthProvider(provider string) (string, error) {
return "anthropic", nil
case "codex", "openai":
return "codex", nil
case "gitlab":
return "gitlab", nil
case "gemini", "google":
return "gemini", nil
case "iflow", "i-flow":

View File

@@ -0,0 +1,49 @@
package management
import (
"context"
"sync"
coreauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth"
)
type memoryAuthStore struct {
mu sync.Mutex
items map[string]*coreauth.Auth
}
func (s *memoryAuthStore) List(_ context.Context) ([]*coreauth.Auth, error) {
s.mu.Lock()
defer s.mu.Unlock()
out := make([]*coreauth.Auth, 0, len(s.items))
for _, item := range s.items {
out = append(out, item)
}
return out, nil
}
func (s *memoryAuthStore) Save(_ context.Context, auth *coreauth.Auth) (string, error) {
if auth == nil {
return "", nil
}
s.mu.Lock()
defer s.mu.Unlock()
if s.items == nil {
s.items = make(map[string]*coreauth.Auth)
}
s.items[auth.ID] = auth
return auth.ID, nil
}
func (s *memoryAuthStore) Delete(_ context.Context, id string) error {
s.mu.Lock()
defer s.mu.Unlock()
delete(s.items, id)
return nil
}
func (s *memoryAuthStore) SetBaseDir(string) {}

View File

@@ -403,6 +403,20 @@ func (s *Server) setupRoutes() {
c.String(http.StatusOK, oauthCallbackSuccessHTML)
})
s.engine.GET("/gitlab/callback", func(c *gin.Context) {
code := c.Query("code")
state := c.Query("state")
errStr := c.Query("error")
if errStr == "" {
errStr = c.Query("error_description")
}
if state != "" {
_, _ = managementHandlers.WriteOAuthCallbackFileForPendingSession(s.cfg.AuthDir, "gitlab", state, code, errStr)
}
c.Header("Content-Type", "text/html; charset=utf-8")
c.String(http.StatusOK, oauthCallbackSuccessHTML)
})
s.engine.GET("/google/callback", func(c *gin.Context) {
code := c.Query("code")
state := c.Query("state")
@@ -658,6 +672,8 @@ func (s *Server) registerManagementRoutes() {
mgmt.GET("/anthropic-auth-url", s.mgmt.RequestAnthropicToken)
mgmt.GET("/codex-auth-url", s.mgmt.RequestCodexToken)
mgmt.GET("/gitlab-auth-url", s.mgmt.RequestGitLabToken)
mgmt.POST("/gitlab-auth-url", s.mgmt.RequestGitLabPATToken)
mgmt.GET("/gemini-cli-auth-url", s.mgmt.RequestGeminiCLIToken)
mgmt.GET("/antigravity-auth-url", s.mgmt.RequestAntigravityToken)
mgmt.GET("/qwen-auth-url", s.mgmt.RequestQwenToken)

View File

@@ -4,12 +4,12 @@ package claude
import (
"net/http"
"net/url"
"strings"
"sync"
tls "github.com/refraction-networking/utls"
"github.com/router-for-me/CLIProxyAPI/v6/sdk/config"
"github.com/router-for-me/CLIProxyAPI/v6/sdk/proxyutil"
log "github.com/sirupsen/logrus"
"golang.org/x/net/http2"
"golang.org/x/net/proxy"
@@ -31,17 +31,12 @@ type utlsRoundTripper struct {
// newUtlsRoundTripper creates a new utls-based round tripper with optional proxy support
func newUtlsRoundTripper(cfg *config.SDKConfig) *utlsRoundTripper {
var dialer proxy.Dialer = proxy.Direct
if cfg != nil && cfg.ProxyURL != "" {
proxyURL, err := url.Parse(cfg.ProxyURL)
if err != nil {
log.Errorf("failed to parse proxy URL %q: %v", cfg.ProxyURL, err)
} else {
pDialer, err := proxy.FromURL(proxyURL, proxy.Direct)
if err != nil {
log.Errorf("failed to create proxy dialer for %q: %v", cfg.ProxyURL, err)
} else {
dialer = pDialer
}
if cfg != nil {
proxyDialer, mode, errBuild := proxyutil.BuildDialer(cfg.ProxyURL)
if errBuild != nil {
log.Errorf("failed to configure proxy dialer for %q: %v", cfg.ProxyURL, errBuild)
} else if mode != proxyutil.ModeInherit && proxyDialer != nil {
dialer = proxyDialer
}
}

View File

@@ -10,9 +10,7 @@ import (
"errors"
"fmt"
"io"
"net"
"net/http"
"net/url"
"time"
"github.com/router-for-me/CLIProxyAPI/v6/internal/auth/codex"
@@ -20,9 +18,9 @@ import (
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
"github.com/router-for-me/CLIProxyAPI/v6/internal/misc"
"github.com/router-for-me/CLIProxyAPI/v6/internal/util"
"github.com/router-for-me/CLIProxyAPI/v6/sdk/proxyutil"
log "github.com/sirupsen/logrus"
"github.com/tidwall/gjson"
"golang.org/x/net/proxy"
"golang.org/x/oauth2"
"golang.org/x/oauth2/google"
@@ -80,36 +78,16 @@ func (g *GeminiAuth) GetAuthenticatedClient(ctx context.Context, ts *GeminiToken
}
callbackURL := fmt.Sprintf("http://localhost:%d/oauth2callback", callbackPort)
// Configure proxy settings for the HTTP client if a proxy URL is provided.
proxyURL, err := url.Parse(cfg.ProxyURL)
if err == nil {
var transport *http.Transport
if proxyURL.Scheme == "socks5" {
// Handle SOCKS5 proxy.
username := proxyURL.User.Username()
password, _ := proxyURL.User.Password()
auth := &proxy.Auth{User: username, Password: password}
dialer, errSOCKS5 := proxy.SOCKS5("tcp", proxyURL.Host, auth, proxy.Direct)
if errSOCKS5 != nil {
log.Errorf("create SOCKS5 dialer failed: %v", errSOCKS5)
return nil, fmt.Errorf("create SOCKS5 dialer failed: %w", errSOCKS5)
}
transport = &http.Transport{
DialContext: func(ctx context.Context, network, addr string) (net.Conn, error) {
return dialer.Dial(network, addr)
},
}
} else if proxyURL.Scheme == "http" || proxyURL.Scheme == "https" {
// Handle HTTP/HTTPS proxy.
transport = &http.Transport{Proxy: http.ProxyURL(proxyURL)}
}
if transport != nil {
proxyClient := &http.Client{Transport: transport}
ctx = context.WithValue(ctx, oauth2.HTTPClient, proxyClient)
}
transport, _, errBuild := proxyutil.BuildHTTPTransport(cfg.ProxyURL)
if errBuild != nil {
log.Errorf("%v", errBuild)
} else if transport != nil {
proxyClient := &http.Client{Transport: transport}
ctx = context.WithValue(ctx, oauth2.HTTPClient, proxyClient)
}
var err error
// Configure the OAuth2 client.
conf := &oauth2.Config{
ClientID: ClientID,

View File

@@ -0,0 +1,492 @@
package gitlab
import (
"context"
"crypto/rand"
"crypto/sha256"
"encoding/base64"
"encoding/json"
"fmt"
"io"
"net"
"net/http"
"net/url"
"strconv"
"strings"
"sync"
"time"
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
"github.com/router-for-me/CLIProxyAPI/v6/internal/util"
log "github.com/sirupsen/logrus"
)
const (
DefaultBaseURL = "https://gitlab.com"
DefaultCallbackPort = 17171
defaultOAuthScope = "api read_user"
)
type PKCECodes struct {
CodeVerifier string
CodeChallenge string
}
type OAuthResult struct {
Code string
State string
Error string
}
type OAuthServer struct {
server *http.Server
port int
resultChan chan *OAuthResult
errorChan chan error
mu sync.Mutex
running bool
}
type TokenResponse struct {
AccessToken string `json:"access_token"`
TokenType string `json:"token_type"`
RefreshToken string `json:"refresh_token"`
Scope string `json:"scope"`
CreatedAt int64 `json:"created_at"`
ExpiresIn int `json:"expires_in"`
}
type User struct {
ID int64 `json:"id"`
Username string `json:"username"`
Name string `json:"name"`
Email string `json:"email"`
PublicEmail string `json:"public_email"`
}
type PersonalAccessTokenSelf struct {
ID int64 `json:"id"`
Name string `json:"name"`
Scopes []string `json:"scopes"`
UserID int64 `json:"user_id"`
}
type ModelDetails struct {
ModelProvider string `json:"model_provider"`
ModelName string `json:"model_name"`
}
type DirectAccessResponse struct {
BaseURL string `json:"base_url"`
Token string `json:"token"`
ExpiresAt int64 `json:"expires_at"`
Headers map[string]string `json:"headers"`
ModelDetails *ModelDetails `json:"model_details,omitempty"`
}
type DiscoveredModel struct {
ModelProvider string
ModelName string
}
type AuthClient struct {
httpClient *http.Client
}
func NewAuthClient(cfg *config.Config) *AuthClient {
client := &http.Client{}
if cfg != nil {
client = util.SetProxy(&cfg.SDKConfig, client)
}
return &AuthClient{httpClient: client}
}
func NormalizeBaseURL(raw string) string {
value := strings.TrimSpace(raw)
if value == "" {
return DefaultBaseURL
}
if !strings.Contains(value, "://") {
value = "https://" + value
}
value = strings.TrimRight(value, "/")
return value
}
func TokenExpiry(now time.Time, token *TokenResponse) time.Time {
if token == nil {
return time.Time{}
}
if token.CreatedAt > 0 && token.ExpiresIn > 0 {
return time.Unix(token.CreatedAt+int64(token.ExpiresIn), 0).UTC()
}
if token.ExpiresIn > 0 {
return now.UTC().Add(time.Duration(token.ExpiresIn) * time.Second)
}
return time.Time{}
}
func GeneratePKCECodes() (*PKCECodes, error) {
verifierBytes := make([]byte, 32)
if _, err := rand.Read(verifierBytes); err != nil {
return nil, fmt.Errorf("gitlab pkce generation failed: %w", err)
}
verifier := base64.RawURLEncoding.EncodeToString(verifierBytes)
sum := sha256.Sum256([]byte(verifier))
challenge := base64.RawURLEncoding.EncodeToString(sum[:])
return &PKCECodes{
CodeVerifier: verifier,
CodeChallenge: challenge,
}, nil
}
func NewOAuthServer(port int) *OAuthServer {
return &OAuthServer{
port: port,
resultChan: make(chan *OAuthResult, 1),
errorChan: make(chan error, 1),
}
}
func (s *OAuthServer) Start() error {
s.mu.Lock()
defer s.mu.Unlock()
if s.running {
return fmt.Errorf("gitlab oauth server already running")
}
if !s.isPortAvailable() {
return fmt.Errorf("port %d is already in use", s.port)
}
mux := http.NewServeMux()
mux.HandleFunc("/auth/callback", s.handleCallback)
s.server = &http.Server{
Addr: fmt.Sprintf(":%d", s.port),
Handler: mux,
ReadTimeout: 10 * time.Second,
WriteTimeout: 10 * time.Second,
}
s.running = true
go func() {
if err := s.server.ListenAndServe(); err != nil && err != http.ErrServerClosed {
s.errorChan <- err
}
}()
time.Sleep(100 * time.Millisecond)
return nil
}
func (s *OAuthServer) Stop(ctx context.Context) error {
s.mu.Lock()
defer s.mu.Unlock()
if !s.running || s.server == nil {
return nil
}
defer func() {
s.running = false
s.server = nil
}()
return s.server.Shutdown(ctx)
}
func (s *OAuthServer) WaitForCallback(timeout time.Duration) (*OAuthResult, error) {
select {
case result := <-s.resultChan:
return result, nil
case err := <-s.errorChan:
return nil, err
case <-time.After(timeout):
return nil, fmt.Errorf("timeout waiting for OAuth callback")
}
}
func (s *OAuthServer) handleCallback(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodGet {
http.Error(w, "method not allowed", http.StatusMethodNotAllowed)
return
}
query := r.URL.Query()
if errParam := strings.TrimSpace(query.Get("error")); errParam != "" {
s.sendResult(&OAuthResult{Error: errParam})
http.Error(w, errParam, http.StatusBadRequest)
return
}
code := strings.TrimSpace(query.Get("code"))
state := strings.TrimSpace(query.Get("state"))
if code == "" || state == "" {
s.sendResult(&OAuthResult{Error: "missing_code_or_state"})
http.Error(w, "missing code or state", http.StatusBadRequest)
return
}
s.sendResult(&OAuthResult{Code: code, State: state})
_, _ = w.Write([]byte("GitLab authentication received. You can close this tab."))
}
func (s *OAuthServer) sendResult(result *OAuthResult) {
select {
case s.resultChan <- result:
default:
log.Debug("gitlab oauth result channel full, dropping callback result")
}
}
func (s *OAuthServer) isPortAvailable() bool {
listener, err := net.Listen("tcp", fmt.Sprintf(":%d", s.port))
if err != nil {
return false
}
_ = listener.Close()
return true
}
func RedirectURL(port int) string {
return fmt.Sprintf("http://localhost:%d/auth/callback", port)
}
func (c *AuthClient) GenerateAuthURL(baseURL, clientID, redirectURI, state string, pkce *PKCECodes) (string, error) {
if pkce == nil {
return "", fmt.Errorf("gitlab auth URL generation failed: PKCE codes are required")
}
if strings.TrimSpace(clientID) == "" {
return "", fmt.Errorf("gitlab auth URL generation failed: client ID is required")
}
baseURL = NormalizeBaseURL(baseURL)
params := url.Values{
"client_id": {strings.TrimSpace(clientID)},
"response_type": {"code"},
"redirect_uri": {strings.TrimSpace(redirectURI)},
"scope": {defaultOAuthScope},
"state": {strings.TrimSpace(state)},
"code_challenge": {pkce.CodeChallenge},
"code_challenge_method": {"S256"},
}
return fmt.Sprintf("%s/oauth/authorize?%s", baseURL, params.Encode()), nil
}
func (c *AuthClient) ExchangeCodeForTokens(ctx context.Context, baseURL, clientID, clientSecret, redirectURI, code, codeVerifier string) (*TokenResponse, error) {
form := url.Values{
"grant_type": {"authorization_code"},
"client_id": {strings.TrimSpace(clientID)},
"code": {strings.TrimSpace(code)},
"redirect_uri": {strings.TrimSpace(redirectURI)},
"code_verifier": {strings.TrimSpace(codeVerifier)},
}
if secret := strings.TrimSpace(clientSecret); secret != "" {
form.Set("client_secret", secret)
}
return c.postToken(ctx, NormalizeBaseURL(baseURL)+"/oauth/token", form)
}
func (c *AuthClient) RefreshTokens(ctx context.Context, baseURL, clientID, clientSecret, refreshToken string) (*TokenResponse, error) {
form := url.Values{
"grant_type": {"refresh_token"},
"refresh_token": {strings.TrimSpace(refreshToken)},
}
if clientID = strings.TrimSpace(clientID); clientID != "" {
form.Set("client_id", clientID)
}
if secret := strings.TrimSpace(clientSecret); secret != "" {
form.Set("client_secret", secret)
}
return c.postToken(ctx, NormalizeBaseURL(baseURL)+"/oauth/token", form)
}
func (c *AuthClient) postToken(ctx context.Context, tokenURL string, form url.Values) (*TokenResponse, error) {
req, err := http.NewRequestWithContext(ctx, http.MethodPost, tokenURL, strings.NewReader(form.Encode()))
if err != nil {
return nil, fmt.Errorf("gitlab token request failed: %w", err)
}
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
req.Header.Set("Accept", "application/json")
resp, err := c.httpClient.Do(req)
if err != nil {
return nil, fmt.Errorf("gitlab token request failed: %w", err)
}
defer func() { _ = resp.Body.Close() }()
body, err := io.ReadAll(resp.Body)
if err != nil {
return nil, fmt.Errorf("gitlab token response read failed: %w", err)
}
if resp.StatusCode < 200 || resp.StatusCode >= 300 {
return nil, fmt.Errorf("gitlab token request failed with status %d: %s", resp.StatusCode, strings.TrimSpace(string(body)))
}
var token TokenResponse
if err := json.Unmarshal(body, &token); err != nil {
return nil, fmt.Errorf("gitlab token response decode failed: %w", err)
}
return &token, nil
}
func (c *AuthClient) GetCurrentUser(ctx context.Context, baseURL, token string) (*User, error) {
req, err := http.NewRequestWithContext(ctx, http.MethodGet, NormalizeBaseURL(baseURL)+"/api/v4/user", nil)
if err != nil {
return nil, fmt.Errorf("gitlab user request failed: %w", err)
}
req.Header.Set("Authorization", "Bearer "+strings.TrimSpace(token))
req.Header.Set("Accept", "application/json")
resp, err := c.httpClient.Do(req)
if err != nil {
return nil, fmt.Errorf("gitlab user request failed: %w", err)
}
defer func() { _ = resp.Body.Close() }()
body, err := io.ReadAll(resp.Body)
if err != nil {
return nil, fmt.Errorf("gitlab user response read failed: %w", err)
}
if resp.StatusCode < 200 || resp.StatusCode >= 300 {
return nil, fmt.Errorf("gitlab user request failed with status %d: %s", resp.StatusCode, strings.TrimSpace(string(body)))
}
var user User
if err := json.Unmarshal(body, &user); err != nil {
return nil, fmt.Errorf("gitlab user response decode failed: %w", err)
}
return &user, nil
}
func (c *AuthClient) GetPersonalAccessTokenSelf(ctx context.Context, baseURL, token string) (*PersonalAccessTokenSelf, error) {
req, err := http.NewRequestWithContext(ctx, http.MethodGet, NormalizeBaseURL(baseURL)+"/api/v4/personal_access_tokens/self", nil)
if err != nil {
return nil, fmt.Errorf("gitlab PAT self request failed: %w", err)
}
req.Header.Set("Authorization", "Bearer "+strings.TrimSpace(token))
req.Header.Set("Accept", "application/json")
resp, err := c.httpClient.Do(req)
if err != nil {
return nil, fmt.Errorf("gitlab PAT self request failed: %w", err)
}
defer func() { _ = resp.Body.Close() }()
body, err := io.ReadAll(resp.Body)
if err != nil {
return nil, fmt.Errorf("gitlab PAT self response read failed: %w", err)
}
if resp.StatusCode < 200 || resp.StatusCode >= 300 {
return nil, fmt.Errorf("gitlab PAT self request failed with status %d: %s", resp.StatusCode, strings.TrimSpace(string(body)))
}
var pat PersonalAccessTokenSelf
if err := json.Unmarshal(body, &pat); err != nil {
return nil, fmt.Errorf("gitlab PAT self response decode failed: %w", err)
}
return &pat, nil
}
func (c *AuthClient) FetchDirectAccess(ctx context.Context, baseURL, token string) (*DirectAccessResponse, error) {
req, err := http.NewRequestWithContext(ctx, http.MethodPost, NormalizeBaseURL(baseURL)+"/api/v4/code_suggestions/direct_access", nil)
if err != nil {
return nil, fmt.Errorf("gitlab direct access request failed: %w", err)
}
req.Header.Set("Authorization", "Bearer "+strings.TrimSpace(token))
req.Header.Set("Accept", "application/json")
resp, err := c.httpClient.Do(req)
if err != nil {
return nil, fmt.Errorf("gitlab direct access request failed: %w", err)
}
defer func() { _ = resp.Body.Close() }()
body, err := io.ReadAll(resp.Body)
if err != nil {
return nil, fmt.Errorf("gitlab direct access response read failed: %w", err)
}
if resp.StatusCode < 200 || resp.StatusCode >= 300 {
return nil, fmt.Errorf("gitlab direct access request failed with status %d: %s", resp.StatusCode, strings.TrimSpace(string(body)))
}
var direct DirectAccessResponse
if err := json.Unmarshal(body, &direct); err != nil {
return nil, fmt.Errorf("gitlab direct access response decode failed: %w", err)
}
if direct.Headers == nil {
direct.Headers = make(map[string]string)
}
return &direct, nil
}
func ExtractDiscoveredModels(metadata map[string]any) []DiscoveredModel {
if len(metadata) == 0 {
return nil
}
models := make([]DiscoveredModel, 0, 4)
seen := make(map[string]struct{})
appendModel := func(provider, name string) {
provider = strings.TrimSpace(provider)
name = strings.TrimSpace(name)
if name == "" {
return
}
key := strings.ToLower(name)
if _, ok := seen[key]; ok {
return
}
seen[key] = struct{}{}
models = append(models, DiscoveredModel{
ModelProvider: provider,
ModelName: name,
})
}
if raw, ok := metadata["model_details"]; ok {
appendDiscoveredModels(raw, appendModel)
}
appendModel(stringValue(metadata["model_provider"]), stringValue(metadata["model_name"]))
for _, key := range []string{"models", "supported_models", "discovered_models"} {
if raw, ok := metadata[key]; ok {
appendDiscoveredModels(raw, appendModel)
}
}
return models
}
func appendDiscoveredModels(raw any, appendModel func(provider, name string)) {
switch typed := raw.(type) {
case map[string]any:
appendModel(stringValue(typed["model_provider"]), stringValue(typed["model_name"]))
appendModel(stringValue(typed["provider"]), stringValue(typed["name"]))
if nested, ok := typed["models"]; ok {
appendDiscoveredModels(nested, appendModel)
}
case []any:
for _, item := range typed {
appendDiscoveredModels(item, appendModel)
}
case []string:
for _, item := range typed {
appendModel("", item)
}
case string:
appendModel("", typed)
}
}
func stringValue(raw any) string {
switch typed := raw.(type) {
case string:
return strings.TrimSpace(typed)
case fmt.Stringer:
return strings.TrimSpace(typed.String())
case json.Number:
return typed.String()
case int:
return strconv.Itoa(typed)
case int64:
return strconv.FormatInt(typed, 10)
case float64:
return strconv.FormatInt(int64(typed), 10)
default:
return ""
}
}

View File

@@ -0,0 +1,138 @@
package gitlab
import (
"context"
"encoding/json"
"net/http"
"net/http/httptest"
"net/url"
"strings"
"testing"
)
func TestAuthClientGenerateAuthURLIncludesPKCE(t *testing.T) {
client := NewAuthClient(nil)
pkce, err := GeneratePKCECodes()
if err != nil {
t.Fatalf("GeneratePKCECodes() error = %v", err)
}
rawURL, err := client.GenerateAuthURL("https://gitlab.example.com", "client-id", RedirectURL(17171), "state-123", pkce)
if err != nil {
t.Fatalf("GenerateAuthURL() error = %v", err)
}
parsed, err := url.Parse(rawURL)
if err != nil {
t.Fatalf("Parse(authURL) error = %v", err)
}
if got := parsed.Path; got != "/oauth/authorize" {
t.Fatalf("expected /oauth/authorize path, got %q", got)
}
query := parsed.Query()
if got := query.Get("client_id"); got != "client-id" {
t.Fatalf("expected client_id, got %q", got)
}
if got := query.Get("scope"); got != defaultOAuthScope {
t.Fatalf("expected scope %q, got %q", defaultOAuthScope, got)
}
if got := query.Get("code_challenge_method"); got != "S256" {
t.Fatalf("expected PKCE method S256, got %q", got)
}
if got := query.Get("code_challenge"); got == "" {
t.Fatal("expected non-empty code_challenge")
}
}
func TestAuthClientExchangeCodeForTokens(t *testing.T) {
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.URL.Path != "/oauth/token" {
t.Fatalf("unexpected path %q", r.URL.Path)
}
if err := r.ParseForm(); err != nil {
t.Fatalf("ParseForm() error = %v", err)
}
if got := r.Form.Get("grant_type"); got != "authorization_code" {
t.Fatalf("expected authorization_code grant, got %q", got)
}
if got := r.Form.Get("code_verifier"); got != "verifier-123" {
t.Fatalf("expected code_verifier, got %q", got)
}
_ = json.NewEncoder(w).Encode(map[string]any{
"access_token": "oauth-access",
"refresh_token": "oauth-refresh",
"token_type": "Bearer",
"scope": "api read_user",
"created_at": 1710000000,
"expires_in": 3600,
})
}))
defer srv.Close()
client := NewAuthClient(nil)
token, err := client.ExchangeCodeForTokens(context.Background(), srv.URL, "client-id", "client-secret", RedirectURL(17171), "auth-code", "verifier-123")
if err != nil {
t.Fatalf("ExchangeCodeForTokens() error = %v", err)
}
if token.AccessToken != "oauth-access" {
t.Fatalf("expected access token, got %q", token.AccessToken)
}
if token.RefreshToken != "oauth-refresh" {
t.Fatalf("expected refresh token, got %q", token.RefreshToken)
}
}
func TestExtractDiscoveredModels(t *testing.T) {
models := ExtractDiscoveredModels(map[string]any{
"model_details": map[string]any{
"model_provider": "anthropic",
"model_name": "claude-sonnet-4-5",
},
"supported_models": []any{
map[string]any{"model_provider": "openai", "model_name": "gpt-4.1"},
"claude-sonnet-4-5",
},
})
if len(models) != 2 {
t.Fatalf("expected 2 unique models, got %d", len(models))
}
if models[0].ModelName != "claude-sonnet-4-5" {
t.Fatalf("unexpected first model %q", models[0].ModelName)
}
if models[1].ModelName != "gpt-4.1" {
t.Fatalf("unexpected second model %q", models[1].ModelName)
}
}
func TestFetchDirectAccessDecodesModelDetails(t *testing.T) {
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.URL.Path != "/api/v4/code_suggestions/direct_access" {
t.Fatalf("unexpected path %q", r.URL.Path)
}
if got := r.Header.Get("Authorization"); !strings.Contains(got, "token-123") {
t.Fatalf("expected bearer token, got %q", got)
}
_ = json.NewEncoder(w).Encode(map[string]any{
"base_url": "https://cloud.gitlab.example.com",
"token": "gateway-token",
"expires_at": 1710003600,
"headers": map[string]string{
"X-Gitlab-Realm": "saas",
},
"model_details": map[string]any{
"model_provider": "anthropic",
"model_name": "claude-sonnet-4-5",
},
})
}))
defer srv.Close()
client := NewAuthClient(nil)
direct, err := client.FetchDirectAccess(context.Background(), srv.URL, "token-123")
if err != nil {
t.Fatalf("FetchDirectAccess() error = %v", err)
}
if direct.ModelDetails == nil || direct.ModelDetails.ModelName != "claude-sonnet-4-5" {
t.Fatalf("expected model details, got %+v", direct.ModelDetails)
}
}

View File

@@ -23,6 +23,7 @@ func newAuthManager() *sdkAuth.Manager {
sdkAuth.NewKiroAuthenticator(),
sdkAuth.NewGitHubCopilotAuthenticator(),
sdkAuth.NewKiloAuthenticator(),
sdkAuth.NewGitLabAuthenticator(),
)
return manager
}

View File

@@ -0,0 +1,69 @@
package cmd
import (
"context"
"fmt"
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
sdkAuth "github.com/router-for-me/CLIProxyAPI/v6/sdk/auth"
)
func DoGitLabLogin(cfg *config.Config, options *LoginOptions) {
if options == nil {
options = &LoginOptions{}
}
promptFn := options.Prompt
if promptFn == nil {
promptFn = defaultProjectPrompt()
}
manager := newAuthManager()
authOpts := &sdkAuth.LoginOptions{
NoBrowser: options.NoBrowser,
CallbackPort: options.CallbackPort,
Metadata: map[string]string{
"login_mode": "oauth",
},
Prompt: promptFn,
}
_, savedPath, err := manager.Login(context.Background(), "gitlab", cfg, authOpts)
if err != nil {
fmt.Printf("GitLab Duo authentication failed: %v\n", err)
return
}
if savedPath != "" {
fmt.Printf("Authentication saved to %s\n", savedPath)
}
fmt.Println("GitLab Duo authentication successful!")
}
func DoGitLabTokenLogin(cfg *config.Config, options *LoginOptions) {
if options == nil {
options = &LoginOptions{}
}
promptFn := options.Prompt
if promptFn == nil {
promptFn = defaultProjectPrompt()
}
manager := newAuthManager()
authOpts := &sdkAuth.LoginOptions{
Metadata: map[string]string{
"login_mode": "pat",
},
Prompt: promptFn,
}
_, savedPath, err := manager.Login(context.Background(), "gitlab", cfg, authOpts)
if err != nil {
fmt.Printf("GitLab Duo PAT authentication failed: %v\n", err)
return
}
if savedPath != "" {
fmt.Printf("Authentication saved to %s\n", savedPath)
}
fmt.Println("GitLab Duo PAT authentication successful!")
}

View File

@@ -0,0 +1,32 @@
package config
import (
"os"
"path/filepath"
"testing"
)
func TestLoadConfigOptional_CodexHeaderDefaults(t *testing.T) {
dir := t.TempDir()
configPath := filepath.Join(dir, "config.yaml")
configYAML := []byte(`
codex-header-defaults:
user-agent: " my-codex-client/1.0 "
beta-features: " feature-a,feature-b "
`)
if err := os.WriteFile(configPath, configYAML, 0o600); err != nil {
t.Fatalf("failed to write config: %v", err)
}
cfg, err := LoadConfigOptional(configPath, false)
if err != nil {
t.Fatalf("LoadConfigOptional() error = %v", err)
}
if got := cfg.CodexHeaderDefaults.UserAgent; got != "my-codex-client/1.0" {
t.Fatalf("UserAgent = %q, want %q", got, "my-codex-client/1.0")
}
if got := cfg.CodexHeaderDefaults.BetaFeatures; got != "feature-a,feature-b" {
t.Fatalf("BetaFeatures = %q, want %q", got, "feature-a,feature-b")
}
}

View File

@@ -101,6 +101,10 @@ type Config struct {
// Codex defines a list of Codex API key configurations as specified in the YAML configuration file.
CodexKey []CodexKey `yaml:"codex-api-key" json:"codex-api-key"`
// CodexHeaderDefaults configures fallback headers for Codex OAuth model requests.
// These are used only when the client does not send its own headers.
CodexHeaderDefaults CodexHeaderDefaults `yaml:"codex-header-defaults" json:"codex-header-defaults"`
// ClaudeKey defines a list of Claude API key configurations as specified in the YAML configuration file.
ClaudeKey []ClaudeKey `yaml:"claude-api-key" json:"claude-api-key"`
@@ -150,6 +154,14 @@ type ClaudeHeaderDefaults struct {
Timeout string `yaml:"timeout" json:"timeout"`
}
// CodexHeaderDefaults configures fallback header values injected into Codex
// model requests for OAuth/file-backed auth when the client omits them.
// UserAgent applies to HTTP and websocket requests; BetaFeatures only applies to websockets.
type CodexHeaderDefaults struct {
UserAgent string `yaml:"user-agent" json:"user-agent"`
BetaFeatures string `yaml:"beta-features" json:"beta-features"`
}
// TLSConfig holds HTTPS server settings.
type TLSConfig struct {
// Enable toggles HTTPS server mode.
@@ -679,6 +691,9 @@ func LoadConfigOptional(configFile string, optional bool) (*Config, error) {
// Sanitize Codex keys: drop entries without base-url
cfg.SanitizeCodexKeys()
// Sanitize Codex header defaults.
cfg.SanitizeCodexHeaderDefaults()
// Sanitize Claude key headers
cfg.SanitizeClaudeKeys()
@@ -771,6 +786,16 @@ func payloadRawString(value any) ([]byte, bool) {
}
}
// SanitizeCodexHeaderDefaults trims surrounding whitespace from the
// configured Codex header fallback values.
func (cfg *Config) SanitizeCodexHeaderDefaults() {
if cfg == nil {
return
}
cfg.CodexHeaderDefaults.UserAgent = strings.TrimSpace(cfg.CodexHeaderDefaults.UserAgent)
cfg.CodexHeaderDefaults.BetaFeatures = strings.TrimSpace(cfg.CodexHeaderDefaults.BetaFeatures)
}
// SanitizeOAuthModelAlias normalizes and deduplicates global OAuth model name aliases.
// It trims whitespace, normalizes channel keys to lower-case, drops empty entries,
// allows multiple aliases per upstream name, and ensures aliases are unique within each channel.

View File

@@ -1,12 +1,105 @@
// Package registry provides model definitions and lookup helpers for various AI providers.
// Static model metadata is stored in model_definitions_static_data.go.
// Static model metadata is loaded from the embedded models.json file and can be refreshed from network.
package registry
import (
"sort"
"strings"
)
// staticModelsJSON mirrors the top-level structure of models.json.
type staticModelsJSON struct {
Claude []*ModelInfo `json:"claude"`
Gemini []*ModelInfo `json:"gemini"`
Vertex []*ModelInfo `json:"vertex"`
GeminiCLI []*ModelInfo `json:"gemini-cli"`
AIStudio []*ModelInfo `json:"aistudio"`
CodexFree []*ModelInfo `json:"codex-free"`
CodexTeam []*ModelInfo `json:"codex-team"`
CodexPlus []*ModelInfo `json:"codex-plus"`
CodexPro []*ModelInfo `json:"codex-pro"`
Qwen []*ModelInfo `json:"qwen"`
IFlow []*ModelInfo `json:"iflow"`
Kimi []*ModelInfo `json:"kimi"`
Antigravity []*ModelInfo `json:"antigravity"`
}
// GetClaudeModels returns the standard Claude model definitions.
func GetClaudeModels() []*ModelInfo {
return cloneModelInfos(getModels().Claude)
}
// GetGeminiModels returns the standard Gemini model definitions.
func GetGeminiModels() []*ModelInfo {
return cloneModelInfos(getModels().Gemini)
}
// GetGeminiVertexModels returns Gemini model definitions for Vertex AI.
func GetGeminiVertexModels() []*ModelInfo {
return cloneModelInfos(getModels().Vertex)
}
// GetGeminiCLIModels returns Gemini model definitions for the Gemini CLI.
func GetGeminiCLIModels() []*ModelInfo {
return cloneModelInfos(getModels().GeminiCLI)
}
// GetAIStudioModels returns model definitions for AI Studio.
func GetAIStudioModels() []*ModelInfo {
return cloneModelInfos(getModels().AIStudio)
}
// GetCodexFreeModels returns model definitions for the Codex free plan tier.
func GetCodexFreeModels() []*ModelInfo {
return cloneModelInfos(getModels().CodexFree)
}
// GetCodexTeamModels returns model definitions for the Codex team plan tier.
func GetCodexTeamModels() []*ModelInfo {
return cloneModelInfos(getModels().CodexTeam)
}
// GetCodexPlusModels returns model definitions for the Codex plus plan tier.
func GetCodexPlusModels() []*ModelInfo {
return cloneModelInfos(getModels().CodexPlus)
}
// GetCodexProModels returns model definitions for the Codex pro plan tier.
func GetCodexProModels() []*ModelInfo {
return cloneModelInfos(getModels().CodexPro)
}
// GetQwenModels returns the standard Qwen model definitions.
func GetQwenModels() []*ModelInfo {
return cloneModelInfos(getModels().Qwen)
}
// GetIFlowModels returns the standard iFlow model definitions.
func GetIFlowModels() []*ModelInfo {
return cloneModelInfos(getModels().IFlow)
}
// GetKimiModels returns the standard Kimi (Moonshot AI) model definitions.
func GetKimiModels() []*ModelInfo {
return cloneModelInfos(getModels().Kimi)
}
// GetAntigravityModels returns the standard Antigravity model definitions.
func GetAntigravityModels() []*ModelInfo {
return cloneModelInfos(getModels().Antigravity)
}
// cloneModelInfos returns a shallow copy of the slice with each element deep-cloned.
func cloneModelInfos(models []*ModelInfo) []*ModelInfo {
if len(models) == 0 {
return nil
}
out := make([]*ModelInfo, len(models))
for i, m := range models {
out[i] = cloneModelInfo(m)
}
return out
}
// GetStaticModelDefinitionsByChannel returns static model definitions for a given channel/provider.
// It returns nil when the channel is unknown.
//
@@ -20,7 +113,6 @@ import (
// - qwen
// - iflow
// - kimi
// - kiro
// - kilo
// - github-copilot
// - amazonq
@@ -39,7 +131,7 @@ func GetStaticModelDefinitionsByChannel(channel string) []*ModelInfo {
case "aistudio":
return GetAIStudioModels()
case "codex":
return GetOpenAIModels()
return GetCodexProModels()
case "qwen":
return GetQwenModels()
case "iflow":
@@ -55,28 +147,7 @@ func GetStaticModelDefinitionsByChannel(channel string) []*ModelInfo {
case "amazonq":
return GetAmazonQModels()
case "antigravity":
cfg := GetAntigravityModelConfig()
if len(cfg) == 0 {
return nil
}
models := make([]*ModelInfo, 0, len(cfg))
for modelID, entry := range cfg {
if modelID == "" || entry == nil {
continue
}
models = append(models, &ModelInfo{
ID: modelID,
Object: "model",
OwnedBy: "antigravity",
Type: "antigravity",
Thinking: entry.Thinking,
MaxCompletionTokens: entry.MaxCompletionTokens,
})
}
sort.Slice(models, func(i, j int) bool {
return strings.ToLower(models[i].ID) < strings.ToLower(models[j].ID)
})
return models
return GetAntigravityModels()
default:
return nil
}
@@ -89,16 +160,18 @@ func LookupStaticModelInfo(modelID string) *ModelInfo {
return nil
}
data := getModels()
allModels := [][]*ModelInfo{
GetClaudeModels(),
GetGeminiModels(),
GetGeminiVertexModels(),
GetGeminiCLIModels(),
GetAIStudioModels(),
GetOpenAIModels(),
GetQwenModels(),
GetIFlowModels(),
GetKimiModels(),
data.Claude,
data.Gemini,
data.Vertex,
data.GeminiCLI,
data.AIStudio,
data.CodexPro,
data.Qwen,
data.IFlow,
data.Kimi,
data.Antigravity,
GetGitHubCopilotModels(),
GetKiroModels(),
GetKiloModels(),
@@ -107,20 +180,11 @@ func LookupStaticModelInfo(modelID string) *ModelInfo {
for _, models := range allModels {
for _, m := range models {
if m != nil && m.ID == modelID {
return m
return cloneModelInfo(m)
}
}
}
// Check Antigravity static config
if cfg := GetAntigravityModelConfig()[modelID]; cfg != nil {
return &ModelInfo{
ID: modelID,
Thinking: cfg.Thinking,
MaxCompletionTokens: cfg.MaxCompletionTokens,
}
}
return nil
}

File diff suppressed because it is too large Load Diff

View File

@@ -64,6 +64,11 @@ type ModelInfo struct {
UserDefined bool `json:"-"`
}
type availableModelsCacheEntry struct {
models []map[string]any
expiresAt time.Time
}
// ThinkingSupport describes a model family's supported internal reasoning budget range.
// Values are interpreted in provider-native token units.
type ThinkingSupport struct {
@@ -118,6 +123,8 @@ type ModelRegistry struct {
clientProviders map[string]string
// mutex ensures thread-safe access to the registry
mutex *sync.RWMutex
// availableModelsCache stores per-handler snapshots for GetAvailableModels.
availableModelsCache map[string]availableModelsCacheEntry
// hook is an optional callback sink for model registration changes
hook ModelRegistryHook
}
@@ -130,15 +137,28 @@ var registryOnce sync.Once
func GetGlobalRegistry() *ModelRegistry {
registryOnce.Do(func() {
globalRegistry = &ModelRegistry{
models: make(map[string]*ModelRegistration),
clientModels: make(map[string][]string),
clientModelInfos: make(map[string]map[string]*ModelInfo),
clientProviders: make(map[string]string),
mutex: &sync.RWMutex{},
models: make(map[string]*ModelRegistration),
clientModels: make(map[string][]string),
clientModelInfos: make(map[string]map[string]*ModelInfo),
clientProviders: make(map[string]string),
availableModelsCache: make(map[string]availableModelsCacheEntry),
mutex: &sync.RWMutex{},
}
})
return globalRegistry
}
func (r *ModelRegistry) ensureAvailableModelsCacheLocked() {
if r.availableModelsCache == nil {
r.availableModelsCache = make(map[string]availableModelsCacheEntry)
}
}
func (r *ModelRegistry) invalidateAvailableModelsCacheLocked() {
if len(r.availableModelsCache) == 0 {
return
}
clear(r.availableModelsCache)
}
// LookupModelInfo searches dynamic registry (provider-specific > global) then static definitions.
func LookupModelInfo(modelID string, provider ...string) *ModelInfo {
@@ -153,9 +173,9 @@ func LookupModelInfo(modelID string, provider ...string) *ModelInfo {
}
if info := GetGlobalRegistry().GetModelInfo(modelID, p); info != nil {
return info
return cloneModelInfo(info)
}
return LookupStaticModelInfo(modelID)
return cloneModelInfo(LookupStaticModelInfo(modelID))
}
// SetHook sets an optional hook for observing model registration changes.
@@ -169,6 +189,7 @@ func (r *ModelRegistry) SetHook(hook ModelRegistryHook) {
}
const defaultModelRegistryHookTimeout = 5 * time.Second
const modelQuotaExceededWindow = 5 * time.Minute
func (r *ModelRegistry) triggerModelsRegistered(provider, clientID string, models []*ModelInfo) {
hook := r.hook
@@ -213,6 +234,7 @@ func (r *ModelRegistry) triggerModelsUnregistered(provider, clientID string) {
func (r *ModelRegistry) RegisterClient(clientID, clientProvider string, models []*ModelInfo) {
r.mutex.Lock()
defer r.mutex.Unlock()
r.ensureAvailableModelsCacheLocked()
provider := strings.ToLower(clientProvider)
uniqueModelIDs := make([]string, 0, len(models))
@@ -238,6 +260,7 @@ func (r *ModelRegistry) RegisterClient(clientID, clientProvider string, models [
delete(r.clientModels, clientID)
delete(r.clientModelInfos, clientID)
delete(r.clientProviders, clientID)
r.invalidateAvailableModelsCacheLocked()
misc.LogCredentialSeparator()
return
}
@@ -265,6 +288,7 @@ func (r *ModelRegistry) RegisterClient(clientID, clientProvider string, models [
} else {
delete(r.clientProviders, clientID)
}
r.invalidateAvailableModelsCacheLocked()
r.triggerModelsRegistered(provider, clientID, models)
log.Debugf("Registered client %s from provider %s with %d models", clientID, clientProvider, len(rawModelIDs))
misc.LogCredentialSeparator()
@@ -367,6 +391,9 @@ func (r *ModelRegistry) RegisterClient(clientID, clientProvider string, models [
reg.InfoByProvider[provider] = cloneModelInfo(model)
}
reg.LastUpdated = now
// Re-registering an existing client/model binding starts a fresh registry
// snapshot for that binding. Cooldown and suspension are transient
// scheduling state and must not survive this reconciliation step.
if reg.QuotaExceededClients != nil {
delete(reg.QuotaExceededClients, clientID)
}
@@ -408,6 +435,7 @@ func (r *ModelRegistry) RegisterClient(clientID, clientProvider string, models [
delete(r.clientProviders, clientID)
}
r.invalidateAvailableModelsCacheLocked()
r.triggerModelsRegistered(provider, clientID, models)
if len(added) == 0 && len(removed) == 0 && !providerChanged {
// Only metadata (e.g., display name) changed; skip separator when no log output.
@@ -511,6 +539,13 @@ func cloneModelInfo(model *ModelInfo) *ModelInfo {
if len(model.SupportedOutputModalities) > 0 {
copyModel.SupportedOutputModalities = append([]string(nil), model.SupportedOutputModalities...)
}
if model.Thinking != nil {
copyThinking := *model.Thinking
if len(model.Thinking.Levels) > 0 {
copyThinking.Levels = append([]string(nil), model.Thinking.Levels...)
}
copyModel.Thinking = &copyThinking
}
return &copyModel
}
@@ -540,6 +575,7 @@ func (r *ModelRegistry) UnregisterClient(clientID string) {
r.mutex.Lock()
defer r.mutex.Unlock()
r.unregisterClientInternal(clientID)
r.invalidateAvailableModelsCacheLocked()
}
// unregisterClientInternal performs the actual client unregistration (internal, no locking)
@@ -606,9 +642,12 @@ func (r *ModelRegistry) unregisterClientInternal(clientID string) {
func (r *ModelRegistry) SetModelQuotaExceeded(clientID, modelID string) {
r.mutex.Lock()
defer r.mutex.Unlock()
r.ensureAvailableModelsCacheLocked()
if registration, exists := r.models[modelID]; exists {
registration.QuotaExceededClients[clientID] = new(time.Now())
now := time.Now()
registration.QuotaExceededClients[clientID] = &now
r.invalidateAvailableModelsCacheLocked()
log.Debugf("Marked model %s as quota exceeded for client %s", modelID, clientID)
}
}
@@ -620,9 +659,11 @@ func (r *ModelRegistry) SetModelQuotaExceeded(clientID, modelID string) {
func (r *ModelRegistry) ClearModelQuotaExceeded(clientID, modelID string) {
r.mutex.Lock()
defer r.mutex.Unlock()
r.ensureAvailableModelsCacheLocked()
if registration, exists := r.models[modelID]; exists {
delete(registration.QuotaExceededClients, clientID)
r.invalidateAvailableModelsCacheLocked()
// log.Debugf("Cleared quota exceeded status for model %s and client %s", modelID, clientID)
}
}
@@ -638,6 +679,7 @@ func (r *ModelRegistry) SuspendClientModel(clientID, modelID, reason string) {
}
r.mutex.Lock()
defer r.mutex.Unlock()
r.ensureAvailableModelsCacheLocked()
registration, exists := r.models[modelID]
if !exists || registration == nil {
@@ -651,6 +693,7 @@ func (r *ModelRegistry) SuspendClientModel(clientID, modelID, reason string) {
}
registration.SuspendedClients[clientID] = reason
registration.LastUpdated = time.Now()
r.invalidateAvailableModelsCacheLocked()
if reason != "" {
log.Debugf("Suspended client %s for model %s: %s", clientID, modelID, reason)
} else {
@@ -668,6 +711,7 @@ func (r *ModelRegistry) ResumeClientModel(clientID, modelID string) {
}
r.mutex.Lock()
defer r.mutex.Unlock()
r.ensureAvailableModelsCacheLocked()
registration, exists := r.models[modelID]
if !exists || registration == nil || registration.SuspendedClients == nil {
@@ -678,6 +722,7 @@ func (r *ModelRegistry) ResumeClientModel(clientID, modelID string) {
}
delete(registration.SuspendedClients, clientID)
registration.LastUpdated = time.Now()
r.invalidateAvailableModelsCacheLocked()
log.Debugf("Resumed client %s for model %s", clientID, modelID)
}
@@ -713,22 +758,51 @@ func (r *ModelRegistry) ClientSupportsModel(clientID, modelID string) bool {
// Returns:
// - []map[string]any: List of available models in the requested format
func (r *ModelRegistry) GetAvailableModels(handlerType string) []map[string]any {
r.mutex.RLock()
defer r.mutex.RUnlock()
now := time.Now()
models := make([]map[string]any, 0)
quotaExpiredDuration := 5 * time.Minute
r.mutex.RLock()
if cache, ok := r.availableModelsCache[handlerType]; ok && (cache.expiresAt.IsZero() || now.Before(cache.expiresAt)) {
models := cloneModelMaps(cache.models)
r.mutex.RUnlock()
return models
}
r.mutex.RUnlock()
r.mutex.Lock()
defer r.mutex.Unlock()
r.ensureAvailableModelsCacheLocked()
if cache, ok := r.availableModelsCache[handlerType]; ok && (cache.expiresAt.IsZero() || now.Before(cache.expiresAt)) {
return cloneModelMaps(cache.models)
}
models, expiresAt := r.buildAvailableModelsLocked(handlerType, now)
r.availableModelsCache[handlerType] = availableModelsCacheEntry{
models: cloneModelMaps(models),
expiresAt: expiresAt,
}
return models
}
func (r *ModelRegistry) buildAvailableModelsLocked(handlerType string, now time.Time) ([]map[string]any, time.Time) {
models := make([]map[string]any, 0, len(r.models))
var expiresAt time.Time
for _, registration := range r.models {
// Check if model has any non-quota-exceeded clients
availableClients := registration.Count
now := time.Now()
// Count clients that have exceeded quota but haven't recovered yet
expiredClients := 0
for _, quotaTime := range registration.QuotaExceededClients {
if quotaTime != nil && now.Sub(*quotaTime) < quotaExpiredDuration {
if quotaTime == nil {
continue
}
recoveryAt := quotaTime.Add(modelQuotaExceededWindow)
if now.Before(recoveryAt) {
expiredClients++
if expiresAt.IsZero() || recoveryAt.Before(expiresAt) {
expiresAt = recoveryAt
}
}
}
@@ -749,7 +823,6 @@ func (r *ModelRegistry) GetAvailableModels(handlerType string) []map[string]any
effectiveClients = 0
}
// Include models that have available clients, or those solely cooling down.
if effectiveClients > 0 || (availableClients > 0 && (expiredClients > 0 || cooldownSuspended > 0) && otherSuspended == 0) {
model := r.convertModelToMap(registration.Info, handlerType)
if model != nil {
@@ -758,7 +831,44 @@ func (r *ModelRegistry) GetAvailableModels(handlerType string) []map[string]any
}
}
return models
return models, expiresAt
}
func cloneModelMaps(models []map[string]any) []map[string]any {
cloned := make([]map[string]any, 0, len(models))
for _, model := range models {
if model == nil {
cloned = append(cloned, nil)
continue
}
copyModel := make(map[string]any, len(model))
for key, value := range model {
copyModel[key] = cloneModelMapValue(value)
}
cloned = append(cloned, copyModel)
}
return cloned
}
func cloneModelMapValue(value any) any {
switch typed := value.(type) {
case map[string]any:
copyMap := make(map[string]any, len(typed))
for key, entry := range typed {
copyMap[key] = cloneModelMapValue(entry)
}
return copyMap
case []any:
copySlice := make([]any, len(typed))
for i, entry := range typed {
copySlice[i] = cloneModelMapValue(entry)
}
return copySlice
case []string:
return append([]string(nil), typed...)
default:
return value
}
}
// GetAvailableModelsByProvider returns models available for the given provider identifier.
@@ -822,7 +932,6 @@ func (r *ModelRegistry) GetAvailableModelsByProvider(provider string) []*ModelIn
return nil
}
quotaExpiredDuration := 5 * time.Minute
now := time.Now()
result := make([]*ModelInfo, 0, len(providerModels))
@@ -844,7 +953,7 @@ func (r *ModelRegistry) GetAvailableModelsByProvider(provider string) []*ModelIn
if p, okProvider := r.clientProviders[clientID]; !okProvider || p != provider {
continue
}
if quotaTime != nil && now.Sub(*quotaTime) < quotaExpiredDuration {
if quotaTime != nil && now.Sub(*quotaTime) < modelQuotaExceededWindow {
expiredClients++
}
}
@@ -874,11 +983,11 @@ func (r *ModelRegistry) GetAvailableModelsByProvider(provider string) []*ModelIn
if effectiveClients > 0 || (availableClients > 0 && (expiredClients > 0 || cooldownSuspended > 0) && otherSuspended == 0) {
if entry.info != nil {
result = append(result, entry.info)
result = append(result, cloneModelInfo(entry.info))
continue
}
if ok && registration != nil && registration.Info != nil {
result = append(result, registration.Info)
result = append(result, cloneModelInfo(registration.Info))
}
}
}
@@ -898,12 +1007,11 @@ func (r *ModelRegistry) GetModelCount(modelID string) int {
if registration, exists := r.models[modelID]; exists {
now := time.Now()
quotaExpiredDuration := 5 * time.Minute
// Count clients that have exceeded quota but haven't recovered yet
expiredClients := 0
for _, quotaTime := range registration.QuotaExceededClients {
if quotaTime != nil && now.Sub(*quotaTime) < quotaExpiredDuration {
if quotaTime != nil && now.Sub(*quotaTime) < modelQuotaExceededWindow {
expiredClients++
}
}
@@ -987,13 +1095,13 @@ func (r *ModelRegistry) GetModelInfo(modelID, provider string) *ModelInfo {
if reg.Providers != nil {
if count, ok := reg.Providers[provider]; ok && count > 0 {
if info, ok := reg.InfoByProvider[provider]; ok && info != nil {
return info
return cloneModelInfo(info)
}
}
}
}
// Fallback to global info (last registered)
return reg.Info
return cloneModelInfo(reg.Info)
}
return nil
}
@@ -1033,7 +1141,7 @@ func (r *ModelRegistry) convertModelToMap(model *ModelInfo, handlerType string)
result["max_completion_tokens"] = model.MaxCompletionTokens
}
if len(model.SupportedParameters) > 0 {
result["supported_parameters"] = model.SupportedParameters
result["supported_parameters"] = append([]string(nil), model.SupportedParameters...)
}
if len(model.SupportedEndpoints) > 0 {
result["supported_endpoints"] = model.SupportedEndpoints
@@ -1094,13 +1202,13 @@ func (r *ModelRegistry) convertModelToMap(model *ModelInfo, handlerType string)
result["outputTokenLimit"] = model.OutputTokenLimit
}
if len(model.SupportedGenerationMethods) > 0 {
result["supportedGenerationMethods"] = model.SupportedGenerationMethods
result["supportedGenerationMethods"] = append([]string(nil), model.SupportedGenerationMethods...)
}
if len(model.SupportedInputModalities) > 0 {
result["supportedInputModalities"] = model.SupportedInputModalities
result["supportedInputModalities"] = append([]string(nil), model.SupportedInputModalities...)
}
if len(model.SupportedOutputModalities) > 0 {
result["supportedOutputModalities"] = model.SupportedOutputModalities
result["supportedOutputModalities"] = append([]string(nil), model.SupportedOutputModalities...)
}
return result
@@ -1129,16 +1237,20 @@ func (r *ModelRegistry) CleanupExpiredQuotas() {
defer r.mutex.Unlock()
now := time.Now()
quotaExpiredDuration := 5 * time.Minute
invalidated := false
for modelID, registration := range r.models {
for clientID, quotaTime := range registration.QuotaExceededClients {
if quotaTime != nil && now.Sub(*quotaTime) >= quotaExpiredDuration {
if quotaTime != nil && now.Sub(*quotaTime) >= modelQuotaExceededWindow {
delete(registration.QuotaExceededClients, clientID)
invalidated = true
log.Debugf("Cleaned up expired quota tracking for model %s, client %s", modelID, clientID)
}
}
}
if invalidated {
r.invalidateAvailableModelsCacheLocked()
}
}
// GetFirstAvailableModel returns the first available model for the given handler type.
@@ -1152,8 +1264,6 @@ func (r *ModelRegistry) CleanupExpiredQuotas() {
// - string: The model ID of the first available model, or empty string if none available
// - error: An error if no models are available
func (r *ModelRegistry) GetFirstAvailableModel(handlerType string) (string, error) {
r.mutex.RLock()
defer r.mutex.RUnlock()
// Get all available models for this handler type
models := r.GetAvailableModels(handlerType)
@@ -1213,13 +1323,13 @@ func (r *ModelRegistry) GetModelsForClient(clientID string) []*ModelInfo {
// Prefer client's own model info to preserve original type/owned_by
if clientInfos != nil {
if info, ok := clientInfos[modelID]; ok && info != nil {
result = append(result, info)
result = append(result, cloneModelInfo(info))
continue
}
}
// Fallback to global registry (for backwards compatibility)
if reg, ok := r.models[modelID]; ok && reg.Info != nil {
result = append(result, reg.Info)
result = append(result, cloneModelInfo(reg.Info))
}
}
return result

View File

@@ -0,0 +1,54 @@
package registry
import "testing"
func TestGetAvailableModelsReturnsClonedSnapshots(t *testing.T) {
r := newTestModelRegistry()
r.RegisterClient("client-1", "OpenAI", []*ModelInfo{{ID: "m1", OwnedBy: "team-a", DisplayName: "Model One"}})
first := r.GetAvailableModels("openai")
if len(first) != 1 {
t.Fatalf("expected 1 model, got %d", len(first))
}
first[0]["id"] = "mutated"
first[0]["display_name"] = "Mutated"
second := r.GetAvailableModels("openai")
if got := second[0]["id"]; got != "m1" {
t.Fatalf("expected cached snapshot to stay isolated, got id %v", got)
}
if got := second[0]["display_name"]; got != "Model One" {
t.Fatalf("expected cached snapshot to stay isolated, got display_name %v", got)
}
}
func TestGetAvailableModelsInvalidatesCacheOnRegistryChanges(t *testing.T) {
r := newTestModelRegistry()
r.RegisterClient("client-1", "OpenAI", []*ModelInfo{{ID: "m1", OwnedBy: "team-a", DisplayName: "Model One"}})
models := r.GetAvailableModels("openai")
if len(models) != 1 {
t.Fatalf("expected 1 model, got %d", len(models))
}
if got := models[0]["display_name"]; got != "Model One" {
t.Fatalf("expected initial display_name Model One, got %v", got)
}
r.RegisterClient("client-1", "OpenAI", []*ModelInfo{{ID: "m1", OwnedBy: "team-a", DisplayName: "Model One Updated"}})
models = r.GetAvailableModels("openai")
if got := models[0]["display_name"]; got != "Model One Updated" {
t.Fatalf("expected updated display_name after cache invalidation, got %v", got)
}
r.SuspendClientModel("client-1", "m1", "manual")
models = r.GetAvailableModels("openai")
if len(models) != 0 {
t.Fatalf("expected no available models after suspension, got %d", len(models))
}
r.ResumeClientModel("client-1", "m1")
models = r.GetAvailableModels("openai")
if len(models) != 1 {
t.Fatalf("expected model to reappear after resume, got %d", len(models))
}
}

View File

@@ -0,0 +1,149 @@
package registry
import (
"testing"
"time"
)
func TestGetModelInfoReturnsClone(t *testing.T) {
r := newTestModelRegistry()
r.RegisterClient("client-1", "gemini", []*ModelInfo{{
ID: "m1",
DisplayName: "Model One",
Thinking: &ThinkingSupport{Min: 1, Max: 2, Levels: []string{"low", "high"}},
}})
first := r.GetModelInfo("m1", "gemini")
if first == nil {
t.Fatal("expected model info")
}
first.DisplayName = "mutated"
first.Thinking.Levels[0] = "mutated"
second := r.GetModelInfo("m1", "gemini")
if second.DisplayName != "Model One" {
t.Fatalf("expected cloned display name, got %q", second.DisplayName)
}
if second.Thinking == nil || len(second.Thinking.Levels) == 0 || second.Thinking.Levels[0] != "low" {
t.Fatalf("expected cloned thinking levels, got %+v", second.Thinking)
}
}
func TestGetModelsForClientReturnsClones(t *testing.T) {
r := newTestModelRegistry()
r.RegisterClient("client-1", "gemini", []*ModelInfo{{
ID: "m1",
DisplayName: "Model One",
Thinking: &ThinkingSupport{Levels: []string{"low", "high"}},
}})
first := r.GetModelsForClient("client-1")
if len(first) != 1 || first[0] == nil {
t.Fatalf("expected one model, got %+v", first)
}
first[0].DisplayName = "mutated"
first[0].Thinking.Levels[0] = "mutated"
second := r.GetModelsForClient("client-1")
if len(second) != 1 || second[0] == nil {
t.Fatalf("expected one model on second fetch, got %+v", second)
}
if second[0].DisplayName != "Model One" {
t.Fatalf("expected cloned display name, got %q", second[0].DisplayName)
}
if second[0].Thinking == nil || len(second[0].Thinking.Levels) == 0 || second[0].Thinking.Levels[0] != "low" {
t.Fatalf("expected cloned thinking levels, got %+v", second[0].Thinking)
}
}
func TestGetAvailableModelsByProviderReturnsClones(t *testing.T) {
r := newTestModelRegistry()
r.RegisterClient("client-1", "gemini", []*ModelInfo{{
ID: "m1",
DisplayName: "Model One",
Thinking: &ThinkingSupport{Levels: []string{"low", "high"}},
}})
first := r.GetAvailableModelsByProvider("gemini")
if len(first) != 1 || first[0] == nil {
t.Fatalf("expected one model, got %+v", first)
}
first[0].DisplayName = "mutated"
first[0].Thinking.Levels[0] = "mutated"
second := r.GetAvailableModelsByProvider("gemini")
if len(second) != 1 || second[0] == nil {
t.Fatalf("expected one model on second fetch, got %+v", second)
}
if second[0].DisplayName != "Model One" {
t.Fatalf("expected cloned display name, got %q", second[0].DisplayName)
}
if second[0].Thinking == nil || len(second[0].Thinking.Levels) == 0 || second[0].Thinking.Levels[0] != "low" {
t.Fatalf("expected cloned thinking levels, got %+v", second[0].Thinking)
}
}
func TestCleanupExpiredQuotasInvalidatesAvailableModelsCache(t *testing.T) {
r := newTestModelRegistry()
r.RegisterClient("client-1", "openai", []*ModelInfo{{ID: "m1", Created: 1}})
r.SetModelQuotaExceeded("client-1", "m1")
if models := r.GetAvailableModels("openai"); len(models) != 1 {
t.Fatalf("expected cooldown model to remain listed before cleanup, got %d", len(models))
}
r.mutex.Lock()
quotaTime := time.Now().Add(-6 * time.Minute)
r.models["m1"].QuotaExceededClients["client-1"] = &quotaTime
r.mutex.Unlock()
r.CleanupExpiredQuotas()
if count := r.GetModelCount("m1"); count != 1 {
t.Fatalf("expected model count 1 after cleanup, got %d", count)
}
models := r.GetAvailableModels("openai")
if len(models) != 1 {
t.Fatalf("expected model to stay available after cleanup, got %d", len(models))
}
if got := models[0]["id"]; got != "m1" {
t.Fatalf("expected model id m1, got %v", got)
}
}
func TestGetAvailableModelsReturnsClonedSupportedParameters(t *testing.T) {
r := newTestModelRegistry()
r.RegisterClient("client-1", "openai", []*ModelInfo{{
ID: "m1",
DisplayName: "Model One",
SupportedParameters: []string{"temperature", "top_p"},
}})
first := r.GetAvailableModels("openai")
if len(first) != 1 {
t.Fatalf("expected one model, got %d", len(first))
}
params, ok := first[0]["supported_parameters"].([]string)
if !ok || len(params) != 2 {
t.Fatalf("expected supported_parameters slice, got %#v", first[0]["supported_parameters"])
}
params[0] = "mutated"
second := r.GetAvailableModels("openai")
params, ok = second[0]["supported_parameters"].([]string)
if !ok || len(params) != 2 || params[0] != "temperature" {
t.Fatalf("expected cloned supported_parameters, got %#v", second[0]["supported_parameters"])
}
}
func TestLookupModelInfoReturnsCloneForStaticDefinitions(t *testing.T) {
first := LookupModelInfo("glm-4.6")
if first == nil || first.Thinking == nil || len(first.Thinking.Levels) == 0 {
t.Fatalf("expected static model with thinking levels, got %+v", first)
}
first.Thinking.Levels[0] = "mutated"
second := LookupModelInfo("glm-4.6")
if second == nil || second.Thinking == nil || len(second.Thinking.Levels) == 0 || second.Thinking.Levels[0] == "mutated" {
t.Fatalf("expected static lookup clone, got %+v", second)
}
}

View File

@@ -0,0 +1,372 @@
package registry
import (
"context"
_ "embed"
"encoding/json"
"fmt"
"io"
"net/http"
"strings"
"sync"
"time"
log "github.com/sirupsen/logrus"
)
const (
modelsFetchTimeout = 30 * time.Second
modelsRefreshInterval = 3 * time.Hour
)
var modelsURLs = []string{
"https://raw.githubusercontent.com/router-for-me/models/refs/heads/main/models.json",
"https://models.router-for.me/models.json",
}
//go:embed models/models.json
var embeddedModelsJSON []byte
type modelStore struct {
mu sync.RWMutex
data *staticModelsJSON
}
var modelsCatalogStore = &modelStore{}
var updaterOnce sync.Once
// ModelRefreshCallback is invoked when startup or periodic model refresh detects changes.
// changedProviders contains the provider names whose model definitions changed.
type ModelRefreshCallback func(changedProviders []string)
var (
refreshCallbackMu sync.Mutex
refreshCallback ModelRefreshCallback
pendingRefreshChanges []string
)
// SetModelRefreshCallback registers a callback that is invoked when startup or
// periodic model refresh detects changes. Only one callback is supported;
// subsequent calls replace the previous callback.
func SetModelRefreshCallback(cb ModelRefreshCallback) {
refreshCallbackMu.Lock()
refreshCallback = cb
var pending []string
if cb != nil && len(pendingRefreshChanges) > 0 {
pending = append([]string(nil), pendingRefreshChanges...)
pendingRefreshChanges = nil
}
refreshCallbackMu.Unlock()
if cb != nil && len(pending) > 0 {
cb(pending)
}
}
func init() {
// Load embedded data as fallback on startup.
if err := loadModelsFromBytes(embeddedModelsJSON, "embed"); err != nil {
panic(fmt.Sprintf("registry: failed to parse embedded models.json: %v", err))
}
}
// StartModelsUpdater starts a background updater that fetches models
// immediately on startup and then refreshes the model catalog every 3 hours.
// Safe to call multiple times; only one updater will run.
func StartModelsUpdater(ctx context.Context) {
updaterOnce.Do(func() {
go runModelsUpdater(ctx)
})
}
func runModelsUpdater(ctx context.Context) {
tryStartupRefresh(ctx)
periodicRefresh(ctx)
}
func periodicRefresh(ctx context.Context) {
ticker := time.NewTicker(modelsRefreshInterval)
defer ticker.Stop()
log.Infof("periodic model refresh started (interval=%s)", modelsRefreshInterval)
for {
select {
case <-ctx.Done():
return
case <-ticker.C:
tryPeriodicRefresh(ctx)
}
}
}
// tryPeriodicRefresh fetches models from remote, compares with the current
// catalog, and notifies the registered callback if any provider changed.
func tryPeriodicRefresh(ctx context.Context) {
tryRefreshModels(ctx, "periodic model refresh")
}
// tryStartupRefresh fetches models from remote in the background during
// process startup. It uses the same change detection as periodic refresh so
// existing auth registrations can be updated after the callback is registered.
func tryStartupRefresh(ctx context.Context) {
tryRefreshModels(ctx, "startup model refresh")
}
func tryRefreshModels(ctx context.Context, label string) {
oldData := getModels()
parsed, url := fetchModelsFromRemote(ctx)
if parsed == nil {
log.Warnf("%s: fetch failed from all URLs, keeping current data", label)
return
}
// Detect changes before updating store.
changed := detectChangedProviders(oldData, parsed)
// Update store with new data regardless.
modelsCatalogStore.mu.Lock()
modelsCatalogStore.data = parsed
modelsCatalogStore.mu.Unlock()
if len(changed) == 0 {
log.Infof("%s completed from %s, no changes detected", label, url)
return
}
log.Infof("%s completed from %s, changes detected for providers: %v", label, url, changed)
notifyModelRefresh(changed)
}
// fetchModelsFromRemote tries all remote URLs and returns the parsed model catalog
// along with the URL it was fetched from. Returns (nil, "") if all fetches fail.
func fetchModelsFromRemote(ctx context.Context) (*staticModelsJSON, string) {
client := &http.Client{Timeout: modelsFetchTimeout}
for _, url := range modelsURLs {
reqCtx, cancel := context.WithTimeout(ctx, modelsFetchTimeout)
req, err := http.NewRequestWithContext(reqCtx, "GET", url, nil)
if err != nil {
cancel()
log.Debugf("models fetch request creation failed for %s: %v", url, err)
continue
}
resp, err := client.Do(req)
if err != nil {
cancel()
log.Debugf("models fetch failed from %s: %v", url, err)
continue
}
if resp.StatusCode != 200 {
resp.Body.Close()
cancel()
log.Debugf("models fetch returned %d from %s", resp.StatusCode, url)
continue
}
data, err := io.ReadAll(resp.Body)
resp.Body.Close()
cancel()
if err != nil {
log.Debugf("models fetch read error from %s: %v", url, err)
continue
}
var parsed staticModelsJSON
if err := json.Unmarshal(data, &parsed); err != nil {
log.Warnf("models parse failed from %s: %v", url, err)
continue
}
if err := validateModelsCatalog(&parsed); err != nil {
log.Warnf("models validate failed from %s: %v", url, err)
continue
}
return &parsed, url
}
return nil, ""
}
// detectChangedProviders compares two model catalogs and returns provider names
// whose model definitions differ. Codex tiers (free/team/plus/pro) are grouped
// under a single "codex" provider.
func detectChangedProviders(oldData, newData *staticModelsJSON) []string {
if oldData == nil || newData == nil {
return nil
}
type section struct {
provider string
oldList []*ModelInfo
newList []*ModelInfo
}
sections := []section{
{"claude", oldData.Claude, newData.Claude},
{"gemini", oldData.Gemini, newData.Gemini},
{"vertex", oldData.Vertex, newData.Vertex},
{"gemini-cli", oldData.GeminiCLI, newData.GeminiCLI},
{"aistudio", oldData.AIStudio, newData.AIStudio},
{"codex", oldData.CodexFree, newData.CodexFree},
{"codex", oldData.CodexTeam, newData.CodexTeam},
{"codex", oldData.CodexPlus, newData.CodexPlus},
{"codex", oldData.CodexPro, newData.CodexPro},
{"qwen", oldData.Qwen, newData.Qwen},
{"iflow", oldData.IFlow, newData.IFlow},
{"kimi", oldData.Kimi, newData.Kimi},
{"antigravity", oldData.Antigravity, newData.Antigravity},
}
seen := make(map[string]bool, len(sections))
var changed []string
for _, s := range sections {
if seen[s.provider] {
continue
}
if modelSectionChanged(s.oldList, s.newList) {
changed = append(changed, s.provider)
seen[s.provider] = true
}
}
return changed
}
// modelSectionChanged reports whether two model slices differ.
func modelSectionChanged(a, b []*ModelInfo) bool {
if len(a) != len(b) {
return true
}
if len(a) == 0 {
return false
}
aj, err1 := json.Marshal(a)
bj, err2 := json.Marshal(b)
if err1 != nil || err2 != nil {
return true
}
return string(aj) != string(bj)
}
func notifyModelRefresh(changedProviders []string) {
if len(changedProviders) == 0 {
return
}
refreshCallbackMu.Lock()
cb := refreshCallback
if cb == nil {
pendingRefreshChanges = mergeProviderNames(pendingRefreshChanges, changedProviders)
refreshCallbackMu.Unlock()
return
}
refreshCallbackMu.Unlock()
cb(changedProviders)
}
func mergeProviderNames(existing, incoming []string) []string {
if len(incoming) == 0 {
return existing
}
seen := make(map[string]struct{}, len(existing)+len(incoming))
merged := make([]string, 0, len(existing)+len(incoming))
for _, provider := range existing {
name := strings.ToLower(strings.TrimSpace(provider))
if name == "" {
continue
}
if _, ok := seen[name]; ok {
continue
}
seen[name] = struct{}{}
merged = append(merged, name)
}
for _, provider := range incoming {
name := strings.ToLower(strings.TrimSpace(provider))
if name == "" {
continue
}
if _, ok := seen[name]; ok {
continue
}
seen[name] = struct{}{}
merged = append(merged, name)
}
return merged
}
func loadModelsFromBytes(data []byte, source string) error {
var parsed staticModelsJSON
if err := json.Unmarshal(data, &parsed); err != nil {
return fmt.Errorf("%s: decode models catalog: %w", source, err)
}
if err := validateModelsCatalog(&parsed); err != nil {
return fmt.Errorf("%s: validate models catalog: %w", source, err)
}
modelsCatalogStore.mu.Lock()
modelsCatalogStore.data = &parsed
modelsCatalogStore.mu.Unlock()
return nil
}
func getModels() *staticModelsJSON {
modelsCatalogStore.mu.RLock()
defer modelsCatalogStore.mu.RUnlock()
return modelsCatalogStore.data
}
func validateModelsCatalog(data *staticModelsJSON) error {
if data == nil {
return fmt.Errorf("catalog is nil")
}
requiredSections := []struct {
name string
models []*ModelInfo
}{
{name: "claude", models: data.Claude},
{name: "gemini", models: data.Gemini},
{name: "vertex", models: data.Vertex},
{name: "gemini-cli", models: data.GeminiCLI},
{name: "aistudio", models: data.AIStudio},
{name: "codex-free", models: data.CodexFree},
{name: "codex-team", models: data.CodexTeam},
{name: "codex-plus", models: data.CodexPlus},
{name: "codex-pro", models: data.CodexPro},
{name: "qwen", models: data.Qwen},
{name: "iflow", models: data.IFlow},
{name: "kimi", models: data.Kimi},
{name: "antigravity", models: data.Antigravity},
}
for _, section := range requiredSections {
if err := validateModelSection(section.name, section.models); err != nil {
return err
}
}
return nil
}
func validateModelSection(section string, models []*ModelInfo) error {
if len(models) == 0 {
return fmt.Errorf("%s section is empty", section)
}
seen := make(map[string]struct{}, len(models))
for i, model := range models {
if model == nil {
return fmt.Errorf("%s[%d] is null", section, i)
}
modelID := strings.TrimSpace(model.ID)
if modelID == "" {
return fmt.Errorf("%s[%d] has empty id", section, i)
}
if _, exists := seen[modelID]; exists {
return fmt.Errorf("%s contains duplicate model id %q", section, modelID)
}
seen[modelID] = struct{}{}
}
return nil
}

File diff suppressed because it is too large Load Diff

View File

@@ -24,7 +24,6 @@ import (
"github.com/google/uuid"
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
"github.com/router-for-me/CLIProxyAPI/v6/internal/registry"
"github.com/router-for-me/CLIProxyAPI/v6/internal/thinking"
"github.com/router-for-me/CLIProxyAPI/v6/internal/util"
sdkAuth "github.com/router-for-me/CLIProxyAPI/v6/sdk/auth"
@@ -43,7 +42,6 @@ const (
antigravityCountTokensPath = "/v1internal:countTokens"
antigravityStreamPath = "/v1internal:streamGenerateContent"
antigravityGeneratePath = "/v1internal:generateContent"
antigravityModelsPath = "/v1internal:fetchAvailableModels"
antigravityClientID = "1071006060591-tmhssin2h21lcre235vtolojh4g403ep.apps.googleusercontent.com"
antigravityClientSecret = "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"
defaultAntigravityAgent = "antigravity/1.19.6 darwin/arm64"
@@ -55,78 +53,8 @@ const (
var (
randSource = rand.New(rand.NewSource(time.Now().UnixNano()))
randSourceMutex sync.Mutex
// antigravityPrimaryModelsCache keeps the latest non-empty model list fetched
// from any antigravity auth. Empty fetches never overwrite this cache.
antigravityPrimaryModelsCache struct {
mu sync.RWMutex
models []*registry.ModelInfo
}
)
func cloneAntigravityModels(models []*registry.ModelInfo) []*registry.ModelInfo {
if len(models) == 0 {
return nil
}
out := make([]*registry.ModelInfo, 0, len(models))
for _, model := range models {
if model == nil || strings.TrimSpace(model.ID) == "" {
continue
}
out = append(out, cloneAntigravityModelInfo(model))
}
if len(out) == 0 {
return nil
}
return out
}
func cloneAntigravityModelInfo(model *registry.ModelInfo) *registry.ModelInfo {
if model == nil {
return nil
}
clone := *model
if len(model.SupportedGenerationMethods) > 0 {
clone.SupportedGenerationMethods = append([]string(nil), model.SupportedGenerationMethods...)
}
if len(model.SupportedParameters) > 0 {
clone.SupportedParameters = append([]string(nil), model.SupportedParameters...)
}
if model.Thinking != nil {
thinkingClone := *model.Thinking
if len(model.Thinking.Levels) > 0 {
thinkingClone.Levels = append([]string(nil), model.Thinking.Levels...)
}
clone.Thinking = &thinkingClone
}
return &clone
}
func storeAntigravityPrimaryModels(models []*registry.ModelInfo) bool {
cloned := cloneAntigravityModels(models)
if len(cloned) == 0 {
return false
}
antigravityPrimaryModelsCache.mu.Lock()
antigravityPrimaryModelsCache.models = cloned
antigravityPrimaryModelsCache.mu.Unlock()
return true
}
func loadAntigravityPrimaryModels() []*registry.ModelInfo {
antigravityPrimaryModelsCache.mu.RLock()
cloned := cloneAntigravityModels(antigravityPrimaryModelsCache.models)
antigravityPrimaryModelsCache.mu.RUnlock()
return cloned
}
func fallbackAntigravityPrimaryModels() []*registry.ModelInfo {
models := loadAntigravityPrimaryModels()
if len(models) > 0 {
log.Debugf("antigravity executor: using cached primary model list (%d models)", len(models))
}
return models
}
// AntigravityExecutor proxies requests to the antigravity upstream.
type AntigravityExecutor struct {
cfg *config.Config
@@ -1150,168 +1078,6 @@ func (e *AntigravityExecutor) CountTokens(ctx context.Context, auth *cliproxyaut
}
}
// FetchAntigravityModels retrieves available models using the supplied auth.
func FetchAntigravityModels(ctx context.Context, auth *cliproxyauth.Auth, cfg *config.Config) []*registry.ModelInfo {
exec := &AntigravityExecutor{cfg: cfg}
token, updatedAuth, errToken := exec.ensureAccessToken(ctx, auth)
if errToken != nil || token == "" {
return fallbackAntigravityPrimaryModels()
}
if updatedAuth != nil {
auth = updatedAuth
}
baseURLs := antigravityBaseURLFallbackOrder(auth)
httpClient := newAntigravityHTTPClient(ctx, cfg, auth, 0)
for idx, baseURL := range baseURLs {
modelsURL := baseURL + antigravityModelsPath
var payload []byte
if auth != nil && auth.Metadata != nil {
if pid, ok := auth.Metadata["project_id"].(string); ok && strings.TrimSpace(pid) != "" {
payload = []byte(fmt.Sprintf(`{"project": "%s"}`, strings.TrimSpace(pid)))
}
}
if len(payload) == 0 {
payload = []byte(`{}`)
}
httpReq, errReq := http.NewRequestWithContext(ctx, http.MethodPost, modelsURL, bytes.NewReader(payload))
if errReq != nil {
return fallbackAntigravityPrimaryModels()
}
httpReq.Close = true
httpReq.Header.Set("Content-Type", "application/json")
httpReq.Header.Set("Authorization", "Bearer "+token)
httpReq.Header.Set("User-Agent", resolveUserAgent(auth))
if host := resolveHost(baseURL); host != "" {
httpReq.Host = host
}
httpResp, errDo := httpClient.Do(httpReq)
if errDo != nil {
if errors.Is(errDo, context.Canceled) || errors.Is(errDo, context.DeadlineExceeded) {
return fallbackAntigravityPrimaryModels()
}
if idx+1 < len(baseURLs) {
log.Debugf("antigravity executor: models request error on base url %s, retrying with fallback base url: %s", baseURL, baseURLs[idx+1])
continue
}
return fallbackAntigravityPrimaryModels()
}
bodyBytes, errRead := io.ReadAll(httpResp.Body)
if errClose := httpResp.Body.Close(); errClose != nil {
log.Errorf("antigravity executor: close response body error: %v", errClose)
}
if errRead != nil {
if idx+1 < len(baseURLs) {
log.Debugf("antigravity executor: models read error on base url %s, retrying with fallback base url: %s", baseURL, baseURLs[idx+1])
continue
}
return fallbackAntigravityPrimaryModels()
}
if httpResp.StatusCode < http.StatusOK || httpResp.StatusCode >= http.StatusMultipleChoices {
if httpResp.StatusCode == http.StatusTooManyRequests && idx+1 < len(baseURLs) {
log.Debugf("antigravity executor: models request rate limited on base url %s, retrying with fallback base url: %s", baseURL, baseURLs[idx+1])
continue
}
if idx+1 < len(baseURLs) {
log.Debugf("antigravity executor: models request failed with status %d on base url %s, retrying with fallback base url: %s", httpResp.StatusCode, baseURL, baseURLs[idx+1])
continue
}
return fallbackAntigravityPrimaryModels()
}
result := gjson.GetBytes(bodyBytes, "models")
if !result.Exists() {
if idx+1 < len(baseURLs) {
log.Debugf("antigravity executor: models field missing on base url %s, retrying with fallback base url: %s", baseURL, baseURLs[idx+1])
continue
}
return fallbackAntigravityPrimaryModels()
}
now := time.Now().Unix()
modelConfig := registry.GetAntigravityModelConfig()
models := make([]*registry.ModelInfo, 0, len(result.Map()))
for originalName, modelData := range result.Map() {
modelID := strings.TrimSpace(originalName)
if modelID == "" {
continue
}
switch modelID {
case "chat_20706", "chat_23310", "tab_flash_lite_preview", "tab_jump_flash_lite_preview", "gemini-2.5-flash-thinking", "gemini-2.5-pro":
continue
}
modelCfg := modelConfig[modelID]
// Extract displayName from upstream response, fallback to modelID
displayName := modelData.Get("displayName").String()
if displayName == "" {
displayName = modelID
}
modelInfo := &registry.ModelInfo{
ID: modelID,
Name: modelID,
Description: displayName,
DisplayName: displayName,
Version: modelID,
Object: "model",
Created: now,
OwnedBy: antigravityAuthType,
Type: antigravityAuthType,
}
// Build input modalities from upstream capability flags.
inputModalities := []string{"TEXT"}
if modelData.Get("supportsImages").Bool() {
inputModalities = append(inputModalities, "IMAGE")
}
if modelData.Get("supportsVideo").Bool() {
inputModalities = append(inputModalities, "VIDEO")
}
modelInfo.SupportedInputModalities = inputModalities
modelInfo.SupportedOutputModalities = []string{"TEXT"}
// Token limits from upstream.
if maxTok := modelData.Get("maxTokens").Int(); maxTok > 0 {
modelInfo.InputTokenLimit = int(maxTok)
}
if maxOut := modelData.Get("maxOutputTokens").Int(); maxOut > 0 {
modelInfo.OutputTokenLimit = int(maxOut)
}
// Supported generation methods (Gemini v1beta convention).
modelInfo.SupportedGenerationMethods = []string{"generateContent", "countTokens"}
// Look up Thinking support from static config using upstream model name.
if modelCfg != nil {
if modelCfg.Thinking != nil {
modelInfo.Thinking = modelCfg.Thinking
}
if modelCfg.MaxCompletionTokens > 0 {
modelInfo.MaxCompletionTokens = modelCfg.MaxCompletionTokens
}
}
models = append(models, modelInfo)
}
if len(models) == 0 {
if idx+1 < len(baseURLs) {
log.Debugf("antigravity executor: empty models list on base url %s, retrying with fallback base url: %s", baseURL, baseURLs[idx+1])
continue
}
log.Debug("antigravity executor: fetched empty model list; retaining cached primary model list")
return fallbackAntigravityPrimaryModels()
}
storeAntigravityPrimaryModels(models)
return models
}
return fallbackAntigravityPrimaryModels()
}
func (e *AntigravityExecutor) ensureAccessToken(ctx context.Context, auth *cliproxyauth.Auth) (string, *cliproxyauth.Auth, error) {
if auth == nil {
return "", nil, statusErr{code: http.StatusUnauthorized, msg: "missing auth"}

View File

@@ -1,90 +0,0 @@
package executor
import (
"testing"
"github.com/router-for-me/CLIProxyAPI/v6/internal/registry"
)
func resetAntigravityPrimaryModelsCacheForTest() {
antigravityPrimaryModelsCache.mu.Lock()
antigravityPrimaryModelsCache.models = nil
antigravityPrimaryModelsCache.mu.Unlock()
}
func TestStoreAntigravityPrimaryModels_EmptyDoesNotOverwrite(t *testing.T) {
resetAntigravityPrimaryModelsCacheForTest()
t.Cleanup(resetAntigravityPrimaryModelsCacheForTest)
seed := []*registry.ModelInfo{
{ID: "claude-sonnet-4-5"},
{ID: "gemini-2.5-pro"},
}
if updated := storeAntigravityPrimaryModels(seed); !updated {
t.Fatal("expected non-empty model list to update primary cache")
}
if updated := storeAntigravityPrimaryModels(nil); updated {
t.Fatal("expected nil model list not to overwrite primary cache")
}
if updated := storeAntigravityPrimaryModels([]*registry.ModelInfo{}); updated {
t.Fatal("expected empty model list not to overwrite primary cache")
}
got := loadAntigravityPrimaryModels()
if len(got) != 2 {
t.Fatalf("expected cached model count 2, got %d", len(got))
}
if got[0].ID != "claude-sonnet-4-5" || got[1].ID != "gemini-2.5-pro" {
t.Fatalf("unexpected cached model ids: %q, %q", got[0].ID, got[1].ID)
}
}
func TestLoadAntigravityPrimaryModels_ReturnsClone(t *testing.T) {
resetAntigravityPrimaryModelsCacheForTest()
t.Cleanup(resetAntigravityPrimaryModelsCacheForTest)
if updated := storeAntigravityPrimaryModels([]*registry.ModelInfo{{
ID: "gpt-5",
DisplayName: "GPT-5",
SupportedGenerationMethods: []string{"generateContent"},
SupportedParameters: []string{"temperature"},
Thinking: &registry.ThinkingSupport{
Levels: []string{"high"},
},
}}); !updated {
t.Fatal("expected model cache update")
}
got := loadAntigravityPrimaryModels()
if len(got) != 1 {
t.Fatalf("expected one cached model, got %d", len(got))
}
got[0].ID = "mutated-id"
if len(got[0].SupportedGenerationMethods) > 0 {
got[0].SupportedGenerationMethods[0] = "mutated-method"
}
if len(got[0].SupportedParameters) > 0 {
got[0].SupportedParameters[0] = "mutated-parameter"
}
if got[0].Thinking != nil && len(got[0].Thinking.Levels) > 0 {
got[0].Thinking.Levels[0] = "mutated-level"
}
again := loadAntigravityPrimaryModels()
if len(again) != 1 {
t.Fatalf("expected one cached model after mutation, got %d", len(again))
}
if again[0].ID != "gpt-5" {
t.Fatalf("expected cached model id to remain %q, got %q", "gpt-5", again[0].ID)
}
if len(again[0].SupportedGenerationMethods) == 0 || again[0].SupportedGenerationMethods[0] != "generateContent" {
t.Fatalf("expected cached generation methods to be unmutated, got %v", again[0].SupportedGenerationMethods)
}
if len(again[0].SupportedParameters) == 0 || again[0].SupportedParameters[0] != "temperature" {
t.Fatalf("expected cached supported parameters to be unmutated, got %v", again[0].SupportedParameters)
}
if again[0].Thinking == nil || len(again[0].Thinking.Levels) == 0 || again[0].Thinking.Levels[0] != "high" {
t.Fatalf("expected cached model thinking levels to be unmutated, got %v", again[0].Thinking)
}
}

View File

@@ -1266,6 +1266,10 @@ func checkSystemInstructionsWithMode(payload []byte, strictMode bool) []byte {
}
return true
})
} else if system.Type == gjson.String && system.String() != "" {
partJSON := `{"type":"text","cache_control":{"type":"ephemeral"}}`
partJSON, _ = sjson.Set(partJSON, "text", system.String())
result += "," + partJSON
}
result += "]"
@@ -1485,25 +1489,27 @@ func countCacheControlsMap(root map[string]any) int {
return count
}
func normalizeTTLForBlock(obj map[string]any, seen5m *bool) {
func normalizeTTLForBlock(obj map[string]any, seen5m *bool) bool {
ccRaw, exists := obj["cache_control"]
if !exists {
return
return false
}
cc, ok := asObject(ccRaw)
if !ok {
*seen5m = true
return
return false
}
ttlRaw, ttlExists := cc["ttl"]
ttl, ttlIsString := ttlRaw.(string)
if !ttlExists || !ttlIsString || ttl != "1h" {
*seen5m = true
return
return false
}
if *seen5m {
delete(cc, "ttl")
return true
}
return false
}
func findLastCacheControlIndex(arr []any) int {
@@ -1599,11 +1605,14 @@ func normalizeCacheControlTTL(payload []byte) []byte {
}
seen5m := false
modified := false
if tools, ok := asArray(root["tools"]); ok {
for _, tool := range tools {
if obj, ok := asObject(tool); ok {
normalizeTTLForBlock(obj, &seen5m)
if normalizeTTLForBlock(obj, &seen5m) {
modified = true
}
}
}
}
@@ -1611,7 +1620,9 @@ func normalizeCacheControlTTL(payload []byte) []byte {
if system, ok := asArray(root["system"]); ok {
for _, item := range system {
if obj, ok := asObject(item); ok {
normalizeTTLForBlock(obj, &seen5m)
if normalizeTTLForBlock(obj, &seen5m) {
modified = true
}
}
}
}
@@ -1628,12 +1639,17 @@ func normalizeCacheControlTTL(payload []byte) []byte {
}
for _, item := range content {
if obj, ok := asObject(item); ok {
normalizeTTLForBlock(obj, &seen5m)
if normalizeTTLForBlock(obj, &seen5m) {
modified = true
}
}
}
}
}
if !modified {
return payload
}
return marshalPayloadObject(payload, root)
}

View File

@@ -369,6 +369,19 @@ func TestNormalizeCacheControlTTL_DowngradesLaterOneHourBlocks(t *testing.T) {
}
}
func TestNormalizeCacheControlTTL_PreservesOriginalBytesWhenNoChange(t *testing.T) {
// Payload where no TTL normalization is needed (all blocks use 1h with no
// preceding 5m block). The text intentionally contains HTML chars (<, >, &)
// that json.Marshal would escape to \u003c etc., altering byte identity.
payload := []byte(`{"tools":[{"name":"t1","cache_control":{"type":"ephemeral","ttl":"1h"}}],"system":[{"type":"text","text":"<system-reminder>foo & bar</system-reminder>","cache_control":{"type":"ephemeral","ttl":"1h"}}],"messages":[{"role":"user","content":[{"type":"text","text":"hello"}]}]}`)
out := normalizeCacheControlTTL(payload)
if !bytes.Equal(out, payload) {
t.Fatalf("normalizeCacheControlTTL altered bytes when no change was needed.\noriginal: %s\ngot: %s", payload, out)
}
}
func TestEnforceCacheControlLimit_StripsNonLastToolBeforeMessages(t *testing.T) {
payload := []byte(`{
"tools": [
@@ -829,8 +842,8 @@ func TestClaudeExecutor_ExecuteStream_AcceptEncodingOverrideCannotBypassIdentity
executor := NewClaudeExecutor(&config.Config{})
// Inject Accept-Encoding via the custom header attribute mechanism.
auth := &cliproxyauth.Auth{Attributes: map[string]string{
"api_key": "key-123",
"base_url": server.URL,
"api_key": "key-123",
"base_url": server.URL,
"header:Accept-Encoding": "gzip, deflate, br, zstd",
}}
payload := []byte(`{"messages":[{"role":"user","content":[{"type":"text","text":"hi"}]}]}`)
@@ -967,3 +980,87 @@ func TestClaudeExecutor_ExecuteStream_GzipErrorBodyNoContentEncodingHeader(t *te
t.Errorf("error message should contain decompressed JSON, got: %q", err.Error())
}
}
// Test case 1: String system prompt is preserved and converted to a content block
func TestCheckSystemInstructionsWithMode_StringSystemPreserved(t *testing.T) {
payload := []byte(`{"system":"You are a helpful assistant.","messages":[{"role":"user","content":"hi"}]}`)
out := checkSystemInstructionsWithMode(payload, false)
system := gjson.GetBytes(out, "system")
if !system.IsArray() {
t.Fatalf("system should be an array, got %s", system.Type)
}
blocks := system.Array()
if len(blocks) != 3 {
t.Fatalf("expected 3 system blocks, got %d", len(blocks))
}
if !strings.HasPrefix(blocks[0].Get("text").String(), "x-anthropic-billing-header:") {
t.Fatalf("blocks[0] should be billing header, got %q", blocks[0].Get("text").String())
}
if blocks[1].Get("text").String() != "You are a Claude agent, built on Anthropic's Claude Agent SDK." {
t.Fatalf("blocks[1] should be agent block, got %q", blocks[1].Get("text").String())
}
if blocks[2].Get("text").String() != "You are a helpful assistant." {
t.Fatalf("blocks[2] should be user system prompt, got %q", blocks[2].Get("text").String())
}
if blocks[2].Get("cache_control.type").String() != "ephemeral" {
t.Fatalf("blocks[2] should have cache_control.type=ephemeral")
}
}
// Test case 2: Strict mode drops the string system prompt
func TestCheckSystemInstructionsWithMode_StringSystemStrict(t *testing.T) {
payload := []byte(`{"system":"You are a helpful assistant.","messages":[{"role":"user","content":"hi"}]}`)
out := checkSystemInstructionsWithMode(payload, true)
blocks := gjson.GetBytes(out, "system").Array()
if len(blocks) != 2 {
t.Fatalf("strict mode should produce 2 blocks, got %d", len(blocks))
}
}
// Test case 3: Empty string system prompt does not produce a spurious block
func TestCheckSystemInstructionsWithMode_EmptyStringSystemIgnored(t *testing.T) {
payload := []byte(`{"system":"","messages":[{"role":"user","content":"hi"}]}`)
out := checkSystemInstructionsWithMode(payload, false)
blocks := gjson.GetBytes(out, "system").Array()
if len(blocks) != 2 {
t.Fatalf("empty string system should produce 2 blocks, got %d", len(blocks))
}
}
// Test case 4: Array system prompt is unaffected by the string handling
func TestCheckSystemInstructionsWithMode_ArraySystemStillWorks(t *testing.T) {
payload := []byte(`{"system":[{"type":"text","text":"Be concise."}],"messages":[{"role":"user","content":"hi"}]}`)
out := checkSystemInstructionsWithMode(payload, false)
blocks := gjson.GetBytes(out, "system").Array()
if len(blocks) != 3 {
t.Fatalf("expected 3 system blocks, got %d", len(blocks))
}
if blocks[2].Get("text").String() != "Be concise." {
t.Fatalf("blocks[2] should be user system prompt, got %q", blocks[2].Get("text").String())
}
}
// Test case 5: Special characters in string system prompt survive conversion
func TestCheckSystemInstructionsWithMode_StringWithSpecialChars(t *testing.T) {
payload := []byte(`{"system":"Use <xml> tags & \"quotes\" in output.","messages":[{"role":"user","content":"hi"}]}`)
out := checkSystemInstructionsWithMode(payload, false)
blocks := gjson.GetBytes(out, "system").Array()
if len(blocks) != 3 {
t.Fatalf("expected 3 system blocks, got %d", len(blocks))
}
if blocks[2].Get("text").String() != `Use <xml> tags & "quotes" in output.` {
t.Fatalf("blocks[2] text mangled, got %q", blocks[2].Get("text").String())
}
}

View File

@@ -122,7 +122,7 @@ func (e *CodexExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, re
if err != nil {
return resp, err
}
applyCodexHeaders(httpReq, auth, apiKey, true)
applyCodexHeaders(httpReq, auth, apiKey, true, e.cfg)
var authID, authLabel, authType, authValue string
if auth != nil {
authID = auth.ID
@@ -226,7 +226,7 @@ func (e *CodexExecutor) executeCompact(ctx context.Context, auth *cliproxyauth.A
if err != nil {
return resp, err
}
applyCodexHeaders(httpReq, auth, apiKey, false)
applyCodexHeaders(httpReq, auth, apiKey, false, e.cfg)
var authID, authLabel, authType, authValue string
if auth != nil {
authID = auth.ID
@@ -321,7 +321,7 @@ func (e *CodexExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Au
if err != nil {
return nil, err
}
applyCodexHeaders(httpReq, auth, apiKey, true)
applyCodexHeaders(httpReq, auth, apiKey, true, e.cfg)
var authID, authLabel, authType, authValue string
if auth != nil {
authID = auth.ID
@@ -636,7 +636,7 @@ func (e *CodexExecutor) cacheHelper(ctx context.Context, from sdktranslator.Form
return httpReq, nil
}
func applyCodexHeaders(r *http.Request, auth *cliproxyauth.Auth, token string, stream bool) {
func applyCodexHeaders(r *http.Request, auth *cliproxyauth.Auth, token string, stream bool, cfg *config.Config) {
r.Header.Set("Content-Type", "application/json")
r.Header.Set("Authorization", "Bearer "+token)
@@ -647,7 +647,8 @@ func applyCodexHeaders(r *http.Request, auth *cliproxyauth.Auth, token string, s
misc.EnsureHeader(r.Header, ginHeaders, "Version", codexClientVersion)
misc.EnsureHeader(r.Header, ginHeaders, "Session_id", uuid.NewString())
misc.EnsureHeader(r.Header, ginHeaders, "User-Agent", codexUserAgent)
cfgUserAgent, _ := codexHeaderDefaults(cfg, auth)
ensureHeaderWithConfigPrecedence(r.Header, ginHeaders, "User-Agent", cfgUserAgent, codexUserAgent)
if stream {
r.Header.Set("Accept", "text/event-stream")

View File

@@ -23,6 +23,7 @@ import (
"github.com/router-for-me/CLIProxyAPI/v6/internal/util"
cliproxyauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth"
cliproxyexecutor "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/executor"
"github.com/router-for-me/CLIProxyAPI/v6/sdk/proxyutil"
sdktranslator "github.com/router-for-me/CLIProxyAPI/v6/sdk/translator"
log "github.com/sirupsen/logrus"
"github.com/tidwall/gjson"
@@ -31,7 +32,7 @@ import (
)
const (
codexResponsesWebsocketBetaHeaderValue = "responses_websockets=2026-02-04"
codexResponsesWebsocketBetaHeaderValue = "responses_websockets=2026-02-06"
codexResponsesWebsocketIdleTimeout = 5 * time.Minute
codexResponsesWebsocketHandshakeTO = 30 * time.Second
)
@@ -57,11 +58,6 @@ type codexWebsocketSession struct {
wsURL string
authID string
// connCreateSent tracks whether a `response.create` message has been successfully sent
// on the current websocket connection. The upstream expects the first message on each
// connection to be `response.create`.
connCreateSent bool
writeMu sync.Mutex
activeMu sync.Mutex
@@ -195,7 +191,7 @@ func (e *CodexWebsocketsExecutor) Execute(ctx context.Context, auth *cliproxyaut
}
body, wsHeaders := applyCodexPromptCacheHeaders(from, req, body)
wsHeaders = applyCodexWebsocketHeaders(ctx, wsHeaders, auth, apiKey)
wsHeaders = applyCodexWebsocketHeaders(ctx, wsHeaders, auth, apiKey, e.cfg)
var authID, authLabel, authType, authValue string
if auth != nil {
@@ -212,13 +208,7 @@ func (e *CodexWebsocketsExecutor) Execute(ctx context.Context, auth *cliproxyaut
defer sess.reqMu.Unlock()
}
allowAppend := true
if sess != nil {
sess.connMu.Lock()
allowAppend = sess.connCreateSent
sess.connMu.Unlock()
}
wsReqBody := buildCodexWebsocketRequestBody(body, allowAppend)
wsReqBody := buildCodexWebsocketRequestBody(body)
recordAPIRequest(ctx, e.cfg, upstreamRequestLog{
URL: wsURL,
Method: "WEBSOCKET",
@@ -280,10 +270,7 @@ func (e *CodexWebsocketsExecutor) Execute(ctx context.Context, auth *cliproxyaut
// execution session.
connRetry, _, errDialRetry := e.ensureUpstreamConn(ctx, auth, sess, authID, wsURL, wsHeaders)
if errDialRetry == nil && connRetry != nil {
sess.connMu.Lock()
allowAppend = sess.connCreateSent
sess.connMu.Unlock()
wsReqBodyRetry := buildCodexWebsocketRequestBody(body, allowAppend)
wsReqBodyRetry := buildCodexWebsocketRequestBody(body)
recordAPIRequest(ctx, e.cfg, upstreamRequestLog{
URL: wsURL,
Method: "WEBSOCKET",
@@ -312,7 +299,6 @@ func (e *CodexWebsocketsExecutor) Execute(ctx context.Context, auth *cliproxyaut
return resp, errSend
}
}
markCodexWebsocketCreateSent(sess, conn, wsReqBody)
for {
if ctx != nil && ctx.Err() != nil {
@@ -400,29 +386,23 @@ func (e *CodexWebsocketsExecutor) ExecuteStream(ctx context.Context, auth *clipr
}
body, wsHeaders := applyCodexPromptCacheHeaders(from, req, body)
wsHeaders = applyCodexWebsocketHeaders(ctx, wsHeaders, auth, apiKey)
wsHeaders = applyCodexWebsocketHeaders(ctx, wsHeaders, auth, apiKey, e.cfg)
var authID, authLabel, authType, authValue string
if auth != nil {
authID = auth.ID
authLabel = auth.Label
authType, authValue = auth.AccountInfo()
}
authID = auth.ID
authLabel = auth.Label
authType, authValue = auth.AccountInfo()
executionSessionID := executionSessionIDFromOptions(opts)
var sess *codexWebsocketSession
if executionSessionID != "" {
sess = e.getOrCreateSession(executionSessionID)
sess.reqMu.Lock()
if sess != nil {
sess.reqMu.Lock()
}
}
allowAppend := true
if sess != nil {
sess.connMu.Lock()
allowAppend = sess.connCreateSent
sess.connMu.Unlock()
}
wsReqBody := buildCodexWebsocketRequestBody(body, allowAppend)
wsReqBody := buildCodexWebsocketRequestBody(body)
recordAPIRequest(ctx, e.cfg, upstreamRequestLog{
URL: wsURL,
Method: "WEBSOCKET",
@@ -483,10 +463,7 @@ func (e *CodexWebsocketsExecutor) ExecuteStream(ctx context.Context, auth *clipr
sess.reqMu.Unlock()
return nil, errDialRetry
}
sess.connMu.Lock()
allowAppend = sess.connCreateSent
sess.connMu.Unlock()
wsReqBodyRetry := buildCodexWebsocketRequestBody(body, allowAppend)
wsReqBodyRetry := buildCodexWebsocketRequestBody(body)
recordAPIRequest(ctx, e.cfg, upstreamRequestLog{
URL: wsURL,
Method: "WEBSOCKET",
@@ -515,7 +492,6 @@ func (e *CodexWebsocketsExecutor) ExecuteStream(ctx context.Context, auth *clipr
return nil, errSend
}
}
markCodexWebsocketCreateSent(sess, conn, wsReqBody)
out := make(chan cliproxyexecutor.StreamChunk)
go func() {
@@ -657,31 +633,14 @@ func writeCodexWebsocketMessage(sess *codexWebsocketSession, conn *websocket.Con
return conn.WriteMessage(websocket.TextMessage, payload)
}
func buildCodexWebsocketRequestBody(body []byte, allowAppend bool) []byte {
func buildCodexWebsocketRequestBody(body []byte) []byte {
if len(body) == 0 {
return nil
}
// Codex CLI websocket v2 uses `response.create` with `previous_response_id` for incremental turns.
// The upstream ChatGPT Codex websocket currently rejects that with close 1008 (policy violation).
// Fall back to v1 `response.append` semantics on the same websocket connection to keep the session alive.
//
// NOTE: The upstream expects the first websocket event on each connection to be `response.create`,
// so we only use `response.append` after we have initialized the current connection.
if allowAppend {
if prev := strings.TrimSpace(gjson.GetBytes(body, "previous_response_id").String()); prev != "" {
inputNode := gjson.GetBytes(body, "input")
wsReqBody := []byte(`{}`)
wsReqBody, _ = sjson.SetBytes(wsReqBody, "type", "response.append")
if inputNode.Exists() && inputNode.IsArray() && strings.TrimSpace(inputNode.Raw) != "" {
wsReqBody, _ = sjson.SetRawBytes(wsReqBody, "input", []byte(inputNode.Raw))
return wsReqBody
}
wsReqBody, _ = sjson.SetRawBytes(wsReqBody, "input", []byte("[]"))
return wsReqBody
}
}
// Match codex-rs websocket v2 semantics: every request is `response.create`.
// Incremental follow-up turns continue on the same websocket using
// `previous_response_id` + incremental `input`, not `response.append`.
wsReqBody, errSet := sjson.SetBytes(bytes.Clone(body), "type", "response.create")
if errSet == nil && len(wsReqBody) > 0 {
return wsReqBody
@@ -725,21 +684,6 @@ func readCodexWebsocketMessage(ctx context.Context, sess *codexWebsocketSession,
}
}
func markCodexWebsocketCreateSent(sess *codexWebsocketSession, conn *websocket.Conn, payload []byte) {
if sess == nil || conn == nil || len(payload) == 0 {
return
}
if strings.TrimSpace(gjson.GetBytes(payload, "type").String()) != "response.create" {
return
}
sess.connMu.Lock()
if sess.conn == conn {
sess.connCreateSent = true
}
sess.connMu.Unlock()
}
func newProxyAwareWebsocketDialer(cfg *config.Config, auth *cliproxyauth.Auth) *websocket.Dialer {
dialer := &websocket.Dialer{
Proxy: http.ProxyFromEnvironment,
@@ -762,21 +706,30 @@ func newProxyAwareWebsocketDialer(cfg *config.Config, auth *cliproxyauth.Auth) *
return dialer
}
parsedURL, errParse := url.Parse(proxyURL)
setting, errParse := proxyutil.Parse(proxyURL)
if errParse != nil {
log.Errorf("codex websockets executor: parse proxy URL failed: %v", errParse)
log.Errorf("codex websockets executor: %v", errParse)
return dialer
}
switch parsedURL.Scheme {
switch setting.Mode {
case proxyutil.ModeDirect:
dialer.Proxy = nil
return dialer
case proxyutil.ModeProxy:
default:
return dialer
}
switch setting.URL.Scheme {
case "socks5":
var proxyAuth *proxy.Auth
if parsedURL.User != nil {
username := parsedURL.User.Username()
password, _ := parsedURL.User.Password()
if setting.URL.User != nil {
username := setting.URL.User.Username()
password, _ := setting.URL.User.Password()
proxyAuth = &proxy.Auth{User: username, Password: password}
}
socksDialer, errSOCKS5 := proxy.SOCKS5("tcp", parsedURL.Host, proxyAuth, proxy.Direct)
socksDialer, errSOCKS5 := proxy.SOCKS5("tcp", setting.URL.Host, proxyAuth, proxy.Direct)
if errSOCKS5 != nil {
log.Errorf("codex websockets executor: create SOCKS5 dialer failed: %v", errSOCKS5)
return dialer
@@ -786,9 +739,9 @@ func newProxyAwareWebsocketDialer(cfg *config.Config, auth *cliproxyauth.Auth) *
return socksDialer.Dial(network, addr)
}
case "http", "https":
dialer.Proxy = http.ProxyURL(parsedURL)
dialer.Proxy = http.ProxyURL(setting.URL)
default:
log.Errorf("codex websockets executor: unsupported proxy scheme: %s", parsedURL.Scheme)
log.Errorf("codex websockets executor: unsupported proxy scheme: %s", setting.URL.Scheme)
}
return dialer
@@ -844,7 +797,7 @@ func applyCodexPromptCacheHeaders(from sdktranslator.Format, req cliproxyexecuto
return rawJSON, headers
}
func applyCodexWebsocketHeaders(ctx context.Context, headers http.Header, auth *cliproxyauth.Auth, token string) http.Header {
func applyCodexWebsocketHeaders(ctx context.Context, headers http.Header, auth *cliproxyauth.Auth, token string, cfg *config.Config) http.Header {
if headers == nil {
headers = http.Header{}
}
@@ -857,7 +810,8 @@ func applyCodexWebsocketHeaders(ctx context.Context, headers http.Header, auth *
ginHeaders = ginCtx.Request.Header
}
misc.EnsureHeader(headers, ginHeaders, "x-codex-beta-features", "")
cfgUserAgent, cfgBetaFeatures := codexHeaderDefaults(cfg, auth)
ensureHeaderWithPriority(headers, ginHeaders, "x-codex-beta-features", cfgBetaFeatures, "")
misc.EnsureHeader(headers, ginHeaders, "x-codex-turn-state", "")
misc.EnsureHeader(headers, ginHeaders, "x-codex-turn-metadata", "")
misc.EnsureHeader(headers, ginHeaders, "x-responsesapi-include-timing-metrics", "")
@@ -872,7 +826,7 @@ func applyCodexWebsocketHeaders(ctx context.Context, headers http.Header, auth *
}
headers.Set("OpenAI-Beta", betaHeader)
misc.EnsureHeader(headers, ginHeaders, "Session_id", uuid.NewString())
misc.EnsureHeader(headers, ginHeaders, "User-Agent", codexUserAgent)
ensureHeaderWithConfigPrecedence(headers, ginHeaders, "User-Agent", cfgUserAgent, codexUserAgent)
isAPIKey := false
if auth != nil && auth.Attributes != nil {
@@ -900,6 +854,62 @@ func applyCodexWebsocketHeaders(ctx context.Context, headers http.Header, auth *
return headers
}
func codexHeaderDefaults(cfg *config.Config, auth *cliproxyauth.Auth) (string, string) {
if cfg == nil || auth == nil {
return "", ""
}
if auth.Attributes != nil {
if v := strings.TrimSpace(auth.Attributes["api_key"]); v != "" {
return "", ""
}
}
return strings.TrimSpace(cfg.CodexHeaderDefaults.UserAgent), strings.TrimSpace(cfg.CodexHeaderDefaults.BetaFeatures)
}
func ensureHeaderWithPriority(target http.Header, source http.Header, key, configValue, fallbackValue string) {
if target == nil {
return
}
if strings.TrimSpace(target.Get(key)) != "" {
return
}
if source != nil {
if val := strings.TrimSpace(source.Get(key)); val != "" {
target.Set(key, val)
return
}
}
if val := strings.TrimSpace(configValue); val != "" {
target.Set(key, val)
return
}
if val := strings.TrimSpace(fallbackValue); val != "" {
target.Set(key, val)
}
}
func ensureHeaderWithConfigPrecedence(target http.Header, source http.Header, key, configValue, fallbackValue string) {
if target == nil {
return
}
if strings.TrimSpace(target.Get(key)) != "" {
return
}
if val := strings.TrimSpace(configValue); val != "" {
target.Set(key, val)
return
}
if source != nil {
if val := strings.TrimSpace(source.Get(key)); val != "" {
target.Set(key, val)
return
}
}
if val := strings.TrimSpace(fallbackValue); val != "" {
target.Set(key, val)
}
}
type statusErrWithHeaders struct {
statusErr
headers http.Header
@@ -1017,36 +1027,6 @@ func closeHTTPResponseBody(resp *http.Response, logPrefix string) {
}
}
func closeOnContextDone(ctx context.Context, conn *websocket.Conn) chan struct{} {
done := make(chan struct{})
if ctx == nil || conn == nil {
return done
}
go func() {
select {
case <-done:
case <-ctx.Done():
_ = conn.Close()
}
}()
return done
}
func cancelReadOnContextDone(ctx context.Context, conn *websocket.Conn) chan struct{} {
done := make(chan struct{})
if ctx == nil || conn == nil {
return done
}
go func() {
select {
case <-done:
case <-ctx.Done():
_ = conn.SetReadDeadline(time.Now())
}
}()
return done
}
func executionSessionIDFromOptions(opts cliproxyexecutor.Options) string {
if len(opts.Metadata) == 0 {
return ""
@@ -1120,7 +1100,6 @@ func (e *CodexWebsocketsExecutor) ensureUpstreamConn(ctx context.Context, auth *
sess.conn = conn
sess.wsURL = wsURL
sess.authID = authID
sess.connCreateSent = false
sess.readerConn = conn
sess.connMu.Unlock()
@@ -1206,7 +1185,6 @@ func (e *CodexWebsocketsExecutor) invalidateUpstreamConn(sess *codexWebsocketSes
return
}
sess.conn = nil
sess.connCreateSent = false
if sess.readerConn == conn {
sess.readerConn = nil
}
@@ -1273,7 +1251,6 @@ func (e *CodexWebsocketsExecutor) closeExecutionSession(sess *codexWebsocketSess
authID := sess.authID
wsURL := sess.wsURL
sess.conn = nil
sess.connCreateSent = false
if sess.readerConn == conn {
sess.readerConn = nil
}

View File

@@ -0,0 +1,203 @@
package executor
import (
"context"
"net/http"
"net/http/httptest"
"testing"
"github.com/gin-gonic/gin"
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
cliproxyauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth"
sdkconfig "github.com/router-for-me/CLIProxyAPI/v6/sdk/config"
"github.com/tidwall/gjson"
)
func TestBuildCodexWebsocketRequestBodyPreservesPreviousResponseID(t *testing.T) {
body := []byte(`{"model":"gpt-5-codex","previous_response_id":"resp-1","input":[{"type":"message","id":"msg-1"}]}`)
wsReqBody := buildCodexWebsocketRequestBody(body)
if got := gjson.GetBytes(wsReqBody, "type").String(); got != "response.create" {
t.Fatalf("type = %s, want response.create", got)
}
if got := gjson.GetBytes(wsReqBody, "previous_response_id").String(); got != "resp-1" {
t.Fatalf("previous_response_id = %s, want resp-1", got)
}
if gjson.GetBytes(wsReqBody, "input.0.id").String() != "msg-1" {
t.Fatalf("input item id mismatch")
}
if got := gjson.GetBytes(wsReqBody, "type").String(); got == "response.append" {
t.Fatalf("unexpected websocket request type: %s", got)
}
}
func TestApplyCodexWebsocketHeadersDefaultsToCurrentResponsesBeta(t *testing.T) {
headers := applyCodexWebsocketHeaders(context.Background(), http.Header{}, nil, "", nil)
if got := headers.Get("OpenAI-Beta"); got != codexResponsesWebsocketBetaHeaderValue {
t.Fatalf("OpenAI-Beta = %s, want %s", got, codexResponsesWebsocketBetaHeaderValue)
}
if got := headers.Get("User-Agent"); got != codexUserAgent {
t.Fatalf("User-Agent = %s, want %s", got, codexUserAgent)
}
if got := headers.Get("x-codex-beta-features"); got != "" {
t.Fatalf("x-codex-beta-features = %q, want empty", got)
}
}
func TestApplyCodexWebsocketHeadersUsesConfigDefaultsForOAuth(t *testing.T) {
cfg := &config.Config{
CodexHeaderDefaults: config.CodexHeaderDefaults{
UserAgent: "my-codex-client/1.0",
BetaFeatures: "feature-a,feature-b",
},
}
auth := &cliproxyauth.Auth{
Provider: "codex",
Metadata: map[string]any{"email": "user@example.com"},
}
headers := applyCodexWebsocketHeaders(context.Background(), http.Header{}, auth, "", cfg)
if got := headers.Get("User-Agent"); got != "my-codex-client/1.0" {
t.Fatalf("User-Agent = %s, want %s", got, "my-codex-client/1.0")
}
if got := headers.Get("x-codex-beta-features"); got != "feature-a,feature-b" {
t.Fatalf("x-codex-beta-features = %s, want %s", got, "feature-a,feature-b")
}
if got := headers.Get("OpenAI-Beta"); got != codexResponsesWebsocketBetaHeaderValue {
t.Fatalf("OpenAI-Beta = %s, want %s", got, codexResponsesWebsocketBetaHeaderValue)
}
}
func TestApplyCodexWebsocketHeadersPrefersExistingHeadersOverClientAndConfig(t *testing.T) {
cfg := &config.Config{
CodexHeaderDefaults: config.CodexHeaderDefaults{
UserAgent: "config-ua",
BetaFeatures: "config-beta",
},
}
auth := &cliproxyauth.Auth{
Provider: "codex",
Metadata: map[string]any{"email": "user@example.com"},
}
ctx := contextWithGinHeaders(map[string]string{
"User-Agent": "client-ua",
"X-Codex-Beta-Features": "client-beta",
})
headers := http.Header{}
headers.Set("User-Agent", "existing-ua")
headers.Set("X-Codex-Beta-Features", "existing-beta")
got := applyCodexWebsocketHeaders(ctx, headers, auth, "", cfg)
if gotVal := got.Get("User-Agent"); gotVal != "existing-ua" {
t.Fatalf("User-Agent = %s, want %s", gotVal, "existing-ua")
}
if gotVal := got.Get("x-codex-beta-features"); gotVal != "existing-beta" {
t.Fatalf("x-codex-beta-features = %s, want %s", gotVal, "existing-beta")
}
}
func TestApplyCodexWebsocketHeadersConfigUserAgentOverridesClientHeader(t *testing.T) {
cfg := &config.Config{
CodexHeaderDefaults: config.CodexHeaderDefaults{
UserAgent: "config-ua",
BetaFeatures: "config-beta",
},
}
auth := &cliproxyauth.Auth{
Provider: "codex",
Metadata: map[string]any{"email": "user@example.com"},
}
ctx := contextWithGinHeaders(map[string]string{
"User-Agent": "client-ua",
"X-Codex-Beta-Features": "client-beta",
})
headers := applyCodexWebsocketHeaders(ctx, http.Header{}, auth, "", cfg)
if got := headers.Get("User-Agent"); got != "config-ua" {
t.Fatalf("User-Agent = %s, want %s", got, "config-ua")
}
if got := headers.Get("x-codex-beta-features"); got != "client-beta" {
t.Fatalf("x-codex-beta-features = %s, want %s", got, "client-beta")
}
}
func TestApplyCodexWebsocketHeadersIgnoresConfigForAPIKeyAuth(t *testing.T) {
cfg := &config.Config{
CodexHeaderDefaults: config.CodexHeaderDefaults{
UserAgent: "config-ua",
BetaFeatures: "config-beta",
},
}
auth := &cliproxyauth.Auth{
Provider: "codex",
Attributes: map[string]string{"api_key": "sk-test"},
}
headers := applyCodexWebsocketHeaders(context.Background(), http.Header{}, auth, "sk-test", cfg)
if got := headers.Get("User-Agent"); got != codexUserAgent {
t.Fatalf("User-Agent = %s, want %s", got, codexUserAgent)
}
if got := headers.Get("x-codex-beta-features"); got != "" {
t.Fatalf("x-codex-beta-features = %q, want empty", got)
}
}
func TestApplyCodexHeadersUsesConfigUserAgentForOAuth(t *testing.T) {
req, err := http.NewRequest(http.MethodPost, "https://example.com/responses", nil)
if err != nil {
t.Fatalf("NewRequest() error = %v", err)
}
cfg := &config.Config{
CodexHeaderDefaults: config.CodexHeaderDefaults{
UserAgent: "config-ua",
BetaFeatures: "config-beta",
},
}
auth := &cliproxyauth.Auth{
Provider: "codex",
Metadata: map[string]any{"email": "user@example.com"},
}
req = req.WithContext(contextWithGinHeaders(map[string]string{
"User-Agent": "client-ua",
}))
applyCodexHeaders(req, auth, "oauth-token", true, cfg)
if got := req.Header.Get("User-Agent"); got != "config-ua" {
t.Fatalf("User-Agent = %s, want %s", got, "config-ua")
}
if got := req.Header.Get("x-codex-beta-features"); got != "" {
t.Fatalf("x-codex-beta-features = %q, want empty", got)
}
}
func contextWithGinHeaders(headers map[string]string) context.Context {
gin.SetMode(gin.TestMode)
recorder := httptest.NewRecorder()
ginCtx, _ := gin.CreateTestContext(recorder)
ginCtx.Request = httptest.NewRequest(http.MethodPost, "/", nil)
ginCtx.Request.Header = make(http.Header, len(headers))
for key, value := range headers {
ginCtx.Request.Header.Set(key, value)
}
return context.WithValue(context.Background(), "gin", ginCtx)
}
func TestNewProxyAwareWebsocketDialerDirectDisablesProxy(t *testing.T) {
t.Parallel()
dialer := newProxyAwareWebsocketDialer(
&config.Config{SDKConfig: sdkconfig.SDKConfig{ProxyURL: "http://global-proxy.example.com:8080"}},
&cliproxyauth.Auth{ProxyURL: "direct"},
)
if dialer.Proxy != nil {
t.Fatal("expected websocket proxy function to be nil for direct mode")
}
}

View File

@@ -460,7 +460,7 @@ func (e *GeminiVertexExecutor) executeWithAPIKey(ctx context.Context, auth *clip
// For API key auth, use simpler URL format without project/location
if baseURL == "" {
baseURL = "https://generativelanguage.googleapis.com"
baseURL = "https://aiplatform.googleapis.com"
}
url := fmt.Sprintf("%s/%s/publishers/google/models/%s:%s", baseURL, vertexAPIVersion, baseModel, action)
if opts.Alt != "" && action != "countTokens" {
@@ -683,7 +683,7 @@ func (e *GeminiVertexExecutor) executeStreamWithAPIKey(ctx context.Context, auth
action := getVertexAction(baseModel, true)
// For API key auth, use simpler URL format without project/location
if baseURL == "" {
baseURL = "https://generativelanguage.googleapis.com"
baseURL = "https://aiplatform.googleapis.com"
}
url := fmt.Sprintf("%s/%s/publishers/google/models/%s:%s", baseURL, vertexAPIVersion, baseModel, action)
// Imagen models don't support streaming, skip SSE params
@@ -883,7 +883,7 @@ func (e *GeminiVertexExecutor) countTokensWithAPIKey(ctx context.Context, auth *
// For API key auth, use simpler URL format without project/location
if baseURL == "" {
baseURL = "https://generativelanguage.googleapis.com"
baseURL = "https://aiplatform.googleapis.com"
}
url := fmt.Sprintf("%s/%s/publishers/google/models/%s:%s", baseURL, vertexAPIVersion, baseModel, "countTokens")

View File

@@ -522,9 +522,9 @@ func detectLastConversationRole(body []byte) string {
}
switch item.Get("type").String() {
case "function_call", "function_call_arguments":
case "function_call", "function_call_arguments", "computer_call":
return "assistant"
case "function_call_output", "function_call_response", "tool_result":
case "function_call_output", "function_call_response", "tool_result", "computer_call_output":
return "tool"
}
}
@@ -653,6 +653,7 @@ func normalizeGitHubCopilotChatTools(body []byte) []byte {
}
func normalizeGitHubCopilotResponsesInput(body []byte) []byte {
body = stripGitHubCopilotResponsesUnsupportedFields(body)
input := gjson.GetBytes(body, "input")
if input.Exists() {
// If input is already a string or array, keep it as-is.
@@ -825,6 +826,12 @@ func normalizeGitHubCopilotResponsesInput(body []byte) []byte {
return body
}
func stripGitHubCopilotResponsesUnsupportedFields(body []byte) []byte {
// GitHub Copilot /responses rejects service_tier, so always remove it.
body, _ = sjson.DeleteBytes(body, "service_tier")
return body
}
func normalizeGitHubCopilotResponsesTools(body []byte) []byte {
tools := gjson.GetBytes(body, "tools")
if tools.Exists() {
@@ -832,6 +839,10 @@ func normalizeGitHubCopilotResponsesTools(body []byte) []byte {
if tools.IsArray() {
for _, tool := range tools.Array() {
toolType := tool.Get("type").String()
if isGitHubCopilotResponsesBuiltinTool(toolType) {
filtered, _ = sjson.SetRaw(filtered, "-1", tool.Raw)
continue
}
// Accept OpenAI format (type="function") and Claude format
// (no type field, but has top-level name + input_schema).
if toolType != "" && toolType != "function" {
@@ -879,6 +890,10 @@ func normalizeGitHubCopilotResponsesTools(body []byte) []byte {
}
if toolChoice.Type == gjson.JSON {
choiceType := toolChoice.Get("type").String()
if isGitHubCopilotResponsesBuiltinTool(choiceType) {
body, _ = sjson.SetRawBytes(body, "tool_choice", []byte(toolChoice.Raw))
return body
}
if choiceType == "function" {
name := toolChoice.Get("name").String()
if name == "" {
@@ -896,6 +911,15 @@ func normalizeGitHubCopilotResponsesTools(body []byte) []byte {
return body
}
func isGitHubCopilotResponsesBuiltinTool(toolType string) bool {
switch strings.TrimSpace(toolType) {
case "computer", "computer_use_preview":
return true
default:
return false
}
}
func collectTextFromNode(node gjson.Result) string {
if !node.Exists() {
return ""

View File

@@ -132,6 +132,19 @@ func TestNormalizeGitHubCopilotResponsesInput_NonStringInputStringified(t *testi
}
}
func TestNormalizeGitHubCopilotResponsesInput_StripsServiceTier(t *testing.T) {
t.Parallel()
body := []byte(`{"input":"user text","service_tier":"default"}`)
got := normalizeGitHubCopilotResponsesInput(body)
if gjson.GetBytes(got, "service_tier").Exists() {
t.Fatalf("service_tier should be removed, got %s", gjson.GetBytes(got, "service_tier").Raw)
}
if gjson.GetBytes(got, "input").String() != "user text" {
t.Fatalf("input = %q, want %q", gjson.GetBytes(got, "input").String(), "user text")
}
}
func TestNormalizeGitHubCopilotResponsesTools_FlattenFunctionTools(t *testing.T) {
t.Parallel()
body := []byte(`{"tools":[{"type":"function","function":{"name":"sum","description":"d","parameters":{"type":"object"}}},{"type":"web_search"}]}`)

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,469 @@
package executor
import (
"context"
"encoding/json"
"io"
"net/http"
"net/http/httptest"
"strings"
"testing"
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
_ "github.com/router-for-me/CLIProxyAPI/v6/internal/translator"
cliproxyauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth"
cliproxyexecutor "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/executor"
sdktranslator "github.com/router-for-me/CLIProxyAPI/v6/sdk/translator"
"github.com/tidwall/gjson"
)
func TestGitLabExecutorExecuteUsesChatEndpoint(t *testing.T) {
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.URL.Path != gitLabChatEndpoint {
t.Fatalf("unexpected path %q", r.URL.Path)
}
_, _ = w.Write([]byte(`"chat response"`))
}))
defer srv.Close()
exec := NewGitLabExecutor(&config.Config{})
auth := &cliproxyauth.Auth{
Provider: "gitlab",
Metadata: map[string]any{
"base_url": srv.URL,
"access_token": "oauth-access",
"model_name": "claude-sonnet-4-5",
},
}
req := cliproxyexecutor.Request{
Model: "gitlab-duo",
Payload: []byte(`{"model":"gitlab-duo","messages":[{"role":"user","content":"hello"}]}`),
}
resp, err := exec.Execute(context.Background(), auth, req, cliproxyexecutor.Options{
SourceFormat: sdktranslator.FromString("openai"),
})
if err != nil {
t.Fatalf("Execute() error = %v", err)
}
if got := gjson.GetBytes(resp.Payload, "choices.0.message.content").String(); got != "chat response" {
t.Fatalf("expected chat response, got %q", got)
}
if got := gjson.GetBytes(resp.Payload, "model").String(); got != "claude-sonnet-4-5" {
t.Fatalf("expected resolved model, got %q", got)
}
}
func TestGitLabExecutorExecuteFallsBackToCodeSuggestions(t *testing.T) {
chatCalls := 0
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
switch r.URL.Path {
case gitLabChatEndpoint:
chatCalls++
http.Error(w, "feature unavailable", http.StatusForbidden)
case gitLabCodeSuggestionsEndpoint:
_ = json.NewEncoder(w).Encode(map[string]any{
"choices": []map[string]any{{
"text": "fallback response",
}},
})
default:
t.Fatalf("unexpected path %q", r.URL.Path)
}
}))
defer srv.Close()
exec := NewGitLabExecutor(&config.Config{})
auth := &cliproxyauth.Auth{
Provider: "gitlab",
Metadata: map[string]any{
"base_url": srv.URL,
"personal_access_token": "glpat-token",
"auth_method": "pat",
},
}
req := cliproxyexecutor.Request{
Model: "gitlab-duo",
Payload: []byte(`{"model":"gitlab-duo","messages":[{"role":"user","content":"write code"}]}`),
}
resp, err := exec.Execute(context.Background(), auth, req, cliproxyexecutor.Options{
SourceFormat: sdktranslator.FromString("openai"),
})
if err != nil {
t.Fatalf("Execute() error = %v", err)
}
if chatCalls != 1 {
t.Fatalf("expected chat endpoint to be tried once, got %d", chatCalls)
}
if got := gjson.GetBytes(resp.Payload, "choices.0.message.content").String(); got != "fallback response" {
t.Fatalf("expected fallback response, got %q", got)
}
}
func TestGitLabExecutorExecuteUsesAnthropicGateway(t *testing.T) {
var gotAuthHeader, gotRealmHeader string
var gotPath string
var gotModel string
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
gotPath = r.URL.Path
gotAuthHeader = r.Header.Get("Authorization")
gotRealmHeader = r.Header.Get("X-Gitlab-Realm")
gotModel = gjson.GetBytes(readBody(t, r), "model").String()
w.Header().Set("Content-Type", "application/json")
_, _ = w.Write([]byte(`{"id":"msg_1","type":"message","role":"assistant","model":"claude-sonnet-4-5","content":[{"type":"tool_use","id":"toolu_1","name":"Bash","input":{"cmd":"ls"}}],"stop_reason":"tool_use","stop_sequence":null,"usage":{"input_tokens":11,"output_tokens":4}}`))
}))
defer srv.Close()
exec := NewGitLabExecutor(&config.Config{})
auth := &cliproxyauth.Auth{
Provider: "gitlab",
Metadata: map[string]any{
"duo_gateway_base_url": srv.URL,
"duo_gateway_token": "gateway-token",
"duo_gateway_headers": map[string]string{"X-Gitlab-Realm": "saas"},
"model_provider": "anthropic",
"model_name": "claude-sonnet-4-5",
},
}
req := cliproxyexecutor.Request{
Model: "gitlab-duo",
Payload: []byte(`{
"model":"gitlab-duo",
"messages":[{"role":"user","content":[{"type":"text","text":"list files"}]}],
"tools":[{"name":"Bash","description":"run bash","input_schema":{"type":"object","properties":{"cmd":{"type":"string"}},"required":["cmd"]}}],
"max_tokens":128
}`),
}
resp, err := exec.Execute(context.Background(), auth, req, cliproxyexecutor.Options{
SourceFormat: sdktranslator.FromString("claude"),
})
if err != nil {
t.Fatalf("Execute() error = %v", err)
}
if gotPath != "/v1/proxy/anthropic/v1/messages" {
t.Fatalf("Path = %q, want %q", gotPath, "/v1/proxy/anthropic/v1/messages")
}
if gotAuthHeader != "Bearer gateway-token" {
t.Fatalf("Authorization = %q, want Bearer gateway-token", gotAuthHeader)
}
if gotRealmHeader != "saas" {
t.Fatalf("X-Gitlab-Realm = %q, want saas", gotRealmHeader)
}
if gotModel != "claude-sonnet-4-5" {
t.Fatalf("model = %q, want claude-sonnet-4-5", gotModel)
}
if got := gjson.GetBytes(resp.Payload, "content.0.type").String(); got != "tool_use" {
t.Fatalf("expected tool_use response, got %q", got)
}
if got := gjson.GetBytes(resp.Payload, "content.0.name").String(); got != "Bash" {
t.Fatalf("expected tool name Bash, got %q", got)
}
}
func TestGitLabExecutorExecuteUsesOpenAIGateway(t *testing.T) {
var gotAuthHeader, gotRealmHeader string
var gotPath string
var gotModel string
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
gotPath = r.URL.Path
gotAuthHeader = r.Header.Get("Authorization")
gotRealmHeader = r.Header.Get("X-Gitlab-Realm")
gotModel = gjson.GetBytes(readBody(t, r), "model").String()
w.Header().Set("Content-Type", "text/event-stream")
_, _ = w.Write([]byte("data: {\"type\":\"response.created\",\"response\":{\"id\":\"resp_1\",\"created_at\":1710000000,\"model\":\"gpt-5-codex\"}}\n\n"))
_, _ = w.Write([]byte("data: {\"type\":\"response.output_text.delta\",\"delta\":\"hello from openai gateway\"}\n\n"))
_, _ = w.Write([]byte("data: {\"type\":\"response.completed\",\"response\":{\"id\":\"resp_1\",\"created_at\":1710000000,\"model\":\"gpt-5-codex\",\"output\":[{\"type\":\"message\",\"id\":\"msg_1\",\"role\":\"assistant\",\"content\":[{\"type\":\"output_text\",\"text\":\"hello from openai gateway\"}]}],\"usage\":{\"input_tokens\":11,\"output_tokens\":4,\"total_tokens\":15}}}\n\n"))
}))
defer srv.Close()
exec := NewGitLabExecutor(&config.Config{})
auth := &cliproxyauth.Auth{
Provider: "gitlab",
Metadata: map[string]any{
"duo_gateway_base_url": srv.URL,
"duo_gateway_token": "gateway-token",
"duo_gateway_headers": map[string]string{"X-Gitlab-Realm": "saas"},
"model_provider": "openai",
"model_name": "gpt-5-codex",
},
}
req := cliproxyexecutor.Request{
Model: "gitlab-duo",
Payload: []byte(`{"model":"gitlab-duo","messages":[{"role":"user","content":"hello"}]}`),
}
resp, err := exec.Execute(context.Background(), auth, req, cliproxyexecutor.Options{
SourceFormat: sdktranslator.FromString("openai"),
})
if err != nil {
t.Fatalf("Execute() error = %v", err)
}
if gotPath != "/v1/proxy/openai/v1/responses" {
t.Fatalf("Path = %q, want %q", gotPath, "/v1/proxy/openai/v1/responses")
}
if gotAuthHeader != "Bearer gateway-token" {
t.Fatalf("Authorization = %q, want Bearer gateway-token", gotAuthHeader)
}
if gotRealmHeader != "saas" {
t.Fatalf("X-Gitlab-Realm = %q, want saas", gotRealmHeader)
}
if gotModel != "gpt-5-codex" {
t.Fatalf("model = %q, want gpt-5-codex", gotModel)
}
if got := gjson.GetBytes(resp.Payload, "choices.0.message.content").String(); got != "hello from openai gateway" {
t.Fatalf("expected openai gateway response, got %q payload=%s", got, string(resp.Payload))
}
}
func TestGitLabExecutorRefreshUpdatesMetadata(t *testing.T) {
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
switch r.URL.Path {
case "/oauth/token":
_ = json.NewEncoder(w).Encode(map[string]any{
"access_token": "oauth-refreshed",
"refresh_token": "oauth-refresh",
"token_type": "Bearer",
"scope": "api read_user",
"created_at": 1710000000,
"expires_in": 3600,
})
case "/api/v4/code_suggestions/direct_access":
_ = json.NewEncoder(w).Encode(map[string]any{
"base_url": "https://cloud.gitlab.example.com",
"token": "gateway-token",
"expires_at": 1710003600,
"headers": map[string]string{"X-Gitlab-Realm": "saas"},
"model_details": map[string]any{
"model_provider": "anthropic",
"model_name": "claude-sonnet-4-5",
},
})
default:
t.Fatalf("unexpected path %q", r.URL.Path)
}
}))
defer srv.Close()
exec := NewGitLabExecutor(&config.Config{})
auth := &cliproxyauth.Auth{
ID: "gitlab-auth.json",
Provider: "gitlab",
Metadata: map[string]any{
"base_url": srv.URL,
"access_token": "oauth-access",
"refresh_token": "oauth-refresh",
"oauth_client_id": "client-id",
"oauth_client_secret": "client-secret",
"auth_method": "oauth",
"oauth_expires_at": "2000-01-01T00:00:00Z",
},
}
updated, err := exec.Refresh(context.Background(), auth)
if err != nil {
t.Fatalf("Refresh() error = %v", err)
}
if got := updated.Metadata["access_token"]; got != "oauth-refreshed" {
t.Fatalf("expected refreshed access token, got %#v", got)
}
if got := updated.Metadata["model_name"]; got != "claude-sonnet-4-5" {
t.Fatalf("expected refreshed model metadata, got %#v", got)
}
}
func TestGitLabExecutorExecuteStreamUsesCodeSuggestionsSSE(t *testing.T) {
var gotAccept, gotStreamingHeader, gotEncoding string
var gotStreamFlag bool
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.URL.Path != gitLabCodeSuggestionsEndpoint {
t.Fatalf("unexpected path %q", r.URL.Path)
}
gotAccept = r.Header.Get("Accept")
gotStreamingHeader = r.Header.Get(gitLabSSEStreamingHeader)
gotEncoding = r.Header.Get("Accept-Encoding")
gotStreamFlag = gjson.GetBytes(readBody(t, r), "stream").Bool()
w.Header().Set("Content-Type", "text/event-stream")
_, _ = w.Write([]byte("event: stream_start\n"))
_, _ = w.Write([]byte("data: {\"model\":{\"name\":\"claude-sonnet-4-5\"}}\n\n"))
_, _ = w.Write([]byte("event: content_chunk\n"))
_, _ = w.Write([]byte("data: {\"content\":\"hello\"}\n\n"))
_, _ = w.Write([]byte("event: content_chunk\n"))
_, _ = w.Write([]byte("data: {\"content\":\" world\"}\n\n"))
_, _ = w.Write([]byte("event: stream_end\n"))
_, _ = w.Write([]byte("data: {}\n\n"))
}))
defer srv.Close()
exec := NewGitLabExecutor(&config.Config{})
auth := &cliproxyauth.Auth{
Provider: "gitlab",
Metadata: map[string]any{
"base_url": srv.URL,
"access_token": "oauth-access",
"model_name": "claude-sonnet-4-5",
},
}
req := cliproxyexecutor.Request{
Model: "gitlab-duo",
Payload: []byte(`{"model":"gitlab-duo","stream":true,"messages":[{"role":"user","content":"hello"}]}`),
}
result, err := exec.ExecuteStream(context.Background(), auth, req, cliproxyexecutor.Options{
SourceFormat: sdktranslator.FromString("openai"),
})
if err != nil {
t.Fatalf("ExecuteStream() error = %v", err)
}
lines := collectStreamLines(t, result)
if gotAccept != "text/event-stream" {
t.Fatalf("Accept = %q, want text/event-stream", gotAccept)
}
if gotStreamingHeader != "true" {
t.Fatalf("%s = %q, want true", gitLabSSEStreamingHeader, gotStreamingHeader)
}
if gotEncoding != "identity" {
t.Fatalf("Accept-Encoding = %q, want identity", gotEncoding)
}
if !gotStreamFlag {
t.Fatalf("expected upstream request to set stream=true")
}
if len(lines) < 4 {
t.Fatalf("expected translated stream chunks, got %d", len(lines))
}
if !strings.Contains(strings.Join(lines, "\n"), `"content":"hello"`) {
t.Fatalf("expected hello delta in stream, got %q", strings.Join(lines, "\n"))
}
if !strings.Contains(strings.Join(lines, "\n"), `"content":" world"`) {
t.Fatalf("expected world delta in stream, got %q", strings.Join(lines, "\n"))
}
last := lines[len(lines)-1]
if last != "data: [DONE]" && !strings.Contains(last, `"finish_reason":"stop"`) {
t.Fatalf("expected stream terminator, got %q", last)
}
}
func TestGitLabExecutorExecuteStreamFallsBackToSyntheticChat(t *testing.T) {
chatCalls := 0
streamCalls := 0
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
switch r.URL.Path {
case gitLabCodeSuggestionsEndpoint:
streamCalls++
http.Error(w, "feature unavailable", http.StatusForbidden)
case gitLabChatEndpoint:
chatCalls++
_, _ = w.Write([]byte(`"chat fallback response"`))
default:
t.Fatalf("unexpected path %q", r.URL.Path)
}
}))
defer srv.Close()
exec := NewGitLabExecutor(&config.Config{})
auth := &cliproxyauth.Auth{
Provider: "gitlab",
Metadata: map[string]any{
"base_url": srv.URL,
"access_token": "oauth-access",
"model_name": "claude-sonnet-4-5",
},
}
req := cliproxyexecutor.Request{
Model: "gitlab-duo",
Payload: []byte(`{"model":"gitlab-duo","stream":true,"messages":[{"role":"user","content":"hello"}]}`),
}
result, err := exec.ExecuteStream(context.Background(), auth, req, cliproxyexecutor.Options{
SourceFormat: sdktranslator.FromString("openai"),
})
if err != nil {
t.Fatalf("ExecuteStream() error = %v", err)
}
lines := collectStreamLines(t, result)
if streamCalls != 1 {
t.Fatalf("expected streaming endpoint once, got %d", streamCalls)
}
if chatCalls != 1 {
t.Fatalf("expected chat fallback once, got %d", chatCalls)
}
if !strings.Contains(strings.Join(lines, "\n"), `"content":"chat fallback response"`) {
t.Fatalf("expected fallback content in stream, got %q", strings.Join(lines, "\n"))
}
}
func TestGitLabExecutorExecuteStreamUsesAnthropicGateway(t *testing.T) {
var gotPath string
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
gotPath = r.URL.Path
w.Header().Set("Content-Type", "text/event-stream")
_, _ = w.Write([]byte("event: message_start\n"))
_, _ = w.Write([]byte("data: {\"type\":\"message_start\",\"message\":{\"id\":\"msg_1\",\"type\":\"message\",\"role\":\"assistant\",\"model\":\"claude-sonnet-4-5\",\"content\":[],\"stop_reason\":null,\"stop_sequence\":null,\"usage\":{\"input_tokens\":0,\"output_tokens\":0}}}\n\n"))
_, _ = w.Write([]byte("event: content_block_start\n"))
_, _ = w.Write([]byte("data: {\"type\":\"content_block_start\",\"index\":0,\"content_block\":{\"type\":\"text\",\"text\":\"\"}}\n\n"))
_, _ = w.Write([]byte("event: content_block_delta\n"))
_, _ = w.Write([]byte("data: {\"type\":\"content_block_delta\",\"index\":0,\"delta\":{\"type\":\"text_delta\",\"text\":\"hello from gateway\"}}\n\n"))
_, _ = w.Write([]byte("event: message_delta\n"))
_, _ = w.Write([]byte("data: {\"type\":\"message_delta\",\"delta\":{\"stop_reason\":\"end_turn\",\"stop_sequence\":null},\"usage\":{\"input_tokens\":10,\"output_tokens\":3}}\n\n"))
_, _ = w.Write([]byte("event: message_stop\n"))
_, _ = w.Write([]byte("data: {\"type\":\"message_stop\"}\n\n"))
}))
defer srv.Close()
exec := NewGitLabExecutor(&config.Config{})
auth := &cliproxyauth.Auth{
Provider: "gitlab",
Metadata: map[string]any{
"duo_gateway_base_url": srv.URL,
"duo_gateway_token": "gateway-token",
"duo_gateway_headers": map[string]string{"X-Gitlab-Realm": "saas"},
"model_provider": "anthropic",
"model_name": "claude-sonnet-4-5",
},
}
req := cliproxyexecutor.Request{
Model: "gitlab-duo",
Payload: []byte(`{"model":"gitlab-duo","messages":[{"role":"user","content":[{"type":"text","text":"hello"}]}],"max_tokens":64}`),
}
result, err := exec.ExecuteStream(context.Background(), auth, req, cliproxyexecutor.Options{
SourceFormat: sdktranslator.FromString("claude"),
})
if err != nil {
t.Fatalf("ExecuteStream() error = %v", err)
}
lines := collectStreamLines(t, result)
if gotPath != "/v1/proxy/anthropic/v1/messages" {
t.Fatalf("Path = %q, want %q", gotPath, "/v1/proxy/anthropic/v1/messages")
}
if !strings.Contains(strings.Join(lines, "\n"), "hello from gateway") {
t.Fatalf("expected anthropic gateway stream, got %q", strings.Join(lines, "\n"))
}
}
func collectStreamLines(t *testing.T, result *cliproxyexecutor.StreamResult) []string {
t.Helper()
lines := make([]string, 0, 8)
for chunk := range result.Chunks {
if chunk.Err != nil {
t.Fatalf("unexpected stream error: %v", chunk.Err)
}
lines = append(lines, string(chunk.Payload))
}
return lines
}
func readBody(t *testing.T, r *http.Request) []byte {
t.Helper()
defer func() { _ = r.Body.Close() }()
body, err := io.ReadAll(r.Body)
if err != nil {
t.Fatalf("ReadAll() error = %v", err)
}
return body
}

View File

@@ -2,17 +2,15 @@ package executor
import (
"context"
"net"
"net/http"
"net/url"
"strings"
"sync"
"time"
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
cliproxyauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth"
"github.com/router-for-me/CLIProxyAPI/v6/sdk/proxyutil"
log "github.com/sirupsen/logrus"
"golang.org/x/net/proxy"
)
// httpClientCache caches HTTP clients by proxy URL to enable connection reuse
@@ -111,45 +109,10 @@ func newProxyAwareHTTPClient(ctx context.Context, cfg *config.Config, auth *clip
// Returns:
// - *http.Transport: A configured transport, or nil if the proxy URL is invalid
func buildProxyTransport(proxyURL string) *http.Transport {
if proxyURL == "" {
transport, _, errBuild := proxyutil.BuildHTTPTransport(proxyURL)
if errBuild != nil {
log.Errorf("%v", errBuild)
return nil
}
parsedURL, errParse := url.Parse(proxyURL)
if errParse != nil {
log.Errorf("parse proxy URL failed: %v", errParse)
return nil
}
var transport *http.Transport
// Handle different proxy schemes
if parsedURL.Scheme == "socks5" {
// Configure SOCKS5 proxy with optional authentication
var proxyAuth *proxy.Auth
if parsedURL.User != nil {
username := parsedURL.User.Username()
password, _ := parsedURL.User.Password()
proxyAuth = &proxy.Auth{User: username, Password: password}
}
dialer, errSOCKS5 := proxy.SOCKS5("tcp", parsedURL.Host, proxyAuth, proxy.Direct)
if errSOCKS5 != nil {
log.Errorf("create SOCKS5 dialer failed: %v", errSOCKS5)
return nil
}
// Set up a custom transport using the SOCKS5 dialer
transport = &http.Transport{
DialContext: func(ctx context.Context, network, addr string) (net.Conn, error) {
return dialer.Dial(network, addr)
},
}
} else if parsedURL.Scheme == "http" || parsedURL.Scheme == "https" {
// Configure HTTP or HTTPS proxy
transport = &http.Transport{Proxy: http.ProxyURL(parsedURL)}
} else {
log.Errorf("unsupported proxy scheme: %s", parsedURL.Scheme)
return nil
}
return transport
}

View File

@@ -0,0 +1,30 @@
package executor
import (
"context"
"net/http"
"testing"
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
cliproxyauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth"
sdkconfig "github.com/router-for-me/CLIProxyAPI/v6/sdk/config"
)
func TestNewProxyAwareHTTPClientDirectBypassesGlobalProxy(t *testing.T) {
t.Parallel()
client := newProxyAwareHTTPClient(
context.Background(),
&config.Config{SDKConfig: sdkconfig.SDKConfig{ProxyURL: "http://global-proxy.example.com:8080"}},
&cliproxyauth.Auth{ProxyURL: "direct"},
0,
)
transport, ok := client.Transport.(*http.Transport)
if !ok {
t.Fatalf("transport type = %T, want *http.Transport", client.Transport)
}
if transport.Proxy != nil {
t.Fatal("expected direct transport to disable proxy function")
}
}

View File

@@ -257,7 +257,10 @@ func applyUserDefinedModel(body []byte, modelInfo *registry.ModelInfo, fromForma
if suffixResult.HasSuffix {
config = parseSuffixToConfig(suffixResult.RawSuffix, toFormat, modelID)
} else {
config = extractThinkingConfig(body, toFormat)
config = extractThinkingConfig(body, fromFormat)
if !hasThinkingConfig(config) && fromFormat != toFormat {
config = extractThinkingConfig(body, toFormat)
}
}
if !hasThinkingConfig(config) {
@@ -293,6 +296,9 @@ func normalizeUserDefinedConfig(config ThinkingConfig, fromFormat, toFormat stri
if config.Mode != ModeLevel {
return config
}
if toFormat == "claude" {
return config
}
if !isBudgetCapableProvider(toFormat) {
return config
}

View File

@@ -0,0 +1,55 @@
package thinking_test
import (
"testing"
"github.com/router-for-me/CLIProxyAPI/v6/internal/registry"
"github.com/router-for-me/CLIProxyAPI/v6/internal/thinking"
_ "github.com/router-for-me/CLIProxyAPI/v6/internal/thinking/provider/claude"
"github.com/tidwall/gjson"
)
func TestApplyThinking_UserDefinedClaudePreservesAdaptiveLevel(t *testing.T) {
reg := registry.GetGlobalRegistry()
clientID := "test-user-defined-claude-" + t.Name()
modelID := "custom-claude-4-6"
reg.RegisterClient(clientID, "claude", []*registry.ModelInfo{{ID: modelID, UserDefined: true}})
t.Cleanup(func() {
reg.UnregisterClient(clientID)
})
tests := []struct {
name string
model string
body []byte
}{
{
name: "claude adaptive effort body",
model: modelID,
body: []byte(`{"thinking":{"type":"adaptive"},"output_config":{"effort":"high"}}`),
},
{
name: "suffix level",
model: modelID + "(high)",
body: []byte(`{}`),
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
out, err := thinking.ApplyThinking(tt.body, tt.model, "openai", "claude", "claude")
if err != nil {
t.Fatalf("ApplyThinking() error = %v", err)
}
if got := gjson.GetBytes(out, "thinking.type").String(); got != "adaptive" {
t.Fatalf("thinking.type = %q, want %q, body=%s", got, "adaptive", string(out))
}
if got := gjson.GetBytes(out, "output_config.effort").String(); got != "high" {
t.Fatalf("output_config.effort = %q, want %q, body=%s", got, "high", string(out))
}
if gjson.GetBytes(out, "thinking.budget_tokens").Exists() {
t.Fatalf("thinking.budget_tokens should be removed, body=%s", string(out))
}
})
}
}

View File

@@ -477,9 +477,6 @@ func ConvertClaudeRequestToAntigravity(modelName string, inputRawJSON []byte, _
effort = strings.ToLower(strings.TrimSpace(v.String()))
}
if effort != "" {
if effort == "max" {
effort = "high"
}
out, _ = sjson.Set(out, "request.generationConfig.thinkingConfig.thinkingLevel", effort)
} else {
out, _ = sjson.Set(out, "request.generationConfig.thinkingConfig.thinkingLevel", "high")

View File

@@ -1235,64 +1235,3 @@ func TestConvertClaudeRequestToAntigravity_ToolAndThinking_NoExistingSystem(t *t
t.Errorf("Interleaved thinking hint should be in created systemInstruction, got: %v", sysInstruction.Raw)
}
}
func TestConvertClaudeRequestToAntigravity_AdaptiveThinking_EffortLevels(t *testing.T) {
tests := []struct {
name string
effort string
expected string
}{
{"low", "low", "low"},
{"medium", "medium", "medium"},
{"high", "high", "high"},
{"max", "max", "high"},
}
for _, tt := range tests {
tt := tt
t.Run(tt.name, func(t *testing.T) {
inputJSON := []byte(`{
"model": "claude-opus-4-6-thinking",
"messages": [{"role": "user", "content": [{"type": "text", "text": "Hello"}]}],
"thinking": {"type": "adaptive"},
"output_config": {"effort": "` + tt.effort + `"}
}`)
output := ConvertClaudeRequestToAntigravity("claude-opus-4-6-thinking", inputJSON, false)
outputStr := string(output)
thinkingConfig := gjson.Get(outputStr, "request.generationConfig.thinkingConfig")
if !thinkingConfig.Exists() {
t.Fatal("thinkingConfig should exist for adaptive thinking")
}
if thinkingConfig.Get("thinkingLevel").String() != tt.expected {
t.Errorf("Expected thinkingLevel %q, got %q", tt.expected, thinkingConfig.Get("thinkingLevel").String())
}
if !thinkingConfig.Get("includeThoughts").Bool() {
t.Error("includeThoughts should be true")
}
})
}
}
func TestConvertClaudeRequestToAntigravity_AdaptiveThinking_NoEffort(t *testing.T) {
inputJSON := []byte(`{
"model": "claude-opus-4-6-thinking",
"messages": [{"role": "user", "content": [{"type": "text", "text": "Hello"}]}],
"thinking": {"type": "adaptive"}
}`)
output := ConvertClaudeRequestToAntigravity("claude-opus-4-6-thinking", inputJSON, false)
outputStr := string(output)
thinkingConfig := gjson.Get(outputStr, "request.generationConfig.thinkingConfig")
if !thinkingConfig.Exists() {
t.Fatal("thinkingConfig should exist for adaptive thinking without effort")
}
if thinkingConfig.Get("thinkingLevel").String() != "high" {
t.Errorf("Expected default thinkingLevel \"high\", got %q", thinkingConfig.Get("thinkingLevel").String())
}
if !thinkingConfig.Get("includeThoughts").Bool() {
t.Error("includeThoughts should be true")
}
}

View File

@@ -15,6 +15,7 @@ import (
"time"
"github.com/router-for-me/CLIProxyAPI/v6/internal/cache"
"github.com/router-for-me/CLIProxyAPI/v6/internal/util"
log "github.com/sirupsen/logrus"
"github.com/tidwall/gjson"
@@ -256,7 +257,7 @@ func ConvertAntigravityResponseToClaude(_ context.Context, _ string, originalReq
// Create the tool use block with unique ID and function details
data := fmt.Sprintf(`{"type":"content_block_start","index":%d,"content_block":{"type":"tool_use","id":"","name":"","input":{}}}`, params.ResponseIndex)
data, _ = sjson.Set(data, "content_block.id", fmt.Sprintf("%s-%d-%d", fcName, time.Now().UnixNano(), atomic.AddUint64(&toolUseIDCounter, 1)))
data, _ = sjson.Set(data, "content_block.id", util.SanitizeClaudeToolID(fmt.Sprintf("%s-%d-%d", fcName, time.Now().UnixNano(), atomic.AddUint64(&toolUseIDCounter, 1))))
data, _ = sjson.Set(data, "content_block.name", fcName)
output = output + fmt.Sprintf("data: %s\n\n\n", data)

View File

@@ -138,20 +138,31 @@ func ConvertGeminiRequestToAntigravity(modelName string, inputRawJSON []byte, _
// FunctionCallGroup represents a group of function calls and their responses
type FunctionCallGroup struct {
ResponsesNeeded int
CallNames []string // ordered function call names for backfilling empty response names
}
// parseFunctionResponseRaw attempts to normalize a function response part into a JSON object string.
// Falls back to a minimal "functionResponse" object when parsing fails.
func parseFunctionResponseRaw(response gjson.Result) string {
// fallbackName is used when the response's own name is empty.
func parseFunctionResponseRaw(response gjson.Result, fallbackName string) string {
if response.IsObject() && gjson.Valid(response.Raw) {
return response.Raw
raw := response.Raw
name := response.Get("functionResponse.name").String()
if strings.TrimSpace(name) == "" && fallbackName != "" {
raw, _ = sjson.Set(raw, "functionResponse.name", fallbackName)
}
return raw
}
log.Debugf("parse function response failed, using fallback")
funcResp := response.Get("functionResponse")
if funcResp.Exists() {
fr := `{"functionResponse":{"name":"","response":{"result":""}}}`
fr, _ = sjson.Set(fr, "functionResponse.name", funcResp.Get("name").String())
name := funcResp.Get("name").String()
if strings.TrimSpace(name) == "" {
name = fallbackName
}
fr, _ = sjson.Set(fr, "functionResponse.name", name)
fr, _ = sjson.Set(fr, "functionResponse.response.result", funcResp.Get("response").String())
if id := funcResp.Get("id").String(); id != "" {
fr, _ = sjson.Set(fr, "functionResponse.id", id)
@@ -159,7 +170,12 @@ func parseFunctionResponseRaw(response gjson.Result) string {
return fr
}
fr := `{"functionResponse":{"name":"unknown","response":{"result":""}}}`
useName := fallbackName
if useName == "" {
useName = "unknown"
}
fr := `{"functionResponse":{"name":"","response":{"result":""}}}`
fr, _ = sjson.Set(fr, "functionResponse.name", useName)
fr, _ = sjson.Set(fr, "functionResponse.response.result", response.String())
return fr
}
@@ -211,30 +227,26 @@ func fixCLIToolResponse(input string) (string, error) {
if len(responsePartsInThisContent) > 0 {
collectedResponses = append(collectedResponses, responsePartsInThisContent...)
// Check if any pending groups can be satisfied
for i := len(pendingGroups) - 1; i >= 0; i-- {
group := pendingGroups[i]
if len(collectedResponses) >= group.ResponsesNeeded {
// Take the needed responses for this group
groupResponses := collectedResponses[:group.ResponsesNeeded]
collectedResponses = collectedResponses[group.ResponsesNeeded:]
// Check if pending groups can be satisfied (FIFO: oldest group first)
for len(pendingGroups) > 0 && len(collectedResponses) >= pendingGroups[0].ResponsesNeeded {
group := pendingGroups[0]
pendingGroups = pendingGroups[1:]
// Create merged function response content
functionResponseContent := `{"parts":[],"role":"function"}`
for _, response := range groupResponses {
partRaw := parseFunctionResponseRaw(response)
if partRaw != "" {
functionResponseContent, _ = sjson.SetRaw(functionResponseContent, "parts.-1", partRaw)
}
// Take the needed responses for this group
groupResponses := collectedResponses[:group.ResponsesNeeded]
collectedResponses = collectedResponses[group.ResponsesNeeded:]
// Create merged function response content
functionResponseContent := `{"parts":[],"role":"function"}`
for ri, response := range groupResponses {
partRaw := parseFunctionResponseRaw(response, group.CallNames[ri])
if partRaw != "" {
functionResponseContent, _ = sjson.SetRaw(functionResponseContent, "parts.-1", partRaw)
}
}
if gjson.Get(functionResponseContent, "parts.#").Int() > 0 {
contentsWrapper, _ = sjson.SetRaw(contentsWrapper, "contents.-1", functionResponseContent)
}
// Remove this group as it's been satisfied
pendingGroups = append(pendingGroups[:i], pendingGroups[i+1:]...)
break
if gjson.Get(functionResponseContent, "parts.#").Int() > 0 {
contentsWrapper, _ = sjson.SetRaw(contentsWrapper, "contents.-1", functionResponseContent)
}
}
@@ -243,15 +255,15 @@ func fixCLIToolResponse(input string) (string, error) {
// If this is a model with function calls, create a new group
if role == "model" {
functionCallsCount := 0
var callNames []string
parts.ForEach(func(_, part gjson.Result) bool {
if part.Get("functionCall").Exists() {
functionCallsCount++
callNames = append(callNames, part.Get("functionCall.name").String())
}
return true
})
if functionCallsCount > 0 {
if len(callNames) > 0 {
// Add the model content
if !value.IsObject() {
log.Warnf("failed to parse model content")
@@ -261,7 +273,8 @@ func fixCLIToolResponse(input string) (string, error) {
// Create a new group for tracking responses
group := &FunctionCallGroup{
ResponsesNeeded: functionCallsCount,
ResponsesNeeded: len(callNames),
CallNames: callNames,
}
pendingGroups = append(pendingGroups, group)
} else {
@@ -291,8 +304,8 @@ func fixCLIToolResponse(input string) (string, error) {
collectedResponses = collectedResponses[group.ResponsesNeeded:]
functionResponseContent := `{"parts":[],"role":"function"}`
for _, response := range groupResponses {
partRaw := parseFunctionResponseRaw(response)
for ri, response := range groupResponses {
partRaw := parseFunctionResponseRaw(response, group.CallNames[ri])
if partRaw != "" {
functionResponseContent, _ = sjson.SetRaw(functionResponseContent, "parts.-1", partRaw)
}

View File

@@ -171,3 +171,257 @@ func TestFixCLIToolResponse_PreservesFunctionResponseParts(t *testing.T) {
t.Errorf("Expected response.result 'Screenshot taken', got '%s'", funcResp.Get("response.result").String())
}
}
func TestFixCLIToolResponse_BackfillsEmptyFunctionResponseName(t *testing.T) {
// When the Amp client sends functionResponse with an empty name,
// fixCLIToolResponse should backfill it from the corresponding functionCall.
input := `{
"model": "gemini-3-pro-preview",
"request": {
"contents": [
{
"role": "model",
"parts": [
{"functionCall": {"name": "Bash", "args": {"cmd": "ls"}}}
]
},
{
"role": "function",
"parts": [
{"functionResponse": {"name": "", "response": {"output": "file1.txt"}}}
]
}
]
}
}`
result, err := fixCLIToolResponse(input)
if err != nil {
t.Fatalf("fixCLIToolResponse failed: %v", err)
}
contents := gjson.Get(result, "request.contents").Array()
var funcContent gjson.Result
for _, c := range contents {
if c.Get("role").String() == "function" {
funcContent = c
break
}
}
if !funcContent.Exists() {
t.Fatal("function role content should exist in output")
}
name := funcContent.Get("parts.0.functionResponse.name").String()
if name != "Bash" {
t.Errorf("Expected backfilled name 'Bash', got '%s'", name)
}
}
func TestFixCLIToolResponse_BackfillsMultipleEmptyNames(t *testing.T) {
// Parallel function calls: both responses have empty names.
input := `{
"model": "gemini-3-pro-preview",
"request": {
"contents": [
{
"role": "model",
"parts": [
{"functionCall": {"name": "Read", "args": {"path": "/a"}}},
{"functionCall": {"name": "Grep", "args": {"pattern": "x"}}}
]
},
{
"role": "function",
"parts": [
{"functionResponse": {"name": "", "response": {"result": "content a"}}},
{"functionResponse": {"name": "", "response": {"result": "match x"}}}
]
}
]
}
}`
result, err := fixCLIToolResponse(input)
if err != nil {
t.Fatalf("fixCLIToolResponse failed: %v", err)
}
contents := gjson.Get(result, "request.contents").Array()
var funcContent gjson.Result
for _, c := range contents {
if c.Get("role").String() == "function" {
funcContent = c
break
}
}
if !funcContent.Exists() {
t.Fatal("function role content should exist in output")
}
parts := funcContent.Get("parts").Array()
if len(parts) != 2 {
t.Fatalf("Expected 2 function response parts, got %d", len(parts))
}
name0 := parts[0].Get("functionResponse.name").String()
name1 := parts[1].Get("functionResponse.name").String()
if name0 != "Read" {
t.Errorf("Expected first response name 'Read', got '%s'", name0)
}
if name1 != "Grep" {
t.Errorf("Expected second response name 'Grep', got '%s'", name1)
}
}
func TestFixCLIToolResponse_PreservesExistingName(t *testing.T) {
// When functionResponse already has a valid name, it should be preserved.
input := `{
"model": "gemini-3-pro-preview",
"request": {
"contents": [
{
"role": "model",
"parts": [
{"functionCall": {"name": "Bash", "args": {}}}
]
},
{
"role": "function",
"parts": [
{"functionResponse": {"name": "Bash", "response": {"result": "ok"}}}
]
}
]
}
}`
result, err := fixCLIToolResponse(input)
if err != nil {
t.Fatalf("fixCLIToolResponse failed: %v", err)
}
contents := gjson.Get(result, "request.contents").Array()
var funcContent gjson.Result
for _, c := range contents {
if c.Get("role").String() == "function" {
funcContent = c
break
}
}
if !funcContent.Exists() {
t.Fatal("function role content should exist in output")
}
name := funcContent.Get("parts.0.functionResponse.name").String()
if name != "Bash" {
t.Errorf("Expected preserved name 'Bash', got '%s'", name)
}
}
func TestFixCLIToolResponse_MoreResponsesThanCalls(t *testing.T) {
// If there are more function responses than calls, unmatched extras are discarded by grouping.
input := `{
"model": "gemini-3-pro-preview",
"request": {
"contents": [
{
"role": "model",
"parts": [
{"functionCall": {"name": "Bash", "args": {}}}
]
},
{
"role": "function",
"parts": [
{"functionResponse": {"name": "", "response": {"result": "ok"}}},
{"functionResponse": {"name": "", "response": {"result": "extra"}}}
]
}
]
}
}`
result, err := fixCLIToolResponse(input)
if err != nil {
t.Fatalf("fixCLIToolResponse failed: %v", err)
}
contents := gjson.Get(result, "request.contents").Array()
var funcContent gjson.Result
for _, c := range contents {
if c.Get("role").String() == "function" {
funcContent = c
break
}
}
if !funcContent.Exists() {
t.Fatal("function role content should exist in output")
}
// First response should be backfilled from the call
name0 := funcContent.Get("parts.0.functionResponse.name").String()
if name0 != "Bash" {
t.Errorf("Expected first response name 'Bash', got '%s'", name0)
}
}
func TestFixCLIToolResponse_MultipleGroupsFIFO(t *testing.T) {
// Two sequential function call groups should be matched FIFO.
input := `{
"model": "gemini-3-pro-preview",
"request": {
"contents": [
{
"role": "model",
"parts": [
{"functionCall": {"name": "Read", "args": {}}}
]
},
{
"role": "function",
"parts": [
{"functionResponse": {"name": "", "response": {"result": "file content"}}}
]
},
{
"role": "model",
"parts": [
{"functionCall": {"name": "Grep", "args": {}}}
]
},
{
"role": "function",
"parts": [
{"functionResponse": {"name": "", "response": {"result": "match"}}}
]
}
]
}
}`
result, err := fixCLIToolResponse(input)
if err != nil {
t.Fatalf("fixCLIToolResponse failed: %v", err)
}
contents := gjson.Get(result, "request.contents").Array()
var funcContents []gjson.Result
for _, c := range contents {
if c.Get("role").String() == "function" {
funcContents = append(funcContents, c)
}
}
if len(funcContents) != 2 {
t.Fatalf("Expected 2 function contents, got %d", len(funcContents))
}
name0 := funcContents[0].Get("parts.0.functionResponse.name").String()
name1 := funcContents[1].Get("parts.0.functionResponse.name").String()
if name0 != "Read" {
t.Errorf("Expected first group name 'Read', got '%s'", name0)
}
if name1 != "Grep" {
t.Errorf("Expected second group name 'Grep', got '%s'", name1)
}
}

View File

@@ -212,6 +212,33 @@ func ConvertOpenAIRequestToAntigravity(modelName string, inputRawJSON []byte, _
} else {
log.Warnf("Unknown file name extension '%s' in user message, skip", ext)
}
case "input_audio":
audioData := item.Get("input_audio.data").String()
audioFormat := item.Get("input_audio.format").String()
if audioData != "" {
audioMimeMap := map[string]string{
"mp3": "audio/mpeg",
"wav": "audio/wav",
"ogg": "audio/ogg",
"flac": "audio/flac",
"aac": "audio/aac",
"webm": "audio/webm",
"pcm16": "audio/pcm",
"g711_ulaw": "audio/basic",
"g711_alaw": "audio/basic",
}
mimeType := "audio/wav"
if audioFormat != "" {
if mapped, ok := audioMimeMap[audioFormat]; ok {
mimeType = mapped
} else {
mimeType = "audio/" + audioFormat
}
}
node, _ = sjson.SetBytes(node, "parts."+itoa(p)+".inlineData.mime_type", mimeType)
node, _ = sjson.SetBytes(node, "parts."+itoa(p)+".inlineData.data", audioData)
p++
}
}
}
}

View File

@@ -203,46 +203,9 @@ func ConvertOpenAIRequestToClaude(modelName string, inputRawJSON []byte, stream
msg, _ = sjson.SetRaw(msg, "content.-1", part)
} else if contentResult.Exists() && contentResult.IsArray() {
contentResult.ForEach(func(_, part gjson.Result) bool {
partType := part.Get("type").String()
switch partType {
case "text":
textPart := `{"type":"text","text":""}`
textPart, _ = sjson.Set(textPart, "text", part.Get("text").String())
msg, _ = sjson.SetRaw(msg, "content.-1", textPart)
case "image_url":
// Convert OpenAI image format to Claude Code format
imageURL := part.Get("image_url.url").String()
if strings.HasPrefix(imageURL, "data:") {
// Extract base64 data and media type from data URL
parts := strings.Split(imageURL, ",")
if len(parts) == 2 {
mediaTypePart := strings.Split(parts[0], ";")[0]
mediaType := strings.TrimPrefix(mediaTypePart, "data:")
data := parts[1]
imagePart := `{"type":"image","source":{"type":"base64","media_type":"","data":""}}`
imagePart, _ = sjson.Set(imagePart, "source.media_type", mediaType)
imagePart, _ = sjson.Set(imagePart, "source.data", data)
msg, _ = sjson.SetRaw(msg, "content.-1", imagePart)
}
}
case "file":
fileData := part.Get("file.file_data").String()
if strings.HasPrefix(fileData, "data:") {
semicolonIdx := strings.Index(fileData, ";")
commaIdx := strings.Index(fileData, ",")
if semicolonIdx != -1 && commaIdx != -1 && commaIdx > semicolonIdx {
mediaType := strings.TrimPrefix(fileData[:semicolonIdx], "data:")
data := fileData[commaIdx+1:]
docPart := `{"type":"document","source":{"type":"base64","media_type":"","data":""}}`
docPart, _ = sjson.Set(docPart, "source.media_type", mediaType)
docPart, _ = sjson.Set(docPart, "source.data", data)
msg, _ = sjson.SetRaw(msg, "content.-1", docPart)
}
}
claudePart := convertOpenAIContentPartToClaudePart(part)
if claudePart != "" {
msg, _ = sjson.SetRaw(msg, "content.-1", claudePart)
}
return true
})
@@ -291,11 +254,16 @@ func ConvertOpenAIRequestToClaude(modelName string, inputRawJSON []byte, stream
case "tool":
// Handle tool result messages conversion
toolCallID := message.Get("tool_call_id").String()
content := message.Get("content").String()
toolContentResult := message.Get("content")
msg := `{"role":"user","content":[{"type":"tool_result","tool_use_id":"","content":""}]}`
msg, _ = sjson.Set(msg, "content.0.tool_use_id", toolCallID)
msg, _ = sjson.Set(msg, "content.0.content", content)
toolResultContent, toolResultContentRaw := convertOpenAIToolResultContent(toolContentResult)
if toolResultContentRaw {
msg, _ = sjson.SetRaw(msg, "content.0.content", toolResultContent)
} else {
msg, _ = sjson.Set(msg, "content.0.content", toolResultContent)
}
out, _ = sjson.SetRaw(out, "messages.-1", msg)
messageIndex++
}
@@ -358,3 +326,110 @@ func ConvertOpenAIRequestToClaude(modelName string, inputRawJSON []byte, stream
return []byte(out)
}
func convertOpenAIContentPartToClaudePart(part gjson.Result) string {
switch part.Get("type").String() {
case "text":
textPart := `{"type":"text","text":""}`
textPart, _ = sjson.Set(textPart, "text", part.Get("text").String())
return textPart
case "image_url":
return convertOpenAIImageURLToClaudePart(part.Get("image_url.url").String())
case "file":
fileData := part.Get("file.file_data").String()
if strings.HasPrefix(fileData, "data:") {
semicolonIdx := strings.Index(fileData, ";")
commaIdx := strings.Index(fileData, ",")
if semicolonIdx != -1 && commaIdx != -1 && commaIdx > semicolonIdx {
mediaType := strings.TrimPrefix(fileData[:semicolonIdx], "data:")
data := fileData[commaIdx+1:]
docPart := `{"type":"document","source":{"type":"base64","media_type":"","data":""}}`
docPart, _ = sjson.Set(docPart, "source.media_type", mediaType)
docPart, _ = sjson.Set(docPart, "source.data", data)
return docPart
}
}
}
return ""
}
func convertOpenAIImageURLToClaudePart(imageURL string) string {
if imageURL == "" {
return ""
}
if strings.HasPrefix(imageURL, "data:") {
parts := strings.SplitN(imageURL, ",", 2)
if len(parts) != 2 {
return ""
}
mediaTypePart := strings.SplitN(parts[0], ";", 2)[0]
mediaType := strings.TrimPrefix(mediaTypePart, "data:")
if mediaType == "" {
mediaType = "application/octet-stream"
}
imagePart := `{"type":"image","source":{"type":"base64","media_type":"","data":""}}`
imagePart, _ = sjson.Set(imagePart, "source.media_type", mediaType)
imagePart, _ = sjson.Set(imagePart, "source.data", parts[1])
return imagePart
}
imagePart := `{"type":"image","source":{"type":"url","url":""}}`
imagePart, _ = sjson.Set(imagePart, "source.url", imageURL)
return imagePart
}
func convertOpenAIToolResultContent(content gjson.Result) (string, bool) {
if !content.Exists() {
return "", false
}
if content.Type == gjson.String {
return content.String(), false
}
if content.IsArray() {
claudeContent := "[]"
partCount := 0
content.ForEach(func(_, part gjson.Result) bool {
if part.Type == gjson.String {
textPart := `{"type":"text","text":""}`
textPart, _ = sjson.Set(textPart, "text", part.String())
claudeContent, _ = sjson.SetRaw(claudeContent, "-1", textPart)
partCount++
return true
}
claudePart := convertOpenAIContentPartToClaudePart(part)
if claudePart != "" {
claudeContent, _ = sjson.SetRaw(claudeContent, "-1", claudePart)
partCount++
}
return true
})
if partCount > 0 || len(content.Array()) == 0 {
return claudeContent, true
}
return content.Raw, false
}
if content.IsObject() {
claudePart := convertOpenAIContentPartToClaudePart(content)
if claudePart != "" {
claudeContent := "[]"
claudeContent, _ = sjson.SetRaw(claudeContent, "-1", claudePart)
return claudeContent, true
}
return content.Raw, false
}
return content.Raw, false
}

View File

@@ -0,0 +1,137 @@
package chat_completions
import (
"testing"
"github.com/tidwall/gjson"
)
func TestConvertOpenAIRequestToClaude_ToolResultTextAndBase64Image(t *testing.T) {
inputJSON := `{
"model": "gpt-4.1",
"messages": [
{
"role": "assistant",
"content": "",
"tool_calls": [
{
"id": "call_1",
"type": "function",
"function": {
"name": "do_work",
"arguments": "{\"a\":1}"
}
}
]
},
{
"role": "tool",
"tool_call_id": "call_1",
"content": [
{"type": "text", "text": "tool ok"},
{
"type": "image_url",
"image_url": {
"url": "data:image/png;base64,iVBORw0KGgoAAAANSUhEUg=="
}
}
]
}
]
}`
result := ConvertOpenAIRequestToClaude("claude-sonnet-4-5", []byte(inputJSON), false)
resultJSON := gjson.ParseBytes(result)
messages := resultJSON.Get("messages").Array()
if len(messages) != 2 {
t.Fatalf("Expected 2 messages, got %d. Messages: %s", len(messages), resultJSON.Get("messages").Raw)
}
toolResult := messages[1].Get("content.0")
if got := toolResult.Get("type").String(); got != "tool_result" {
t.Fatalf("Expected content[0].type %q, got %q", "tool_result", got)
}
if got := toolResult.Get("tool_use_id").String(); got != "call_1" {
t.Fatalf("Expected tool_use_id %q, got %q", "call_1", got)
}
toolContent := toolResult.Get("content")
if !toolContent.IsArray() {
t.Fatalf("Expected tool_result content array, got %s", toolContent.Raw)
}
if got := toolContent.Get("0.type").String(); got != "text" {
t.Fatalf("Expected first tool_result part type %q, got %q", "text", got)
}
if got := toolContent.Get("0.text").String(); got != "tool ok" {
t.Fatalf("Expected first tool_result part text %q, got %q", "tool ok", got)
}
if got := toolContent.Get("1.type").String(); got != "image" {
t.Fatalf("Expected second tool_result part type %q, got %q", "image", got)
}
if got := toolContent.Get("1.source.type").String(); got != "base64" {
t.Fatalf("Expected image source type %q, got %q", "base64", got)
}
if got := toolContent.Get("1.source.media_type").String(); got != "image/png" {
t.Fatalf("Expected image media type %q, got %q", "image/png", got)
}
if got := toolContent.Get("1.source.data").String(); got != "iVBORw0KGgoAAAANSUhEUg==" {
t.Fatalf("Unexpected base64 image data: %q", got)
}
}
func TestConvertOpenAIRequestToClaude_ToolResultURLImageOnly(t *testing.T) {
inputJSON := `{
"model": "gpt-4.1",
"messages": [
{
"role": "assistant",
"content": "",
"tool_calls": [
{
"id": "call_1",
"type": "function",
"function": {
"name": "do_work",
"arguments": "{\"a\":1}"
}
}
]
},
{
"role": "tool",
"tool_call_id": "call_1",
"content": [
{
"type": "image_url",
"image_url": {
"url": "https://example.com/tool.png"
}
}
]
}
]
}`
result := ConvertOpenAIRequestToClaude("claude-sonnet-4-5", []byte(inputJSON), false)
resultJSON := gjson.ParseBytes(result)
messages := resultJSON.Get("messages").Array()
if len(messages) != 2 {
t.Fatalf("Expected 2 messages, got %d. Messages: %s", len(messages), resultJSON.Get("messages").Raw)
}
toolContent := messages[1].Get("content.0.content")
if !toolContent.IsArray() {
t.Fatalf("Expected tool_result content array, got %s", toolContent.Raw)
}
if got := toolContent.Get("0.type").String(); got != "image" {
t.Fatalf("Expected tool_result part type %q, got %q", "image", got)
}
if got := toolContent.Get("0.source.type").String(); got != "url" {
t.Fatalf("Expected image source type %q, got %q", "url", got)
}
if got := toolContent.Get("0.source.url").String(); got != "https://example.com/tool.png" {
t.Fatalf("Unexpected image URL: %q", got)
}
}

View File

@@ -43,23 +43,32 @@ func ConvertClaudeRequestToCodex(modelName string, inputRawJSON []byte, _ bool)
// Process system messages and convert them to input content format.
systemsResult := rootResult.Get("system")
if systemsResult.IsArray() {
systemResults := systemsResult.Array()
if systemsResult.Exists() {
message := `{"type":"message","role":"developer","content":[]}`
contentIndex := 0
for i := 0; i < len(systemResults); i++ {
systemResult := systemResults[i]
systemTypeResult := systemResult.Get("type")
if systemTypeResult.String() == "text" {
text := systemResult.Get("text").String()
if strings.HasPrefix(text, "x-anthropic-billing-header: ") {
continue
appendSystemText := func(text string) {
if text == "" || strings.HasPrefix(text, "x-anthropic-billing-header: ") {
return
}
message, _ = sjson.Set(message, fmt.Sprintf("content.%d.type", contentIndex), "input_text")
message, _ = sjson.Set(message, fmt.Sprintf("content.%d.text", contentIndex), text)
contentIndex++
}
if systemsResult.Type == gjson.String {
appendSystemText(systemsResult.String())
} else if systemsResult.IsArray() {
systemResults := systemsResult.Array()
for i := 0; i < len(systemResults); i++ {
systemResult := systemResults[i]
if systemResult.Get("type").String() == "text" {
appendSystemText(systemResult.Get("text").String())
}
message, _ = sjson.Set(message, fmt.Sprintf("content.%d.type", contentIndex), "input_text")
message, _ = sjson.Set(message, fmt.Sprintf("content.%d.text", contentIndex), text)
contentIndex++
}
}
if contentIndex > 0 {
template, _ = sjson.SetRaw(template, "input.-1", message)
}

View File

@@ -0,0 +1,89 @@
package claude
import (
"testing"
"github.com/tidwall/gjson"
)
func TestConvertClaudeRequestToCodex_SystemMessageScenarios(t *testing.T) {
tests := []struct {
name string
inputJSON string
wantHasDeveloper bool
wantTexts []string
}{
{
name: "No system field",
inputJSON: `{
"model": "claude-3-opus",
"messages": [{"role": "user", "content": "hello"}]
}`,
wantHasDeveloper: false,
},
{
name: "Empty string system field",
inputJSON: `{
"model": "claude-3-opus",
"system": "",
"messages": [{"role": "user", "content": "hello"}]
}`,
wantHasDeveloper: false,
},
{
name: "String system field",
inputJSON: `{
"model": "claude-3-opus",
"system": "Be helpful",
"messages": [{"role": "user", "content": "hello"}]
}`,
wantHasDeveloper: true,
wantTexts: []string{"Be helpful"},
},
{
name: "Array system field with filtered billing header",
inputJSON: `{
"model": "claude-3-opus",
"system": [
{"type": "text", "text": "x-anthropic-billing-header: tenant-123"},
{"type": "text", "text": "Block 1"},
{"type": "text", "text": "Block 2"}
],
"messages": [{"role": "user", "content": "hello"}]
}`,
wantHasDeveloper: true,
wantTexts: []string{"Block 1", "Block 2"},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := ConvertClaudeRequestToCodex("test-model", []byte(tt.inputJSON), false)
resultJSON := gjson.ParseBytes(result)
inputs := resultJSON.Get("input").Array()
hasDeveloper := len(inputs) > 0 && inputs[0].Get("role").String() == "developer"
if hasDeveloper != tt.wantHasDeveloper {
t.Fatalf("got hasDeveloper = %v, want %v. Output: %s", hasDeveloper, tt.wantHasDeveloper, resultJSON.Get("input").Raw)
}
if !tt.wantHasDeveloper {
return
}
content := inputs[0].Get("content").Array()
if len(content) != len(tt.wantTexts) {
t.Fatalf("got %d system content items, want %d. Content: %s", len(content), len(tt.wantTexts), inputs[0].Get("content").Raw)
}
for i, wantText := range tt.wantTexts {
if gotType := content[i].Get("type").String(); gotType != "input_text" {
t.Fatalf("content[%d] type = %q, want %q", i, gotType, "input_text")
}
if gotText := content[i].Get("text").String(); gotText != wantText {
t.Fatalf("content[%d] text = %q, want %q", i, gotText, wantText)
}
}
})
}
}

View File

@@ -12,6 +12,7 @@ import (
"fmt"
"strings"
"github.com/router-for-me/CLIProxyAPI/v6/internal/util"
"github.com/tidwall/gjson"
"github.com/tidwall/sjson"
)
@@ -141,7 +142,7 @@ func ConvertCodexResponseToClaude(_ context.Context, _ string, originalRequestRa
(*param).(*ConvertCodexResponseToClaudeParams).HasReceivedArgumentsDelta = false
template = `{"type":"content_block_start","index":0,"content_block":{"type":"tool_use","id":"","name":"","input":{}}}`
template, _ = sjson.Set(template, "index", (*param).(*ConvertCodexResponseToClaudeParams).BlockIndex)
template, _ = sjson.Set(template, "content_block.id", itemResult.Get("call_id").String())
template, _ = sjson.Set(template, "content_block.id", util.SanitizeClaudeToolID(itemResult.Get("call_id").String()))
{
// Restore original tool name if shortened
name := itemResult.Get("name").String()
@@ -310,7 +311,7 @@ func ConvertCodexResponseToClaudeNonStream(_ context.Context, _ string, original
}
toolBlock := `{"type":"tool_use","id":"","name":"","input":{}}`
toolBlock, _ = sjson.Set(toolBlock, "id", item.Get("call_id").String())
toolBlock, _ = sjson.Set(toolBlock, "id", util.SanitizeClaudeToolID(item.Get("call_id").String()))
toolBlock, _ = sjson.Set(toolBlock, "name", name)
inputRaw := "{}"
if argsStr := item.Get("arguments").String(); argsStr != "" && gjson.Valid(argsStr) {

View File

@@ -25,7 +25,12 @@ func ConvertOpenAIResponsesRequestToCodex(modelName string, inputRawJSON []byte,
rawJSON, _ = sjson.DeleteBytes(rawJSON, "max_completion_tokens")
rawJSON, _ = sjson.DeleteBytes(rawJSON, "temperature")
rawJSON, _ = sjson.DeleteBytes(rawJSON, "top_p")
// rawJSON, _ = sjson.DeleteBytes(rawJSON, "service_tier")
if v := gjson.GetBytes(rawJSON, "service_tier"); v.Exists() {
if v.String() != "priority" {
rawJSON, _ = sjson.DeleteBytes(rawJSON, "service_tier")
}
}
rawJSON, _ = sjson.DeleteBytes(rawJSON, "truncation")
rawJSON = applyResponsesCompactionCompatibility(rawJSON)

View File

@@ -14,6 +14,7 @@ import (
"sync/atomic"
"time"
"github.com/router-for-me/CLIProxyAPI/v6/internal/util"
"github.com/tidwall/gjson"
"github.com/tidwall/sjson"
)
@@ -209,7 +210,7 @@ func ConvertGeminiCLIResponseToClaude(_ context.Context, _ string, originalReque
// Create the tool use block with unique ID and function details
data := fmt.Sprintf(`{"type":"content_block_start","index":%d,"content_block":{"type":"tool_use","id":"","name":"","input":{}}}`, (*param).(*Params).ResponseIndex)
data, _ = sjson.Set(data, "content_block.id", fmt.Sprintf("%s-%d-%d", fcName, time.Now().UnixNano(), atomic.AddUint64(&toolUseIDCounter, 1)))
data, _ = sjson.Set(data, "content_block.id", util.SanitizeClaudeToolID(fmt.Sprintf("%s-%d-%d", fcName, time.Now().UnixNano(), atomic.AddUint64(&toolUseIDCounter, 1))))
data, _ = sjson.Set(data, "content_block.name", fcName)
output = output + fmt.Sprintf("data: %s\n\n\n", data)

View File

@@ -7,6 +7,7 @@ package gemini
import (
"fmt"
"strings"
"github.com/router-for-me/CLIProxyAPI/v6/internal/translator/gemini/common"
"github.com/router-for-me/CLIProxyAPI/v6/internal/util"
@@ -116,6 +117,17 @@ func ConvertGeminiRequestToGeminiCLI(_ string, inputRawJSON []byte, _ bool) []by
// FunctionCallGroup represents a group of function calls and their responses
type FunctionCallGroup struct {
ResponsesNeeded int
CallNames []string // ordered function call names for backfilling empty response names
}
// backfillFunctionResponseName ensures that a functionResponse JSON object has a non-empty name,
// falling back to fallbackName if the original is empty.
func backfillFunctionResponseName(raw string, fallbackName string) string {
name := gjson.Get(raw, "functionResponse.name").String()
if strings.TrimSpace(name) == "" && fallbackName != "" {
raw, _ = sjson.Set(raw, "functionResponse.name", fallbackName)
}
return raw
}
// fixCLIToolResponse performs sophisticated tool response format conversion and grouping.
@@ -165,31 +177,28 @@ func fixCLIToolResponse(input string) (string, error) {
if len(responsePartsInThisContent) > 0 {
collectedResponses = append(collectedResponses, responsePartsInThisContent...)
// Check if any pending groups can be satisfied
for i := len(pendingGroups) - 1; i >= 0; i-- {
group := pendingGroups[i]
if len(collectedResponses) >= group.ResponsesNeeded {
// Take the needed responses for this group
groupResponses := collectedResponses[:group.ResponsesNeeded]
collectedResponses = collectedResponses[group.ResponsesNeeded:]
// Check if pending groups can be satisfied (FIFO: oldest group first)
for len(pendingGroups) > 0 && len(collectedResponses) >= pendingGroups[0].ResponsesNeeded {
group := pendingGroups[0]
pendingGroups = pendingGroups[1:]
// Create merged function response content
functionResponseContent := `{"parts":[],"role":"function"}`
for _, response := range groupResponses {
if !response.IsObject() {
log.Warnf("failed to parse function response")
continue
}
functionResponseContent, _ = sjson.SetRaw(functionResponseContent, "parts.-1", response.Raw)
// Take the needed responses for this group
groupResponses := collectedResponses[:group.ResponsesNeeded]
collectedResponses = collectedResponses[group.ResponsesNeeded:]
// Create merged function response content
functionResponseContent := `{"parts":[],"role":"function"}`
for ri, response := range groupResponses {
if !response.IsObject() {
log.Warnf("failed to parse function response")
continue
}
raw := backfillFunctionResponseName(response.Raw, group.CallNames[ri])
functionResponseContent, _ = sjson.SetRaw(functionResponseContent, "parts.-1", raw)
}
if gjson.Get(functionResponseContent, "parts.#").Int() > 0 {
contentsWrapper, _ = sjson.SetRaw(contentsWrapper, "contents.-1", functionResponseContent)
}
// Remove this group as it's been satisfied
pendingGroups = append(pendingGroups[:i], pendingGroups[i+1:]...)
break
if gjson.Get(functionResponseContent, "parts.#").Int() > 0 {
contentsWrapper, _ = sjson.SetRaw(contentsWrapper, "contents.-1", functionResponseContent)
}
}
@@ -198,15 +207,15 @@ func fixCLIToolResponse(input string) (string, error) {
// If this is a model with function calls, create a new group
if role == "model" {
functionCallsCount := 0
var callNames []string
parts.ForEach(func(_, part gjson.Result) bool {
if part.Get("functionCall").Exists() {
functionCallsCount++
callNames = append(callNames, part.Get("functionCall.name").String())
}
return true
})
if functionCallsCount > 0 {
if len(callNames) > 0 {
// Add the model content
if !value.IsObject() {
log.Warnf("failed to parse model content")
@@ -216,7 +225,8 @@ func fixCLIToolResponse(input string) (string, error) {
// Create a new group for tracking responses
group := &FunctionCallGroup{
ResponsesNeeded: functionCallsCount,
ResponsesNeeded: len(callNames),
CallNames: callNames,
}
pendingGroups = append(pendingGroups, group)
} else {
@@ -246,12 +256,13 @@ func fixCLIToolResponse(input string) (string, error) {
collectedResponses = collectedResponses[group.ResponsesNeeded:]
functionResponseContent := `{"parts":[],"role":"function"}`
for _, response := range groupResponses {
for ri, response := range groupResponses {
if !response.IsObject() {
log.Warnf("failed to parse function response")
continue
}
functionResponseContent, _ = sjson.SetRaw(functionResponseContent, "parts.-1", response.Raw)
raw := backfillFunctionResponseName(response.Raw, group.CallNames[ri])
functionResponseContent, _ = sjson.SetRaw(functionResponseContent, "parts.-1", raw)
}
if gjson.Get(functionResponseContent, "parts.#").Int() > 0 {

View File

@@ -224,7 +224,7 @@ func ConvertGeminiResponseToClaude(_ context.Context, _ string, originalRequestR
// Create the tool use block with unique ID and function details
data := fmt.Sprintf(`{"type":"content_block_start","index":%d,"content_block":{"type":"tool_use","id":"","name":"","input":{}}}`, (*param).(*Params).ResponseIndex)
data, _ = sjson.Set(data, "content_block.id", fmt.Sprintf("%s-%d", upstreamToolName, atomic.AddUint64(&toolUseIDCounter, 1)))
data, _ = sjson.Set(data, "content_block.id", util.SanitizeClaudeToolID(fmt.Sprintf("%s-%d", upstreamToolName, atomic.AddUint64(&toolUseIDCounter, 1))))
data, _ = sjson.Set(data, "content_block.name", clientToolName)
output = output + fmt.Sprintf("data: %s\n\n\n", data)
@@ -343,7 +343,7 @@ func ConvertGeminiResponseToClaudeNonStream(_ context.Context, _ string, origina
clientToolName := util.MapToolName(toolNameMap, upstreamToolName)
toolIDCounter++
toolBlock := `{"type":"tool_use","id":"","name":"","input":{}}`
toolBlock, _ = sjson.Set(toolBlock, "id", fmt.Sprintf("%s-%d", upstreamToolName, toolIDCounter))
toolBlock, _ = sjson.Set(toolBlock, "id", util.SanitizeClaudeToolID(fmt.Sprintf("%s-%d", upstreamToolName, toolIDCounter)))
toolBlock, _ = sjson.Set(toolBlock, "name", clientToolName)
inputRaw := "{}"
if args := functionCall.Get("args"); args.Exists() && gjson.Valid(args.Raw) && args.IsObject() {

View File

@@ -5,9 +5,11 @@ package gemini
import (
"fmt"
"strings"
"github.com/router-for-me/CLIProxyAPI/v6/internal/translator/gemini/common"
"github.com/router-for-me/CLIProxyAPI/v6/internal/util"
log "github.com/sirupsen/logrus"
"github.com/tidwall/gjson"
"github.com/tidwall/sjson"
)
@@ -95,6 +97,71 @@ func ConvertGeminiRequestToGemini(_ string, inputRawJSON []byte, _ bool) []byte
out = []byte(strJson)
}
// Backfill empty functionResponse.name from the preceding functionCall.name.
// Amp may send function responses with empty names; the Gemini API rejects these.
out = backfillEmptyFunctionResponseNames(out)
out = common.AttachDefaultSafetySettings(out, "safetySettings")
return out
}
// backfillEmptyFunctionResponseNames walks the contents array and for each
// model turn containing functionCall parts, records the call names in order.
// For the immediately following user/function turn containing functionResponse
// parts, any empty name is replaced with the corresponding call name.
func backfillEmptyFunctionResponseNames(data []byte) []byte {
contents := gjson.GetBytes(data, "contents")
if !contents.Exists() {
return data
}
out := data
var pendingCallNames []string
contents.ForEach(func(contentIdx, content gjson.Result) bool {
role := content.Get("role").String()
// Collect functionCall names from model turns
if role == "model" {
var names []string
content.Get("parts").ForEach(func(_, part gjson.Result) bool {
if part.Get("functionCall").Exists() {
names = append(names, part.Get("functionCall.name").String())
}
return true
})
if len(names) > 0 {
pendingCallNames = names
} else {
pendingCallNames = nil
}
return true
}
// Backfill empty functionResponse names from pending call names
if len(pendingCallNames) > 0 {
ri := 0
content.Get("parts").ForEach(func(partIdx, part gjson.Result) bool {
if part.Get("functionResponse").Exists() {
name := part.Get("functionResponse.name").String()
if strings.TrimSpace(name) == "" {
if ri < len(pendingCallNames) {
out, _ = sjson.SetBytes(out,
fmt.Sprintf("contents.%d.parts.%d.functionResponse.name", contentIdx.Int(), partIdx.Int()),
pendingCallNames[ri])
} else {
log.Debugf("more function responses than calls at contents[%d], skipping name backfill", contentIdx.Int())
}
}
ri++
}
return true
})
pendingCallNames = nil
}
return true
})
return out
}

View File

@@ -0,0 +1,193 @@
package gemini
import (
"testing"
"github.com/tidwall/gjson"
)
func TestBackfillEmptyFunctionResponseNames_Single(t *testing.T) {
input := []byte(`{
"contents": [
{
"role": "model",
"parts": [
{"functionCall": {"name": "Bash", "args": {"cmd": "ls"}}}
]
},
{
"role": "user",
"parts": [
{"functionResponse": {"name": "", "response": {"output": "file1.txt"}}}
]
}
]
}`)
out := backfillEmptyFunctionResponseNames(input)
name := gjson.GetBytes(out, "contents.1.parts.0.functionResponse.name").String()
if name != "Bash" {
t.Errorf("Expected backfilled name 'Bash', got '%s'", name)
}
}
func TestBackfillEmptyFunctionResponseNames_Parallel(t *testing.T) {
input := []byte(`{
"contents": [
{
"role": "model",
"parts": [
{"functionCall": {"name": "Read", "args": {"path": "/a"}}},
{"functionCall": {"name": "Grep", "args": {"pattern": "x"}}}
]
},
{
"role": "user",
"parts": [
{"functionResponse": {"name": "", "response": {"result": "content a"}}},
{"functionResponse": {"name": "", "response": {"result": "match x"}}}
]
}
]
}`)
out := backfillEmptyFunctionResponseNames(input)
name0 := gjson.GetBytes(out, "contents.1.parts.0.functionResponse.name").String()
name1 := gjson.GetBytes(out, "contents.1.parts.1.functionResponse.name").String()
if name0 != "Read" {
t.Errorf("Expected first name 'Read', got '%s'", name0)
}
if name1 != "Grep" {
t.Errorf("Expected second name 'Grep', got '%s'", name1)
}
}
func TestBackfillEmptyFunctionResponseNames_PreservesExisting(t *testing.T) {
input := []byte(`{
"contents": [
{
"role": "model",
"parts": [
{"functionCall": {"name": "Bash", "args": {}}}
]
},
{
"role": "user",
"parts": [
{"functionResponse": {"name": "Bash", "response": {"result": "ok"}}}
]
}
]
}`)
out := backfillEmptyFunctionResponseNames(input)
name := gjson.GetBytes(out, "contents.1.parts.0.functionResponse.name").String()
if name != "Bash" {
t.Errorf("Expected preserved name 'Bash', got '%s'", name)
}
}
func TestConvertGeminiRequestToGemini_BackfillsEmptyName(t *testing.T) {
input := []byte(`{
"contents": [
{
"role": "model",
"parts": [
{"functionCall": {"name": "Bash", "args": {"cmd": "ls"}}}
]
},
{
"role": "user",
"parts": [
{"functionResponse": {"name": "", "response": {"output": "file1.txt"}}}
]
}
]
}`)
out := ConvertGeminiRequestToGemini("", input, false)
name := gjson.GetBytes(out, "contents.1.parts.0.functionResponse.name").String()
if name != "Bash" {
t.Errorf("Expected backfilled name 'Bash', got '%s'", name)
}
}
func TestBackfillEmptyFunctionResponseNames_MoreResponsesThanCalls(t *testing.T) {
// Extra responses beyond the call count should not panic and should be left unchanged.
input := []byte(`{
"contents": [
{
"role": "model",
"parts": [
{"functionCall": {"name": "Bash", "args": {}}}
]
},
{
"role": "user",
"parts": [
{"functionResponse": {"name": "", "response": {"result": "ok"}}},
{"functionResponse": {"name": "", "response": {"result": "extra"}}}
]
}
]
}`)
out := backfillEmptyFunctionResponseNames(input)
name0 := gjson.GetBytes(out, "contents.1.parts.0.functionResponse.name").String()
if name0 != "Bash" {
t.Errorf("Expected first name 'Bash', got '%s'", name0)
}
// Second response has no matching call, should remain empty
name1 := gjson.GetBytes(out, "contents.1.parts.1.functionResponse.name").String()
if name1 != "" {
t.Errorf("Expected second name to remain empty, got '%s'", name1)
}
}
func TestBackfillEmptyFunctionResponseNames_MultipleGroups(t *testing.T) {
// Two sequential call/response groups should each get correct names.
input := []byte(`{
"contents": [
{
"role": "model",
"parts": [
{"functionCall": {"name": "Read", "args": {}}}
]
},
{
"role": "user",
"parts": [
{"functionResponse": {"name": "", "response": {"result": "content"}}}
]
},
{
"role": "model",
"parts": [
{"functionCall": {"name": "Grep", "args": {}}}
]
},
{
"role": "user",
"parts": [
{"functionResponse": {"name": "", "response": {"result": "match"}}}
]
}
]
}`)
out := backfillEmptyFunctionResponseNames(input)
name0 := gjson.GetBytes(out, "contents.1.parts.0.functionResponse.name").String()
name1 := gjson.GetBytes(out, "contents.3.parts.0.functionResponse.name").String()
if name0 != "Read" {
t.Errorf("Expected first group name 'Read', got '%s'", name0)
}
if name1 != "Grep" {
t.Errorf("Expected second group name 'Grep', got '%s'", name1)
}
}

View File

@@ -147,21 +147,21 @@ func ConvertOpenAIRequestToGemini(modelName string, inputRawJSON []byte, _ bool)
content := m.Get("content")
if (role == "system" || role == "developer") && len(arr) > 1 {
// system -> system_instruction as a user message style
// system -> systemInstruction as a user message style
if content.Type == gjson.String {
out, _ = sjson.SetBytes(out, "system_instruction.role", "user")
out, _ = sjson.SetBytes(out, fmt.Sprintf("system_instruction.parts.%d.text", systemPartIndex), content.String())
out, _ = sjson.SetBytes(out, "systemInstruction.role", "user")
out, _ = sjson.SetBytes(out, fmt.Sprintf("systemInstruction.parts.%d.text", systemPartIndex), content.String())
systemPartIndex++
} else if content.IsObject() && content.Get("type").String() == "text" {
out, _ = sjson.SetBytes(out, "system_instruction.role", "user")
out, _ = sjson.SetBytes(out, fmt.Sprintf("system_instruction.parts.%d.text", systemPartIndex), content.Get("text").String())
out, _ = sjson.SetBytes(out, "systemInstruction.role", "user")
out, _ = sjson.SetBytes(out, fmt.Sprintf("systemInstruction.parts.%d.text", systemPartIndex), content.Get("text").String())
systemPartIndex++
} else if content.IsArray() {
contents := content.Array()
if len(contents) > 0 {
out, _ = sjson.SetBytes(out, "system_instruction.role", "user")
out, _ = sjson.SetBytes(out, "systemInstruction.role", "user")
for j := 0; j < len(contents); j++ {
out, _ = sjson.SetBytes(out, fmt.Sprintf("system_instruction.parts.%d.text", systemPartIndex), contents[j].Get("text").String())
out, _ = sjson.SetBytes(out, fmt.Sprintf("systemInstruction.parts.%d.text", systemPartIndex), contents[j].Get("text").String())
systemPartIndex++
}
}

View File

@@ -26,7 +26,7 @@ func ConvertOpenAIResponsesRequestToGemini(modelName string, inputRawJSON []byte
if instructions := root.Get("instructions"); instructions.Exists() {
systemInstr := `{"parts":[{"text":""}]}`
systemInstr, _ = sjson.Set(systemInstr, "parts.0.text", instructions.String())
out, _ = sjson.SetRaw(out, "system_instruction", systemInstr)
out, _ = sjson.SetRaw(out, "systemInstruction", systemInstr)
}
// Convert input messages to Gemini contents format
@@ -119,7 +119,7 @@ func ConvertOpenAIResponsesRequestToGemini(modelName string, inputRawJSON []byte
if strings.EqualFold(itemRole, "system") {
if contentArray := item.Get("content"); contentArray.Exists() {
systemInstr := ""
if systemInstructionResult := gjson.Get(out, "system_instruction"); systemInstructionResult.Exists() {
if systemInstructionResult := gjson.Get(out, "systemInstruction"); systemInstructionResult.Exists() {
systemInstr = systemInstructionResult.Raw
} else {
systemInstr = `{"parts":[]}`
@@ -140,7 +140,7 @@ func ConvertOpenAIResponsesRequestToGemini(modelName string, inputRawJSON []byte
}
if systemInstr != `{"parts":[]}` {
out, _ = sjson.SetRaw(out, "system_instruction", systemInstr)
out, _ = sjson.SetRaw(out, "systemInstruction", systemInstr)
}
}
continue
@@ -237,6 +237,33 @@ func ConvertOpenAIResponsesRequestToGemini(modelName string, inputRawJSON []byte
partJSON, _ = sjson.Set(partJSON, "inline_data.data", data)
}
}
case "input_audio":
audioData := contentItem.Get("data").String()
audioFormat := contentItem.Get("format").String()
if audioData != "" {
audioMimeMap := map[string]string{
"mp3": "audio/mpeg",
"wav": "audio/wav",
"ogg": "audio/ogg",
"flac": "audio/flac",
"aac": "audio/aac",
"webm": "audio/webm",
"pcm16": "audio/pcm",
"g711_ulaw": "audio/basic",
"g711_alaw": "audio/basic",
}
mimeType := "audio/wav"
if audioFormat != "" {
if mapped, ok := audioMimeMap[audioFormat]; ok {
mimeType = mapped
} else {
mimeType = "audio/" + audioFormat
}
}
partJSON = `{"inline_data":{"mime_type":"","data":""}}`
partJSON, _ = sjson.Set(partJSON, "inline_data.mime_type", mimeType)
partJSON, _ = sjson.Set(partJSON, "inline_data.data", audioData)
}
}
if partJSON != "" {

View File

@@ -183,7 +183,12 @@ func ConvertClaudeRequestToOpenAI(modelName string, inputRawJSON []byte, stream
// Collect tool_result to emit after the main message (ensures tool results follow tool_calls)
toolResultJSON := `{"role":"tool","tool_call_id":"","content":""}`
toolResultJSON, _ = sjson.Set(toolResultJSON, "tool_call_id", part.Get("tool_use_id").String())
toolResultJSON, _ = sjson.Set(toolResultJSON, "content", convertClaudeToolResultContentToString(part.Get("content")))
toolResultContent, toolResultContentRaw := convertClaudeToolResultContent(part.Get("content"))
if toolResultContentRaw {
toolResultJSON, _ = sjson.SetRaw(toolResultJSON, "content", toolResultContent)
} else {
toolResultJSON, _ = sjson.Set(toolResultJSON, "content", toolResultContent)
}
toolResults = append(toolResults, toolResultJSON)
}
return true
@@ -374,21 +379,41 @@ func convertClaudeContentPart(part gjson.Result) (string, bool) {
}
}
func convertClaudeToolResultContentToString(content gjson.Result) string {
func convertClaudeToolResultContent(content gjson.Result) (string, bool) {
if !content.Exists() {
return ""
return "", false
}
if content.Type == gjson.String {
return content.String()
return content.String(), false
}
if content.IsArray() {
var parts []string
contentJSON := "[]"
hasImagePart := false
content.ForEach(func(_, item gjson.Result) bool {
switch {
case item.Type == gjson.String:
parts = append(parts, item.String())
text := item.String()
parts = append(parts, text)
textContent := `{"type":"text","text":""}`
textContent, _ = sjson.Set(textContent, "text", text)
contentJSON, _ = sjson.SetRaw(contentJSON, "-1", textContent)
case item.IsObject() && item.Get("type").String() == "text":
text := item.Get("text").String()
parts = append(parts, text)
textContent := `{"type":"text","text":""}`
textContent, _ = sjson.Set(textContent, "text", text)
contentJSON, _ = sjson.SetRaw(contentJSON, "-1", textContent)
case item.IsObject() && item.Get("type").String() == "image":
contentItem, ok := convertClaudeContentPart(item)
if ok {
contentJSON, _ = sjson.SetRaw(contentJSON, "-1", contentItem)
hasImagePart = true
} else {
parts = append(parts, item.Raw)
}
case item.IsObject() && item.Get("text").Exists() && item.Get("text").Type == gjson.String:
parts = append(parts, item.Get("text").String())
default:
@@ -397,19 +422,31 @@ func convertClaudeToolResultContentToString(content gjson.Result) string {
return true
})
if hasImagePart {
return contentJSON, true
}
joined := strings.Join(parts, "\n\n")
if strings.TrimSpace(joined) != "" {
return joined
return joined, false
}
return content.Raw
return content.Raw, false
}
if content.IsObject() {
if text := content.Get("text"); text.Exists() && text.Type == gjson.String {
return text.String()
if content.Get("type").String() == "image" {
contentItem, ok := convertClaudeContentPart(content)
if ok {
contentJSON := "[]"
contentJSON, _ = sjson.SetRaw(contentJSON, "-1", contentItem)
return contentJSON, true
}
}
return content.Raw
if text := content.Get("text"); text.Exists() && text.Type == gjson.String {
return text.String(), false
}
return content.Raw, false
}
return content.Raw
return content.Raw, false
}

View File

@@ -488,6 +488,114 @@ func TestConvertClaudeRequestToOpenAI_ToolResultObjectContent(t *testing.T) {
}
}
func TestConvertClaudeRequestToOpenAI_ToolResultTextAndImageContent(t *testing.T) {
inputJSON := `{
"model": "claude-3-opus",
"messages": [
{
"role": "assistant",
"content": [
{"type": "tool_use", "id": "call_1", "name": "do_work", "input": {"a": 1}}
]
},
{
"role": "user",
"content": [
{
"type": "tool_result",
"tool_use_id": "call_1",
"content": [
{"type": "text", "text": "tool ok"},
{
"type": "image",
"source": {
"type": "base64",
"media_type": "image/png",
"data": "iVBORw0KGgoAAAANSUhEUg=="
}
}
]
}
]
}
]
}`
result := ConvertClaudeRequestToOpenAI("test-model", []byte(inputJSON), false)
resultJSON := gjson.ParseBytes(result)
messages := resultJSON.Get("messages").Array()
if len(messages) != 2 {
t.Fatalf("Expected 2 messages, got %d. Messages: %s", len(messages), resultJSON.Get("messages").Raw)
}
toolContent := messages[1].Get("content")
if !toolContent.IsArray() {
t.Fatalf("Expected tool content array, got %s", toolContent.Raw)
}
if got := toolContent.Get("0.type").String(); got != "text" {
t.Fatalf("Expected first tool content type %q, got %q", "text", got)
}
if got := toolContent.Get("0.text").String(); got != "tool ok" {
t.Fatalf("Expected first tool content text %q, got %q", "tool ok", got)
}
if got := toolContent.Get("1.type").String(); got != "image_url" {
t.Fatalf("Expected second tool content type %q, got %q", "image_url", got)
}
if got := toolContent.Get("1.image_url.url").String(); got != "data:image/png;base64,iVBORw0KGgoAAAANSUhEUg==" {
t.Fatalf("Unexpected image_url: %q", got)
}
}
func TestConvertClaudeRequestToOpenAI_ToolResultURLImageOnly(t *testing.T) {
inputJSON := `{
"model": "claude-3-opus",
"messages": [
{
"role": "assistant",
"content": [
{"type": "tool_use", "id": "call_1", "name": "do_work", "input": {"a": 1}}
]
},
{
"role": "user",
"content": [
{
"type": "tool_result",
"tool_use_id": "call_1",
"content": {
"type": "image",
"source": {
"type": "url",
"url": "https://example.com/tool.png"
}
}
}
]
}
]
}`
result := ConvertClaudeRequestToOpenAI("test-model", []byte(inputJSON), false)
resultJSON := gjson.ParseBytes(result)
messages := resultJSON.Get("messages").Array()
if len(messages) != 2 {
t.Fatalf("Expected 2 messages, got %d. Messages: %s", len(messages), resultJSON.Get("messages").Raw)
}
toolContent := messages[1].Get("content")
if !toolContent.IsArray() {
t.Fatalf("Expected tool content array, got %s", toolContent.Raw)
}
if got := toolContent.Get("0.type").String(); got != "image_url" {
t.Fatalf("Expected tool content type %q, got %q", "image_url", got)
}
if got := toolContent.Get("0.image_url.url").String(); got != "https://example.com/tool.png" {
t.Fatalf("Unexpected image_url: %q", got)
}
}
func TestConvertClaudeRequestToOpenAI_AssistantTextToolUseTextOrder(t *testing.T) {
inputJSON := `{
"model": "claude-3-opus",

View File

@@ -243,7 +243,7 @@ func convertOpenAIStreamingChunkToAnthropic(rawJSON []byte, param *ConvertOpenAI
// Send content_block_start for tool_use
contentBlockStartJSON := `{"type":"content_block_start","index":0,"content_block":{"type":"tool_use","id":"","name":"","input":{}}}`
contentBlockStartJSON, _ = sjson.Set(contentBlockStartJSON, "index", blockIndex)
contentBlockStartJSON, _ = sjson.Set(contentBlockStartJSON, "content_block.id", accumulator.ID)
contentBlockStartJSON, _ = sjson.Set(contentBlockStartJSON, "content_block.id", util.SanitizeClaudeToolID(accumulator.ID))
contentBlockStartJSON, _ = sjson.Set(contentBlockStartJSON, "content_block.name", accumulator.Name)
results = append(results, "event: content_block_start\ndata: "+contentBlockStartJSON+"\n\n")
}
@@ -414,7 +414,7 @@ func convertOpenAINonStreamingToAnthropic(rawJSON []byte) []string {
if toolCalls := choice.Get("message.tool_calls"); toolCalls.Exists() && toolCalls.IsArray() {
toolCalls.ForEach(func(_, toolCall gjson.Result) bool {
toolUseBlock := `{"type":"tool_use","id":"","name":"","input":{}}`
toolUseBlock, _ = sjson.Set(toolUseBlock, "id", toolCall.Get("id").String())
toolUseBlock, _ = sjson.Set(toolUseBlock, "id", util.SanitizeClaudeToolID(toolCall.Get("id").String()))
toolUseBlock, _ = sjson.Set(toolUseBlock, "name", toolCall.Get("function.name").String())
argsStr := util.FixJSON(toolCall.Get("function.arguments").String())
@@ -612,7 +612,7 @@ func ConvertOpenAIResponseToClaudeNonStream(_ context.Context, _ string, origina
toolCalls.ForEach(func(_, tc gjson.Result) bool {
hasToolCall = true
toolUse := `{"type":"tool_use","id":"","name":"","input":{}}`
toolUse, _ = sjson.Set(toolUse, "id", tc.Get("id").String())
toolUse, _ = sjson.Set(toolUse, "id", util.SanitizeClaudeToolID(tc.Get("id").String()))
toolUse, _ = sjson.Set(toolUse, "name", util.MapToolName(toolNameMap, tc.Get("function.name").String()))
argsStr := util.FixJSON(tc.Get("function.arguments").String())
@@ -669,7 +669,7 @@ func ConvertOpenAIResponseToClaudeNonStream(_ context.Context, _ string, origina
toolCalls.ForEach(func(_, toolCall gjson.Result) bool {
hasToolCall = true
toolUseBlock := `{"type":"tool_use","id":"","name":"","input":{}}`
toolUseBlock, _ = sjson.Set(toolUseBlock, "id", toolCall.Get("id").String())
toolUseBlock, _ = sjson.Set(toolUseBlock, "id", util.SanitizeClaudeToolID(toolCall.Get("id").String()))
toolUseBlock, _ = sjson.Set(toolUseBlock, "name", util.MapToolName(toolNameMap, toolCall.Get("function.name").String()))
argsStr := util.FixJSON(toolCall.Get("function.arguments").String())

View File

@@ -0,0 +1,24 @@
package util
import (
"fmt"
"regexp"
"sync/atomic"
"time"
)
var (
claudeToolUseIDSanitizer = regexp.MustCompile(`[^a-zA-Z0-9_-]`)
claudeToolUseIDCounter uint64
)
// SanitizeClaudeToolID ensures the given id conforms to Claude's
// tool_use.id regex ^[a-zA-Z0-9_-]+$. Non-conforming characters are
// replaced with '_'; an empty result gets a generated fallback.
func SanitizeClaudeToolID(id string) string {
s := claudeToolUseIDSanitizer.ReplaceAllString(id, "_")
if s == "" {
s = fmt.Sprintf("toolu_%d_%d", time.Now().UnixNano(), atomic.AddUint64(&claudeToolUseIDCounter, 1))
}
return s
}

View File

@@ -4,50 +4,25 @@
package util
import (
"context"
"net"
"net/http"
"net/url"
"github.com/router-for-me/CLIProxyAPI/v6/sdk/config"
"github.com/router-for-me/CLIProxyAPI/v6/sdk/proxyutil"
log "github.com/sirupsen/logrus"
"golang.org/x/net/proxy"
)
// SetProxy configures the provided HTTP client with proxy settings from the configuration.
// It supports SOCKS5, HTTP, and HTTPS proxies. The function modifies the client's transport
// to route requests through the configured proxy server.
func SetProxy(cfg *config.SDKConfig, httpClient *http.Client) *http.Client {
var transport *http.Transport
// Attempt to parse the proxy URL from the configuration.
proxyURL, errParse := url.Parse(cfg.ProxyURL)
if errParse == nil {
// Handle different proxy schemes.
if proxyURL.Scheme == "socks5" {
// Configure SOCKS5 proxy with optional authentication.
var proxyAuth *proxy.Auth
if proxyURL.User != nil {
username := proxyURL.User.Username()
password, _ := proxyURL.User.Password()
proxyAuth = &proxy.Auth{User: username, Password: password}
}
dialer, errSOCKS5 := proxy.SOCKS5("tcp", proxyURL.Host, proxyAuth, proxy.Direct)
if errSOCKS5 != nil {
log.Errorf("create SOCKS5 dialer failed: %v", errSOCKS5)
return httpClient
}
// Set up a custom transport using the SOCKS5 dialer.
transport = &http.Transport{
DialContext: func(ctx context.Context, network, addr string) (net.Conn, error) {
return dialer.Dial(network, addr)
},
}
} else if proxyURL.Scheme == "http" || proxyURL.Scheme == "https" {
// Configure HTTP or HTTPS proxy.
transport = &http.Transport{Proxy: http.ProxyURL(proxyURL)}
}
if cfg == nil || httpClient == nil {
return httpClient
}
transport, _, errBuild := proxyutil.BuildHTTPTransport(cfg.ProxyURL)
if errBuild != nil {
log.Errorf("%v", errBuild)
}
// If a new transport was created, apply it to the HTTP client.
if transport != nil {
httpClient.Transport = transport
}

View File

@@ -17,6 +17,7 @@ import (
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
"github.com/router-for-me/CLIProxyAPI/v6/internal/util"
"github.com/router-for-me/CLIProxyAPI/v6/internal/watcher/diff"
"github.com/router-for-me/CLIProxyAPI/v6/internal/watcher/synthesizer"
coreauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth"
log "github.com/sirupsen/logrus"
)
@@ -75,6 +76,7 @@ func (w *Watcher) reloadClients(rescanAuth bool, affectedOAuthProviders []string
w.lastAuthHashes = make(map[string]string)
w.lastAuthContents = make(map[string]*coreauth.Auth)
w.fileAuthsByPath = make(map[string]map[string]*coreauth.Auth)
if resolvedAuthDir, errResolveAuthDir := util.ResolveAuthDir(cfg.AuthDir); errResolveAuthDir != nil {
log.Errorf("failed to resolve auth directory for hash cache: %v", errResolveAuthDir)
} else if resolvedAuthDir != "" {
@@ -92,6 +94,17 @@ func (w *Watcher) reloadClients(rescanAuth bool, affectedOAuthProviders []string
if errParse := json.Unmarshal(data, &auth); errParse == nil {
w.lastAuthContents[normalizedPath] = &auth
}
ctx := &synthesizer.SynthesisContext{
Config: cfg,
AuthDir: resolvedAuthDir,
Now: time.Now(),
IDGenerator: synthesizer.NewStableIDGenerator(),
}
if generated := synthesizer.SynthesizeAuthFile(ctx, path, data); len(generated) > 0 {
if pathAuths := authSliceToMap(generated); len(pathAuths) > 0 {
w.fileAuthsByPath[normalizedPath] = pathAuths
}
}
}
}
return nil
@@ -143,13 +156,14 @@ func (w *Watcher) addOrUpdateClient(path string) {
}
w.clientsMutex.Lock()
cfg := w.config
if cfg == nil {
if w.config == nil {
log.Error("config is nil, cannot add or update client")
w.clientsMutex.Unlock()
return
}
if w.fileAuthsByPath == nil {
w.fileAuthsByPath = make(map[string]map[string]*coreauth.Auth)
}
if prev, ok := w.lastAuthHashes[normalized]; ok && prev == curHash {
log.Debugf("auth file unchanged (hash match), skipping reload: %s", filepath.Base(path))
w.clientsMutex.Unlock()
@@ -177,34 +191,86 @@ func (w *Watcher) addOrUpdateClient(path string) {
}
w.lastAuthContents[normalized] = &newAuth
w.clientsMutex.Unlock() // Unlock before the callback
w.refreshAuthState(false)
if w.reloadCallback != nil {
log.Debugf("triggering server update callback after add/update")
w.triggerServerUpdate(cfg)
oldByID := make(map[string]*coreauth.Auth, len(w.fileAuthsByPath[normalized]))
for id, a := range w.fileAuthsByPath[normalized] {
oldByID[id] = a
}
// Build synthesized auth entries for this single file only.
sctx := &synthesizer.SynthesisContext{
Config: w.config,
AuthDir: w.authDir,
Now: time.Now(),
IDGenerator: synthesizer.NewStableIDGenerator(),
}
generated := synthesizer.SynthesizeAuthFile(sctx, path, data)
newByID := authSliceToMap(generated)
if len(newByID) > 0 {
w.fileAuthsByPath[normalized] = newByID
} else {
delete(w.fileAuthsByPath, normalized)
}
updates := w.computePerPathUpdatesLocked(oldByID, newByID)
w.clientsMutex.Unlock()
w.persistAuthAsync(fmt.Sprintf("Sync auth %s", filepath.Base(path)), path)
w.dispatchAuthUpdates(updates)
}
func (w *Watcher) removeClient(path string) {
normalized := w.normalizeAuthPath(path)
w.clientsMutex.Lock()
cfg := w.config
oldByID := make(map[string]*coreauth.Auth, len(w.fileAuthsByPath[normalized]))
for id, a := range w.fileAuthsByPath[normalized] {
oldByID[id] = a
}
delete(w.lastAuthHashes, normalized)
delete(w.lastAuthContents, normalized)
delete(w.fileAuthsByPath, normalized)
w.clientsMutex.Unlock() // Release the lock before the callback
updates := w.computePerPathUpdatesLocked(oldByID, map[string]*coreauth.Auth{})
w.clientsMutex.Unlock()
w.refreshAuthState(false)
if w.reloadCallback != nil {
log.Debugf("triggering server update callback after removal")
w.triggerServerUpdate(cfg)
}
w.persistAuthAsync(fmt.Sprintf("Remove auth %s", filepath.Base(path)), path)
w.dispatchAuthUpdates(updates)
}
func (w *Watcher) computePerPathUpdatesLocked(oldByID, newByID map[string]*coreauth.Auth) []AuthUpdate {
if w.currentAuths == nil {
w.currentAuths = make(map[string]*coreauth.Auth)
}
updates := make([]AuthUpdate, 0, len(oldByID)+len(newByID))
for id, newAuth := range newByID {
existing, ok := w.currentAuths[id]
if !ok {
w.currentAuths[id] = newAuth.Clone()
updates = append(updates, AuthUpdate{Action: AuthUpdateActionAdd, ID: id, Auth: newAuth.Clone()})
continue
}
if !authEqual(existing, newAuth) {
w.currentAuths[id] = newAuth.Clone()
updates = append(updates, AuthUpdate{Action: AuthUpdateActionModify, ID: id, Auth: newAuth.Clone()})
}
}
for id := range oldByID {
if _, stillExists := newByID[id]; stillExists {
continue
}
delete(w.currentAuths, id)
updates = append(updates, AuthUpdate{Action: AuthUpdateActionDelete, ID: id})
}
return updates
}
func authSliceToMap(auths []*coreauth.Auth) map[string]*coreauth.Auth {
byID := make(map[string]*coreauth.Auth, len(auths))
for _, a := range auths {
if a == nil || strings.TrimSpace(a.ID) == "" {
continue
}
byID[a.ID] = a
}
return byID
}
func (w *Watcher) loadFileClients(cfg *config.Config) int {

View File

@@ -14,6 +14,8 @@ import (
coreauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth"
)
var snapshotCoreAuthsFunc = snapshotCoreAuths
func (w *Watcher) setAuthUpdateQueue(queue chan<- AuthUpdate) {
w.clientsMutex.Lock()
defer w.clientsMutex.Unlock()
@@ -76,7 +78,11 @@ func (w *Watcher) dispatchRuntimeAuthUpdate(update AuthUpdate) bool {
}
func (w *Watcher) refreshAuthState(force bool) {
auths := w.SnapshotCoreAuths()
w.clientsMutex.RLock()
cfg := w.config
authDir := w.authDir
w.clientsMutex.RUnlock()
auths := snapshotCoreAuthsFunc(cfg, authDir)
w.clientsMutex.Lock()
if len(w.runtimeAuths) > 0 {
for _, a := range w.runtimeAuths {

View File

@@ -10,6 +10,7 @@ import (
"strings"
"time"
"github.com/router-for-me/CLIProxyAPI/v6/internal/auth/codex"
"github.com/router-for-me/CLIProxyAPI/v6/internal/runtime/geminicli"
coreauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth"
)
@@ -36,9 +37,6 @@ func (s *FileSynthesizer) Synthesize(ctx *SynthesisContext) ([]*coreauth.Auth, e
return out, nil
}
now := ctx.Now
cfg := ctx.Config
for _, e := range entries {
if e.IsDir() {
continue
@@ -52,99 +50,130 @@ func (s *FileSynthesizer) Synthesize(ctx *SynthesisContext) ([]*coreauth.Auth, e
if errRead != nil || len(data) == 0 {
continue
}
var metadata map[string]any
if errUnmarshal := json.Unmarshal(data, &metadata); errUnmarshal != nil {
auths := synthesizeFileAuths(ctx, full, data)
if len(auths) == 0 {
continue
}
t, _ := metadata["type"].(string)
if t == "" {
continue
}
provider := strings.ToLower(t)
if provider == "gemini" {
provider = "gemini-cli"
}
label := provider
if email, _ := metadata["email"].(string); email != "" {
label = email
}
// Use relative path under authDir as ID to stay consistent with the file-based token store
id := full
if rel, errRel := filepath.Rel(ctx.AuthDir, full); errRel == nil && rel != "" {
id = rel
}
// On Windows, normalize ID casing to avoid duplicate auth entries caused by case-insensitive paths.
if runtime.GOOS == "windows" {
id = strings.ToLower(id)
}
proxyURL := ""
if p, ok := metadata["proxy_url"].(string); ok {
proxyURL = p
}
prefix := ""
if rawPrefix, ok := metadata["prefix"].(string); ok {
trimmed := strings.TrimSpace(rawPrefix)
trimmed = strings.Trim(trimmed, "/")
if trimmed != "" && !strings.Contains(trimmed, "/") {
prefix = trimmed
}
}
disabled, _ := metadata["disabled"].(bool)
status := coreauth.StatusActive
if disabled {
status = coreauth.StatusDisabled
}
// Read per-account excluded models from the OAuth JSON file
perAccountExcluded := extractExcludedModelsFromMetadata(metadata)
a := &coreauth.Auth{
ID: id,
Provider: provider,
Label: label,
Prefix: prefix,
Status: status,
Disabled: disabled,
Attributes: map[string]string{
"source": full,
"path": full,
},
ProxyURL: proxyURL,
Metadata: metadata,
CreatedAt: now,
UpdatedAt: now,
}
// Read priority from auth file
if rawPriority, ok := metadata["priority"]; ok {
switch v := rawPriority.(type) {
case float64:
a.Attributes["priority"] = strconv.Itoa(int(v))
case string:
priority := strings.TrimSpace(v)
if _, errAtoi := strconv.Atoi(priority); errAtoi == nil {
a.Attributes["priority"] = priority
}
}
}
ApplyAuthExcludedModelsMeta(a, cfg, perAccountExcluded, "oauth")
if provider == "gemini-cli" {
if virtuals := SynthesizeGeminiVirtualAuths(a, metadata, now); len(virtuals) > 0 {
for _, v := range virtuals {
ApplyAuthExcludedModelsMeta(v, cfg, perAccountExcluded, "oauth")
}
out = append(out, a)
out = append(out, virtuals...)
continue
}
}
out = append(out, a)
out = append(out, auths...)
}
return out, nil
}
// SynthesizeAuthFile generates Auth entries for one auth JSON file payload.
// It shares exactly the same mapping behavior as FileSynthesizer.Synthesize.
func SynthesizeAuthFile(ctx *SynthesisContext, fullPath string, data []byte) []*coreauth.Auth {
return synthesizeFileAuths(ctx, fullPath, data)
}
func synthesizeFileAuths(ctx *SynthesisContext, fullPath string, data []byte) []*coreauth.Auth {
if ctx == nil || len(data) == 0 {
return nil
}
now := ctx.Now
cfg := ctx.Config
var metadata map[string]any
if errUnmarshal := json.Unmarshal(data, &metadata); errUnmarshal != nil {
return nil
}
t, _ := metadata["type"].(string)
if t == "" {
return nil
}
provider := strings.ToLower(t)
if provider == "gemini" {
provider = "gemini-cli"
}
label := provider
if email, _ := metadata["email"].(string); email != "" {
label = email
}
// Use relative path under authDir as ID to stay consistent with the file-based token store.
id := fullPath
if strings.TrimSpace(ctx.AuthDir) != "" {
if rel, errRel := filepath.Rel(ctx.AuthDir, fullPath); errRel == nil && rel != "" {
id = rel
}
}
if runtime.GOOS == "windows" {
id = strings.ToLower(id)
}
proxyURL := ""
if p, ok := metadata["proxy_url"].(string); ok {
proxyURL = p
}
prefix := ""
if rawPrefix, ok := metadata["prefix"].(string); ok {
trimmed := strings.TrimSpace(rawPrefix)
trimmed = strings.Trim(trimmed, "/")
if trimmed != "" && !strings.Contains(trimmed, "/") {
prefix = trimmed
}
}
disabled, _ := metadata["disabled"].(bool)
status := coreauth.StatusActive
if disabled {
status = coreauth.StatusDisabled
}
// Read per-account excluded models from the OAuth JSON file.
perAccountExcluded := extractExcludedModelsFromMetadata(metadata)
a := &coreauth.Auth{
ID: id,
Provider: provider,
Label: label,
Prefix: prefix,
Status: status,
Disabled: disabled,
Attributes: map[string]string{
"source": fullPath,
"path": fullPath,
},
ProxyURL: proxyURL,
Metadata: metadata,
CreatedAt: now,
UpdatedAt: now,
}
// Read priority from auth file.
if rawPriority, ok := metadata["priority"]; ok {
switch v := rawPriority.(type) {
case float64:
a.Attributes["priority"] = strconv.Itoa(int(v))
case string:
priority := strings.TrimSpace(v)
if _, errAtoi := strconv.Atoi(priority); errAtoi == nil {
a.Attributes["priority"] = priority
}
}
}
ApplyAuthExcludedModelsMeta(a, cfg, perAccountExcluded, "oauth")
// For codex auth files, extract plan_type from the JWT id_token.
if provider == "codex" {
if idTokenRaw, ok := metadata["id_token"].(string); ok && strings.TrimSpace(idTokenRaw) != "" {
if claims, errParse := codex.ParseJWTToken(idTokenRaw); errParse == nil && claims != nil {
if pt := strings.TrimSpace(claims.CodexAuthInfo.ChatgptPlanType); pt != "" {
a.Attributes["plan_type"] = pt
}
}
}
}
if provider == "gemini-cli" {
if virtuals := SynthesizeGeminiVirtualAuths(a, metadata, now); len(virtuals) > 0 {
for _, v := range virtuals {
ApplyAuthExcludedModelsMeta(v, cfg, perAccountExcluded, "oauth")
}
out := make([]*coreauth.Auth, 0, 1+len(virtuals))
out = append(out, a)
out = append(out, virtuals...)
return out
}
}
return []*coreauth.Auth{a}
}
// SynthesizeGeminiVirtualAuths creates virtual Auth entries for multi-project Gemini credentials.
// It disables the primary auth and creates one virtual auth per project.
func SynthesizeGeminiVirtualAuths(primary *coreauth.Auth, metadata map[string]any, now time.Time) []*coreauth.Auth {

View File

@@ -45,6 +45,7 @@ type Watcher struct {
watcher *fsnotify.Watcher
lastAuthHashes map[string]string
lastAuthContents map[string]*coreauth.Auth
fileAuthsByPath map[string]map[string]*coreauth.Auth
lastRemoveTimes map[string]time.Time
lastConfigHash string
authQueue chan<- AuthUpdate
@@ -92,11 +93,12 @@ func NewWatcher(configPath, authDir string, reloadCallback func(*config.Config))
return nil, errNewWatcher
}
w := &Watcher{
configPath: configPath,
authDir: authDir,
reloadCallback: reloadCallback,
watcher: watcher,
lastAuthHashes: make(map[string]string),
configPath: configPath,
authDir: authDir,
reloadCallback: reloadCallback,
watcher: watcher,
lastAuthHashes: make(map[string]string),
fileAuthsByPath: make(map[string]map[string]*coreauth.Auth),
}
w.dispatchCond = sync.NewCond(&w.dispatchMu)
if store := sdkAuth.GetTokenStore(); store != nil {

View File

@@ -406,8 +406,8 @@ func TestAddOrUpdateClientTriggersReloadAndHash(t *testing.T) {
w.addOrUpdateClient(authFile)
if got := atomic.LoadInt32(&reloads); got != 1 {
t.Fatalf("expected reload callback once, got %d", got)
if got := atomic.LoadInt32(&reloads); got != 0 {
t.Fatalf("expected no reload callback for auth update, got %d", got)
}
// Use normalizeAuthPath to match how addOrUpdateClient stores the key
normalized := w.normalizeAuthPath(authFile)
@@ -436,8 +436,110 @@ func TestRemoveClientRemovesHash(t *testing.T) {
if _, ok := w.lastAuthHashes[w.normalizeAuthPath(authFile)]; ok {
t.Fatal("expected hash to be removed after deletion")
}
if got := atomic.LoadInt32(&reloads); got != 1 {
t.Fatalf("expected reload callback once, got %d", got)
if got := atomic.LoadInt32(&reloads); got != 0 {
t.Fatalf("expected no reload callback for auth removal, got %d", got)
}
}
func TestAuthFileEventsDoNotInvokeSnapshotCoreAuths(t *testing.T) {
tmpDir := t.TempDir()
authFile := filepath.Join(tmpDir, "sample.json")
if err := os.WriteFile(authFile, []byte(`{"type":"codex","email":"u@example.com"}`), 0o644); err != nil {
t.Fatalf("failed to create auth file: %v", err)
}
origSnapshot := snapshotCoreAuthsFunc
var snapshotCalls int32
snapshotCoreAuthsFunc = func(cfg *config.Config, authDir string) []*coreauth.Auth {
atomic.AddInt32(&snapshotCalls, 1)
return origSnapshot(cfg, authDir)
}
defer func() { snapshotCoreAuthsFunc = origSnapshot }()
w := &Watcher{
authDir: tmpDir,
lastAuthHashes: make(map[string]string),
lastAuthContents: make(map[string]*coreauth.Auth),
fileAuthsByPath: make(map[string]map[string]*coreauth.Auth),
}
w.SetConfig(&config.Config{AuthDir: tmpDir})
w.addOrUpdateClient(authFile)
w.removeClient(authFile)
if got := atomic.LoadInt32(&snapshotCalls); got != 0 {
t.Fatalf("expected auth file events to avoid full snapshot, got %d calls", got)
}
}
func TestAuthSliceToMap(t *testing.T) {
t.Parallel()
valid1 := &coreauth.Auth{ID: "a"}
valid2 := &coreauth.Auth{ID: "b"}
dupOld := &coreauth.Auth{ID: "dup", Label: "old"}
dupNew := &coreauth.Auth{ID: "dup", Label: "new"}
empty := &coreauth.Auth{ID: " "}
tests := []struct {
name string
in []*coreauth.Auth
want map[string]*coreauth.Auth
}{
{
name: "nil input",
in: nil,
want: map[string]*coreauth.Auth{},
},
{
name: "empty input",
in: []*coreauth.Auth{},
want: map[string]*coreauth.Auth{},
},
{
name: "filters invalid auths",
in: []*coreauth.Auth{nil, empty},
want: map[string]*coreauth.Auth{},
},
{
name: "keeps valid auths",
in: []*coreauth.Auth{valid1, nil, valid2},
want: map[string]*coreauth.Auth{"a": valid1, "b": valid2},
},
{
name: "last duplicate wins",
in: []*coreauth.Auth{dupOld, dupNew},
want: map[string]*coreauth.Auth{"dup": dupNew},
},
}
for _, tc := range tests {
tc := tc
t.Run(tc.name, func(t *testing.T) {
t.Parallel()
got := authSliceToMap(tc.in)
if len(tc.want) == 0 {
if got == nil {
t.Fatal("expected empty map, got nil")
}
if len(got) != 0 {
t.Fatalf("expected empty map, got %#v", got)
}
return
}
if len(got) != len(tc.want) {
t.Fatalf("unexpected map length: got %d, want %d", len(got), len(tc.want))
}
for id, wantAuth := range tc.want {
gotAuth, ok := got[id]
if !ok {
t.Fatalf("missing id %q in result map", id)
}
if !authEqual(gotAuth, wantAuth) {
t.Fatalf("unexpected auth for id %q: got %#v, want %#v", id, gotAuth, wantAuth)
}
}
})
}
}
@@ -695,8 +797,8 @@ func TestHandleEventRemovesAuthFile(t *testing.T) {
w.handleEvent(fsnotify.Event{Name: authFile, Op: fsnotify.Remove})
if atomic.LoadInt32(&reloads) != 1 {
t.Fatalf("expected reload callback once, got %d", reloads)
if atomic.LoadInt32(&reloads) != 0 {
t.Fatalf("expected no reload callback for auth removal, got %d", reloads)
}
if _, ok := w.lastAuthHashes[w.normalizeAuthPath(authFile)]; ok {
t.Fatal("expected hash entry to be removed")
@@ -893,8 +995,8 @@ func TestHandleEventAuthWriteTriggersUpdate(t *testing.T) {
w.SetConfig(&config.Config{AuthDir: authDir})
w.handleEvent(fsnotify.Event{Name: authFile, Op: fsnotify.Write})
if atomic.LoadInt32(&reloads) != 1 {
t.Fatalf("expected auth write to trigger reload callback, got %d", reloads)
if atomic.LoadInt32(&reloads) != 0 {
t.Fatalf("expected auth write to avoid global reload callback, got %d", reloads)
}
}
@@ -990,8 +1092,8 @@ func TestHandleEventAtomicReplaceChangedTriggersUpdate(t *testing.T) {
w.lastAuthHashes[w.normalizeAuthPath(authFile)] = hexString(oldSum[:])
w.handleEvent(fsnotify.Event{Name: authFile, Op: fsnotify.Rename})
if atomic.LoadInt32(&reloads) != 1 {
t.Fatalf("expected changed atomic replace to trigger update, got %d", reloads)
if atomic.LoadInt32(&reloads) != 0 {
t.Fatalf("expected changed atomic replace to avoid global reload, got %d", reloads)
}
}
@@ -1045,8 +1147,8 @@ func TestHandleEventRemoveKnownFileDeletes(t *testing.T) {
w.lastAuthHashes[w.normalizeAuthPath(authFile)] = "hash"
w.handleEvent(fsnotify.Event{Name: authFile, Op: fsnotify.Remove})
if atomic.LoadInt32(&reloads) != 1 {
t.Fatalf("expected known remove to trigger reload, got %d", reloads)
if atomic.LoadInt32(&reloads) != 0 {
t.Fatalf("expected known remove to avoid global reload, got %d", reloads)
}
if _, ok := w.lastAuthHashes[w.normalizeAuthPath(authFile)]; ok {
t.Fatal("expected known auth hash to be deleted")

View File

@@ -0,0 +1,151 @@
package claude
import (
"context"
"net/http"
"net/http/httptest"
"strings"
"testing"
"github.com/gin-gonic/gin"
internalconfig "github.com/router-for-me/CLIProxyAPI/v6/internal/config"
"github.com/router-for-me/CLIProxyAPI/v6/internal/registry"
runtimeexecutor "github.com/router-for-me/CLIProxyAPI/v6/internal/runtime/executor"
"github.com/router-for-me/CLIProxyAPI/v6/sdk/api/handlers"
coreauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth"
sdkconfig "github.com/router-for-me/CLIProxyAPI/v6/sdk/config"
)
func TestClaudeMessagesWithGitLabDuoAnthropicGateway(t *testing.T) {
gin.SetMode(gin.TestMode)
var gotPath, gotAuthHeader, gotRealmHeader string
upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
gotPath = r.URL.Path
gotAuthHeader = r.Header.Get("Authorization")
gotRealmHeader = r.Header.Get("X-Gitlab-Realm")
w.Header().Set("Content-Type", "application/json")
_, _ = w.Write([]byte(`{"id":"msg_1","type":"message","role":"assistant","model":"claude-sonnet-4-5","content":[{"type":"tool_use","id":"toolu_1","name":"Bash","input":{"cmd":"ls"}}],"stop_reason":"tool_use","stop_sequence":null,"usage":{"input_tokens":11,"output_tokens":4}}`))
}))
defer upstream.Close()
manager, _ := registerGitLabDuoAnthropicAuth(t, upstream.URL)
base := handlers.NewBaseAPIHandlers(&sdkconfig.SDKConfig{}, manager)
h := NewClaudeCodeAPIHandler(base)
router := gin.New()
router.POST("/v1/messages", h.ClaudeMessages)
req := httptest.NewRequest(http.MethodPost, "/v1/messages", strings.NewReader(`{
"model":"claude-sonnet-4-5",
"max_tokens":128,
"messages":[{"role":"user","content":"list files"}],
"tools":[{"name":"Bash","description":"run bash","input_schema":{"type":"object","properties":{"cmd":{"type":"string"}},"required":["cmd"]}}]
}`))
req.Header.Set("Content-Type", "application/json")
req.Header.Set("Anthropic-Version", "2023-06-01")
resp := httptest.NewRecorder()
router.ServeHTTP(resp, req)
if resp.Code != http.StatusOK {
t.Fatalf("status = %d, want %d body=%s", resp.Code, http.StatusOK, resp.Body.String())
}
if gotPath != "/v1/proxy/anthropic/v1/messages" {
t.Fatalf("path = %q, want %q", gotPath, "/v1/proxy/anthropic/v1/messages")
}
if gotAuthHeader != "Bearer gateway-token" {
t.Fatalf("authorization = %q, want Bearer gateway-token", gotAuthHeader)
}
if gotRealmHeader != "saas" {
t.Fatalf("x-gitlab-realm = %q, want saas", gotRealmHeader)
}
if !strings.Contains(resp.Body.String(), `"tool_use"`) {
t.Fatalf("expected tool_use response, got %s", resp.Body.String())
}
if !strings.Contains(resp.Body.String(), `"Bash"`) {
t.Fatalf("expected Bash tool in response, got %s", resp.Body.String())
}
}
func TestClaudeMessagesStreamWithGitLabDuoAnthropicGateway(t *testing.T) {
gin.SetMode(gin.TestMode)
var gotPath string
upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
gotPath = r.URL.Path
w.Header().Set("Content-Type", "text/event-stream")
_, _ = w.Write([]byte("event: message_start\n"))
_, _ = w.Write([]byte("data: {\"type\":\"message_start\",\"message\":{\"id\":\"msg_1\",\"type\":\"message\",\"role\":\"assistant\",\"model\":\"claude-sonnet-4-5\",\"content\":[],\"stop_reason\":null,\"stop_sequence\":null,\"usage\":{\"input_tokens\":0,\"output_tokens\":0}}}\n\n"))
_, _ = w.Write([]byte("event: content_block_start\n"))
_, _ = w.Write([]byte("data: {\"type\":\"content_block_start\",\"index\":0,\"content_block\":{\"type\":\"text\",\"text\":\"\"}}\n\n"))
_, _ = w.Write([]byte("event: content_block_delta\n"))
_, _ = w.Write([]byte("data: {\"type\":\"content_block_delta\",\"index\":0,\"delta\":{\"type\":\"text_delta\",\"text\":\"hello from duo\"}}\n\n"))
_, _ = w.Write([]byte("event: message_delta\n"))
_, _ = w.Write([]byte("data: {\"type\":\"message_delta\",\"delta\":{\"stop_reason\":\"end_turn\",\"stop_sequence\":null},\"usage\":{\"input_tokens\":10,\"output_tokens\":3}}\n\n"))
_, _ = w.Write([]byte("event: message_stop\n"))
_, _ = w.Write([]byte("data: {\"type\":\"message_stop\"}\n\n"))
}))
defer upstream.Close()
manager, _ := registerGitLabDuoAnthropicAuth(t, upstream.URL)
base := handlers.NewBaseAPIHandlers(&sdkconfig.SDKConfig{}, manager)
h := NewClaudeCodeAPIHandler(base)
router := gin.New()
router.POST("/v1/messages", h.ClaudeMessages)
req := httptest.NewRequest(http.MethodPost, "/v1/messages", strings.NewReader(`{
"model":"claude-sonnet-4-5",
"stream":true,
"max_tokens":64,
"messages":[{"role":"user","content":"hello"}]
}`))
req.Header.Set("Content-Type", "application/json")
req.Header.Set("Anthropic-Version", "2023-06-01")
resp := httptest.NewRecorder()
router.ServeHTTP(resp, req)
if resp.Code != http.StatusOK {
t.Fatalf("status = %d, want %d body=%s", resp.Code, http.StatusOK, resp.Body.String())
}
if gotPath != "/v1/proxy/anthropic/v1/messages" {
t.Fatalf("path = %q, want %q", gotPath, "/v1/proxy/anthropic/v1/messages")
}
if got := resp.Header().Get("Content-Type"); got != "text/event-stream" {
t.Fatalf("content-type = %q, want text/event-stream", got)
}
if !strings.Contains(resp.Body.String(), "event: content_block_delta") {
t.Fatalf("expected streamed claude event, got %s", resp.Body.String())
}
if !strings.Contains(resp.Body.String(), "hello from duo") {
t.Fatalf("expected streamed text, got %s", resp.Body.String())
}
}
func registerGitLabDuoAnthropicAuth(t *testing.T, upstreamURL string) (*coreauth.Manager, string) {
t.Helper()
manager := coreauth.NewManager(nil, nil, nil)
manager.RegisterExecutor(runtimeexecutor.NewGitLabExecutor(&internalconfig.Config{}))
auth := &coreauth.Auth{
ID: "gitlab-duo-claude-handler-test",
Provider: "gitlab",
Status: coreauth.StatusActive,
Metadata: map[string]any{
"duo_gateway_base_url": upstreamURL,
"duo_gateway_token": "gateway-token",
"duo_gateway_headers": map[string]string{"X-Gitlab-Realm": "saas"},
"model_provider": "anthropic",
"model_name": "claude-sonnet-4-5",
},
}
registered, err := manager.Register(context.Background(), auth)
if err != nil {
t.Fatalf("register auth: %v", err)
}
registry.GetGlobalRegistry().RegisterClient(registered.ID, registered.Provider, runtimeexecutor.GitLabModelsFromAuth(registered))
t.Cleanup(func() {
registry.GetGlobalRegistry().UnregisterClient(registered.ID)
})
return manager, registered.ID
}

View File

@@ -0,0 +1,143 @@
package openai
import (
"context"
"net/http"
"net/http/httptest"
"strings"
"testing"
"github.com/gin-gonic/gin"
internalconfig "github.com/router-for-me/CLIProxyAPI/v6/internal/config"
"github.com/router-for-me/CLIProxyAPI/v6/internal/registry"
runtimeexecutor "github.com/router-for-me/CLIProxyAPI/v6/internal/runtime/executor"
_ "github.com/router-for-me/CLIProxyAPI/v6/internal/translator"
"github.com/router-for-me/CLIProxyAPI/v6/sdk/api/handlers"
coreauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth"
sdkconfig "github.com/router-for-me/CLIProxyAPI/v6/sdk/config"
)
func TestOpenAIChatCompletionsWithGitLabDuoOpenAIGateway(t *testing.T) {
gin.SetMode(gin.TestMode)
var gotPath, gotAuthHeader, gotRealmHeader string
upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
gotPath = r.URL.Path
gotAuthHeader = r.Header.Get("Authorization")
gotRealmHeader = r.Header.Get("X-Gitlab-Realm")
w.Header().Set("Content-Type", "text/event-stream")
_, _ = w.Write([]byte("data: {\"type\":\"response.created\",\"response\":{\"id\":\"resp_1\",\"created_at\":1710000000,\"model\":\"gpt-5-codex\"}}\n\n"))
_, _ = w.Write([]byte("data: {\"type\":\"response.output_text.delta\",\"delta\":\"hello from duo openai\"}\n\n"))
_, _ = w.Write([]byte("data: {\"type\":\"response.completed\",\"response\":{\"id\":\"resp_1\",\"created_at\":1710000000,\"model\":\"gpt-5-codex\",\"status\":\"completed\",\"output\":[{\"type\":\"message\",\"id\":\"msg_1\",\"role\":\"assistant\",\"content\":[{\"type\":\"output_text\",\"text\":\"hello from duo openai\"}]}],\"usage\":{\"input_tokens\":11,\"output_tokens\":4,\"total_tokens\":15}}}\n\n"))
}))
defer upstream.Close()
manager := registerGitLabDuoOpenAIAuth(t, upstream.URL)
base := handlers.NewBaseAPIHandlers(&sdkconfig.SDKConfig{}, manager)
h := NewOpenAIAPIHandler(base)
router := gin.New()
router.POST("/v1/chat/completions", h.ChatCompletions)
req := httptest.NewRequest(http.MethodPost, "/v1/chat/completions", strings.NewReader(`{
"model":"gpt-5-codex",
"messages":[{"role":"user","content":"hello"}]
}`))
req.Header.Set("Content-Type", "application/json")
resp := httptest.NewRecorder()
router.ServeHTTP(resp, req)
if resp.Code != http.StatusOK {
t.Fatalf("status = %d, want %d body=%s", resp.Code, http.StatusOK, resp.Body.String())
}
if gotPath != "/v1/proxy/openai/v1/responses" {
t.Fatalf("path = %q, want %q", gotPath, "/v1/proxy/openai/v1/responses")
}
if gotAuthHeader != "Bearer gateway-token" {
t.Fatalf("authorization = %q, want Bearer gateway-token", gotAuthHeader)
}
if gotRealmHeader != "saas" {
t.Fatalf("x-gitlab-realm = %q, want saas", gotRealmHeader)
}
if !strings.Contains(resp.Body.String(), `"content":"hello from duo openai"`) {
t.Fatalf("expected translated chat completion, got %s", resp.Body.String())
}
}
func TestOpenAIResponsesStreamWithGitLabDuoOpenAIGateway(t *testing.T) {
gin.SetMode(gin.TestMode)
var gotPath, gotAuthHeader string
upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
gotPath = r.URL.Path
gotAuthHeader = r.Header.Get("Authorization")
w.Header().Set("Content-Type", "text/event-stream")
_, _ = w.Write([]byte("data: {\"type\":\"response.created\",\"response\":{\"id\":\"resp_1\",\"created_at\":1710000000,\"model\":\"gpt-5-codex\"}}\n\n"))
_, _ = w.Write([]byte("data: {\"type\":\"response.output_text.delta\",\"delta\":\"streamed duo output\"}\n\n"))
_, _ = w.Write([]byte("data: {\"type\":\"response.completed\",\"response\":{\"id\":\"resp_1\",\"created_at\":1710000000,\"model\":\"gpt-5-codex\",\"status\":\"completed\",\"output\":[{\"type\":\"message\",\"id\":\"msg_1\",\"role\":\"assistant\",\"content\":[{\"type\":\"output_text\",\"text\":\"streamed duo output\"}]}],\"usage\":{\"input_tokens\":10,\"output_tokens\":3,\"total_tokens\":13}}}\n\n"))
}))
defer upstream.Close()
manager := registerGitLabDuoOpenAIAuth(t, upstream.URL)
base := handlers.NewBaseAPIHandlers(&sdkconfig.SDKConfig{}, manager)
h := NewOpenAIResponsesAPIHandler(base)
router := gin.New()
router.POST("/v1/responses", h.Responses)
req := httptest.NewRequest(http.MethodPost, "/v1/responses", strings.NewReader(`{
"model":"gpt-5-codex",
"stream":true,
"input":"hello"
}`))
req.Header.Set("Content-Type", "application/json")
resp := httptest.NewRecorder()
router.ServeHTTP(resp, req)
if resp.Code != http.StatusOK {
t.Fatalf("status = %d, want %d body=%s", resp.Code, http.StatusOK, resp.Body.String())
}
if gotPath != "/v1/proxy/openai/v1/responses" {
t.Fatalf("path = %q, want %q", gotPath, "/v1/proxy/openai/v1/responses")
}
if gotAuthHeader != "Bearer gateway-token" {
t.Fatalf("authorization = %q, want Bearer gateway-token", gotAuthHeader)
}
if got := resp.Header().Get("Content-Type"); got != "text/event-stream" {
t.Fatalf("content-type = %q, want text/event-stream", got)
}
if !strings.Contains(resp.Body.String(), `"type":"response.output_text.delta"`) {
t.Fatalf("expected streamed responses delta, got %s", resp.Body.String())
}
if !strings.Contains(resp.Body.String(), `"type":"response.completed"`) {
t.Fatalf("expected streamed responses completion, got %s", resp.Body.String())
}
}
func registerGitLabDuoOpenAIAuth(t *testing.T, upstreamURL string) *coreauth.Manager {
t.Helper()
manager := coreauth.NewManager(nil, nil, nil)
manager.RegisterExecutor(runtimeexecutor.NewGitLabExecutor(&internalconfig.Config{}))
auth := &coreauth.Auth{
ID: "gitlab-duo-openai-handler-test",
Provider: "gitlab",
Status: coreauth.StatusActive,
Metadata: map[string]any{
"duo_gateway_base_url": upstreamURL,
"duo_gateway_token": "gateway-token",
"duo_gateway_headers": map[string]string{"X-Gitlab-Realm": "saas"},
"model_provider": "openai",
"model_name": "gpt-5-codex",
},
}
registered, err := manager.Register(context.Background(), auth)
if err != nil {
t.Fatalf("register auth: %v", err)
}
registry.GetGlobalRegistry().RegisterClient(registered.ID, registered.Provider, runtimeexecutor.GitLabModelsFromAuth(registered))
t.Cleanup(func() {
registry.GetGlobalRegistry().UnregisterClient(registered.ID)
})
return manager
}

View File

@@ -14,7 +14,11 @@ import (
"github.com/google/uuid"
"github.com/gorilla/websocket"
"github.com/router-for-me/CLIProxyAPI/v6/internal/interfaces"
"github.com/router-for-me/CLIProxyAPI/v6/internal/registry"
"github.com/router-for-me/CLIProxyAPI/v6/internal/thinking"
"github.com/router-for-me/CLIProxyAPI/v6/internal/util"
"github.com/router-for-me/CLIProxyAPI/v6/sdk/api/handlers"
coreauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth"
cliproxyexecutor "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/executor"
log "github.com/sirupsen/logrus"
"github.com/tidwall/gjson"
@@ -26,11 +30,12 @@ const (
wsRequestTypeAppend = "response.append"
wsEventTypeError = "error"
wsEventTypeCompleted = "response.completed"
wsEventTypeDone = "response.done"
wsDoneMarker = "[DONE]"
wsTurnStateHeader = "x-codex-turn-state"
wsRequestBodyKey = "REQUEST_BODY_OVERRIDE"
wsPayloadLogMaxSize = 2048
wsBodyLogMaxSize = 64 * 1024
wsBodyLogTruncated = "\n[websocket log truncated]\n"
)
var responsesWebsocketUpgrader = websocket.Upgrader{
@@ -101,11 +106,17 @@ func (h *OpenAIResponsesAPIHandler) ResponsesWebsocket(c *gin.Context) {
// )
appendWebsocketEvent(&wsBodyLog, "request", payload)
allowIncrementalInputWithPreviousResponseID := websocketUpstreamSupportsIncrementalInput(nil, nil)
allowIncrementalInputWithPreviousResponseID := false
if pinnedAuthID != "" && h != nil && h.AuthManager != nil {
if pinnedAuth, ok := h.AuthManager.GetByID(pinnedAuthID); ok && pinnedAuth != nil {
allowIncrementalInputWithPreviousResponseID = websocketUpstreamSupportsIncrementalInput(pinnedAuth.Attributes, pinnedAuth.Metadata)
}
} else {
requestModelName := strings.TrimSpace(gjson.GetBytes(payload, "model").String())
if requestModelName == "" {
requestModelName = strings.TrimSpace(gjson.GetBytes(lastRequest, "model").String())
}
allowIncrementalInputWithPreviousResponseID = h.websocketUpstreamSupportsIncrementalInputForModel(requestModelName)
}
var requestJSON []byte
@@ -140,6 +151,22 @@ func (h *OpenAIResponsesAPIHandler) ResponsesWebsocket(c *gin.Context) {
}
continue
}
if shouldHandleResponsesWebsocketPrewarmLocally(payload, lastRequest, allowIncrementalInputWithPreviousResponseID) {
if updated, errDelete := sjson.DeleteBytes(requestJSON, "generate"); errDelete == nil {
requestJSON = updated
}
if updated, errDelete := sjson.DeleteBytes(updatedLastRequest, "generate"); errDelete == nil {
updatedLastRequest = updated
}
lastRequest = updatedLastRequest
lastResponseOutput = []byte("[]")
if errWrite := writeResponsesWebsocketSyntheticPrewarm(c, conn, requestJSON, &wsBodyLog, passthroughSessionID); errWrite != nil {
wsTerminateErr = errWrite
appendWebsocketEvent(&wsBodyLog, "disconnect", []byte(errWrite.Error()))
return
}
continue
}
lastRequest = updatedLastRequest
modelName := gjson.GetBytes(requestJSON, "model").String()
@@ -340,6 +367,192 @@ func websocketUpstreamSupportsIncrementalInput(attributes map[string]string, met
return false
}
func (h *OpenAIResponsesAPIHandler) websocketUpstreamSupportsIncrementalInputForModel(modelName string) bool {
if h == nil || h.AuthManager == nil {
return false
}
resolvedModelName := modelName
initialSuffix := thinking.ParseSuffix(modelName)
if initialSuffix.ModelName == "auto" {
resolvedBase := util.ResolveAutoModel(initialSuffix.ModelName)
if initialSuffix.HasSuffix {
resolvedModelName = fmt.Sprintf("%s(%s)", resolvedBase, initialSuffix.RawSuffix)
} else {
resolvedModelName = resolvedBase
}
} else {
resolvedModelName = util.ResolveAutoModel(modelName)
}
parsed := thinking.ParseSuffix(resolvedModelName)
baseModel := strings.TrimSpace(parsed.ModelName)
providers := util.GetProviderName(baseModel)
if len(providers) == 0 && baseModel != resolvedModelName {
providers = util.GetProviderName(resolvedModelName)
}
if len(providers) == 0 {
return false
}
providerSet := make(map[string]struct{}, len(providers))
for i := 0; i < len(providers); i++ {
providerKey := strings.TrimSpace(strings.ToLower(providers[i]))
if providerKey == "" {
continue
}
providerSet[providerKey] = struct{}{}
}
if len(providerSet) == 0 {
return false
}
modelKey := baseModel
if modelKey == "" {
modelKey = strings.TrimSpace(resolvedModelName)
}
registryRef := registry.GetGlobalRegistry()
now := time.Now()
auths := h.AuthManager.List()
for i := 0; i < len(auths); i++ {
auth := auths[i]
if auth == nil {
continue
}
providerKey := strings.TrimSpace(strings.ToLower(auth.Provider))
if _, ok := providerSet[providerKey]; !ok {
continue
}
if modelKey != "" && registryRef != nil && !registryRef.ClientSupportsModel(auth.ID, modelKey) {
continue
}
if !responsesWebsocketAuthAvailableForModel(auth, modelKey, now) {
continue
}
if websocketUpstreamSupportsIncrementalInput(auth.Attributes, auth.Metadata) {
return true
}
}
return false
}
func responsesWebsocketAuthAvailableForModel(auth *coreauth.Auth, modelName string, now time.Time) bool {
if auth == nil {
return false
}
if auth.Disabled || auth.Status == coreauth.StatusDisabled {
return false
}
if modelName != "" && len(auth.ModelStates) > 0 {
state, ok := auth.ModelStates[modelName]
if (!ok || state == nil) && modelName != "" {
baseModel := strings.TrimSpace(thinking.ParseSuffix(modelName).ModelName)
if baseModel != "" && baseModel != modelName {
state, ok = auth.ModelStates[baseModel]
}
}
if ok && state != nil {
if state.Status == coreauth.StatusDisabled {
return false
}
if state.Unavailable && !state.NextRetryAfter.IsZero() && state.NextRetryAfter.After(now) {
return false
}
return true
}
}
if auth.Unavailable && !auth.NextRetryAfter.IsZero() && auth.NextRetryAfter.After(now) {
return false
}
return true
}
func shouldHandleResponsesWebsocketPrewarmLocally(rawJSON []byte, lastRequest []byte, allowIncrementalInputWithPreviousResponseID bool) bool {
if allowIncrementalInputWithPreviousResponseID || len(lastRequest) != 0 {
return false
}
if strings.TrimSpace(gjson.GetBytes(rawJSON, "type").String()) != wsRequestTypeCreate {
return false
}
generateResult := gjson.GetBytes(rawJSON, "generate")
return generateResult.Exists() && !generateResult.Bool()
}
func writeResponsesWebsocketSyntheticPrewarm(
c *gin.Context,
conn *websocket.Conn,
requestJSON []byte,
wsBodyLog *strings.Builder,
sessionID string,
) error {
payloads, errPayloads := syntheticResponsesWebsocketPrewarmPayloads(requestJSON)
if errPayloads != nil {
return errPayloads
}
for i := 0; i < len(payloads); i++ {
markAPIResponseTimestamp(c)
appendWebsocketEvent(wsBodyLog, "response", payloads[i])
// log.Infof(
// "responses websocket: downstream_out id=%s type=%d event=%s payload=%s",
// sessionID,
// websocket.TextMessage,
// websocketPayloadEventType(payloads[i]),
// websocketPayloadPreview(payloads[i]),
// )
if errWrite := conn.WriteMessage(websocket.TextMessage, payloads[i]); errWrite != nil {
log.Warnf(
"responses websocket: downstream_out write failed id=%s event=%s error=%v",
sessionID,
websocketPayloadEventType(payloads[i]),
errWrite,
)
return errWrite
}
}
return nil
}
func syntheticResponsesWebsocketPrewarmPayloads(requestJSON []byte) ([][]byte, error) {
responseID := "resp_prewarm_" + uuid.NewString()
createdAt := time.Now().Unix()
modelName := strings.TrimSpace(gjson.GetBytes(requestJSON, "model").String())
createdPayload := []byte(`{"type":"response.created","sequence_number":0,"response":{"id":"","object":"response","created_at":0,"status":"in_progress","background":false,"error":null,"output":[]}}`)
var errSet error
createdPayload, errSet = sjson.SetBytes(createdPayload, "response.id", responseID)
if errSet != nil {
return nil, errSet
}
createdPayload, errSet = sjson.SetBytes(createdPayload, "response.created_at", createdAt)
if errSet != nil {
return nil, errSet
}
if modelName != "" {
createdPayload, errSet = sjson.SetBytes(createdPayload, "response.model", modelName)
if errSet != nil {
return nil, errSet
}
}
completedPayload := []byte(`{"type":"response.completed","sequence_number":1,"response":{"id":"","object":"response","created_at":0,"status":"completed","background":false,"error":null,"output":[],"usage":{"input_tokens":0,"output_tokens":0,"total_tokens":0}}}`)
completedPayload, errSet = sjson.SetBytes(completedPayload, "response.id", responseID)
if errSet != nil {
return nil, errSet
}
completedPayload, errSet = sjson.SetBytes(completedPayload, "response.created_at", createdAt)
if errSet != nil {
return nil, errSet
}
if modelName != "" {
completedPayload, errSet = sjson.SetBytes(completedPayload, "response.model", modelName)
if errSet != nil {
return nil, errSet
}
}
return [][]byte{createdPayload, completedPayload}, nil
}
func mergeJSONArrayRaw(existingRaw, appendRaw string) (string, error) {
existingRaw = strings.TrimSpace(existingRaw)
appendRaw = strings.TrimSpace(appendRaw)
@@ -469,9 +682,6 @@ func (h *OpenAIResponsesAPIHandler) forwardResponsesWebsocket(
for i := range payloads {
eventType := gjson.GetBytes(payloads[i], "type").String()
if eventType == wsEventTypeCompleted {
// log.Infof("replace %s with %s", wsEventTypeCompleted, wsEventTypeDone)
payloads[i], _ = sjson.SetBytes(payloads[i], "type", wsEventTypeDone)
completed = true
completedOutput = responseCompletedOutputFromPayload(payloads[i])
}
@@ -554,65 +764,134 @@ func writeResponsesWebsocketError(conn *websocket.Conn, errMsg *interfaces.Error
}
body := handlers.BuildErrorResponseBody(status, errText)
payload := map[string]any{
"type": wsEventTypeError,
"status": status,
payload := []byte(`{}`)
var errSet error
payload, errSet = sjson.SetBytes(payload, "type", wsEventTypeError)
if errSet != nil {
return nil, errSet
}
payload, errSet = sjson.SetBytes(payload, "status", status)
if errSet != nil {
return nil, errSet
}
if errMsg != nil && errMsg.Addon != nil {
headers := map[string]any{}
headers := []byte(`{}`)
hasHeaders := false
for key, values := range errMsg.Addon {
if len(values) == 0 {
continue
}
headers[key] = values[0]
headerPath := strings.ReplaceAll(strings.ReplaceAll(key, `\\`, `\\\\`), ".", `\\.`)
headers, errSet = sjson.SetBytes(headers, headerPath, values[0])
if errSet != nil {
return nil, errSet
}
hasHeaders = true
}
if len(headers) > 0 {
payload["headers"] = headers
}
}
if len(body) > 0 && json.Valid(body) {
var decoded map[string]any
if errDecode := json.Unmarshal(body, &decoded); errDecode == nil {
if inner, ok := decoded["error"]; ok {
payload["error"] = inner
} else {
payload["error"] = decoded
if hasHeaders {
payload, errSet = sjson.SetRawBytes(payload, "headers", headers)
if errSet != nil {
return nil, errSet
}
}
}
if _, ok := payload["error"]; !ok {
payload["error"] = map[string]any{
"type": "server_error",
"message": errText,
if len(body) > 0 && json.Valid(body) {
errorNode := gjson.GetBytes(body, "error")
if errorNode.Exists() {
payload, errSet = sjson.SetRawBytes(payload, "error", []byte(errorNode.Raw))
} else {
payload, errSet = sjson.SetRawBytes(payload, "error", body)
}
if errSet != nil {
return nil, errSet
}
}
data, err := json.Marshal(payload)
if err != nil {
return nil, err
if !gjson.GetBytes(payload, "error").Exists() {
payload, errSet = sjson.SetBytes(payload, "error.type", "server_error")
if errSet != nil {
return nil, errSet
}
payload, errSet = sjson.SetBytes(payload, "error.message", errText)
if errSet != nil {
return nil, errSet
}
}
return data, conn.WriteMessage(websocket.TextMessage, data)
return payload, conn.WriteMessage(websocket.TextMessage, payload)
}
func appendWebsocketEvent(builder *strings.Builder, eventType string, payload []byte) {
if builder == nil {
return
}
if builder.Len() >= wsBodyLogMaxSize {
return
}
trimmedPayload := bytes.TrimSpace(payload)
if len(trimmedPayload) == 0 {
return
}
if builder.Len() > 0 {
builder.WriteString("\n")
if !appendWebsocketLogString(builder, "\n") {
return
}
}
builder.WriteString("websocket.")
builder.WriteString(eventType)
builder.WriteString("\n")
builder.Write(trimmedPayload)
builder.WriteString("\n")
if !appendWebsocketLogString(builder, "websocket.") {
return
}
if !appendWebsocketLogString(builder, eventType) {
return
}
if !appendWebsocketLogString(builder, "\n") {
return
}
if !appendWebsocketLogBytes(builder, trimmedPayload, len(wsBodyLogTruncated)) {
appendWebsocketLogString(builder, wsBodyLogTruncated)
return
}
appendWebsocketLogString(builder, "\n")
}
func appendWebsocketLogString(builder *strings.Builder, value string) bool {
if builder == nil {
return false
}
remaining := wsBodyLogMaxSize - builder.Len()
if remaining <= 0 {
return false
}
if len(value) <= remaining {
builder.WriteString(value)
return true
}
builder.WriteString(value[:remaining])
return false
}
func appendWebsocketLogBytes(builder *strings.Builder, value []byte, reserveForSuffix int) bool {
if builder == nil {
return false
}
remaining := wsBodyLogMaxSize - builder.Len()
if remaining <= 0 {
return false
}
if len(value) <= remaining {
builder.Write(value)
return true
}
limit := remaining - reserveForSuffix
if limit < 0 {
limit = 0
}
if limit > len(value) {
limit = len(value)
}
builder.Write(value[:limit])
return false
}
func websocketPayloadEventType(payload []byte) string {

View File

@@ -2,15 +2,57 @@ package openai
import (
"bytes"
"context"
"errors"
"fmt"
"net/http"
"net/http/httptest"
"strings"
"testing"
"github.com/gin-gonic/gin"
"github.com/gorilla/websocket"
"github.com/router-for-me/CLIProxyAPI/v6/internal/interfaces"
"github.com/router-for-me/CLIProxyAPI/v6/internal/registry"
"github.com/router-for-me/CLIProxyAPI/v6/sdk/api/handlers"
coreauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth"
coreexecutor "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/executor"
sdkconfig "github.com/router-for-me/CLIProxyAPI/v6/sdk/config"
"github.com/tidwall/gjson"
)
type websocketCaptureExecutor struct {
streamCalls int
payloads [][]byte
}
func (e *websocketCaptureExecutor) Identifier() string { return "test-provider" }
func (e *websocketCaptureExecutor) Execute(context.Context, *coreauth.Auth, coreexecutor.Request, coreexecutor.Options) (coreexecutor.Response, error) {
return coreexecutor.Response{}, errors.New("not implemented")
}
func (e *websocketCaptureExecutor) ExecuteStream(_ context.Context, _ *coreauth.Auth, req coreexecutor.Request, _ coreexecutor.Options) (*coreexecutor.StreamResult, error) {
e.streamCalls++
e.payloads = append(e.payloads, bytes.Clone(req.Payload))
chunks := make(chan coreexecutor.StreamChunk, 1)
chunks <- coreexecutor.StreamChunk{Payload: []byte(`{"type":"response.completed","response":{"id":"resp-upstream","output":[{"type":"message","id":"out-1"}]}}`)}
close(chunks)
return &coreexecutor.StreamResult{Chunks: chunks}, nil
}
func (e *websocketCaptureExecutor) Refresh(_ context.Context, auth *coreauth.Auth) (*coreauth.Auth, error) {
return auth, nil
}
func (e *websocketCaptureExecutor) CountTokens(context.Context, *coreauth.Auth, coreexecutor.Request, coreexecutor.Options) (coreexecutor.Response, error) {
return coreexecutor.Response{}, errors.New("not implemented")
}
func (e *websocketCaptureExecutor) HttpRequest(context.Context, *coreauth.Auth, *http.Request) (*http.Response, error) {
return nil, errors.New("not implemented")
}
func TestNormalizeResponsesWebsocketRequestCreate(t *testing.T) {
raw := []byte(`{"type":"response.create","model":"test-model","stream":false,"input":[{"type":"message","id":"msg-1"}]}`)
@@ -224,6 +266,33 @@ func TestAppendWebsocketEvent(t *testing.T) {
}
}
func TestAppendWebsocketEventTruncatesAtLimit(t *testing.T) {
var builder strings.Builder
payload := bytes.Repeat([]byte("x"), wsBodyLogMaxSize)
appendWebsocketEvent(&builder, "request", payload)
got := builder.String()
if len(got) > wsBodyLogMaxSize {
t.Fatalf("body log len = %d, want <= %d", len(got), wsBodyLogMaxSize)
}
if !strings.Contains(got, wsBodyLogTruncated) {
t.Fatalf("expected truncation marker in body log")
}
}
func TestAppendWebsocketEventNoGrowthAfterLimit(t *testing.T) {
var builder strings.Builder
appendWebsocketEvent(&builder, "request", bytes.Repeat([]byte("x"), wsBodyLogMaxSize))
initial := builder.String()
appendWebsocketEvent(&builder, "response", []byte(`{"type":"response.completed"}`))
if builder.String() != initial {
t.Fatalf("builder grew after reaching limit")
}
}
func TestSetWebsocketRequestBody(t *testing.T) {
gin.SetMode(gin.TestMode)
recorder := httptest.NewRecorder()
@@ -247,3 +316,206 @@ func TestSetWebsocketRequestBody(t *testing.T) {
t.Fatalf("request body = %q, want %q", string(bodyBytes), "event body")
}
}
func TestForwardResponsesWebsocketPreservesCompletedEvent(t *testing.T) {
gin.SetMode(gin.TestMode)
serverErrCh := make(chan error, 1)
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
conn, err := responsesWebsocketUpgrader.Upgrade(w, r, nil)
if err != nil {
serverErrCh <- err
return
}
defer func() {
errClose := conn.Close()
if errClose != nil {
serverErrCh <- errClose
}
}()
ctx, _ := gin.CreateTestContext(httptest.NewRecorder())
ctx.Request = r
data := make(chan []byte, 1)
errCh := make(chan *interfaces.ErrorMessage)
data <- []byte("data: {\"type\":\"response.completed\",\"response\":{\"id\":\"resp-1\",\"output\":[{\"type\":\"message\",\"id\":\"out-1\"}]}}\n\n")
close(data)
close(errCh)
var bodyLog strings.Builder
completedOutput, err := (*OpenAIResponsesAPIHandler)(nil).forwardResponsesWebsocket(
ctx,
conn,
func(...interface{}) {},
data,
errCh,
&bodyLog,
"session-1",
)
if err != nil {
serverErrCh <- err
return
}
if gjson.GetBytes(completedOutput, "0.id").String() != "out-1" {
serverErrCh <- errors.New("completed output not captured")
return
}
serverErrCh <- nil
}))
defer server.Close()
wsURL := "ws" + strings.TrimPrefix(server.URL, "http")
conn, _, err := websocket.DefaultDialer.Dial(wsURL, nil)
if err != nil {
t.Fatalf("dial websocket: %v", err)
}
defer func() {
errClose := conn.Close()
if errClose != nil {
t.Fatalf("close websocket: %v", errClose)
}
}()
_, payload, errReadMessage := conn.ReadMessage()
if errReadMessage != nil {
t.Fatalf("read websocket message: %v", errReadMessage)
}
if gjson.GetBytes(payload, "type").String() != wsEventTypeCompleted {
t.Fatalf("payload type = %s, want %s", gjson.GetBytes(payload, "type").String(), wsEventTypeCompleted)
}
if strings.Contains(string(payload), "response.done") {
t.Fatalf("payload unexpectedly rewrote completed event: %s", payload)
}
if errServer := <-serverErrCh; errServer != nil {
t.Fatalf("server error: %v", errServer)
}
}
func TestWebsocketUpstreamSupportsIncrementalInputForModel(t *testing.T) {
manager := coreauth.NewManager(nil, nil, nil)
auth := &coreauth.Auth{
ID: "auth-ws",
Provider: "test-provider",
Status: coreauth.StatusActive,
Attributes: map[string]string{"websockets": "true"},
}
if _, err := manager.Register(context.Background(), auth); err != nil {
t.Fatalf("Register auth: %v", err)
}
registry.GetGlobalRegistry().RegisterClient(auth.ID, auth.Provider, []*registry.ModelInfo{{ID: "test-model"}})
t.Cleanup(func() {
registry.GetGlobalRegistry().UnregisterClient(auth.ID)
})
base := handlers.NewBaseAPIHandlers(&sdkconfig.SDKConfig{}, manager)
h := NewOpenAIResponsesAPIHandler(base)
if !h.websocketUpstreamSupportsIncrementalInputForModel("test-model") {
t.Fatalf("expected websocket-capable upstream for test-model")
}
}
func TestResponsesWebsocketPrewarmHandledLocallyForSSEUpstream(t *testing.T) {
gin.SetMode(gin.TestMode)
executor := &websocketCaptureExecutor{}
manager := coreauth.NewManager(nil, nil, nil)
manager.RegisterExecutor(executor)
auth := &coreauth.Auth{ID: "auth-sse", Provider: executor.Identifier(), Status: coreauth.StatusActive}
if _, err := manager.Register(context.Background(), auth); err != nil {
t.Fatalf("Register auth: %v", err)
}
registry.GetGlobalRegistry().RegisterClient(auth.ID, auth.Provider, []*registry.ModelInfo{{ID: "test-model"}})
t.Cleanup(func() {
registry.GetGlobalRegistry().UnregisterClient(auth.ID)
})
base := handlers.NewBaseAPIHandlers(&sdkconfig.SDKConfig{}, manager)
h := NewOpenAIResponsesAPIHandler(base)
router := gin.New()
router.GET("/v1/responses/ws", h.ResponsesWebsocket)
server := httptest.NewServer(router)
defer server.Close()
wsURL := "ws" + strings.TrimPrefix(server.URL, "http") + "/v1/responses/ws"
conn, _, err := websocket.DefaultDialer.Dial(wsURL, nil)
if err != nil {
t.Fatalf("dial websocket: %v", err)
}
defer func() {
errClose := conn.Close()
if errClose != nil {
t.Fatalf("close websocket: %v", errClose)
}
}()
errWrite := conn.WriteMessage(websocket.TextMessage, []byte(`{"type":"response.create","model":"test-model","generate":false}`))
if errWrite != nil {
t.Fatalf("write prewarm websocket message: %v", errWrite)
}
_, createdPayload, errReadMessage := conn.ReadMessage()
if errReadMessage != nil {
t.Fatalf("read prewarm created message: %v", errReadMessage)
}
if gjson.GetBytes(createdPayload, "type").String() != "response.created" {
t.Fatalf("created payload type = %s, want response.created", gjson.GetBytes(createdPayload, "type").String())
}
prewarmResponseID := gjson.GetBytes(createdPayload, "response.id").String()
if prewarmResponseID == "" {
t.Fatalf("prewarm response id is empty")
}
if executor.streamCalls != 0 {
t.Fatalf("stream calls after prewarm = %d, want 0", executor.streamCalls)
}
_, completedPayload, errReadMessage := conn.ReadMessage()
if errReadMessage != nil {
t.Fatalf("read prewarm completed message: %v", errReadMessage)
}
if gjson.GetBytes(completedPayload, "type").String() != wsEventTypeCompleted {
t.Fatalf("completed payload type = %s, want %s", gjson.GetBytes(completedPayload, "type").String(), wsEventTypeCompleted)
}
if gjson.GetBytes(completedPayload, "response.id").String() != prewarmResponseID {
t.Fatalf("completed response id = %s, want %s", gjson.GetBytes(completedPayload, "response.id").String(), prewarmResponseID)
}
if gjson.GetBytes(completedPayload, "response.usage.total_tokens").Int() != 0 {
t.Fatalf("prewarm total tokens = %d, want 0", gjson.GetBytes(completedPayload, "response.usage.total_tokens").Int())
}
secondRequest := fmt.Sprintf(`{"type":"response.create","previous_response_id":%q,"input":[{"type":"message","id":"msg-1"}]}`, prewarmResponseID)
errWrite = conn.WriteMessage(websocket.TextMessage, []byte(secondRequest))
if errWrite != nil {
t.Fatalf("write follow-up websocket message: %v", errWrite)
}
_, upstreamPayload, errReadMessage := conn.ReadMessage()
if errReadMessage != nil {
t.Fatalf("read upstream completed message: %v", errReadMessage)
}
if gjson.GetBytes(upstreamPayload, "type").String() != wsEventTypeCompleted {
t.Fatalf("upstream payload type = %s, want %s", gjson.GetBytes(upstreamPayload, "type").String(), wsEventTypeCompleted)
}
if executor.streamCalls != 1 {
t.Fatalf("stream calls after follow-up = %d, want 1", executor.streamCalls)
}
if len(executor.payloads) != 1 {
t.Fatalf("captured upstream payloads = %d, want 1", len(executor.payloads))
}
forwarded := executor.payloads[0]
if gjson.GetBytes(forwarded, "previous_response_id").Exists() {
t.Fatalf("previous_response_id leaked upstream: %s", forwarded)
}
if gjson.GetBytes(forwarded, "generate").Exists() {
t.Fatalf("generate leaked upstream: %s", forwarded)
}
if gjson.GetBytes(forwarded, "model").String() != "test-model" {
t.Fatalf("forwarded model = %s, want test-model", gjson.GetBytes(forwarded, "model").String())
}
input := gjson.GetBytes(forwarded, "input").Array()
if len(input) != 1 || input[0].Get("id").String() != "msg-1" {
t.Fatalf("unexpected forwarded input: %s", forwarded)
}
}

View File

@@ -287,5 +287,8 @@ func (a *CodexAuthenticator) buildAuthRecord(authSvc *codex.CodexAuth, authBundl
FileName: fileName,
Storage: tokenStorage,
Metadata: metadata,
Attributes: map[string]string{
"plan_type": planType,
},
}, nil
}

485
sdk/auth/gitlab.go Normal file
View File

@@ -0,0 +1,485 @@
package auth
import (
"context"
"fmt"
"os"
"strings"
"time"
gitlabauth "github.com/router-for-me/CLIProxyAPI/v6/internal/auth/gitlab"
"github.com/router-for-me/CLIProxyAPI/v6/internal/browser"
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
"github.com/router-for-me/CLIProxyAPI/v6/internal/misc"
"github.com/router-for-me/CLIProxyAPI/v6/internal/util"
coreauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth"
log "github.com/sirupsen/logrus"
)
const (
gitLabLoginModeMetadataKey = "login_mode"
gitLabLoginModeOAuth = "oauth"
gitLabLoginModePAT = "pat"
gitLabBaseURLMetadataKey = "base_url"
gitLabOAuthClientIDMetadataKey = "oauth_client_id"
gitLabOAuthClientSecretMetadataKey = "oauth_client_secret"
gitLabPersonalAccessTokenMetadataKey = "personal_access_token"
)
var gitLabRefreshLead = 5 * time.Minute
type GitLabAuthenticator struct {
CallbackPort int
}
func NewGitLabAuthenticator() *GitLabAuthenticator {
return &GitLabAuthenticator{CallbackPort: gitlabauth.DefaultCallbackPort}
}
func (a *GitLabAuthenticator) Provider() string {
return "gitlab"
}
func (a *GitLabAuthenticator) RefreshLead() *time.Duration {
return &gitLabRefreshLead
}
func (a *GitLabAuthenticator) Login(ctx context.Context, cfg *config.Config, opts *LoginOptions) (*coreauth.Auth, error) {
if cfg == nil {
return nil, fmt.Errorf("cliproxy auth: configuration is required")
}
if ctx == nil {
ctx = context.Background()
}
if opts == nil {
opts = &LoginOptions{}
}
switch strings.ToLower(strings.TrimSpace(opts.Metadata[gitLabLoginModeMetadataKey])) {
case "", gitLabLoginModeOAuth:
return a.loginOAuth(ctx, cfg, opts)
case gitLabLoginModePAT:
return a.loginPAT(ctx, cfg, opts)
default:
return nil, fmt.Errorf("gitlab auth: unsupported login mode %q", opts.Metadata[gitLabLoginModeMetadataKey])
}
}
func (a *GitLabAuthenticator) loginOAuth(ctx context.Context, cfg *config.Config, opts *LoginOptions) (*coreauth.Auth, error) {
client := gitlabauth.NewAuthClient(cfg)
baseURL := a.resolveString(opts, gitLabBaseURLMetadataKey, gitlabauth.DefaultBaseURL)
clientID, err := a.requireInput(opts, gitLabOAuthClientIDMetadataKey, "Enter GitLab OAuth application client ID: ")
if err != nil {
return nil, err
}
clientSecret, err := a.optionalInput(opts, gitLabOAuthClientSecretMetadataKey, "Enter GitLab OAuth application client secret (press Enter for public PKCE app): ")
if err != nil {
return nil, err
}
callbackPort := a.CallbackPort
if opts.CallbackPort > 0 {
callbackPort = opts.CallbackPort
}
redirectURI := gitlabauth.RedirectURL(callbackPort)
pkceCodes, err := gitlabauth.GeneratePKCECodes()
if err != nil {
return nil, err
}
state, err := misc.GenerateRandomState()
if err != nil {
return nil, fmt.Errorf("gitlab state generation failed: %w", err)
}
oauthServer := gitlabauth.NewOAuthServer(callbackPort)
if err := oauthServer.Start(); err != nil {
return nil, err
}
defer func() {
stopCtx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
defer cancel()
if stopErr := oauthServer.Stop(stopCtx); stopErr != nil {
log.Warnf("gitlab oauth server stop error: %v", stopErr)
}
}()
authURL, err := client.GenerateAuthURL(baseURL, clientID, redirectURI, state, pkceCodes)
if err != nil {
return nil, err
}
if !opts.NoBrowser {
fmt.Println("Opening browser for GitLab Duo authentication")
if !browser.IsAvailable() {
log.Warn("No browser available; please open the URL manually")
util.PrintSSHTunnelInstructions(callbackPort)
fmt.Printf("Visit the following URL to continue authentication:\n%s\n", authURL)
} else if err = browser.OpenURL(authURL); err != nil {
log.Warnf("Failed to open browser automatically: %v", err)
util.PrintSSHTunnelInstructions(callbackPort)
fmt.Printf("Visit the following URL to continue authentication:\n%s\n", authURL)
}
} else {
util.PrintSSHTunnelInstructions(callbackPort)
fmt.Printf("Visit the following URL to continue authentication:\n%s\n", authURL)
}
fmt.Println("Waiting for GitLab OAuth callback...")
callbackCh := make(chan *gitlabauth.OAuthResult, 1)
callbackErrCh := make(chan error, 1)
go func() {
result, waitErr := oauthServer.WaitForCallback(5 * time.Minute)
if waitErr != nil {
callbackErrCh <- waitErr
return
}
callbackCh <- result
}()
var result *gitlabauth.OAuthResult
var manualPromptTimer *time.Timer
var manualPromptC <-chan time.Time
if opts.Prompt != nil {
manualPromptTimer = time.NewTimer(15 * time.Second)
manualPromptC = manualPromptTimer.C
defer manualPromptTimer.Stop()
}
waitForCallback:
for {
select {
case result = <-callbackCh:
break waitForCallback
case err = <-callbackErrCh:
return nil, err
case <-manualPromptC:
manualPromptC = nil
if manualPromptTimer != nil {
manualPromptTimer.Stop()
}
input, promptErr := opts.Prompt("Paste the GitLab callback URL (or press Enter to keep waiting): ")
if promptErr != nil {
return nil, promptErr
}
parsed, parseErr := misc.ParseOAuthCallback(input)
if parseErr != nil {
return nil, parseErr
}
if parsed == nil {
continue
}
result = &gitlabauth.OAuthResult{
Code: parsed.Code,
State: parsed.State,
Error: parsed.Error,
}
break waitForCallback
}
}
if result.Error != "" {
return nil, fmt.Errorf("gitlab oauth returned error: %s", result.Error)
}
if result.State != state {
return nil, fmt.Errorf("gitlab auth: state mismatch")
}
tokenResp, err := client.ExchangeCodeForTokens(ctx, baseURL, clientID, clientSecret, redirectURI, result.Code, pkceCodes.CodeVerifier)
if err != nil {
return nil, err
}
accessToken := strings.TrimSpace(tokenResp.AccessToken)
if accessToken == "" {
return nil, fmt.Errorf("gitlab auth: missing access token")
}
user, err := client.GetCurrentUser(ctx, baseURL, accessToken)
if err != nil {
return nil, err
}
direct, err := client.FetchDirectAccess(ctx, baseURL, accessToken)
if err != nil {
return nil, err
}
identifier := gitLabAccountIdentifier(user)
fileName := fmt.Sprintf("gitlab-%s.json", sanitizeGitLabFileName(identifier))
metadata := buildGitLabAuthMetadata(baseURL, gitLabLoginModeOAuth, tokenResp, direct)
metadata["auth_kind"] = "oauth"
metadata[gitLabOAuthClientIDMetadataKey] = clientID
if strings.TrimSpace(clientSecret) != "" {
metadata[gitLabOAuthClientSecretMetadataKey] = clientSecret
}
metadata["username"] = strings.TrimSpace(user.Username)
if email := strings.TrimSpace(primaryGitLabEmail(user)); email != "" {
metadata["email"] = email
}
metadata["name"] = strings.TrimSpace(user.Name)
fmt.Println("GitLab Duo authentication successful")
return &coreauth.Auth{
ID: fileName,
Provider: a.Provider(),
FileName: fileName,
Label: identifier,
Metadata: metadata,
}, nil
}
func (a *GitLabAuthenticator) loginPAT(ctx context.Context, cfg *config.Config, opts *LoginOptions) (*coreauth.Auth, error) {
client := gitlabauth.NewAuthClient(cfg)
baseURL := a.resolveString(opts, gitLabBaseURLMetadataKey, gitlabauth.DefaultBaseURL)
token, err := a.requireInput(opts, gitLabPersonalAccessTokenMetadataKey, "Enter GitLab personal access token: ")
if err != nil {
return nil, err
}
user, err := client.GetCurrentUser(ctx, baseURL, token)
if err != nil {
return nil, err
}
_, err = client.GetPersonalAccessTokenSelf(ctx, baseURL, token)
if err != nil {
return nil, err
}
direct, err := client.FetchDirectAccess(ctx, baseURL, token)
if err != nil {
return nil, err
}
identifier := gitLabAccountIdentifier(user)
fileName := fmt.Sprintf("gitlab-%s-pat.json", sanitizeGitLabFileName(identifier))
metadata := buildGitLabAuthMetadata(baseURL, gitLabLoginModePAT, nil, direct)
metadata["auth_kind"] = "personal_access_token"
metadata[gitLabPersonalAccessTokenMetadataKey] = strings.TrimSpace(token)
metadata["token_preview"] = maskGitLabToken(token)
metadata["username"] = strings.TrimSpace(user.Username)
if email := strings.TrimSpace(primaryGitLabEmail(user)); email != "" {
metadata["email"] = email
}
metadata["name"] = strings.TrimSpace(user.Name)
fmt.Println("GitLab Duo PAT authentication successful")
return &coreauth.Auth{
ID: fileName,
Provider: a.Provider(),
FileName: fileName,
Label: identifier + " (PAT)",
Metadata: metadata,
}, nil
}
func buildGitLabAuthMetadata(baseURL, mode string, tokenResp *gitlabauth.TokenResponse, direct *gitlabauth.DirectAccessResponse) map[string]any {
metadata := map[string]any{
"type": "gitlab",
"auth_method": strings.TrimSpace(mode),
gitLabBaseURLMetadataKey: gitlabauth.NormalizeBaseURL(baseURL),
"last_refresh": time.Now().UTC().Format(time.RFC3339),
"refresh_interval_seconds": 240,
}
if tokenResp != nil {
metadata["access_token"] = strings.TrimSpace(tokenResp.AccessToken)
if refreshToken := strings.TrimSpace(tokenResp.RefreshToken); refreshToken != "" {
metadata["refresh_token"] = refreshToken
}
if tokenType := strings.TrimSpace(tokenResp.TokenType); tokenType != "" {
metadata["token_type"] = tokenType
}
if scope := strings.TrimSpace(tokenResp.Scope); scope != "" {
metadata["scope"] = scope
}
if expiry := gitlabauth.TokenExpiry(time.Now(), tokenResp); !expiry.IsZero() {
metadata["oauth_expires_at"] = expiry.Format(time.RFC3339)
}
}
mergeGitLabDirectAccessMetadata(metadata, direct)
return metadata
}
func mergeGitLabDirectAccessMetadata(metadata map[string]any, direct *gitlabauth.DirectAccessResponse) {
if metadata == nil || direct == nil {
return
}
if base := strings.TrimSpace(direct.BaseURL); base != "" {
metadata["duo_gateway_base_url"] = base
}
if token := strings.TrimSpace(direct.Token); token != "" {
metadata["duo_gateway_token"] = token
}
if direct.ExpiresAt > 0 {
expiry := time.Unix(direct.ExpiresAt, 0).UTC()
metadata["duo_gateway_expires_at"] = expiry.Format(time.RFC3339)
now := time.Now().UTC()
if ttl := expiry.Sub(now); ttl > 0 {
interval := int(ttl.Seconds()) / 2
switch {
case interval < 60:
interval = 60
case interval > 240:
interval = 240
}
metadata["refresh_interval_seconds"] = interval
}
}
if len(direct.Headers) > 0 {
headers := make(map[string]string, len(direct.Headers))
for key, value := range direct.Headers {
key = strings.TrimSpace(key)
value = strings.TrimSpace(value)
if key == "" || value == "" {
continue
}
headers[key] = value
}
if len(headers) > 0 {
metadata["duo_gateway_headers"] = headers
}
}
if direct.ModelDetails != nil {
modelDetails := map[string]any{}
if provider := strings.TrimSpace(direct.ModelDetails.ModelProvider); provider != "" {
modelDetails["model_provider"] = provider
metadata["model_provider"] = provider
}
if model := strings.TrimSpace(direct.ModelDetails.ModelName); model != "" {
modelDetails["model_name"] = model
metadata["model_name"] = model
}
if len(modelDetails) > 0 {
metadata["model_details"] = modelDetails
}
}
}
func (a *GitLabAuthenticator) resolveString(opts *LoginOptions, key, fallback string) string {
if opts != nil && opts.Metadata != nil {
if value := strings.TrimSpace(opts.Metadata[key]); value != "" {
return value
}
}
for _, envKey := range gitLabEnvKeys(key) {
if raw, ok := os.LookupEnv(envKey); ok {
if trimmed := strings.TrimSpace(raw); trimmed != "" {
return trimmed
}
}
}
if strings.TrimSpace(fallback) != "" {
return fallback
}
return ""
}
func (a *GitLabAuthenticator) requireInput(opts *LoginOptions, key, prompt string) (string, error) {
if value := a.resolveString(opts, key, ""); value != "" {
return value, nil
}
if opts != nil && opts.Prompt != nil {
value, err := opts.Prompt(prompt)
if err != nil {
return "", err
}
if trimmed := strings.TrimSpace(value); trimmed != "" {
return trimmed, nil
}
}
return "", fmt.Errorf("gitlab auth: missing required %s", key)
}
func (a *GitLabAuthenticator) optionalInput(opts *LoginOptions, key, prompt string) (string, error) {
if value := a.resolveString(opts, key, ""); value != "" {
return value, nil
}
if opts != nil && opts.Prompt != nil {
value, err := opts.Prompt(prompt)
if err != nil {
return "", err
}
return strings.TrimSpace(value), nil
}
return "", nil
}
func primaryGitLabEmail(user *gitlabauth.User) string {
if user == nil {
return ""
}
if value := strings.TrimSpace(user.Email); value != "" {
return value
}
return strings.TrimSpace(user.PublicEmail)
}
func gitLabAccountIdentifier(user *gitlabauth.User) string {
if user == nil {
return "user"
}
for _, value := range []string{user.Username, primaryGitLabEmail(user), user.Name} {
if trimmed := strings.TrimSpace(value); trimmed != "" {
return trimmed
}
}
return "user"
}
func sanitizeGitLabFileName(value string) string {
value = strings.TrimSpace(strings.ToLower(value))
if value == "" {
return "user"
}
var builder strings.Builder
lastDash := false
for _, r := range value {
switch {
case r >= 'a' && r <= 'z':
builder.WriteRune(r)
lastDash = false
case r >= '0' && r <= '9':
builder.WriteRune(r)
lastDash = false
case r == '-' || r == '_' || r == '.':
builder.WriteRune(r)
lastDash = false
default:
if !lastDash {
builder.WriteRune('-')
lastDash = true
}
}
}
result := strings.Trim(builder.String(), "-")
if result == "" {
return "user"
}
return result
}
func maskGitLabToken(token string) string {
trimmed := strings.TrimSpace(token)
if trimmed == "" {
return ""
}
if len(trimmed) <= 8 {
return trimmed
}
return trimmed[:4] + "..." + trimmed[len(trimmed)-4:]
}
func gitLabEnvKeys(key string) []string {
switch strings.TrimSpace(key) {
case gitLabBaseURLMetadataKey:
return []string{"GITLAB_BASE_URL"}
case gitLabOAuthClientIDMetadataKey:
return []string{"GITLAB_OAUTH_CLIENT_ID"}
case gitLabOAuthClientSecretMetadataKey:
return []string{"GITLAB_OAUTH_CLIENT_SECRET"}
case gitLabPersonalAccessTokenMetadataKey:
return []string{"GITLAB_PERSONAL_ACCESS_TOKEN"}
default:
return nil
}
}

66
sdk/auth/gitlab_test.go Normal file
View File

@@ -0,0 +1,66 @@
package auth
import (
"context"
"encoding/json"
"net/http"
"net/http/httptest"
"testing"
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
)
func TestGitLabAuthenticatorLoginPAT(t *testing.T) {
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
switch r.URL.Path {
case "/api/v4/user":
_ = json.NewEncoder(w).Encode(map[string]any{
"id": 42,
"username": "duo-user",
"email": "duo@example.com",
"name": "Duo User",
})
case "/api/v4/personal_access_tokens/self":
_ = json.NewEncoder(w).Encode(map[string]any{
"id": 5,
"name": "CLIProxyAPI",
"scopes": []string{"api"},
})
case "/api/v4/code_suggestions/direct_access":
_ = json.NewEncoder(w).Encode(map[string]any{
"base_url": "https://cloud.gitlab.example.com",
"token": "gateway-token",
"expires_at": 1710003600,
"headers": map[string]string{"X-Gitlab-Realm": "saas"},
"model_details": map[string]any{
"model_provider": "anthropic",
"model_name": "claude-sonnet-4-5",
},
})
default:
t.Fatalf("unexpected path %q", r.URL.Path)
}
}))
defer srv.Close()
authenticator := NewGitLabAuthenticator()
record, err := authenticator.Login(context.Background(), &config.Config{}, &LoginOptions{
Metadata: map[string]string{
"login_mode": "pat",
"base_url": srv.URL,
"personal_access_token": "glpat-test-token",
},
})
if err != nil {
t.Fatalf("Login() error = %v", err)
}
if record.Provider != "gitlab" {
t.Fatalf("expected gitlab provider, got %q", record.Provider)
}
if got := record.Metadata["model_name"]; got != "claude-sonnet-4-5" {
t.Fatalf("expected discovered model, got %#v", got)
}
if got := record.Metadata["auth_kind"]; got != "personal_access_token" {
t.Fatalf("expected personal_access_token auth kind, got %#v", got)
}
}

View File

@@ -17,6 +17,7 @@ func init() {
registerRefreshLead("kimi", func() Authenticator { return NewKimiAuthenticator() })
registerRefreshLead("kiro", func() Authenticator { return NewKiroAuthenticator() })
registerRefreshLead("github-copilot", func() Authenticator { return NewGitHubCopilotAuthenticator() })
registerRefreshLead("gitlab", func() Authenticator { return NewGitLabAuthenticator() })
}
func registerRefreshLead(provider string, factory func() Authenticator) {

View File

@@ -134,6 +134,7 @@ type Manager struct {
hook Hook
mu sync.RWMutex
auths map[string]*Auth
scheduler *authScheduler
// providerOffsets tracks per-model provider rotation state for multi-provider routing.
providerOffsets map[string]int
@@ -149,6 +150,9 @@ type Manager struct {
// Keyed by auth.ID, value is alias(lower) -> upstream model (including suffix).
apiKeyModelAlias atomic.Value
// modelPoolOffsets tracks per-auth alias pool rotation state.
modelPoolOffsets map[string]int
// runtimeConfig stores the latest application config for request-time decisions.
// It is initialized in NewManager; never Load() before first Store().
runtimeConfig atomic.Value
@@ -176,14 +180,59 @@ func NewManager(store Store, selector Selector, hook Hook) *Manager {
hook: hook,
auths: make(map[string]*Auth),
providerOffsets: make(map[string]int),
modelPoolOffsets: make(map[string]int),
refreshSemaphore: make(chan struct{}, refreshMaxConcurrency),
}
// atomic.Value requires non-nil initial value.
manager.runtimeConfig.Store(&internalconfig.Config{})
manager.apiKeyModelAlias.Store(apiKeyModelAliasTable(nil))
manager.scheduler = newAuthScheduler(selector)
return manager
}
func isBuiltInSelector(selector Selector) bool {
switch selector.(type) {
case *RoundRobinSelector, *FillFirstSelector:
return true
default:
return false
}
}
func (m *Manager) syncSchedulerFromSnapshot(auths []*Auth) {
if m == nil || m.scheduler == nil {
return
}
m.scheduler.rebuild(auths)
}
func (m *Manager) syncScheduler() {
if m == nil || m.scheduler == nil {
return
}
m.syncSchedulerFromSnapshot(m.snapshotAuths())
}
// RefreshSchedulerEntry re-upserts a single auth into the scheduler so that its
// supportedModelSet is rebuilt from the current global model registry state.
// This must be called after models have been registered for a newly added auth,
// because the initial scheduler.upsertAuth during Register/Update runs before
// registerModelsForAuth and therefore snapshots an empty model set.
func (m *Manager) RefreshSchedulerEntry(authID string) {
if m == nil || m.scheduler == nil || authID == "" {
return
}
m.mu.RLock()
auth, ok := m.auths[authID]
if !ok || auth == nil {
m.mu.RUnlock()
return
}
snapshot := auth.Clone()
m.mu.RUnlock()
m.scheduler.upsertAuth(snapshot)
}
func (m *Manager) SetSelector(selector Selector) {
if m == nil {
return
@@ -194,6 +243,10 @@ func (m *Manager) SetSelector(selector Selector) {
m.mu.Lock()
m.selector = selector
m.mu.Unlock()
if m.scheduler != nil {
m.scheduler.setSelector(selector)
m.syncScheduler()
}
}
// SetStore swaps the underlying persistence store.
@@ -251,16 +304,323 @@ func (m *Manager) lookupAPIKeyUpstreamModel(authID, requestedModel string) strin
if resolved == "" {
return ""
}
// Preserve thinking suffix from the client's requested model unless config already has one.
requestResult := thinking.ParseSuffix(requestedModel)
if thinking.ParseSuffix(resolved).HasSuffix {
return resolved
}
if requestResult.HasSuffix && requestResult.RawSuffix != "" {
return resolved + "(" + requestResult.RawSuffix + ")"
}
return resolved
return preserveRequestedModelSuffix(requestedModel, resolved)
}
func isAPIKeyAuth(auth *Auth) bool {
if auth == nil {
return false
}
kind, _ := auth.AccountInfo()
return strings.EqualFold(strings.TrimSpace(kind), "api_key")
}
func isOpenAICompatAPIKeyAuth(auth *Auth) bool {
if !isAPIKeyAuth(auth) {
return false
}
if strings.EqualFold(strings.TrimSpace(auth.Provider), "openai-compatibility") {
return true
}
if auth.Attributes == nil {
return false
}
return strings.TrimSpace(auth.Attributes["compat_name"]) != ""
}
func openAICompatProviderKey(auth *Auth) string {
if auth == nil {
return ""
}
if auth.Attributes != nil {
if providerKey := strings.TrimSpace(auth.Attributes["provider_key"]); providerKey != "" {
return strings.ToLower(providerKey)
}
if compatName := strings.TrimSpace(auth.Attributes["compat_name"]); compatName != "" {
return strings.ToLower(compatName)
}
}
return strings.ToLower(strings.TrimSpace(auth.Provider))
}
func openAICompatModelPoolKey(auth *Auth, requestedModel string) string {
base := strings.TrimSpace(thinking.ParseSuffix(requestedModel).ModelName)
if base == "" {
base = strings.TrimSpace(requestedModel)
}
return strings.ToLower(strings.TrimSpace(auth.ID)) + "|" + openAICompatProviderKey(auth) + "|" + strings.ToLower(base)
}
func (m *Manager) nextModelPoolOffset(key string, size int) int {
if m == nil || size <= 1 {
return 0
}
key = strings.TrimSpace(key)
if key == "" {
return 0
}
m.mu.Lock()
defer m.mu.Unlock()
if m.modelPoolOffsets == nil {
m.modelPoolOffsets = make(map[string]int)
}
offset := m.modelPoolOffsets[key]
if offset >= 2_147_483_640 {
offset = 0
}
m.modelPoolOffsets[key] = offset + 1
if size <= 0 {
return 0
}
return offset % size
}
func rotateStrings(values []string, offset int) []string {
if len(values) <= 1 {
return values
}
if offset <= 0 {
out := make([]string, len(values))
copy(out, values)
return out
}
offset = offset % len(values)
out := make([]string, 0, len(values))
out = append(out, values[offset:]...)
out = append(out, values[:offset]...)
return out
}
func (m *Manager) resolveOpenAICompatUpstreamModelPool(auth *Auth, requestedModel string) []string {
if m == nil || !isOpenAICompatAPIKeyAuth(auth) {
return nil
}
requestedModel = strings.TrimSpace(requestedModel)
if requestedModel == "" {
return nil
}
cfg, _ := m.runtimeConfig.Load().(*internalconfig.Config)
if cfg == nil {
cfg = &internalconfig.Config{}
}
providerKey := ""
compatName := ""
if auth.Attributes != nil {
providerKey = strings.TrimSpace(auth.Attributes["provider_key"])
compatName = strings.TrimSpace(auth.Attributes["compat_name"])
}
entry := resolveOpenAICompatConfig(cfg, providerKey, compatName, auth.Provider)
if entry == nil {
return nil
}
return resolveModelAliasPoolFromConfigModels(requestedModel, asModelAliasEntries(entry.Models))
}
func preserveRequestedModelSuffix(requestedModel, resolved string) string {
return preserveResolvedModelSuffix(resolved, thinking.ParseSuffix(requestedModel))
}
func (m *Manager) executionModelCandidates(auth *Auth, routeModel string) []string {
return m.prepareExecutionModels(auth, routeModel)
}
func (m *Manager) prepareExecutionModels(auth *Auth, routeModel string) []string {
requestedModel := rewriteModelForAuth(routeModel, auth)
requestedModel = m.applyOAuthModelAlias(auth, requestedModel)
if pool := m.resolveOpenAICompatUpstreamModelPool(auth, requestedModel); len(pool) > 0 {
if len(pool) == 1 {
return pool
}
offset := m.nextModelPoolOffset(openAICompatModelPoolKey(auth, requestedModel), len(pool))
return rotateStrings(pool, offset)
}
resolved := m.applyAPIKeyModelAlias(auth, requestedModel)
if strings.TrimSpace(resolved) == "" {
resolved = requestedModel
}
return []string{resolved}
}
func discardStreamChunks(ch <-chan cliproxyexecutor.StreamChunk) {
if ch == nil {
return
}
go func() {
for range ch {
}
}()
}
func readStreamBootstrap(ctx context.Context, ch <-chan cliproxyexecutor.StreamChunk) ([]cliproxyexecutor.StreamChunk, bool, error) {
if ch == nil {
return nil, true, nil
}
buffered := make([]cliproxyexecutor.StreamChunk, 0, 1)
for {
var (
chunk cliproxyexecutor.StreamChunk
ok bool
)
if ctx != nil {
select {
case <-ctx.Done():
return nil, false, ctx.Err()
case chunk, ok = <-ch:
}
} else {
chunk, ok = <-ch
}
if !ok {
return buffered, true, nil
}
if chunk.Err != nil {
return nil, false, chunk.Err
}
buffered = append(buffered, chunk)
if len(chunk.Payload) > 0 {
return buffered, false, nil
}
}
}
func (m *Manager) wrapStreamResult(ctx context.Context, auth *Auth, provider, routeModel string, headers http.Header, buffered []cliproxyexecutor.StreamChunk, remaining <-chan cliproxyexecutor.StreamChunk) *cliproxyexecutor.StreamResult {
out := make(chan cliproxyexecutor.StreamChunk)
go func() {
defer close(out)
var failed bool
forward := true
emit := func(chunk cliproxyexecutor.StreamChunk) bool {
if chunk.Err != nil && !failed {
failed = true
rerr := &Error{Message: chunk.Err.Error()}
if se, ok := errors.AsType[cliproxyexecutor.StatusError](chunk.Err); ok && se != nil {
rerr.HTTPStatus = se.StatusCode()
}
m.MarkResult(ctx, Result{AuthID: auth.ID, Provider: provider, Model: routeModel, Success: false, Error: rerr})
}
if !forward {
return false
}
if ctx == nil {
out <- chunk
return true
}
select {
case <-ctx.Done():
forward = false
return false
case out <- chunk:
return true
}
}
for _, chunk := range buffered {
if ok := emit(chunk); !ok {
discardStreamChunks(remaining)
return
}
}
for chunk := range remaining {
if ok := emit(chunk); !ok {
discardStreamChunks(remaining)
return
}
}
if !failed {
m.MarkResult(ctx, Result{AuthID: auth.ID, Provider: provider, Model: routeModel, Success: true})
}
}()
return &cliproxyexecutor.StreamResult{Headers: headers, Chunks: out}
}
func (m *Manager) executeStreamWithModelPool(ctx context.Context, executor ProviderExecutor, auth *Auth, provider string, req cliproxyexecutor.Request, opts cliproxyexecutor.Options, routeModel string) (*cliproxyexecutor.StreamResult, error) {
if executor == nil {
return nil, &Error{Code: "executor_not_found", Message: "executor not registered"}
}
execModels := m.prepareExecutionModels(auth, routeModel)
var lastErr error
for idx, execModel := range execModels {
execReq := req
execReq.Model = execModel
streamResult, errStream := executor.ExecuteStream(ctx, auth, execReq, opts)
if errStream != nil {
if errCtx := ctx.Err(); errCtx != nil {
return nil, errCtx
}
rerr := &Error{Message: errStream.Error()}
if se, ok := errors.AsType[cliproxyexecutor.StatusError](errStream); ok && se != nil {
rerr.HTTPStatus = se.StatusCode()
}
result := Result{AuthID: auth.ID, Provider: provider, Model: routeModel, Success: false, Error: rerr}
result.RetryAfter = retryAfterFromError(errStream)
m.MarkResult(ctx, result)
if isRequestInvalidError(errStream) {
return nil, errStream
}
lastErr = errStream
continue
}
buffered, closed, bootstrapErr := readStreamBootstrap(ctx, streamResult.Chunks)
if bootstrapErr != nil {
if errCtx := ctx.Err(); errCtx != nil {
discardStreamChunks(streamResult.Chunks)
return nil, errCtx
}
if isRequestInvalidError(bootstrapErr) {
rerr := &Error{Message: bootstrapErr.Error()}
if se, ok := errors.AsType[cliproxyexecutor.StatusError](bootstrapErr); ok && se != nil {
rerr.HTTPStatus = se.StatusCode()
}
result := Result{AuthID: auth.ID, Provider: provider, Model: routeModel, Success: false, Error: rerr}
result.RetryAfter = retryAfterFromError(bootstrapErr)
m.MarkResult(ctx, result)
discardStreamChunks(streamResult.Chunks)
return nil, bootstrapErr
}
if idx < len(execModels)-1 {
rerr := &Error{Message: bootstrapErr.Error()}
if se, ok := errors.AsType[cliproxyexecutor.StatusError](bootstrapErr); ok && se != nil {
rerr.HTTPStatus = se.StatusCode()
}
result := Result{AuthID: auth.ID, Provider: provider, Model: routeModel, Success: false, Error: rerr}
result.RetryAfter = retryAfterFromError(bootstrapErr)
m.MarkResult(ctx, result)
discardStreamChunks(streamResult.Chunks)
lastErr = bootstrapErr
continue
}
errCh := make(chan cliproxyexecutor.StreamChunk, 1)
errCh <- cliproxyexecutor.StreamChunk{Err: bootstrapErr}
close(errCh)
return m.wrapStreamResult(ctx, auth.Clone(), provider, routeModel, streamResult.Headers, nil, errCh), nil
}
if closed && len(buffered) == 0 {
emptyErr := &Error{Code: "empty_stream", Message: "upstream stream closed before first payload", Retryable: true}
result := Result{AuthID: auth.ID, Provider: provider, Model: routeModel, Success: false, Error: emptyErr}
m.MarkResult(ctx, result)
if idx < len(execModels)-1 {
lastErr = emptyErr
continue
}
errCh := make(chan cliproxyexecutor.StreamChunk, 1)
errCh <- cliproxyexecutor.StreamChunk{Err: emptyErr}
close(errCh)
return m.wrapStreamResult(ctx, auth.Clone(), provider, routeModel, streamResult.Headers, nil, errCh), nil
}
remaining := streamResult.Chunks
if closed {
closedCh := make(chan cliproxyexecutor.StreamChunk)
close(closedCh)
remaining = closedCh
}
return m.wrapStreamResult(ctx, auth.Clone(), provider, routeModel, streamResult.Headers, buffered, remaining), nil
}
if lastErr == nil {
lastErr = &Error{Code: "auth_not_found", Message: "no upstream model available"}
}
return nil, lastErr
}
func (m *Manager) rebuildAPIKeyModelAliasFromRuntimeConfig() {
@@ -448,10 +808,14 @@ func (m *Manager) Register(ctx context.Context, auth *Auth) (*Auth, error) {
auth.ID = uuid.NewString()
}
auth.EnsureIndex()
authClone := auth.Clone()
m.mu.Lock()
m.auths[auth.ID] = auth.Clone()
m.auths[auth.ID] = authClone
m.mu.Unlock()
m.rebuildAPIKeyModelAliasFromRuntimeConfig()
if m.scheduler != nil {
m.scheduler.upsertAuth(authClone)
}
_ = m.persist(ctx, auth)
m.hook.OnAuthRegistered(ctx, auth.Clone())
return auth.Clone(), nil
@@ -473,9 +837,13 @@ func (m *Manager) Update(ctx context.Context, auth *Auth) (*Auth, error) {
}
}
auth.EnsureIndex()
m.auths[auth.ID] = auth.Clone()
authClone := auth.Clone()
m.auths[auth.ID] = authClone
m.mu.Unlock()
m.rebuildAPIKeyModelAliasFromRuntimeConfig()
if m.scheduler != nil {
m.scheduler.upsertAuth(authClone)
}
_ = m.persist(ctx, auth)
m.hook.OnAuthUpdated(ctx, auth.Clone())
return auth.Clone(), nil
@@ -484,12 +852,13 @@ func (m *Manager) Update(ctx context.Context, auth *Auth) (*Auth, error) {
// Load resets manager state from the backing store.
func (m *Manager) Load(ctx context.Context) error {
m.mu.Lock()
defer m.mu.Unlock()
if m.store == nil {
m.mu.Unlock()
return nil
}
items, err := m.store.List(ctx)
if err != nil {
m.mu.Unlock()
return err
}
m.auths = make(map[string]*Auth, len(items))
@@ -505,6 +874,8 @@ func (m *Manager) Load(ctx context.Context) error {
cfg = &internalconfig.Config{}
}
m.rebuildAPIKeyModelAliasLocked(cfg)
m.mu.Unlock()
m.syncScheduler()
return nil
}
@@ -634,32 +1005,42 @@ func (m *Manager) executeMixedOnce(ctx context.Context, providers []string, req
execCtx = context.WithValue(execCtx, roundTripperContextKey{}, rt)
execCtx = context.WithValue(execCtx, "cliproxy.roundtripper", rt)
}
execReq := req
execReq.Model = rewriteModelForAuth(routeModel, auth)
execReq.Model = m.applyOAuthModelAlias(auth, execReq.Model)
execReq.Model = m.applyAPIKeyModelAlias(auth, execReq.Model)
resp, errExec := executor.Execute(execCtx, auth, execReq, opts)
result := Result{AuthID: auth.ID, Provider: provider, Model: routeModel, Success: errExec == nil}
if errExec != nil {
if errCtx := execCtx.Err(); errCtx != nil {
return cliproxyexecutor.Response{}, errCtx
}
result.Error = &Error{Message: errExec.Error()}
if se, ok := errors.AsType[cliproxyexecutor.StatusError](errExec); ok && se != nil {
result.Error.HTTPStatus = se.StatusCode()
}
if ra := retryAfterFromError(errExec); ra != nil {
result.RetryAfter = ra
models := m.prepareExecutionModels(auth, routeModel)
var authErr error
for _, upstreamModel := range models {
execReq := req
execReq.Model = upstreamModel
resp, errExec := executor.Execute(execCtx, auth, execReq, opts)
result := Result{AuthID: auth.ID, Provider: provider, Model: routeModel, Success: errExec == nil}
if errExec != nil {
if errCtx := execCtx.Err(); errCtx != nil {
return cliproxyexecutor.Response{}, errCtx
}
result.Error = &Error{Message: errExec.Error()}
if se, ok := errors.AsType[cliproxyexecutor.StatusError](errExec); ok && se != nil {
result.Error.HTTPStatus = se.StatusCode()
}
if ra := retryAfterFromError(errExec); ra != nil {
result.RetryAfter = ra
}
m.MarkResult(execCtx, result)
if isRequestInvalidError(errExec) {
return cliproxyexecutor.Response{}, errExec
}
authErr = errExec
continue
}
m.MarkResult(execCtx, result)
if isRequestInvalidError(errExec) {
return cliproxyexecutor.Response{}, errExec
return resp, nil
}
if authErr != nil {
if isRequestInvalidError(authErr) {
return cliproxyexecutor.Response{}, authErr
}
lastErr = errExec
lastErr = authErr
continue
}
m.MarkResult(execCtx, result)
return resp, nil
}
}
@@ -696,32 +1077,42 @@ func (m *Manager) executeCountMixedOnce(ctx context.Context, providers []string,
execCtx = context.WithValue(execCtx, roundTripperContextKey{}, rt)
execCtx = context.WithValue(execCtx, "cliproxy.roundtripper", rt)
}
execReq := req
execReq.Model = rewriteModelForAuth(routeModel, auth)
execReq.Model = m.applyOAuthModelAlias(auth, execReq.Model)
execReq.Model = m.applyAPIKeyModelAlias(auth, execReq.Model)
resp, errExec := executor.CountTokens(execCtx, auth, execReq, opts)
result := Result{AuthID: auth.ID, Provider: provider, Model: routeModel, Success: errExec == nil}
if errExec != nil {
if errCtx := execCtx.Err(); errCtx != nil {
return cliproxyexecutor.Response{}, errCtx
}
result.Error = &Error{Message: errExec.Error()}
if se, ok := errors.AsType[cliproxyexecutor.StatusError](errExec); ok && se != nil {
result.Error.HTTPStatus = se.StatusCode()
}
if ra := retryAfterFromError(errExec); ra != nil {
result.RetryAfter = ra
models := m.prepareExecutionModels(auth, routeModel)
var authErr error
for _, upstreamModel := range models {
execReq := req
execReq.Model = upstreamModel
resp, errExec := executor.CountTokens(execCtx, auth, execReq, opts)
result := Result{AuthID: auth.ID, Provider: provider, Model: routeModel, Success: errExec == nil}
if errExec != nil {
if errCtx := execCtx.Err(); errCtx != nil {
return cliproxyexecutor.Response{}, errCtx
}
result.Error = &Error{Message: errExec.Error()}
if se, ok := errors.AsType[cliproxyexecutor.StatusError](errExec); ok && se != nil {
result.Error.HTTPStatus = se.StatusCode()
}
if ra := retryAfterFromError(errExec); ra != nil {
result.RetryAfter = ra
}
m.hook.OnResult(execCtx, result)
if isRequestInvalidError(errExec) {
return cliproxyexecutor.Response{}, errExec
}
authErr = errExec
continue
}
m.hook.OnResult(execCtx, result)
if isRequestInvalidError(errExec) {
return cliproxyexecutor.Response{}, errExec
return resp, nil
}
if authErr != nil {
if isRequestInvalidError(authErr) {
return cliproxyexecutor.Response{}, authErr
}
lastErr = errExec
lastErr = authErr
continue
}
m.hook.OnResult(execCtx, result)
return resp, nil
}
}
@@ -758,63 +1149,18 @@ func (m *Manager) executeStreamMixedOnce(ctx context.Context, providers []string
execCtx = context.WithValue(execCtx, roundTripperContextKey{}, rt)
execCtx = context.WithValue(execCtx, "cliproxy.roundtripper", rt)
}
execReq := req
execReq.Model = rewriteModelForAuth(routeModel, auth)
execReq.Model = m.applyOAuthModelAlias(auth, execReq.Model)
execReq.Model = m.applyAPIKeyModelAlias(auth, execReq.Model)
streamResult, errStream := executor.ExecuteStream(execCtx, auth, execReq, opts)
streamResult, errStream := m.executeStreamWithModelPool(execCtx, executor, auth, provider, req, opts, routeModel)
if errStream != nil {
if errCtx := execCtx.Err(); errCtx != nil {
return nil, errCtx
}
rerr := &Error{Message: errStream.Error()}
if se, ok := errors.AsType[cliproxyexecutor.StatusError](errStream); ok && se != nil {
rerr.HTTPStatus = se.StatusCode()
}
result := Result{AuthID: auth.ID, Provider: provider, Model: routeModel, Success: false, Error: rerr}
result.RetryAfter = retryAfterFromError(errStream)
m.MarkResult(execCtx, result)
if isRequestInvalidError(errStream) {
return nil, errStream
}
lastErr = errStream
continue
}
out := make(chan cliproxyexecutor.StreamChunk)
go func(streamCtx context.Context, streamAuth *Auth, streamProvider string, streamChunks <-chan cliproxyexecutor.StreamChunk) {
defer close(out)
var failed bool
forward := true
for chunk := range streamChunks {
if chunk.Err != nil && !failed {
failed = true
rerr := &Error{Message: chunk.Err.Error()}
if se, ok := errors.AsType[cliproxyexecutor.StatusError](chunk.Err); ok && se != nil {
rerr.HTTPStatus = se.StatusCode()
}
m.MarkResult(streamCtx, Result{AuthID: streamAuth.ID, Provider: streamProvider, Model: routeModel, Success: false, Error: rerr})
}
if !forward {
continue
}
if streamCtx == nil {
out <- chunk
continue
}
select {
case <-streamCtx.Done():
forward = false
case out <- chunk:
}
}
if !failed {
m.MarkResult(streamCtx, Result{AuthID: streamAuth.ID, Provider: streamProvider, Model: routeModel, Success: true})
}
}(execCtx, auth.Clone(), provider, streamResult.Chunks)
return &cliproxyexecutor.StreamResult{
Headers: streamResult.Headers,
Chunks: out,
}, nil
return streamResult, nil
}
}
@@ -1245,6 +1591,7 @@ func (m *Manager) MarkResult(ctx context.Context, result Result) {
suspendReason := ""
clearModelQuota := false
setModelQuota := false
var authSnapshot *Auth
m.mu.Lock()
if auth, ok := m.auths[result.AuthID]; ok && auth != nil {
@@ -1338,8 +1685,12 @@ func (m *Manager) MarkResult(ctx context.Context, result Result) {
}
_ = m.persist(ctx, auth)
authSnapshot = auth.Clone()
}
m.mu.Unlock()
if m.scheduler != nil && authSnapshot != nil {
m.scheduler.upsertAuth(authSnapshot)
}
if clearModelQuota && result.Model != "" {
registry.GetGlobalRegistry().ClearModelQuotaExceeded(result.AuthID, result.Model)
@@ -1533,18 +1884,22 @@ func statusCodeFromResult(err *Error) int {
}
// isRequestInvalidError returns true if the error represents a client request
// error that should not be retried. Specifically, it checks for 400 Bad Request
// with "invalid_request_error" in the message, indicating the request itself is
// malformed and switching to a different auth will not help.
// error that should not be retried. Specifically, it treats 400 responses with
// "invalid_request_error" and all 422 responses as request-shape failures,
// where switching auths or pooled upstream models will not help.
func isRequestInvalidError(err error) bool {
if err == nil {
return false
}
status := statusCodeFromError(err)
if status != http.StatusBadRequest {
switch status {
case http.StatusBadRequest:
return strings.Contains(err.Error(), "invalid_request_error")
case http.StatusUnprocessableEntity:
return true
default:
return false
}
return strings.Contains(err.Error(), "invalid_request_error")
}
func applyAuthFailureState(auth *Auth, resultErr *Error, retryAfter *time.Duration, now time.Time) {
@@ -1692,7 +2047,29 @@ func (m *Manager) CloseExecutionSession(sessionID string) {
}
}
func (m *Manager) pickNext(ctx context.Context, provider, model string, opts cliproxyexecutor.Options, tried map[string]struct{}) (*Auth, ProviderExecutor, error) {
func (m *Manager) useSchedulerFastPath() bool {
if m == nil || m.scheduler == nil {
return false
}
return isBuiltInSelector(m.selector)
}
func shouldRetrySchedulerPick(err error) bool {
if err == nil {
return false
}
var cooldownErr *modelCooldownError
if errors.As(err, &cooldownErr) {
return true
}
var authErr *Error
if !errors.As(err, &authErr) || authErr == nil {
return false
}
return authErr.Code == "auth_not_found" || authErr.Code == "auth_unavailable"
}
func (m *Manager) pickNextLegacy(ctx context.Context, provider, model string, opts cliproxyexecutor.Options, tried map[string]struct{}) (*Auth, ProviderExecutor, error) {
pinnedAuthID := pinnedAuthIDFromMetadata(opts.Metadata)
m.mu.RLock()
@@ -1752,7 +2129,38 @@ func (m *Manager) pickNext(ctx context.Context, provider, model string, opts cli
return authCopy, executor, nil
}
func (m *Manager) pickNextMixed(ctx context.Context, providers []string, model string, opts cliproxyexecutor.Options, tried map[string]struct{}) (*Auth, ProviderExecutor, string, error) {
func (m *Manager) pickNext(ctx context.Context, provider, model string, opts cliproxyexecutor.Options, tried map[string]struct{}) (*Auth, ProviderExecutor, error) {
if !m.useSchedulerFastPath() {
return m.pickNextLegacy(ctx, provider, model, opts, tried)
}
executor, okExecutor := m.Executor(provider)
if !okExecutor {
return nil, nil, &Error{Code: "executor_not_found", Message: "executor not registered"}
}
selected, errPick := m.scheduler.pickSingle(ctx, provider, model, opts, tried)
if errPick != nil && model != "" && shouldRetrySchedulerPick(errPick) {
m.syncScheduler()
selected, errPick = m.scheduler.pickSingle(ctx, provider, model, opts, tried)
}
if errPick != nil {
return nil, nil, errPick
}
if selected == nil {
return nil, nil, &Error{Code: "auth_not_found", Message: "selector returned no auth"}
}
authCopy := selected.Clone()
if !selected.indexAssigned {
m.mu.Lock()
if current := m.auths[authCopy.ID]; current != nil && !current.indexAssigned {
current.EnsureIndex()
authCopy = current.Clone()
}
m.mu.Unlock()
}
return authCopy, executor, nil
}
func (m *Manager) pickNextMixedLegacy(ctx context.Context, providers []string, model string, opts cliproxyexecutor.Options, tried map[string]struct{}) (*Auth, ProviderExecutor, string, error) {
pinnedAuthID := pinnedAuthIDFromMetadata(opts.Metadata)
providerSet := make(map[string]struct{}, len(providers))
@@ -1835,6 +2243,58 @@ func (m *Manager) pickNextMixed(ctx context.Context, providers []string, model s
return authCopy, executor, providerKey, nil
}
func (m *Manager) pickNextMixed(ctx context.Context, providers []string, model string, opts cliproxyexecutor.Options, tried map[string]struct{}) (*Auth, ProviderExecutor, string, error) {
if !m.useSchedulerFastPath() {
return m.pickNextMixedLegacy(ctx, providers, model, opts, tried)
}
eligibleProviders := make([]string, 0, len(providers))
seenProviders := make(map[string]struct{}, len(providers))
for _, provider := range providers {
providerKey := strings.TrimSpace(strings.ToLower(provider))
if providerKey == "" {
continue
}
if _, seen := seenProviders[providerKey]; seen {
continue
}
if _, okExecutor := m.Executor(providerKey); !okExecutor {
continue
}
seenProviders[providerKey] = struct{}{}
eligibleProviders = append(eligibleProviders, providerKey)
}
if len(eligibleProviders) == 0 {
return nil, nil, "", &Error{Code: "auth_not_found", Message: "no auth available"}
}
selected, providerKey, errPick := m.scheduler.pickMixed(ctx, eligibleProviders, model, opts, tried)
if errPick != nil && model != "" && shouldRetrySchedulerPick(errPick) {
m.syncScheduler()
selected, providerKey, errPick = m.scheduler.pickMixed(ctx, eligibleProviders, model, opts, tried)
}
if errPick != nil {
return nil, nil, "", errPick
}
if selected == nil {
return nil, nil, "", &Error{Code: "auth_not_found", Message: "selector returned no auth"}
}
executor, okExecutor := m.Executor(providerKey)
if !okExecutor {
return nil, nil, "", &Error{Code: "executor_not_found", Message: "executor not registered"}
}
authCopy := selected.Clone()
if !selected.indexAssigned {
m.mu.Lock()
if current := m.auths[authCopy.ID]; current != nil && !current.indexAssigned {
current.EnsureIndex()
authCopy = current.Clone()
}
m.mu.Unlock()
}
return authCopy, executor, providerKey, nil
}
func (m *Manager) persist(ctx context.Context, auth *Auth) error {
if m.store == nil || auth == nil {
return nil
@@ -2186,6 +2646,9 @@ func (m *Manager) refreshAuth(ctx context.Context, id string) {
current.NextRefreshAfter = now.Add(refreshFailureBackoff)
current.LastError = &Error{Message: err.Error()}
m.auths[id] = current
if m.scheduler != nil {
m.scheduler.upsertAuth(current.Clone())
}
}
m.mu.Unlock()
return

View File

@@ -0,0 +1,163 @@
package auth
import (
"context"
"errors"
"net/http"
"testing"
"github.com/router-for-me/CLIProxyAPI/v6/internal/registry"
cliproxyexecutor "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/executor"
)
type schedulerProviderTestExecutor struct {
provider string
}
func (e schedulerProviderTestExecutor) Identifier() string { return e.provider }
func (e schedulerProviderTestExecutor) Execute(ctx context.Context, auth *Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (cliproxyexecutor.Response, error) {
return cliproxyexecutor.Response{}, nil
}
func (e schedulerProviderTestExecutor) ExecuteStream(ctx context.Context, auth *Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (*cliproxyexecutor.StreamResult, error) {
return nil, nil
}
func (e schedulerProviderTestExecutor) Refresh(ctx context.Context, auth *Auth) (*Auth, error) {
return auth, nil
}
func (e schedulerProviderTestExecutor) CountTokens(ctx context.Context, auth *Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (cliproxyexecutor.Response, error) {
return cliproxyexecutor.Response{}, nil
}
func (e schedulerProviderTestExecutor) HttpRequest(ctx context.Context, auth *Auth, req *http.Request) (*http.Response, error) {
return nil, nil
}
func TestManager_RefreshSchedulerEntry_RebuildsSupportedModelSetAfterModelRegistration(t *testing.T) {
ctx := context.Background()
testCases := []struct {
name string
prime func(*Manager, *Auth) error
}{
{
name: "register",
prime: func(manager *Manager, auth *Auth) error {
_, errRegister := manager.Register(ctx, auth)
return errRegister
},
},
{
name: "update",
prime: func(manager *Manager, auth *Auth) error {
_, errRegister := manager.Register(ctx, auth)
if errRegister != nil {
return errRegister
}
updated := auth.Clone()
updated.Metadata = map[string]any{"updated": true}
_, errUpdate := manager.Update(ctx, updated)
return errUpdate
},
},
}
for _, testCase := range testCases {
testCase := testCase
t.Run(testCase.name, func(t *testing.T) {
manager := NewManager(nil, &RoundRobinSelector{}, nil)
auth := &Auth{
ID: "refresh-entry-" + testCase.name,
Provider: "gemini",
}
if errPrime := testCase.prime(manager, auth); errPrime != nil {
t.Fatalf("prime auth %s: %v", testCase.name, errPrime)
}
registerSchedulerModels(t, "gemini", "scheduler-refresh-model", auth.ID)
got, errPick := manager.scheduler.pickSingle(ctx, "gemini", "scheduler-refresh-model", cliproxyexecutor.Options{}, nil)
var authErr *Error
if !errors.As(errPick, &authErr) || authErr == nil {
t.Fatalf("pickSingle() before refresh error = %v, want auth_not_found", errPick)
}
if authErr.Code != "auth_not_found" {
t.Fatalf("pickSingle() before refresh code = %q, want %q", authErr.Code, "auth_not_found")
}
if got != nil {
t.Fatalf("pickSingle() before refresh auth = %v, want nil", got)
}
manager.RefreshSchedulerEntry(auth.ID)
got, errPick = manager.scheduler.pickSingle(ctx, "gemini", "scheduler-refresh-model", cliproxyexecutor.Options{}, nil)
if errPick != nil {
t.Fatalf("pickSingle() after refresh error = %v", errPick)
}
if got == nil || got.ID != auth.ID {
t.Fatalf("pickSingle() after refresh auth = %v, want %q", got, auth.ID)
}
})
}
}
func TestManager_PickNext_RebuildsSchedulerAfterModelCooldownError(t *testing.T) {
ctx := context.Background()
manager := NewManager(nil, &RoundRobinSelector{}, nil)
manager.RegisterExecutor(schedulerProviderTestExecutor{provider: "gemini"})
registerSchedulerModels(t, "gemini", "scheduler-cooldown-rebuild-model", "cooldown-stale-old")
oldAuth := &Auth{
ID: "cooldown-stale-old",
Provider: "gemini",
}
if _, errRegister := manager.Register(ctx, oldAuth); errRegister != nil {
t.Fatalf("register old auth: %v", errRegister)
}
manager.MarkResult(ctx, Result{
AuthID: oldAuth.ID,
Provider: "gemini",
Model: "scheduler-cooldown-rebuild-model",
Success: false,
Error: &Error{HTTPStatus: http.StatusTooManyRequests, Message: "quota"},
})
newAuth := &Auth{
ID: "cooldown-stale-new",
Provider: "gemini",
}
if _, errRegister := manager.Register(ctx, newAuth); errRegister != nil {
t.Fatalf("register new auth: %v", errRegister)
}
reg := registry.GetGlobalRegistry()
reg.RegisterClient(newAuth.ID, "gemini", []*registry.ModelInfo{{ID: "scheduler-cooldown-rebuild-model"}})
t.Cleanup(func() {
reg.UnregisterClient(newAuth.ID)
})
got, errPick := manager.scheduler.pickSingle(ctx, "gemini", "scheduler-cooldown-rebuild-model", cliproxyexecutor.Options{}, nil)
var cooldownErr *modelCooldownError
if !errors.As(errPick, &cooldownErr) {
t.Fatalf("pickSingle() before sync error = %v, want modelCooldownError", errPick)
}
if got != nil {
t.Fatalf("pickSingle() before sync auth = %v, want nil", got)
}
got, executor, errPick := manager.pickNext(ctx, "gemini", "scheduler-cooldown-rebuild-model", cliproxyexecutor.Options{}, nil)
if errPick != nil {
t.Fatalf("pickNext() error = %v", errPick)
}
if executor == nil {
t.Fatal("pickNext() executor = nil")
}
if got == nil || got.ID != newAuth.ID {
t.Fatalf("pickNext() auth = %v, want %q", got, newAuth.ID)
}
}

View File

@@ -80,54 +80,98 @@ func (m *Manager) applyOAuthModelAlias(auth *Auth, requestedModel string) string
return upstreamModel
}
func resolveModelAliasFromConfigModels(requestedModel string, models []modelAliasEntry) string {
func modelAliasLookupCandidates(requestedModel string) (thinking.SuffixResult, []string) {
requestedModel = strings.TrimSpace(requestedModel)
if requestedModel == "" {
return ""
return thinking.SuffixResult{}, nil
}
if len(models) == 0 {
return ""
}
requestResult := thinking.ParseSuffix(requestedModel)
base := requestResult.ModelName
if base == "" {
base = requestedModel
}
candidates := []string{base}
if base != requestedModel {
candidates = append(candidates, requestedModel)
}
return requestResult, candidates
}
preserveSuffix := func(resolved string) string {
resolved = strings.TrimSpace(resolved)
if resolved == "" {
return ""
}
if thinking.ParseSuffix(resolved).HasSuffix {
return resolved
}
if requestResult.HasSuffix && requestResult.RawSuffix != "" {
return resolved + "(" + requestResult.RawSuffix + ")"
}
func preserveResolvedModelSuffix(resolved string, requestResult thinking.SuffixResult) string {
resolved = strings.TrimSpace(resolved)
if resolved == "" {
return ""
}
if thinking.ParseSuffix(resolved).HasSuffix {
return resolved
}
if requestResult.HasSuffix && requestResult.RawSuffix != "" {
return resolved + "(" + requestResult.RawSuffix + ")"
}
return resolved
}
func resolveModelAliasPoolFromConfigModels(requestedModel string, models []modelAliasEntry) []string {
requestedModel = strings.TrimSpace(requestedModel)
if requestedModel == "" {
return nil
}
if len(models) == 0 {
return nil
}
requestResult, candidates := modelAliasLookupCandidates(requestedModel)
if len(candidates) == 0 {
return nil
}
out := make([]string, 0)
seen := make(map[string]struct{})
for i := range models {
name := strings.TrimSpace(models[i].GetName())
alias := strings.TrimSpace(models[i].GetAlias())
for _, candidate := range candidates {
if candidate == "" {
if candidate == "" || alias == "" || !strings.EqualFold(alias, candidate) {
continue
}
if alias != "" && strings.EqualFold(alias, candidate) {
if name != "" {
return preserveSuffix(name)
}
return preserveSuffix(candidate)
resolved := candidate
if name != "" {
resolved = name
}
if name != "" && strings.EqualFold(name, candidate) {
return preserveSuffix(name)
resolved = preserveResolvedModelSuffix(resolved, requestResult)
key := strings.ToLower(strings.TrimSpace(resolved))
if key == "" {
break
}
if _, exists := seen[key]; exists {
break
}
seen[key] = struct{}{}
out = append(out, resolved)
break
}
}
if len(out) > 0 {
return out
}
for i := range models {
name := strings.TrimSpace(models[i].GetName())
for _, candidate := range candidates {
if candidate == "" || name == "" || !strings.EqualFold(name, candidate) {
continue
}
return []string{preserveResolvedModelSuffix(name, requestResult)}
}
}
return nil
}
func resolveModelAliasFromConfigModels(requestedModel string, models []modelAliasEntry) string {
resolved := resolveModelAliasPoolFromConfigModels(requestedModel, models)
if len(resolved) > 0 {
return resolved[0]
}
return ""
}

View File

@@ -0,0 +1,419 @@
package auth
import (
"context"
"net/http"
"sync"
"testing"
internalconfig "github.com/router-for-me/CLIProxyAPI/v6/internal/config"
"github.com/router-for-me/CLIProxyAPI/v6/internal/registry"
cliproxyexecutor "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/executor"
)
type openAICompatPoolExecutor struct {
id string
mu sync.Mutex
executeModels []string
countModels []string
streamModels []string
executeErrors map[string]error
countErrors map[string]error
streamFirstErrors map[string]error
streamPayloads map[string][]cliproxyexecutor.StreamChunk
}
func (e *openAICompatPoolExecutor) Identifier() string { return e.id }
func (e *openAICompatPoolExecutor) Execute(ctx context.Context, auth *Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (cliproxyexecutor.Response, error) {
_ = ctx
_ = auth
_ = opts
e.mu.Lock()
e.executeModels = append(e.executeModels, req.Model)
err := e.executeErrors[req.Model]
e.mu.Unlock()
if err != nil {
return cliproxyexecutor.Response{}, err
}
return cliproxyexecutor.Response{Payload: []byte(req.Model)}, nil
}
func (e *openAICompatPoolExecutor) ExecuteStream(ctx context.Context, auth *Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (*cliproxyexecutor.StreamResult, error) {
_ = ctx
_ = auth
_ = opts
e.mu.Lock()
e.streamModels = append(e.streamModels, req.Model)
err := e.streamFirstErrors[req.Model]
payloadChunks, hasCustomChunks := e.streamPayloads[req.Model]
chunks := append([]cliproxyexecutor.StreamChunk(nil), payloadChunks...)
e.mu.Unlock()
ch := make(chan cliproxyexecutor.StreamChunk, max(1, len(chunks)))
if err != nil {
ch <- cliproxyexecutor.StreamChunk{Err: err}
close(ch)
return &cliproxyexecutor.StreamResult{Headers: http.Header{"X-Model": {req.Model}}, Chunks: ch}, nil
}
if !hasCustomChunks {
ch <- cliproxyexecutor.StreamChunk{Payload: []byte(req.Model)}
} else {
for _, chunk := range chunks {
ch <- chunk
}
}
close(ch)
return &cliproxyexecutor.StreamResult{Headers: http.Header{"X-Model": {req.Model}}, Chunks: ch}, nil
}
func (e *openAICompatPoolExecutor) Refresh(_ context.Context, auth *Auth) (*Auth, error) {
return auth, nil
}
func (e *openAICompatPoolExecutor) CountTokens(ctx context.Context, auth *Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (cliproxyexecutor.Response, error) {
_ = ctx
_ = auth
_ = opts
e.mu.Lock()
e.countModels = append(e.countModels, req.Model)
err := e.countErrors[req.Model]
e.mu.Unlock()
if err != nil {
return cliproxyexecutor.Response{}, err
}
return cliproxyexecutor.Response{Payload: []byte(req.Model)}, nil
}
func (e *openAICompatPoolExecutor) HttpRequest(ctx context.Context, auth *Auth, req *http.Request) (*http.Response, error) {
_ = ctx
_ = auth
_ = req
return nil, &Error{HTTPStatus: http.StatusNotImplemented, Message: "HttpRequest not implemented"}
}
func (e *openAICompatPoolExecutor) ExecuteModels() []string {
e.mu.Lock()
defer e.mu.Unlock()
out := make([]string, len(e.executeModels))
copy(out, e.executeModels)
return out
}
func (e *openAICompatPoolExecutor) CountModels() []string {
e.mu.Lock()
defer e.mu.Unlock()
out := make([]string, len(e.countModels))
copy(out, e.countModels)
return out
}
func (e *openAICompatPoolExecutor) StreamModels() []string {
e.mu.Lock()
defer e.mu.Unlock()
out := make([]string, len(e.streamModels))
copy(out, e.streamModels)
return out
}
func newOpenAICompatPoolTestManager(t *testing.T, alias string, models []internalconfig.OpenAICompatibilityModel, executor *openAICompatPoolExecutor) *Manager {
t.Helper()
cfg := &internalconfig.Config{
OpenAICompatibility: []internalconfig.OpenAICompatibility{{
Name: "pool",
Models: models,
}},
}
m := NewManager(nil, nil, nil)
m.SetConfig(cfg)
if executor == nil {
executor = &openAICompatPoolExecutor{id: "pool"}
}
m.RegisterExecutor(executor)
auth := &Auth{
ID: "pool-auth-" + t.Name(),
Provider: "pool",
Status: StatusActive,
Attributes: map[string]string{
"api_key": "test-key",
"compat_name": "pool",
"provider_key": "pool",
},
}
if _, err := m.Register(context.Background(), auth); err != nil {
t.Fatalf("register auth: %v", err)
}
reg := registry.GetGlobalRegistry()
reg.RegisterClient(auth.ID, "pool", []*registry.ModelInfo{{ID: alias}})
t.Cleanup(func() {
reg.UnregisterClient(auth.ID)
})
return m
}
func TestManagerExecuteCount_OpenAICompatAliasPoolStopsOnInvalidRequest(t *testing.T) {
alias := "claude-opus-4.66"
invalidErr := &Error{HTTPStatus: http.StatusUnprocessableEntity, Message: "unprocessable entity"}
executor := &openAICompatPoolExecutor{
id: "pool",
countErrors: map[string]error{"qwen3.5-plus": invalidErr},
}
m := newOpenAICompatPoolTestManager(t, alias, []internalconfig.OpenAICompatibilityModel{
{Name: "qwen3.5-plus", Alias: alias},
{Name: "glm-5", Alias: alias},
}, executor)
_, err := m.ExecuteCount(context.Background(), []string{"pool"}, cliproxyexecutor.Request{Model: alias}, cliproxyexecutor.Options{})
if err == nil || err.Error() != invalidErr.Error() {
t.Fatalf("execute count error = %v, want %v", err, invalidErr)
}
got := executor.CountModels()
if len(got) != 1 || got[0] != "qwen3.5-plus" {
t.Fatalf("count calls = %v, want only first invalid model", got)
}
}
func TestResolveModelAliasPoolFromConfigModels(t *testing.T) {
models := []modelAliasEntry{
internalconfig.OpenAICompatibilityModel{Name: "qwen3.5-plus", Alias: "claude-opus-4.66"},
internalconfig.OpenAICompatibilityModel{Name: "glm-5", Alias: "claude-opus-4.66"},
internalconfig.OpenAICompatibilityModel{Name: "kimi-k2.5", Alias: "claude-opus-4.66"},
}
got := resolveModelAliasPoolFromConfigModels("claude-opus-4.66(8192)", models)
want := []string{"qwen3.5-plus(8192)", "glm-5(8192)", "kimi-k2.5(8192)"}
if len(got) != len(want) {
t.Fatalf("pool len = %d, want %d (%v)", len(got), len(want), got)
}
for i := range want {
if got[i] != want[i] {
t.Fatalf("pool[%d] = %q, want %q", i, got[i], want[i])
}
}
}
func TestManagerExecute_OpenAICompatAliasPoolRotatesWithinAuth(t *testing.T) {
alias := "claude-opus-4.66"
executor := &openAICompatPoolExecutor{id: "pool"}
m := newOpenAICompatPoolTestManager(t, alias, []internalconfig.OpenAICompatibilityModel{
{Name: "qwen3.5-plus", Alias: alias},
{Name: "glm-5", Alias: alias},
}, executor)
for i := 0; i < 3; i++ {
resp, err := m.Execute(context.Background(), []string{"pool"}, cliproxyexecutor.Request{Model: alias}, cliproxyexecutor.Options{})
if err != nil {
t.Fatalf("execute %d: %v", i, err)
}
if len(resp.Payload) == 0 {
t.Fatalf("execute %d returned empty payload", i)
}
}
got := executor.ExecuteModels()
want := []string{"qwen3.5-plus", "glm-5", "qwen3.5-plus"}
if len(got) != len(want) {
t.Fatalf("execute calls = %v, want %v", got, want)
}
for i := range want {
if got[i] != want[i] {
t.Fatalf("execute call %d model = %q, want %q", i, got[i], want[i])
}
}
}
func TestManagerExecute_OpenAICompatAliasPoolStopsOnBadRequest(t *testing.T) {
alias := "claude-opus-4.66"
invalidErr := &Error{HTTPStatus: http.StatusBadRequest, Message: "invalid_request_error: malformed payload"}
executor := &openAICompatPoolExecutor{
id: "pool",
executeErrors: map[string]error{"qwen3.5-plus": invalidErr},
}
m := newOpenAICompatPoolTestManager(t, alias, []internalconfig.OpenAICompatibilityModel{
{Name: "qwen3.5-plus", Alias: alias},
{Name: "glm-5", Alias: alias},
}, executor)
_, err := m.Execute(context.Background(), []string{"pool"}, cliproxyexecutor.Request{Model: alias}, cliproxyexecutor.Options{})
if err == nil || err.Error() != invalidErr.Error() {
t.Fatalf("execute error = %v, want %v", err, invalidErr)
}
got := executor.ExecuteModels()
if len(got) != 1 || got[0] != "qwen3.5-plus" {
t.Fatalf("execute calls = %v, want only first invalid model", got)
}
}
func TestManagerExecute_OpenAICompatAliasPoolFallsBackWithinSameAuth(t *testing.T) {
alias := "claude-opus-4.66"
executor := &openAICompatPoolExecutor{
id: "pool",
executeErrors: map[string]error{"qwen3.5-plus": &Error{HTTPStatus: http.StatusTooManyRequests, Message: "quota"}},
}
m := newOpenAICompatPoolTestManager(t, alias, []internalconfig.OpenAICompatibilityModel{
{Name: "qwen3.5-plus", Alias: alias},
{Name: "glm-5", Alias: alias},
}, executor)
resp, err := m.Execute(context.Background(), []string{"pool"}, cliproxyexecutor.Request{Model: alias}, cliproxyexecutor.Options{})
if err != nil {
t.Fatalf("execute: %v", err)
}
if string(resp.Payload) != "glm-5" {
t.Fatalf("payload = %q, want %q", string(resp.Payload), "glm-5")
}
got := executor.ExecuteModels()
want := []string{"qwen3.5-plus", "glm-5"}
for i := range want {
if got[i] != want[i] {
t.Fatalf("execute call %d model = %q, want %q", i, got[i], want[i])
}
}
}
func TestManagerExecuteStream_OpenAICompatAliasPoolRetriesOnEmptyBootstrap(t *testing.T) {
alias := "claude-opus-4.66"
executor := &openAICompatPoolExecutor{
id: "pool",
streamPayloads: map[string][]cliproxyexecutor.StreamChunk{
"qwen3.5-plus": {},
},
}
m := newOpenAICompatPoolTestManager(t, alias, []internalconfig.OpenAICompatibilityModel{
{Name: "qwen3.5-plus", Alias: alias},
{Name: "glm-5", Alias: alias},
}, executor)
streamResult, err := m.ExecuteStream(context.Background(), []string{"pool"}, cliproxyexecutor.Request{Model: alias}, cliproxyexecutor.Options{})
if err != nil {
t.Fatalf("execute stream: %v", err)
}
var payload []byte
for chunk := range streamResult.Chunks {
if chunk.Err != nil {
t.Fatalf("unexpected stream error: %v", chunk.Err)
}
payload = append(payload, chunk.Payload...)
}
if string(payload) != "glm-5" {
t.Fatalf("payload = %q, want %q", string(payload), "glm-5")
}
got := executor.StreamModels()
want := []string{"qwen3.5-plus", "glm-5"}
for i := range want {
if got[i] != want[i] {
t.Fatalf("stream call %d model = %q, want %q", i, got[i], want[i])
}
}
}
func TestManagerExecuteStream_OpenAICompatAliasPoolFallsBackBeforeFirstByte(t *testing.T) {
alias := "claude-opus-4.66"
executor := &openAICompatPoolExecutor{
id: "pool",
streamFirstErrors: map[string]error{"qwen3.5-plus": &Error{HTTPStatus: http.StatusTooManyRequests, Message: "quota"}},
}
m := newOpenAICompatPoolTestManager(t, alias, []internalconfig.OpenAICompatibilityModel{
{Name: "qwen3.5-plus", Alias: alias},
{Name: "glm-5", Alias: alias},
}, executor)
streamResult, err := m.ExecuteStream(context.Background(), []string{"pool"}, cliproxyexecutor.Request{Model: alias}, cliproxyexecutor.Options{})
if err != nil {
t.Fatalf("execute stream: %v", err)
}
var payload []byte
for chunk := range streamResult.Chunks {
if chunk.Err != nil {
t.Fatalf("unexpected stream error: %v", chunk.Err)
}
payload = append(payload, chunk.Payload...)
}
if string(payload) != "glm-5" {
t.Fatalf("payload = %q, want %q", string(payload), "glm-5")
}
got := executor.StreamModels()
want := []string{"qwen3.5-plus", "glm-5"}
for i := range want {
if got[i] != want[i] {
t.Fatalf("stream call %d model = %q, want %q", i, got[i], want[i])
}
}
if gotHeader := streamResult.Headers.Get("X-Model"); gotHeader != "glm-5" {
t.Fatalf("header X-Model = %q, want %q", gotHeader, "glm-5")
}
}
func TestManagerExecuteStream_OpenAICompatAliasPoolStopsOnInvalidRequest(t *testing.T) {
alias := "claude-opus-4.66"
invalidErr := &Error{HTTPStatus: http.StatusUnprocessableEntity, Message: "unprocessable entity"}
executor := &openAICompatPoolExecutor{
id: "pool",
streamFirstErrors: map[string]error{"qwen3.5-plus": invalidErr},
}
m := newOpenAICompatPoolTestManager(t, alias, []internalconfig.OpenAICompatibilityModel{
{Name: "qwen3.5-plus", Alias: alias},
{Name: "glm-5", Alias: alias},
}, executor)
_, err := m.ExecuteStream(context.Background(), []string{"pool"}, cliproxyexecutor.Request{Model: alias}, cliproxyexecutor.Options{})
if err == nil || err.Error() != invalidErr.Error() {
t.Fatalf("execute stream error = %v, want %v", err, invalidErr)
}
got := executor.StreamModels()
if len(got) != 1 || got[0] != "qwen3.5-plus" {
t.Fatalf("stream calls = %v, want only first invalid model", got)
}
}
func TestManagerExecuteCount_OpenAICompatAliasPoolRotatesWithinAuth(t *testing.T) {
alias := "claude-opus-4.66"
executor := &openAICompatPoolExecutor{id: "pool"}
m := newOpenAICompatPoolTestManager(t, alias, []internalconfig.OpenAICompatibilityModel{
{Name: "qwen3.5-plus", Alias: alias},
{Name: "glm-5", Alias: alias},
}, executor)
for i := 0; i < 2; i++ {
resp, err := m.ExecuteCount(context.Background(), []string{"pool"}, cliproxyexecutor.Request{Model: alias}, cliproxyexecutor.Options{})
if err != nil {
t.Fatalf("execute count %d: %v", i, err)
}
if len(resp.Payload) == 0 {
t.Fatalf("execute count %d returned empty payload", i)
}
}
got := executor.CountModels()
want := []string{"qwen3.5-plus", "glm-5"}
for i := range want {
if got[i] != want[i] {
t.Fatalf("count call %d model = %q, want %q", i, got[i], want[i])
}
}
}
func TestManagerExecuteStream_OpenAICompatAliasPoolStopsOnInvalidBootstrap(t *testing.T) {
alias := "claude-opus-4.66"
invalidErr := &Error{HTTPStatus: http.StatusBadRequest, Message: "invalid_request_error: malformed payload"}
executor := &openAICompatPoolExecutor{
id: "pool",
streamFirstErrors: map[string]error{"qwen3.5-plus": invalidErr},
}
m := newOpenAICompatPoolTestManager(t, alias, []internalconfig.OpenAICompatibilityModel{
{Name: "qwen3.5-plus", Alias: alias},
{Name: "glm-5", Alias: alias},
}, executor)
streamResult, err := m.ExecuteStream(context.Background(), []string{"pool"}, cliproxyexecutor.Request{Model: alias}, cliproxyexecutor.Options{})
if err == nil {
t.Fatal("expected invalid request error")
}
if err != invalidErr {
t.Fatalf("error = %v, want %v", err, invalidErr)
}
if streamResult != nil {
t.Fatalf("streamResult = %#v, want nil on invalid bootstrap", streamResult)
}
if got := executor.StreamModels(); len(got) != 1 || got[0] != "qwen3.5-plus" {
t.Fatalf("stream calls = %v, want only first upstream model", got)
}
}

View File

@@ -0,0 +1,904 @@
package auth
import (
"context"
"sort"
"strings"
"sync"
"time"
"github.com/router-for-me/CLIProxyAPI/v6/internal/registry"
cliproxyexecutor "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/executor"
)
// schedulerStrategy identifies which built-in routing semantics the scheduler should apply.
type schedulerStrategy int
const (
schedulerStrategyCustom schedulerStrategy = iota
schedulerStrategyRoundRobin
schedulerStrategyFillFirst
)
// scheduledState describes how an auth currently participates in a model shard.
type scheduledState int
const (
scheduledStateReady scheduledState = iota
scheduledStateCooldown
scheduledStateBlocked
scheduledStateDisabled
)
// authScheduler keeps the incremental provider/model scheduling state used by Manager.
type authScheduler struct {
mu sync.Mutex
strategy schedulerStrategy
providers map[string]*providerScheduler
authProviders map[string]string
mixedCursors map[string]int
}
// providerScheduler stores auth metadata and model shards for a single provider.
type providerScheduler struct {
providerKey string
auths map[string]*scheduledAuthMeta
modelShards map[string]*modelScheduler
}
// scheduledAuthMeta stores the immutable scheduling fields derived from an auth snapshot.
type scheduledAuthMeta struct {
auth *Auth
providerKey string
priority int
virtualParent string
websocketEnabled bool
supportedModelSet map[string]struct{}
}
// modelScheduler tracks ready and blocked auths for one provider/model combination.
type modelScheduler struct {
modelKey string
entries map[string]*scheduledAuth
priorityOrder []int
readyByPriority map[int]*readyBucket
blocked cooldownQueue
}
// scheduledAuth stores the runtime scheduling state for a single auth inside a model shard.
type scheduledAuth struct {
meta *scheduledAuthMeta
auth *Auth
state scheduledState
nextRetryAt time.Time
}
// readyBucket keeps the ready views for one priority level.
type readyBucket struct {
all readyView
ws readyView
}
// readyView holds the selection order for flat or grouped round-robin traversal.
type readyView struct {
flat []*scheduledAuth
cursor int
parentOrder []string
parentCursor int
children map[string]*childBucket
}
// childBucket keeps the per-parent rotation state for grouped Gemini virtual auths.
type childBucket struct {
items []*scheduledAuth
cursor int
}
// cooldownQueue is the blocked auth collection ordered by next retry time during rebuilds.
type cooldownQueue []*scheduledAuth
// newAuthScheduler constructs an empty scheduler configured for the supplied selector strategy.
func newAuthScheduler(selector Selector) *authScheduler {
return &authScheduler{
strategy: selectorStrategy(selector),
providers: make(map[string]*providerScheduler),
authProviders: make(map[string]string),
mixedCursors: make(map[string]int),
}
}
// selectorStrategy maps a selector implementation to the scheduler semantics it should emulate.
func selectorStrategy(selector Selector) schedulerStrategy {
switch selector.(type) {
case *FillFirstSelector:
return schedulerStrategyFillFirst
case nil, *RoundRobinSelector:
return schedulerStrategyRoundRobin
default:
return schedulerStrategyCustom
}
}
// setSelector updates the active built-in strategy and resets mixed-provider cursors.
func (s *authScheduler) setSelector(selector Selector) {
if s == nil {
return
}
s.mu.Lock()
defer s.mu.Unlock()
s.strategy = selectorStrategy(selector)
clear(s.mixedCursors)
}
// rebuild recreates the complete scheduler state from an auth snapshot.
func (s *authScheduler) rebuild(auths []*Auth) {
if s == nil {
return
}
s.mu.Lock()
defer s.mu.Unlock()
s.providers = make(map[string]*providerScheduler)
s.authProviders = make(map[string]string)
s.mixedCursors = make(map[string]int)
now := time.Now()
for _, auth := range auths {
s.upsertAuthLocked(auth, now)
}
}
// upsertAuth incrementally synchronizes one auth into the scheduler.
func (s *authScheduler) upsertAuth(auth *Auth) {
if s == nil {
return
}
s.mu.Lock()
defer s.mu.Unlock()
s.upsertAuthLocked(auth, time.Now())
}
// removeAuth deletes one auth from every scheduler shard that references it.
func (s *authScheduler) removeAuth(authID string) {
if s == nil {
return
}
authID = strings.TrimSpace(authID)
if authID == "" {
return
}
s.mu.Lock()
defer s.mu.Unlock()
s.removeAuthLocked(authID)
}
// pickSingle returns the next auth for a single provider/model request using scheduler state.
func (s *authScheduler) pickSingle(ctx context.Context, provider, model string, opts cliproxyexecutor.Options, tried map[string]struct{}) (*Auth, error) {
if s == nil {
return nil, &Error{Code: "auth_not_found", Message: "no auth available"}
}
providerKey := strings.ToLower(strings.TrimSpace(provider))
modelKey := canonicalModelKey(model)
pinnedAuthID := pinnedAuthIDFromMetadata(opts.Metadata)
preferWebsocket := cliproxyexecutor.DownstreamWebsocket(ctx) && providerKey == "codex" && pinnedAuthID == ""
s.mu.Lock()
defer s.mu.Unlock()
providerState := s.providers[providerKey]
if providerState == nil {
return nil, &Error{Code: "auth_not_found", Message: "no auth available"}
}
shard := providerState.ensureModelLocked(modelKey, time.Now())
if shard == nil {
return nil, &Error{Code: "auth_not_found", Message: "no auth available"}
}
predicate := func(entry *scheduledAuth) bool {
if entry == nil || entry.auth == nil {
return false
}
if pinnedAuthID != "" && entry.auth.ID != pinnedAuthID {
return false
}
if len(tried) > 0 {
if _, ok := tried[entry.auth.ID]; ok {
return false
}
}
return true
}
if picked := shard.pickReadyLocked(preferWebsocket, s.strategy, predicate); picked != nil {
return picked, nil
}
return nil, shard.unavailableErrorLocked(provider, model, predicate)
}
// pickMixed returns the next auth and provider for a mixed-provider request.
func (s *authScheduler) pickMixed(ctx context.Context, providers []string, model string, opts cliproxyexecutor.Options, tried map[string]struct{}) (*Auth, string, error) {
if s == nil {
return nil, "", &Error{Code: "auth_not_found", Message: "no auth available"}
}
normalized := normalizeProviderKeys(providers)
if len(normalized) == 0 {
return nil, "", &Error{Code: "provider_not_found", Message: "no provider supplied"}
}
pinnedAuthID := pinnedAuthIDFromMetadata(opts.Metadata)
modelKey := canonicalModelKey(model)
s.mu.Lock()
defer s.mu.Unlock()
if pinnedAuthID != "" {
providerKey := s.authProviders[pinnedAuthID]
if providerKey == "" || !containsProvider(normalized, providerKey) {
return nil, "", &Error{Code: "auth_not_found", Message: "no auth available"}
}
providerState := s.providers[providerKey]
if providerState == nil {
return nil, "", &Error{Code: "auth_not_found", Message: "no auth available"}
}
shard := providerState.ensureModelLocked(modelKey, time.Now())
predicate := func(entry *scheduledAuth) bool {
if entry == nil || entry.auth == nil || entry.auth.ID != pinnedAuthID {
return false
}
if len(tried) == 0 {
return true
}
_, ok := tried[pinnedAuthID]
return !ok
}
if picked := shard.pickReadyLocked(false, s.strategy, predicate); picked != nil {
return picked, providerKey, nil
}
return nil, "", shard.unavailableErrorLocked("mixed", model, predicate)
}
predicate := triedPredicate(tried)
candidateShards := make([]*modelScheduler, len(normalized))
bestPriority := 0
hasCandidate := false
now := time.Now()
for providerIndex, providerKey := range normalized {
providerState := s.providers[providerKey]
if providerState == nil {
continue
}
shard := providerState.ensureModelLocked(modelKey, now)
candidateShards[providerIndex] = shard
if shard == nil {
continue
}
priorityReady, okPriority := shard.highestReadyPriorityLocked(false, predicate)
if !okPriority {
continue
}
if !hasCandidate || priorityReady > bestPriority {
bestPriority = priorityReady
hasCandidate = true
}
}
if !hasCandidate {
return nil, "", s.mixedUnavailableErrorLocked(normalized, model, tried)
}
if s.strategy == schedulerStrategyFillFirst {
for providerIndex, providerKey := range normalized {
shard := candidateShards[providerIndex]
if shard == nil {
continue
}
picked := shard.pickReadyAtPriorityLocked(false, bestPriority, s.strategy, predicate)
if picked != nil {
return picked, providerKey, nil
}
}
return nil, "", s.mixedUnavailableErrorLocked(normalized, model, tried)
}
cursorKey := strings.Join(normalized, ",") + ":" + modelKey
start := 0
if len(normalized) > 0 {
start = s.mixedCursors[cursorKey] % len(normalized)
}
for offset := 0; offset < len(normalized); offset++ {
providerIndex := (start + offset) % len(normalized)
providerKey := normalized[providerIndex]
shard := candidateShards[providerIndex]
if shard == nil {
continue
}
picked := shard.pickReadyAtPriorityLocked(false, bestPriority, schedulerStrategyRoundRobin, predicate)
if picked == nil {
continue
}
s.mixedCursors[cursorKey] = providerIndex + 1
return picked, providerKey, nil
}
return nil, "", s.mixedUnavailableErrorLocked(normalized, model, tried)
}
// mixedUnavailableErrorLocked synthesizes the mixed-provider cooldown or unavailable error.
func (s *authScheduler) mixedUnavailableErrorLocked(providers []string, model string, tried map[string]struct{}) error {
now := time.Now()
total := 0
cooldownCount := 0
earliest := time.Time{}
for _, providerKey := range providers {
providerState := s.providers[providerKey]
if providerState == nil {
continue
}
shard := providerState.ensureModelLocked(canonicalModelKey(model), now)
if shard == nil {
continue
}
localTotal, localCooldownCount, localEarliest := shard.availabilitySummaryLocked(triedPredicate(tried))
total += localTotal
cooldownCount += localCooldownCount
if !localEarliest.IsZero() && (earliest.IsZero() || localEarliest.Before(earliest)) {
earliest = localEarliest
}
}
if total == 0 {
return &Error{Code: "auth_not_found", Message: "no auth available"}
}
if cooldownCount == total && !earliest.IsZero() {
resetIn := earliest.Sub(now)
if resetIn < 0 {
resetIn = 0
}
return newModelCooldownError(model, "", resetIn)
}
return &Error{Code: "auth_unavailable", Message: "no auth available"}
}
// triedPredicate builds a filter that excludes auths already attempted for the current request.
func triedPredicate(tried map[string]struct{}) func(*scheduledAuth) bool {
if len(tried) == 0 {
return func(entry *scheduledAuth) bool { return entry != nil && entry.auth != nil }
}
return func(entry *scheduledAuth) bool {
if entry == nil || entry.auth == nil {
return false
}
_, ok := tried[entry.auth.ID]
return !ok
}
}
// normalizeProviderKeys lowercases, trims, and de-duplicates provider keys while preserving order.
func normalizeProviderKeys(providers []string) []string {
seen := make(map[string]struct{}, len(providers))
out := make([]string, 0, len(providers))
for _, provider := range providers {
providerKey := strings.ToLower(strings.TrimSpace(provider))
if providerKey == "" {
continue
}
if _, ok := seen[providerKey]; ok {
continue
}
seen[providerKey] = struct{}{}
out = append(out, providerKey)
}
return out
}
// containsProvider reports whether provider is present in the normalized provider list.
func containsProvider(providers []string, provider string) bool {
for _, candidate := range providers {
if candidate == provider {
return true
}
}
return false
}
// upsertAuthLocked updates one auth in-place while the scheduler mutex is held.
func (s *authScheduler) upsertAuthLocked(auth *Auth, now time.Time) {
if auth == nil {
return
}
authID := strings.TrimSpace(auth.ID)
providerKey := strings.ToLower(strings.TrimSpace(auth.Provider))
if authID == "" || providerKey == "" || auth.Disabled {
s.removeAuthLocked(authID)
return
}
if previousProvider := s.authProviders[authID]; previousProvider != "" && previousProvider != providerKey {
if previousState := s.providers[previousProvider]; previousState != nil {
previousState.removeAuthLocked(authID)
}
}
meta := buildScheduledAuthMeta(auth)
s.authProviders[authID] = providerKey
s.ensureProviderLocked(providerKey).upsertAuthLocked(meta, now)
}
// removeAuthLocked removes one auth from the scheduler while the scheduler mutex is held.
func (s *authScheduler) removeAuthLocked(authID string) {
if authID == "" {
return
}
if providerKey := s.authProviders[authID]; providerKey != "" {
if providerState := s.providers[providerKey]; providerState != nil {
providerState.removeAuthLocked(authID)
}
delete(s.authProviders, authID)
}
}
// ensureProviderLocked returns the provider scheduler for providerKey, creating it when needed.
func (s *authScheduler) ensureProviderLocked(providerKey string) *providerScheduler {
if s.providers == nil {
s.providers = make(map[string]*providerScheduler)
}
providerState := s.providers[providerKey]
if providerState == nil {
providerState = &providerScheduler{
providerKey: providerKey,
auths: make(map[string]*scheduledAuthMeta),
modelShards: make(map[string]*modelScheduler),
}
s.providers[providerKey] = providerState
}
return providerState
}
// buildScheduledAuthMeta extracts the scheduling metadata needed for shard bookkeeping.
func buildScheduledAuthMeta(auth *Auth) *scheduledAuthMeta {
providerKey := strings.ToLower(strings.TrimSpace(auth.Provider))
virtualParent := ""
if auth.Attributes != nil {
virtualParent = strings.TrimSpace(auth.Attributes["gemini_virtual_parent"])
}
return &scheduledAuthMeta{
auth: auth,
providerKey: providerKey,
priority: authPriority(auth),
virtualParent: virtualParent,
websocketEnabled: authWebsocketsEnabled(auth),
supportedModelSet: supportedModelSetForAuth(auth.ID),
}
}
// supportedModelSetForAuth snapshots the registry models currently registered for an auth.
func supportedModelSetForAuth(authID string) map[string]struct{} {
authID = strings.TrimSpace(authID)
if authID == "" {
return nil
}
models := registry.GetGlobalRegistry().GetModelsForClient(authID)
if len(models) == 0 {
return nil
}
set := make(map[string]struct{}, len(models))
for _, model := range models {
if model == nil {
continue
}
modelKey := canonicalModelKey(model.ID)
if modelKey == "" {
continue
}
set[modelKey] = struct{}{}
}
return set
}
// upsertAuthLocked updates every existing model shard that can reference the auth metadata.
func (p *providerScheduler) upsertAuthLocked(meta *scheduledAuthMeta, now time.Time) {
if p == nil || meta == nil || meta.auth == nil {
return
}
p.auths[meta.auth.ID] = meta
for modelKey, shard := range p.modelShards {
if shard == nil {
continue
}
if !meta.supportsModel(modelKey) {
shard.removeEntryLocked(meta.auth.ID)
continue
}
shard.upsertEntryLocked(meta, now)
}
}
// removeAuthLocked removes an auth from all model shards owned by the provider scheduler.
func (p *providerScheduler) removeAuthLocked(authID string) {
if p == nil || authID == "" {
return
}
delete(p.auths, authID)
for _, shard := range p.modelShards {
if shard != nil {
shard.removeEntryLocked(authID)
}
}
}
// ensureModelLocked returns the shard for modelKey, building it lazily from provider auths.
func (p *providerScheduler) ensureModelLocked(modelKey string, now time.Time) *modelScheduler {
if p == nil {
return nil
}
modelKey = canonicalModelKey(modelKey)
if shard, ok := p.modelShards[modelKey]; ok && shard != nil {
shard.promoteExpiredLocked(now)
return shard
}
shard := &modelScheduler{
modelKey: modelKey,
entries: make(map[string]*scheduledAuth),
readyByPriority: make(map[int]*readyBucket),
}
for _, meta := range p.auths {
if meta == nil || !meta.supportsModel(modelKey) {
continue
}
shard.upsertEntryLocked(meta, now)
}
p.modelShards[modelKey] = shard
return shard
}
// supportsModel reports whether the auth metadata currently supports modelKey.
func (m *scheduledAuthMeta) supportsModel(modelKey string) bool {
modelKey = canonicalModelKey(modelKey)
if modelKey == "" {
return true
}
if len(m.supportedModelSet) == 0 {
return false
}
_, ok := m.supportedModelSet[modelKey]
return ok
}
// upsertEntryLocked updates or inserts one auth entry and rebuilds indexes when ordering changes.
func (m *modelScheduler) upsertEntryLocked(meta *scheduledAuthMeta, now time.Time) {
if m == nil || meta == nil || meta.auth == nil {
return
}
entry, ok := m.entries[meta.auth.ID]
if !ok || entry == nil {
entry = &scheduledAuth{}
m.entries[meta.auth.ID] = entry
}
previousState := entry.state
previousNextRetryAt := entry.nextRetryAt
previousPriority := 0
previousParent := ""
previousWebsocketEnabled := false
if entry.meta != nil {
previousPriority = entry.meta.priority
previousParent = entry.meta.virtualParent
previousWebsocketEnabled = entry.meta.websocketEnabled
}
entry.meta = meta
entry.auth = meta.auth
entry.nextRetryAt = time.Time{}
blocked, reason, next := isAuthBlockedForModel(meta.auth, m.modelKey, now)
switch {
case !blocked:
entry.state = scheduledStateReady
case reason == blockReasonCooldown:
entry.state = scheduledStateCooldown
entry.nextRetryAt = next
case reason == blockReasonDisabled:
entry.state = scheduledStateDisabled
default:
entry.state = scheduledStateBlocked
entry.nextRetryAt = next
}
if ok && previousState == entry.state && previousNextRetryAt.Equal(entry.nextRetryAt) && previousPriority == meta.priority && previousParent == meta.virtualParent && previousWebsocketEnabled == meta.websocketEnabled {
return
}
m.rebuildIndexesLocked()
}
// removeEntryLocked deletes one auth entry and rebuilds the shard indexes if needed.
func (m *modelScheduler) removeEntryLocked(authID string) {
if m == nil || authID == "" {
return
}
if _, ok := m.entries[authID]; !ok {
return
}
delete(m.entries, authID)
m.rebuildIndexesLocked()
}
// promoteExpiredLocked reevaluates blocked auths whose retry time has elapsed.
func (m *modelScheduler) promoteExpiredLocked(now time.Time) {
if m == nil || len(m.blocked) == 0 {
return
}
changed := false
for _, entry := range m.blocked {
if entry == nil || entry.auth == nil {
continue
}
if entry.nextRetryAt.IsZero() || entry.nextRetryAt.After(now) {
continue
}
blocked, reason, next := isAuthBlockedForModel(entry.auth, m.modelKey, now)
switch {
case !blocked:
entry.state = scheduledStateReady
entry.nextRetryAt = time.Time{}
case reason == blockReasonCooldown:
entry.state = scheduledStateCooldown
entry.nextRetryAt = next
case reason == blockReasonDisabled:
entry.state = scheduledStateDisabled
entry.nextRetryAt = time.Time{}
default:
entry.state = scheduledStateBlocked
entry.nextRetryAt = next
}
changed = true
}
if changed {
m.rebuildIndexesLocked()
}
}
// pickReadyLocked selects the next ready auth from the highest available priority bucket.
func (m *modelScheduler) pickReadyLocked(preferWebsocket bool, strategy schedulerStrategy, predicate func(*scheduledAuth) bool) *Auth {
if m == nil {
return nil
}
m.promoteExpiredLocked(time.Now())
priorityReady, okPriority := m.highestReadyPriorityLocked(preferWebsocket, predicate)
if !okPriority {
return nil
}
return m.pickReadyAtPriorityLocked(preferWebsocket, priorityReady, strategy, predicate)
}
// highestReadyPriorityLocked returns the highest priority bucket that still has a matching ready auth.
// The caller must ensure expired entries are already promoted when needed.
func (m *modelScheduler) highestReadyPriorityLocked(preferWebsocket bool, predicate func(*scheduledAuth) bool) (int, bool) {
if m == nil {
return 0, false
}
for _, priority := range m.priorityOrder {
bucket := m.readyByPriority[priority]
if bucket == nil {
continue
}
view := &bucket.all
if preferWebsocket && len(bucket.ws.flat) > 0 {
view = &bucket.ws
}
if view.pickFirst(predicate) != nil {
return priority, true
}
}
return 0, false
}
// pickReadyAtPriorityLocked selects the next ready auth from a specific priority bucket.
// The caller must ensure expired entries are already promoted when needed.
func (m *modelScheduler) pickReadyAtPriorityLocked(preferWebsocket bool, priority int, strategy schedulerStrategy, predicate func(*scheduledAuth) bool) *Auth {
if m == nil {
return nil
}
bucket := m.readyByPriority[priority]
if bucket == nil {
return nil
}
view := &bucket.all
if preferWebsocket && len(bucket.ws.flat) > 0 {
view = &bucket.ws
}
var picked *scheduledAuth
if strategy == schedulerStrategyFillFirst {
picked = view.pickFirst(predicate)
} else {
picked = view.pickRoundRobin(predicate)
}
if picked == nil || picked.auth == nil {
return nil
}
return picked.auth
}
// unavailableErrorLocked returns the correct unavailable or cooldown error for the shard.
func (m *modelScheduler) unavailableErrorLocked(provider, model string, predicate func(*scheduledAuth) bool) error {
now := time.Now()
total, cooldownCount, earliest := m.availabilitySummaryLocked(predicate)
if total == 0 {
return &Error{Code: "auth_not_found", Message: "no auth available"}
}
if cooldownCount == total && !earliest.IsZero() {
providerForError := provider
if providerForError == "mixed" {
providerForError = ""
}
resetIn := earliest.Sub(now)
if resetIn < 0 {
resetIn = 0
}
return newModelCooldownError(model, providerForError, resetIn)
}
return &Error{Code: "auth_unavailable", Message: "no auth available"}
}
// availabilitySummaryLocked summarizes total candidates, cooldown count, and earliest retry time.
func (m *modelScheduler) availabilitySummaryLocked(predicate func(*scheduledAuth) bool) (int, int, time.Time) {
if m == nil {
return 0, 0, time.Time{}
}
total := 0
cooldownCount := 0
earliest := time.Time{}
for _, entry := range m.entries {
if predicate != nil && !predicate(entry) {
continue
}
total++
if entry == nil || entry.auth == nil {
continue
}
if entry.state != scheduledStateCooldown {
continue
}
cooldownCount++
if !entry.nextRetryAt.IsZero() && (earliest.IsZero() || entry.nextRetryAt.Before(earliest)) {
earliest = entry.nextRetryAt
}
}
return total, cooldownCount, earliest
}
// rebuildIndexesLocked reconstructs ready and blocked views from the current entry map.
func (m *modelScheduler) rebuildIndexesLocked() {
m.readyByPriority = make(map[int]*readyBucket)
m.priorityOrder = m.priorityOrder[:0]
m.blocked = m.blocked[:0]
priorityBuckets := make(map[int][]*scheduledAuth)
for _, entry := range m.entries {
if entry == nil || entry.auth == nil {
continue
}
switch entry.state {
case scheduledStateReady:
priority := entry.meta.priority
priorityBuckets[priority] = append(priorityBuckets[priority], entry)
case scheduledStateCooldown, scheduledStateBlocked:
m.blocked = append(m.blocked, entry)
}
}
for priority, entries := range priorityBuckets {
sort.Slice(entries, func(i, j int) bool {
return entries[i].auth.ID < entries[j].auth.ID
})
m.readyByPriority[priority] = buildReadyBucket(entries)
m.priorityOrder = append(m.priorityOrder, priority)
}
sort.Slice(m.priorityOrder, func(i, j int) bool {
return m.priorityOrder[i] > m.priorityOrder[j]
})
sort.Slice(m.blocked, func(i, j int) bool {
left := m.blocked[i]
right := m.blocked[j]
if left == nil || right == nil {
return left != nil
}
if left.nextRetryAt.Equal(right.nextRetryAt) {
return left.auth.ID < right.auth.ID
}
if left.nextRetryAt.IsZero() {
return false
}
if right.nextRetryAt.IsZero() {
return true
}
return left.nextRetryAt.Before(right.nextRetryAt)
})
}
// buildReadyBucket prepares the general and websocket-only ready views for one priority bucket.
func buildReadyBucket(entries []*scheduledAuth) *readyBucket {
bucket := &readyBucket{}
bucket.all = buildReadyView(entries)
wsEntries := make([]*scheduledAuth, 0, len(entries))
for _, entry := range entries {
if entry != nil && entry.meta != nil && entry.meta.websocketEnabled {
wsEntries = append(wsEntries, entry)
}
}
bucket.ws = buildReadyView(wsEntries)
return bucket
}
// buildReadyView creates either a flat view or a grouped parent/child view for rotation.
func buildReadyView(entries []*scheduledAuth) readyView {
view := readyView{flat: append([]*scheduledAuth(nil), entries...)}
if len(entries) == 0 {
return view
}
groups := make(map[string][]*scheduledAuth)
for _, entry := range entries {
if entry == nil || entry.meta == nil || entry.meta.virtualParent == "" {
return view
}
groups[entry.meta.virtualParent] = append(groups[entry.meta.virtualParent], entry)
}
if len(groups) <= 1 {
return view
}
view.children = make(map[string]*childBucket, len(groups))
view.parentOrder = make([]string, 0, len(groups))
for parent := range groups {
view.parentOrder = append(view.parentOrder, parent)
}
sort.Strings(view.parentOrder)
for _, parent := range view.parentOrder {
view.children[parent] = &childBucket{items: append([]*scheduledAuth(nil), groups[parent]...)}
}
return view
}
// pickFirst returns the first ready entry that satisfies predicate without advancing cursors.
func (v *readyView) pickFirst(predicate func(*scheduledAuth) bool) *scheduledAuth {
for _, entry := range v.flat {
if predicate == nil || predicate(entry) {
return entry
}
}
return nil
}
// pickRoundRobin returns the next ready entry using flat or grouped round-robin traversal.
func (v *readyView) pickRoundRobin(predicate func(*scheduledAuth) bool) *scheduledAuth {
if len(v.parentOrder) > 1 && len(v.children) > 0 {
return v.pickGroupedRoundRobin(predicate)
}
if len(v.flat) == 0 {
return nil
}
start := 0
if len(v.flat) > 0 {
start = v.cursor % len(v.flat)
}
for offset := 0; offset < len(v.flat); offset++ {
index := (start + offset) % len(v.flat)
entry := v.flat[index]
if predicate != nil && !predicate(entry) {
continue
}
v.cursor = index + 1
return entry
}
return nil
}
// pickGroupedRoundRobin rotates across parents first and then within the selected parent.
func (v *readyView) pickGroupedRoundRobin(predicate func(*scheduledAuth) bool) *scheduledAuth {
start := 0
if len(v.parentOrder) > 0 {
start = v.parentCursor % len(v.parentOrder)
}
for offset := 0; offset < len(v.parentOrder); offset++ {
parentIndex := (start + offset) % len(v.parentOrder)
parent := v.parentOrder[parentIndex]
child := v.children[parent]
if child == nil || len(child.items) == 0 {
continue
}
itemStart := child.cursor % len(child.items)
for itemOffset := 0; itemOffset < len(child.items); itemOffset++ {
itemIndex := (itemStart + itemOffset) % len(child.items)
entry := child.items[itemIndex]
if predicate != nil && !predicate(entry) {
continue
}
child.cursor = itemIndex + 1
v.parentCursor = parentIndex + 1
return entry
}
}
return nil
}

View File

@@ -0,0 +1,216 @@
package auth
import (
"context"
"fmt"
"net/http"
"testing"
"github.com/router-for-me/CLIProxyAPI/v6/internal/registry"
cliproxyexecutor "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/executor"
)
type schedulerBenchmarkExecutor struct {
id string
}
func (e schedulerBenchmarkExecutor) Identifier() string { return e.id }
func (e schedulerBenchmarkExecutor) Execute(ctx context.Context, auth *Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (cliproxyexecutor.Response, error) {
return cliproxyexecutor.Response{}, nil
}
func (e schedulerBenchmarkExecutor) ExecuteStream(ctx context.Context, auth *Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (*cliproxyexecutor.StreamResult, error) {
return nil, nil
}
func (e schedulerBenchmarkExecutor) Refresh(ctx context.Context, auth *Auth) (*Auth, error) {
return auth, nil
}
func (e schedulerBenchmarkExecutor) CountTokens(ctx context.Context, auth *Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (cliproxyexecutor.Response, error) {
return cliproxyexecutor.Response{}, nil
}
func (e schedulerBenchmarkExecutor) HttpRequest(ctx context.Context, auth *Auth, req *http.Request) (*http.Response, error) {
return nil, nil
}
func benchmarkManagerSetup(b *testing.B, total int, mixed bool, withPriority bool) (*Manager, []string, string) {
b.Helper()
manager := NewManager(nil, &RoundRobinSelector{}, nil)
providers := []string{"gemini"}
manager.executors["gemini"] = schedulerBenchmarkExecutor{id: "gemini"}
if mixed {
providers = []string{"gemini", "claude"}
manager.executors["claude"] = schedulerBenchmarkExecutor{id: "claude"}
}
reg := registry.GetGlobalRegistry()
model := "bench-model"
for index := 0; index < total; index++ {
provider := providers[0]
if mixed && index%2 == 1 {
provider = providers[1]
}
auth := &Auth{ID: fmt.Sprintf("bench-%s-%04d", provider, index), Provider: provider}
if withPriority {
priority := "0"
if index%2 == 0 {
priority = "10"
}
auth.Attributes = map[string]string{"priority": priority}
}
_, errRegister := manager.Register(context.Background(), auth)
if errRegister != nil {
b.Fatalf("Register(%s) error = %v", auth.ID, errRegister)
}
reg.RegisterClient(auth.ID, provider, []*registry.ModelInfo{{ID: model}})
}
manager.syncScheduler()
b.Cleanup(func() {
for index := 0; index < total; index++ {
provider := providers[0]
if mixed && index%2 == 1 {
provider = providers[1]
}
reg.UnregisterClient(fmt.Sprintf("bench-%s-%04d", provider, index))
}
})
return manager, providers, model
}
func BenchmarkManagerPickNext500(b *testing.B) {
manager, _, model := benchmarkManagerSetup(b, 500, false, false)
ctx := context.Background()
opts := cliproxyexecutor.Options{}
tried := map[string]struct{}{}
if _, _, errWarm := manager.pickNext(ctx, "gemini", model, opts, tried); errWarm != nil {
b.Fatalf("warmup pickNext error = %v", errWarm)
}
b.ReportAllocs()
b.ResetTimer()
for i := 0; i < b.N; i++ {
auth, exec, errPick := manager.pickNext(ctx, "gemini", model, opts, tried)
if errPick != nil || auth == nil || exec == nil {
b.Fatalf("pickNext failed: auth=%v exec=%v err=%v", auth, exec, errPick)
}
}
}
func BenchmarkManagerPickNext1000(b *testing.B) {
manager, _, model := benchmarkManagerSetup(b, 1000, false, false)
ctx := context.Background()
opts := cliproxyexecutor.Options{}
tried := map[string]struct{}{}
if _, _, errWarm := manager.pickNext(ctx, "gemini", model, opts, tried); errWarm != nil {
b.Fatalf("warmup pickNext error = %v", errWarm)
}
b.ReportAllocs()
b.ResetTimer()
for i := 0; i < b.N; i++ {
auth, exec, errPick := manager.pickNext(ctx, "gemini", model, opts, tried)
if errPick != nil || auth == nil || exec == nil {
b.Fatalf("pickNext failed: auth=%v exec=%v err=%v", auth, exec, errPick)
}
}
}
func BenchmarkManagerPickNextPriority500(b *testing.B) {
manager, _, model := benchmarkManagerSetup(b, 500, false, true)
ctx := context.Background()
opts := cliproxyexecutor.Options{}
tried := map[string]struct{}{}
if _, _, errWarm := manager.pickNext(ctx, "gemini", model, opts, tried); errWarm != nil {
b.Fatalf("warmup pickNext error = %v", errWarm)
}
b.ReportAllocs()
b.ResetTimer()
for i := 0; i < b.N; i++ {
auth, exec, errPick := manager.pickNext(ctx, "gemini", model, opts, tried)
if errPick != nil || auth == nil || exec == nil {
b.Fatalf("pickNext failed: auth=%v exec=%v err=%v", auth, exec, errPick)
}
}
}
func BenchmarkManagerPickNextPriority1000(b *testing.B) {
manager, _, model := benchmarkManagerSetup(b, 1000, false, true)
ctx := context.Background()
opts := cliproxyexecutor.Options{}
tried := map[string]struct{}{}
if _, _, errWarm := manager.pickNext(ctx, "gemini", model, opts, tried); errWarm != nil {
b.Fatalf("warmup pickNext error = %v", errWarm)
}
b.ReportAllocs()
b.ResetTimer()
for i := 0; i < b.N; i++ {
auth, exec, errPick := manager.pickNext(ctx, "gemini", model, opts, tried)
if errPick != nil || auth == nil || exec == nil {
b.Fatalf("pickNext failed: auth=%v exec=%v err=%v", auth, exec, errPick)
}
}
}
func BenchmarkManagerPickNextMixed500(b *testing.B) {
manager, providers, model := benchmarkManagerSetup(b, 500, true, false)
ctx := context.Background()
opts := cliproxyexecutor.Options{}
tried := map[string]struct{}{}
if _, _, _, errWarm := manager.pickNextMixed(ctx, providers, model, opts, tried); errWarm != nil {
b.Fatalf("warmup pickNextMixed error = %v", errWarm)
}
b.ReportAllocs()
b.ResetTimer()
for i := 0; i < b.N; i++ {
auth, exec, provider, errPick := manager.pickNextMixed(ctx, providers, model, opts, tried)
if errPick != nil || auth == nil || exec == nil || provider == "" {
b.Fatalf("pickNextMixed failed: auth=%v exec=%v provider=%q err=%v", auth, exec, provider, errPick)
}
}
}
func BenchmarkManagerPickNextMixedPriority500(b *testing.B) {
manager, providers, model := benchmarkManagerSetup(b, 500, true, true)
ctx := context.Background()
opts := cliproxyexecutor.Options{}
tried := map[string]struct{}{}
if _, _, _, errWarm := manager.pickNextMixed(ctx, providers, model, opts, tried); errWarm != nil {
b.Fatalf("warmup pickNextMixed error = %v", errWarm)
}
b.ReportAllocs()
b.ResetTimer()
for i := 0; i < b.N; i++ {
auth, exec, provider, errPick := manager.pickNextMixed(ctx, providers, model, opts, tried)
if errPick != nil || auth == nil || exec == nil || provider == "" {
b.Fatalf("pickNextMixed failed: auth=%v exec=%v provider=%q err=%v", auth, exec, provider, errPick)
}
}
}
func BenchmarkManagerPickNextAndMarkResult1000(b *testing.B) {
manager, _, model := benchmarkManagerSetup(b, 1000, false, false)
ctx := context.Background()
opts := cliproxyexecutor.Options{}
tried := map[string]struct{}{}
if _, _, errWarm := manager.pickNext(ctx, "gemini", model, opts, tried); errWarm != nil {
b.Fatalf("warmup pickNext error = %v", errWarm)
}
b.ReportAllocs()
b.ResetTimer()
for i := 0; i < b.N; i++ {
auth, _, errPick := manager.pickNext(ctx, "gemini", model, opts, tried)
if errPick != nil || auth == nil {
b.Fatalf("pickNext failed: auth=%v err=%v", auth, errPick)
}
manager.MarkResult(ctx, Result{AuthID: auth.ID, Provider: "gemini", Model: model, Success: true})
}
}

View File

@@ -0,0 +1,503 @@
package auth
import (
"context"
"net/http"
"testing"
"time"
"github.com/router-for-me/CLIProxyAPI/v6/internal/registry"
cliproxyexecutor "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/executor"
)
type schedulerTestExecutor struct{}
func (schedulerTestExecutor) Identifier() string { return "test" }
func (schedulerTestExecutor) Execute(ctx context.Context, auth *Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (cliproxyexecutor.Response, error) {
return cliproxyexecutor.Response{}, nil
}
func (schedulerTestExecutor) ExecuteStream(ctx context.Context, auth *Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (*cliproxyexecutor.StreamResult, error) {
return nil, nil
}
func (schedulerTestExecutor) Refresh(ctx context.Context, auth *Auth) (*Auth, error) {
return auth, nil
}
func (schedulerTestExecutor) CountTokens(ctx context.Context, auth *Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (cliproxyexecutor.Response, error) {
return cliproxyexecutor.Response{}, nil
}
func (schedulerTestExecutor) HttpRequest(ctx context.Context, auth *Auth, req *http.Request) (*http.Response, error) {
return nil, nil
}
type trackingSelector struct {
calls int
lastAuthID []string
}
func (s *trackingSelector) Pick(ctx context.Context, provider, model string, opts cliproxyexecutor.Options, auths []*Auth) (*Auth, error) {
s.calls++
s.lastAuthID = s.lastAuthID[:0]
for _, auth := range auths {
s.lastAuthID = append(s.lastAuthID, auth.ID)
}
if len(auths) == 0 {
return nil, nil
}
return auths[len(auths)-1], nil
}
func newSchedulerForTest(selector Selector, auths ...*Auth) *authScheduler {
scheduler := newAuthScheduler(selector)
scheduler.rebuild(auths)
return scheduler
}
func registerSchedulerModels(t *testing.T, provider string, model string, authIDs ...string) {
t.Helper()
reg := registry.GetGlobalRegistry()
for _, authID := range authIDs {
reg.RegisterClient(authID, provider, []*registry.ModelInfo{{ID: model}})
}
t.Cleanup(func() {
for _, authID := range authIDs {
reg.UnregisterClient(authID)
}
})
}
func TestSchedulerPick_RoundRobinHighestPriority(t *testing.T) {
t.Parallel()
scheduler := newSchedulerForTest(
&RoundRobinSelector{},
&Auth{ID: "low", Provider: "gemini", Attributes: map[string]string{"priority": "0"}},
&Auth{ID: "high-b", Provider: "gemini", Attributes: map[string]string{"priority": "10"}},
&Auth{ID: "high-a", Provider: "gemini", Attributes: map[string]string{"priority": "10"}},
)
want := []string{"high-a", "high-b", "high-a"}
for index, wantID := range want {
got, errPick := scheduler.pickSingle(context.Background(), "gemini", "", cliproxyexecutor.Options{}, nil)
if errPick != nil {
t.Fatalf("pickSingle() #%d error = %v", index, errPick)
}
if got == nil {
t.Fatalf("pickSingle() #%d auth = nil", index)
}
if got.ID != wantID {
t.Fatalf("pickSingle() #%d auth.ID = %q, want %q", index, got.ID, wantID)
}
}
}
func TestSchedulerPick_FillFirstSticksToFirstReady(t *testing.T) {
t.Parallel()
scheduler := newSchedulerForTest(
&FillFirstSelector{},
&Auth{ID: "b", Provider: "gemini"},
&Auth{ID: "a", Provider: "gemini"},
&Auth{ID: "c", Provider: "gemini"},
)
for index := 0; index < 3; index++ {
got, errPick := scheduler.pickSingle(context.Background(), "gemini", "", cliproxyexecutor.Options{}, nil)
if errPick != nil {
t.Fatalf("pickSingle() #%d error = %v", index, errPick)
}
if got == nil {
t.Fatalf("pickSingle() #%d auth = nil", index)
}
if got.ID != "a" {
t.Fatalf("pickSingle() #%d auth.ID = %q, want %q", index, got.ID, "a")
}
}
}
func TestSchedulerPick_PromotesExpiredCooldownBeforePick(t *testing.T) {
t.Parallel()
model := "gemini-2.5-pro"
registerSchedulerModels(t, "gemini", model, "cooldown-expired")
scheduler := newSchedulerForTest(
&RoundRobinSelector{},
&Auth{
ID: "cooldown-expired",
Provider: "gemini",
ModelStates: map[string]*ModelState{
model: {
Status: StatusError,
Unavailable: true,
NextRetryAfter: time.Now().Add(-1 * time.Second),
},
},
},
)
got, errPick := scheduler.pickSingle(context.Background(), "gemini", model, cliproxyexecutor.Options{}, nil)
if errPick != nil {
t.Fatalf("pickSingle() error = %v", errPick)
}
if got == nil {
t.Fatalf("pickSingle() auth = nil")
}
if got.ID != "cooldown-expired" {
t.Fatalf("pickSingle() auth.ID = %q, want %q", got.ID, "cooldown-expired")
}
}
func TestSchedulerPick_GeminiVirtualParentUsesTwoLevelRotation(t *testing.T) {
t.Parallel()
registerSchedulerModels(t, "gemini-cli", "gemini-2.5-pro", "cred-a::proj-1", "cred-a::proj-2", "cred-b::proj-1", "cred-b::proj-2")
scheduler := newSchedulerForTest(
&RoundRobinSelector{},
&Auth{ID: "cred-a::proj-1", Provider: "gemini-cli", Attributes: map[string]string{"gemini_virtual_parent": "cred-a"}},
&Auth{ID: "cred-a::proj-2", Provider: "gemini-cli", Attributes: map[string]string{"gemini_virtual_parent": "cred-a"}},
&Auth{ID: "cred-b::proj-1", Provider: "gemini-cli", Attributes: map[string]string{"gemini_virtual_parent": "cred-b"}},
&Auth{ID: "cred-b::proj-2", Provider: "gemini-cli", Attributes: map[string]string{"gemini_virtual_parent": "cred-b"}},
)
wantParents := []string{"cred-a", "cred-b", "cred-a", "cred-b"}
wantIDs := []string{"cred-a::proj-1", "cred-b::proj-1", "cred-a::proj-2", "cred-b::proj-2"}
for index := range wantIDs {
got, errPick := scheduler.pickSingle(context.Background(), "gemini-cli", "gemini-2.5-pro", cliproxyexecutor.Options{}, nil)
if errPick != nil {
t.Fatalf("pickSingle() #%d error = %v", index, errPick)
}
if got == nil {
t.Fatalf("pickSingle() #%d auth = nil", index)
}
if got.ID != wantIDs[index] {
t.Fatalf("pickSingle() #%d auth.ID = %q, want %q", index, got.ID, wantIDs[index])
}
if got.Attributes["gemini_virtual_parent"] != wantParents[index] {
t.Fatalf("pickSingle() #%d parent = %q, want %q", index, got.Attributes["gemini_virtual_parent"], wantParents[index])
}
}
}
func TestSchedulerPick_CodexWebsocketPrefersWebsocketEnabledSubset(t *testing.T) {
t.Parallel()
scheduler := newSchedulerForTest(
&RoundRobinSelector{},
&Auth{ID: "codex-http", Provider: "codex"},
&Auth{ID: "codex-ws-a", Provider: "codex", Attributes: map[string]string{"websockets": "true"}},
&Auth{ID: "codex-ws-b", Provider: "codex", Attributes: map[string]string{"websockets": "true"}},
)
ctx := cliproxyexecutor.WithDownstreamWebsocket(context.Background())
want := []string{"codex-ws-a", "codex-ws-b", "codex-ws-a"}
for index, wantID := range want {
got, errPick := scheduler.pickSingle(ctx, "codex", "", cliproxyexecutor.Options{}, nil)
if errPick != nil {
t.Fatalf("pickSingle() #%d error = %v", index, errPick)
}
if got == nil {
t.Fatalf("pickSingle() #%d auth = nil", index)
}
if got.ID != wantID {
t.Fatalf("pickSingle() #%d auth.ID = %q, want %q", index, got.ID, wantID)
}
}
}
func TestSchedulerPick_MixedProvidersUsesProviderRotationOverReadyCandidates(t *testing.T) {
t.Parallel()
scheduler := newSchedulerForTest(
&RoundRobinSelector{},
&Auth{ID: "gemini-a", Provider: "gemini"},
&Auth{ID: "gemini-b", Provider: "gemini"},
&Auth{ID: "claude-a", Provider: "claude"},
)
wantProviders := []string{"gemini", "claude", "gemini", "claude"}
wantIDs := []string{"gemini-a", "claude-a", "gemini-b", "claude-a"}
for index := range wantProviders {
got, provider, errPick := scheduler.pickMixed(context.Background(), []string{"gemini", "claude"}, "", cliproxyexecutor.Options{}, nil)
if errPick != nil {
t.Fatalf("pickMixed() #%d error = %v", index, errPick)
}
if got == nil {
t.Fatalf("pickMixed() #%d auth = nil", index)
}
if provider != wantProviders[index] {
t.Fatalf("pickMixed() #%d provider = %q, want %q", index, provider, wantProviders[index])
}
if got.ID != wantIDs[index] {
t.Fatalf("pickMixed() #%d auth.ID = %q, want %q", index, got.ID, wantIDs[index])
}
}
}
func TestSchedulerPick_MixedProvidersPrefersHighestPriorityTier(t *testing.T) {
t.Parallel()
model := "gpt-default"
registerSchedulerModels(t, "provider-low", model, "low")
registerSchedulerModels(t, "provider-high-a", model, "high-a")
registerSchedulerModels(t, "provider-high-b", model, "high-b")
scheduler := newSchedulerForTest(
&RoundRobinSelector{},
&Auth{ID: "low", Provider: "provider-low", Attributes: map[string]string{"priority": "4"}},
&Auth{ID: "high-a", Provider: "provider-high-a", Attributes: map[string]string{"priority": "7"}},
&Auth{ID: "high-b", Provider: "provider-high-b", Attributes: map[string]string{"priority": "7"}},
)
providers := []string{"provider-low", "provider-high-a", "provider-high-b"}
wantProviders := []string{"provider-high-a", "provider-high-b", "provider-high-a", "provider-high-b"}
wantIDs := []string{"high-a", "high-b", "high-a", "high-b"}
for index := range wantProviders {
got, provider, errPick := scheduler.pickMixed(context.Background(), providers, model, cliproxyexecutor.Options{}, nil)
if errPick != nil {
t.Fatalf("pickMixed() #%d error = %v", index, errPick)
}
if got == nil {
t.Fatalf("pickMixed() #%d auth = nil", index)
}
if provider != wantProviders[index] {
t.Fatalf("pickMixed() #%d provider = %q, want %q", index, provider, wantProviders[index])
}
if got.ID != wantIDs[index] {
t.Fatalf("pickMixed() #%d auth.ID = %q, want %q", index, got.ID, wantIDs[index])
}
}
}
func TestManager_PickNextMixed_UsesProviderRotationBeforeCredentialRotation(t *testing.T) {
t.Parallel()
manager := NewManager(nil, &RoundRobinSelector{}, nil)
manager.executors["gemini"] = schedulerTestExecutor{}
manager.executors["claude"] = schedulerTestExecutor{}
if _, errRegister := manager.Register(context.Background(), &Auth{ID: "gemini-a", Provider: "gemini"}); errRegister != nil {
t.Fatalf("Register(gemini-a) error = %v", errRegister)
}
if _, errRegister := manager.Register(context.Background(), &Auth{ID: "gemini-b", Provider: "gemini"}); errRegister != nil {
t.Fatalf("Register(gemini-b) error = %v", errRegister)
}
if _, errRegister := manager.Register(context.Background(), &Auth{ID: "claude-a", Provider: "claude"}); errRegister != nil {
t.Fatalf("Register(claude-a) error = %v", errRegister)
}
wantProviders := []string{"gemini", "claude", "gemini", "claude"}
wantIDs := []string{"gemini-a", "claude-a", "gemini-b", "claude-a"}
for index := range wantProviders {
got, _, provider, errPick := manager.pickNextMixed(context.Background(), []string{"gemini", "claude"}, "", cliproxyexecutor.Options{}, map[string]struct{}{})
if errPick != nil {
t.Fatalf("pickNextMixed() #%d error = %v", index, errPick)
}
if got == nil {
t.Fatalf("pickNextMixed() #%d auth = nil", index)
}
if provider != wantProviders[index] {
t.Fatalf("pickNextMixed() #%d provider = %q, want %q", index, provider, wantProviders[index])
}
if got.ID != wantIDs[index] {
t.Fatalf("pickNextMixed() #%d auth.ID = %q, want %q", index, got.ID, wantIDs[index])
}
}
}
func TestManagerCustomSelector_FallsBackToLegacyPath(t *testing.T) {
t.Parallel()
selector := &trackingSelector{}
manager := NewManager(nil, selector, nil)
manager.executors["gemini"] = schedulerTestExecutor{}
manager.auths["auth-a"] = &Auth{ID: "auth-a", Provider: "gemini"}
manager.auths["auth-b"] = &Auth{ID: "auth-b", Provider: "gemini"}
got, _, errPick := manager.pickNext(context.Background(), "gemini", "", cliproxyexecutor.Options{}, map[string]struct{}{})
if errPick != nil {
t.Fatalf("pickNext() error = %v", errPick)
}
if got == nil {
t.Fatalf("pickNext() auth = nil")
}
if selector.calls != 1 {
t.Fatalf("selector.calls = %d, want %d", selector.calls, 1)
}
if len(selector.lastAuthID) != 2 {
t.Fatalf("len(selector.lastAuthID) = %d, want %d", len(selector.lastAuthID), 2)
}
if got.ID != selector.lastAuthID[len(selector.lastAuthID)-1] {
t.Fatalf("pickNext() auth.ID = %q, want selector-picked %q", got.ID, selector.lastAuthID[len(selector.lastAuthID)-1])
}
}
func TestManager_InitializesSchedulerForBuiltInSelector(t *testing.T) {
t.Parallel()
manager := NewManager(nil, &RoundRobinSelector{}, nil)
if manager.scheduler == nil {
t.Fatalf("manager.scheduler = nil")
}
if manager.scheduler.strategy != schedulerStrategyRoundRobin {
t.Fatalf("manager.scheduler.strategy = %v, want %v", manager.scheduler.strategy, schedulerStrategyRoundRobin)
}
manager.SetSelector(&FillFirstSelector{})
if manager.scheduler.strategy != schedulerStrategyFillFirst {
t.Fatalf("manager.scheduler.strategy = %v, want %v", manager.scheduler.strategy, schedulerStrategyFillFirst)
}
}
func TestManager_SchedulerTracksRegisterAndUpdate(t *testing.T) {
t.Parallel()
manager := NewManager(nil, &RoundRobinSelector{}, nil)
if _, errRegister := manager.Register(context.Background(), &Auth{ID: "auth-b", Provider: "gemini"}); errRegister != nil {
t.Fatalf("Register(auth-b) error = %v", errRegister)
}
if _, errRegister := manager.Register(context.Background(), &Auth{ID: "auth-a", Provider: "gemini"}); errRegister != nil {
t.Fatalf("Register(auth-a) error = %v", errRegister)
}
got, errPick := manager.scheduler.pickSingle(context.Background(), "gemini", "", cliproxyexecutor.Options{}, nil)
if errPick != nil {
t.Fatalf("scheduler.pickSingle() error = %v", errPick)
}
if got == nil || got.ID != "auth-a" {
t.Fatalf("scheduler.pickSingle() auth = %v, want auth-a", got)
}
if _, errUpdate := manager.Update(context.Background(), &Auth{ID: "auth-a", Provider: "gemini", Disabled: true}); errUpdate != nil {
t.Fatalf("Update(auth-a) error = %v", errUpdate)
}
got, errPick = manager.scheduler.pickSingle(context.Background(), "gemini", "", cliproxyexecutor.Options{}, nil)
if errPick != nil {
t.Fatalf("scheduler.pickSingle() after update error = %v", errPick)
}
if got == nil || got.ID != "auth-b" {
t.Fatalf("scheduler.pickSingle() after update auth = %v, want auth-b", got)
}
}
func TestManager_PickNextMixed_UsesSchedulerRotation(t *testing.T) {
t.Parallel()
manager := NewManager(nil, &RoundRobinSelector{}, nil)
manager.executors["gemini"] = schedulerTestExecutor{}
manager.executors["claude"] = schedulerTestExecutor{}
if _, errRegister := manager.Register(context.Background(), &Auth{ID: "gemini-a", Provider: "gemini"}); errRegister != nil {
t.Fatalf("Register(gemini-a) error = %v", errRegister)
}
if _, errRegister := manager.Register(context.Background(), &Auth{ID: "gemini-b", Provider: "gemini"}); errRegister != nil {
t.Fatalf("Register(gemini-b) error = %v", errRegister)
}
if _, errRegister := manager.Register(context.Background(), &Auth{ID: "claude-a", Provider: "claude"}); errRegister != nil {
t.Fatalf("Register(claude-a) error = %v", errRegister)
}
wantProviders := []string{"gemini", "claude", "gemini", "claude"}
wantIDs := []string{"gemini-a", "claude-a", "gemini-b", "claude-a"}
for index := range wantProviders {
got, _, provider, errPick := manager.pickNextMixed(context.Background(), []string{"gemini", "claude"}, "", cliproxyexecutor.Options{}, nil)
if errPick != nil {
t.Fatalf("pickNextMixed() #%d error = %v", index, errPick)
}
if got == nil {
t.Fatalf("pickNextMixed() #%d auth = nil", index)
}
if provider != wantProviders[index] {
t.Fatalf("pickNextMixed() #%d provider = %q, want %q", index, provider, wantProviders[index])
}
if got.ID != wantIDs[index] {
t.Fatalf("pickNextMixed() #%d auth.ID = %q, want %q", index, got.ID, wantIDs[index])
}
}
}
func TestManager_PickNextMixed_SkipsProvidersWithoutExecutors(t *testing.T) {
t.Parallel()
manager := NewManager(nil, &RoundRobinSelector{}, nil)
manager.executors["claude"] = schedulerTestExecutor{}
if _, errRegister := manager.Register(context.Background(), &Auth{ID: "gemini-a", Provider: "gemini"}); errRegister != nil {
t.Fatalf("Register(gemini-a) error = %v", errRegister)
}
if _, errRegister := manager.Register(context.Background(), &Auth{ID: "claude-a", Provider: "claude"}); errRegister != nil {
t.Fatalf("Register(claude-a) error = %v", errRegister)
}
got, _, provider, errPick := manager.pickNextMixed(context.Background(), []string{"gemini", "claude"}, "", cliproxyexecutor.Options{}, nil)
if errPick != nil {
t.Fatalf("pickNextMixed() error = %v", errPick)
}
if got == nil {
t.Fatalf("pickNextMixed() auth = nil")
}
if provider != "claude" {
t.Fatalf("pickNextMixed() provider = %q, want %q", provider, "claude")
}
if got.ID != "claude-a" {
t.Fatalf("pickNextMixed() auth.ID = %q, want %q", got.ID, "claude-a")
}
}
func TestManager_SchedulerTracksMarkResultCooldownAndRecovery(t *testing.T) {
t.Parallel()
manager := NewManager(nil, &RoundRobinSelector{}, nil)
reg := registry.GetGlobalRegistry()
reg.RegisterClient("auth-a", "gemini", []*registry.ModelInfo{{ID: "test-model"}})
reg.RegisterClient("auth-b", "gemini", []*registry.ModelInfo{{ID: "test-model"}})
t.Cleanup(func() {
reg.UnregisterClient("auth-a")
reg.UnregisterClient("auth-b")
})
if _, errRegister := manager.Register(context.Background(), &Auth{ID: "auth-a", Provider: "gemini"}); errRegister != nil {
t.Fatalf("Register(auth-a) error = %v", errRegister)
}
if _, errRegister := manager.Register(context.Background(), &Auth{ID: "auth-b", Provider: "gemini"}); errRegister != nil {
t.Fatalf("Register(auth-b) error = %v", errRegister)
}
manager.MarkResult(context.Background(), Result{
AuthID: "auth-a",
Provider: "gemini",
Model: "test-model",
Success: false,
Error: &Error{HTTPStatus: 429, Message: "quota"},
})
got, errPick := manager.scheduler.pickSingle(context.Background(), "gemini", "test-model", cliproxyexecutor.Options{}, nil)
if errPick != nil {
t.Fatalf("scheduler.pickSingle() after cooldown error = %v", errPick)
}
if got == nil || got.ID != "auth-b" {
t.Fatalf("scheduler.pickSingle() after cooldown auth = %v, want auth-b", got)
}
manager.MarkResult(context.Background(), Result{
AuthID: "auth-a",
Provider: "gemini",
Model: "test-model",
Success: true,
})
seen := make(map[string]struct{}, 2)
for index := 0; index < 2; index++ {
got, errPick = manager.scheduler.pickSingle(context.Background(), "gemini", "test-model", cliproxyexecutor.Options{}, nil)
if errPick != nil {
t.Fatalf("scheduler.pickSingle() after recovery #%d error = %v", index, errPick)
}
if got == nil {
t.Fatalf("scheduler.pickSingle() after recovery #%d auth = nil", index)
}
seen[got.ID] = struct{}{}
}
if len(seen) != 2 {
t.Fatalf("len(seen) = %d, want %d", len(seen), 2)
}
}

View File

@@ -390,6 +390,27 @@ func (a *Auth) AccountInfo() (string, string) {
// Check metadata for email first (OAuth-style auth)
if a.Metadata != nil {
if method, ok := a.Metadata["auth_method"].(string); ok {
switch strings.ToLower(strings.TrimSpace(method)) {
case "oauth":
for _, key := range []string{"email", "username", "name"} {
if value, okValue := a.Metadata[key].(string); okValue {
if trimmed := strings.TrimSpace(value); trimmed != "" {
return "oauth", trimmed
}
}
}
case "pat", "personal_access_token":
for _, key := range []string{"username", "email", "name", "token_preview"} {
if value, okValue := a.Metadata[key].(string); okValue {
if trimmed := strings.TrimSpace(value); trimmed != "" {
return "personal_access_token", trimmed
}
}
}
return "personal_access_token", ""
}
}
if v, ok := a.Metadata["email"].(string); ok {
email := strings.TrimSpace(v)
if email != "" {

View File

@@ -1,16 +1,13 @@
package cliproxy
import (
"context"
"net"
"net/http"
"net/url"
"strings"
"sync"
coreauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth"
"github.com/router-for-me/CLIProxyAPI/v6/sdk/proxyutil"
log "github.com/sirupsen/logrus"
"golang.org/x/net/proxy"
)
// defaultRoundTripperProvider returns a per-auth HTTP RoundTripper based on
@@ -39,35 +36,12 @@ func (p *defaultRoundTripperProvider) RoundTripperFor(auth *coreauth.Auth) http.
if rt != nil {
return rt
}
// Parse the proxy URL to determine the scheme.
proxyURL, errParse := url.Parse(proxyStr)
if errParse != nil {
log.Errorf("parse proxy URL failed: %v", errParse)
transport, _, errBuild := proxyutil.BuildHTTPTransport(proxyStr)
if errBuild != nil {
log.Errorf("%v", errBuild)
return nil
}
var transport *http.Transport
// Handle different proxy schemes.
if proxyURL.Scheme == "socks5" {
// Configure SOCKS5 proxy with optional authentication.
username := proxyURL.User.Username()
password, _ := proxyURL.User.Password()
proxyAuth := &proxy.Auth{User: username, Password: password}
dialer, errSOCKS5 := proxy.SOCKS5("tcp", proxyURL.Host, proxyAuth, proxy.Direct)
if errSOCKS5 != nil {
log.Errorf("create SOCKS5 dialer failed: %v", errSOCKS5)
return nil
}
// Set up a custom transport using the SOCKS5 dialer.
transport = &http.Transport{
DialContext: func(ctx context.Context, network, addr string) (net.Conn, error) {
return dialer.Dial(network, addr)
},
}
} else if proxyURL.Scheme == "http" || proxyURL.Scheme == "https" {
// Configure HTTP or HTTPS proxy.
transport = &http.Transport{Proxy: http.ProxyURL(proxyURL)}
} else {
log.Errorf("unsupported proxy scheme: %s", proxyURL.Scheme)
if transport == nil {
return nil
}
p.mu.Lock()

View File

@@ -0,0 +1,22 @@
package cliproxy
import (
"net/http"
"testing"
coreauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth"
)
func TestRoundTripperForDirectBypassesProxy(t *testing.T) {
t.Parallel()
provider := newDefaultRoundTripperProvider()
rt := provider.RoundTripperFor(&coreauth.Auth{ProxyURL: "direct"})
transport, ok := rt.(*http.Transport)
if !ok {
t.Fatalf("transport type = %T, want *http.Transport", rt)
}
if transport.Proxy != nil {
t.Fatal("expected direct transport to disable proxy function")
}
}

View File

@@ -119,6 +119,7 @@ func newDefaultAuthManager() *sdkAuth.Manager {
sdkAuth.NewCodexAuthenticator(),
sdkAuth.NewClaudeAuthenticator(),
sdkAuth.NewQwenAuthenticator(),
sdkAuth.NewGitLabAuthenticator(),
)
}
@@ -293,8 +294,6 @@ func (s *Service) applyCoreAuthAddOrUpdate(ctx context.Context, auth *coreauth.A
// IMPORTANT: Update coreManager FIRST, before model registration.
// This ensures that configuration changes (proxy_url, prefix, etc.) take effect
// immediately for API calls, rather than waiting for model registration to complete.
// Model registration may involve network calls (e.g., FetchAntigravityModels) that
// could timeout if the new proxy_url is unreachable.
op := "register"
var err error
if existing, ok := s.coreManager.GetByID(auth.ID); ok {
@@ -323,6 +322,12 @@ func (s *Service) applyCoreAuthAddOrUpdate(ctx context.Context, auth *coreauth.A
// This operation may block on network calls, but the auth configuration
// is already effective at this point.
s.registerModelsForAuth(auth)
// Refresh the scheduler entry so that the auth's supportedModelSet is rebuilt
// from the now-populated global model registry. Without this, newly added auths
// have an empty supportedModelSet (because Register/Update upserts into the
// scheduler before registerModelsForAuth runs) and are invisible to the scheduler.
s.coreManager.RefreshSchedulerEntry(auth.ID)
}
func (s *Service) applyCoreAuthRemoval(ctx context.Context, id string) {
@@ -438,6 +443,8 @@ func (s *Service) ensureExecutorsForAuthWithMode(a *coreauth.Auth, forceReplace
s.coreManager.RegisterExecutor(executor.NewKiloExecutor(s.cfg))
case "github-copilot":
s.coreManager.RegisterExecutor(executor.NewGitHubCopilotExecutor(s.cfg))
case "gitlab":
s.coreManager.RegisterExecutor(executor.NewGitLabExecutor(s.cfg))
default:
providerKey := strings.ToLower(strings.TrimSpace(a.Provider))
if providerKey == "" {
@@ -447,6 +454,17 @@ func (s *Service) ensureExecutorsForAuthWithMode(a *coreauth.Auth, forceReplace
}
}
func (s *Service) registerResolvedModelsForAuth(a *coreauth.Auth, providerKey string, models []*ModelInfo) {
if a == nil || a.ID == "" {
return
}
if len(models) == 0 {
GlobalModelRegistry().UnregisterClient(a.ID)
return
}
GlobalModelRegistry().RegisterClient(a.ID, providerKey, models)
}
// rebindExecutors refreshes provider executors so they observe the latest configuration.
func (s *Service) rebindExecutors() {
if s == nil || s.coreManager == nil {
@@ -554,6 +572,44 @@ func (s *Service) Run(ctx context.Context) error {
s.hooks.OnBeforeStart(s.cfg)
}
// Register callback for startup and periodic model catalog refresh.
// When remote model definitions change, re-register models for affected providers.
// This intentionally rebuilds per-auth model availability from the latest catalog
// snapshot instead of preserving prior registry suppression state.
registry.SetModelRefreshCallback(func(changedProviders []string) {
if s == nil || s.coreManager == nil || len(changedProviders) == 0 {
return
}
providerSet := make(map[string]bool, len(changedProviders))
for _, p := range changedProviders {
providerSet[strings.ToLower(strings.TrimSpace(p))] = true
}
auths := s.coreManager.List()
refreshed := 0
for _, item := range auths {
if item == nil || item.ID == "" {
continue
}
auth, ok := s.coreManager.GetByID(item.ID)
if !ok || auth == nil || auth.Disabled {
continue
}
provider := strings.ToLower(strings.TrimSpace(auth.Provider))
if !providerSet[provider] {
continue
}
if s.refreshModelRegistrationForAuth(auth) {
refreshed++
}
}
if refreshed > 0 {
log.Infof("re-registered models for %d auth(s) due to model catalog changes: %v", refreshed, changedProviders)
}
})
s.serverErr = make(chan error, 1)
go func() {
if errStart := s.server.Start(); errStart != nil {
@@ -836,9 +892,7 @@ func (s *Service) registerModelsForAuth(a *coreauth.Auth) {
models = registry.GetAIStudioModels()
models = applyExcludedModels(models, excluded)
case "antigravity":
ctx, cancel := context.WithTimeout(context.Background(), 15*time.Second)
models = executor.FetchAntigravityModels(ctx, a, s.cfg)
cancel()
models = registry.GetAntigravityModels()
models = applyExcludedModels(models, excluded)
case "claude":
models = registry.GetClaudeModels()
@@ -852,7 +906,22 @@ func (s *Service) registerModelsForAuth(a *coreauth.Auth) {
}
models = applyExcludedModels(models, excluded)
case "codex":
models = registry.GetOpenAIModels()
codexPlanType := ""
if a.Attributes != nil {
codexPlanType = strings.TrimSpace(a.Attributes["plan_type"])
}
switch strings.ToLower(codexPlanType) {
case "pro":
models = registry.GetCodexProModels()
case "plus":
models = registry.GetCodexPlusModels()
case "team":
models = registry.GetCodexTeamModels()
case "free":
models = registry.GetCodexFreeModels()
default:
models = registry.GetCodexProModels()
}
if entry := s.resolveConfigCodexKey(a); entry != nil {
if len(entry.Models) > 0 {
models = buildCodexConfigModels(entry)
@@ -870,7 +939,7 @@ func (s *Service) registerModelsForAuth(a *coreauth.Auth) {
models = applyExcludedModels(models, excluded)
case "kimi":
models = registry.GetKimiModels()
models = applyExcludedModels(models, excluded)
models = applyExcludedModels(models, excluded)
case "github-copilot":
ctx, cancel := context.WithTimeout(context.Background(), 15*time.Second)
defer cancel()
@@ -882,6 +951,9 @@ func (s *Service) registerModelsForAuth(a *coreauth.Auth) {
case "kilo":
models = executor.FetchKiloModels(context.Background(), a, s.cfg)
models = applyExcludedModels(models, excluded)
case "gitlab":
models = executor.GitLabModelsFromAuth(a)
models = applyExcludedModels(models, excluded)
default:
// Handle OpenAI-compatibility providers by name using config
if s.cfg != nil {
@@ -949,7 +1021,7 @@ func (s *Service) registerModelsForAuth(a *coreauth.Auth) {
if providerKey == "" {
providerKey = "openai-compatibility"
}
GlobalModelRegistry().RegisterClient(a.ID, providerKey, applyModelPrefixes(ms, a.Prefix, s.cfg.ForceModelPrefix))
s.registerResolvedModelsForAuth(a, providerKey, applyModelPrefixes(ms, a.Prefix, s.cfg.ForceModelPrefix))
} else {
// Ensure stale registrations are cleared when model list becomes empty.
GlobalModelRegistry().UnregisterClient(a.ID)
@@ -970,16 +1042,60 @@ func (s *Service) registerModelsForAuth(a *coreauth.Auth) {
if key == "" {
key = strings.ToLower(strings.TrimSpace(a.Provider))
}
GlobalModelRegistry().RegisterClient(a.ID, key, applyModelPrefixes(models, a.Prefix, s.cfg != nil && s.cfg.ForceModelPrefix))
if provider == "antigravity" {
s.backfillAntigravityModels(a, models)
}
s.registerResolvedModelsForAuth(a, key, applyModelPrefixes(models, a.Prefix, s.cfg != nil && s.cfg.ForceModelPrefix))
return
}
GlobalModelRegistry().UnregisterClient(a.ID)
}
// refreshModelRegistrationForAuth re-applies the latest model registration for
// one auth and reconciles any concurrent auth changes that race with the
// refresh. Callers are expected to pre-filter provider membership.
//
// Re-registration is deliberate: registry cooldown/suspension state is treated
// as part of the previous registration snapshot and is cleared when the auth is
// rebound to the refreshed model catalog.
func (s *Service) refreshModelRegistrationForAuth(current *coreauth.Auth) bool {
if s == nil || s.coreManager == nil || current == nil || current.ID == "" {
return false
}
if !current.Disabled {
s.ensureExecutorsForAuth(current)
}
s.registerModelsForAuth(current)
latest, ok := s.latestAuthForModelRegistration(current.ID)
if !ok || latest.Disabled {
GlobalModelRegistry().UnregisterClient(current.ID)
s.coreManager.RefreshSchedulerEntry(current.ID)
return false
}
// Re-apply the latest auth snapshot so concurrent auth updates cannot leave
// stale model registrations behind. This may duplicate registration work when
// no auth fields changed, but keeps the refresh path simple and correct.
s.ensureExecutorsForAuth(latest)
s.registerModelsForAuth(latest)
s.coreManager.RefreshSchedulerEntry(current.ID)
return true
}
// latestAuthForModelRegistration returns the latest auth snapshot regardless of
// provider membership. Callers use this after a registration attempt to restore
// whichever state currently owns the client ID in the global registry.
func (s *Service) latestAuthForModelRegistration(authID string) (*coreauth.Auth, bool) {
if s == nil || s.coreManager == nil || authID == "" {
return nil, false
}
auth, ok := s.coreManager.GetByID(authID)
if !ok || auth == nil || auth.ID == "" {
return nil, false
}
return auth, true
}
func (s *Service) resolveConfigClaudeKey(auth *coreauth.Auth) *config.ClaudeKey {
if auth == nil || s.cfg == nil {
return nil
@@ -1118,56 +1234,6 @@ func (s *Service) oauthExcludedModels(provider, authKind string) []string {
return cfg.OAuthExcludedModels[providerKey]
}
func (s *Service) backfillAntigravityModels(source *coreauth.Auth, primaryModels []*ModelInfo) {
if s == nil || s.coreManager == nil || len(primaryModels) == 0 {
return
}
sourceID := ""
if source != nil {
sourceID = strings.TrimSpace(source.ID)
}
reg := registry.GetGlobalRegistry()
for _, candidate := range s.coreManager.List() {
if candidate == nil || candidate.Disabled {
continue
}
candidateID := strings.TrimSpace(candidate.ID)
if candidateID == "" || candidateID == sourceID {
continue
}
if !strings.EqualFold(strings.TrimSpace(candidate.Provider), "antigravity") {
continue
}
if len(reg.GetModelsForClient(candidateID)) > 0 {
continue
}
authKind := strings.ToLower(strings.TrimSpace(candidate.Attributes["auth_kind"]))
if authKind == "" {
if kind, _ := candidate.AccountInfo(); strings.EqualFold(kind, "api_key") {
authKind = "apikey"
}
}
excluded := s.oauthExcludedModels("antigravity", authKind)
if candidate.Attributes != nil {
if val, ok := candidate.Attributes["excluded_models"]; ok && strings.TrimSpace(val) != "" {
excluded = strings.Split(val, ",")
}
}
models := applyExcludedModels(primaryModels, excluded)
models = applyOAuthModelAlias(s.cfg, "antigravity", authKind, models)
if len(models) == 0 {
continue
}
reg.RegisterClient(candidateID, "antigravity", applyModelPrefixes(models, candidate.Prefix, s.cfg != nil && s.cfg.ForceModelPrefix))
log.Debugf("antigravity models backfilled for auth %s using primary model list", candidateID)
}
}
func applyExcludedModels(models []*ModelInfo, excluded []string) []*ModelInfo {
if len(models) == 0 || len(excluded) == 0 {
return models

View File

@@ -1,135 +0,0 @@
package cliproxy
import (
"context"
"strings"
"testing"
"github.com/router-for-me/CLIProxyAPI/v6/internal/registry"
coreauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth"
"github.com/router-for-me/CLIProxyAPI/v6/sdk/config"
)
func TestBackfillAntigravityModels_RegistersMissingAuth(t *testing.T) {
source := &coreauth.Auth{
ID: "ag-backfill-source",
Provider: "antigravity",
Status: coreauth.StatusActive,
Attributes: map[string]string{
"auth_kind": "oauth",
},
}
target := &coreauth.Auth{
ID: "ag-backfill-target",
Provider: "antigravity",
Status: coreauth.StatusActive,
Attributes: map[string]string{
"auth_kind": "oauth",
},
}
manager := coreauth.NewManager(nil, nil, nil)
if _, err := manager.Register(context.Background(), source); err != nil {
t.Fatalf("register source auth: %v", err)
}
if _, err := manager.Register(context.Background(), target); err != nil {
t.Fatalf("register target auth: %v", err)
}
service := &Service{
cfg: &config.Config{},
coreManager: manager,
}
reg := registry.GetGlobalRegistry()
reg.UnregisterClient(source.ID)
reg.UnregisterClient(target.ID)
t.Cleanup(func() {
reg.UnregisterClient(source.ID)
reg.UnregisterClient(target.ID)
})
primary := []*ModelInfo{
{ID: "claude-sonnet-4-5"},
{ID: "gemini-2.5-pro"},
}
reg.RegisterClient(source.ID, "antigravity", primary)
service.backfillAntigravityModels(source, primary)
got := reg.GetModelsForClient(target.ID)
if len(got) != 2 {
t.Fatalf("expected target auth to be backfilled with 2 models, got %d", len(got))
}
ids := make(map[string]struct{}, len(got))
for _, model := range got {
if model == nil {
continue
}
ids[strings.ToLower(strings.TrimSpace(model.ID))] = struct{}{}
}
if _, ok := ids["claude-sonnet-4-5"]; !ok {
t.Fatal("expected backfilled model claude-sonnet-4-5")
}
if _, ok := ids["gemini-2.5-pro"]; !ok {
t.Fatal("expected backfilled model gemini-2.5-pro")
}
}
func TestBackfillAntigravityModels_RespectsExcludedModels(t *testing.T) {
source := &coreauth.Auth{
ID: "ag-backfill-source-excluded",
Provider: "antigravity",
Status: coreauth.StatusActive,
Attributes: map[string]string{
"auth_kind": "oauth",
},
}
target := &coreauth.Auth{
ID: "ag-backfill-target-excluded",
Provider: "antigravity",
Status: coreauth.StatusActive,
Attributes: map[string]string{
"auth_kind": "oauth",
"excluded_models": "gemini-2.5-pro",
},
}
manager := coreauth.NewManager(nil, nil, nil)
if _, err := manager.Register(context.Background(), source); err != nil {
t.Fatalf("register source auth: %v", err)
}
if _, err := manager.Register(context.Background(), target); err != nil {
t.Fatalf("register target auth: %v", err)
}
service := &Service{
cfg: &config.Config{},
coreManager: manager,
}
reg := registry.GetGlobalRegistry()
reg.UnregisterClient(source.ID)
reg.UnregisterClient(target.ID)
t.Cleanup(func() {
reg.UnregisterClient(source.ID)
reg.UnregisterClient(target.ID)
})
primary := []*ModelInfo{
{ID: "claude-sonnet-4-5"},
{ID: "gemini-2.5-pro"},
}
reg.RegisterClient(source.ID, "antigravity", primary)
service.backfillAntigravityModels(source, primary)
got := reg.GetModelsForClient(target.ID)
if len(got) != 1 {
t.Fatalf("expected 1 model after exclusion, got %d", len(got))
}
if got[0] == nil || !strings.EqualFold(strings.TrimSpace(got[0].ID), "claude-sonnet-4-5") {
t.Fatalf("expected remaining model %q, got %+v", "claude-sonnet-4-5", got[0])
}
}

View File

@@ -0,0 +1,48 @@
package cliproxy
import (
"testing"
"github.com/router-for-me/CLIProxyAPI/v6/internal/registry"
coreauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth"
"github.com/router-for-me/CLIProxyAPI/v6/sdk/config"
)
func TestRegisterModelsForAuth_GitLabUsesDiscoveredModels(t *testing.T) {
service := &Service{cfg: &config.Config{}}
auth := &coreauth.Auth{
ID: "gitlab-auth.json",
Provider: "gitlab",
Status: coreauth.StatusActive,
Metadata: map[string]any{
"model_details": map[string]any{
"model_provider": "anthropic",
"model_name": "claude-sonnet-4-5",
},
},
}
reg := registry.GetGlobalRegistry()
reg.UnregisterClient(auth.ID)
t.Cleanup(func() { reg.UnregisterClient(auth.ID) })
service.registerModelsForAuth(auth)
models := reg.GetModelsForClient(auth.ID)
if len(models) < 2 {
t.Fatalf("expected stable alias and discovered model, got %d entries", len(models))
}
seenAlias := false
seenDiscovered := false
for _, model := range models {
switch model.ID {
case "gitlab-duo":
seenAlias = true
case "claude-sonnet-4-5":
seenDiscovered = true
}
}
if !seenAlias || !seenDiscovered {
t.Fatalf("expected gitlab-duo and discovered model, got %+v", models)
}
}

139
sdk/proxyutil/proxy.go Normal file
View File

@@ -0,0 +1,139 @@
package proxyutil
import (
"context"
"fmt"
"net"
"net/http"
"net/url"
"strings"
"golang.org/x/net/proxy"
)
// Mode describes how a proxy setting should be interpreted.
type Mode int
const (
// ModeInherit means no explicit proxy behavior was configured.
ModeInherit Mode = iota
// ModeDirect means outbound requests must bypass proxies explicitly.
ModeDirect
// ModeProxy means a concrete proxy URL was configured.
ModeProxy
// ModeInvalid means the proxy setting is present but malformed or unsupported.
ModeInvalid
)
// Setting is the normalized interpretation of a proxy configuration value.
type Setting struct {
Raw string
Mode Mode
URL *url.URL
}
// Parse normalizes a proxy configuration value into inherit, direct, or proxy modes.
func Parse(raw string) (Setting, error) {
trimmed := strings.TrimSpace(raw)
setting := Setting{Raw: trimmed}
if trimmed == "" {
setting.Mode = ModeInherit
return setting, nil
}
if strings.EqualFold(trimmed, "direct") || strings.EqualFold(trimmed, "none") {
setting.Mode = ModeDirect
return setting, nil
}
parsedURL, errParse := url.Parse(trimmed)
if errParse != nil {
setting.Mode = ModeInvalid
return setting, fmt.Errorf("parse proxy URL failed: %w", errParse)
}
if parsedURL.Scheme == "" || parsedURL.Host == "" {
setting.Mode = ModeInvalid
return setting, fmt.Errorf("proxy URL missing scheme/host")
}
switch parsedURL.Scheme {
case "socks5", "http", "https":
setting.Mode = ModeProxy
setting.URL = parsedURL
return setting, nil
default:
setting.Mode = ModeInvalid
return setting, fmt.Errorf("unsupported proxy scheme: %s", parsedURL.Scheme)
}
}
// NewDirectTransport returns a transport that bypasses environment proxies.
func NewDirectTransport() *http.Transport {
if transport, ok := http.DefaultTransport.(*http.Transport); ok && transport != nil {
clone := transport.Clone()
clone.Proxy = nil
return clone
}
return &http.Transport{Proxy: nil}
}
// BuildHTTPTransport constructs an HTTP transport for the provided proxy setting.
func BuildHTTPTransport(raw string) (*http.Transport, Mode, error) {
setting, errParse := Parse(raw)
if errParse != nil {
return nil, setting.Mode, errParse
}
switch setting.Mode {
case ModeInherit:
return nil, setting.Mode, nil
case ModeDirect:
return NewDirectTransport(), setting.Mode, nil
case ModeProxy:
if setting.URL.Scheme == "socks5" {
var proxyAuth *proxy.Auth
if setting.URL.User != nil {
username := setting.URL.User.Username()
password, _ := setting.URL.User.Password()
proxyAuth = &proxy.Auth{User: username, Password: password}
}
dialer, errSOCKS5 := proxy.SOCKS5("tcp", setting.URL.Host, proxyAuth, proxy.Direct)
if errSOCKS5 != nil {
return nil, setting.Mode, fmt.Errorf("create SOCKS5 dialer failed: %w", errSOCKS5)
}
return &http.Transport{
Proxy: nil,
DialContext: func(_ context.Context, network, addr string) (net.Conn, error) {
return dialer.Dial(network, addr)
},
}, setting.Mode, nil
}
return &http.Transport{Proxy: http.ProxyURL(setting.URL)}, setting.Mode, nil
default:
return nil, setting.Mode, nil
}
}
// BuildDialer constructs a proxy dialer for settings that operate at the connection layer.
func BuildDialer(raw string) (proxy.Dialer, Mode, error) {
setting, errParse := Parse(raw)
if errParse != nil {
return nil, setting.Mode, errParse
}
switch setting.Mode {
case ModeInherit:
return nil, setting.Mode, nil
case ModeDirect:
return proxy.Direct, setting.Mode, nil
case ModeProxy:
dialer, errDialer := proxy.FromURL(setting.URL, proxy.Direct)
if errDialer != nil {
return nil, setting.Mode, fmt.Errorf("create proxy dialer failed: %w", errDialer)
}
return dialer, setting.Mode, nil
default:
return nil, setting.Mode, nil
}
}

Some files were not shown because too many files have changed in this diff Show More