Compare commits

..

50 Commits

Author SHA1 Message Date
Luis Pater
1fa094dac6 Merge pull request #461 from MrHuangJser/main
feat(cursor): Full Cursor provider with H2 streaming, MCP tools, multi-turn & multi-account
2026-03-28 05:01:27 +08:00
Luis Pater
f55754621f Merge pull request #464 from router-for-me/plus
v6.9.4
2026-03-28 04:51:27 +08:00
Luis Pater
ac26e7db43 Merge branch 'main' into plus 2026-03-28 04:51:18 +08:00
Luis Pater
10b824fcac fix(security): validate auth file names to prevent unsafe input 2026-03-28 04:48:23 +08:00
Luis Pater
7dccc7ba2f docs(readme): remove redundant whitespace in BmoPlus sponsorship section of Chinese README 2026-03-27 20:52:14 +08:00
Luis Pater
70c90687fd docs(readme): fix formatting in BmoPlus sponsorship section of Chinese README 2026-03-27 20:49:43 +08:00
Luis Pater
8144ffd5c8 Merge pull request #2370 from B3o/add-bmoplus-sponsor
docs: add BmoPlus sponsorship banners to READMEs
2026-03-27 20:48:22 +08:00
B3o
6b45d311ec add BmoPlus sponsorship banners to READMEs 2026-03-27 18:01:35 +08:00
MrHuangJser
7386a70724 feat(cursor): auto-identify accounts from JWT sub for multi-account support
Previously Cursor required a manual ?label=xxx parameter to distinguish
accounts (unlike Codex which auto-generates filenames from JWT claims).

Cursor JWTs contain a "sub" claim (e.g. "auth0|user_XXXX") that uniquely
identifies each account. Now we:

- Add ParseJWTSub() + SubToShortHash() to extract and hash the sub claim
- Refactor GetTokenExpiry() to share the new decodeJWTPayload() helper
- Update CredentialFileName(label, subHash) to auto-generate filenames
  from the sub hash when no explicit label is provided
  (e.g. "cursor.8f202e67.json" instead of always "cursor.json")
- Add DisplayLabel() for human-readable account identification
- Store "sub" in metadata for observability
- Update both management API handler and SDK authenticator

Same account always produces the same filename (deterministic), different
accounts get different files. Explicit ?label= still takes priority.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
2026-03-27 17:40:02 +08:00
白金
1821bf7051 docs: add BmoPlus sponsorship banners to READMEs 2026-03-27 17:39:29 +08:00
Luis Pater
d42b5d4e78 docs(readme): update QQ group information in Chinese README 2026-03-27 11:46:21 +08:00
MrHuangJser
1b7447b682 feat(cursor): implement StatusError for conductor cooldown integration
Cursor executor errors were plain fmt.Errorf — the conductor couldn't
extract HTTP status codes, so exhausted accounts never entered cooldown.

Changes:
- Add ConnectError struct to proto/connect.go: ParseConnectEndStream now
  returns *ConnectError with Code/Message fields for precise matching
- Add cursorStatusErr implementing StatusError + RetryAfter interfaces
- Add classifyCursorError() with two-layer classification:
  Layer 1: exact match on ConnectError.Code (gRPC standard codes)
    resource_exhausted → 429, unauthenticated → 401,
    permission_denied → 403, unavailable → 503, internal → 500
  Layer 2: fuzzy string match for H2 errors (RST_STREAM → 502)
- Log all ConnectError code/message pairs for observing real server
  error codes (we have no samples yet)
- Wrap Execute and ExecuteStream error returns with classifyCursorError

Now the conductor properly marks Cursor auths as cooldown on quota errors,
enabling exponential backoff and round-robin failover.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
2026-03-27 11:42:22 +08:00
MrHuangJser
40dee4453a feat(cursor): auto-migrate sessions to healthy account on quota exhaustion
When a Cursor account's quota is exhausted, sessions bound to it can now
seamlessly continue on a different account:

Layer 1 — Checkpoint decoupling:
  Key checkpoints by conversationId (not authID:conversationId). Store
  authID inside savedCheckpoint. On lookup, if auth changed, discard the
  stale checkpoint and flatten conversation history into userText.

Layer 2 — Cross-account session cleanup:
  When a request arrives for a conversation whose session belongs to a
  different (now-exhausted) auth, close the old H2 stream and remove
  the stale session to free resources.

Layer 3 — H2Stream.Err() exposure:
  New Err() method on H2Stream so callers can inspect RST_STREAM,
  GOAWAY, or other stream-level errors after closure.

Layer 4 — processH2SessionFrames error propagation:
  Returns error instead of bare return. Connect EndStream errors (quota,
  rate limit) are now propagated instead of being logged and swallowed.

Layer 5 — Pre-response transparent retry:
  If the stream fails before any data is sent to the client, return an
  error to the conductor so it retries with a different auth — fully
  transparent to the client.

Layer 6 — Post-response error logging:
  If the stream fails after data was already sent, log a warning. The
  conductor's existing cooldown mechanism ensures the next request routes
  to a healthy account.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
2026-03-27 10:50:32 +08:00
MrHuangJser
8902e1cccb style(cursor): replace fmt.Print* with log package for consistent logging
Address Gemini Code Assist review feedback: use logrus log package
instead of fmt.Printf/Println in Cursor auth handlers and CLI for
unified log formatting and level control.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
2026-03-26 17:03:32 +08:00
黄姜恒
de5fe71478 feat(cursor): multi-account routing with round-robin and session isolation
- Add cursor/filename.go for multi-account credential file naming
- Include auth.ID in session and checkpoint keys for per-account isolation
- Record authID in cursorSession, validate on resume to prevent cross-account access
- Management API /cursor-auth-url supports ?label= for creating named accounts
- Leverages existing conductor round-robin + failover framework

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
2026-03-26 11:27:49 +08:00
黄姜恒
dcfbec2990 feat(cursor): add management API for Cursor OAuth authentication
- Add RequestCursorToken handler with PKCE + polling flow
- Register /v0/management/cursor-auth-url route
- Returns login URL + state for browser auth, polls in background
- Saves cursor.json with access/refresh tokens on success

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
2026-03-26 11:10:07 +08:00
黄姜恒
c95620f90e feat(cursor): conversation checkpoint + session_id for multi-turn context
- Capture conversation_checkpoint_update from Cursor server (was ignored)
- Store checkpoint per conversationId, replay as conversation_state on next request
- Use protowire to embed raw checkpoint bytes directly (no deserialization)
- Extract session_id from Claude Code metadata for stable conversationId across resume
- Flatten conversation history into userText as fallback when no checkpoint available
- Use conversationId as session key for reliable tool call resume
- Add checkpoint TTL cleanup (30min)

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
2026-03-26 10:51:47 +08:00
黄姜恒
9613f0b3f9 feat(cursor): deterministic conversation_id from Claude Code session cch
Extract the cch hash from Claude Code's billing header in the system
prompt (x-anthropic-billing-header: ...cch=XXXXX;) and use it to derive
a deterministic conversation_id instead of generating a random UUID.

Same Claude Code session → same cch → same conversation_id → Cursor
server can reuse conversation state across multiple turns, preserving
tool call results and other context without re-encoding history.

Also cleans up temporary debug logging from previous iterations.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
2026-03-25 20:29:49 +08:00
黄姜恒
274f29e26b fix(cursor): improve session key uniqueness for multi-session safety
Include system prompt prefix (first 200 chars) in session key derivation.
Claude Code sessions have unique system prompts containing cwd, session_id,
file paths, etc., making collisions between concurrent sessions from the
same user virtually impossible.

Session key now = SHA256(apiKey + model + systemPrompt[:200] + firstUserMsg)

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
2026-03-25 17:24:37 +08:00
黄姜恒
c8e79c3787 fix(cursor): prevent session key collision across users
Include client API key in session key derivation to prevent different
users sharing the same proxy from accidentally resuming each other's
H2 streams when they send identical first messages with the same model.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
2026-03-25 17:19:11 +08:00
黄姜恒
8afef43887 fix(cursor): preserve tool call context in multi-turn conversations
When an assistant message appears after tool results without a pending
user message, append it to the last turn's assistant text instead of
dropping it. Also add bakeToolResultsIntoTurns() to merge tool results
into turn context when no active H2 session exists for resume, ensuring
the model sees the full tool interaction history in follow-up requests.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
2026-03-25 17:15:24 +08:00
黄姜恒
c1083cbfc6 fix(cursor): MCP tool call resume, H2 flow control, and token usage
- Rewrite tool call mechanism from interrupt-resume to inline-wait mode:
  processH2SessionFrames no longer exits on mcpArgs; instead blocks on
  toolResultCh while continuing to handle KV/heartbeat messages, then
  sends MCP result and continues processing text in the same goroutine.
  Fixes the issue where server stopped generating text after resume.

- Add switchable output channel (outMu/currentOut) so first HTTP response
  closes after tool_calls+[DONE], and resumed text goes to a new channel
  returned by resumeWithToolResults. Reset streamParam on switch so
  Translator produces fresh message_start/content_block_start events.

- Implement send-side H2 flow control: track server's initial window size
  and WINDOW_UPDATE increments; Write() blocks when window exhausted.
  Fixes RST_STREAM FLOW_CONTROL_ERROR on large requests (178KB+).

- Decode new InteractionUpdate fields: TurnEndedUpdate (field 14) as
  stream termination signal, HeartbeatUpdate (field 13) silently ignored,
  TokenDeltaUpdate (field 8) for token usage tracking.

- Include token usage in final stop chunk (prompt_tokens estimated from
  payload size, completion_tokens from accumulated TokenDeltaUpdate deltas)
  so Claude CLI status bar shows non-zero token counts.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
2026-03-25 17:03:14 +08:00
Luis Pater
1e6bc81cfd refactor(config): replace auto-update-panel with disable-auto-update-panel for clarity 2026-03-25 10:31:44 +08:00
Luis Pater
1a149475e0 Merge pull request #2293 from Xvvln/fix/management-asset-security
fix(security): harden management panel asset updater
2026-03-25 10:22:49 +08:00
Luis Pater
e5166841db Merge pull request #2310 from shellus/fix/claude-openai-system-top-level
fix: preserve OpenAI system messages as Claude top-level system
2026-03-25 10:21:18 +08:00
黄姜恒
19c52bcb60 feat: stash code 2026-03-25 10:14:14 +08:00
Luis Pater
bb9b2d1758 Merge pull request #2320 from cikichen/build/freebsd-support
build: add freebsd support for releases
2026-03-25 10:12:35 +08:00
Luis Pater
7fa527193c Merge pull request #453 from HeCHieh/fix/github-copilot-gpt54-responses
Fix GitHub Copilot gpt-5.4 endpoint routing
2026-03-25 09:45:23 +08:00
Luis Pater
ed0eb51b4d Merge pull request #450 from lwiles692/feature/add-codebuddy-support
feat(auth): add CodeBuddy-CN browser OAuth authentication support
2026-03-25 09:43:52 +08:00
Luis Pater
0e4f669c8b Merge branch 'router-for-me:main' into main 2026-03-25 09:38:34 +08:00
Luis Pater
76c064c729 Merge pull request #2335 from router-for-me/auth
Support batch upload and delete for auth files
2026-03-25 09:34:44 +08:00
Luis Pater
d2f652f436 Merge pull request #2333 from router-for-me/codex
feat(codex): pass through codex client identity headers
2026-03-25 09:34:09 +08:00
Luis Pater
6a452a54d5 Merge pull request #2316 from router-for-me/openai
Add per-model thinking support for OpenAI compatibility
2026-03-25 09:31:28 +08:00
hkfires
9e5693e74f feat(api): support batch auth file upload and delete 2026-03-25 09:20:17 +08:00
hkfires
528b1a2307 feat(codex): pass through codex client identity headers 2026-03-25 08:48:18 +08:00
Luis Pater
0cc978ec1d Merge pull request #2297 from router-for-me/readme
docs(readme): update japanese documentation links
2026-03-25 03:11:24 +08:00
simon
d312422ab4 build: add freebsd support to releases 2026-03-24 16:49:04 +08:00
hkfires
fee736933b feat(openai-compat): add per-model thinking support 2026-03-24 14:21:12 +08:00
GeJiaXiang
09c92aa0b5 fix: keep a fallback turn for system-only Claude inputs 2026-03-24 13:54:25 +08:00
GeJiaXiang
8c67b3ae64 test: verify remaining user message after system merge 2026-03-24 13:47:52 +08:00
GeJiaXiang
000e4ceb4e fix: map OpenAI system messages to Claude top-level system 2026-03-24 13:42:33 +08:00
hkfires
5c99846ecf docs(readme): update japanese documentation links 2026-03-24 09:47:01 +08:00
Luis Pater
d475aaba96 Fixed: #2274
fix(translator): omit null content fields in Codex OpenAI tool call responses
2026-03-24 01:00:57 +08:00
Xvvln
7333619f15 fix: reject oversized downloads instead of truncating; warn on unverified fallback
- Read maxAssetDownloadSize+1 bytes and error if exceeded, preventing
  silent truncation that could write a broken management.html to disk
- Log explicit warning when fallback URL is used without digest
  verification, so users are aware of the reduced security guarantee
2026-03-24 00:27:44 +08:00
Xvvln
2db8df8e38 fix(security): harden management panel asset updater
- Abort update when SHA256 digest mismatch is detected instead of
  logging a warning and proceeding (prevents MITM asset replacement)
- Cap asset download size to 10 MB via io.LimitReader (defense-in-depth
  against OOM from oversized responses)
- Add `auto-update-panel` config option (default: false) to make the
  periodic background updater opt-in; the panel is still downloaded
  on first access when missing, but no longer silently auto-updated
  every 3 hours unless explicitly enabled
2026-03-24 00:10:04 +08:00
hechieh
e6690cb447 Refine GitHub Copilot endpoint selection
Amp-Thread-ID: https://ampcode.com/threads/T-019d14cd-bc90-70ce-b1ae-87bc97332650
Co-authored-by: Amp <amp@ampcode.com>
2026-03-22 19:43:35 +08:00
hechieh
35907416b8 Fix GitHub Copilot gpt-5.4 endpoint routing
Amp-Thread-ID: https://ampcode.com/threads/T-019d14cd-bc90-70ce-b1ae-87bc97332650
Co-authored-by: Amp <amp@ampcode.com>
2026-03-22 19:05:44 +08:00
lwiles692
7275e99b41 Update internal/auth/codebuddy/codebuddy_auth.go
Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
2026-03-19 09:46:59 +08:00
lwiles692
c28b65f849 Update internal/auth/codebuddy/codebuddy_auth.go
Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
2026-03-19 09:46:40 +08:00
Wei Lee
4022e69651 feat(auth): add CodeBuddy-CN browser OAuth authentication support 2026-03-18 17:50:12 +08:00
55 changed files with 8178 additions and 157 deletions

1
.gitignore vendored
View File

@@ -1,6 +1,7 @@
# Binaries
cli-proxy-api
cliproxy
/server
*.exe

View File

@@ -8,6 +8,7 @@ builds:
- linux
- windows
- darwin
- freebsd
goarch:
- amd64
- arm64

View File

@@ -30,6 +30,10 @@ GLM CODING PLANを10%割引で取得https://z.ai/subscribe?ic=8JVLJQFSKB
<td width="180"><a href="https://www.aicodemirror.com/register?invitecode=TJNAIF"><img src="./assets/aicodemirror.png" alt="AICodeMirror" width="150"></a></td>
<td>AICodeMirrorのスポンサーシップに感謝しますAICodeMirrorはClaude Code / Codex / Gemini CLI向けの公式高安定性リレーサービスを提供しており、エンタープライズグレードの同時接続、迅速な請求書発行、24時間365日の専任技術サポートを備えています。Claude Code / Codex / Geminiの公式チャネルが元の価格の38% / 2% / 9%で利用でき、チャージ時にはさらに割引がありますCLIProxyAPIユーザー向けの特別特典<a href="https://www.aicodemirror.com/register?invitecode=TJNAIF">こちらのリンク</a>から登録すると、初回チャージが20%割引になり、エンタープライズのお客様は最大25%割引を受けられます!</td>
</tr>
<tr>
<td width="180"><a href="https://shop.bmoplus.com/?utm_source=github"><img src="./assets/bmoplus.png" alt="BmoPlus" width="150"></a></td>
<td>本プロジェクトにご支援いただいた BmoPlus に感謝いたしますBmoPlusは、AIサブスクリプションのヘビーユーザー向けに特化した信頼性の高いAIアカウントサービスプロバイダーであり、安定した ChatGPT Plus / ChatGPT Pro (完全保証) / Claude Pro / Super Grok / Gemini Pro の公式代行チャージおよび即納アカウントを提供しています。こちらの<a href="https://shop.bmoplus.com/?utm_source=github">BmoPlus AIアカウント専門店/代行チャージ</a>経由でご登録・ご注文いただいたユーザー様は、GPTを <b>公式サイト価格の約1割90% OFF</b> という驚異的な価格でご利用いただけます!</td>
</tr>
</tbody>
</table>
@@ -58,11 +62,11 @@ GLM CODING PLANを10%割引で取得https://z.ai/subscribe?ic=8JVLJQFSKB
## はじめに
CLIProxyAPIガイド[https://help.router-for.me/ja/](https://help.router-for.me/ja/)
CLIProxyAPIガイド[https://help.router-for.me/](https://help.router-for.me/)
## 管理API
[MANAGEMENT_API.md](https://help.router-for.me/ja/management/api)を参照
[MANAGEMENT_API.md](https://help.router-for.me/management/api)を参照
## Amp CLIサポート
@@ -74,7 +78,7 @@ CLIProxyAPIは[Amp CLI](https://ampcode.com)およびAmp IDE拡張機能の統
- 利用できないモデルを代替モデルにルーティングする**モデルマッピング**(例:`claude-opus-4.5``claude-sonnet-4`
- localhostのみの管理エンドポイントによるセキュリティファーストの設計
**→ [Amp CLI統合ガイドの完全版](https://help.router-for.me/ja/agent-client/amp-cli.html)**
**→ [Amp CLI統合ガイドの完全版](https://help.router-for.me/agent-client/amp-cli.html)**
## SDKドキュメント

BIN
assets/bmoplus.png Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 28 KiB

20
cmd/mcpdebug/main.go Normal file
View File

@@ -0,0 +1,20 @@
package main
import (
"encoding/hex"
"fmt"
"os"
cursorproto "github.com/router-for-me/CLIProxyAPI/v6/internal/auth/cursor/proto"
)
func main() {
// Encode MCP result with empty execId
resultBytes := cursorproto.EncodeExecMcpResult(1, "", `{"test": "data"}`, false)
fmt.Printf("Result protobuf hex: %s\n", hex.EncodeToString(resultBytes))
fmt.Printf("Result length: %d bytes\n", len(resultBytes))
// Write to file for analysis
os.WriteFile("mcp_result.bin", resultBytes)
fmt.Println("Wrote mcp_result.bin")
}

32
cmd/protocheck/main.go Normal file
View File

@@ -0,0 +1,32 @@
package main
import (
"fmt"
cursorproto "github.com/router-for-me/CLIProxyAPI/v6/internal/auth/cursor/proto"
)
func main() {
ecm := cursorproto.NewMsg("ExecClientMessage")
// Try different field names
names := []string{
"mcp_result", "mcpResult", "McpResult", "MCP_RESULT",
"shell_result", "shellResult",
}
for _, name := range names {
fd := ecm.Descriptor().Fields().ByName(name)
if fd != nil {
fmt.Printf("Found field %q: number=%d, kind=%s\n", name, fd.Number(), fd.Kind())
} else {
fmt.Printf("Field %q NOT FOUND\n", name)
}
}
// List all fields
fmt.Println("\nAll fields in ExecClientMessage:")
for i := 0; i < ecm.Descriptor().Fields().Len(); i++ {
f := ecm.Descriptor().Fields().Get(i)
fmt.Printf(" %d: %q (number=%d)\n", i, f.Name(), f.Number())
}
}

View File

@@ -85,6 +85,7 @@ func main() {
var oauthCallbackPort int
var antigravityLogin bool
var kimiLogin bool
var cursorLogin bool
var kiroLogin bool
var kiroGoogleLogin bool
var kiroAWSLogin bool
@@ -95,6 +96,7 @@ func main() {
var kiroIDCRegion string
var kiroIDCFlow string
var githubCopilotLogin bool
var codeBuddyLogin bool
var projectID string
var vertexImport string
var configPath string
@@ -122,6 +124,7 @@ func main() {
flag.BoolVar(&noIncognito, "no-incognito", false, "Force disable incognito mode (uses existing browser session)")
flag.BoolVar(&antigravityLogin, "antigravity-login", false, "Login to Antigravity using OAuth")
flag.BoolVar(&kimiLogin, "kimi-login", false, "Login to Kimi using OAuth")
flag.BoolVar(&cursorLogin, "cursor-login", false, "Login to Cursor using OAuth")
flag.BoolVar(&kiroLogin, "kiro-login", false, "Login to Kiro using Google OAuth")
flag.BoolVar(&kiroGoogleLogin, "kiro-google-login", false, "Login to Kiro using Google OAuth (same as --kiro-login)")
flag.BoolVar(&kiroAWSLogin, "kiro-aws-login", false, "Login to Kiro using AWS Builder ID (device code flow)")
@@ -132,6 +135,7 @@ func main() {
flag.StringVar(&kiroIDCRegion, "kiro-idc-region", "", "IDC region (default: us-east-1)")
flag.StringVar(&kiroIDCFlow, "kiro-idc-flow", "", "IDC flow type: authcode (default) or device")
flag.BoolVar(&githubCopilotLogin, "github-copilot-login", false, "Login to GitHub Copilot using device flow")
flag.BoolVar(&codeBuddyLogin, "codebuddy-login", false, "Login to CodeBuddy using browser OAuth flow")
flag.StringVar(&projectID, "project_id", "", "Project ID (Gemini only, not required)")
flag.StringVar(&configPath, "config", DefaultConfigPath, "Configure File Path")
flag.StringVar(&vertexImport, "vertex-import", "", "Import Vertex service account key JSON file")
@@ -516,6 +520,9 @@ func main() {
} else if githubCopilotLogin {
// Handle GitHub Copilot login
cmd.DoGitHubCopilotLogin(cfg, options)
} else if codeBuddyLogin {
// Handle CodeBuddy login
cmd.DoCodeBuddyLogin(cfg, options)
} else if codexLogin {
// Handle Codex login
cmd.DoCodexLogin(cfg, options)
@@ -539,6 +546,8 @@ func main() {
cmd.DoGitLabTokenLogin(cfg, options)
} else if kimiLogin {
cmd.DoKimiLogin(cfg, options)
} else if cursorLogin {
cmd.DoCursorLogin(cfg, options)
} else if kiroLogin {
// For Kiro auth, default to incognito mode for multi-account support
// Users can explicitly override with --no-incognito

View File

@@ -25,6 +25,10 @@ remote-management:
# Disable the bundled management control panel asset download and HTTP route when true.
disable-control-panel: false
# Disable automatic periodic background updates of the management panel from GitHub (default: false).
# When enabled, the panel is only downloaded on first access if missing, and never auto-updated afterward.
# disable-auto-update-panel: false
# GitHub repository for the management control panel. Accepts a repository URL or releases API URL.
panel-github-repository: 'https://github.com/router-for-me/Cli-Proxy-API-Management-Center'
@@ -238,7 +242,9 @@ nonstream-keepalive-interval: 0
# - api-key: "sk-or-v1-...b781" # without proxy-url
# models: # The models supported by the provider.
# - name: "moonshotai/kimi-k2:free" # The actual model name.
# alias: "kimi-k2" # The alias used in the API.
# alias: "kimi-k2" # The alias used in the API.
# thinking: # optional: omit to default to levels ["low","medium","high"]
# levels: ["low", "medium", "high"]
# # You may repeat the same alias to build an internal model pool.
# # The client still sees only one alias in the model list.
# # Requests to that alias will round-robin across the upstream names below,

2
go.mod
View File

@@ -91,8 +91,8 @@ require (
github.com/tidwall/pretty v1.2.0 // indirect
github.com/twitchyliquid64/golang-asm v0.15.1 // indirect
github.com/ugorji/go/codec v1.2.12 // indirect
github.com/xo/terminfo v0.0.0-20220910002029-abceb7e1c41e // indirect
github.com/x448/float16 v0.8.4 // indirect
github.com/xo/terminfo v0.0.0-20220910002029-abceb7e1c41e // indirect
golang.org/x/arch v0.8.0 // indirect
golang.org/x/sys v0.38.0 // indirect
golang.org/x/text v0.31.0 // indirect

View File

@@ -11,6 +11,7 @@ import (
"errors"
"fmt"
"io"
"mime/multipart"
"net"
"net/http"
"net/url"
@@ -28,6 +29,7 @@ import (
"github.com/router-for-me/CLIProxyAPI/v6/internal/auth/claude"
"github.com/router-for-me/CLIProxyAPI/v6/internal/auth/codex"
"github.com/router-for-me/CLIProxyAPI/v6/internal/auth/copilot"
cursorauth "github.com/router-for-me/CLIProxyAPI/v6/internal/auth/cursor"
geminiAuth "github.com/router-for-me/CLIProxyAPI/v6/internal/auth/gemini"
gitlabauth "github.com/router-for-me/CLIProxyAPI/v6/internal/auth/gitlab"
iflowauth "github.com/router-for-me/CLIProxyAPI/v6/internal/auth/iflow"
@@ -66,8 +68,10 @@ type callbackForwarder struct {
}
var (
callbackForwardersMu sync.Mutex
callbackForwarders = make(map[int]*callbackForwarder)
callbackForwardersMu sync.Mutex
callbackForwarders = make(map[int]*callbackForwarder)
errAuthFileMustBeJSON = errors.New("auth file must be .json")
errAuthFileNotFound = errors.New("auth file not found")
)
func extractLastRefreshTimestamp(meta map[string]any) (time.Time, bool) {
@@ -547,10 +551,23 @@ func isRuntimeOnlyAuth(auth *coreauth.Auth) bool {
return strings.EqualFold(strings.TrimSpace(auth.Attributes["runtime_only"]), "true")
}
func isUnsafeAuthFileName(name string) bool {
if strings.TrimSpace(name) == "" {
return true
}
if strings.ContainsAny(name, "/\\") {
return true
}
if filepath.VolumeName(name) != "" {
return true
}
return false
}
// Download single auth file by name
func (h *Handler) DownloadAuthFile(c *gin.Context) {
name := c.Query("name")
if name == "" || strings.Contains(name, string(os.PathSeparator)) {
name := strings.TrimSpace(c.Query("name"))
if isUnsafeAuthFileName(name) {
c.JSON(400, gin.H{"error": "invalid name"})
return
}
@@ -579,36 +596,61 @@ func (h *Handler) UploadAuthFile(c *gin.Context) {
return
}
ctx := c.Request.Context()
if file, err := c.FormFile("file"); err == nil && file != nil {
name := filepath.Base(file.Filename)
if !strings.HasSuffix(strings.ToLower(name), ".json") {
c.JSON(400, gin.H{"error": "file must be .json"})
return
}
dst := filepath.Join(h.cfg.AuthDir, name)
if !filepath.IsAbs(dst) {
if abs, errAbs := filepath.Abs(dst); errAbs == nil {
dst = abs
}
}
if errSave := c.SaveUploadedFile(file, dst); errSave != nil {
c.JSON(500, gin.H{"error": fmt.Sprintf("failed to save file: %v", errSave)})
return
}
data, errRead := os.ReadFile(dst)
if errRead != nil {
c.JSON(500, gin.H{"error": fmt.Sprintf("failed to read saved file: %v", errRead)})
return
}
if errReg := h.registerAuthFromFile(ctx, dst, data); errReg != nil {
c.JSON(500, gin.H{"error": errReg.Error()})
return
}
c.JSON(200, gin.H{"status": "ok"})
fileHeaders, errMultipart := h.multipartAuthFileHeaders(c)
if errMultipart != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": fmt.Sprintf("invalid multipart form: %v", errMultipart)})
return
}
name := c.Query("name")
if name == "" || strings.Contains(name, string(os.PathSeparator)) {
if len(fileHeaders) == 1 {
if _, errUpload := h.storeUploadedAuthFile(ctx, fileHeaders[0]); errUpload != nil {
if errors.Is(errUpload, errAuthFileMustBeJSON) {
c.JSON(http.StatusBadRequest, gin.H{"error": "file must be .json"})
return
}
c.JSON(http.StatusInternalServerError, gin.H{"error": errUpload.Error()})
return
}
c.JSON(http.StatusOK, gin.H{"status": "ok"})
return
}
if len(fileHeaders) > 1 {
uploaded := make([]string, 0, len(fileHeaders))
failed := make([]gin.H, 0)
for _, file := range fileHeaders {
name, errUpload := h.storeUploadedAuthFile(ctx, file)
if errUpload != nil {
failureName := ""
if file != nil {
failureName = filepath.Base(file.Filename)
}
msg := errUpload.Error()
if errors.Is(errUpload, errAuthFileMustBeJSON) {
msg = "file must be .json"
}
failed = append(failed, gin.H{"name": failureName, "error": msg})
continue
}
uploaded = append(uploaded, name)
}
if len(failed) > 0 {
c.JSON(http.StatusMultiStatus, gin.H{
"status": "partial",
"uploaded": len(uploaded),
"files": uploaded,
"failed": failed,
})
return
}
c.JSON(http.StatusOK, gin.H{"status": "ok", "uploaded": len(uploaded), "files": uploaded})
return
}
if c.ContentType() == "multipart/form-data" {
c.JSON(http.StatusBadRequest, gin.H{"error": "no files uploaded"})
return
}
name := strings.TrimSpace(c.Query("name"))
if isUnsafeAuthFileName(name) {
c.JSON(400, gin.H{"error": "invalid name"})
return
}
@@ -621,17 +663,7 @@ func (h *Handler) UploadAuthFile(c *gin.Context) {
c.JSON(400, gin.H{"error": "failed to read body"})
return
}
dst := filepath.Join(h.cfg.AuthDir, filepath.Base(name))
if !filepath.IsAbs(dst) {
if abs, errAbs := filepath.Abs(dst); errAbs == nil {
dst = abs
}
}
if errWrite := os.WriteFile(dst, data, 0o600); errWrite != nil {
c.JSON(500, gin.H{"error": fmt.Sprintf("failed to write file: %v", errWrite)})
return
}
if err = h.registerAuthFromFile(ctx, dst, data); err != nil {
if err = h.writeAuthFile(ctx, filepath.Base(name), data); err != nil {
c.JSON(500, gin.H{"error": err.Error()})
return
}
@@ -678,11 +710,182 @@ func (h *Handler) DeleteAuthFile(c *gin.Context) {
c.JSON(200, gin.H{"status": "ok", "deleted": deleted})
return
}
name := c.Query("name")
if name == "" || strings.Contains(name, string(os.PathSeparator)) {
names, errNames := requestedAuthFileNamesForDelete(c)
if errNames != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": errNames.Error()})
return
}
if len(names) == 0 {
c.JSON(400, gin.H{"error": "invalid name"})
return
}
if len(names) == 1 {
if _, status, errDelete := h.deleteAuthFileByName(ctx, names[0]); errDelete != nil {
c.JSON(status, gin.H{"error": errDelete.Error()})
return
}
c.JSON(http.StatusOK, gin.H{"status": "ok"})
return
}
deletedFiles := make([]string, 0, len(names))
failed := make([]gin.H, 0)
for _, name := range names {
deletedName, _, errDelete := h.deleteAuthFileByName(ctx, name)
if errDelete != nil {
failed = append(failed, gin.H{"name": name, "error": errDelete.Error()})
continue
}
deletedFiles = append(deletedFiles, deletedName)
}
if len(failed) > 0 {
c.JSON(http.StatusMultiStatus, gin.H{
"status": "partial",
"deleted": len(deletedFiles),
"files": deletedFiles,
"failed": failed,
})
return
}
c.JSON(http.StatusOK, gin.H{"status": "ok", "deleted": len(deletedFiles), "files": deletedFiles})
}
func (h *Handler) multipartAuthFileHeaders(c *gin.Context) ([]*multipart.FileHeader, error) {
if h == nil || c == nil || c.ContentType() != "multipart/form-data" {
return nil, nil
}
form, err := c.MultipartForm()
if err != nil {
return nil, err
}
if form == nil || len(form.File) == 0 {
return nil, nil
}
keys := make([]string, 0, len(form.File))
for key := range form.File {
keys = append(keys, key)
}
sort.Strings(keys)
headers := make([]*multipart.FileHeader, 0)
for _, key := range keys {
headers = append(headers, form.File[key]...)
}
return headers, nil
}
func (h *Handler) storeUploadedAuthFile(ctx context.Context, file *multipart.FileHeader) (string, error) {
if file == nil {
return "", fmt.Errorf("no file uploaded")
}
name := filepath.Base(strings.TrimSpace(file.Filename))
if !strings.HasSuffix(strings.ToLower(name), ".json") {
return "", errAuthFileMustBeJSON
}
src, err := file.Open()
if err != nil {
return "", fmt.Errorf("failed to open uploaded file: %w", err)
}
defer src.Close()
data, err := io.ReadAll(src)
if err != nil {
return "", fmt.Errorf("failed to read uploaded file: %w", err)
}
if err := h.writeAuthFile(ctx, name, data); err != nil {
return "", err
}
return name, nil
}
func (h *Handler) writeAuthFile(ctx context.Context, name string, data []byte) error {
dst := filepath.Join(h.cfg.AuthDir, filepath.Base(name))
if !filepath.IsAbs(dst) {
if abs, errAbs := filepath.Abs(dst); errAbs == nil {
dst = abs
}
}
auth, err := h.buildAuthFromFileData(dst, data)
if err != nil {
return err
}
if errWrite := os.WriteFile(dst, data, 0o600); errWrite != nil {
return fmt.Errorf("failed to write file: %w", errWrite)
}
if err := h.upsertAuthRecord(ctx, auth); err != nil {
return err
}
return nil
}
func requestedAuthFileNamesForDelete(c *gin.Context) ([]string, error) {
if c == nil {
return nil, nil
}
names := uniqueAuthFileNames(c.QueryArray("name"))
if len(names) > 0 {
return names, nil
}
body, err := io.ReadAll(c.Request.Body)
if err != nil {
return nil, fmt.Errorf("failed to read body")
}
body = bytes.TrimSpace(body)
if len(body) == 0 {
return nil, nil
}
var objectBody struct {
Name string `json:"name"`
Names []string `json:"names"`
}
if body[0] == '[' {
var arrayBody []string
if err := json.Unmarshal(body, &arrayBody); err != nil {
return nil, fmt.Errorf("invalid request body")
}
return uniqueAuthFileNames(arrayBody), nil
}
if err := json.Unmarshal(body, &objectBody); err != nil {
return nil, fmt.Errorf("invalid request body")
}
out := make([]string, 0, len(objectBody.Names)+1)
if strings.TrimSpace(objectBody.Name) != "" {
out = append(out, objectBody.Name)
}
out = append(out, objectBody.Names...)
return uniqueAuthFileNames(out), nil
}
func uniqueAuthFileNames(names []string) []string {
if len(names) == 0 {
return nil
}
seen := make(map[string]struct{}, len(names))
out := make([]string, 0, len(names))
for _, name := range names {
name = strings.TrimSpace(name)
if name == "" {
continue
}
if _, ok := seen[name]; ok {
continue
}
seen[name] = struct{}{}
out = append(out, name)
}
return out
}
func (h *Handler) deleteAuthFileByName(ctx context.Context, name string) (string, int, error) {
name = strings.TrimSpace(name)
if isUnsafeAuthFileName(name) {
return "", http.StatusBadRequest, fmt.Errorf("invalid name")
}
targetPath := filepath.Join(h.cfg.AuthDir, filepath.Base(name))
targetID := ""
@@ -699,22 +902,19 @@ func (h *Handler) DeleteAuthFile(c *gin.Context) {
}
if errRemove := os.Remove(targetPath); errRemove != nil {
if os.IsNotExist(errRemove) {
c.JSON(404, gin.H{"error": "file not found"})
} else {
c.JSON(500, gin.H{"error": fmt.Sprintf("failed to remove file: %v", errRemove)})
return filepath.Base(name), http.StatusNotFound, errAuthFileNotFound
}
return
return filepath.Base(name), http.StatusInternalServerError, fmt.Errorf("failed to remove file: %w", errRemove)
}
if errDeleteRecord := h.deleteTokenRecord(ctx, targetPath); errDeleteRecord != nil {
c.JSON(500, gin.H{"error": errDeleteRecord.Error()})
return
return filepath.Base(name), http.StatusInternalServerError, errDeleteRecord
}
if targetID != "" {
h.disableAuth(ctx, targetID)
} else {
h.disableAuth(ctx, targetPath)
}
c.JSON(200, gin.H{"status": "ok"})
return filepath.Base(name), http.StatusOK, nil
}
func (h *Handler) findAuthForDelete(name string) *coreauth.Auth {
@@ -783,19 +983,27 @@ func (h *Handler) registerAuthFromFile(ctx context.Context, path string, data []
if h.authManager == nil {
return nil
}
auth, err := h.buildAuthFromFileData(path, data)
if err != nil {
return err
}
return h.upsertAuthRecord(ctx, auth)
}
func (h *Handler) buildAuthFromFileData(path string, data []byte) (*coreauth.Auth, error) {
if path == "" {
return fmt.Errorf("auth path is empty")
return nil, fmt.Errorf("auth path is empty")
}
if data == nil {
var err error
data, err = os.ReadFile(path)
if err != nil {
return fmt.Errorf("failed to read auth file: %w", err)
return nil, fmt.Errorf("failed to read auth file: %w", err)
}
}
metadata := make(map[string]any)
if err := json.Unmarshal(data, &metadata); err != nil {
return fmt.Errorf("invalid auth file: %w", err)
return nil, fmt.Errorf("invalid auth file: %w", err)
}
provider, _ := metadata["type"].(string)
if provider == "" {
@@ -829,13 +1037,25 @@ func (h *Handler) registerAuthFromFile(ctx context.Context, path string, data []
if hasLastRefresh {
auth.LastRefreshedAt = lastRefresh
}
if existing, ok := h.authManager.GetByID(authID); ok {
auth.CreatedAt = existing.CreatedAt
if !hasLastRefresh {
auth.LastRefreshedAt = existing.LastRefreshedAt
if h != nil && h.authManager != nil {
if existing, ok := h.authManager.GetByID(authID); ok {
auth.CreatedAt = existing.CreatedAt
if !hasLastRefresh {
auth.LastRefreshedAt = existing.LastRefreshedAt
}
auth.NextRefreshAfter = existing.NextRefreshAfter
auth.Runtime = existing.Runtime
}
auth.NextRefreshAfter = existing.NextRefreshAfter
auth.Runtime = existing.Runtime
}
return auth, nil
}
func (h *Handler) upsertAuthRecord(ctx context.Context, auth *coreauth.Auth) error {
if h == nil || h.authManager == nil || auth == nil {
return nil
}
if existing, ok := h.authManager.GetByID(auth.ID); ok {
auth.CreatedAt = existing.CreatedAt
_, err := h.authManager.Update(ctx, auth)
return err
}
@@ -3488,3 +3708,84 @@ func (h *Handler) RequestKiloToken(c *gin.Context) {
"verification_uri": resp.VerificationURL,
})
}
// RequestCursorToken initiates the Cursor PKCE authentication flow.
// Supports multiple accounts via ?label=xxx query parameter.
// The user opens the returned URL in a browser, logs in, and the server polls
// until the authentication completes.
func (h *Handler) RequestCursorToken(c *gin.Context) {
ctx := context.Background()
ctx = PopulateAuthContext(ctx, c)
label := strings.TrimSpace(c.Query("label"))
log.Infof("Initializing Cursor authentication (label=%q)...", label)
authParams, err := cursorauth.GenerateAuthParams()
if err != nil {
log.Errorf("Failed to generate Cursor auth params: %v", err)
c.JSON(http.StatusInternalServerError, gin.H{"error": "failed to generate auth params"})
return
}
state := fmt.Sprintf("cur-%d", time.Now().UnixNano())
RegisterOAuthSession(state, "cursor")
go func() {
log.Info("Waiting for Cursor authentication...")
log.Infof("Open this URL in your browser: %s", authParams.LoginURL)
tokens, errPoll := cursorauth.PollForAuth(ctx, authParams.UUID, authParams.Verifier)
if errPoll != nil {
SetOAuthSessionError(state, "Authentication failed: "+errPoll.Error())
log.Errorf("Cursor authentication failed: %v", errPoll)
return
}
// Build metadata
metadata := map[string]any{
"type": "cursor",
"access_token": tokens.AccessToken,
"refresh_token": tokens.RefreshToken,
"timestamp": time.Now().UnixMilli(),
}
// Extract expiry and account identity from JWT
expiry := cursorauth.GetTokenExpiry(tokens.AccessToken)
if !expiry.IsZero() {
metadata["expires_at"] = expiry.Format(time.RFC3339)
}
// Auto-identify account from JWT sub claim for multi-account support
sub := cursorauth.ParseJWTSub(tokens.AccessToken)
subHash := cursorauth.SubToShortHash(sub)
if sub != "" {
metadata["sub"] = sub
}
fileName := cursorauth.CredentialFileName(label, subHash)
displayLabel := cursorauth.DisplayLabel(label, subHash)
record := &coreauth.Auth{
ID: fileName,
Provider: "cursor",
FileName: fileName,
Label: displayLabel,
Metadata: metadata,
}
savedPath, errSave := h.saveTokenRecord(ctx, record)
if errSave != nil {
log.Errorf("Failed to save Cursor tokens: %v", errSave)
SetOAuthSessionError(state, "Failed to save tokens")
return
}
log.Infof("Cursor authentication successful! Token saved to %s", savedPath)
CompleteOAuthSession(state)
CompleteOAuthSessionsByProvider("cursor")
}()
c.JSON(200, gin.H{
"status": "ok",
"url": authParams.LoginURL,
"state": state,
})
}

View File

@@ -0,0 +1,197 @@
package management
import (
"bytes"
"encoding/json"
"mime/multipart"
"net/http"
"net/http/httptest"
"net/url"
"os"
"path/filepath"
"testing"
"github.com/gin-gonic/gin"
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
coreauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth"
)
func TestUploadAuthFile_BatchMultipart(t *testing.T) {
t.Setenv("MANAGEMENT_PASSWORD", "")
gin.SetMode(gin.TestMode)
authDir := t.TempDir()
manager := coreauth.NewManager(nil, nil, nil)
h := NewHandlerWithoutConfigFilePath(&config.Config{AuthDir: authDir}, manager)
files := []struct {
name string
content string
}{
{name: "alpha.json", content: `{"type":"codex","email":"alpha@example.com"}`},
{name: "beta.json", content: `{"type":"claude","email":"beta@example.com"}`},
}
var body bytes.Buffer
writer := multipart.NewWriter(&body)
for _, file := range files {
part, err := writer.CreateFormFile("file", file.name)
if err != nil {
t.Fatalf("failed to create multipart file: %v", err)
}
if _, err = part.Write([]byte(file.content)); err != nil {
t.Fatalf("failed to write multipart content: %v", err)
}
}
if err := writer.Close(); err != nil {
t.Fatalf("failed to close multipart writer: %v", err)
}
rec := httptest.NewRecorder()
ctx, _ := gin.CreateTestContext(rec)
req := httptest.NewRequest(http.MethodPost, "/v0/management/auth-files", &body)
req.Header.Set("Content-Type", writer.FormDataContentType())
ctx.Request = req
h.UploadAuthFile(ctx)
if rec.Code != http.StatusOK {
t.Fatalf("expected upload status %d, got %d with body %s", http.StatusOK, rec.Code, rec.Body.String())
}
var payload map[string]any
if err := json.Unmarshal(rec.Body.Bytes(), &payload); err != nil {
t.Fatalf("failed to decode response: %v", err)
}
if got, ok := payload["uploaded"].(float64); !ok || int(got) != len(files) {
t.Fatalf("expected uploaded=%d, got %#v", len(files), payload["uploaded"])
}
for _, file := range files {
fullPath := filepath.Join(authDir, file.name)
data, err := os.ReadFile(fullPath)
if err != nil {
t.Fatalf("expected uploaded file %s to exist: %v", file.name, err)
}
if string(data) != file.content {
t.Fatalf("expected file %s content %q, got %q", file.name, file.content, string(data))
}
}
auths := manager.List()
if len(auths) != len(files) {
t.Fatalf("expected %d auth entries, got %d", len(files), len(auths))
}
}
func TestUploadAuthFile_BatchMultipart_InvalidJSONDoesNotOverwriteExistingFile(t *testing.T) {
t.Setenv("MANAGEMENT_PASSWORD", "")
gin.SetMode(gin.TestMode)
authDir := t.TempDir()
manager := coreauth.NewManager(nil, nil, nil)
h := NewHandlerWithoutConfigFilePath(&config.Config{AuthDir: authDir}, manager)
existingName := "alpha.json"
existingContent := `{"type":"codex","email":"alpha@example.com"}`
if err := os.WriteFile(filepath.Join(authDir, existingName), []byte(existingContent), 0o600); err != nil {
t.Fatalf("failed to seed existing auth file: %v", err)
}
files := []struct {
name string
content string
}{
{name: existingName, content: `{"type":"codex"`},
{name: "beta.json", content: `{"type":"claude","email":"beta@example.com"}`},
}
var body bytes.Buffer
writer := multipart.NewWriter(&body)
for _, file := range files {
part, err := writer.CreateFormFile("file", file.name)
if err != nil {
t.Fatalf("failed to create multipart file: %v", err)
}
if _, err = part.Write([]byte(file.content)); err != nil {
t.Fatalf("failed to write multipart content: %v", err)
}
}
if err := writer.Close(); err != nil {
t.Fatalf("failed to close multipart writer: %v", err)
}
rec := httptest.NewRecorder()
ctx, _ := gin.CreateTestContext(rec)
req := httptest.NewRequest(http.MethodPost, "/v0/management/auth-files", &body)
req.Header.Set("Content-Type", writer.FormDataContentType())
ctx.Request = req
h.UploadAuthFile(ctx)
if rec.Code != http.StatusMultiStatus {
t.Fatalf("expected upload status %d, got %d with body %s", http.StatusMultiStatus, rec.Code, rec.Body.String())
}
data, err := os.ReadFile(filepath.Join(authDir, existingName))
if err != nil {
t.Fatalf("expected existing auth file to remain readable: %v", err)
}
if string(data) != existingContent {
t.Fatalf("expected existing auth file to remain %q, got %q", existingContent, string(data))
}
betaData, err := os.ReadFile(filepath.Join(authDir, "beta.json"))
if err != nil {
t.Fatalf("expected valid auth file to be created: %v", err)
}
if string(betaData) != files[1].content {
t.Fatalf("expected beta auth file content %q, got %q", files[1].content, string(betaData))
}
}
func TestDeleteAuthFile_BatchQuery(t *testing.T) {
t.Setenv("MANAGEMENT_PASSWORD", "")
gin.SetMode(gin.TestMode)
authDir := t.TempDir()
files := []string{"alpha.json", "beta.json"}
for _, name := range files {
if err := os.WriteFile(filepath.Join(authDir, name), []byte(`{"type":"codex"}`), 0o600); err != nil {
t.Fatalf("failed to write auth file %s: %v", name, err)
}
}
manager := coreauth.NewManager(nil, nil, nil)
h := NewHandlerWithoutConfigFilePath(&config.Config{AuthDir: authDir}, manager)
h.tokenStore = &memoryAuthStore{}
rec := httptest.NewRecorder()
ctx, _ := gin.CreateTestContext(rec)
req := httptest.NewRequest(
http.MethodDelete,
"/v0/management/auth-files?name="+url.QueryEscape(files[0])+"&name="+url.QueryEscape(files[1]),
nil,
)
ctx.Request = req
h.DeleteAuthFile(ctx)
if rec.Code != http.StatusOK {
t.Fatalf("expected delete status %d, got %d with body %s", http.StatusOK, rec.Code, rec.Body.String())
}
var payload map[string]any
if err := json.Unmarshal(rec.Body.Bytes(), &payload); err != nil {
t.Fatalf("failed to decode response: %v", err)
}
if got, ok := payload["deleted"].(float64); !ok || int(got) != len(files) {
t.Fatalf("expected deleted=%d, got %#v", len(files), payload["deleted"])
}
for _, name := range files {
if _, err := os.Stat(filepath.Join(authDir, name)); !os.IsNotExist(err) {
t.Fatalf("expected auth file %s to be removed, stat err: %v", name, err)
}
}
}

View File

@@ -0,0 +1,62 @@
package management
import (
"net/http"
"net/http/httptest"
"net/url"
"os"
"path/filepath"
"testing"
"github.com/gin-gonic/gin"
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
)
func TestDownloadAuthFile_ReturnsFile(t *testing.T) {
t.Setenv("MANAGEMENT_PASSWORD", "")
gin.SetMode(gin.TestMode)
authDir := t.TempDir()
fileName := "download-user.json"
expected := []byte(`{"type":"codex"}`)
if err := os.WriteFile(filepath.Join(authDir, fileName), expected, 0o600); err != nil {
t.Fatalf("failed to write auth file: %v", err)
}
h := NewHandlerWithoutConfigFilePath(&config.Config{AuthDir: authDir}, nil)
rec := httptest.NewRecorder()
ctx, _ := gin.CreateTestContext(rec)
ctx.Request = httptest.NewRequest(http.MethodGet, "/v0/management/auth-files/download?name="+url.QueryEscape(fileName), nil)
h.DownloadAuthFile(ctx)
if rec.Code != http.StatusOK {
t.Fatalf("expected download status %d, got %d with body %s", http.StatusOK, rec.Code, rec.Body.String())
}
if got := rec.Body.Bytes(); string(got) != string(expected) {
t.Fatalf("unexpected download content: %q", string(got))
}
}
func TestDownloadAuthFile_RejectsPathSeparators(t *testing.T) {
t.Setenv("MANAGEMENT_PASSWORD", "")
gin.SetMode(gin.TestMode)
h := NewHandlerWithoutConfigFilePath(&config.Config{AuthDir: t.TempDir()}, nil)
for _, name := range []string{
"../external/secret.json",
`..\\external\\secret.json`,
"nested/secret.json",
`nested\\secret.json`,
} {
rec := httptest.NewRecorder()
ctx, _ := gin.CreateTestContext(rec)
ctx.Request = httptest.NewRequest(http.MethodGet, "/v0/management/auth-files/download?name="+url.QueryEscape(name), nil)
h.DownloadAuthFile(ctx)
if rec.Code != http.StatusBadRequest {
t.Fatalf("expected %d for name %q, got %d with body %s", http.StatusBadRequest, name, rec.Code, rec.Body.String())
}
}
}

View File

@@ -0,0 +1,51 @@
//go:build windows
package management
import (
"net/http"
"net/http/httptest"
"net/url"
"os"
"path/filepath"
"testing"
"github.com/gin-gonic/gin"
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
)
func TestDownloadAuthFile_PreventsWindowsSlashTraversal(t *testing.T) {
t.Setenv("MANAGEMENT_PASSWORD", "")
gin.SetMode(gin.TestMode)
tempDir := t.TempDir()
authDir := filepath.Join(tempDir, "auth")
externalDir := filepath.Join(tempDir, "external")
if err := os.MkdirAll(authDir, 0o700); err != nil {
t.Fatalf("failed to create auth dir: %v", err)
}
if err := os.MkdirAll(externalDir, 0o700); err != nil {
t.Fatalf("failed to create external dir: %v", err)
}
secretName := "secret.json"
secretPath := filepath.Join(externalDir, secretName)
if err := os.WriteFile(secretPath, []byte(`{"secret":true}`), 0o600); err != nil {
t.Fatalf("failed to write external file: %v", err)
}
h := NewHandlerWithoutConfigFilePath(&config.Config{AuthDir: authDir}, nil)
rec := httptest.NewRecorder()
ctx, _ := gin.CreateTestContext(rec)
ctx.Request = httptest.NewRequest(
http.MethodGet,
"/v0/management/auth-files/download?name="+url.QueryEscape("../external/"+secretName),
nil,
)
h.DownloadAuthFile(ctx)
if rec.Code != http.StatusBadRequest {
t.Fatalf("expected status %d, got %d with body %s", http.StatusBadRequest, rec.Code, rec.Body.String())
}
}

View File

@@ -682,6 +682,7 @@ func (s *Server) registerManagementRoutes() {
mgmt.GET("/iflow-auth-url", s.mgmt.RequestIFlowToken)
mgmt.POST("/iflow-auth-url", s.mgmt.RequestIFlowCookieToken)
mgmt.GET("/kiro-auth-url", s.mgmt.RequestKiroToken)
mgmt.GET("/cursor-auth-url", s.mgmt.RequestCursorToken)
mgmt.GET("/github-auth-url", s.mgmt.RequestGitHubToken)
mgmt.POST("/oauth-callback", s.mgmt.PostOAuthCallback)
mgmt.GET("/get-auth-status", s.mgmt.GetAuthStatus)

View File

@@ -0,0 +1,335 @@
package codebuddy
import (
"bytes"
"context"
"encoding/base64"
"encoding/json"
"fmt"
"io"
"net/http"
"net/url"
"strings"
"time"
"github.com/google/uuid"
log "github.com/sirupsen/logrus"
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
"github.com/router-for-me/CLIProxyAPI/v6/internal/util"
)
const (
BaseURL = "https://copilot.tencent.com"
DefaultDomain = "www.codebuddy.cn"
UserAgent = "CLI/2.63.2 CodeBuddy/2.63.2"
codeBuddyStatePath = "/v2/plugin/auth/state"
codeBuddyTokenPath = "/v2/plugin/auth/token"
codeBuddyRefreshPath = "/v2/plugin/auth/token/refresh"
pollInterval = 5 * time.Second
maxPollDuration = 5 * time.Minute
codeLoginPending = 11217
codeSuccess = 0
)
type CodeBuddyAuth struct {
httpClient *http.Client
cfg *config.Config
baseURL string
}
func NewCodeBuddyAuth(cfg *config.Config) *CodeBuddyAuth {
httpClient := &http.Client{Timeout: 30 * time.Second}
if cfg != nil {
httpClient = util.SetProxy(&cfg.SDKConfig, httpClient)
}
return &CodeBuddyAuth{httpClient: httpClient, cfg: cfg, baseURL: BaseURL}
}
// AuthState holds the state and auth URL returned by the auth state API.
type AuthState struct {
State string
AuthURL string
}
// FetchAuthState calls POST /v2/plugin/auth/state?platform=CLI to get the state and login URL.
func (a *CodeBuddyAuth) FetchAuthState(ctx context.Context) (*AuthState, error) {
stateURL := fmt.Sprintf("%s%s?platform=CLI", a.baseURL, codeBuddyStatePath)
body := []byte("{}")
req, err := http.NewRequestWithContext(ctx, http.MethodPost, stateURL, bytes.NewReader(body))
if err != nil {
return nil, fmt.Errorf("codebuddy: failed to create auth state request: %w", err)
}
requestID := uuid.NewString()
req.Header.Set("Accept", "application/json, text/plain, */*")
req.Header.Set("Content-Type", "application/json")
req.Header.Set("X-Requested-With", "XMLHttpRequest")
req.Header.Set("X-Domain", "copilot.tencent.com")
req.Header.Set("X-No-Authorization", "true")
req.Header.Set("X-No-User-Id", "true")
req.Header.Set("X-No-Enterprise-Id", "true")
req.Header.Set("X-No-Department-Info", "true")
req.Header.Set("X-Product", "SaaS")
req.Header.Set("User-Agent", UserAgent)
req.Header.Set("X-Request-ID", requestID)
resp, err := a.httpClient.Do(req)
if err != nil {
return nil, fmt.Errorf("codebuddy: auth state request failed: %w", err)
}
defer func() {
if errClose := resp.Body.Close(); errClose != nil {
log.Errorf("codebuddy auth state: close body error: %v", errClose)
}
}()
bodyBytes, err := io.ReadAll(resp.Body)
if err != nil {
return nil, fmt.Errorf("codebuddy: failed to read auth state response: %w", err)
}
if resp.StatusCode != http.StatusOK {
return nil, fmt.Errorf("codebuddy: auth state request returned status %d: %s", resp.StatusCode, string(bodyBytes))
}
var result struct {
Code int `json:"code"`
Msg string `json:"msg"`
Data *struct {
State string `json:"state"`
AuthURL string `json:"authUrl"`
} `json:"data"`
}
if err = json.Unmarshal(bodyBytes, &result); err != nil {
return nil, fmt.Errorf("codebuddy: failed to parse auth state response: %w", err)
}
if result.Code != codeSuccess {
return nil, fmt.Errorf("codebuddy: auth state request failed with code %d: %s", result.Code, result.Msg)
}
if result.Data == nil || result.Data.State == "" || result.Data.AuthURL == "" {
return nil, fmt.Errorf("codebuddy: auth state response missing state or authUrl")
}
return &AuthState{
State: result.Data.State,
AuthURL: result.Data.AuthURL,
}, nil
}
type pollResponse struct {
Code int `json:"code"`
Msg string `json:"msg"`
RequestID string `json:"requestId"`
Data *struct {
AccessToken string `json:"accessToken"`
RefreshToken string `json:"refreshToken"`
ExpiresIn int64 `json:"expiresIn"`
TokenType string `json:"tokenType"`
Domain string `json:"domain"`
} `json:"data"`
}
// doPollRequest performs a single polling request, safely reading and closing the response body
func (a *CodeBuddyAuth) doPollRequest(ctx context.Context, pollURL string) ([]byte, int, error) {
req, err := http.NewRequestWithContext(ctx, http.MethodGet, pollURL, nil)
if err != nil {
return nil, 0, fmt.Errorf("%w: %v", ErrTokenFetchFailed, err)
}
a.applyPollHeaders(req)
resp, err := a.httpClient.Do(req)
if err != nil {
return nil, 0, err
}
defer func() {
if errClose := resp.Body.Close(); errClose != nil {
log.Errorf("codebuddy poll: close body error: %v", errClose)
}
}()
body, err := io.ReadAll(resp.Body)
if err != nil {
return nil, resp.StatusCode, fmt.Errorf("codebuddy poll: failed to read response body: %w", err)
}
return body, resp.StatusCode, nil
}
// PollForToken polls until the user completes browser authorization and returns auth data.
func (a *CodeBuddyAuth) PollForToken(ctx context.Context, state string) (*CodeBuddyTokenStorage, error) {
deadline := time.Now().Add(maxPollDuration)
pollURL := fmt.Sprintf("%s%s?state=%s", a.baseURL, codeBuddyTokenPath, url.QueryEscape(state))
for time.Now().Before(deadline) {
select {
case <-ctx.Done():
return nil, ctx.Err()
case <-time.After(pollInterval):
}
body, statusCode, err := a.doPollRequest(ctx, pollURL)
if err != nil {
log.Debugf("codebuddy poll: request error: %v", err)
continue
}
if statusCode != http.StatusOK {
log.Debugf("codebuddy poll: unexpected status %d", statusCode)
continue
}
var result pollResponse
if err := json.Unmarshal(body, &result); err != nil {
continue
}
switch result.Code {
case codeSuccess:
if result.Data == nil {
return nil, fmt.Errorf("%w: empty data in response", ErrTokenFetchFailed)
}
userID, _ := a.DecodeUserID(result.Data.AccessToken)
return &CodeBuddyTokenStorage{
AccessToken: result.Data.AccessToken,
RefreshToken: result.Data.RefreshToken,
ExpiresIn: result.Data.ExpiresIn,
TokenType: result.Data.TokenType,
Domain: result.Data.Domain,
UserID: userID,
Type: "codebuddy",
}, nil
case codeLoginPending:
// continue polling
default:
// TODO: when the CodeBuddy API error code for user denial is known,
// return ErrAccessDenied here instead of ErrTokenFetchFailed.
return nil, fmt.Errorf("%w: server returned code %d: %s", ErrTokenFetchFailed, result.Code, result.Msg)
}
}
return nil, ErrPollingTimeout
}
// DecodeUserID decodes the sub field from a JWT access token as the user ID.
func (a *CodeBuddyAuth) DecodeUserID(accessToken string) (string, error) {
parts := strings.Split(accessToken, ".")
if len(parts) < 2 {
return "", ErrJWTDecodeFailed
}
payload, err := base64.RawURLEncoding.DecodeString(parts[1])
if err != nil {
return "", fmt.Errorf("%w: %v", ErrJWTDecodeFailed, err)
}
var claims struct {
Sub string `json:"sub"`
}
if err := json.Unmarshal(payload, &claims); err != nil {
return "", fmt.Errorf("%w: %v", ErrJWTDecodeFailed, err)
}
if claims.Sub == "" {
return "", fmt.Errorf("%w: sub claim is empty", ErrJWTDecodeFailed)
}
return claims.Sub, nil
}
// RefreshToken exchanges a refresh token for a new access token.
// It calls POST /v2/plugin/auth/token/refresh with the required headers.
func (a *CodeBuddyAuth) RefreshToken(ctx context.Context, accessToken, refreshToken, userID, domain string) (*CodeBuddyTokenStorage, error) {
if domain == "" {
domain = DefaultDomain
}
refreshURL := fmt.Sprintf("%s%s", a.baseURL, codeBuddyRefreshPath)
body := []byte("{}")
req, err := http.NewRequestWithContext(ctx, http.MethodPost, refreshURL, bytes.NewReader(body))
if err != nil {
return nil, fmt.Errorf("codebuddy: failed to create refresh request: %w", err)
}
requestID := strings.ReplaceAll(uuid.New().String(), "-", "")
req.Header.Set("Accept", "application/json, text/plain, */*")
req.Header.Set("Content-Type", "application/json")
req.Header.Set("X-Requested-With", "XMLHttpRequest")
req.Header.Set("X-Domain", domain)
req.Header.Set("X-Refresh-Token", refreshToken)
req.Header.Set("X-Auth-Refresh-Source", "plugin")
req.Header.Set("X-Request-ID", requestID)
req.Header.Set("Authorization", "Bearer "+accessToken)
req.Header.Set("X-User-Id", userID)
req.Header.Set("X-Product", "SaaS")
req.Header.Set("User-Agent", UserAgent)
resp, err := a.httpClient.Do(req)
if err != nil {
return nil, fmt.Errorf("codebuddy: refresh request failed: %w", err)
}
defer func() {
if errClose := resp.Body.Close(); errClose != nil {
log.Errorf("codebuddy refresh: close body error: %v", errClose)
}
}()
bodyBytes, err := io.ReadAll(resp.Body)
if err != nil {
return nil, fmt.Errorf("codebuddy: failed to read refresh response: %w", err)
}
if resp.StatusCode == http.StatusUnauthorized || resp.StatusCode == http.StatusForbidden {
return nil, fmt.Errorf("codebuddy: refresh token rejected (status %d)", resp.StatusCode)
}
if resp.StatusCode != http.StatusOK {
return nil, fmt.Errorf("codebuddy: refresh failed with status %d: %s", resp.StatusCode, string(bodyBytes))
}
var result struct {
Code int `json:"code"`
Msg string `json:"msg"`
Data *struct {
AccessToken string `json:"accessToken"`
RefreshToken string `json:"refreshToken"`
ExpiresIn int64 `json:"expiresIn"`
RefreshExpiresIn int64 `json:"refreshExpiresIn"`
TokenType string `json:"tokenType"`
Domain string `json:"domain"`
} `json:"data"`
}
if err = json.Unmarshal(bodyBytes, &result); err != nil {
return nil, fmt.Errorf("codebuddy: failed to parse refresh response: %w", err)
}
if result.Code != codeSuccess {
return nil, fmt.Errorf("codebuddy: refresh failed with code %d: %s", result.Code, result.Msg)
}
if result.Data == nil {
return nil, fmt.Errorf("codebuddy: empty data in refresh response")
}
newUserID, _ := a.DecodeUserID(result.Data.AccessToken)
if newUserID == "" {
newUserID = userID
}
tokenDomain := result.Data.Domain
if tokenDomain == "" {
tokenDomain = domain
}
return &CodeBuddyTokenStorage{
AccessToken: result.Data.AccessToken,
RefreshToken: result.Data.RefreshToken,
ExpiresIn: result.Data.ExpiresIn,
RefreshExpiresIn: result.Data.RefreshExpiresIn,
TokenType: result.Data.TokenType,
Domain: tokenDomain,
UserID: newUserID,
Type: "codebuddy",
}, nil
}
func (a *CodeBuddyAuth) applyPollHeaders(req *http.Request) {
req.Header.Set("Accept", "application/json, text/plain, */*")
req.Header.Set("User-Agent", UserAgent)
req.Header.Set("X-Requested-With", "XMLHttpRequest")
req.Header.Set("X-No-Authorization", "true")
req.Header.Set("X-No-User-Id", "true")
req.Header.Set("X-No-Enterprise-Id", "true")
req.Header.Set("X-No-Department-Info", "true")
req.Header.Set("X-Product", "SaaS")
}

View File

@@ -0,0 +1,285 @@
package codebuddy
import (
"context"
"encoding/base64"
"encoding/json"
"net/http"
"net/http/httptest"
"testing"
)
// newTestAuth creates a CodeBuddyAuth pointing at the given test server.
func newTestAuth(serverURL string) *CodeBuddyAuth {
return &CodeBuddyAuth{
httpClient: http.DefaultClient,
baseURL: serverURL,
}
}
// fakeJWT builds a minimal JWT with the given sub claim for testing.
func fakeJWT(sub string) string {
header := base64.RawURLEncoding.EncodeToString([]byte(`{"alg":"RS256"}`))
payload, _ := json.Marshal(map[string]any{"sub": sub, "iat": 1234567890})
encodedPayload := base64.RawURLEncoding.EncodeToString(payload)
return header + "." + encodedPayload + ".sig"
}
// --- FetchAuthState tests ---
func TestFetchAuthState_Success(t *testing.T) {
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodPost {
t.Errorf("expected POST, got %s", r.Method)
}
if got := r.URL.Path; got != codeBuddyStatePath {
t.Errorf("expected path %s, got %s", codeBuddyStatePath, got)
}
if got := r.URL.Query().Get("platform"); got != "CLI" {
t.Errorf("expected platform=CLI, got %s", got)
}
if got := r.Header.Get("User-Agent"); got != UserAgent {
t.Errorf("expected User-Agent %s, got %s", UserAgent, got)
}
w.Header().Set("Content-Type", "application/json")
_ = json.NewEncoder(w).Encode(map[string]any{
"code": 0,
"msg": "ok",
"data": map[string]any{
"state": "test-state-abc",
"authUrl": "https://example.com/login?state=test-state-abc",
},
})
}))
defer srv.Close()
auth := newTestAuth(srv.URL)
result, err := auth.FetchAuthState(context.Background())
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if result.State != "test-state-abc" {
t.Errorf("expected state 'test-state-abc', got '%s'", result.State)
}
if result.AuthURL != "https://example.com/login?state=test-state-abc" {
t.Errorf("unexpected authURL: %s", result.AuthURL)
}
}
func TestFetchAuthState_NonOKStatus(t *testing.T) {
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
w.WriteHeader(http.StatusInternalServerError)
_, _ = w.Write([]byte("internal error"))
}))
defer srv.Close()
auth := newTestAuth(srv.URL)
_, err := auth.FetchAuthState(context.Background())
if err == nil {
t.Fatal("expected error for non-200 status")
}
}
func TestFetchAuthState_APIErrorCode(t *testing.T) {
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
_ = json.NewEncoder(w).Encode(map[string]any{
"code": 10001,
"msg": "rate limited",
})
}))
defer srv.Close()
auth := newTestAuth(srv.URL)
_, err := auth.FetchAuthState(context.Background())
if err == nil {
t.Fatal("expected error for non-zero code")
}
}
func TestFetchAuthState_MissingData(t *testing.T) {
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
_ = json.NewEncoder(w).Encode(map[string]any{
"code": 0,
"msg": "ok",
"data": map[string]any{
"state": "",
"authUrl": "",
},
})
}))
defer srv.Close()
auth := newTestAuth(srv.URL)
_, err := auth.FetchAuthState(context.Background())
if err == nil {
t.Fatal("expected error for empty state/authUrl")
}
}
// --- RefreshToken tests ---
func TestRefreshToken_Success(t *testing.T) {
newAccessToken := fakeJWT("refreshed-user-456")
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodPost {
t.Errorf("expected POST, got %s", r.Method)
}
if got := r.URL.Path; got != codeBuddyRefreshPath {
t.Errorf("expected path %s, got %s", codeBuddyRefreshPath, got)
}
if got := r.Header.Get("X-Refresh-Token"); got != "old-refresh-token" {
t.Errorf("expected X-Refresh-Token 'old-refresh-token', got '%s'", got)
}
if got := r.Header.Get("Authorization"); got != "Bearer old-access-token" {
t.Errorf("expected Authorization 'Bearer old-access-token', got '%s'", got)
}
if got := r.Header.Get("X-User-Id"); got != "user-123" {
t.Errorf("expected X-User-Id 'user-123', got '%s'", got)
}
if got := r.Header.Get("X-Domain"); got != "custom.domain.com" {
t.Errorf("expected X-Domain 'custom.domain.com', got '%s'", got)
}
w.Header().Set("Content-Type", "application/json")
_ = json.NewEncoder(w).Encode(map[string]any{
"code": 0,
"msg": "ok",
"data": map[string]any{
"accessToken": newAccessToken,
"refreshToken": "new-refresh-token",
"expiresIn": 3600,
"refreshExpiresIn": 86400,
"tokenType": "bearer",
"domain": "custom.domain.com",
},
})
}))
defer srv.Close()
auth := newTestAuth(srv.URL)
storage, err := auth.RefreshToken(context.Background(), "old-access-token", "old-refresh-token", "user-123", "custom.domain.com")
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if storage.AccessToken != newAccessToken {
t.Errorf("expected new access token, got '%s'", storage.AccessToken)
}
if storage.RefreshToken != "new-refresh-token" {
t.Errorf("expected 'new-refresh-token', got '%s'", storage.RefreshToken)
}
if storage.UserID != "refreshed-user-456" {
t.Errorf("expected userID 'refreshed-user-456', got '%s'", storage.UserID)
}
if storage.ExpiresIn != 3600 {
t.Errorf("expected expiresIn 3600, got %d", storage.ExpiresIn)
}
if storage.RefreshExpiresIn != 86400 {
t.Errorf("expected refreshExpiresIn 86400, got %d", storage.RefreshExpiresIn)
}
if storage.Domain != "custom.domain.com" {
t.Errorf("expected domain 'custom.domain.com', got '%s'", storage.Domain)
}
if storage.Type != "codebuddy" {
t.Errorf("expected type 'codebuddy', got '%s'", storage.Type)
}
}
func TestRefreshToken_DefaultDomain(t *testing.T) {
var receivedDomain string
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
receivedDomain = r.Header.Get("X-Domain")
_ = json.NewEncoder(w).Encode(map[string]any{
"code": 0,
"msg": "ok",
"data": map[string]any{
"accessToken": fakeJWT("user-1"),
"refreshToken": "rt",
"expiresIn": 3600,
"tokenType": "bearer",
"domain": DefaultDomain,
},
})
}))
defer srv.Close()
auth := newTestAuth(srv.URL)
_, err := auth.RefreshToken(context.Background(), "at", "rt", "uid", "")
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if receivedDomain != DefaultDomain {
t.Errorf("expected default domain '%s', got '%s'", DefaultDomain, receivedDomain)
}
}
func TestRefreshToken_Unauthorized(t *testing.T) {
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
w.WriteHeader(http.StatusUnauthorized)
}))
defer srv.Close()
auth := newTestAuth(srv.URL)
_, err := auth.RefreshToken(context.Background(), "at", "rt", "uid", "d")
if err == nil {
t.Fatal("expected error for 401 response")
}
}
func TestRefreshToken_Forbidden(t *testing.T) {
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
w.WriteHeader(http.StatusForbidden)
}))
defer srv.Close()
auth := newTestAuth(srv.URL)
_, err := auth.RefreshToken(context.Background(), "at", "rt", "uid", "d")
if err == nil {
t.Fatal("expected error for 403 response")
}
}
func TestRefreshToken_APIErrorCode(t *testing.T) {
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
_ = json.NewEncoder(w).Encode(map[string]any{
"code": 40001,
"msg": "invalid refresh token",
})
}))
defer srv.Close()
auth := newTestAuth(srv.URL)
_, err := auth.RefreshToken(context.Background(), "at", "rt", "uid", "d")
if err == nil {
t.Fatal("expected error for non-zero API code")
}
}
func TestRefreshToken_FallbackUserIDAndDomain(t *testing.T) {
// When the new access token cannot be decoded for userID, it should fall back to the provided one.
// When the response domain is empty, it should fall back to the request domain.
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
_ = json.NewEncoder(w).Encode(map[string]any{
"code": 0,
"msg": "ok",
"data": map[string]any{
"accessToken": "not-a-valid-jwt",
"refreshToken": "new-rt",
"expiresIn": 7200,
"tokenType": "bearer",
"domain": "",
},
})
}))
defer srv.Close()
auth := newTestAuth(srv.URL)
storage, err := auth.RefreshToken(context.Background(), "at", "rt", "original-uid", "original.domain.com")
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if storage.UserID != "original-uid" {
t.Errorf("expected fallback userID 'original-uid', got '%s'", storage.UserID)
}
if storage.Domain != "original.domain.com" {
t.Errorf("expected fallback domain 'original.domain.com', got '%s'", storage.Domain)
}
}

View File

@@ -0,0 +1,22 @@
package codebuddy_test
import (
"testing"
"github.com/router-for-me/CLIProxyAPI/v6/internal/auth/codebuddy"
)
func TestDecodeUserID_ValidJWT(t *testing.T) {
// JWT payload: {"sub":"test-user-id-123","iat":1234567890}
// base64url encode: eyJzdWIiOiJ0ZXN0LXVzZXItaWQtMTIzIiwiaWF0IjoxMjM0NTY3ODkwfQ
token := "eyJhbGciOiJSUzI1NiJ9.eyJzdWIiOiJ0ZXN0LXVzZXItaWQtMTIzIiwiaWF0IjoxMjM0NTY3ODkwfQ.sig"
auth := codebuddy.NewCodeBuddyAuth(nil)
userID, err := auth.DecodeUserID(token)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if userID != "test-user-id-123" {
t.Errorf("expected 'test-user-id-123', got '%s'", userID)
}
}

View File

@@ -0,0 +1,25 @@
package codebuddy
import "errors"
var (
ErrPollingTimeout = errors.New("codebuddy: polling timeout, user did not authorize in time")
ErrAccessDenied = errors.New("codebuddy: access denied by user")
ErrTokenFetchFailed = errors.New("codebuddy: failed to fetch token from server")
ErrJWTDecodeFailed = errors.New("codebuddy: failed to decode JWT token")
)
func GetUserFriendlyMessage(err error) string {
switch {
case errors.Is(err, ErrPollingTimeout):
return "Authentication timed out. Please try again."
case errors.Is(err, ErrAccessDenied):
return "Access denied. Please try again and approve the login request."
case errors.Is(err, ErrJWTDecodeFailed):
return "Failed to decode token. Please try logging in again."
case errors.Is(err, ErrTokenFetchFailed):
return "Failed to fetch token from server. Please try again."
default:
return "Authentication failed: " + err.Error()
}
}

View File

@@ -0,0 +1,65 @@
// Package codebuddy provides authentication and token management functionality
// for CodeBuddy AI services. It handles OAuth2 token storage, serialization,
// and retrieval for maintaining authenticated sessions with the CodeBuddy API.
package codebuddy
import (
"encoding/json"
"fmt"
"os"
"path/filepath"
"github.com/router-for-me/CLIProxyAPI/v6/internal/misc"
)
// CodeBuddyTokenStorage stores OAuth token information for CodeBuddy API authentication.
// It maintains compatibility with the existing auth system while adding CodeBuddy-specific fields
// for managing access tokens and user account information.
type CodeBuddyTokenStorage struct {
// AccessToken is the OAuth2 access token used for authenticating API requests.
AccessToken string `json:"access_token"`
// RefreshToken is the OAuth2 refresh token used to obtain new access tokens.
RefreshToken string `json:"refresh_token"`
// ExpiresIn is the number of seconds until the access token expires.
ExpiresIn int64 `json:"expires_in"`
// RefreshExpiresIn is the number of seconds until the refresh token expires.
RefreshExpiresIn int64 `json:"refresh_expires_in,omitempty"`
// TokenType is the type of token, typically "bearer".
TokenType string `json:"token_type"`
// Domain is the CodeBuddy service domain/region.
Domain string `json:"domain"`
// UserID is the user ID associated with this token.
UserID string `json:"user_id"`
// Type indicates the authentication provider type, always "codebuddy" for this storage.
Type string `json:"type"`
}
// SaveTokenToFile serializes the CodeBuddy token storage to a JSON file.
// This method creates the necessary directory structure and writes the token
// data in JSON format to the specified file path for persistent storage.
//
// Parameters:
// - authFilePath: The full path where the token file should be saved
//
// Returns:
// - error: An error if the operation fails, nil otherwise
func (s *CodeBuddyTokenStorage) SaveTokenToFile(authFilePath string) error {
misc.LogSavingCredentials(authFilePath)
s.Type = "codebuddy"
if err := os.MkdirAll(filepath.Dir(authFilePath), 0700); err != nil {
return fmt.Errorf("failed to create directory: %w", err)
}
f, err := os.OpenFile(authFilePath, os.O_WRONLY|os.O_CREATE|os.O_TRUNC, 0600)
if err != nil {
return fmt.Errorf("failed to create token file: %w", err)
}
defer func() {
_ = f.Close()
}()
if err = json.NewEncoder(f).Encode(s); err != nil {
return fmt.Errorf("failed to write token to file: %w", err)
}
return nil
}

View File

@@ -0,0 +1,33 @@
package cursor
import (
"fmt"
"strings"
)
// CredentialFileName returns the filename used to persist Cursor credentials.
// Priority: explicit label > auto-generated from JWT sub hash.
// If both label and subHash are empty, falls back to "cursor.json".
func CredentialFileName(label, subHash string) string {
label = strings.TrimSpace(label)
subHash = strings.TrimSpace(subHash)
if label != "" {
return fmt.Sprintf("cursor.%s.json", label)
}
if subHash != "" {
return fmt.Sprintf("cursor.%s.json", subHash)
}
return "cursor.json"
}
// DisplayLabel returns a human-readable label for the Cursor account.
func DisplayLabel(label, subHash string) string {
label = strings.TrimSpace(label)
if label != "" {
return "Cursor " + label
}
if subHash != "" {
return "Cursor " + subHash
}
return "Cursor User"
}

View File

@@ -0,0 +1,249 @@
// Package cursor implements Cursor OAuth PKCE authentication and token refresh.
package cursor
import (
"context"
"crypto/rand"
"crypto/sha256"
"encoding/base64"
"encoding/json"
"fmt"
"io"
"math"
"net/http"
"strings"
"time"
)
const (
CursorLoginURL = "https://cursor.com/loginDeepControl"
CursorPollURL = "https://api2.cursor.sh/auth/poll"
CursorRefreshURL = "https://api2.cursor.sh/auth/exchange_user_api_key"
pollMaxAttempts = 150
pollBaseDelay = 1 * time.Second
pollMaxDelay = 10 * time.Second
pollBackoffMultiply = 1.2
maxConsecutiveErrors = 10
)
// AuthParams holds the PKCE parameters for Cursor login.
type AuthParams struct {
Verifier string
Challenge string
UUID string
LoginURL string
}
// TokenPair holds the access and refresh tokens from Cursor.
type TokenPair struct {
AccessToken string `json:"accessToken"`
RefreshToken string `json:"refreshToken"`
}
// GeneratePKCE creates a PKCE verifier and challenge pair.
func GeneratePKCE() (verifier, challenge string, err error) {
verifierBytes := make([]byte, 96)
if _, err = rand.Read(verifierBytes); err != nil {
return "", "", fmt.Errorf("cursor: failed to generate PKCE verifier: %w", err)
}
verifier = base64.RawURLEncoding.EncodeToString(verifierBytes)
h := sha256.Sum256([]byte(verifier))
challenge = base64.RawURLEncoding.EncodeToString(h[:])
return verifier, challenge, nil
}
// GenerateAuthParams creates the full set of auth params for Cursor login.
func GenerateAuthParams() (*AuthParams, error) {
verifier, challenge, err := GeneratePKCE()
if err != nil {
return nil, err
}
uuidBytes := make([]byte, 16)
if _, err = rand.Read(uuidBytes); err != nil {
return nil, fmt.Errorf("cursor: failed to generate UUID: %w", err)
}
uuid := fmt.Sprintf("%x-%x-%x-%x-%x",
uuidBytes[0:4], uuidBytes[4:6], uuidBytes[6:8], uuidBytes[8:10], uuidBytes[10:16])
loginURL := fmt.Sprintf("%s?challenge=%s&uuid=%s&mode=login&redirectTarget=cli",
CursorLoginURL, challenge, uuid)
return &AuthParams{
Verifier: verifier,
Challenge: challenge,
UUID: uuid,
LoginURL: loginURL,
}, nil
}
// PollForAuth polls the Cursor auth endpoint until the user completes login.
func PollForAuth(ctx context.Context, uuid, verifier string) (*TokenPair, error) {
delay := pollBaseDelay
consecutiveErrors := 0
client := &http.Client{Timeout: 10 * time.Second}
for attempt := 0; attempt < pollMaxAttempts; attempt++ {
select {
case <-ctx.Done():
return nil, ctx.Err()
case <-time.After(delay):
}
url := fmt.Sprintf("%s?uuid=%s&verifier=%s", CursorPollURL, uuid, verifier)
req, err := http.NewRequestWithContext(ctx, http.MethodGet, url, nil)
if err != nil {
return nil, fmt.Errorf("cursor: failed to create poll request: %w", err)
}
resp, err := client.Do(req)
if err != nil {
consecutiveErrors++
if consecutiveErrors >= maxConsecutiveErrors {
return nil, fmt.Errorf("cursor: too many consecutive poll errors (last: %v)", err)
}
delay = minDuration(time.Duration(float64(delay)*pollBackoffMultiply), pollMaxDelay)
continue
}
body, _ := io.ReadAll(resp.Body)
resp.Body.Close()
if resp.StatusCode == http.StatusNotFound {
// Still waiting for user to authorize
consecutiveErrors = 0
delay = minDuration(time.Duration(float64(delay)*pollBackoffMultiply), pollMaxDelay)
continue
}
if resp.StatusCode >= 200 && resp.StatusCode < 300 {
var tokens TokenPair
if err := json.Unmarshal(body, &tokens); err != nil {
return nil, fmt.Errorf("cursor: failed to parse auth response: %w", err)
}
return &tokens, nil
}
return nil, fmt.Errorf("cursor: poll failed with status %d: %s", resp.StatusCode, string(body))
}
return nil, fmt.Errorf("cursor: authentication polling timeout (waited ~%.0f seconds)",
float64(pollMaxAttempts)*pollMaxDelay.Seconds()/2)
}
// RefreshToken refreshes a Cursor access token using the refresh token.
func RefreshToken(ctx context.Context, refreshToken string) (*TokenPair, error) {
client := &http.Client{Timeout: 10 * time.Second}
req, err := http.NewRequestWithContext(ctx, http.MethodPost, CursorRefreshURL,
strings.NewReader("{}"))
if err != nil {
return nil, fmt.Errorf("cursor: failed to create refresh request: %w", err)
}
req.Header.Set("Authorization", "Bearer "+refreshToken)
req.Header.Set("Content-Type", "application/json")
resp, err := client.Do(req)
if err != nil {
return nil, fmt.Errorf("cursor: token refresh request failed: %w", err)
}
defer resp.Body.Close()
body, _ := io.ReadAll(resp.Body)
if resp.StatusCode < 200 || resp.StatusCode >= 300 {
return nil, fmt.Errorf("cursor: token refresh failed (status %d): %s", resp.StatusCode, string(body))
}
var tokens TokenPair
if err := json.Unmarshal(body, &tokens); err != nil {
return nil, fmt.Errorf("cursor: failed to parse refresh response: %w", err)
}
// Keep original refresh token if not returned
if tokens.RefreshToken == "" {
tokens.RefreshToken = refreshToken
}
return &tokens, nil
}
// ParseJWTSub extracts the "sub" claim from a Cursor JWT access token.
// Cursor JWTs contain "sub" like "auth0|user_XXXX" which uniquely identifies
// the account. Returns empty string if parsing fails.
func ParseJWTSub(token string) string {
decoded := decodeJWTPayload(token)
if decoded == nil {
return ""
}
var claims struct {
Sub string `json:"sub"`
}
if err := json.Unmarshal(decoded, &claims); err != nil {
return ""
}
return claims.Sub
}
// SubToShortHash converts a JWT sub claim to a short hex hash for use in filenames.
// e.g. "auth0|user_2x..." → "a3f8b2c1"
func SubToShortHash(sub string) string {
if sub == "" {
return ""
}
h := sha256.Sum256([]byte(sub))
return fmt.Sprintf("%x", h[:4]) // 8 hex chars
}
// decodeJWTPayload decodes the payload (middle) part of a JWT.
func decodeJWTPayload(token string) []byte {
parts := strings.Split(token, ".")
if len(parts) != 3 {
return nil
}
payload := parts[1]
switch len(payload) % 4 {
case 2:
payload += "=="
case 3:
payload += "="
}
payload = strings.ReplaceAll(payload, "-", "+")
payload = strings.ReplaceAll(payload, "_", "/")
decoded, err := base64.StdEncoding.DecodeString(payload)
if err != nil {
return nil
}
return decoded
}
// GetTokenExpiry extracts the JWT expiry from an access token with a 5-minute safety margin.
// Falls back to 1 hour from now if the token can't be parsed.
func GetTokenExpiry(token string) time.Time {
decoded := decodeJWTPayload(token)
if decoded == nil {
return time.Now().Add(1 * time.Hour)
}
var claims struct {
Exp float64 `json:"exp"`
}
if err := json.Unmarshal(decoded, &claims); err != nil || claims.Exp == 0 {
return time.Now().Add(1 * time.Hour)
}
sec, frac := math.Modf(claims.Exp)
expiry := time.Unix(int64(sec), int64(frac*1e9))
// Subtract 5-minute safety margin
return expiry.Add(-5 * time.Minute)
}
func minDuration(a, b time.Duration) time.Duration {
if a < b {
return a
}
return b
}

View File

@@ -0,0 +1,84 @@
package proto
import (
"encoding/binary"
"encoding/json"
"fmt"
)
const (
// ConnectEndStreamFlag marks the end-of-stream frame (trailers).
ConnectEndStreamFlag byte = 0x02
// ConnectCompressionFlag indicates the payload is compressed (not supported).
ConnectCompressionFlag byte = 0x01
// ConnectFrameHeaderSize is the fixed 5-byte frame header.
ConnectFrameHeaderSize = 5
)
// FrameConnectMessage wraps a protobuf payload in a Connect frame.
// Frame format: [1 byte flags][4 bytes payload length (big-endian)][payload]
func FrameConnectMessage(data []byte, flags byte) []byte {
frame := make([]byte, ConnectFrameHeaderSize+len(data))
frame[0] = flags
binary.BigEndian.PutUint32(frame[1:5], uint32(len(data)))
copy(frame[5:], data)
return frame
}
// ParseConnectFrame extracts one frame from a buffer.
// Returns (flags, payload, bytesConsumed, ok).
// ok is false when the buffer is too short for a complete frame.
func ParseConnectFrame(buf []byte) (flags byte, payload []byte, consumed int, ok bool) {
if len(buf) < ConnectFrameHeaderSize {
return 0, nil, 0, false
}
flags = buf[0]
length := binary.BigEndian.Uint32(buf[1:5])
total := ConnectFrameHeaderSize + int(length)
if len(buf) < total {
return 0, nil, 0, false
}
return flags, buf[5:total], total, true
}
// ConnectError is a structured error from the Connect protocol end-of-stream trailer.
// The Code field contains the server-defined error code (e.g. gRPC standard codes
// like "resource_exhausted", "unauthenticated", "permission_denied", "unavailable").
type ConnectError struct {
Code string // server-defined error code
Message string // human-readable error description
}
func (e *ConnectError) Error() string {
return fmt.Sprintf("Connect error %s: %s", e.Code, e.Message)
}
// ParseConnectEndStream parses a Connect end-of-stream frame payload (JSON).
// Returns nil if there is no error in the trailer.
// On error, returns a *ConnectError with the server's error code and message.
func ParseConnectEndStream(data []byte) error {
if len(data) == 0 {
return nil
}
var trailer struct {
Error *struct {
Code string `json:"code"`
Message string `json:"message"`
} `json:"error"`
}
if err := json.Unmarshal(data, &trailer); err != nil {
return fmt.Errorf("failed to parse Connect end stream: %w", err)
}
if trailer.Error != nil {
code := trailer.Error.Code
if code == "" {
code = "unknown"
}
msg := trailer.Error.Message
if msg == "" {
msg = "Unknown error"
}
return &ConnectError{Code: code, Message: msg}
}
return nil
}

View File

@@ -0,0 +1,564 @@
package proto
import (
"encoding/hex"
"fmt"
log "github.com/sirupsen/logrus"
"google.golang.org/protobuf/encoding/protowire"
)
// ServerMessageType identifies the kind of decoded server message.
type ServerMessageType int
const (
ServerMsgUnknown ServerMessageType = iota
ServerMsgTextDelta // Text content delta
ServerMsgThinkingDelta // Thinking/reasoning delta
ServerMsgThinkingCompleted // Thinking completed
ServerMsgKvGetBlob // Server wants a blob
ServerMsgKvSetBlob // Server wants to store a blob
ServerMsgExecRequestCtx // Server requests context (tools, etc.)
ServerMsgExecMcpArgs // Server wants MCP tool execution
ServerMsgExecShellArgs // Rejected: shell command
ServerMsgExecReadArgs // Rejected: file read
ServerMsgExecWriteArgs // Rejected: file write
ServerMsgExecDeleteArgs // Rejected: file delete
ServerMsgExecLsArgs // Rejected: directory listing
ServerMsgExecGrepArgs // Rejected: grep search
ServerMsgExecFetchArgs // Rejected: HTTP fetch
ServerMsgExecDiagnostics // Respond with empty diagnostics
ServerMsgExecShellStream // Rejected: shell stream
ServerMsgExecBgShellSpawn // Rejected: background shell
ServerMsgExecWriteShellStdin // Rejected: write shell stdin
ServerMsgExecOther // Other exec types (respond with empty)
ServerMsgTurnEnded // Turn has ended (no more output)
ServerMsgHeartbeat // Server heartbeat
ServerMsgTokenDelta // Token usage delta
ServerMsgCheckpoint // Conversation checkpoint update
)
// DecodedServerMessage holds parsed data from an AgentServerMessage.
type DecodedServerMessage struct {
Type ServerMessageType
// For text/thinking deltas
Text string
// For KV messages
KvId uint32
BlobId []byte // hex-encoded blob ID
BlobData []byte // for setBlobArgs
// For exec messages
ExecMsgId uint32
ExecId string
// For MCP args
McpToolName string
McpToolCallId string
McpArgs map[string][]byte // arg name -> protobuf-encoded value
// For rejection context
Path string
Command string
WorkingDirectory string
Url string
// For other exec - the raw field number for building a response
ExecFieldNumber int
// For TokenDeltaUpdate
TokenDelta int64
// For conversation checkpoint update (raw bytes, not decoded)
CheckpointData []byte
}
// DecodeAgentServerMessage parses an AgentServerMessage and returns
// a structured representation of the first meaningful message found.
func DecodeAgentServerMessage(data []byte) (*DecodedServerMessage, error) {
msg := &DecodedServerMessage{Type: ServerMsgUnknown}
for len(data) > 0 {
num, typ, n := protowire.ConsumeTag(data)
if n < 0 {
return msg, fmt.Errorf("invalid tag")
}
data = data[n:]
switch typ {
case protowire.BytesType:
val, n := protowire.ConsumeBytes(data)
if n < 0 {
return msg, fmt.Errorf("invalid bytes field %d", num)
}
data = data[n:]
// Debug: log top-level ASM fields
log.Debugf("DecodeAgentServerMessage: found ASM field %d, len=%d", num, len(val))
switch num {
case ASM_InteractionUpdate:
log.Debugf("DecodeAgentServerMessage: calling decodeInteractionUpdate")
decodeInteractionUpdate(val, msg)
case ASM_ExecServerMessage:
log.Debugf("DecodeAgentServerMessage: calling decodeExecServerMessage")
decodeExecServerMessage(val, msg)
case ASM_KvServerMessage:
decodeKvServerMessage(val, msg)
case ASM_ConversationCheckpoint:
msg.Type = ServerMsgCheckpoint
msg.CheckpointData = append([]byte(nil), val...) // copy raw bytes
log.Debugf("DecodeAgentServerMessage: captured checkpoint %d bytes", len(val))
}
case protowire.VarintType:
_, n := protowire.ConsumeVarint(data)
if n < 0 {
return msg, fmt.Errorf("invalid varint field %d", num)
}
data = data[n:]
default:
// Skip unknown wire types
n := protowire.ConsumeFieldValue(num, typ, data)
if n < 0 {
return msg, fmt.Errorf("invalid field %d", num)
}
data = data[n:]
}
}
return msg, nil
}
func decodeInteractionUpdate(data []byte, msg *DecodedServerMessage) {
log.Debugf("decodeInteractionUpdate: input len=%d, hex=%x", len(data), data)
for len(data) > 0 {
num, typ, n := protowire.ConsumeTag(data)
if n < 0 {
log.Debugf("decodeInteractionUpdate: invalid tag, remaining=%x", data)
return
}
data = data[n:]
log.Debugf("decodeInteractionUpdate: field=%d wire=%d remaining=%d bytes", num, typ, len(data))
if typ == protowire.BytesType {
val, n := protowire.ConsumeBytes(data)
if n < 0 {
log.Debugf("decodeInteractionUpdate: invalid bytes field %d", num)
return
}
data = data[n:]
log.Debugf("decodeInteractionUpdate: field %d content len=%d, first 20 bytes: %x", num, len(val), val[:min(20, len(val))])
switch num {
case IU_TextDelta:
msg.Type = ServerMsgTextDelta
msg.Text = decodeStringField(val, TDU_Text)
log.Debugf("decodeInteractionUpdate: TextDelta text=%q", msg.Text)
case IU_ThinkingDelta:
msg.Type = ServerMsgThinkingDelta
msg.Text = decodeStringField(val, TKD_Text)
log.Debugf("decodeInteractionUpdate: ThinkingDelta text=%q", msg.Text)
case IU_ThinkingCompleted:
msg.Type = ServerMsgThinkingCompleted
log.Debugf("decodeInteractionUpdate: ThinkingCompleted")
case 2:
// tool_call_started - ignore but log
log.Debugf("decodeInteractionUpdate: ToolCallStarted (ignored)")
case 3:
// tool_call_completed - ignore but log
log.Debugf("decodeInteractionUpdate: ToolCallCompleted (ignored)")
case 8:
// token_delta - extract token count
msg.Type = ServerMsgTokenDelta
msg.TokenDelta = decodeVarintField(val, 1)
log.Debugf("decodeInteractionUpdate: TokenDeltaUpdate tokens=%d", msg.TokenDelta)
case 13:
// heartbeat from server
msg.Type = ServerMsgHeartbeat
case 14:
// turn_ended - critical: model finished generating
msg.Type = ServerMsgTurnEnded
log.Debugf("decodeInteractionUpdate: TurnEndedUpdate - stream should end")
case 16:
// step_started - ignore
log.Debugf("decodeInteractionUpdate: StepStartedUpdate (ignored)")
case 17:
// step_completed - ignore
log.Debugf("decodeInteractionUpdate: StepCompletedUpdate (ignored)")
default:
log.Debugf("decodeInteractionUpdate: unknown field %d", num)
}
} else {
n := protowire.ConsumeFieldValue(num, typ, data)
if n < 0 {
return
}
data = data[n:]
}
}
}
func decodeKvServerMessage(data []byte, msg *DecodedServerMessage) {
for len(data) > 0 {
num, typ, n := protowire.ConsumeTag(data)
if n < 0 {
return
}
data = data[n:]
switch typ {
case protowire.VarintType:
val, n := protowire.ConsumeVarint(data)
if n < 0 {
return
}
data = data[n:]
if num == KSM_Id {
msg.KvId = uint32(val)
}
case protowire.BytesType:
val, n := protowire.ConsumeBytes(data)
if n < 0 {
return
}
data = data[n:]
switch num {
case KSM_GetBlobArgs:
msg.Type = ServerMsgKvGetBlob
msg.BlobId = decodeBytesField(val, GBA_BlobId)
case KSM_SetBlobArgs:
msg.Type = ServerMsgKvSetBlob
decodeSetBlobArgs(val, msg)
}
default:
n := protowire.ConsumeFieldValue(num, typ, data)
if n < 0 {
return
}
data = data[n:]
}
}
}
func decodeSetBlobArgs(data []byte, msg *DecodedServerMessage) {
for len(data) > 0 {
num, typ, n := protowire.ConsumeTag(data)
if n < 0 {
return
}
data = data[n:]
if typ == protowire.BytesType {
val, n := protowire.ConsumeBytes(data)
if n < 0 {
return
}
data = data[n:]
switch num {
case SBA_BlobId:
msg.BlobId = val
case SBA_BlobData:
msg.BlobData = val
}
} else {
n := protowire.ConsumeFieldValue(num, typ, data)
if n < 0 {
return
}
data = data[n:]
}
}
}
func decodeExecServerMessage(data []byte, msg *DecodedServerMessage) {
for len(data) > 0 {
num, typ, n := protowire.ConsumeTag(data)
if n < 0 {
return
}
data = data[n:]
switch typ {
case protowire.VarintType:
val, n := protowire.ConsumeVarint(data)
if n < 0 {
return
}
data = data[n:]
if num == ESM_Id {
msg.ExecMsgId = uint32(val)
log.Debugf("decodeExecServerMessage: ESM_Id = %d", val)
}
case protowire.BytesType:
val, n := protowire.ConsumeBytes(data)
if n < 0 {
return
}
data = data[n:]
// Debug: log all fields found in ExecServerMessage
log.Debugf("decodeExecServerMessage: found field %d, len=%d, first 20 bytes: %x", num, len(val), val[:min(20, len(val))])
switch num {
case ESM_ExecId:
msg.ExecId = string(val)
log.Debugf("decodeExecServerMessage: ESM_ExecId = %q", msg.ExecId)
case ESM_RequestContextArgs:
msg.Type = ServerMsgExecRequestCtx
case ESM_McpArgs:
msg.Type = ServerMsgExecMcpArgs
decodeMcpArgs(val, msg)
case ESM_ShellArgs:
msg.Type = ServerMsgExecShellArgs
decodeShellArgs(val, msg)
case ESM_ShellStreamArgs:
msg.Type = ServerMsgExecShellStream
decodeShellArgs(val, msg)
case ESM_ReadArgs:
msg.Type = ServerMsgExecReadArgs
msg.Path = decodeStringField(val, RA_Path)
case ESM_WriteArgs:
msg.Type = ServerMsgExecWriteArgs
msg.Path = decodeStringField(val, WA_Path)
case ESM_DeleteArgs:
msg.Type = ServerMsgExecDeleteArgs
msg.Path = decodeStringField(val, DA_Path)
case ESM_LsArgs:
msg.Type = ServerMsgExecLsArgs
msg.Path = decodeStringField(val, LA_Path)
case ESM_GrepArgs:
msg.Type = ServerMsgExecGrepArgs
case ESM_FetchArgs:
msg.Type = ServerMsgExecFetchArgs
msg.Url = decodeStringField(val, FA_Url)
case ESM_DiagnosticsArgs:
msg.Type = ServerMsgExecDiagnostics
case ESM_BackgroundShellSpawn:
msg.Type = ServerMsgExecBgShellSpawn
decodeShellArgs(val, msg) // same structure
case ESM_WriteShellStdinArgs:
msg.Type = ServerMsgExecWriteShellStdin
default:
// Unknown exec types - only set if we haven't identified the type yet
// (other fields like span_context (19) come after the exec type field)
if msg.Type == ServerMsgUnknown {
msg.Type = ServerMsgExecOther
msg.ExecFieldNumber = int(num)
}
}
default:
n := protowire.ConsumeFieldValue(num, typ, data)
if n < 0 {
return
}
data = data[n:]
}
}
}
func decodeMcpArgs(data []byte, msg *DecodedServerMessage) {
msg.McpArgs = make(map[string][]byte)
for len(data) > 0 {
num, typ, n := protowire.ConsumeTag(data)
if n < 0 {
return
}
data = data[n:]
if typ == protowire.BytesType {
val, n := protowire.ConsumeBytes(data)
if n < 0 {
return
}
data = data[n:]
switch num {
case MCA_Name:
msg.McpToolName = string(val)
case MCA_Args:
// Map entries are encoded as submessages with key=1, value=2
decodeMapEntry(val, msg.McpArgs)
case MCA_ToolCallId:
msg.McpToolCallId = string(val)
case MCA_ToolName:
// ToolName takes precedence if present
if msg.McpToolName == "" || string(val) != "" {
msg.McpToolName = string(val)
}
}
} else {
n := protowire.ConsumeFieldValue(num, typ, data)
if n < 0 {
return
}
data = data[n:]
}
}
}
func decodeMapEntry(data []byte, m map[string][]byte) {
var key string
var value []byte
for len(data) > 0 {
num, typ, n := protowire.ConsumeTag(data)
if n < 0 {
return
}
data = data[n:]
if typ == protowire.BytesType {
val, n := protowire.ConsumeBytes(data)
if n < 0 {
return
}
data = data[n:]
if num == 1 {
key = string(val)
} else if num == 2 {
value = append([]byte(nil), val...)
}
} else {
n := protowire.ConsumeFieldValue(num, typ, data)
if n < 0 {
return
}
data = data[n:]
}
}
if key != "" {
m[key] = value
}
}
func decodeShellArgs(data []byte, msg *DecodedServerMessage) {
for len(data) > 0 {
num, typ, n := protowire.ConsumeTag(data)
if n < 0 {
return
}
data = data[n:]
if typ == protowire.BytesType {
val, n := protowire.ConsumeBytes(data)
if n < 0 {
return
}
data = data[n:]
switch num {
case SHA_Command:
msg.Command = string(val)
case SHA_WorkingDirectory:
msg.WorkingDirectory = string(val)
}
} else {
n := protowire.ConsumeFieldValue(num, typ, data)
if n < 0 {
return
}
data = data[n:]
}
}
}
// --- Helper decoders ---
// decodeStringField extracts a string from the first matching field in a submessage.
func decodeStringField(data []byte, targetField protowire.Number) string {
for len(data) > 0 {
num, typ, n := protowire.ConsumeTag(data)
if n < 0 {
return ""
}
data = data[n:]
if typ == protowire.BytesType {
val, n := protowire.ConsumeBytes(data)
if n < 0 {
return ""
}
data = data[n:]
if num == targetField {
return string(val)
}
} else {
n := protowire.ConsumeFieldValue(num, typ, data)
if n < 0 {
return ""
}
data = data[n:]
}
}
return ""
}
// decodeBytesField extracts bytes from the first matching field in a submessage.
func decodeBytesField(data []byte, targetField protowire.Number) []byte {
for len(data) > 0 {
num, typ, n := protowire.ConsumeTag(data)
if n < 0 {
return nil
}
data = data[n:]
if typ == protowire.BytesType {
val, n := protowire.ConsumeBytes(data)
if n < 0 {
return nil
}
data = data[n:]
if num == targetField {
return append([]byte(nil), val...)
}
} else {
n := protowire.ConsumeFieldValue(num, typ, data)
if n < 0 {
return nil
}
data = data[n:]
}
}
return nil
}
// decodeVarintField extracts an int64 from the first matching varint field in a submessage.
func decodeVarintField(data []byte, targetField protowire.Number) int64 {
for len(data) > 0 {
num, typ, n := protowire.ConsumeTag(data)
if n < 0 {
return 0
}
data = data[n:]
if typ == protowire.VarintType {
val, n := protowire.ConsumeVarint(data)
if n < 0 {
return 0
}
data = data[n:]
if num == targetField {
return int64(val)
}
} else {
n := protowire.ConsumeFieldValue(num, typ, data)
if n < 0 {
return 0
}
data = data[n:]
}
}
return 0
}
// BlobIdHex returns the hex string of a blob ID for use as a map key.
func BlobIdHex(blobId []byte) string {
return hex.EncodeToString(blobId)
}

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,664 @@
// Package proto provides protobuf encoding for Cursor's gRPC API,
// using dynamicpb with the embedded FileDescriptorProto from agent.proto.
// This mirrors the cursor-auth TS plugin's use of @bufbuild/protobuf create()+toBinary().
package proto
import (
"crypto/sha256"
"encoding/hex"
"encoding/json"
"fmt"
log "github.com/sirupsen/logrus"
"google.golang.org/protobuf/encoding/protowire"
"google.golang.org/protobuf/proto"
"google.golang.org/protobuf/reflect/protoreflect"
"google.golang.org/protobuf/types/dynamicpb"
"google.golang.org/protobuf/types/known/structpb"
)
// --- Public types ---
// RunRequestParams holds all data needed to build an AgentRunRequest.
type RunRequestParams struct {
ModelId string
SystemPrompt string
UserText string
MessageId string
ConversationId string
Images []ImageData
Turns []TurnData
McpTools []McpToolDef
BlobStore map[string][]byte // hex(sha256) -> data, populated during encoding
RawCheckpoint []byte // if non-nil, use as conversation_state directly (from server checkpoint)
}
type ImageData struct {
MimeType string
Data []byte
}
type TurnData struct {
UserText string
AssistantText string
}
type McpToolDef struct {
Name string
Description string
InputSchema json.RawMessage
}
// --- Helper: create a dynamic message and set fields ---
func newMsg(name string) *dynamicpb.Message {
return dynamicpb.NewMessage(Msg(name))
}
func field(msg *dynamicpb.Message, name string) protoreflect.FieldDescriptor {
return msg.Descriptor().Fields().ByName(protoreflect.Name(name))
}
func setStr(msg *dynamicpb.Message, name, val string) {
if val != "" {
msg.Set(field(msg, name), protoreflect.ValueOfString(val))
}
}
func setBytes(msg *dynamicpb.Message, name string, val []byte) {
if len(val) > 0 {
msg.Set(field(msg, name), protoreflect.ValueOfBytes(val))
}
}
func setUint32(msg *dynamicpb.Message, name string, val uint32) {
msg.Set(field(msg, name), protoreflect.ValueOfUint32(val))
}
func setBool(msg *dynamicpb.Message, name string, val bool) {
msg.Set(field(msg, name), protoreflect.ValueOfBool(val))
}
func setMsg(msg *dynamicpb.Message, name string, sub *dynamicpb.Message) {
msg.Set(field(msg, name), protoreflect.ValueOfMessage(sub.ProtoReflect()))
}
func marshal(msg *dynamicpb.Message) []byte {
b, err := proto.Marshal(msg)
if err != nil {
panic("cursor proto marshal: " + err.Error())
}
return b
}
// --- Encode functions mirroring cursor-fetch.ts ---
// EncodeHeartbeat returns an encoded AgentClientMessage with clientHeartbeat.
// Mirrors: create(AgentClientMessageSchema, { message: { case: 'clientHeartbeat', value: create(ClientHeartbeatSchema, {}) } })
func EncodeHeartbeat() []byte {
hb := newMsg("ClientHeartbeat")
acm := newMsg("AgentClientMessage")
setMsg(acm, "client_heartbeat", hb)
return marshal(acm)
}
// EncodeRunRequest builds a full AgentClientMessage wrapping an AgentRunRequest.
// Mirrors buildCursorRequest() in cursor-fetch.ts.
// If p.RawCheckpoint is set, it is used directly as the conversation_state bytes
// (from a previous conversation_checkpoint_update), skipping manual turn construction.
func EncodeRunRequest(p *RunRequestParams) []byte {
if p.RawCheckpoint != nil {
return encodeRunRequestWithCheckpoint(p)
}
if p.BlobStore == nil {
p.BlobStore = make(map[string][]byte)
}
// --- Conversation turns ---
// Each turn is serialized as bytes (ConversationTurnStructure → bytes)
var turnBytes [][]byte
for _, turn := range p.Turns {
// UserMessage for this turn
um := newMsg("UserMessage")
setStr(um, "text", turn.UserText)
setStr(um, "message_id", generateId())
umBytes := marshal(um)
// Steps (assistant response)
var stepBytes [][]byte
if turn.AssistantText != "" {
am := newMsg("AssistantMessage")
setStr(am, "text", turn.AssistantText)
step := newMsg("ConversationStep")
setMsg(step, "assistant_message", am)
stepBytes = append(stepBytes, marshal(step))
}
// AgentConversationTurnStructure (fields are bytes, not submessages)
agentTurn := newMsg("AgentConversationTurnStructure")
setBytes(agentTurn, "user_message", umBytes)
for _, sb := range stepBytes {
stepsField := field(agentTurn, "steps")
list := agentTurn.Mutable(stepsField).List()
list.Append(protoreflect.ValueOfBytes(sb))
}
// ConversationTurnStructure (oneof turn → agentConversationTurn)
cts := newMsg("ConversationTurnStructure")
setMsg(cts, "agent_conversation_turn", agentTurn)
turnBytes = append(turnBytes, marshal(cts))
}
// --- System prompt blob ---
systemJSON, _ := json.Marshal(map[string]string{"role": "system", "content": p.SystemPrompt})
blobId := sha256Sum(systemJSON)
p.BlobStore[hex.EncodeToString(blobId)] = systemJSON
// --- ConversationStateStructure ---
css := newMsg("ConversationStateStructure")
// rootPromptMessagesJson: repeated bytes
rootField := field(css, "root_prompt_messages_json")
rootList := css.Mutable(rootField).List()
rootList.Append(protoreflect.ValueOfBytes(blobId))
// turns: repeated bytes (field 8) + turns_old (field 2) for compatibility
turnsField := field(css, "turns")
turnsList := css.Mutable(turnsField).List()
for _, tb := range turnBytes {
turnsList.Append(protoreflect.ValueOfBytes(tb))
}
turnsOldField := field(css, "turns_old")
if turnsOldField != nil {
turnsOldList := css.Mutable(turnsOldField).List()
for _, tb := range turnBytes {
turnsOldList.Append(protoreflect.ValueOfBytes(tb))
}
}
// --- UserMessage (current) ---
userMessage := newMsg("UserMessage")
setStr(userMessage, "text", p.UserText)
setStr(userMessage, "message_id", p.MessageId)
// Images via SelectedContext
if len(p.Images) > 0 {
sc := newMsg("SelectedContext")
imgsField := field(sc, "selected_images")
imgsList := sc.Mutable(imgsField).List()
for _, img := range p.Images {
si := newMsg("SelectedImage")
setStr(si, "uuid", generateId())
setStr(si, "mime_type", img.MimeType)
setBytes(si, "data", img.Data)
imgsList.Append(protoreflect.ValueOfMessage(si.ProtoReflect()))
}
setMsg(userMessage, "selected_context", sc)
}
// --- UserMessageAction ---
uma := newMsg("UserMessageAction")
setMsg(uma, "user_message", userMessage)
// --- ConversationAction ---
ca := newMsg("ConversationAction")
setMsg(ca, "user_message_action", uma)
// --- ModelDetails ---
md := newMsg("ModelDetails")
setStr(md, "model_id", p.ModelId)
setStr(md, "display_model_id", p.ModelId)
setStr(md, "display_name", p.ModelId)
// --- AgentRunRequest ---
arr := newMsg("AgentRunRequest")
setMsg(arr, "conversation_state", css)
setMsg(arr, "action", ca)
setMsg(arr, "model_details", md)
setStr(arr, "conversation_id", p.ConversationId)
// McpTools
if len(p.McpTools) > 0 {
mcpTools := newMsg("McpTools")
toolsField := field(mcpTools, "mcp_tools")
toolsList := mcpTools.Mutable(toolsField).List()
for _, tool := range p.McpTools {
td := newMsg("McpToolDefinition")
setStr(td, "name", tool.Name)
setStr(td, "description", tool.Description)
if len(tool.InputSchema) > 0 {
setBytes(td, "input_schema", jsonToProtobufValueBytes(tool.InputSchema))
}
setStr(td, "provider_identifier", "proxy")
setStr(td, "tool_name", tool.Name)
toolsList.Append(protoreflect.ValueOfMessage(td.ProtoReflect()))
}
setMsg(arr, "mcp_tools", mcpTools)
}
// --- AgentClientMessage ---
acm := newMsg("AgentClientMessage")
setMsg(acm, "run_request", arr)
return marshal(acm)
}
// encodeRunRequestWithCheckpoint builds an AgentClientMessage using a raw checkpoint
// as conversation_state. The checkpoint bytes are embedded directly without deserialization.
func encodeRunRequestWithCheckpoint(p *RunRequestParams) []byte {
// Build UserMessage
userMessage := newMsg("UserMessage")
setStr(userMessage, "text", p.UserText)
setStr(userMessage, "message_id", p.MessageId)
if len(p.Images) > 0 {
sc := newMsg("SelectedContext")
imgsField := field(sc, "selected_images")
imgsList := sc.Mutable(imgsField).List()
for _, img := range p.Images {
si := newMsg("SelectedImage")
setStr(si, "uuid", generateId())
setStr(si, "mime_type", img.MimeType)
setBytes(si, "data", img.Data)
imgsList.Append(protoreflect.ValueOfMessage(si.ProtoReflect()))
}
setMsg(userMessage, "selected_context", sc)
}
// Build ConversationAction with UserMessageAction
uma := newMsg("UserMessageAction")
setMsg(uma, "user_message", userMessage)
ca := newMsg("ConversationAction")
setMsg(ca, "user_message_action", uma)
caBytes := marshal(ca)
// Build ModelDetails
md := newMsg("ModelDetails")
setStr(md, "model_id", p.ModelId)
setStr(md, "display_model_id", p.ModelId)
setStr(md, "display_name", p.ModelId)
mdBytes := marshal(md)
// Build McpTools
var mcpToolsBytes []byte
if len(p.McpTools) > 0 {
mcpTools := newMsg("McpTools")
toolsField := field(mcpTools, "mcp_tools")
toolsList := mcpTools.Mutable(toolsField).List()
for _, tool := range p.McpTools {
td := newMsg("McpToolDefinition")
setStr(td, "name", tool.Name)
setStr(td, "description", tool.Description)
if len(tool.InputSchema) > 0 {
setBytes(td, "input_schema", jsonToProtobufValueBytes(tool.InputSchema))
}
setStr(td, "provider_identifier", "proxy")
setStr(td, "tool_name", tool.Name)
toolsList.Append(protoreflect.ValueOfMessage(td.ProtoReflect()))
}
mcpToolsBytes = marshal(mcpTools)
}
// Manually assemble AgentRunRequest using protowire to embed raw checkpoint
var arrBuf []byte
// field 1: conversation_state = raw checkpoint bytes (length-delimited)
arrBuf = protowire.AppendTag(arrBuf, ARR_ConversationState, protowire.BytesType)
arrBuf = protowire.AppendBytes(arrBuf, p.RawCheckpoint)
// field 2: action = ConversationAction
arrBuf = protowire.AppendTag(arrBuf, ARR_Action, protowire.BytesType)
arrBuf = protowire.AppendBytes(arrBuf, caBytes)
// field 3: model_details = ModelDetails
arrBuf = protowire.AppendTag(arrBuf, ARR_ModelDetails, protowire.BytesType)
arrBuf = protowire.AppendBytes(arrBuf, mdBytes)
// field 4: mcp_tools = McpTools
if len(mcpToolsBytes) > 0 {
arrBuf = protowire.AppendTag(arrBuf, ARR_McpTools, protowire.BytesType)
arrBuf = protowire.AppendBytes(arrBuf, mcpToolsBytes)
}
// field 5: conversation_id = string
if p.ConversationId != "" {
arrBuf = protowire.AppendTag(arrBuf, ARR_ConversationId, protowire.BytesType)
arrBuf = protowire.AppendString(arrBuf, p.ConversationId)
}
// Wrap in AgentClientMessage field 1 (run_request)
var acmBuf []byte
acmBuf = protowire.AppendTag(acmBuf, ACM_RunRequest, protowire.BytesType)
acmBuf = protowire.AppendBytes(acmBuf, arrBuf)
log.Debugf("cursor encode: built RunRequest with checkpoint (%d bytes), total=%d bytes", len(p.RawCheckpoint), len(acmBuf))
return acmBuf
}
// ResumeRequestParams holds data for a ResumeAction request.
type ResumeRequestParams struct {
ModelId string
ConversationId string
McpTools []McpToolDef
}
// EncodeResumeRequest builds an AgentClientMessage with ResumeAction.
// Used to resume a conversation by conversation_id without re-sending full history.
func EncodeResumeRequest(p *ResumeRequestParams) []byte {
// RequestContext with tools
rc := newMsg("RequestContext")
if len(p.McpTools) > 0 {
toolsField := field(rc, "tools")
toolsList := rc.Mutable(toolsField).List()
for _, tool := range p.McpTools {
td := newMsg("McpToolDefinition")
setStr(td, "name", tool.Name)
setStr(td, "description", tool.Description)
if len(tool.InputSchema) > 0 {
setBytes(td, "input_schema", jsonToProtobufValueBytes(tool.InputSchema))
}
setStr(td, "provider_identifier", "proxy")
setStr(td, "tool_name", tool.Name)
toolsList.Append(protoreflect.ValueOfMessage(td.ProtoReflect()))
}
}
// ResumeAction
ra := newMsg("ResumeAction")
setMsg(ra, "request_context", rc)
// ConversationAction with resume_action
ca := newMsg("ConversationAction")
setMsg(ca, "resume_action", ra)
// ModelDetails
md := newMsg("ModelDetails")
setStr(md, "model_id", p.ModelId)
setStr(md, "display_model_id", p.ModelId)
setStr(md, "display_name", p.ModelId)
// AgentRunRequest — no conversation_state needed for resume
arr := newMsg("AgentRunRequest")
setMsg(arr, "action", ca)
setMsg(arr, "model_details", md)
setStr(arr, "conversation_id", p.ConversationId)
// McpTools at top level
if len(p.McpTools) > 0 {
mcpTools := newMsg("McpTools")
toolsField := field(mcpTools, "mcp_tools")
toolsList := mcpTools.Mutable(toolsField).List()
for _, tool := range p.McpTools {
td := newMsg("McpToolDefinition")
setStr(td, "name", tool.Name)
setStr(td, "description", tool.Description)
if len(tool.InputSchema) > 0 {
setBytes(td, "input_schema", jsonToProtobufValueBytes(tool.InputSchema))
}
setStr(td, "provider_identifier", "proxy")
setStr(td, "tool_name", tool.Name)
toolsList.Append(protoreflect.ValueOfMessage(td.ProtoReflect()))
}
setMsg(arr, "mcp_tools", mcpTools)
}
acm := newMsg("AgentClientMessage")
setMsg(acm, "run_request", arr)
return marshal(acm)
}
// --- KV response encoders ---
// Mirrors handleKvMessage() in cursor-fetch.ts
// EncodeKvGetBlobResult responds to a getBlobArgs request.
func EncodeKvGetBlobResult(kvId uint32, blobData []byte) []byte {
result := newMsg("GetBlobResult")
if blobData != nil {
setBytes(result, "blob_data", blobData)
}
kvc := newMsg("KvClientMessage")
setUint32(kvc, "id", kvId)
setMsg(kvc, "get_blob_result", result)
acm := newMsg("AgentClientMessage")
setMsg(acm, "kv_client_message", kvc)
return marshal(acm)
}
// EncodeKvSetBlobResult responds to a setBlobArgs request.
func EncodeKvSetBlobResult(kvId uint32) []byte {
result := newMsg("SetBlobResult")
kvc := newMsg("KvClientMessage")
setUint32(kvc, "id", kvId)
setMsg(kvc, "set_blob_result", result)
acm := newMsg("AgentClientMessage")
setMsg(acm, "kv_client_message", kvc)
return marshal(acm)
}
// --- Exec response encoders ---
// Mirrors handleExecMessage() and sendExec() in cursor-fetch.ts
// EncodeExecRequestContextResult responds to requestContextArgs with tool definitions.
func EncodeExecRequestContextResult(execMsgId uint32, execId string, tools []McpToolDef) []byte {
// RequestContext with tools
rc := newMsg("RequestContext")
if len(tools) > 0 {
toolsField := field(rc, "tools")
toolsList := rc.Mutable(toolsField).List()
for _, tool := range tools {
td := newMsg("McpToolDefinition")
setStr(td, "name", tool.Name)
setStr(td, "description", tool.Description)
if len(tool.InputSchema) > 0 {
setBytes(td, "input_schema", jsonToProtobufValueBytes(tool.InputSchema))
}
setStr(td, "provider_identifier", "proxy")
setStr(td, "tool_name", tool.Name)
toolsList.Append(protoreflect.ValueOfMessage(td.ProtoReflect()))
}
}
// RequestContextSuccess
rcs := newMsg("RequestContextSuccess")
setMsg(rcs, "request_context", rc)
// RequestContextResult (oneof success)
rcr := newMsg("RequestContextResult")
setMsg(rcr, "success", rcs)
return encodeExecClientMsg(execMsgId, execId, "request_context_result", rcr)
}
// EncodeExecMcpResult responds with MCP tool result.
func EncodeExecMcpResult(execMsgId uint32, execId string, content string, isError bool) []byte {
textContent := newMsg("McpTextContent")
setStr(textContent, "text", content)
contentItem := newMsg("McpToolResultContentItem")
setMsg(contentItem, "text", textContent)
success := newMsg("McpSuccess")
contentField := field(success, "content")
contentList := success.Mutable(contentField).List()
contentList.Append(protoreflect.ValueOfMessage(contentItem.ProtoReflect()))
setBool(success, "is_error", isError)
result := newMsg("McpResult")
setMsg(result, "success", success)
return encodeExecClientMsg(execMsgId, execId, "mcp_result", result)
}
// EncodeExecMcpError responds with MCP error.
func EncodeExecMcpError(execMsgId uint32, execId string, errMsg string) []byte {
mcpErr := newMsg("McpError")
setStr(mcpErr, "error", errMsg)
result := newMsg("McpResult")
setMsg(result, "error", mcpErr)
return encodeExecClientMsg(execMsgId, execId, "mcp_result", result)
}
// --- Rejection encoders (mirror handleExecMessage rejections) ---
func EncodeExecReadRejected(execMsgId uint32, execId string, path, reason string) []byte {
rej := newMsg("ReadRejected")
setStr(rej, "path", path)
setStr(rej, "reason", reason)
result := newMsg("ReadResult")
setMsg(result, "rejected", rej)
return encodeExecClientMsg(execMsgId, execId, "read_result", result)
}
func EncodeExecShellRejected(execMsgId uint32, execId string, command, workDir, reason string) []byte {
rej := newMsg("ShellRejected")
setStr(rej, "command", command)
setStr(rej, "working_directory", workDir)
setStr(rej, "reason", reason)
result := newMsg("ShellResult")
setMsg(result, "rejected", rej)
return encodeExecClientMsg(execMsgId, execId, "shell_result", result)
}
func EncodeExecWriteRejected(execMsgId uint32, execId string, path, reason string) []byte {
rej := newMsg("WriteRejected")
setStr(rej, "path", path)
setStr(rej, "reason", reason)
result := newMsg("WriteResult")
setMsg(result, "rejected", rej)
return encodeExecClientMsg(execMsgId, execId, "write_result", result)
}
func EncodeExecDeleteRejected(execMsgId uint32, execId string, path, reason string) []byte {
rej := newMsg("DeleteRejected")
setStr(rej, "path", path)
setStr(rej, "reason", reason)
result := newMsg("DeleteResult")
setMsg(result, "rejected", rej)
return encodeExecClientMsg(execMsgId, execId, "delete_result", result)
}
func EncodeExecLsRejected(execMsgId uint32, execId string, path, reason string) []byte {
rej := newMsg("LsRejected")
setStr(rej, "path", path)
setStr(rej, "reason", reason)
result := newMsg("LsResult")
setMsg(result, "rejected", rej)
return encodeExecClientMsg(execMsgId, execId, "ls_result", result)
}
func EncodeExecGrepError(execMsgId uint32, execId string, errMsg string) []byte {
grepErr := newMsg("GrepError")
setStr(grepErr, "error", errMsg)
result := newMsg("GrepResult")
setMsg(result, "error", grepErr)
return encodeExecClientMsg(execMsgId, execId, "grep_result", result)
}
func EncodeExecFetchError(execMsgId uint32, execId string, url, errMsg string) []byte {
fetchErr := newMsg("FetchError")
setStr(fetchErr, "url", url)
setStr(fetchErr, "error", errMsg)
result := newMsg("FetchResult")
setMsg(result, "error", fetchErr)
return encodeExecClientMsg(execMsgId, execId, "fetch_result", result)
}
func EncodeExecDiagnosticsResult(execMsgId uint32, execId string) []byte {
result := newMsg("DiagnosticsResult")
return encodeExecClientMsg(execMsgId, execId, "diagnostics_result", result)
}
func EncodeExecBackgroundShellSpawnRejected(execMsgId uint32, execId string, command, workDir, reason string) []byte {
rej := newMsg("ShellRejected")
setStr(rej, "command", command)
setStr(rej, "working_directory", workDir)
setStr(rej, "reason", reason)
result := newMsg("BackgroundShellSpawnResult")
setMsg(result, "rejected", rej)
return encodeExecClientMsg(execMsgId, execId, "background_shell_spawn_result", result)
}
func EncodeExecWriteShellStdinError(execMsgId uint32, execId string, errMsg string) []byte {
wsErr := newMsg("WriteShellStdinError")
setStr(wsErr, "error", errMsg)
result := newMsg("WriteShellStdinResult")
setMsg(result, "error", wsErr)
return encodeExecClientMsg(execMsgId, execId, "write_shell_stdin_result", result)
}
// encodeExecClientMsg wraps an exec result in AgentClientMessage.
// Mirrors sendExec() in cursor-fetch.ts.
func encodeExecClientMsg(id uint32, execId string, resultFieldName string, resultMsg *dynamicpb.Message) []byte {
ecm := newMsg("ExecClientMessage")
setUint32(ecm, "id", id)
// Force set exec_id even if empty - Cursor requires this field to be set
ecm.Set(field(ecm, "exec_id"), protoreflect.ValueOfString(execId))
// Debug: check if field exists
fd := field(ecm, resultFieldName)
if fd == nil {
panic(fmt.Sprintf("field %q NOT FOUND in ExecClientMessage! Available fields: %v", resultFieldName, listFields(ecm)))
}
// Debug: log the actual field being set
log.Debugf("encodeExecClientMsg: setting field %q (number=%d, kind=%s)", fd.Name(), fd.Number(), fd.Kind())
ecm.Set(fd, protoreflect.ValueOfMessage(resultMsg.ProtoReflect()))
acm := newMsg("AgentClientMessage")
setMsg(acm, "exec_client_message", ecm)
return marshal(acm)
}
func listFields(msg *dynamicpb.Message) []string {
var names []string
for i := 0; i < msg.Descriptor().Fields().Len(); i++ {
names = append(names, string(msg.Descriptor().Fields().Get(i).Name()))
}
return names
}
// --- Utilities ---
// jsonToProtobufValueBytes converts a JSON schema (json.RawMessage) to protobuf Value binary.
// This mirrors the TS pattern: toBinary(ValueSchema, fromJson(ValueSchema, jsonSchema))
func jsonToProtobufValueBytes(jsonData json.RawMessage) []byte {
if len(jsonData) == 0 {
return nil
}
var v interface{}
if err := json.Unmarshal(jsonData, &v); err != nil {
return jsonData // fallback to raw JSON if parsing fails
}
pbVal, err := structpb.NewValue(v)
if err != nil {
return jsonData // fallback
}
b, err := proto.Marshal(pbVal)
if err != nil {
return jsonData // fallback
}
return b
}
// ProtobufValueBytesToJSON converts protobuf Value binary back to JSON.
// This mirrors the TS pattern: toJson(ValueSchema, fromBinary(ValueSchema, value))
func ProtobufValueBytesToJSON(data []byte) (interface{}, error) {
val := &structpb.Value{}
if err := proto.Unmarshal(data, val); err != nil {
return nil, err
}
return val.AsInterface(), nil
}
func sha256Sum(data []byte) []byte {
h := sha256.Sum256(data)
return h[:]
}
var idCounter uint64
func generateId() string {
idCounter++
h := sha256.Sum256([]byte{byte(idCounter), byte(idCounter >> 8), byte(idCounter >> 16)})
return hex.EncodeToString(h[:16])
}

View File

@@ -0,0 +1,332 @@
// Package proto provides hand-rolled protobuf encode/decode for Cursor's gRPC API.
// Field numbers are extracted from the TypeScript generated proto/agent_pb.ts in alma-plugins/cursor-auth.
package proto
// AgentClientMessage (msg 118) oneof "message"
const (
ACM_RunRequest = 1 // AgentRunRequest
ACM_ExecClientMessage = 2 // ExecClientMessage
ACM_KvClientMessage = 3 // KvClientMessage
ACM_ConversationAction = 4 // ConversationAction
ACM_ExecClientControlMsg = 5 // ExecClientControlMessage
ACM_InteractionResponse = 6 // InteractionResponse
ACM_ClientHeartbeat = 7 // ClientHeartbeat
)
// AgentServerMessage (msg 119) oneof "message"
const (
ASM_InteractionUpdate = 1 // InteractionUpdate
ASM_ExecServerMessage = 2 // ExecServerMessage
ASM_ConversationCheckpoint = 3 // ConversationStateStructure
ASM_KvServerMessage = 4 // KvServerMessage
ASM_ExecServerControlMessage = 5 // ExecServerControlMessage
ASM_InteractionQuery = 7 // InteractionQuery
)
// AgentRunRequest (msg 91)
const (
ARR_ConversationState = 1 // ConversationStateStructure
ARR_Action = 2 // ConversationAction
ARR_ModelDetails = 3 // ModelDetails
ARR_McpTools = 4 // McpTools
ARR_ConversationId = 5 // string (optional)
)
// ConversationStateStructure (msg 83)
const (
CSS_RootPromptMessagesJson = 1 // repeated bytes
CSS_TurnsOld = 2 // repeated bytes (deprecated)
CSS_Todos = 3 // repeated bytes
CSS_PendingToolCalls = 4 // repeated string
CSS_Turns = 8 // repeated bytes (CURRENT field for turns)
CSS_PreviousWorkspaceUris = 9 // repeated string
CSS_SelfSummaryCount = 17 // uint32
CSS_ReadPaths = 18 // repeated string
)
// ConversationAction (msg 54) oneof "action"
const (
CA_UserMessageAction = 1 // UserMessageAction
)
// UserMessageAction (msg 55)
const (
UMA_UserMessage = 1 // UserMessage
)
// UserMessage (msg 63)
const (
UM_Text = 1 // string
UM_MessageId = 2 // string
UM_SelectedContext = 3 // SelectedContext (optional)
)
// SelectedContext
const (
SC_SelectedImages = 1 // repeated SelectedImage
)
// SelectedImage
const (
SI_BlobId = 1 // bytes (oneof dataOrBlobId)
SI_Uuid = 2 // string
SI_Path = 3 // string
SI_MimeType = 7 // string
SI_Data = 8 // bytes (oneof dataOrBlobId)
)
// ModelDetails (msg 88)
const (
MD_ModelId = 1 // string
MD_ThinkingDetails = 2 // ThinkingDetails (optional)
MD_DisplayModelId = 3 // string
MD_DisplayName = 4 // string
)
// McpTools (msg 307)
const (
MT_McpTools = 1 // repeated McpToolDefinition
)
// McpToolDefinition (msg 306)
const (
MTD_Name = 1 // string
MTD_Description = 2 // string
MTD_InputSchema = 3 // bytes
MTD_ProviderIdentifier = 4 // string
MTD_ToolName = 5 // string
)
// ConversationTurnStructure (msg 70) oneof "turn"
const (
CTS_AgentConversationTurn = 1 // AgentConversationTurnStructure
)
// AgentConversationTurnStructure (msg 72)
const (
ACTS_UserMessage = 1 // bytes (serialized UserMessage)
ACTS_Steps = 2 // repeated bytes (serialized ConversationStep)
)
// ConversationStep (msg 53) oneof "message"
const (
CS_AssistantMessage = 1 // AssistantMessage
)
// AssistantMessage
const (
AM_Text = 1 // string
)
// --- Server-side message fields ---
// InteractionUpdate oneof "message"
const (
IU_TextDelta = 1 // TextDeltaUpdate
IU_ThinkingDelta = 4 // ThinkingDeltaUpdate
IU_ThinkingCompleted = 5 // ThinkingCompletedUpdate
)
// TextDeltaUpdate (msg 92)
const (
TDU_Text = 1 // string
)
// ThinkingDeltaUpdate (msg 97)
const (
TKD_Text = 1 // string
)
// KvServerMessage (msg 271)
const (
KSM_Id = 1 // uint32
KSM_GetBlobArgs = 2 // GetBlobArgs
KSM_SetBlobArgs = 3 // SetBlobArgs
)
// GetBlobArgs (msg 267)
const (
GBA_BlobId = 1 // bytes
)
// SetBlobArgs (msg 269)
const (
SBA_BlobId = 1 // bytes
SBA_BlobData = 2 // bytes
)
// KvClientMessage (msg 272)
const (
KCM_Id = 1 // uint32
KCM_GetBlobResult = 2 // GetBlobResult
KCM_SetBlobResult = 3 // SetBlobResult
)
// GetBlobResult (msg 268)
const (
GBR_BlobData = 1 // bytes (optional)
)
// ExecServerMessage
const (
ESM_Id = 1 // uint32
ESM_ExecId = 15 // string
// oneof message:
ESM_ShellArgs = 2 // ShellArgs
ESM_WriteArgs = 3 // WriteArgs
ESM_DeleteArgs = 4 // DeleteArgs
ESM_GrepArgs = 5 // GrepArgs
ESM_ReadArgs = 7 // ReadArgs (NOTE: 6 is skipped)
ESM_LsArgs = 8 // LsArgs
ESM_DiagnosticsArgs = 9 // DiagnosticsArgs
ESM_RequestContextArgs = 10 // RequestContextArgs
ESM_McpArgs = 11 // McpArgs
ESM_ShellStreamArgs = 14 // ShellArgs (stream variant)
ESM_BackgroundShellSpawn = 16 // BackgroundShellSpawnArgs
ESM_FetchArgs = 20 // FetchArgs
ESM_WriteShellStdinArgs = 23 // WriteShellStdinArgs
)
// ExecClientMessage
const (
ECM_Id = 1 // uint32
ECM_ExecId = 15 // string
// oneof message (mirrors server fields):
ECM_ShellResult = 2
ECM_WriteResult = 3
ECM_DeleteResult = 4
ECM_GrepResult = 5
ECM_ReadResult = 7
ECM_LsResult = 8
ECM_DiagnosticsResult = 9
ECM_RequestContextResult = 10
ECM_McpResult = 11
ECM_ShellStream = 14
ECM_BackgroundShellSpawnRes = 16
ECM_FetchResult = 20
ECM_WriteShellStdinResult = 23
)
// McpArgs
const (
MCA_Name = 1 // string
MCA_Args = 2 // map<string, bytes>
MCA_ToolCallId = 3 // string
MCA_ProviderIdentifier = 4 // string
MCA_ToolName = 5 // string
)
// RequestContextResult oneof "result"
const (
RCR_Success = 1 // RequestContextSuccess
RCR_Error = 2 // RequestContextError
)
// RequestContextSuccess (msg 337)
const (
RCS_RequestContext = 1 // RequestContext
)
// RequestContext
const (
RC_Rules = 2 // repeated CursorRule
RC_Tools = 7 // repeated McpToolDefinition
)
// McpResult oneof "result"
const (
MCR_Success = 1 // McpSuccess
MCR_Error = 2 // McpError
MCR_Rejected = 3 // McpRejected
)
// McpSuccess (msg 290)
const (
MCS_Content = 1 // repeated McpToolResultContentItem
MCS_IsError = 2 // bool
)
// McpToolResultContentItem oneof "content"
const (
MTRCI_Text = 1 // McpTextContent
)
// McpTextContent (msg 287)
const (
MTC_Text = 1 // string
)
// McpError (msg 291)
const (
MCE_Error = 1 // string
)
// --- Rejection messages ---
// ReadRejected: path=1, reason=2
// ShellRejected: command=1, workingDirectory=2, reason=3, isReadonly=4
// WriteRejected: path=1, reason=2
// DeleteRejected: path=1, reason=2
// LsRejected: path=1, reason=2
// GrepError: error=1
// FetchError: url=1, error=2
// WriteShellStdinError: error=1
// ReadResult oneof: success=1, error=2, rejected=3
// ShellResult oneof: success=1 (+ various), rejected=?
// The TS code uses specific result field numbers from the oneof:
const (
RR_Rejected = 3 // ReadResult.rejected
SR_Rejected = 5 // ShellResult.rejected (from TS: ShellResult has success/various/rejected)
WR_Rejected = 5 // WriteResult.rejected
DR_Rejected = 3 // DeleteResult.rejected
LR_Rejected = 3 // LsResult.rejected
GR_Error = 2 // GrepResult.error
FR_Error = 2 // FetchResult.error
BSSR_Rejected = 2 // BackgroundShellSpawnResult.rejected (error field)
WSSR_Error = 2 // WriteShellStdinResult.error
)
// --- Rejection struct fields ---
const (
REJ_Path = 1
REJ_Reason = 2
SREJ_Command = 1
SREJ_WorkingDir = 2
SREJ_Reason = 3
SREJ_IsReadonly = 4
GERR_Error = 1
FERR_Url = 1
FERR_Error = 2
)
// ReadArgs
const (
RA_Path = 1 // string
)
// WriteArgs
const (
WA_Path = 1 // string
)
// DeleteArgs
const (
DA_Path = 1 // string
)
// LsArgs
const (
LA_Path = 1 // string
)
// ShellArgs
const (
SHA_Command = 1 // string
SHA_WorkingDirectory = 2 // string
)
// FetchArgs
const (
FA_Url = 1 // string
)

View File

@@ -0,0 +1,313 @@
package proto
import (
"crypto/tls"
"fmt"
"io"
"net"
"sync"
"time"
log "github.com/sirupsen/logrus"
"golang.org/x/net/http2"
"golang.org/x/net/http2/hpack"
)
const (
defaultInitialWindowSize = 65535 // HTTP/2 default
maxFramePayload = 16384 // HTTP/2 default max frame size
)
// H2Stream provides bidirectional HTTP/2 streaming for the Connect protocol.
// Go's net/http does not support full-duplex HTTP/2, so we use the low-level framer.
type H2Stream struct {
framer *http2.Framer
conn net.Conn
streamID uint32
mu sync.Mutex
id string // unique identifier for debugging
frameNum int64 // sequential frame counter for debugging
dataCh chan []byte
doneCh chan struct{}
err error
// Send-side flow control
sendWindow int32 // available bytes we can send on this stream
connWindow int32 // available bytes on the connection level
windowCond *sync.Cond // signaled when window is updated
windowMu sync.Mutex // protects sendWindow, connWindow
}
// ID returns the unique identifier for this stream (for logging).
func (s *H2Stream) ID() string { return s.id }
// FrameNum returns the current frame number for debugging.
func (s *H2Stream) FrameNum() int64 {
s.mu.Lock()
defer s.mu.Unlock()
return s.frameNum
}
// DialH2Stream establishes a TLS+HTTP/2 connection and opens a new stream.
func DialH2Stream(host string, headers map[string]string) (*H2Stream, error) {
tlsConn, err := tls.Dial("tcp", host+":443", &tls.Config{
NextProtos: []string{"h2"},
})
if err != nil {
return nil, fmt.Errorf("h2: TLS dial failed: %w", err)
}
if tlsConn.ConnectionState().NegotiatedProtocol != "h2" {
tlsConn.Close()
return nil, fmt.Errorf("h2: server did not negotiate h2")
}
framer := http2.NewFramer(tlsConn, tlsConn)
// Client connection preface
if _, err := tlsConn.Write([]byte(http2.ClientPreface)); err != nil {
tlsConn.Close()
return nil, fmt.Errorf("h2: preface write failed: %w", err)
}
// Send initial SETTINGS (tell server how much WE can receive)
if err := framer.WriteSettings(
http2.Setting{ID: http2.SettingInitialWindowSize, Val: 4 * 1024 * 1024},
http2.Setting{ID: http2.SettingMaxConcurrentStreams, Val: 100},
); err != nil {
tlsConn.Close()
return nil, fmt.Errorf("h2: settings write failed: %w", err)
}
// Connection-level window update (for receiving)
if err := framer.WriteWindowUpdate(0, 3*1024*1024); err != nil {
tlsConn.Close()
return nil, fmt.Errorf("h2: window update failed: %w", err)
}
// Read and handle initial server frames (SETTINGS, WINDOW_UPDATE)
// Track server's initial window size (how much WE can send)
serverInitialWindowSize := int32(defaultInitialWindowSize)
connWindowSize := int32(defaultInitialWindowSize) // connection-level send window
for i := 0; i < 10; i++ {
f, err := framer.ReadFrame()
if err != nil {
tlsConn.Close()
return nil, fmt.Errorf("h2: initial frame read failed: %w", err)
}
switch sf := f.(type) {
case *http2.SettingsFrame:
if !sf.IsAck() {
sf.ForeachSetting(func(s http2.Setting) error {
if s.ID == http2.SettingInitialWindowSize {
serverInitialWindowSize = int32(s.Val)
log.Debugf("h2: server initial window size: %d", s.Val)
}
return nil
})
framer.WriteSettingsAck()
} else {
goto handshakeDone
}
case *http2.WindowUpdateFrame:
if sf.StreamID == 0 {
connWindowSize += int32(sf.Increment)
log.Debugf("h2: initial conn window update: +%d, total=%d", sf.Increment, connWindowSize)
}
default:
// unexpected but continue
}
}
handshakeDone:
// Build HEADERS
streamID := uint32(1)
var hdrBuf []byte
enc := hpack.NewEncoder(&sliceWriter{buf: &hdrBuf})
enc.WriteField(hpack.HeaderField{Name: ":method", Value: "POST"})
enc.WriteField(hpack.HeaderField{Name: ":scheme", Value: "https"})
enc.WriteField(hpack.HeaderField{Name: ":authority", Value: host})
if p, ok := headers[":path"]; ok {
enc.WriteField(hpack.HeaderField{Name: ":path", Value: p})
}
for k, v := range headers {
if len(k) > 0 && k[0] == ':' {
continue
}
enc.WriteField(hpack.HeaderField{Name: k, Value: v})
}
if err := framer.WriteHeaders(http2.HeadersFrameParam{
StreamID: streamID,
BlockFragment: hdrBuf,
EndStream: false,
EndHeaders: true,
}); err != nil {
tlsConn.Close()
return nil, fmt.Errorf("h2: headers write failed: %w", err)
}
s := &H2Stream{
framer: framer,
conn: tlsConn,
streamID: streamID,
dataCh: make(chan []byte, 256),
doneCh: make(chan struct{}),
id: fmt.Sprintf("%d-%s", streamID, time.Now().Format("150405.000")),
frameNum: 0,
sendWindow: serverInitialWindowSize,
connWindow: connWindowSize,
}
s.windowCond = sync.NewCond(&s.windowMu)
go s.readLoop()
return s, nil
}
// Write sends a DATA frame on the stream, respecting flow control.
func (s *H2Stream) Write(data []byte) error {
for len(data) > 0 {
chunk := data
if len(chunk) > maxFramePayload {
chunk = data[:maxFramePayload]
}
// Wait for flow control window
s.windowMu.Lock()
for s.sendWindow <= 0 || s.connWindow <= 0 {
s.windowCond.Wait()
}
// Limit chunk to available window
allowed := int(s.sendWindow)
if int(s.connWindow) < allowed {
allowed = int(s.connWindow)
}
if len(chunk) > allowed {
chunk = chunk[:allowed]
}
s.sendWindow -= int32(len(chunk))
s.connWindow -= int32(len(chunk))
s.windowMu.Unlock()
s.mu.Lock()
err := s.framer.WriteData(s.streamID, false, chunk)
s.mu.Unlock()
if err != nil {
return err
}
data = data[len(chunk):]
}
return nil
}
// Data returns the channel of received data chunks.
func (s *H2Stream) Data() <-chan []byte { return s.dataCh }
// Done returns a channel closed when the stream ends.
func (s *H2Stream) Done() <-chan struct{} { return s.doneCh }
// Err returns the error (if any) that caused the stream to close.
// Returns nil for a clean shutdown (EOF / StreamEnded).
func (s *H2Stream) Err() error { return s.err }
// Close tears down the connection.
func (s *H2Stream) Close() {
s.conn.Close()
// Unblock any writers waiting on flow control
s.windowCond.Broadcast()
}
func (s *H2Stream) readLoop() {
defer close(s.doneCh)
defer close(s.dataCh)
for {
f, err := s.framer.ReadFrame()
if err != nil {
if err != io.EOF {
s.err = err
log.Debugf("h2stream[%s]: readLoop error: %v", s.id, err)
}
return
}
// Increment frame counter
s.mu.Lock()
s.frameNum++
s.mu.Unlock()
switch frame := f.(type) {
case *http2.DataFrame:
if frame.StreamID == s.streamID && len(frame.Data()) > 0 {
cp := make([]byte, len(frame.Data()))
copy(cp, frame.Data())
s.dataCh <- cp
// Flow control: send WINDOW_UPDATE for received data
s.mu.Lock()
s.framer.WriteWindowUpdate(0, uint32(len(cp)))
s.framer.WriteWindowUpdate(s.streamID, uint32(len(cp)))
s.mu.Unlock()
}
if frame.StreamEnded() {
return
}
case *http2.HeadersFrame:
if frame.StreamEnded() {
return
}
case *http2.RSTStreamFrame:
s.err = fmt.Errorf("h2: RST_STREAM code=%d", frame.ErrCode)
log.Debugf("h2stream[%s]: received RST_STREAM code=%d", s.id, frame.ErrCode)
return
case *http2.GoAwayFrame:
s.err = fmt.Errorf("h2: GOAWAY code=%d", frame.ErrCode)
return
case *http2.PingFrame:
if !frame.IsAck() {
s.mu.Lock()
s.framer.WritePing(true, frame.Data)
s.mu.Unlock()
}
case *http2.SettingsFrame:
if !frame.IsAck() {
// Check for window size changes
frame.ForeachSetting(func(setting http2.Setting) error {
if setting.ID == http2.SettingInitialWindowSize {
s.windowMu.Lock()
delta := int32(setting.Val) - s.sendWindow
s.sendWindow += delta
s.windowMu.Unlock()
s.windowCond.Broadcast()
}
return nil
})
s.mu.Lock()
s.framer.WriteSettingsAck()
s.mu.Unlock()
}
case *http2.WindowUpdateFrame:
// Update send-side flow control window
s.windowMu.Lock()
if frame.StreamID == 0 {
s.connWindow += int32(frame.Increment)
} else if frame.StreamID == s.streamID {
s.sendWindow += int32(frame.Increment)
}
s.windowMu.Unlock()
s.windowCond.Broadcast()
}
}
}
type sliceWriter struct{ buf *[]byte }
func (w *sliceWriter) Write(p []byte) (int, error) {
*w.buf = append(*w.buf, p...)
return len(p), nil
}

View File

@@ -5,8 +5,7 @@ import (
)
// newAuthManager creates a new authentication manager instance with all supported
// authenticators and a file-based token store. It initializes authenticators for
// Gemini, Codex, Claude, Qwen, IFlow, Antigravity, and GitHub Copilot providers.
// authenticators and a file-based token store.
//
// Returns:
// - *sdkAuth.Manager: A configured authentication manager instance
@@ -24,6 +23,8 @@ func newAuthManager() *sdkAuth.Manager {
sdkAuth.NewGitHubCopilotAuthenticator(),
sdkAuth.NewKiloAuthenticator(),
sdkAuth.NewGitLabAuthenticator(),
sdkAuth.NewCodeBuddyAuthenticator(),
sdkAuth.NewCursorAuthenticator(),
)
return manager
}

View File

@@ -0,0 +1,43 @@
package cmd
import (
"context"
"fmt"
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
sdkAuth "github.com/router-for-me/CLIProxyAPI/v6/sdk/auth"
log "github.com/sirupsen/logrus"
)
// DoCodeBuddyLogin triggers the browser OAuth polling flow for CodeBuddy and saves tokens.
// It initiates the OAuth authentication, displays the user code for the user to enter
// at the CodeBuddy verification URL, and waits for authorization before saving the tokens.
//
// Parameters:
// - cfg: The application configuration containing proxy and auth directory settings
// - options: Login options including browser behavior settings
func DoCodeBuddyLogin(cfg *config.Config, options *LoginOptions) {
if options == nil {
options = &LoginOptions{}
}
manager := newAuthManager()
authOpts := &sdkAuth.LoginOptions{
NoBrowser: options.NoBrowser,
Metadata: map[string]string{},
}
record, savedPath, err := manager.Login(context.Background(), "codebuddy", cfg, authOpts)
if err != nil {
log.Errorf("CodeBuddy authentication failed: %v", err)
return
}
if savedPath != "" {
fmt.Printf("Authentication saved to %s\n", savedPath)
}
if record != nil && record.Label != "" {
fmt.Printf("Authenticated as %s\n", record.Label)
}
fmt.Println("CodeBuddy authentication successful!")
}

View File

@@ -0,0 +1,37 @@
package cmd
import (
"context"
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
sdkAuth "github.com/router-for-me/CLIProxyAPI/v6/sdk/auth"
log "github.com/sirupsen/logrus"
)
// DoCursorLogin triggers the OAuth PKCE flow for Cursor and saves tokens.
func DoCursorLogin(cfg *config.Config, options *LoginOptions) {
if options == nil {
options = &LoginOptions{}
}
manager := newAuthManager()
authOpts := &sdkAuth.LoginOptions{
NoBrowser: options.NoBrowser,
Metadata: map[string]string{},
Prompt: options.Prompt,
}
record, savedPath, err := manager.Login(context.Background(), "cursor", cfg, authOpts)
if err != nil {
log.Errorf("Cursor authentication failed: %v", err)
return
}
if savedPath != "" {
log.Infof("Authentication saved to %s", savedPath)
}
if record != nil && record.Label != "" {
log.Infof("Authenticated as %s", record.Label)
}
log.Info("Cursor authentication successful!")
}

View File

@@ -13,6 +13,7 @@ import (
"strings"
"syscall"
"github.com/router-for-me/CLIProxyAPI/v6/internal/registry"
log "github.com/sirupsen/logrus"
"golang.org/x/crypto/bcrypt"
"gopkg.in/yaml.v3"
@@ -194,6 +195,9 @@ type RemoteManagement struct {
SecretKey string `yaml:"secret-key"`
// DisableControlPanel skips serving and syncing the bundled management UI when true.
DisableControlPanel bool `yaml:"disable-control-panel"`
// DisableAutoUpdatePanel disables automatic periodic background updates of the management panel asset from GitHub.
// When false (the default), the background updater remains enabled; when true, the panel is only downloaded on first access if missing.
DisableAutoUpdatePanel bool `yaml:"disable-auto-update-panel"`
// PanelGitHubRepository overrides the GitHub repository used to fetch the management panel asset.
// Accepts either a repository URL (https://github.com/org/repo) or an API releases endpoint.
PanelGitHubRepository string `yaml:"panel-github-repository"`
@@ -574,6 +578,10 @@ type OpenAICompatibilityModel struct {
// Alias is the model name alias that clients will use to reference this model.
Alias string `yaml:"alias" json:"alias"`
// Thinking configures the thinking/reasoning capability for this model.
// If nil, the model defaults to level-based reasoning with levels ["low", "medium", "high"].
Thinking *registry.ThinkingSupport `yaml:"thinking,omitempty" json:"thinking,omitempty"`
}
func (m OpenAICompatibilityModel) GetName() string { return m.Name }

View File

@@ -31,6 +31,7 @@ const (
httpUserAgent = "CLIProxyAPI-management-updater"
managementSyncMinInterval = 30 * time.Second
updateCheckInterval = 3 * time.Hour
maxAssetDownloadSize = 50 << 20 // 10 MB safety limit for management asset downloads
)
// ManagementFileName exposes the control panel asset filename.
@@ -88,6 +89,10 @@ func runAutoUpdater(ctx context.Context) {
log.Debug("management asset auto-updater skipped: control panel disabled")
return
}
if cfg.RemoteManagement.DisableAutoUpdatePanel {
log.Debug("management asset auto-updater skipped: disable-auto-update-panel is enabled")
return
}
configPath, _ := schedulerConfigPath.Load().(string)
staticDir := StaticDir(configPath)
@@ -259,7 +264,8 @@ func EnsureLatestManagementHTML(ctx context.Context, staticDir string, proxyURL
}
if remoteHash != "" && !strings.EqualFold(remoteHash, downloadedHash) {
log.Warnf("remote digest mismatch for management asset: expected %s got %s", remoteHash, downloadedHash)
log.Errorf("management asset digest mismatch: expected %s got %s — aborting update for safety", remoteHash, downloadedHash)
return nil, nil
}
if err = atomicWriteFile(localPath, data); err != nil {
@@ -282,6 +288,9 @@ func ensureFallbackManagementHTML(ctx context.Context, client *http.Client, loca
return false
}
log.Warnf("management asset downloaded from fallback URL without digest verification (hash=%s) — "+
"enable verified GitHub updates by keeping disable-auto-update-panel set to false", downloadedHash)
if err = atomicWriteFile(localPath, data); err != nil {
log.WithError(err).Warn("failed to persist fallback management control panel page")
return false
@@ -392,10 +401,13 @@ func downloadAsset(ctx context.Context, client *http.Client, downloadURL string)
return nil, "", fmt.Errorf("unexpected download status %d: %s", resp.StatusCode, strings.TrimSpace(string(body)))
}
data, err := io.ReadAll(resp.Body)
data, err := io.ReadAll(io.LimitReader(resp.Body, maxAssetDownloadSize+1))
if err != nil {
return nil, "", fmt.Errorf("read download body: %w", err)
}
if int64(len(data)) > maxAssetDownloadSize {
return nil, "", fmt.Errorf("download exceeds maximum allowed size of %d bytes", maxAssetDownloadSize)
}
sum := sha256.Sum256(data)
return data, hex.EncodeToString(sum[:]), nil

View File

@@ -88,6 +88,87 @@ func GetAntigravityModels() []*ModelInfo {
return cloneModelInfos(getModels().Antigravity)
}
// GetCodeBuddyModels returns the available models for CodeBuddy (Tencent).
// These models are served through the copilot.tencent.com API.
func GetCodeBuddyModels() []*ModelInfo {
now := int64(1748044800) // 2025-05-24
return []*ModelInfo{
{
ID: "glm-5.0",
Object: "model",
Created: now,
OwnedBy: "tencent",
Type: "codebuddy",
DisplayName: "GLM-5.0",
Description: "GLM-5.0 via CodeBuddy",
ContextLength: 128000,
MaxCompletionTokens: 32768,
SupportedEndpoints: []string{"/chat/completions"},
},
{
ID: "glm-4.7",
Object: "model",
Created: now,
OwnedBy: "tencent",
Type: "codebuddy",
DisplayName: "GLM-4.7",
Description: "GLM-4.7 via CodeBuddy",
ContextLength: 128000,
MaxCompletionTokens: 32768,
SupportedEndpoints: []string{"/chat/completions"},
},
{
ID: "minimax-m2.5",
Object: "model",
Created: now,
OwnedBy: "tencent",
Type: "codebuddy",
DisplayName: "MiniMax M2.5",
Description: "MiniMax M2.5 via CodeBuddy",
ContextLength: 200000,
MaxCompletionTokens: 32768,
SupportedEndpoints: []string{"/chat/completions"},
},
{
ID: "kimi-k2.5",
Object: "model",
Created: now,
OwnedBy: "tencent",
Type: "codebuddy",
DisplayName: "Kimi K2.5",
Description: "Kimi K2.5 via CodeBuddy",
ContextLength: 128000,
MaxCompletionTokens: 32768,
SupportedEndpoints: []string{"/chat/completions"},
},
{
ID: "deepseek-v3-2-volc",
Object: "model",
Created: now,
OwnedBy: "tencent",
Type: "codebuddy",
DisplayName: "DeepSeek V3.2 (Volc)",
Description: "DeepSeek V3.2 via CodeBuddy (Volcano Engine)",
ContextLength: 128000,
MaxCompletionTokens: 32768,
SupportedEndpoints: []string{"/chat/completions"},
},
{
ID: "hunyuan-2.0-thinking",
Object: "model",
Created: now,
OwnedBy: "tencent",
Type: "codebuddy",
DisplayName: "Hunyuan 2.0 Thinking",
Description: "Tencent Hunyuan 2.0 Thinking via CodeBuddy",
ContextLength: 128000,
MaxCompletionTokens: 32768,
Thinking: &ThinkingSupport{ZeroAllowed: true},
SupportedEndpoints: []string{"/chat/completions"},
},
}
}
// cloneModelInfos returns a shallow copy of the slice with each element deep-cloned.
func cloneModelInfos(models []*ModelInfo) []*ModelInfo {
if len(models) == 0 {
@@ -148,11 +229,27 @@ func GetStaticModelDefinitionsByChannel(channel string) []*ModelInfo {
return GetAmazonQModels()
case "antigravity":
return GetAntigravityModels()
case "codebuddy":
return GetCodeBuddyModels()
case "cursor":
return GetCursorModels()
default:
return nil
}
}
// GetCursorModels returns the fallback Cursor model definitions.
func GetCursorModels() []*ModelInfo {
return []*ModelInfo{
{ID: "composer-2", Object: "model", OwnedBy: "cursor", Type: "cursor", DisplayName: "Composer 2", ContextLength: 200000, MaxCompletionTokens: 64000, Thinking: &ThinkingSupport{Max: 50000, DynamicAllowed: true}},
{ID: "claude-4-sonnet", Object: "model", OwnedBy: "cursor", Type: "cursor", DisplayName: "Claude 4 Sonnet", ContextLength: 200000, MaxCompletionTokens: 64000, Thinking: &ThinkingSupport{Max: 50000, DynamicAllowed: true}},
{ID: "claude-3.5-sonnet", Object: "model", OwnedBy: "cursor", Type: "cursor", DisplayName: "Claude 3.5 Sonnet", ContextLength: 200000, MaxCompletionTokens: 8192},
{ID: "gpt-4o", Object: "model", OwnedBy: "cursor", Type: "cursor", DisplayName: "GPT-4o", ContextLength: 128000, MaxCompletionTokens: 16384},
{ID: "cursor-small", Object: "model", OwnedBy: "cursor", Type: "cursor", DisplayName: "Cursor Small", ContextLength: 200000, MaxCompletionTokens: 64000},
{ID: "gemini-2.5-pro", Object: "model", OwnedBy: "cursor", Type: "cursor", DisplayName: "Gemini 2.5 Pro", ContextLength: 1000000, MaxCompletionTokens: 65536, Thinking: &ThinkingSupport{Max: 50000, DynamicAllowed: true}},
}
}
// LookupStaticModelInfo searches all static model definitions for a model by ID.
// Returns nil if no matching model is found.
func LookupStaticModelInfo(modelID string) *ModelInfo {
@@ -176,6 +273,8 @@ func LookupStaticModelInfo(modelID string) *ModelInfo {
GetKiroModels(),
GetKiloModels(),
GetAmazonQModels(),
GetCodeBuddyModels(),
GetCursorModels(),
}
for _, models := range allModels {
for _, m := range models {
@@ -365,6 +464,19 @@ func GetGitHubCopilotModels() []*ModelInfo {
SupportedEndpoints: []string{"/responses"},
Thinking: &ThinkingSupport{Levels: []string{"none", "low", "medium", "high", "xhigh"}},
},
{
ID: "gpt-5.4",
Object: "model",
Created: now,
OwnedBy: "github-copilot",
Type: "github-copilot",
DisplayName: "GPT-5.4",
Description: "OpenAI GPT-5.4 via GitHub Copilot",
ContextLength: 200000,
MaxCompletionTokens: 32768,
SupportedEndpoints: []string{"/responses"},
Thinking: &ThinkingSupport{Levels: []string{"none", "low", "medium", "high", "xhigh"}},
},
{
ID: "claude-haiku-4.5",
Object: "model",

View File

@@ -73,16 +73,16 @@ type availableModelsCacheEntry struct {
// Values are interpreted in provider-native token units.
type ThinkingSupport struct {
// Min is the minimum allowed thinking budget (inclusive).
Min int `json:"min,omitempty"`
Min int `json:"min,omitempty" yaml:"min,omitempty"`
// Max is the maximum allowed thinking budget (inclusive).
Max int `json:"max,omitempty"`
Max int `json:"max,omitempty" yaml:"max,omitempty"`
// ZeroAllowed indicates whether 0 is a valid value (to disable thinking).
ZeroAllowed bool `json:"zero_allowed,omitempty"`
ZeroAllowed bool `json:"zero_allowed,omitempty" yaml:"zero-allowed,omitempty"`
// DynamicAllowed indicates whether -1 is a valid value (dynamic thinking budget).
DynamicAllowed bool `json:"dynamic_allowed,omitempty"`
DynamicAllowed bool `json:"dynamic_allowed,omitempty" yaml:"dynamic-allowed,omitempty"`
// Levels defines discrete reasoning effort levels (e.g., "low", "medium", "high").
// When set, the model uses level-based reasoning instead of token budgets.
Levels []string `json:"levels,omitempty"`
Levels []string `json:"levels,omitempty" yaml:"levels,omitempty"`
}
// ModelRegistration tracks a model's availability

View File

@@ -0,0 +1,343 @@
package executor
import (
"bufio"
"bytes"
"context"
"fmt"
"io"
"net/http"
"time"
"github.com/router-for-me/CLIProxyAPI/v6/internal/auth/codebuddy"
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
"github.com/router-for-me/CLIProxyAPI/v6/internal/thinking"
cliproxyauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth"
cliproxyexecutor "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/executor"
sdktranslator "github.com/router-for-me/CLIProxyAPI/v6/sdk/translator"
log "github.com/sirupsen/logrus"
)
const (
codeBuddyChatPath = "/v2/chat/completions"
codeBuddyAuthType = "codebuddy"
)
// CodeBuddyExecutor handles requests to the CodeBuddy API.
type CodeBuddyExecutor struct {
cfg *config.Config
}
// NewCodeBuddyExecutor creates a new CodeBuddy executor instance.
func NewCodeBuddyExecutor(cfg *config.Config) *CodeBuddyExecutor {
return &CodeBuddyExecutor{cfg: cfg}
}
// Identifier returns the unique identifier for this executor.
func (e *CodeBuddyExecutor) Identifier() string { return codeBuddyAuthType }
// codeBuddyCredentials extracts the access token and domain from auth metadata.
func codeBuddyCredentials(auth *cliproxyauth.Auth) (accessToken, userID, domain string) {
if auth == nil {
return "", "", ""
}
accessToken = metaStringValue(auth.Metadata, "access_token")
userID = metaStringValue(auth.Metadata, "user_id")
domain = metaStringValue(auth.Metadata, "domain")
if domain == "" {
domain = codebuddy.DefaultDomain
}
return
}
// PrepareRequest prepares the HTTP request before execution.
func (e *CodeBuddyExecutor) PrepareRequest(req *http.Request, auth *cliproxyauth.Auth) error {
if req == nil {
return nil
}
accessToken, userID, domain := codeBuddyCredentials(auth)
if accessToken == "" {
return fmt.Errorf("codebuddy: missing access token")
}
e.applyHeaders(req, accessToken, userID, domain)
return nil
}
// HttpRequest executes a raw HTTP request.
func (e *CodeBuddyExecutor) HttpRequest(ctx context.Context, auth *cliproxyauth.Auth, req *http.Request) (*http.Response, error) {
if req == nil {
return nil, fmt.Errorf("codebuddy executor: request is nil")
}
if ctx == nil {
ctx = req.Context()
}
httpReq := req.WithContext(ctx)
if err := e.PrepareRequest(httpReq, auth); err != nil {
return nil, err
}
httpClient := newProxyAwareHTTPClient(ctx, e.cfg, auth, 0)
return httpClient.Do(httpReq)
}
// Execute performs a non-streaming request.
func (e *CodeBuddyExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (resp cliproxyexecutor.Response, err error) {
baseModel := thinking.ParseSuffix(req.Model).ModelName
reporter := newUsageReporter(ctx, e.Identifier(), baseModel, auth)
defer reporter.trackFailure(ctx, &err)
accessToken, userID, domain := codeBuddyCredentials(auth)
if accessToken == "" {
return resp, fmt.Errorf("codebuddy: missing access token")
}
from := opts.SourceFormat
to := sdktranslator.FromString("openai")
originalPayloadSource := req.Payload
if len(opts.OriginalRequest) > 0 {
originalPayloadSource = opts.OriginalRequest
}
originalTranslated := sdktranslator.TranslateRequest(from, to, baseModel, originalPayloadSource, false)
translated := sdktranslator.TranslateRequest(from, to, baseModel, req.Payload, false)
requestedModel := payloadRequestedModel(opts, req.Model)
translated = applyPayloadConfigWithRoot(e.cfg, baseModel, to.String(), "", translated, originalTranslated, requestedModel)
translated, err = thinking.ApplyThinking(translated, req.Model, from.String(), to.String(), e.Identifier())
if err != nil {
return resp, err
}
url := codebuddy.BaseURL + codeBuddyChatPath
httpReq, err := http.NewRequestWithContext(ctx, http.MethodPost, url, bytes.NewReader(translated))
if err != nil {
return resp, err
}
e.applyHeaders(httpReq, accessToken, userID, domain)
var authID, authLabel, authType, authValue string
if auth != nil {
authID = auth.ID
authLabel = auth.Label
authType, authValue = auth.AccountInfo()
}
recordAPIRequest(ctx, e.cfg, upstreamRequestLog{
URL: url,
Method: http.MethodPost,
Headers: httpReq.Header.Clone(),
Body: translated,
Provider: e.Identifier(),
AuthID: authID,
AuthLabel: authLabel,
AuthType: authType,
AuthValue: authValue,
})
httpClient := newProxyAwareHTTPClient(ctx, e.cfg, auth, 0)
httpResp, err := httpClient.Do(httpReq)
if err != nil {
recordAPIResponseError(ctx, e.cfg, err)
return resp, err
}
defer func() {
if errClose := httpResp.Body.Close(); errClose != nil {
log.Errorf("codebuddy executor: close response body error: %v", errClose)
}
}()
recordAPIResponseMetadata(ctx, e.cfg, httpResp.StatusCode, httpResp.Header.Clone())
if !isHTTPSuccess(httpResp.StatusCode) {
b, _ := io.ReadAll(httpResp.Body)
appendAPIResponseChunk(ctx, e.cfg, b)
log.Debugf("codebuddy executor: upstream error status: %d, body: %s", httpResp.StatusCode, summarizeErrorBody(httpResp.Header.Get("Content-Type"), b))
err = statusErr{code: httpResp.StatusCode, msg: string(b)}
return resp, err
}
body, err := io.ReadAll(httpResp.Body)
if err != nil {
recordAPIResponseError(ctx, e.cfg, err)
return resp, err
}
appendAPIResponseChunk(ctx, e.cfg, body)
reporter.publish(ctx, parseOpenAIUsage(body))
reporter.ensurePublished(ctx)
var param any
out := sdktranslator.TranslateNonStream(ctx, to, from, req.Model, opts.OriginalRequest, translated, body, &param)
resp = cliproxyexecutor.Response{Payload: []byte(out), Headers: httpResp.Header.Clone()}
return resp, nil
}
// ExecuteStream performs a streaming request.
func (e *CodeBuddyExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (_ *cliproxyexecutor.StreamResult, err error) {
baseModel := thinking.ParseSuffix(req.Model).ModelName
reporter := newUsageReporter(ctx, e.Identifier(), baseModel, auth)
defer reporter.trackFailure(ctx, &err)
accessToken, userID, domain := codeBuddyCredentials(auth)
if accessToken == "" {
return nil, fmt.Errorf("codebuddy: missing access token")
}
from := opts.SourceFormat
to := sdktranslator.FromString("openai")
originalPayloadSource := req.Payload
if len(opts.OriginalRequest) > 0 {
originalPayloadSource = opts.OriginalRequest
}
originalTranslated := sdktranslator.TranslateRequest(from, to, baseModel, originalPayloadSource, true)
translated := sdktranslator.TranslateRequest(from, to, baseModel, req.Payload, true)
requestedModel := payloadRequestedModel(opts, req.Model)
translated = applyPayloadConfigWithRoot(e.cfg, baseModel, to.String(), "", translated, originalTranslated, requestedModel)
translated, err = thinking.ApplyThinking(translated, req.Model, from.String(), to.String(), e.Identifier())
if err != nil {
return nil, err
}
url := codebuddy.BaseURL + codeBuddyChatPath
httpReq, err := http.NewRequestWithContext(ctx, http.MethodPost, url, bytes.NewReader(translated))
if err != nil {
return nil, err
}
e.applyHeaders(httpReq, accessToken, userID, domain)
httpReq.Header.Set("Accept", "text/event-stream")
httpReq.Header.Set("Cache-Control", "no-cache")
var authID, authLabel, authType, authValue string
if auth != nil {
authID = auth.ID
authLabel = auth.Label
authType, authValue = auth.AccountInfo()
}
recordAPIRequest(ctx, e.cfg, upstreamRequestLog{
URL: url,
Method: http.MethodPost,
Headers: httpReq.Header.Clone(),
Body: translated,
Provider: e.Identifier(),
AuthID: authID,
AuthLabel: authLabel,
AuthType: authType,
AuthValue: authValue,
})
httpClient := newProxyAwareHTTPClient(ctx, e.cfg, auth, 0)
httpResp, err := httpClient.Do(httpReq)
if err != nil {
recordAPIResponseError(ctx, e.cfg, err)
return nil, err
}
recordAPIResponseMetadata(ctx, e.cfg, httpResp.StatusCode, httpResp.Header.Clone())
if !isHTTPSuccess(httpResp.StatusCode) {
b, _ := io.ReadAll(httpResp.Body)
appendAPIResponseChunk(ctx, e.cfg, b)
httpResp.Body.Close()
log.Debugf("codebuddy executor: upstream error status: %d, body: %s", httpResp.StatusCode, summarizeErrorBody(httpResp.Header.Get("Content-Type"), b))
err = statusErr{code: httpResp.StatusCode, msg: string(b)}
return nil, err
}
out := make(chan cliproxyexecutor.StreamChunk)
go func() {
defer close(out)
defer func() {
if errClose := httpResp.Body.Close(); errClose != nil {
log.Errorf("codebuddy executor: close stream body error: %v", errClose)
}
}()
scanner := bufio.NewScanner(httpResp.Body)
scanner.Buffer(nil, maxScannerBufferSize)
var param any
for scanner.Scan() {
line := scanner.Bytes()
appendAPIResponseChunk(ctx, e.cfg, line)
if detail, ok := parseOpenAIStreamUsage(line); ok {
reporter.publish(ctx, detail)
}
if len(line) == 0 {
continue
}
if !bytes.HasPrefix(line, []byte("data:")) {
continue
}
chunks := sdktranslator.TranslateStream(ctx, to, from, req.Model, opts.OriginalRequest, translated, bytes.Clone(line), &param)
for i := range chunks {
out <- cliproxyexecutor.StreamChunk{Payload: []byte(chunks[i])}
}
}
if errScan := scanner.Err(); errScan != nil {
recordAPIResponseError(ctx, e.cfg, errScan)
reporter.publishFailure(ctx)
out <- cliproxyexecutor.StreamChunk{Err: errScan}
}
reporter.ensurePublished(ctx)
}()
return &cliproxyexecutor.StreamResult{
Headers: httpResp.Header.Clone(),
Chunks: out,
}, nil
}
// Refresh exchanges the CodeBuddy refresh token for a new access token.
func (e *CodeBuddyExecutor) Refresh(ctx context.Context, auth *cliproxyauth.Auth) (*cliproxyauth.Auth, error) {
if auth == nil {
return nil, fmt.Errorf("codebuddy: missing auth")
}
refreshToken := metaStringValue(auth.Metadata, "refresh_token")
if refreshToken == "" {
log.Debugf("codebuddy executor: no refresh token available, skipping refresh")
return auth, nil
}
accessToken, userID, domain := codeBuddyCredentials(auth)
authSvc := codebuddy.NewCodeBuddyAuth(e.cfg)
storage, err := authSvc.RefreshToken(ctx, accessToken, refreshToken, userID, domain)
if err != nil {
return nil, fmt.Errorf("codebuddy: token refresh failed: %w", err)
}
updated := auth.Clone()
updated.Metadata["access_token"] = storage.AccessToken
if storage.RefreshToken != "" {
updated.Metadata["refresh_token"] = storage.RefreshToken
}
updated.Metadata["expires_in"] = storage.ExpiresIn
updated.Metadata["domain"] = storage.Domain
if storage.UserID != "" {
updated.Metadata["user_id"] = storage.UserID
}
now := time.Now()
updated.UpdatedAt = now
updated.LastRefreshedAt = now
return updated, nil
}
// CountTokens is not supported for CodeBuddy.
func (e *CodeBuddyExecutor) CountTokens(_ context.Context, _ *cliproxyauth.Auth, _ cliproxyexecutor.Request, _ cliproxyexecutor.Options) (cliproxyexecutor.Response, error) {
return cliproxyexecutor.Response{}, fmt.Errorf("codebuddy: count tokens not supported")
}
// applyHeaders sets required headers for CodeBuddy API requests.
func (e *CodeBuddyExecutor) applyHeaders(req *http.Request, accessToken, userID, domain string) {
req.Header.Set("Authorization", "Bearer "+accessToken)
req.Header.Set("Content-Type", "application/json")
req.Header.Set("Accept", "application/json")
req.Header.Set("User-Agent", codebuddy.UserAgent)
req.Header.Set("X-User-Id", userID)
req.Header.Set("X-Domain", domain)
req.Header.Set("X-Product", "SaaS")
req.Header.Set("X-IDE-Type", "CLI")
req.Header.Set("X-IDE-Name", "CLI")
req.Header.Set("X-IDE-Version", "2.63.2")
req.Header.Set("X-Requested-With", "XMLHttpRequest")
}

View File

@@ -28,8 +28,8 @@ import (
)
const (
codexClientVersion = "0.101.0"
codexUserAgent = "codex_cli_rs/0.101.0 (Mac OS 26.0.1; arm64) Apple_Terminal/464"
codexUserAgent = "codex_cli_rs/0.116.0 (Mac OS 26.0.1; arm64) Apple_Terminal/464"
codexOriginator = "codex_cli_rs"
)
var dataTag = []byte("data:")
@@ -645,8 +645,10 @@ func applyCodexHeaders(r *http.Request, auth *cliproxyauth.Auth, token string, s
ginHeaders = ginCtx.Request.Header
}
misc.EnsureHeader(r.Header, ginHeaders, "Version", codexClientVersion)
misc.EnsureHeader(r.Header, ginHeaders, "Version", "")
misc.EnsureHeader(r.Header, ginHeaders, "Session_id", uuid.NewString())
misc.EnsureHeader(r.Header, ginHeaders, "X-Codex-Turn-Metadata", "")
misc.EnsureHeader(r.Header, ginHeaders, "X-Client-Request-Id", "")
cfgUserAgent, _ := codexHeaderDefaults(cfg, auth)
ensureHeaderWithConfigPrecedence(r.Header, ginHeaders, "User-Agent", cfgUserAgent, codexUserAgent)
@@ -663,8 +665,12 @@ func applyCodexHeaders(r *http.Request, auth *cliproxyauth.Auth, token string, s
isAPIKey = true
}
}
if originator := strings.TrimSpace(ginHeaders.Get("Originator")); originator != "" {
r.Header.Set("Originator", originator)
} else if !isAPIKey {
r.Header.Set("Originator", codexOriginator)
}
if !isAPIKey {
r.Header.Set("Originator", "codex_cli_rs")
if auth != nil && auth.Metadata != nil {
if accountID, ok := auth.Metadata["account_id"].(string); ok {
r.Header.Set("Chatgpt-Account-Id", accountID)

View File

@@ -814,9 +814,10 @@ func applyCodexWebsocketHeaders(ctx context.Context, headers http.Header, auth *
ensureHeaderWithPriority(headers, ginHeaders, "x-codex-beta-features", cfgBetaFeatures, "")
misc.EnsureHeader(headers, ginHeaders, "x-codex-turn-state", "")
misc.EnsureHeader(headers, ginHeaders, "x-codex-turn-metadata", "")
misc.EnsureHeader(headers, ginHeaders, "x-client-request-id", "")
misc.EnsureHeader(headers, ginHeaders, "x-responsesapi-include-timing-metrics", "")
misc.EnsureHeader(headers, ginHeaders, "Version", "")
misc.EnsureHeader(headers, ginHeaders, "Version", codexClientVersion)
betaHeader := strings.TrimSpace(headers.Get("OpenAI-Beta"))
if betaHeader == "" && ginHeaders != nil {
betaHeader = strings.TrimSpace(ginHeaders.Get("OpenAI-Beta"))
@@ -834,8 +835,12 @@ func applyCodexWebsocketHeaders(ctx context.Context, headers http.Header, auth *
isAPIKey = true
}
}
if originator := strings.TrimSpace(ginHeaders.Get("Originator")); originator != "" {
headers.Set("Originator", originator)
} else if !isAPIKey {
headers.Set("Originator", codexOriginator)
}
if !isAPIKey {
headers.Set("Originator", "codex_cli_rs")
if auth != nil && auth.Metadata != nil {
if accountID, ok := auth.Metadata["account_id"].(string); ok {
if trimmed := strings.TrimSpace(accountID); trimmed != "" {

View File

@@ -41,9 +41,46 @@ func TestApplyCodexWebsocketHeadersDefaultsToCurrentResponsesBeta(t *testing.T)
if got := headers.Get("User-Agent"); got != codexUserAgent {
t.Fatalf("User-Agent = %s, want %s", got, codexUserAgent)
}
if got := headers.Get("Version"); got != "" {
t.Fatalf("Version = %q, want empty", got)
}
if got := headers.Get("x-codex-beta-features"); got != "" {
t.Fatalf("x-codex-beta-features = %q, want empty", got)
}
if got := headers.Get("X-Codex-Turn-Metadata"); got != "" {
t.Fatalf("X-Codex-Turn-Metadata = %q, want empty", got)
}
if got := headers.Get("X-Client-Request-Id"); got != "" {
t.Fatalf("X-Client-Request-Id = %q, want empty", got)
}
}
func TestApplyCodexWebsocketHeadersPassesThroughClientIdentityHeaders(t *testing.T) {
auth := &cliproxyauth.Auth{
Provider: "codex",
Metadata: map[string]any{"email": "user@example.com"},
}
ctx := contextWithGinHeaders(map[string]string{
"Originator": "Codex Desktop",
"Version": "0.115.0-alpha.27",
"X-Codex-Turn-Metadata": `{"turn_id":"turn-1"}`,
"X-Client-Request-Id": "019d2233-e240-7162-992d-38df0a2a0e0d",
})
headers := applyCodexWebsocketHeaders(ctx, http.Header{}, auth, "", nil)
if got := headers.Get("Originator"); got != "Codex Desktop" {
t.Fatalf("Originator = %s, want %s", got, "Codex Desktop")
}
if got := headers.Get("Version"); got != "0.115.0-alpha.27" {
t.Fatalf("Version = %s, want %s", got, "0.115.0-alpha.27")
}
if got := headers.Get("X-Codex-Turn-Metadata"); got != `{"turn_id":"turn-1"}` {
t.Fatalf("X-Codex-Turn-Metadata = %s, want %s", got, `{"turn_id":"turn-1"}`)
}
if got := headers.Get("X-Client-Request-Id"); got != "019d2233-e240-7162-992d-38df0a2a0e0d" {
t.Fatalf("X-Client-Request-Id = %s, want %s", got, "019d2233-e240-7162-992d-38df0a2a0e0d")
}
}
func TestApplyCodexWebsocketHeadersUsesConfigDefaultsForOAuth(t *testing.T) {
@@ -177,6 +214,57 @@ func TestApplyCodexHeadersUsesConfigUserAgentForOAuth(t *testing.T) {
}
}
func TestApplyCodexHeadersPassesThroughClientIdentityHeaders(t *testing.T) {
req, err := http.NewRequest(http.MethodPost, "https://example.com/responses", nil)
if err != nil {
t.Fatalf("NewRequest() error = %v", err)
}
auth := &cliproxyauth.Auth{
Provider: "codex",
Metadata: map[string]any{"email": "user@example.com"},
}
req = req.WithContext(contextWithGinHeaders(map[string]string{
"Originator": "Codex Desktop",
"Version": "0.115.0-alpha.27",
"X-Codex-Turn-Metadata": `{"turn_id":"turn-1"}`,
"X-Client-Request-Id": "019d2233-e240-7162-992d-38df0a2a0e0d",
}))
applyCodexHeaders(req, auth, "oauth-token", true, nil)
if got := req.Header.Get("Originator"); got != "Codex Desktop" {
t.Fatalf("Originator = %s, want %s", got, "Codex Desktop")
}
if got := req.Header.Get("Version"); got != "0.115.0-alpha.27" {
t.Fatalf("Version = %s, want %s", got, "0.115.0-alpha.27")
}
if got := req.Header.Get("X-Codex-Turn-Metadata"); got != `{"turn_id":"turn-1"}` {
t.Fatalf("X-Codex-Turn-Metadata = %s, want %s", got, `{"turn_id":"turn-1"}`)
}
if got := req.Header.Get("X-Client-Request-Id"); got != "019d2233-e240-7162-992d-38df0a2a0e0d" {
t.Fatalf("X-Client-Request-Id = %s, want %s", got, "019d2233-e240-7162-992d-38df0a2a0e0d")
}
}
func TestApplyCodexHeadersDoesNotInjectClientOnlyHeadersByDefault(t *testing.T) {
req, err := http.NewRequest(http.MethodPost, "https://example.com/responses", nil)
if err != nil {
t.Fatalf("NewRequest() error = %v", err)
}
applyCodexHeaders(req, nil, "oauth-token", true, nil)
if got := req.Header.Get("Version"); got != "" {
t.Fatalf("Version = %q, want empty", got)
}
if got := req.Header.Get("X-Codex-Turn-Metadata"); got != "" {
t.Fatalf("X-Codex-Turn-Metadata = %q, want empty", got)
}
if got := req.Header.Get("X-Client-Request-Id"); got != "" {
t.Fatalf("X-Client-Request-Id = %q, want empty", got)
}
}
func contextWithGinHeaders(headers map[string]string) context.Context {
gin.SetMode(gin.TestMode)
recorder := httptest.NewRecorder()

File diff suppressed because it is too large Load Diff

View File

@@ -577,9 +577,33 @@ func useGitHubCopilotResponsesEndpoint(sourceFormat sdktranslator.Format, model
return true
}
baseModel := strings.ToLower(thinking.ParseSuffix(model).ModelName)
if info := registry.GetGlobalRegistry().GetModelInfo(baseModel, githubCopilotAuthType); info != nil {
return len(info.SupportedEndpoints) > 0 && !containsEndpoint(info.SupportedEndpoints, githubCopilotChatPath) && containsEndpoint(info.SupportedEndpoints, githubCopilotResponsesPath)
}
if info := lookupGitHubCopilotStaticModelInfo(baseModel); info != nil {
return len(info.SupportedEndpoints) > 0 && !containsEndpoint(info.SupportedEndpoints, githubCopilotChatPath) && containsEndpoint(info.SupportedEndpoints, githubCopilotResponsesPath)
}
return strings.Contains(baseModel, "codex")
}
func lookupGitHubCopilotStaticModelInfo(model string) *registry.ModelInfo {
for _, info := range registry.GetStaticModelDefinitionsByChannel(githubCopilotAuthType) {
if info != nil && strings.EqualFold(info.ID, model) {
return info
}
}
return nil
}
func containsEndpoint(endpoints []string, endpoint string) bool {
for _, item := range endpoints {
if item == endpoint {
return true
}
}
return false
}
// flattenAssistantContent converts assistant message content from array format
// to a joined string. GitHub Copilot requires assistant content as a string;
// sending it as an array causes Claude models to re-answer all previous prompts.

View File

@@ -5,6 +5,7 @@ import (
"strings"
"testing"
"github.com/router-for-me/CLIProxyAPI/v6/internal/registry"
sdktranslator "github.com/router-for-me/CLIProxyAPI/v6/sdk/translator"
"github.com/tidwall/gjson"
)
@@ -70,6 +71,29 @@ func TestUseGitHubCopilotResponsesEndpoint_CodexModel(t *testing.T) {
}
}
func TestUseGitHubCopilotResponsesEndpoint_RegistryResponsesOnlyModel(t *testing.T) {
t.Parallel()
if !useGitHubCopilotResponsesEndpoint(sdktranslator.FromString("openai"), "gpt-5.4") {
t.Fatal("expected responses-only registry model to use /responses")
}
}
func TestUseGitHubCopilotResponsesEndpoint_DynamicRegistryWinsOverStatic(t *testing.T) {
t.Parallel()
reg := registry.GetGlobalRegistry()
clientID := "github-copilot-test-client"
reg.RegisterClient(clientID, "github-copilot", []*registry.ModelInfo{{
ID: "gpt-5.4",
SupportedEndpoints: []string{"/chat/completions", "/responses"},
}})
defer reg.UnregisterClient(clientID)
if useGitHubCopilotResponsesEndpoint(sdktranslator.FromString("openai"), "gpt-5.4") {
t.Fatal("expected dynamic registry definition to take precedence over static fallback")
}
}
func TestUseGitHubCopilotResponsesEndpoint_DefaultChat(t *testing.T) {
t.Parallel()
if useGitHubCopilotResponsesEndpoint(sdktranslator.FromString("openai"), "claude-3-5-sonnet") {

View File

@@ -104,59 +104,59 @@ func ConvertClaudeRequestToAntigravity(modelName string, inputRawJSON []byte, _
// Always try cached signature first (more reliable than client-provided)
// Client may send stale or invalid signatures from different sessions
signature := ""
if thinkingText != "" {
if cachedSig := cache.GetCachedSignature(modelName, thinkingText); cachedSig != "" {
signature = cachedSig
// log.Debugf("Using cached signature for thinking block")
}
}
// Fallback to client signature only if cache miss and client signature is valid
if signature == "" {
signatureResult := contentResult.Get("signature")
clientSignature := ""
if signatureResult.Exists() && signatureResult.String() != "" {
arrayClientSignatures := strings.SplitN(signatureResult.String(), "#", 2)
if len(arrayClientSignatures) == 2 {
if cache.GetModelGroup(modelName) == arrayClientSignatures[0] {
clientSignature = arrayClientSignatures[1]
}
signature := ""
if thinkingText != "" {
if cachedSig := cache.GetCachedSignature(modelName, thinkingText); cachedSig != "" {
signature = cachedSig
// log.Debugf("Using cached signature for thinking block")
}
}
if cache.HasValidSignature(modelName, clientSignature) {
signature = clientSignature
// Fallback to client signature only if cache miss and client signature is valid
if signature == "" {
signatureResult := contentResult.Get("signature")
clientSignature := ""
if signatureResult.Exists() && signatureResult.String() != "" {
arrayClientSignatures := strings.SplitN(signatureResult.String(), "#", 2)
if len(arrayClientSignatures) == 2 {
if cache.GetModelGroup(modelName) == arrayClientSignatures[0] {
clientSignature = arrayClientSignatures[1]
}
}
}
if cache.HasValidSignature(modelName, clientSignature) {
signature = clientSignature
}
// log.Debugf("Using client-provided signature for thinking block")
}
// log.Debugf("Using client-provided signature for thinking block")
}
// Store for subsequent tool_use in the same message
if cache.HasValidSignature(modelName, signature) {
currentMessageThinkingSignature = signature
}
// Store for subsequent tool_use in the same message
if cache.HasValidSignature(modelName, signature) {
currentMessageThinkingSignature = signature
}
// Skip trailing unsigned thinking blocks on last assistant message
isUnsigned := !cache.HasValidSignature(modelName, signature)
// Skip trailing unsigned thinking blocks on last assistant message
isUnsigned := !cache.HasValidSignature(modelName, signature)
// If unsigned, skip entirely (don't convert to text)
// Claude requires assistant messages to start with thinking blocks when thinking is enabled
// Converting to text would break this requirement
if isUnsigned {
// log.Debugf("Dropping unsigned thinking block (no valid signature)")
enableThoughtTranslate = false
continue
}
// If unsigned, skip entirely (don't convert to text)
// Claude requires assistant messages to start with thinking blocks when thinking is enabled
// Converting to text would break this requirement
if isUnsigned {
// log.Debugf("Dropping unsigned thinking block (no valid signature)")
enableThoughtTranslate = false
continue
}
// Valid signature, send as thought block
// Always include "text" field — Google Antigravity API requires it
// even for redacted thinking where the text is empty.
partJSON := []byte(`{}`)
partJSON, _ = sjson.SetBytes(partJSON, "thought", true)
partJSON, _ = sjson.SetBytes(partJSON, "text", thinkingText)
if signature != "" {
partJSON, _ = sjson.SetBytes(partJSON, "thoughtSignature", signature)
}
clientContentJSON, _ = sjson.SetRawBytes(clientContentJSON, "parts.-1", partJSON)
// Valid signature, send as thought block
// Always include "text" field — Google Antigravity API requires it
// even for redacted thinking where the text is empty.
partJSON := []byte(`{}`)
partJSON, _ = sjson.SetBytes(partJSON, "thought", true)
partJSON, _ = sjson.SetBytes(partJSON, "text", thinkingText)
if signature != "" {
partJSON, _ = sjson.SetBytes(partJSON, "thoughtSignature", signature)
}
clientContentJSON, _ = sjson.SetRawBytes(clientContentJSON, "parts.-1", partJSON)
} else if contentTypeResult.Type == gjson.String && contentTypeResult.String() == "text" {
prompt := contentResult.Get("text").String()
// Skip empty text parts to avoid Gemini API error:

View File

@@ -165,29 +165,22 @@ func ConvertOpenAIRequestToClaude(modelName string, inputRawJSON []byte, stream
// Process messages and transform them to Claude Code format
if messages := root.Get("messages"); messages.Exists() && messages.IsArray() {
messageIndex := 0
systemMessageIndex := -1
messages.ForEach(func(_, message gjson.Result) bool {
role := message.Get("role").String()
contentResult := message.Get("content")
switch role {
case "system":
if systemMessageIndex == -1 {
systemMsg := []byte(`{"role":"user","content":[]}`)
out, _ = sjson.SetRawBytes(out, "messages.-1", systemMsg)
systemMessageIndex = messageIndex
messageIndex++
}
if contentResult.Exists() && contentResult.Type == gjson.String && contentResult.String() != "" {
textPart := []byte(`{"type":"text","text":""}`)
textPart, _ = sjson.SetBytes(textPart, "text", contentResult.String())
out, _ = sjson.SetRawBytes(out, fmt.Sprintf("messages.%d.content.-1", systemMessageIndex), textPart)
out, _ = sjson.SetRawBytes(out, "system.-1", textPart)
} else if contentResult.Exists() && contentResult.IsArray() {
contentResult.ForEach(func(_, part gjson.Result) bool {
if part.Get("type").String() == "text" {
textPart := []byte(`{"type":"text","text":""}`)
textPart, _ = sjson.SetBytes(textPart, "text", part.Get("text").String())
out, _ = sjson.SetRawBytes(out, fmt.Sprintf("messages.%d.content.-1", systemMessageIndex), textPart)
out, _ = sjson.SetRawBytes(out, "system.-1", textPart)
}
return true
})
@@ -269,6 +262,16 @@ func ConvertOpenAIRequestToClaude(modelName string, inputRawJSON []byte, stream
}
return true
})
// Preserve a minimal conversational turn for system-only inputs.
// Claude payloads with top-level system instructions but no messages are risky for downstream validation.
if messageIndex == 0 {
system := gjson.GetBytes(out, "system")
if system.Exists() && system.IsArray() && len(system.Array()) > 0 {
fallbackMsg := []byte(`{"role":"user","content":[{"type":"text","text":""}]}`)
out, _ = sjson.SetRawBytes(out, "messages.-1", fallbackMsg)
}
}
}
// Tools mapping: OpenAI tools -> Claude Code tools

View File

@@ -135,3 +135,111 @@ func TestConvertOpenAIRequestToClaude_ToolResultURLImageOnly(t *testing.T) {
t.Fatalf("Unexpected image URL: %q", got)
}
}
func TestConvertOpenAIRequestToClaude_SystemRoleBecomesTopLevelSystem(t *testing.T) {
inputJSON := `{
"model": "gpt-4.1",
"messages": [
{"role": "system", "content": "You are a helpful assistant."},
{"role": "user", "content": "Hello"}
]
}`
result := ConvertOpenAIRequestToClaude("claude-sonnet-4-5", []byte(inputJSON), false)
resultJSON := gjson.ParseBytes(result)
system := resultJSON.Get("system")
if !system.IsArray() {
t.Fatalf("Expected top-level system array, got %s", system.Raw)
}
if len(system.Array()) != 1 {
t.Fatalf("Expected 1 system block, got %d. System: %s", len(system.Array()), system.Raw)
}
if got := system.Get("0.type").String(); got != "text" {
t.Fatalf("Expected system block type %q, got %q", "text", got)
}
if got := system.Get("0.text").String(); got != "You are a helpful assistant." {
t.Fatalf("Expected system text %q, got %q", "You are a helpful assistant.", got)
}
messages := resultJSON.Get("messages").Array()
if len(messages) != 1 {
t.Fatalf("Expected 1 non-system message, got %d. Messages: %s", len(messages), resultJSON.Get("messages").Raw)
}
if got := messages[0].Get("role").String(); got != "user" {
t.Fatalf("Expected remaining message role %q, got %q", "user", got)
}
if got := messages[0].Get("content.0.text").String(); got != "Hello" {
t.Fatalf("Expected user text %q, got %q", "Hello", got)
}
}
func TestConvertOpenAIRequestToClaude_MultipleSystemMessagesMergedIntoTopLevelSystem(t *testing.T) {
inputJSON := `{
"model": "gpt-4.1",
"messages": [
{"role": "system", "content": "Rule 1"},
{"role": "system", "content": [{"type": "text", "text": "Rule 2"}]},
{"role": "user", "content": "Hello"}
]
}`
result := ConvertOpenAIRequestToClaude("claude-sonnet-4-5", []byte(inputJSON), false)
resultJSON := gjson.ParseBytes(result)
system := resultJSON.Get("system").Array()
if len(system) != 2 {
t.Fatalf("Expected 2 system blocks, got %d. System: %s", len(system), resultJSON.Get("system").Raw)
}
if got := system[0].Get("text").String(); got != "Rule 1" {
t.Fatalf("Expected first system text %q, got %q", "Rule 1", got)
}
if got := system[1].Get("text").String(); got != "Rule 2" {
t.Fatalf("Expected second system text %q, got %q", "Rule 2", got)
}
messages := resultJSON.Get("messages").Array()
if len(messages) != 1 {
t.Fatalf("Expected 1 non-system message, got %d. Messages: %s", len(messages), resultJSON.Get("messages").Raw)
}
if got := messages[0].Get("role").String(); got != "user" {
t.Fatalf("Expected remaining message role %q, got %q", "user", got)
}
if got := messages[0].Get("content.0.text").String(); got != "Hello" {
t.Fatalf("Expected user text %q, got %q", "Hello", got)
}
}
func TestConvertOpenAIRequestToClaude_SystemOnlyInputKeepsFallbackUserMessage(t *testing.T) {
inputJSON := `{
"model": "gpt-4.1",
"messages": [
{"role": "system", "content": "You are a helpful assistant."}
]
}`
result := ConvertOpenAIRequestToClaude("claude-sonnet-4-5", []byte(inputJSON), false)
resultJSON := gjson.ParseBytes(result)
system := resultJSON.Get("system").Array()
if len(system) != 1 {
t.Fatalf("Expected 1 system block, got %d. System: %s", len(system), resultJSON.Get("system").Raw)
}
if got := system[0].Get("text").String(); got != "You are a helpful assistant." {
t.Fatalf("Expected system text %q, got %q", "You are a helpful assistant.", got)
}
messages := resultJSON.Get("messages").Array()
if len(messages) != 1 {
t.Fatalf("Expected 1 fallback message, got %d. Messages: %s", len(messages), resultJSON.Get("messages").Raw)
}
if got := messages[0].Get("role").String(); got != "user" {
t.Fatalf("Expected fallback message role %q, got %q", "user", got)
}
if got := messages[0].Get("content.0.type").String(); got != "text" {
t.Fatalf("Expected fallback content type %q, got %q", "text", got)
}
if got := messages[0].Get("content.0.text").String(); got != "" {
t.Fatalf("Expected fallback text %q, got %q", "", got)
}
}

View File

@@ -60,7 +60,7 @@ func ConvertCodexResponseToOpenAI(_ context.Context, modelName string, originalR
rawJSON = bytes.TrimSpace(rawJSON[5:])
// Initialize the OpenAI SSE template.
template := []byte(`{"id":"","object":"chat.completion.chunk","created":12345,"model":"model","choices":[{"index":0,"delta":{"role":null,"content":null,"reasoning_content":null,"tool_calls":null},"finish_reason":null,"native_finish_reason":null}]}`)
template := []byte(`{"id":"","object":"chat.completion.chunk","created":12345,"model":"model","choices":[{"index":0,"delta":{},"finish_reason":null,"native_finish_reason":null}]}`)
rootResult := gjson.ParseBytes(rawJSON)

View File

@@ -45,3 +45,48 @@ func TestConvertCodexResponseToOpenAI_FirstChunkUsesRequestModelName(t *testing.
t.Fatalf("expected model %q, got %q", modelName, gotModel)
}
}
func TestConvertCodexResponseToOpenAI_ToolCallChunkOmitsNullContentFields(t *testing.T) {
ctx := context.Background()
var param any
out := ConvertCodexResponseToOpenAI(ctx, "gpt-5.4", nil, nil, []byte(`data: {"type":"response.output_item.added","item":{"type":"function_call","call_id":"call_123","name":"websearch"}}`), &param)
if len(out) != 1 {
t.Fatalf("expected 1 chunk, got %d", len(out))
}
if gjson.GetBytes(out[0], "choices.0.delta.content").Exists() {
t.Fatalf("expected content to be omitted, got %s", string(out[0]))
}
if gjson.GetBytes(out[0], "choices.0.delta.reasoning_content").Exists() {
t.Fatalf("expected reasoning_content to be omitted, got %s", string(out[0]))
}
if !gjson.GetBytes(out[0], "choices.0.delta.tool_calls").Exists() {
t.Fatalf("expected tool_calls to exist, got %s", string(out[0]))
}
}
func TestConvertCodexResponseToOpenAI_ToolCallArgumentsDeltaOmitsNullContentFields(t *testing.T) {
ctx := context.Background()
var param any
out := ConvertCodexResponseToOpenAI(ctx, "gpt-5.4", nil, nil, []byte(`data: {"type":"response.output_item.added","item":{"type":"function_call","call_id":"call_123","name":"websearch"}}`), &param)
if len(out) != 1 {
t.Fatalf("expected tool call announcement chunk, got %d", len(out))
}
out = ConvertCodexResponseToOpenAI(ctx, "gpt-5.4", nil, nil, []byte(`data: {"type":"response.function_call_arguments.delta","delta":"{\"query\":\"OpenAI\"}"}`), &param)
if len(out) != 1 {
t.Fatalf("expected 1 chunk, got %d", len(out))
}
if gjson.GetBytes(out[0], "choices.0.delta.content").Exists() {
t.Fatalf("expected content to be omitted, got %s", string(out[0]))
}
if gjson.GetBytes(out[0], "choices.0.delta.reasoning_content").Exists() {
t.Fatalf("expected reasoning_content to be omitted, got %s", string(out[0]))
}
if !gjson.GetBytes(out[0], "choices.0.delta.tool_calls.0.function.arguments").Exists() {
t.Fatalf("expected tool call arguments delta to exist, got %s", string(out[0]))
}
}

View File

@@ -57,6 +57,12 @@ func GetProviderName(modelName string) []string {
return providers
}
// Fallback: if cursor provider has registered models, route unknown models to it.
// Cursor acts as a universal proxy supporting multiple model families (Claude, GPT, Gemini, etc.).
if models := registry.GetGlobalRegistry().GetAvailableModelsByProvider("cursor"); len(models) > 0 {
return []string{"cursor"}
}
return providers
}

View File

@@ -256,6 +256,9 @@ func BuildConfigChangeDetails(oldCfg, newCfg *config.Config) []string {
if oldCfg.RemoteManagement.DisableControlPanel != newCfg.RemoteManagement.DisableControlPanel {
changes = append(changes, fmt.Sprintf("remote-management.disable-control-panel: %t -> %t", oldCfg.RemoteManagement.DisableControlPanel, newCfg.RemoteManagement.DisableControlPanel))
}
if oldCfg.RemoteManagement.DisableAutoUpdatePanel != newCfg.RemoteManagement.DisableAutoUpdatePanel {
changes = append(changes, fmt.Sprintf("remote-management.disable-auto-update-panel: %t -> %t", oldCfg.RemoteManagement.DisableAutoUpdatePanel, newCfg.RemoteManagement.DisableAutoUpdatePanel))
}
oldPanelRepo := strings.TrimSpace(oldCfg.RemoteManagement.PanelGitHubRepository)
newPanelRepo := strings.TrimSpace(newCfg.RemoteManagement.PanelGitHubRepository)
if oldPanelRepo != newPanelRepo {

View File

@@ -20,10 +20,11 @@ func TestBuildConfigChangeDetails(t *testing.T) {
RestrictManagementToLocalhost: false,
},
RemoteManagement: config.RemoteManagement{
AllowRemote: false,
SecretKey: "old",
DisableControlPanel: false,
PanelGitHubRepository: "repo-old",
AllowRemote: false,
SecretKey: "old",
DisableControlPanel: false,
DisableAutoUpdatePanel: false,
PanelGitHubRepository: "repo-old",
},
OAuthExcludedModels: map[string][]string{
"providerA": {"m1"},
@@ -54,10 +55,11 @@ func TestBuildConfigChangeDetails(t *testing.T) {
},
},
RemoteManagement: config.RemoteManagement{
AllowRemote: true,
SecretKey: "new",
DisableControlPanel: true,
PanelGitHubRepository: "repo-new",
AllowRemote: true,
SecretKey: "new",
DisableControlPanel: true,
DisableAutoUpdatePanel: true,
PanelGitHubRepository: "repo-new",
},
OAuthExcludedModels: map[string][]string{
"providerA": {"m1", "m2"},
@@ -88,6 +90,7 @@ func TestBuildConfigChangeDetails(t *testing.T) {
expectContains(t, details, "ampcode.upstream-url: http://old-upstream -> http://new-upstream")
expectContains(t, details, "ampcode.model-mappings: updated (1 -> 2 entries)")
expectContains(t, details, "remote-management.allow-remote: false -> true")
expectContains(t, details, "remote-management.disable-auto-update-panel: false -> true")
expectContains(t, details, "remote-management.secret-key: updated")
expectContains(t, details, "oauth-excluded-models[providera]: updated (1 -> 2 entries)")
expectContains(t, details, "oauth-excluded-models[providerb]: added (1 entries)")
@@ -265,9 +268,10 @@ func TestBuildConfigChangeDetails_FlagsAndKeys(t *testing.T) {
ModelMappings: []config.AmpModelMapping{{From: "a", To: "b"}},
},
RemoteManagement: config.RemoteManagement{
DisableControlPanel: true,
PanelGitHubRepository: "new/repo",
SecretKey: "",
DisableControlPanel: true,
DisableAutoUpdatePanel: true,
PanelGitHubRepository: "new/repo",
SecretKey: "",
},
SDKConfig: sdkconfig.SDKConfig{
RequestLog: true,
@@ -299,6 +303,7 @@ func TestBuildConfigChangeDetails_FlagsAndKeys(t *testing.T) {
expectContains(t, details, "ampcode.restrict-management-to-localhost: false -> true")
expectContains(t, details, "ampcode.upstream-api-key: removed")
expectContains(t, details, "remote-management.disable-control-panel: false -> true")
expectContains(t, details, "remote-management.disable-auto-update-panel: false -> true")
expectContains(t, details, "remote-management.panel-github-repository: old/repo -> new/repo")
expectContains(t, details, "remote-management.secret-key: deleted")
}
@@ -336,10 +341,11 @@ func TestBuildConfigChangeDetails_AllBranches(t *testing.T) {
ForceModelMappings: false,
},
RemoteManagement: config.RemoteManagement{
AllowRemote: false,
DisableControlPanel: false,
PanelGitHubRepository: "old/repo",
SecretKey: "old",
AllowRemote: false,
DisableControlPanel: false,
DisableAutoUpdatePanel: false,
PanelGitHubRepository: "old/repo",
SecretKey: "old",
},
SDKConfig: sdkconfig.SDKConfig{
RequestLog: false,
@@ -389,10 +395,11 @@ func TestBuildConfigChangeDetails_AllBranches(t *testing.T) {
ForceModelMappings: true,
},
RemoteManagement: config.RemoteManagement{
AllowRemote: true,
DisableControlPanel: true,
PanelGitHubRepository: "new/repo",
SecretKey: "",
AllowRemote: true,
DisableControlPanel: true,
DisableAutoUpdatePanel: true,
PanelGitHubRepository: "new/repo",
SecretKey: "",
},
SDKConfig: sdkconfig.SDKConfig{
RequestLog: true,
@@ -460,6 +467,7 @@ func TestBuildConfigChangeDetails_AllBranches(t *testing.T) {
expectContains(t, changes, "oauth-excluded-models[p2]: added (1 entries)")
expectContains(t, changes, "remote-management.allow-remote: false -> true")
expectContains(t, changes, "remote-management.disable-control-panel: false -> true")
expectContains(t, changes, "remote-management.disable-auto-update-panel: false -> true")
expectContains(t, changes, "remote-management.panel-github-repository: old/repo -> new/repo")
expectContains(t, changes, "remote-management.secret-key: deleted")
expectContains(t, changes, "openai-compatibility:")

95
sdk/auth/codebuddy.go Normal file
View File

@@ -0,0 +1,95 @@
package auth
import (
"context"
"fmt"
"time"
"github.com/router-for-me/CLIProxyAPI/v6/internal/auth/codebuddy"
"github.com/router-for-me/CLIProxyAPI/v6/internal/browser"
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
coreauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth"
log "github.com/sirupsen/logrus"
)
// CodeBuddyAuthenticator implements the browser OAuth polling flow for CodeBuddy.
type CodeBuddyAuthenticator struct{}
// NewCodeBuddyAuthenticator constructs a new CodeBuddy authenticator.
func NewCodeBuddyAuthenticator() Authenticator {
return &CodeBuddyAuthenticator{}
}
// Provider returns the provider key for codebuddy.
func (CodeBuddyAuthenticator) Provider() string {
return "codebuddy"
}
// codeBuddyRefreshLead is the duration before token expiry when a refresh should be attempted.
var codeBuddyRefreshLead = 24 * time.Hour
// RefreshLead returns how soon before expiry a refresh should be attempted.
// CodeBuddy tokens have a long validity period, so we refresh 24 hours before expiry.
func (CodeBuddyAuthenticator) RefreshLead() *time.Duration {
return &codeBuddyRefreshLead
}
// Login initiates the browser OAuth flow for CodeBuddy.
func (a CodeBuddyAuthenticator) Login(ctx context.Context, cfg *config.Config, opts *LoginOptions) (*coreauth.Auth, error) {
if cfg == nil {
return nil, fmt.Errorf("codebuddy: configuration is required")
}
if opts == nil {
opts = &LoginOptions{}
}
if ctx == nil {
ctx = context.Background()
}
authSvc := codebuddy.NewCodeBuddyAuth(cfg)
authState, err := authSvc.FetchAuthState(ctx)
if err != nil {
return nil, fmt.Errorf("codebuddy: failed to fetch auth state: %w", err)
}
fmt.Printf("\nPlease open the following URL in your browser to login:\n\n %s\n\n", authState.AuthURL)
fmt.Println("Waiting for authorization...")
if !opts.NoBrowser {
if browser.IsAvailable() {
if errOpen := browser.OpenURL(authState.AuthURL); errOpen != nil {
log.Debugf("codebuddy: failed to open browser: %v", errOpen)
}
}
}
storage, err := authSvc.PollForToken(ctx, authState.State)
if err != nil {
return nil, fmt.Errorf("codebuddy: %s: %w", codebuddy.GetUserFriendlyMessage(err), err)
}
fmt.Printf("\nSuccessfully logged in! (User ID: %s)\n", storage.UserID)
authID := fmt.Sprintf("codebuddy-%s.json", storage.UserID)
label := storage.UserID
if label == "" {
label = "codebuddy-user"
}
return &coreauth.Auth{
ID: authID,
Provider: a.Provider(),
FileName: authID,
Label: label,
Storage: storage,
Metadata: map[string]any{
"access_token": storage.AccessToken,
"refresh_token": storage.RefreshToken,
"user_id": storage.UserID,
"domain": storage.Domain,
"expires_in": storage.ExpiresIn,
},
}, nil
}

98
sdk/auth/cursor.go Normal file
View File

@@ -0,0 +1,98 @@
package auth
import (
"context"
"fmt"
"time"
cursorauth "github.com/router-for-me/CLIProxyAPI/v6/internal/auth/cursor"
"github.com/router-for-me/CLIProxyAPI/v6/internal/browser"
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
coreauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth"
log "github.com/sirupsen/logrus"
)
// CursorAuthenticator implements OAuth PKCE login for Cursor.
type CursorAuthenticator struct{}
// NewCursorAuthenticator constructs a new Cursor authenticator.
func NewCursorAuthenticator() Authenticator {
return &CursorAuthenticator{}
}
// Provider returns the provider key for cursor.
func (CursorAuthenticator) Provider() string {
return "cursor"
}
// RefreshLead returns the time before expiry when a refresh should be attempted.
func (CursorAuthenticator) RefreshLead() *time.Duration {
d := 10 * time.Minute
return &d
}
// Login initiates the Cursor PKCE authentication flow.
func (a CursorAuthenticator) Login(ctx context.Context, cfg *config.Config, opts *LoginOptions) (*coreauth.Auth, error) {
if cfg == nil {
return nil, fmt.Errorf("cursor auth: configuration is required")
}
if opts == nil {
opts = &LoginOptions{}
}
// Generate PKCE auth parameters
authParams, err := cursorauth.GenerateAuthParams()
if err != nil {
return nil, fmt.Errorf("cursor: failed to generate auth params: %w", err)
}
// Display the login URL
log.Info("Starting Cursor authentication...")
log.Infof("Please visit this URL to log in: %s", authParams.LoginURL)
// Try to open the browser automatically
if !opts.NoBrowser {
if browser.IsAvailable() {
if errOpen := browser.OpenURL(authParams.LoginURL); errOpen != nil {
log.Warnf("Failed to open browser automatically: %v", errOpen)
}
}
}
log.Info("Waiting for Cursor authorization...")
// Poll for the auth result
tokens, err := cursorauth.PollForAuth(ctx, authParams.UUID, authParams.Verifier)
if err != nil {
return nil, fmt.Errorf("cursor: authentication failed: %w", err)
}
expiresAt := cursorauth.GetTokenExpiry(tokens.AccessToken)
// Auto-identify account from JWT sub claim
sub := cursorauth.ParseJWTSub(tokens.AccessToken)
subHash := cursorauth.SubToShortHash(sub)
log.Info("Cursor authentication successful!")
metadata := map[string]any{
"type": "cursor",
"access_token": tokens.AccessToken,
"refresh_token": tokens.RefreshToken,
"expires_at": expiresAt.Format(time.RFC3339),
"timestamp": time.Now().UnixMilli(),
}
if sub != "" {
metadata["sub"] = sub
}
fileName := cursorauth.CredentialFileName("", subHash)
return &coreauth.Auth{
ID: fileName,
Provider: a.Provider(),
FileName: fileName,
Label: cursorauth.DisplayLabel("", subHash),
Metadata: metadata,
}, nil
}

View File

@@ -18,6 +18,8 @@ func init() {
registerRefreshLead("kiro", func() Authenticator { return NewKiroAuthenticator() })
registerRefreshLead("github-copilot", func() Authenticator { return NewGitHubCopilotAuthenticator() })
registerRefreshLead("gitlab", func() Authenticator { return NewGitLabAuthenticator() })
registerRefreshLead("codebuddy", func() Authenticator { return NewCodeBuddyAuthenticator() })
registerRefreshLead("cursor", func() Authenticator { return NewCursorAuthenticator() })
}
func registerRefreshLead(provider string, factory func() Authenticator) {

View File

@@ -545,6 +545,11 @@ func (m *scheduledAuthMeta) supportsModel(modelKey string) bool {
if modelKey == "" {
return true
}
// Cursor acts as a universal proxy supporting multiple model families.
// Allow any model to be routed to cursor auth.
if m.providerKey == "cursor" {
return true
}
if len(m.supportedModelSet) == 0 {
return false
}

View File

@@ -441,8 +441,12 @@ func (s *Service) ensureExecutorsForAuthWithMode(a *coreauth.Auth, forceReplace
s.coreManager.RegisterExecutor(executor.NewKiroExecutor(s.cfg))
case "kilo":
s.coreManager.RegisterExecutor(executor.NewKiloExecutor(s.cfg))
case "cursor":
s.coreManager.RegisterExecutor(executor.NewCursorExecutor(s.cfg))
case "github-copilot":
s.coreManager.RegisterExecutor(executor.NewGitHubCopilotExecutor(s.cfg))
case "codebuddy":
s.coreManager.RegisterExecutor(executor.NewCodeBuddyExecutor(s.cfg))
case "gitlab":
s.coreManager.RegisterExecutor(executor.NewGitLabExecutor(s.cfg))
default:
@@ -940,6 +944,11 @@ func (s *Service) registerModelsForAuth(a *coreauth.Auth) {
case "kimi":
models = registry.GetKimiModels()
models = applyExcludedModels(models, excluded)
case "cursor":
ctx, cancel := context.WithTimeout(context.Background(), 15*time.Second)
defer cancel()
models = executor.FetchCursorModels(ctx, a, s.cfg)
models = applyExcludedModels(models, excluded)
case "github-copilot":
ctx, cancel := context.WithTimeout(context.Background(), 15*time.Second)
defer cancel()
@@ -954,6 +963,9 @@ func (s *Service) registerModelsForAuth(a *coreauth.Auth) {
case "gitlab":
models = executor.GitLabModelsFromAuth(a)
models = applyExcludedModels(models, excluded)
case "codebuddy":
models = registry.GetCodeBuddyModels()
models = applyExcludedModels(models, excluded)
default:
// Handle OpenAI-compatibility providers by name using config
if s.cfg != nil {
@@ -1006,6 +1018,10 @@ func (s *Service) registerModelsForAuth(a *coreauth.Auth) {
if modelID == "" {
modelID = m.Name
}
thinking := m.Thinking
if thinking == nil {
thinking = &registry.ThinkingSupport{Levels: []string{"low", "medium", "high"}}
}
ms = append(ms, &ModelInfo{
ID: modelID,
Object: "model",
@@ -1013,7 +1029,8 @@ func (s *Service) registerModelsForAuth(a *coreauth.Auth) {
OwnedBy: compat.Name,
Type: "openai-compatibility",
DisplayName: modelID,
UserDefined: true,
UserDefined: false,
Thinking: thinking,
})
}
// Register and return

309
test_cursor.sh Executable file
View File

@@ -0,0 +1,309 @@
#!/bin/bash
# Test script for Cursor proxy integration
# Usage:
# ./test_cursor.sh login - Login to Cursor (opens browser)
# ./test_cursor.sh start - Build and start the server
# ./test_cursor.sh test - Run API tests against running server
# ./test_cursor.sh all - Login + Start + Test (full flow)
set -e
export PATH="/opt/homebrew/bin:$PATH"
export GOROOT="/opt/homebrew/Cellar/go/1.26.1/libexec"
PROJECT_DIR="/Volumes/Personal/cursor-cli-proxy/CLIProxyAPIPlus"
BINARY="$PROJECT_DIR/cliproxy-test"
API_KEY="quotio-local-D6ABC285-3085-44B4-B872-BD269888811F"
BASE_URL="http://127.0.0.1:8317"
CONFIG="$PROJECT_DIR/config-cursor-test.yaml"
PID_FILE="/tmp/cliproxy-test.pid"
# Colors
GREEN='\033[0;32m'
RED='\033[0;31m'
YELLOW='\033[1;33m'
NC='\033[0m'
info() { echo -e "${GREEN}[INFO]${NC} $1"; }
warn() { echo -e "${YELLOW}[WARN]${NC} $1"; }
error() { echo -e "${RED}[ERROR]${NC} $1"; }
# --- Build ---
build() {
info "Building CLIProxyAPIPlus..."
cd "$PROJECT_DIR"
go build -o "$BINARY" ./cmd/server/
info "Build successful: $BINARY"
}
# --- Create test config ---
create_config() {
cat > "$CONFIG" << 'EOF'
host: '127.0.0.1'
port: 8317
auth-dir: '~/.cli-proxy-api'
api-keys:
- 'quotio-local-D6ABC285-3085-44B4-B872-BD269888811F'
debug: true
EOF
info "Test config created: $CONFIG"
}
# --- Login ---
do_login() {
build
create_config
info "Starting Cursor login (will open browser)..."
"$BINARY" --config "$CONFIG" --cursor-login
}
# --- Start server ---
start_server() {
# Kill any existing instance
stop_server 2>/dev/null || true
build
create_config
info "Starting server on port 8317..."
"$BINARY" --config "$CONFIG" &
SERVER_PID=$!
echo "$SERVER_PID" > "$PID_FILE"
info "Server started (PID: $SERVER_PID)"
# Wait for server to be ready
info "Waiting for server to be ready..."
for i in $(seq 1 15); do
if curl -s "$BASE_URL/v1/models" -H "Authorization: Bearer $API_KEY" > /dev/null 2>&1; then
info "Server is ready!"
return 0
fi
sleep 1
done
error "Server failed to start within 15 seconds"
return 1
}
# --- Stop server ---
stop_server() {
if [ -f "$PID_FILE" ]; then
PID=$(cat "$PID_FILE")
if kill -0 "$PID" 2>/dev/null; then
info "Stopping server (PID: $PID)..."
kill "$PID"
rm -f "$PID_FILE"
fi
fi
# Also kill any stale process on port 8317
lsof -ti:8317 2>/dev/null | xargs kill 2>/dev/null || true
}
# --- Test: List models ---
test_models() {
info "Testing GET /v1/models (looking for cursor models)..."
RESPONSE=$(curl -s "$BASE_URL/v1/models" \
-H "Authorization: Bearer $API_KEY")
CURSOR_MODELS=$(echo "$RESPONSE" | python3 -c "
import json, sys
try:
data = json.load(sys.stdin)
models = [m['id'] for m in data.get('data', []) if m.get('owned_by') == 'cursor' or m.get('type') == 'cursor']
if models:
print('\n'.join(models))
else:
print('NONE')
except:
print('ERROR')
" 2>/dev/null || echo "PARSE_ERROR")
if [ "$CURSOR_MODELS" = "NONE" ] || [ "$CURSOR_MODELS" = "ERROR" ] || [ "$CURSOR_MODELS" = "PARSE_ERROR" ]; then
warn "No cursor models found. Have you run '--cursor-login' first?"
echo " Response preview: $(echo "$RESPONSE" | head -c 200)"
return 1
else
info "Found cursor models:"
echo "$CURSOR_MODELS" | while read -r model; do
echo " - $model"
done
return 0
fi
}
# --- Test: Chat completion (streaming) ---
test_chat_stream() {
local model="${1:-cursor-small}"
info "Testing POST /v1/chat/completions (stream, model=$model)..."
RESPONSE=$(curl -s --max-time 30 "$BASE_URL/v1/chat/completions" \
-H "Authorization: Bearer $API_KEY" \
-H "Content-Type: application/json" \
-d "{
\"model\": \"$model\",
\"messages\": [{\"role\": \"user\", \"content\": \"Say hello in exactly 3 words.\"}],
\"stream\": true
}" 2>&1)
# Check if we got SSE data
if echo "$RESPONSE" | grep -q "data:"; then
# Extract content from SSE chunks
CONTENT=$(echo "$RESPONSE" | grep "^data: " | grep -v "\[DONE\]" | while read -r line; do
echo "${line#data: }" | python3 -c "
import json, sys
try:
chunk = json.load(sys.stdin)
delta = chunk.get('choices', [{}])[0].get('delta', {})
content = delta.get('content', '')
if content:
sys.stdout.write(content)
except:
pass
" 2>/dev/null
done)
if [ -n "$CONTENT" ]; then
info "Stream response received:"
echo " Content: $CONTENT"
return 0
else
warn "Got SSE chunks but no content extracted"
echo " Raw (first 500 chars): $(echo "$RESPONSE" | head -c 500)"
return 1
fi
else
error "No SSE data received"
echo " Response: $(echo "$RESPONSE" | head -c 300)"
return 1
fi
}
# --- Test: Chat completion (non-streaming) ---
test_chat_nonstream() {
local model="${1:-cursor-small}"
info "Testing POST /v1/chat/completions (non-stream, model=$model)..."
RESPONSE=$(curl -s --max-time 30 "$BASE_URL/v1/chat/completions" \
-H "Authorization: Bearer $API_KEY" \
-H "Content-Type: application/json" \
-d "{
\"model\": \"$model\",
\"messages\": [{\"role\": \"user\", \"content\": \"What is 2+2? Answer with just the number.\"}],
\"stream\": false
}" 2>&1)
CONTENT=$(echo "$RESPONSE" | python3 -c "
import json, sys
try:
data = json.load(sys.stdin)
content = data['choices'][0]['message']['content']
print(content)
except Exception as e:
print(f'ERROR: {e}')
" 2>/dev/null || echo "PARSE_ERROR")
if echo "$CONTENT" | grep -q "ERROR\|PARSE_ERROR"; then
error "Non-streaming request failed"
echo " Response: $(echo "$RESPONSE" | head -c 300)"
return 1
else
info "Non-stream response received:"
echo " Content: $CONTENT"
return 0
fi
}
# --- Run all tests ---
run_tests() {
local passed=0
local failed=0
echo ""
echo "========================================="
echo " Cursor Proxy Integration Tests"
echo "========================================="
echo ""
# Test 1: Models
if test_models; then
((passed++))
else
((failed++))
fi
echo ""
# Test 2: Streaming chat
if test_chat_stream "cursor-small"; then
((passed++))
else
((failed++))
fi
echo ""
# Test 3: Non-streaming chat
if test_chat_nonstream "cursor-small"; then
((passed++))
else
((failed++))
fi
echo ""
echo "========================================="
echo " Results: ${passed} passed, ${failed} failed"
echo "========================================="
[ "$failed" -eq 0 ]
}
# --- Cleanup ---
cleanup() {
stop_server
rm -f "$BINARY" "$CONFIG"
info "Cleaned up."
}
# --- Main ---
case "${1:-help}" in
login)
do_login
;;
start)
start_server
info "Server running. Use './test_cursor.sh test' to run tests."
info "Use './test_cursor.sh stop' to stop."
;;
stop)
stop_server
;;
test)
run_tests
;;
all)
info "=== Full flow: login -> start -> test ==="
echo ""
info "Step 1: Login to Cursor"
do_login
echo ""
info "Step 2: Start server"
start_server
echo ""
info "Step 3: Run tests"
sleep 2
run_tests
echo ""
info "Step 4: Cleanup"
stop_server
;;
clean)
cleanup
;;
*)
echo "Usage: $0 {login|start|stop|test|all|clean}"
echo ""
echo " login - Authenticate with Cursor (opens browser)"
echo " start - Build and start the proxy server"
echo " stop - Stop the running server"
echo " test - Run API tests against running server"
echo " all - Full flow: login + start + test"
echo " clean - Stop server and remove artifacts"
;;
esac