mirror of
https://github.com/router-for-me/CLIProxyAPIPlus.git
synced 2026-03-30 01:06:39 +00:00
Compare commits
46 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
f8d1bc06ea | ||
|
|
d5930f4e44 | ||
|
|
9b7d7021af | ||
|
|
e41c22ef44 | ||
|
|
55271403fb | ||
|
|
36fba66619 | ||
|
|
b9b127a7ea | ||
|
|
2741e7b7b3 | ||
|
|
1767a56d4f | ||
|
|
779e6c2d2f | ||
|
|
73c831747b | ||
|
|
b8b89f34f4 | ||
|
|
1fa094dac6 | ||
|
|
e5d3541b5a | ||
|
|
79755e76ea | ||
|
|
35f158d526 | ||
|
|
6962e09dd9 | ||
|
|
4c4cbd44da | ||
|
|
26eca8b6ba | ||
|
|
62b17f40a1 | ||
|
|
511b8a992e | ||
|
|
0ab977c236 | ||
|
|
224f0de353 | ||
|
|
d54de441d3 | ||
|
|
7386a70724 | ||
|
|
1b7447b682 | ||
|
|
40dee4453a | ||
|
|
8902e1cccb | ||
|
|
de5fe71478 | ||
|
|
dcfbec2990 | ||
|
|
c95620f90e | ||
|
|
754f3bcbc3 | ||
|
|
36973d4a6f | ||
|
|
9613f0b3f9 | ||
|
|
274f29e26b | ||
|
|
c8e79c3787 | ||
|
|
8afef43887 | ||
|
|
c1083cbfc6 | ||
|
|
c89d19b300 | ||
|
|
19c52bcb60 | ||
|
|
cc32f5ff61 | ||
|
|
fbff68b9e0 | ||
|
|
7e1a543b79 | ||
|
|
74b862d8b8 | ||
|
|
5c817a9b42 | ||
|
|
5da0decef6 |
1
.gitignore
vendored
1
.gitignore
vendored
@@ -1,6 +1,7 @@
|
||||
# Binaries
|
||||
cli-proxy-api
|
||||
cliproxy
|
||||
/server
|
||||
*.exe
|
||||
|
||||
|
||||
|
||||
12
README_JA.md
12
README_JA.md
@@ -34,6 +34,10 @@ GLM CODING PLANを10%割引で取得:https://z.ai/subscribe?ic=8JVLJQFSKB
|
||||
<td width="180"><a href="https://shop.bmoplus.com/?utm_source=github"><img src="./assets/bmoplus.png" alt="BmoPlus" width="150"></a></td>
|
||||
<td>本プロジェクトにご支援いただいた BmoPlus に感謝いたします!BmoPlusは、AIサブスクリプションのヘビーユーザー向けに特化した信頼性の高いAIアカウントサービスプロバイダーであり、安定した ChatGPT Plus / ChatGPT Pro (完全保証) / Claude Pro / Super Grok / Gemini Pro の公式代行チャージおよび即納アカウントを提供しています。こちらの<a href="https://shop.bmoplus.com/?utm_source=github">BmoPlus AIアカウント専門店/代行チャージ</a>経由でご登録・ご注文いただいたユーザー様は、GPTを <b>公式サイト価格の約1割(90% OFF)</b> という驚異的な価格でご利用いただけます!</td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td width="180"><a href="https://www.lingtrue.com/register"><img src="./assets/lingtrue.png" alt="LingtrueAPI" width="150"></a></td>
|
||||
<td>LingtrueAPIのスポンサーシップに感謝します!LingtrueAPIはグローバルな大規模モデルAPIリレーサービスプラットフォームで、Claude Code、Codex、GeminiなどのトップモデルAPI呼び出しサービスを提供し、ユーザーが低コストかつ高い安定性で世界中のAI能力に接続できるよう支援しています。LingtrueAPIは本ソフトウェアのユーザーに特別割引を提供しています:<a href="https://www.lingtrue.com/register">こちらのリンク</a>から登録し、初回チャージ時にプロモーションコード「LingtrueAPI」を入力すると10%割引になります。</td>
|
||||
</tr>
|
||||
</tbody>
|
||||
</table>
|
||||
|
||||
@@ -78,6 +82,14 @@ CLIProxyAPIは[Amp CLI](https://ampcode.com)およびAmp IDE拡張機能の統
|
||||
- 利用できないモデルを代替モデルにルーティングする**モデルマッピング**(例:`claude-opus-4.5` → `claude-sonnet-4`)
|
||||
- localhostのみの管理エンドポイントによるセキュリティファーストの設計
|
||||
|
||||
特定のバックエンド系統のリクエスト/レスポンス形状が必要な場合は、統合された `/v1/...` エンドポイントよりも provider-specific のパスを優先してください。
|
||||
|
||||
- messages 系のバックエンドには `/api/provider/{provider}/v1/messages`
|
||||
- モデル単位の generate 系エンドポイントには `/api/provider/{provider}/v1beta/models/...`
|
||||
- chat-completions 系のバックエンドには `/api/provider/{provider}/v1/chat/completions`
|
||||
|
||||
これらのパスはプロトコル面の選択には役立ちますが、同じクライアント向けモデル名が複数バックエンドで再利用されている場合、それだけで推論実行系が一意に固定されるわけではありません。実際の推論ルーティングは、引き続きリクエスト内の model/alias 解決に従います。厳密にバックエンドを固定したい場合は、一意な alias や prefix を使うか、クライアント向けモデル名の重複自体を避けてください。
|
||||
|
||||
**→ [Amp CLI統合ガイドの完全版](https://help.router-for.me/agent-client/amp-cli.html)**
|
||||
|
||||
## SDKドキュメント
|
||||
|
||||
BIN
assets/lingtrue.png
Normal file
BIN
assets/lingtrue.png
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 129 KiB |
20
cmd/mcpdebug/main.go
Normal file
20
cmd/mcpdebug/main.go
Normal file
@@ -0,0 +1,20 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"encoding/hex"
|
||||
"fmt"
|
||||
"os"
|
||||
|
||||
cursorproto "github.com/router-for-me/CLIProxyAPI/v6/internal/auth/cursor/proto"
|
||||
)
|
||||
|
||||
func main() {
|
||||
// Encode MCP result with empty execId
|
||||
resultBytes := cursorproto.EncodeExecMcpResult(1, "", `{"test": "data"}`, false)
|
||||
fmt.Printf("Result protobuf hex: %s\n", hex.EncodeToString(resultBytes))
|
||||
fmt.Printf("Result length: %d bytes\n", len(resultBytes))
|
||||
|
||||
// Write to file for analysis
|
||||
os.WriteFile("mcp_result.bin", resultBytes)
|
||||
fmt.Println("Wrote mcp_result.bin")
|
||||
}
|
||||
32
cmd/protocheck/main.go
Normal file
32
cmd/protocheck/main.go
Normal file
@@ -0,0 +1,32 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
cursorproto "github.com/router-for-me/CLIProxyAPI/v6/internal/auth/cursor/proto"
|
||||
)
|
||||
|
||||
func main() {
|
||||
ecm := cursorproto.NewMsg("ExecClientMessage")
|
||||
|
||||
// Try different field names
|
||||
names := []string{
|
||||
"mcp_result", "mcpResult", "McpResult", "MCP_RESULT",
|
||||
"shell_result", "shellResult",
|
||||
}
|
||||
|
||||
for _, name := range names {
|
||||
fd := ecm.Descriptor().Fields().ByName(name)
|
||||
if fd != nil {
|
||||
fmt.Printf("Found field %q: number=%d, kind=%s\n", name, fd.Number(), fd.Kind())
|
||||
} else {
|
||||
fmt.Printf("Field %q NOT FOUND\n", name)
|
||||
}
|
||||
}
|
||||
|
||||
// List all fields
|
||||
fmt.Println("\nAll fields in ExecClientMessage:")
|
||||
for i := 0; i < ecm.Descriptor().Fields().Len(); i++ {
|
||||
f := ecm.Descriptor().Fields().Get(i)
|
||||
fmt.Printf(" %d: %q (number=%d)\n", i, f.Name(), f.Number())
|
||||
}
|
||||
}
|
||||
@@ -85,6 +85,7 @@ func main() {
|
||||
var oauthCallbackPort int
|
||||
var antigravityLogin bool
|
||||
var kimiLogin bool
|
||||
var cursorLogin bool
|
||||
var kiroLogin bool
|
||||
var kiroGoogleLogin bool
|
||||
var kiroAWSLogin bool
|
||||
@@ -123,6 +124,7 @@ func main() {
|
||||
flag.BoolVar(&noIncognito, "no-incognito", false, "Force disable incognito mode (uses existing browser session)")
|
||||
flag.BoolVar(&antigravityLogin, "antigravity-login", false, "Login to Antigravity using OAuth")
|
||||
flag.BoolVar(&kimiLogin, "kimi-login", false, "Login to Kimi using OAuth")
|
||||
flag.BoolVar(&cursorLogin, "cursor-login", false, "Login to Cursor using OAuth")
|
||||
flag.BoolVar(&kiroLogin, "kiro-login", false, "Login to Kiro using Google OAuth")
|
||||
flag.BoolVar(&kiroGoogleLogin, "kiro-google-login", false, "Login to Kiro using Google OAuth (same as --kiro-login)")
|
||||
flag.BoolVar(&kiroAWSLogin, "kiro-aws-login", false, "Login to Kiro using AWS Builder ID (device code flow)")
|
||||
@@ -544,6 +546,8 @@ func main() {
|
||||
cmd.DoGitLabTokenLogin(cfg, options)
|
||||
} else if kimiLogin {
|
||||
cmd.DoKimiLogin(cfg, options)
|
||||
} else if cursorLogin {
|
||||
cmd.DoCursorLogin(cfg, options)
|
||||
} else if kiroLogin {
|
||||
// For Kiro auth, default to incognito mode for multi-account support
|
||||
// Users can explicitly override with --no-incognito
|
||||
|
||||
@@ -313,6 +313,10 @@ nonstream-keepalive-interval: 0
|
||||
# These aliases rename model IDs for both model listing and request routing.
|
||||
# Supported channels: gemini-cli, vertex, aistudio, antigravity, claude, codex, qwen, iflow, kiro, github-copilot, kimi.
|
||||
# NOTE: Aliases do not apply to gemini-api-key, codex-api-key, claude-api-key, openai-compatibility, vertex-api-key, or ampcode.
|
||||
# NOTE: Because aliases affect the merged /v1 model list and merged request routing, overlapping
|
||||
# client-visible names can become ambiguous across providers. /api/provider/{provider}/... helps
|
||||
# you select the protocol surface, but inference backend selection can still follow the resolved
|
||||
# model/alias. For strict backend pinning, use unique aliases/prefixes or avoid overlapping names.
|
||||
# You can repeat the same name with different aliases to expose multiple client model names.
|
||||
# oauth-model-alias:
|
||||
# antigravity:
|
||||
|
||||
@@ -29,6 +29,7 @@ import (
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/auth/claude"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/auth/codex"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/auth/copilot"
|
||||
cursorauth "github.com/router-for-me/CLIProxyAPI/v6/internal/auth/cursor"
|
||||
geminiAuth "github.com/router-for-me/CLIProxyAPI/v6/internal/auth/gemini"
|
||||
gitlabauth "github.com/router-for-me/CLIProxyAPI/v6/internal/auth/gitlab"
|
||||
iflowauth "github.com/router-for-me/CLIProxyAPI/v6/internal/auth/iflow"
|
||||
@@ -2137,9 +2138,6 @@ func (h *Handler) RequestGitLabToken(c *gin.Context) {
|
||||
metadata := buildGitLabAuthMetadata(baseURL, gitLabLoginModeOAuth, tokenResp, direct)
|
||||
metadata["auth_kind"] = "oauth"
|
||||
metadata["oauth_client_id"] = clientID
|
||||
if clientSecret != "" {
|
||||
metadata["oauth_client_secret"] = clientSecret
|
||||
}
|
||||
metadata["username"] = strings.TrimSpace(user.Username)
|
||||
if email := primaryGitLabEmail(user); email != "" {
|
||||
metadata["email"] = email
|
||||
@@ -3707,3 +3705,84 @@ func (h *Handler) RequestKiloToken(c *gin.Context) {
|
||||
"verification_uri": resp.VerificationURL,
|
||||
})
|
||||
}
|
||||
|
||||
// RequestCursorToken initiates the Cursor PKCE authentication flow.
|
||||
// Supports multiple accounts via ?label=xxx query parameter.
|
||||
// The user opens the returned URL in a browser, logs in, and the server polls
|
||||
// until the authentication completes.
|
||||
func (h *Handler) RequestCursorToken(c *gin.Context) {
|
||||
ctx := context.Background()
|
||||
ctx = PopulateAuthContext(ctx, c)
|
||||
|
||||
label := strings.TrimSpace(c.Query("label"))
|
||||
log.Infof("Initializing Cursor authentication (label=%q)...", label)
|
||||
|
||||
authParams, err := cursorauth.GenerateAuthParams()
|
||||
if err != nil {
|
||||
log.Errorf("Failed to generate Cursor auth params: %v", err)
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "failed to generate auth params"})
|
||||
return
|
||||
}
|
||||
|
||||
state := fmt.Sprintf("cur-%d", time.Now().UnixNano())
|
||||
RegisterOAuthSession(state, "cursor")
|
||||
|
||||
go func() {
|
||||
log.Info("Waiting for Cursor authentication...")
|
||||
log.Infof("Open this URL in your browser: %s", authParams.LoginURL)
|
||||
|
||||
tokens, errPoll := cursorauth.PollForAuth(ctx, authParams.UUID, authParams.Verifier)
|
||||
if errPoll != nil {
|
||||
SetOAuthSessionError(state, "Authentication failed: "+errPoll.Error())
|
||||
log.Errorf("Cursor authentication failed: %v", errPoll)
|
||||
return
|
||||
}
|
||||
|
||||
// Build metadata
|
||||
metadata := map[string]any{
|
||||
"type": "cursor",
|
||||
"access_token": tokens.AccessToken,
|
||||
"refresh_token": tokens.RefreshToken,
|
||||
"timestamp": time.Now().UnixMilli(),
|
||||
}
|
||||
|
||||
// Extract expiry and account identity from JWT
|
||||
expiry := cursorauth.GetTokenExpiry(tokens.AccessToken)
|
||||
if !expiry.IsZero() {
|
||||
metadata["expires_at"] = expiry.Format(time.RFC3339)
|
||||
}
|
||||
|
||||
// Auto-identify account from JWT sub claim for multi-account support
|
||||
sub := cursorauth.ParseJWTSub(tokens.AccessToken)
|
||||
subHash := cursorauth.SubToShortHash(sub)
|
||||
if sub != "" {
|
||||
metadata["sub"] = sub
|
||||
}
|
||||
|
||||
fileName := cursorauth.CredentialFileName(label, subHash)
|
||||
displayLabel := cursorauth.DisplayLabel(label, subHash)
|
||||
record := &coreauth.Auth{
|
||||
ID: fileName,
|
||||
Provider: "cursor",
|
||||
FileName: fileName,
|
||||
Label: displayLabel,
|
||||
Metadata: metadata,
|
||||
}
|
||||
savedPath, errSave := h.saveTokenRecord(ctx, record)
|
||||
if errSave != nil {
|
||||
log.Errorf("Failed to save Cursor tokens: %v", errSave)
|
||||
SetOAuthSessionError(state, "Failed to save tokens")
|
||||
return
|
||||
}
|
||||
|
||||
log.Infof("Cursor authentication successful! Token saved to %s", savedPath)
|
||||
CompleteOAuthSession(state)
|
||||
CompleteOAuthSessionsByProvider("cursor")
|
||||
}()
|
||||
|
||||
c.JSON(200, gin.H{
|
||||
"status": "ok",
|
||||
"url": authParams.LoginURL,
|
||||
"state": state,
|
||||
})
|
||||
}
|
||||
|
||||
@@ -682,6 +682,7 @@ func (s *Server) registerManagementRoutes() {
|
||||
mgmt.GET("/iflow-auth-url", s.mgmt.RequestIFlowToken)
|
||||
mgmt.POST("/iflow-auth-url", s.mgmt.RequestIFlowCookieToken)
|
||||
mgmt.GET("/kiro-auth-url", s.mgmt.RequestKiroToken)
|
||||
mgmt.GET("/cursor-auth-url", s.mgmt.RequestCursorToken)
|
||||
mgmt.GET("/github-auth-url", s.mgmt.RequestGitHubToken)
|
||||
mgmt.POST("/oauth-callback", s.mgmt.PostOAuthCallback)
|
||||
mgmt.GET("/get-auth-status", s.mgmt.GetAuthStatus)
|
||||
|
||||
33
internal/auth/cursor/filename.go
Normal file
33
internal/auth/cursor/filename.go
Normal file
@@ -0,0 +1,33 @@
|
||||
package cursor
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"strings"
|
||||
)
|
||||
|
||||
// CredentialFileName returns the filename used to persist Cursor credentials.
|
||||
// Priority: explicit label > auto-generated from JWT sub hash.
|
||||
// If both label and subHash are empty, falls back to "cursor.json".
|
||||
func CredentialFileName(label, subHash string) string {
|
||||
label = strings.TrimSpace(label)
|
||||
subHash = strings.TrimSpace(subHash)
|
||||
if label != "" {
|
||||
return fmt.Sprintf("cursor.%s.json", label)
|
||||
}
|
||||
if subHash != "" {
|
||||
return fmt.Sprintf("cursor.%s.json", subHash)
|
||||
}
|
||||
return "cursor.json"
|
||||
}
|
||||
|
||||
// DisplayLabel returns a human-readable label for the Cursor account.
|
||||
func DisplayLabel(label, subHash string) string {
|
||||
label = strings.TrimSpace(label)
|
||||
if label != "" {
|
||||
return "Cursor " + label
|
||||
}
|
||||
if subHash != "" {
|
||||
return "Cursor " + subHash
|
||||
}
|
||||
return "Cursor User"
|
||||
}
|
||||
249
internal/auth/cursor/oauth.go
Normal file
249
internal/auth/cursor/oauth.go
Normal file
@@ -0,0 +1,249 @@
|
||||
// Package cursor implements Cursor OAuth PKCE authentication and token refresh.
|
||||
package cursor
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/rand"
|
||||
"crypto/sha256"
|
||||
"encoding/base64"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"math"
|
||||
"net/http"
|
||||
"strings"
|
||||
"time"
|
||||
)
|
||||
|
||||
const (
|
||||
CursorLoginURL = "https://cursor.com/loginDeepControl"
|
||||
CursorPollURL = "https://api2.cursor.sh/auth/poll"
|
||||
CursorRefreshURL = "https://api2.cursor.sh/auth/exchange_user_api_key"
|
||||
|
||||
pollMaxAttempts = 150
|
||||
pollBaseDelay = 1 * time.Second
|
||||
pollMaxDelay = 10 * time.Second
|
||||
pollBackoffMultiply = 1.2
|
||||
maxConsecutiveErrors = 10
|
||||
)
|
||||
|
||||
// AuthParams holds the PKCE parameters for Cursor login.
|
||||
type AuthParams struct {
|
||||
Verifier string
|
||||
Challenge string
|
||||
UUID string
|
||||
LoginURL string
|
||||
}
|
||||
|
||||
// TokenPair holds the access and refresh tokens from Cursor.
|
||||
type TokenPair struct {
|
||||
AccessToken string `json:"accessToken"`
|
||||
RefreshToken string `json:"refreshToken"`
|
||||
}
|
||||
|
||||
// GeneratePKCE creates a PKCE verifier and challenge pair.
|
||||
func GeneratePKCE() (verifier, challenge string, err error) {
|
||||
verifierBytes := make([]byte, 96)
|
||||
if _, err = rand.Read(verifierBytes); err != nil {
|
||||
return "", "", fmt.Errorf("cursor: failed to generate PKCE verifier: %w", err)
|
||||
}
|
||||
verifier = base64.RawURLEncoding.EncodeToString(verifierBytes)
|
||||
|
||||
h := sha256.Sum256([]byte(verifier))
|
||||
challenge = base64.RawURLEncoding.EncodeToString(h[:])
|
||||
return verifier, challenge, nil
|
||||
}
|
||||
|
||||
// GenerateAuthParams creates the full set of auth params for Cursor login.
|
||||
func GenerateAuthParams() (*AuthParams, error) {
|
||||
verifier, challenge, err := GeneratePKCE()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
uuidBytes := make([]byte, 16)
|
||||
if _, err = rand.Read(uuidBytes); err != nil {
|
||||
return nil, fmt.Errorf("cursor: failed to generate UUID: %w", err)
|
||||
}
|
||||
uuid := fmt.Sprintf("%x-%x-%x-%x-%x",
|
||||
uuidBytes[0:4], uuidBytes[4:6], uuidBytes[6:8], uuidBytes[8:10], uuidBytes[10:16])
|
||||
|
||||
loginURL := fmt.Sprintf("%s?challenge=%s&uuid=%s&mode=login&redirectTarget=cli",
|
||||
CursorLoginURL, challenge, uuid)
|
||||
|
||||
return &AuthParams{
|
||||
Verifier: verifier,
|
||||
Challenge: challenge,
|
||||
UUID: uuid,
|
||||
LoginURL: loginURL,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// PollForAuth polls the Cursor auth endpoint until the user completes login.
|
||||
func PollForAuth(ctx context.Context, uuid, verifier string) (*TokenPair, error) {
|
||||
delay := pollBaseDelay
|
||||
consecutiveErrors := 0
|
||||
|
||||
client := &http.Client{Timeout: 10 * time.Second}
|
||||
|
||||
for attempt := 0; attempt < pollMaxAttempts; attempt++ {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return nil, ctx.Err()
|
||||
case <-time.After(delay):
|
||||
}
|
||||
|
||||
url := fmt.Sprintf("%s?uuid=%s&verifier=%s", CursorPollURL, uuid, verifier)
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodGet, url, nil)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("cursor: failed to create poll request: %w", err)
|
||||
}
|
||||
|
||||
resp, err := client.Do(req)
|
||||
if err != nil {
|
||||
consecutiveErrors++
|
||||
if consecutiveErrors >= maxConsecutiveErrors {
|
||||
return nil, fmt.Errorf("cursor: too many consecutive poll errors (last: %v)", err)
|
||||
}
|
||||
delay = minDuration(time.Duration(float64(delay)*pollBackoffMultiply), pollMaxDelay)
|
||||
continue
|
||||
}
|
||||
|
||||
body, _ := io.ReadAll(resp.Body)
|
||||
resp.Body.Close()
|
||||
|
||||
if resp.StatusCode == http.StatusNotFound {
|
||||
// Still waiting for user to authorize
|
||||
consecutiveErrors = 0
|
||||
delay = minDuration(time.Duration(float64(delay)*pollBackoffMultiply), pollMaxDelay)
|
||||
continue
|
||||
}
|
||||
|
||||
if resp.StatusCode >= 200 && resp.StatusCode < 300 {
|
||||
var tokens TokenPair
|
||||
if err := json.Unmarshal(body, &tokens); err != nil {
|
||||
return nil, fmt.Errorf("cursor: failed to parse auth response: %w", err)
|
||||
}
|
||||
return &tokens, nil
|
||||
}
|
||||
|
||||
return nil, fmt.Errorf("cursor: poll failed with status %d: %s", resp.StatusCode, string(body))
|
||||
}
|
||||
|
||||
return nil, fmt.Errorf("cursor: authentication polling timeout (waited ~%.0f seconds)",
|
||||
float64(pollMaxAttempts)*pollMaxDelay.Seconds()/2)
|
||||
}
|
||||
|
||||
// RefreshToken refreshes a Cursor access token using the refresh token.
|
||||
func RefreshToken(ctx context.Context, refreshToken string) (*TokenPair, error) {
|
||||
client := &http.Client{Timeout: 10 * time.Second}
|
||||
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodPost, CursorRefreshURL,
|
||||
strings.NewReader("{}"))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("cursor: failed to create refresh request: %w", err)
|
||||
}
|
||||
req.Header.Set("Authorization", "Bearer "+refreshToken)
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
|
||||
resp, err := client.Do(req)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("cursor: token refresh request failed: %w", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
body, _ := io.ReadAll(resp.Body)
|
||||
|
||||
if resp.StatusCode < 200 || resp.StatusCode >= 300 {
|
||||
return nil, fmt.Errorf("cursor: token refresh failed (status %d): %s", resp.StatusCode, string(body))
|
||||
}
|
||||
|
||||
var tokens TokenPair
|
||||
if err := json.Unmarshal(body, &tokens); err != nil {
|
||||
return nil, fmt.Errorf("cursor: failed to parse refresh response: %w", err)
|
||||
}
|
||||
|
||||
// Keep original refresh token if not returned
|
||||
if tokens.RefreshToken == "" {
|
||||
tokens.RefreshToken = refreshToken
|
||||
}
|
||||
|
||||
return &tokens, nil
|
||||
}
|
||||
|
||||
// ParseJWTSub extracts the "sub" claim from a Cursor JWT access token.
|
||||
// Cursor JWTs contain "sub" like "auth0|user_XXXX" which uniquely identifies
|
||||
// the account. Returns empty string if parsing fails.
|
||||
func ParseJWTSub(token string) string {
|
||||
decoded := decodeJWTPayload(token)
|
||||
if decoded == nil {
|
||||
return ""
|
||||
}
|
||||
var claims struct {
|
||||
Sub string `json:"sub"`
|
||||
}
|
||||
if err := json.Unmarshal(decoded, &claims); err != nil {
|
||||
return ""
|
||||
}
|
||||
return claims.Sub
|
||||
}
|
||||
|
||||
// SubToShortHash converts a JWT sub claim to a short hex hash for use in filenames.
|
||||
// e.g. "auth0|user_2x..." → "a3f8b2c1"
|
||||
func SubToShortHash(sub string) string {
|
||||
if sub == "" {
|
||||
return ""
|
||||
}
|
||||
h := sha256.Sum256([]byte(sub))
|
||||
return fmt.Sprintf("%x", h[:4]) // 8 hex chars
|
||||
}
|
||||
|
||||
// decodeJWTPayload decodes the payload (middle) part of a JWT.
|
||||
func decodeJWTPayload(token string) []byte {
|
||||
parts := strings.Split(token, ".")
|
||||
if len(parts) != 3 {
|
||||
return nil
|
||||
}
|
||||
payload := parts[1]
|
||||
switch len(payload) % 4 {
|
||||
case 2:
|
||||
payload += "=="
|
||||
case 3:
|
||||
payload += "="
|
||||
}
|
||||
payload = strings.ReplaceAll(payload, "-", "+")
|
||||
payload = strings.ReplaceAll(payload, "_", "/")
|
||||
decoded, err := base64.StdEncoding.DecodeString(payload)
|
||||
if err != nil {
|
||||
return nil
|
||||
}
|
||||
return decoded
|
||||
}
|
||||
|
||||
// GetTokenExpiry extracts the JWT expiry from an access token with a 5-minute safety margin.
|
||||
// Falls back to 1 hour from now if the token can't be parsed.
|
||||
func GetTokenExpiry(token string) time.Time {
|
||||
decoded := decodeJWTPayload(token)
|
||||
if decoded == nil {
|
||||
return time.Now().Add(1 * time.Hour)
|
||||
}
|
||||
|
||||
var claims struct {
|
||||
Exp float64 `json:"exp"`
|
||||
}
|
||||
if err := json.Unmarshal(decoded, &claims); err != nil || claims.Exp == 0 {
|
||||
return time.Now().Add(1 * time.Hour)
|
||||
}
|
||||
|
||||
sec, frac := math.Modf(claims.Exp)
|
||||
expiry := time.Unix(int64(sec), int64(frac*1e9))
|
||||
// Subtract 5-minute safety margin
|
||||
return expiry.Add(-5 * time.Minute)
|
||||
}
|
||||
|
||||
func minDuration(a, b time.Duration) time.Duration {
|
||||
if a < b {
|
||||
return a
|
||||
}
|
||||
return b
|
||||
}
|
||||
84
internal/auth/cursor/proto/connect.go
Normal file
84
internal/auth/cursor/proto/connect.go
Normal file
@@ -0,0 +1,84 @@
|
||||
package proto
|
||||
|
||||
import (
|
||||
"encoding/binary"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
)
|
||||
|
||||
const (
|
||||
// ConnectEndStreamFlag marks the end-of-stream frame (trailers).
|
||||
ConnectEndStreamFlag byte = 0x02
|
||||
// ConnectCompressionFlag indicates the payload is compressed (not supported).
|
||||
ConnectCompressionFlag byte = 0x01
|
||||
// ConnectFrameHeaderSize is the fixed 5-byte frame header.
|
||||
ConnectFrameHeaderSize = 5
|
||||
)
|
||||
|
||||
// FrameConnectMessage wraps a protobuf payload in a Connect frame.
|
||||
// Frame format: [1 byte flags][4 bytes payload length (big-endian)][payload]
|
||||
func FrameConnectMessage(data []byte, flags byte) []byte {
|
||||
frame := make([]byte, ConnectFrameHeaderSize+len(data))
|
||||
frame[0] = flags
|
||||
binary.BigEndian.PutUint32(frame[1:5], uint32(len(data)))
|
||||
copy(frame[5:], data)
|
||||
return frame
|
||||
}
|
||||
|
||||
// ParseConnectFrame extracts one frame from a buffer.
|
||||
// Returns (flags, payload, bytesConsumed, ok).
|
||||
// ok is false when the buffer is too short for a complete frame.
|
||||
func ParseConnectFrame(buf []byte) (flags byte, payload []byte, consumed int, ok bool) {
|
||||
if len(buf) < ConnectFrameHeaderSize {
|
||||
return 0, nil, 0, false
|
||||
}
|
||||
flags = buf[0]
|
||||
length := binary.BigEndian.Uint32(buf[1:5])
|
||||
total := ConnectFrameHeaderSize + int(length)
|
||||
if len(buf) < total {
|
||||
return 0, nil, 0, false
|
||||
}
|
||||
return flags, buf[5:total], total, true
|
||||
}
|
||||
|
||||
// ConnectError is a structured error from the Connect protocol end-of-stream trailer.
|
||||
// The Code field contains the server-defined error code (e.g. gRPC standard codes
|
||||
// like "resource_exhausted", "unauthenticated", "permission_denied", "unavailable").
|
||||
type ConnectError struct {
|
||||
Code string // server-defined error code
|
||||
Message string // human-readable error description
|
||||
}
|
||||
|
||||
func (e *ConnectError) Error() string {
|
||||
return fmt.Sprintf("Connect error %s: %s", e.Code, e.Message)
|
||||
}
|
||||
|
||||
// ParseConnectEndStream parses a Connect end-of-stream frame payload (JSON).
|
||||
// Returns nil if there is no error in the trailer.
|
||||
// On error, returns a *ConnectError with the server's error code and message.
|
||||
func ParseConnectEndStream(data []byte) error {
|
||||
if len(data) == 0 {
|
||||
return nil
|
||||
}
|
||||
var trailer struct {
|
||||
Error *struct {
|
||||
Code string `json:"code"`
|
||||
Message string `json:"message"`
|
||||
} `json:"error"`
|
||||
}
|
||||
if err := json.Unmarshal(data, &trailer); err != nil {
|
||||
return fmt.Errorf("failed to parse Connect end stream: %w", err)
|
||||
}
|
||||
if trailer.Error != nil {
|
||||
code := trailer.Error.Code
|
||||
if code == "" {
|
||||
code = "unknown"
|
||||
}
|
||||
msg := trailer.Error.Message
|
||||
if msg == "" {
|
||||
msg = "Unknown error"
|
||||
}
|
||||
return &ConnectError{Code: code, Message: msg}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
564
internal/auth/cursor/proto/decode.go
Normal file
564
internal/auth/cursor/proto/decode.go
Normal file
@@ -0,0 +1,564 @@
|
||||
package proto
|
||||
|
||||
import (
|
||||
"encoding/hex"
|
||||
"fmt"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
"google.golang.org/protobuf/encoding/protowire"
|
||||
)
|
||||
|
||||
// ServerMessageType identifies the kind of decoded server message.
|
||||
type ServerMessageType int
|
||||
|
||||
const (
|
||||
ServerMsgUnknown ServerMessageType = iota
|
||||
ServerMsgTextDelta // Text content delta
|
||||
ServerMsgThinkingDelta // Thinking/reasoning delta
|
||||
ServerMsgThinkingCompleted // Thinking completed
|
||||
ServerMsgKvGetBlob // Server wants a blob
|
||||
ServerMsgKvSetBlob // Server wants to store a blob
|
||||
ServerMsgExecRequestCtx // Server requests context (tools, etc.)
|
||||
ServerMsgExecMcpArgs // Server wants MCP tool execution
|
||||
ServerMsgExecShellArgs // Rejected: shell command
|
||||
ServerMsgExecReadArgs // Rejected: file read
|
||||
ServerMsgExecWriteArgs // Rejected: file write
|
||||
ServerMsgExecDeleteArgs // Rejected: file delete
|
||||
ServerMsgExecLsArgs // Rejected: directory listing
|
||||
ServerMsgExecGrepArgs // Rejected: grep search
|
||||
ServerMsgExecFetchArgs // Rejected: HTTP fetch
|
||||
ServerMsgExecDiagnostics // Respond with empty diagnostics
|
||||
ServerMsgExecShellStream // Rejected: shell stream
|
||||
ServerMsgExecBgShellSpawn // Rejected: background shell
|
||||
ServerMsgExecWriteShellStdin // Rejected: write shell stdin
|
||||
ServerMsgExecOther // Other exec types (respond with empty)
|
||||
ServerMsgTurnEnded // Turn has ended (no more output)
|
||||
ServerMsgHeartbeat // Server heartbeat
|
||||
ServerMsgTokenDelta // Token usage delta
|
||||
ServerMsgCheckpoint // Conversation checkpoint update
|
||||
)
|
||||
|
||||
// DecodedServerMessage holds parsed data from an AgentServerMessage.
|
||||
type DecodedServerMessage struct {
|
||||
Type ServerMessageType
|
||||
|
||||
// For text/thinking deltas
|
||||
Text string
|
||||
|
||||
// For KV messages
|
||||
KvId uint32
|
||||
BlobId []byte // hex-encoded blob ID
|
||||
BlobData []byte // for setBlobArgs
|
||||
|
||||
// For exec messages
|
||||
ExecMsgId uint32
|
||||
ExecId string
|
||||
|
||||
// For MCP args
|
||||
McpToolName string
|
||||
McpToolCallId string
|
||||
McpArgs map[string][]byte // arg name -> protobuf-encoded value
|
||||
|
||||
// For rejection context
|
||||
Path string
|
||||
Command string
|
||||
WorkingDirectory string
|
||||
Url string
|
||||
|
||||
// For other exec - the raw field number for building a response
|
||||
ExecFieldNumber int
|
||||
|
||||
// For TokenDeltaUpdate
|
||||
TokenDelta int64
|
||||
|
||||
// For conversation checkpoint update (raw bytes, not decoded)
|
||||
CheckpointData []byte
|
||||
}
|
||||
|
||||
// DecodeAgentServerMessage parses an AgentServerMessage and returns
|
||||
// a structured representation of the first meaningful message found.
|
||||
func DecodeAgentServerMessage(data []byte) (*DecodedServerMessage, error) {
|
||||
msg := &DecodedServerMessage{Type: ServerMsgUnknown}
|
||||
|
||||
for len(data) > 0 {
|
||||
num, typ, n := protowire.ConsumeTag(data)
|
||||
if n < 0 {
|
||||
return msg, fmt.Errorf("invalid tag")
|
||||
}
|
||||
data = data[n:]
|
||||
|
||||
switch typ {
|
||||
case protowire.BytesType:
|
||||
val, n := protowire.ConsumeBytes(data)
|
||||
if n < 0 {
|
||||
return msg, fmt.Errorf("invalid bytes field %d", num)
|
||||
}
|
||||
data = data[n:]
|
||||
|
||||
// Debug: log top-level ASM fields
|
||||
log.Debugf("DecodeAgentServerMessage: found ASM field %d, len=%d", num, len(val))
|
||||
|
||||
switch num {
|
||||
case ASM_InteractionUpdate:
|
||||
log.Debugf("DecodeAgentServerMessage: calling decodeInteractionUpdate")
|
||||
decodeInteractionUpdate(val, msg)
|
||||
case ASM_ExecServerMessage:
|
||||
log.Debugf("DecodeAgentServerMessage: calling decodeExecServerMessage")
|
||||
decodeExecServerMessage(val, msg)
|
||||
case ASM_KvServerMessage:
|
||||
decodeKvServerMessage(val, msg)
|
||||
case ASM_ConversationCheckpoint:
|
||||
msg.Type = ServerMsgCheckpoint
|
||||
msg.CheckpointData = append([]byte(nil), val...) // copy raw bytes
|
||||
log.Debugf("DecodeAgentServerMessage: captured checkpoint %d bytes", len(val))
|
||||
}
|
||||
|
||||
case protowire.VarintType:
|
||||
_, n := protowire.ConsumeVarint(data)
|
||||
if n < 0 {
|
||||
return msg, fmt.Errorf("invalid varint field %d", num)
|
||||
}
|
||||
data = data[n:]
|
||||
|
||||
default:
|
||||
// Skip unknown wire types
|
||||
n := protowire.ConsumeFieldValue(num, typ, data)
|
||||
if n < 0 {
|
||||
return msg, fmt.Errorf("invalid field %d", num)
|
||||
}
|
||||
data = data[n:]
|
||||
}
|
||||
}
|
||||
|
||||
return msg, nil
|
||||
}
|
||||
|
||||
func decodeInteractionUpdate(data []byte, msg *DecodedServerMessage) {
|
||||
log.Debugf("decodeInteractionUpdate: input len=%d, hex=%x", len(data), data)
|
||||
for len(data) > 0 {
|
||||
num, typ, n := protowire.ConsumeTag(data)
|
||||
if n < 0 {
|
||||
log.Debugf("decodeInteractionUpdate: invalid tag, remaining=%x", data)
|
||||
return
|
||||
}
|
||||
data = data[n:]
|
||||
log.Debugf("decodeInteractionUpdate: field=%d wire=%d remaining=%d bytes", num, typ, len(data))
|
||||
|
||||
if typ == protowire.BytesType {
|
||||
val, n := protowire.ConsumeBytes(data)
|
||||
if n < 0 {
|
||||
log.Debugf("decodeInteractionUpdate: invalid bytes field %d", num)
|
||||
return
|
||||
}
|
||||
data = data[n:]
|
||||
log.Debugf("decodeInteractionUpdate: field %d content len=%d, first 20 bytes: %x", num, len(val), val[:min(20, len(val))])
|
||||
|
||||
switch num {
|
||||
case IU_TextDelta:
|
||||
msg.Type = ServerMsgTextDelta
|
||||
msg.Text = decodeStringField(val, TDU_Text)
|
||||
log.Debugf("decodeInteractionUpdate: TextDelta text=%q", msg.Text)
|
||||
case IU_ThinkingDelta:
|
||||
msg.Type = ServerMsgThinkingDelta
|
||||
msg.Text = decodeStringField(val, TKD_Text)
|
||||
log.Debugf("decodeInteractionUpdate: ThinkingDelta text=%q", msg.Text)
|
||||
case IU_ThinkingCompleted:
|
||||
msg.Type = ServerMsgThinkingCompleted
|
||||
log.Debugf("decodeInteractionUpdate: ThinkingCompleted")
|
||||
case 2:
|
||||
// tool_call_started - ignore but log
|
||||
log.Debugf("decodeInteractionUpdate: ToolCallStarted (ignored)")
|
||||
case 3:
|
||||
// tool_call_completed - ignore but log
|
||||
log.Debugf("decodeInteractionUpdate: ToolCallCompleted (ignored)")
|
||||
case 8:
|
||||
// token_delta - extract token count
|
||||
msg.Type = ServerMsgTokenDelta
|
||||
msg.TokenDelta = decodeVarintField(val, 1)
|
||||
log.Debugf("decodeInteractionUpdate: TokenDeltaUpdate tokens=%d", msg.TokenDelta)
|
||||
case 13:
|
||||
// heartbeat from server
|
||||
msg.Type = ServerMsgHeartbeat
|
||||
case 14:
|
||||
// turn_ended - critical: model finished generating
|
||||
msg.Type = ServerMsgTurnEnded
|
||||
log.Debugf("decodeInteractionUpdate: TurnEndedUpdate - stream should end")
|
||||
case 16:
|
||||
// step_started - ignore
|
||||
log.Debugf("decodeInteractionUpdate: StepStartedUpdate (ignored)")
|
||||
case 17:
|
||||
// step_completed - ignore
|
||||
log.Debugf("decodeInteractionUpdate: StepCompletedUpdate (ignored)")
|
||||
default:
|
||||
log.Debugf("decodeInteractionUpdate: unknown field %d", num)
|
||||
}
|
||||
} else {
|
||||
n := protowire.ConsumeFieldValue(num, typ, data)
|
||||
if n < 0 {
|
||||
return
|
||||
}
|
||||
data = data[n:]
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func decodeKvServerMessage(data []byte, msg *DecodedServerMessage) {
|
||||
for len(data) > 0 {
|
||||
num, typ, n := protowire.ConsumeTag(data)
|
||||
if n < 0 {
|
||||
return
|
||||
}
|
||||
data = data[n:]
|
||||
|
||||
switch typ {
|
||||
case protowire.VarintType:
|
||||
val, n := protowire.ConsumeVarint(data)
|
||||
if n < 0 {
|
||||
return
|
||||
}
|
||||
data = data[n:]
|
||||
if num == KSM_Id {
|
||||
msg.KvId = uint32(val)
|
||||
}
|
||||
|
||||
case protowire.BytesType:
|
||||
val, n := protowire.ConsumeBytes(data)
|
||||
if n < 0 {
|
||||
return
|
||||
}
|
||||
data = data[n:]
|
||||
|
||||
switch num {
|
||||
case KSM_GetBlobArgs:
|
||||
msg.Type = ServerMsgKvGetBlob
|
||||
msg.BlobId = decodeBytesField(val, GBA_BlobId)
|
||||
case KSM_SetBlobArgs:
|
||||
msg.Type = ServerMsgKvSetBlob
|
||||
decodeSetBlobArgs(val, msg)
|
||||
}
|
||||
|
||||
default:
|
||||
n := protowire.ConsumeFieldValue(num, typ, data)
|
||||
if n < 0 {
|
||||
return
|
||||
}
|
||||
data = data[n:]
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func decodeSetBlobArgs(data []byte, msg *DecodedServerMessage) {
|
||||
for len(data) > 0 {
|
||||
num, typ, n := protowire.ConsumeTag(data)
|
||||
if n < 0 {
|
||||
return
|
||||
}
|
||||
data = data[n:]
|
||||
|
||||
if typ == protowire.BytesType {
|
||||
val, n := protowire.ConsumeBytes(data)
|
||||
if n < 0 {
|
||||
return
|
||||
}
|
||||
data = data[n:]
|
||||
switch num {
|
||||
case SBA_BlobId:
|
||||
msg.BlobId = val
|
||||
case SBA_BlobData:
|
||||
msg.BlobData = val
|
||||
}
|
||||
} else {
|
||||
n := protowire.ConsumeFieldValue(num, typ, data)
|
||||
if n < 0 {
|
||||
return
|
||||
}
|
||||
data = data[n:]
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func decodeExecServerMessage(data []byte, msg *DecodedServerMessage) {
|
||||
for len(data) > 0 {
|
||||
num, typ, n := protowire.ConsumeTag(data)
|
||||
if n < 0 {
|
||||
return
|
||||
}
|
||||
data = data[n:]
|
||||
|
||||
switch typ {
|
||||
case protowire.VarintType:
|
||||
val, n := protowire.ConsumeVarint(data)
|
||||
if n < 0 {
|
||||
return
|
||||
}
|
||||
data = data[n:]
|
||||
if num == ESM_Id {
|
||||
msg.ExecMsgId = uint32(val)
|
||||
log.Debugf("decodeExecServerMessage: ESM_Id = %d", val)
|
||||
}
|
||||
|
||||
case protowire.BytesType:
|
||||
val, n := protowire.ConsumeBytes(data)
|
||||
if n < 0 {
|
||||
return
|
||||
}
|
||||
data = data[n:]
|
||||
|
||||
// Debug: log all fields found in ExecServerMessage
|
||||
log.Debugf("decodeExecServerMessage: found field %d, len=%d, first 20 bytes: %x", num, len(val), val[:min(20, len(val))])
|
||||
|
||||
switch num {
|
||||
case ESM_ExecId:
|
||||
msg.ExecId = string(val)
|
||||
log.Debugf("decodeExecServerMessage: ESM_ExecId = %q", msg.ExecId)
|
||||
case ESM_RequestContextArgs:
|
||||
msg.Type = ServerMsgExecRequestCtx
|
||||
case ESM_McpArgs:
|
||||
msg.Type = ServerMsgExecMcpArgs
|
||||
decodeMcpArgs(val, msg)
|
||||
case ESM_ShellArgs:
|
||||
msg.Type = ServerMsgExecShellArgs
|
||||
decodeShellArgs(val, msg)
|
||||
case ESM_ShellStreamArgs:
|
||||
msg.Type = ServerMsgExecShellStream
|
||||
decodeShellArgs(val, msg)
|
||||
case ESM_ReadArgs:
|
||||
msg.Type = ServerMsgExecReadArgs
|
||||
msg.Path = decodeStringField(val, RA_Path)
|
||||
case ESM_WriteArgs:
|
||||
msg.Type = ServerMsgExecWriteArgs
|
||||
msg.Path = decodeStringField(val, WA_Path)
|
||||
case ESM_DeleteArgs:
|
||||
msg.Type = ServerMsgExecDeleteArgs
|
||||
msg.Path = decodeStringField(val, DA_Path)
|
||||
case ESM_LsArgs:
|
||||
msg.Type = ServerMsgExecLsArgs
|
||||
msg.Path = decodeStringField(val, LA_Path)
|
||||
case ESM_GrepArgs:
|
||||
msg.Type = ServerMsgExecGrepArgs
|
||||
case ESM_FetchArgs:
|
||||
msg.Type = ServerMsgExecFetchArgs
|
||||
msg.Url = decodeStringField(val, FA_Url)
|
||||
case ESM_DiagnosticsArgs:
|
||||
msg.Type = ServerMsgExecDiagnostics
|
||||
case ESM_BackgroundShellSpawn:
|
||||
msg.Type = ServerMsgExecBgShellSpawn
|
||||
decodeShellArgs(val, msg) // same structure
|
||||
case ESM_WriteShellStdinArgs:
|
||||
msg.Type = ServerMsgExecWriteShellStdin
|
||||
default:
|
||||
// Unknown exec types - only set if we haven't identified the type yet
|
||||
// (other fields like span_context (19) come after the exec type field)
|
||||
if msg.Type == ServerMsgUnknown {
|
||||
msg.Type = ServerMsgExecOther
|
||||
msg.ExecFieldNumber = int(num)
|
||||
}
|
||||
}
|
||||
|
||||
default:
|
||||
n := protowire.ConsumeFieldValue(num, typ, data)
|
||||
if n < 0 {
|
||||
return
|
||||
}
|
||||
data = data[n:]
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func decodeMcpArgs(data []byte, msg *DecodedServerMessage) {
|
||||
msg.McpArgs = make(map[string][]byte)
|
||||
for len(data) > 0 {
|
||||
num, typ, n := protowire.ConsumeTag(data)
|
||||
if n < 0 {
|
||||
return
|
||||
}
|
||||
data = data[n:]
|
||||
|
||||
if typ == protowire.BytesType {
|
||||
val, n := protowire.ConsumeBytes(data)
|
||||
if n < 0 {
|
||||
return
|
||||
}
|
||||
data = data[n:]
|
||||
|
||||
switch num {
|
||||
case MCA_Name:
|
||||
msg.McpToolName = string(val)
|
||||
case MCA_Args:
|
||||
// Map entries are encoded as submessages with key=1, value=2
|
||||
decodeMapEntry(val, msg.McpArgs)
|
||||
case MCA_ToolCallId:
|
||||
msg.McpToolCallId = string(val)
|
||||
case MCA_ToolName:
|
||||
// ToolName takes precedence if present
|
||||
if msg.McpToolName == "" || string(val) != "" {
|
||||
msg.McpToolName = string(val)
|
||||
}
|
||||
}
|
||||
} else {
|
||||
n := protowire.ConsumeFieldValue(num, typ, data)
|
||||
if n < 0 {
|
||||
return
|
||||
}
|
||||
data = data[n:]
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func decodeMapEntry(data []byte, m map[string][]byte) {
|
||||
var key string
|
||||
var value []byte
|
||||
for len(data) > 0 {
|
||||
num, typ, n := protowire.ConsumeTag(data)
|
||||
if n < 0 {
|
||||
return
|
||||
}
|
||||
data = data[n:]
|
||||
|
||||
if typ == protowire.BytesType {
|
||||
val, n := protowire.ConsumeBytes(data)
|
||||
if n < 0 {
|
||||
return
|
||||
}
|
||||
data = data[n:]
|
||||
if num == 1 {
|
||||
key = string(val)
|
||||
} else if num == 2 {
|
||||
value = append([]byte(nil), val...)
|
||||
}
|
||||
} else {
|
||||
n := protowire.ConsumeFieldValue(num, typ, data)
|
||||
if n < 0 {
|
||||
return
|
||||
}
|
||||
data = data[n:]
|
||||
}
|
||||
}
|
||||
if key != "" {
|
||||
m[key] = value
|
||||
}
|
||||
}
|
||||
|
||||
func decodeShellArgs(data []byte, msg *DecodedServerMessage) {
|
||||
for len(data) > 0 {
|
||||
num, typ, n := protowire.ConsumeTag(data)
|
||||
if n < 0 {
|
||||
return
|
||||
}
|
||||
data = data[n:]
|
||||
|
||||
if typ == protowire.BytesType {
|
||||
val, n := protowire.ConsumeBytes(data)
|
||||
if n < 0 {
|
||||
return
|
||||
}
|
||||
data = data[n:]
|
||||
switch num {
|
||||
case SHA_Command:
|
||||
msg.Command = string(val)
|
||||
case SHA_WorkingDirectory:
|
||||
msg.WorkingDirectory = string(val)
|
||||
}
|
||||
} else {
|
||||
n := protowire.ConsumeFieldValue(num, typ, data)
|
||||
if n < 0 {
|
||||
return
|
||||
}
|
||||
data = data[n:]
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// --- Helper decoders ---
|
||||
|
||||
// decodeStringField extracts a string from the first matching field in a submessage.
|
||||
func decodeStringField(data []byte, targetField protowire.Number) string {
|
||||
for len(data) > 0 {
|
||||
num, typ, n := protowire.ConsumeTag(data)
|
||||
if n < 0 {
|
||||
return ""
|
||||
}
|
||||
data = data[n:]
|
||||
|
||||
if typ == protowire.BytesType {
|
||||
val, n := protowire.ConsumeBytes(data)
|
||||
if n < 0 {
|
||||
return ""
|
||||
}
|
||||
data = data[n:]
|
||||
if num == targetField {
|
||||
return string(val)
|
||||
}
|
||||
} else {
|
||||
n := protowire.ConsumeFieldValue(num, typ, data)
|
||||
if n < 0 {
|
||||
return ""
|
||||
}
|
||||
data = data[n:]
|
||||
}
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
// decodeBytesField extracts bytes from the first matching field in a submessage.
|
||||
func decodeBytesField(data []byte, targetField protowire.Number) []byte {
|
||||
for len(data) > 0 {
|
||||
num, typ, n := protowire.ConsumeTag(data)
|
||||
if n < 0 {
|
||||
return nil
|
||||
}
|
||||
data = data[n:]
|
||||
|
||||
if typ == protowire.BytesType {
|
||||
val, n := protowire.ConsumeBytes(data)
|
||||
if n < 0 {
|
||||
return nil
|
||||
}
|
||||
data = data[n:]
|
||||
if num == targetField {
|
||||
return append([]byte(nil), val...)
|
||||
}
|
||||
} else {
|
||||
n := protowire.ConsumeFieldValue(num, typ, data)
|
||||
if n < 0 {
|
||||
return nil
|
||||
}
|
||||
data = data[n:]
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// decodeVarintField extracts an int64 from the first matching varint field in a submessage.
|
||||
func decodeVarintField(data []byte, targetField protowire.Number) int64 {
|
||||
for len(data) > 0 {
|
||||
num, typ, n := protowire.ConsumeTag(data)
|
||||
if n < 0 {
|
||||
return 0
|
||||
}
|
||||
data = data[n:]
|
||||
if typ == protowire.VarintType {
|
||||
val, n := protowire.ConsumeVarint(data)
|
||||
if n < 0 {
|
||||
return 0
|
||||
}
|
||||
data = data[n:]
|
||||
if num == targetField {
|
||||
return int64(val)
|
||||
}
|
||||
} else {
|
||||
n := protowire.ConsumeFieldValue(num, typ, data)
|
||||
if n < 0 {
|
||||
return 0
|
||||
}
|
||||
data = data[n:]
|
||||
}
|
||||
}
|
||||
return 0
|
||||
}
|
||||
|
||||
// BlobIdHex returns the hex string of a blob ID for use as a map key.
|
||||
func BlobIdHex(blobId []byte) string {
|
||||
return hex.EncodeToString(blobId)
|
||||
}
|
||||
|
||||
1244
internal/auth/cursor/proto/descriptor.go
Normal file
1244
internal/auth/cursor/proto/descriptor.go
Normal file
File diff suppressed because it is too large
Load Diff
664
internal/auth/cursor/proto/encode.go
Normal file
664
internal/auth/cursor/proto/encode.go
Normal file
@@ -0,0 +1,664 @@
|
||||
// Package proto provides protobuf encoding for Cursor's gRPC API,
|
||||
// using dynamicpb with the embedded FileDescriptorProto from agent.proto.
|
||||
// This mirrors the cursor-auth TS plugin's use of @bufbuild/protobuf create()+toBinary().
|
||||
package proto
|
||||
|
||||
import (
|
||||
"crypto/sha256"
|
||||
"encoding/hex"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
"google.golang.org/protobuf/encoding/protowire"
|
||||
"google.golang.org/protobuf/proto"
|
||||
"google.golang.org/protobuf/reflect/protoreflect"
|
||||
"google.golang.org/protobuf/types/dynamicpb"
|
||||
"google.golang.org/protobuf/types/known/structpb"
|
||||
)
|
||||
|
||||
// --- Public types ---
|
||||
|
||||
// RunRequestParams holds all data needed to build an AgentRunRequest.
|
||||
type RunRequestParams struct {
|
||||
ModelId string
|
||||
SystemPrompt string
|
||||
UserText string
|
||||
MessageId string
|
||||
ConversationId string
|
||||
Images []ImageData
|
||||
Turns []TurnData
|
||||
McpTools []McpToolDef
|
||||
BlobStore map[string][]byte // hex(sha256) -> data, populated during encoding
|
||||
RawCheckpoint []byte // if non-nil, use as conversation_state directly (from server checkpoint)
|
||||
}
|
||||
|
||||
type ImageData struct {
|
||||
MimeType string
|
||||
Data []byte
|
||||
}
|
||||
|
||||
type TurnData struct {
|
||||
UserText string
|
||||
AssistantText string
|
||||
}
|
||||
|
||||
type McpToolDef struct {
|
||||
Name string
|
||||
Description string
|
||||
InputSchema json.RawMessage
|
||||
}
|
||||
|
||||
// --- Helper: create a dynamic message and set fields ---
|
||||
|
||||
func newMsg(name string) *dynamicpb.Message {
|
||||
return dynamicpb.NewMessage(Msg(name))
|
||||
}
|
||||
|
||||
func field(msg *dynamicpb.Message, name string) protoreflect.FieldDescriptor {
|
||||
return msg.Descriptor().Fields().ByName(protoreflect.Name(name))
|
||||
}
|
||||
|
||||
func setStr(msg *dynamicpb.Message, name, val string) {
|
||||
if val != "" {
|
||||
msg.Set(field(msg, name), protoreflect.ValueOfString(val))
|
||||
}
|
||||
}
|
||||
|
||||
func setBytes(msg *dynamicpb.Message, name string, val []byte) {
|
||||
if len(val) > 0 {
|
||||
msg.Set(field(msg, name), protoreflect.ValueOfBytes(val))
|
||||
}
|
||||
}
|
||||
|
||||
func setUint32(msg *dynamicpb.Message, name string, val uint32) {
|
||||
msg.Set(field(msg, name), protoreflect.ValueOfUint32(val))
|
||||
}
|
||||
|
||||
func setBool(msg *dynamicpb.Message, name string, val bool) {
|
||||
msg.Set(field(msg, name), protoreflect.ValueOfBool(val))
|
||||
}
|
||||
|
||||
func setMsg(msg *dynamicpb.Message, name string, sub *dynamicpb.Message) {
|
||||
msg.Set(field(msg, name), protoreflect.ValueOfMessage(sub.ProtoReflect()))
|
||||
}
|
||||
|
||||
func marshal(msg *dynamicpb.Message) []byte {
|
||||
b, err := proto.Marshal(msg)
|
||||
if err != nil {
|
||||
panic("cursor proto marshal: " + err.Error())
|
||||
}
|
||||
return b
|
||||
}
|
||||
|
||||
// --- Encode functions mirroring cursor-fetch.ts ---
|
||||
|
||||
// EncodeHeartbeat returns an encoded AgentClientMessage with clientHeartbeat.
|
||||
// Mirrors: create(AgentClientMessageSchema, { message: { case: 'clientHeartbeat', value: create(ClientHeartbeatSchema, {}) } })
|
||||
func EncodeHeartbeat() []byte {
|
||||
hb := newMsg("ClientHeartbeat")
|
||||
acm := newMsg("AgentClientMessage")
|
||||
setMsg(acm, "client_heartbeat", hb)
|
||||
return marshal(acm)
|
||||
}
|
||||
|
||||
// EncodeRunRequest builds a full AgentClientMessage wrapping an AgentRunRequest.
|
||||
// Mirrors buildCursorRequest() in cursor-fetch.ts.
|
||||
// If p.RawCheckpoint is set, it is used directly as the conversation_state bytes
|
||||
// (from a previous conversation_checkpoint_update), skipping manual turn construction.
|
||||
func EncodeRunRequest(p *RunRequestParams) []byte {
|
||||
if p.RawCheckpoint != nil {
|
||||
return encodeRunRequestWithCheckpoint(p)
|
||||
}
|
||||
|
||||
if p.BlobStore == nil {
|
||||
p.BlobStore = make(map[string][]byte)
|
||||
}
|
||||
|
||||
// --- Conversation turns ---
|
||||
// Each turn is serialized as bytes (ConversationTurnStructure → bytes)
|
||||
var turnBytes [][]byte
|
||||
for _, turn := range p.Turns {
|
||||
// UserMessage for this turn
|
||||
um := newMsg("UserMessage")
|
||||
setStr(um, "text", turn.UserText)
|
||||
setStr(um, "message_id", generateId())
|
||||
umBytes := marshal(um)
|
||||
|
||||
// Steps (assistant response)
|
||||
var stepBytes [][]byte
|
||||
if turn.AssistantText != "" {
|
||||
am := newMsg("AssistantMessage")
|
||||
setStr(am, "text", turn.AssistantText)
|
||||
step := newMsg("ConversationStep")
|
||||
setMsg(step, "assistant_message", am)
|
||||
stepBytes = append(stepBytes, marshal(step))
|
||||
}
|
||||
|
||||
// AgentConversationTurnStructure (fields are bytes, not submessages)
|
||||
agentTurn := newMsg("AgentConversationTurnStructure")
|
||||
setBytes(agentTurn, "user_message", umBytes)
|
||||
for _, sb := range stepBytes {
|
||||
stepsField := field(agentTurn, "steps")
|
||||
list := agentTurn.Mutable(stepsField).List()
|
||||
list.Append(protoreflect.ValueOfBytes(sb))
|
||||
}
|
||||
|
||||
// ConversationTurnStructure (oneof turn → agentConversationTurn)
|
||||
cts := newMsg("ConversationTurnStructure")
|
||||
setMsg(cts, "agent_conversation_turn", agentTurn)
|
||||
turnBytes = append(turnBytes, marshal(cts))
|
||||
}
|
||||
|
||||
// --- System prompt blob ---
|
||||
systemJSON, _ := json.Marshal(map[string]string{"role": "system", "content": p.SystemPrompt})
|
||||
blobId := sha256Sum(systemJSON)
|
||||
p.BlobStore[hex.EncodeToString(blobId)] = systemJSON
|
||||
|
||||
// --- ConversationStateStructure ---
|
||||
css := newMsg("ConversationStateStructure")
|
||||
// rootPromptMessagesJson: repeated bytes
|
||||
rootField := field(css, "root_prompt_messages_json")
|
||||
rootList := css.Mutable(rootField).List()
|
||||
rootList.Append(protoreflect.ValueOfBytes(blobId))
|
||||
// turns: repeated bytes (field 8) + turns_old (field 2) for compatibility
|
||||
turnsField := field(css, "turns")
|
||||
turnsList := css.Mutable(turnsField).List()
|
||||
for _, tb := range turnBytes {
|
||||
turnsList.Append(protoreflect.ValueOfBytes(tb))
|
||||
}
|
||||
turnsOldField := field(css, "turns_old")
|
||||
if turnsOldField != nil {
|
||||
turnsOldList := css.Mutable(turnsOldField).List()
|
||||
for _, tb := range turnBytes {
|
||||
turnsOldList.Append(protoreflect.ValueOfBytes(tb))
|
||||
}
|
||||
}
|
||||
|
||||
// --- UserMessage (current) ---
|
||||
userMessage := newMsg("UserMessage")
|
||||
setStr(userMessage, "text", p.UserText)
|
||||
setStr(userMessage, "message_id", p.MessageId)
|
||||
|
||||
// Images via SelectedContext
|
||||
if len(p.Images) > 0 {
|
||||
sc := newMsg("SelectedContext")
|
||||
imgsField := field(sc, "selected_images")
|
||||
imgsList := sc.Mutable(imgsField).List()
|
||||
for _, img := range p.Images {
|
||||
si := newMsg("SelectedImage")
|
||||
setStr(si, "uuid", generateId())
|
||||
setStr(si, "mime_type", img.MimeType)
|
||||
setBytes(si, "data", img.Data)
|
||||
imgsList.Append(protoreflect.ValueOfMessage(si.ProtoReflect()))
|
||||
}
|
||||
setMsg(userMessage, "selected_context", sc)
|
||||
}
|
||||
|
||||
// --- UserMessageAction ---
|
||||
uma := newMsg("UserMessageAction")
|
||||
setMsg(uma, "user_message", userMessage)
|
||||
|
||||
// --- ConversationAction ---
|
||||
ca := newMsg("ConversationAction")
|
||||
setMsg(ca, "user_message_action", uma)
|
||||
|
||||
// --- ModelDetails ---
|
||||
md := newMsg("ModelDetails")
|
||||
setStr(md, "model_id", p.ModelId)
|
||||
setStr(md, "display_model_id", p.ModelId)
|
||||
setStr(md, "display_name", p.ModelId)
|
||||
|
||||
// --- AgentRunRequest ---
|
||||
arr := newMsg("AgentRunRequest")
|
||||
setMsg(arr, "conversation_state", css)
|
||||
setMsg(arr, "action", ca)
|
||||
setMsg(arr, "model_details", md)
|
||||
setStr(arr, "conversation_id", p.ConversationId)
|
||||
|
||||
// McpTools
|
||||
if len(p.McpTools) > 0 {
|
||||
mcpTools := newMsg("McpTools")
|
||||
toolsField := field(mcpTools, "mcp_tools")
|
||||
toolsList := mcpTools.Mutable(toolsField).List()
|
||||
for _, tool := range p.McpTools {
|
||||
td := newMsg("McpToolDefinition")
|
||||
setStr(td, "name", tool.Name)
|
||||
setStr(td, "description", tool.Description)
|
||||
if len(tool.InputSchema) > 0 {
|
||||
setBytes(td, "input_schema", jsonToProtobufValueBytes(tool.InputSchema))
|
||||
}
|
||||
setStr(td, "provider_identifier", "proxy")
|
||||
setStr(td, "tool_name", tool.Name)
|
||||
toolsList.Append(protoreflect.ValueOfMessage(td.ProtoReflect()))
|
||||
}
|
||||
setMsg(arr, "mcp_tools", mcpTools)
|
||||
}
|
||||
|
||||
// --- AgentClientMessage ---
|
||||
acm := newMsg("AgentClientMessage")
|
||||
setMsg(acm, "run_request", arr)
|
||||
|
||||
return marshal(acm)
|
||||
}
|
||||
|
||||
// encodeRunRequestWithCheckpoint builds an AgentClientMessage using a raw checkpoint
|
||||
// as conversation_state. The checkpoint bytes are embedded directly without deserialization.
|
||||
func encodeRunRequestWithCheckpoint(p *RunRequestParams) []byte {
|
||||
// Build UserMessage
|
||||
userMessage := newMsg("UserMessage")
|
||||
setStr(userMessage, "text", p.UserText)
|
||||
setStr(userMessage, "message_id", p.MessageId)
|
||||
if len(p.Images) > 0 {
|
||||
sc := newMsg("SelectedContext")
|
||||
imgsField := field(sc, "selected_images")
|
||||
imgsList := sc.Mutable(imgsField).List()
|
||||
for _, img := range p.Images {
|
||||
si := newMsg("SelectedImage")
|
||||
setStr(si, "uuid", generateId())
|
||||
setStr(si, "mime_type", img.MimeType)
|
||||
setBytes(si, "data", img.Data)
|
||||
imgsList.Append(protoreflect.ValueOfMessage(si.ProtoReflect()))
|
||||
}
|
||||
setMsg(userMessage, "selected_context", sc)
|
||||
}
|
||||
|
||||
// Build ConversationAction with UserMessageAction
|
||||
uma := newMsg("UserMessageAction")
|
||||
setMsg(uma, "user_message", userMessage)
|
||||
ca := newMsg("ConversationAction")
|
||||
setMsg(ca, "user_message_action", uma)
|
||||
caBytes := marshal(ca)
|
||||
|
||||
// Build ModelDetails
|
||||
md := newMsg("ModelDetails")
|
||||
setStr(md, "model_id", p.ModelId)
|
||||
setStr(md, "display_model_id", p.ModelId)
|
||||
setStr(md, "display_name", p.ModelId)
|
||||
mdBytes := marshal(md)
|
||||
|
||||
// Build McpTools
|
||||
var mcpToolsBytes []byte
|
||||
if len(p.McpTools) > 0 {
|
||||
mcpTools := newMsg("McpTools")
|
||||
toolsField := field(mcpTools, "mcp_tools")
|
||||
toolsList := mcpTools.Mutable(toolsField).List()
|
||||
for _, tool := range p.McpTools {
|
||||
td := newMsg("McpToolDefinition")
|
||||
setStr(td, "name", tool.Name)
|
||||
setStr(td, "description", tool.Description)
|
||||
if len(tool.InputSchema) > 0 {
|
||||
setBytes(td, "input_schema", jsonToProtobufValueBytes(tool.InputSchema))
|
||||
}
|
||||
setStr(td, "provider_identifier", "proxy")
|
||||
setStr(td, "tool_name", tool.Name)
|
||||
toolsList.Append(protoreflect.ValueOfMessage(td.ProtoReflect()))
|
||||
}
|
||||
mcpToolsBytes = marshal(mcpTools)
|
||||
}
|
||||
|
||||
// Manually assemble AgentRunRequest using protowire to embed raw checkpoint
|
||||
var arrBuf []byte
|
||||
// field 1: conversation_state = raw checkpoint bytes (length-delimited)
|
||||
arrBuf = protowire.AppendTag(arrBuf, ARR_ConversationState, protowire.BytesType)
|
||||
arrBuf = protowire.AppendBytes(arrBuf, p.RawCheckpoint)
|
||||
// field 2: action = ConversationAction
|
||||
arrBuf = protowire.AppendTag(arrBuf, ARR_Action, protowire.BytesType)
|
||||
arrBuf = protowire.AppendBytes(arrBuf, caBytes)
|
||||
// field 3: model_details = ModelDetails
|
||||
arrBuf = protowire.AppendTag(arrBuf, ARR_ModelDetails, protowire.BytesType)
|
||||
arrBuf = protowire.AppendBytes(arrBuf, mdBytes)
|
||||
// field 4: mcp_tools = McpTools
|
||||
if len(mcpToolsBytes) > 0 {
|
||||
arrBuf = protowire.AppendTag(arrBuf, ARR_McpTools, protowire.BytesType)
|
||||
arrBuf = protowire.AppendBytes(arrBuf, mcpToolsBytes)
|
||||
}
|
||||
// field 5: conversation_id = string
|
||||
if p.ConversationId != "" {
|
||||
arrBuf = protowire.AppendTag(arrBuf, ARR_ConversationId, protowire.BytesType)
|
||||
arrBuf = protowire.AppendString(arrBuf, p.ConversationId)
|
||||
}
|
||||
|
||||
// Wrap in AgentClientMessage field 1 (run_request)
|
||||
var acmBuf []byte
|
||||
acmBuf = protowire.AppendTag(acmBuf, ACM_RunRequest, protowire.BytesType)
|
||||
acmBuf = protowire.AppendBytes(acmBuf, arrBuf)
|
||||
|
||||
log.Debugf("cursor encode: built RunRequest with checkpoint (%d bytes), total=%d bytes", len(p.RawCheckpoint), len(acmBuf))
|
||||
return acmBuf
|
||||
}
|
||||
|
||||
// ResumeRequestParams holds data for a ResumeAction request.
|
||||
type ResumeRequestParams struct {
|
||||
ModelId string
|
||||
ConversationId string
|
||||
McpTools []McpToolDef
|
||||
}
|
||||
|
||||
// EncodeResumeRequest builds an AgentClientMessage with ResumeAction.
|
||||
// Used to resume a conversation by conversation_id without re-sending full history.
|
||||
func EncodeResumeRequest(p *ResumeRequestParams) []byte {
|
||||
// RequestContext with tools
|
||||
rc := newMsg("RequestContext")
|
||||
if len(p.McpTools) > 0 {
|
||||
toolsField := field(rc, "tools")
|
||||
toolsList := rc.Mutable(toolsField).List()
|
||||
for _, tool := range p.McpTools {
|
||||
td := newMsg("McpToolDefinition")
|
||||
setStr(td, "name", tool.Name)
|
||||
setStr(td, "description", tool.Description)
|
||||
if len(tool.InputSchema) > 0 {
|
||||
setBytes(td, "input_schema", jsonToProtobufValueBytes(tool.InputSchema))
|
||||
}
|
||||
setStr(td, "provider_identifier", "proxy")
|
||||
setStr(td, "tool_name", tool.Name)
|
||||
toolsList.Append(protoreflect.ValueOfMessage(td.ProtoReflect()))
|
||||
}
|
||||
}
|
||||
|
||||
// ResumeAction
|
||||
ra := newMsg("ResumeAction")
|
||||
setMsg(ra, "request_context", rc)
|
||||
|
||||
// ConversationAction with resume_action
|
||||
ca := newMsg("ConversationAction")
|
||||
setMsg(ca, "resume_action", ra)
|
||||
|
||||
// ModelDetails
|
||||
md := newMsg("ModelDetails")
|
||||
setStr(md, "model_id", p.ModelId)
|
||||
setStr(md, "display_model_id", p.ModelId)
|
||||
setStr(md, "display_name", p.ModelId)
|
||||
|
||||
// AgentRunRequest — no conversation_state needed for resume
|
||||
arr := newMsg("AgentRunRequest")
|
||||
setMsg(arr, "action", ca)
|
||||
setMsg(arr, "model_details", md)
|
||||
setStr(arr, "conversation_id", p.ConversationId)
|
||||
|
||||
// McpTools at top level
|
||||
if len(p.McpTools) > 0 {
|
||||
mcpTools := newMsg("McpTools")
|
||||
toolsField := field(mcpTools, "mcp_tools")
|
||||
toolsList := mcpTools.Mutable(toolsField).List()
|
||||
for _, tool := range p.McpTools {
|
||||
td := newMsg("McpToolDefinition")
|
||||
setStr(td, "name", tool.Name)
|
||||
setStr(td, "description", tool.Description)
|
||||
if len(tool.InputSchema) > 0 {
|
||||
setBytes(td, "input_schema", jsonToProtobufValueBytes(tool.InputSchema))
|
||||
}
|
||||
setStr(td, "provider_identifier", "proxy")
|
||||
setStr(td, "tool_name", tool.Name)
|
||||
toolsList.Append(protoreflect.ValueOfMessage(td.ProtoReflect()))
|
||||
}
|
||||
setMsg(arr, "mcp_tools", mcpTools)
|
||||
}
|
||||
|
||||
acm := newMsg("AgentClientMessage")
|
||||
setMsg(acm, "run_request", arr)
|
||||
return marshal(acm)
|
||||
}
|
||||
|
||||
// --- KV response encoders ---
|
||||
// Mirrors handleKvMessage() in cursor-fetch.ts
|
||||
|
||||
// EncodeKvGetBlobResult responds to a getBlobArgs request.
|
||||
func EncodeKvGetBlobResult(kvId uint32, blobData []byte) []byte {
|
||||
result := newMsg("GetBlobResult")
|
||||
if blobData != nil {
|
||||
setBytes(result, "blob_data", blobData)
|
||||
}
|
||||
|
||||
kvc := newMsg("KvClientMessage")
|
||||
setUint32(kvc, "id", kvId)
|
||||
setMsg(kvc, "get_blob_result", result)
|
||||
|
||||
acm := newMsg("AgentClientMessage")
|
||||
setMsg(acm, "kv_client_message", kvc)
|
||||
return marshal(acm)
|
||||
}
|
||||
|
||||
// EncodeKvSetBlobResult responds to a setBlobArgs request.
|
||||
func EncodeKvSetBlobResult(kvId uint32) []byte {
|
||||
result := newMsg("SetBlobResult")
|
||||
|
||||
kvc := newMsg("KvClientMessage")
|
||||
setUint32(kvc, "id", kvId)
|
||||
setMsg(kvc, "set_blob_result", result)
|
||||
|
||||
acm := newMsg("AgentClientMessage")
|
||||
setMsg(acm, "kv_client_message", kvc)
|
||||
return marshal(acm)
|
||||
}
|
||||
|
||||
// --- Exec response encoders ---
|
||||
// Mirrors handleExecMessage() and sendExec() in cursor-fetch.ts
|
||||
|
||||
// EncodeExecRequestContextResult responds to requestContextArgs with tool definitions.
|
||||
func EncodeExecRequestContextResult(execMsgId uint32, execId string, tools []McpToolDef) []byte {
|
||||
// RequestContext with tools
|
||||
rc := newMsg("RequestContext")
|
||||
if len(tools) > 0 {
|
||||
toolsField := field(rc, "tools")
|
||||
toolsList := rc.Mutable(toolsField).List()
|
||||
for _, tool := range tools {
|
||||
td := newMsg("McpToolDefinition")
|
||||
setStr(td, "name", tool.Name)
|
||||
setStr(td, "description", tool.Description)
|
||||
if len(tool.InputSchema) > 0 {
|
||||
setBytes(td, "input_schema", jsonToProtobufValueBytes(tool.InputSchema))
|
||||
}
|
||||
setStr(td, "provider_identifier", "proxy")
|
||||
setStr(td, "tool_name", tool.Name)
|
||||
toolsList.Append(protoreflect.ValueOfMessage(td.ProtoReflect()))
|
||||
}
|
||||
}
|
||||
|
||||
// RequestContextSuccess
|
||||
rcs := newMsg("RequestContextSuccess")
|
||||
setMsg(rcs, "request_context", rc)
|
||||
|
||||
// RequestContextResult (oneof success)
|
||||
rcr := newMsg("RequestContextResult")
|
||||
setMsg(rcr, "success", rcs)
|
||||
|
||||
return encodeExecClientMsg(execMsgId, execId, "request_context_result", rcr)
|
||||
}
|
||||
|
||||
// EncodeExecMcpResult responds with MCP tool result.
|
||||
func EncodeExecMcpResult(execMsgId uint32, execId string, content string, isError bool) []byte {
|
||||
textContent := newMsg("McpTextContent")
|
||||
setStr(textContent, "text", content)
|
||||
|
||||
contentItem := newMsg("McpToolResultContentItem")
|
||||
setMsg(contentItem, "text", textContent)
|
||||
|
||||
success := newMsg("McpSuccess")
|
||||
contentField := field(success, "content")
|
||||
contentList := success.Mutable(contentField).List()
|
||||
contentList.Append(protoreflect.ValueOfMessage(contentItem.ProtoReflect()))
|
||||
setBool(success, "is_error", isError)
|
||||
|
||||
result := newMsg("McpResult")
|
||||
setMsg(result, "success", success)
|
||||
|
||||
return encodeExecClientMsg(execMsgId, execId, "mcp_result", result)
|
||||
}
|
||||
|
||||
// EncodeExecMcpError responds with MCP error.
|
||||
func EncodeExecMcpError(execMsgId uint32, execId string, errMsg string) []byte {
|
||||
mcpErr := newMsg("McpError")
|
||||
setStr(mcpErr, "error", errMsg)
|
||||
|
||||
result := newMsg("McpResult")
|
||||
setMsg(result, "error", mcpErr)
|
||||
|
||||
return encodeExecClientMsg(execMsgId, execId, "mcp_result", result)
|
||||
}
|
||||
|
||||
// --- Rejection encoders (mirror handleExecMessage rejections) ---
|
||||
|
||||
func EncodeExecReadRejected(execMsgId uint32, execId string, path, reason string) []byte {
|
||||
rej := newMsg("ReadRejected")
|
||||
setStr(rej, "path", path)
|
||||
setStr(rej, "reason", reason)
|
||||
result := newMsg("ReadResult")
|
||||
setMsg(result, "rejected", rej)
|
||||
return encodeExecClientMsg(execMsgId, execId, "read_result", result)
|
||||
}
|
||||
|
||||
func EncodeExecShellRejected(execMsgId uint32, execId string, command, workDir, reason string) []byte {
|
||||
rej := newMsg("ShellRejected")
|
||||
setStr(rej, "command", command)
|
||||
setStr(rej, "working_directory", workDir)
|
||||
setStr(rej, "reason", reason)
|
||||
result := newMsg("ShellResult")
|
||||
setMsg(result, "rejected", rej)
|
||||
return encodeExecClientMsg(execMsgId, execId, "shell_result", result)
|
||||
}
|
||||
|
||||
func EncodeExecWriteRejected(execMsgId uint32, execId string, path, reason string) []byte {
|
||||
rej := newMsg("WriteRejected")
|
||||
setStr(rej, "path", path)
|
||||
setStr(rej, "reason", reason)
|
||||
result := newMsg("WriteResult")
|
||||
setMsg(result, "rejected", rej)
|
||||
return encodeExecClientMsg(execMsgId, execId, "write_result", result)
|
||||
}
|
||||
|
||||
func EncodeExecDeleteRejected(execMsgId uint32, execId string, path, reason string) []byte {
|
||||
rej := newMsg("DeleteRejected")
|
||||
setStr(rej, "path", path)
|
||||
setStr(rej, "reason", reason)
|
||||
result := newMsg("DeleteResult")
|
||||
setMsg(result, "rejected", rej)
|
||||
return encodeExecClientMsg(execMsgId, execId, "delete_result", result)
|
||||
}
|
||||
|
||||
func EncodeExecLsRejected(execMsgId uint32, execId string, path, reason string) []byte {
|
||||
rej := newMsg("LsRejected")
|
||||
setStr(rej, "path", path)
|
||||
setStr(rej, "reason", reason)
|
||||
result := newMsg("LsResult")
|
||||
setMsg(result, "rejected", rej)
|
||||
return encodeExecClientMsg(execMsgId, execId, "ls_result", result)
|
||||
}
|
||||
|
||||
func EncodeExecGrepError(execMsgId uint32, execId string, errMsg string) []byte {
|
||||
grepErr := newMsg("GrepError")
|
||||
setStr(grepErr, "error", errMsg)
|
||||
result := newMsg("GrepResult")
|
||||
setMsg(result, "error", grepErr)
|
||||
return encodeExecClientMsg(execMsgId, execId, "grep_result", result)
|
||||
}
|
||||
|
||||
func EncodeExecFetchError(execMsgId uint32, execId string, url, errMsg string) []byte {
|
||||
fetchErr := newMsg("FetchError")
|
||||
setStr(fetchErr, "url", url)
|
||||
setStr(fetchErr, "error", errMsg)
|
||||
result := newMsg("FetchResult")
|
||||
setMsg(result, "error", fetchErr)
|
||||
return encodeExecClientMsg(execMsgId, execId, "fetch_result", result)
|
||||
}
|
||||
|
||||
func EncodeExecDiagnosticsResult(execMsgId uint32, execId string) []byte {
|
||||
result := newMsg("DiagnosticsResult")
|
||||
return encodeExecClientMsg(execMsgId, execId, "diagnostics_result", result)
|
||||
}
|
||||
|
||||
func EncodeExecBackgroundShellSpawnRejected(execMsgId uint32, execId string, command, workDir, reason string) []byte {
|
||||
rej := newMsg("ShellRejected")
|
||||
setStr(rej, "command", command)
|
||||
setStr(rej, "working_directory", workDir)
|
||||
setStr(rej, "reason", reason)
|
||||
result := newMsg("BackgroundShellSpawnResult")
|
||||
setMsg(result, "rejected", rej)
|
||||
return encodeExecClientMsg(execMsgId, execId, "background_shell_spawn_result", result)
|
||||
}
|
||||
|
||||
func EncodeExecWriteShellStdinError(execMsgId uint32, execId string, errMsg string) []byte {
|
||||
wsErr := newMsg("WriteShellStdinError")
|
||||
setStr(wsErr, "error", errMsg)
|
||||
result := newMsg("WriteShellStdinResult")
|
||||
setMsg(result, "error", wsErr)
|
||||
return encodeExecClientMsg(execMsgId, execId, "write_shell_stdin_result", result)
|
||||
}
|
||||
|
||||
// encodeExecClientMsg wraps an exec result in AgentClientMessage.
|
||||
// Mirrors sendExec() in cursor-fetch.ts.
|
||||
func encodeExecClientMsg(id uint32, execId string, resultFieldName string, resultMsg *dynamicpb.Message) []byte {
|
||||
ecm := newMsg("ExecClientMessage")
|
||||
setUint32(ecm, "id", id)
|
||||
// Force set exec_id even if empty - Cursor requires this field to be set
|
||||
ecm.Set(field(ecm, "exec_id"), protoreflect.ValueOfString(execId))
|
||||
|
||||
// Debug: check if field exists
|
||||
fd := field(ecm, resultFieldName)
|
||||
if fd == nil {
|
||||
panic(fmt.Sprintf("field %q NOT FOUND in ExecClientMessage! Available fields: %v", resultFieldName, listFields(ecm)))
|
||||
}
|
||||
|
||||
// Debug: log the actual field being set
|
||||
log.Debugf("encodeExecClientMsg: setting field %q (number=%d, kind=%s)", fd.Name(), fd.Number(), fd.Kind())
|
||||
|
||||
ecm.Set(fd, protoreflect.ValueOfMessage(resultMsg.ProtoReflect()))
|
||||
|
||||
acm := newMsg("AgentClientMessage")
|
||||
setMsg(acm, "exec_client_message", ecm)
|
||||
return marshal(acm)
|
||||
}
|
||||
|
||||
func listFields(msg *dynamicpb.Message) []string {
|
||||
var names []string
|
||||
for i := 0; i < msg.Descriptor().Fields().Len(); i++ {
|
||||
names = append(names, string(msg.Descriptor().Fields().Get(i).Name()))
|
||||
}
|
||||
return names
|
||||
}
|
||||
|
||||
// --- Utilities ---
|
||||
|
||||
// jsonToProtobufValueBytes converts a JSON schema (json.RawMessage) to protobuf Value binary.
|
||||
// This mirrors the TS pattern: toBinary(ValueSchema, fromJson(ValueSchema, jsonSchema))
|
||||
func jsonToProtobufValueBytes(jsonData json.RawMessage) []byte {
|
||||
if len(jsonData) == 0 {
|
||||
return nil
|
||||
}
|
||||
var v interface{}
|
||||
if err := json.Unmarshal(jsonData, &v); err != nil {
|
||||
return jsonData // fallback to raw JSON if parsing fails
|
||||
}
|
||||
pbVal, err := structpb.NewValue(v)
|
||||
if err != nil {
|
||||
return jsonData // fallback
|
||||
}
|
||||
b, err := proto.Marshal(pbVal)
|
||||
if err != nil {
|
||||
return jsonData // fallback
|
||||
}
|
||||
return b
|
||||
}
|
||||
|
||||
// ProtobufValueBytesToJSON converts protobuf Value binary back to JSON.
|
||||
// This mirrors the TS pattern: toJson(ValueSchema, fromBinary(ValueSchema, value))
|
||||
func ProtobufValueBytesToJSON(data []byte) (interface{}, error) {
|
||||
val := &structpb.Value{}
|
||||
if err := proto.Unmarshal(data, val); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return val.AsInterface(), nil
|
||||
}
|
||||
|
||||
func sha256Sum(data []byte) []byte {
|
||||
h := sha256.Sum256(data)
|
||||
return h[:]
|
||||
}
|
||||
|
||||
var idCounter uint64
|
||||
|
||||
func generateId() string {
|
||||
idCounter++
|
||||
h := sha256.Sum256([]byte{byte(idCounter), byte(idCounter >> 8), byte(idCounter >> 16)})
|
||||
return hex.EncodeToString(h[:16])
|
||||
}
|
||||
332
internal/auth/cursor/proto/fieldnumbers.go
Normal file
332
internal/auth/cursor/proto/fieldnumbers.go
Normal file
@@ -0,0 +1,332 @@
|
||||
// Package proto provides hand-rolled protobuf encode/decode for Cursor's gRPC API.
|
||||
// Field numbers are extracted from the TypeScript generated proto/agent_pb.ts in alma-plugins/cursor-auth.
|
||||
package proto
|
||||
|
||||
// AgentClientMessage (msg 118) oneof "message"
|
||||
const (
|
||||
ACM_RunRequest = 1 // AgentRunRequest
|
||||
ACM_ExecClientMessage = 2 // ExecClientMessage
|
||||
ACM_KvClientMessage = 3 // KvClientMessage
|
||||
ACM_ConversationAction = 4 // ConversationAction
|
||||
ACM_ExecClientControlMsg = 5 // ExecClientControlMessage
|
||||
ACM_InteractionResponse = 6 // InteractionResponse
|
||||
ACM_ClientHeartbeat = 7 // ClientHeartbeat
|
||||
)
|
||||
|
||||
// AgentServerMessage (msg 119) oneof "message"
|
||||
const (
|
||||
ASM_InteractionUpdate = 1 // InteractionUpdate
|
||||
ASM_ExecServerMessage = 2 // ExecServerMessage
|
||||
ASM_ConversationCheckpoint = 3 // ConversationStateStructure
|
||||
ASM_KvServerMessage = 4 // KvServerMessage
|
||||
ASM_ExecServerControlMessage = 5 // ExecServerControlMessage
|
||||
ASM_InteractionQuery = 7 // InteractionQuery
|
||||
)
|
||||
|
||||
// AgentRunRequest (msg 91)
|
||||
const (
|
||||
ARR_ConversationState = 1 // ConversationStateStructure
|
||||
ARR_Action = 2 // ConversationAction
|
||||
ARR_ModelDetails = 3 // ModelDetails
|
||||
ARR_McpTools = 4 // McpTools
|
||||
ARR_ConversationId = 5 // string (optional)
|
||||
)
|
||||
|
||||
// ConversationStateStructure (msg 83)
|
||||
const (
|
||||
CSS_RootPromptMessagesJson = 1 // repeated bytes
|
||||
CSS_TurnsOld = 2 // repeated bytes (deprecated)
|
||||
CSS_Todos = 3 // repeated bytes
|
||||
CSS_PendingToolCalls = 4 // repeated string
|
||||
CSS_Turns = 8 // repeated bytes (CURRENT field for turns)
|
||||
CSS_PreviousWorkspaceUris = 9 // repeated string
|
||||
CSS_SelfSummaryCount = 17 // uint32
|
||||
CSS_ReadPaths = 18 // repeated string
|
||||
)
|
||||
|
||||
// ConversationAction (msg 54) oneof "action"
|
||||
const (
|
||||
CA_UserMessageAction = 1 // UserMessageAction
|
||||
)
|
||||
|
||||
// UserMessageAction (msg 55)
|
||||
const (
|
||||
UMA_UserMessage = 1 // UserMessage
|
||||
)
|
||||
|
||||
// UserMessage (msg 63)
|
||||
const (
|
||||
UM_Text = 1 // string
|
||||
UM_MessageId = 2 // string
|
||||
UM_SelectedContext = 3 // SelectedContext (optional)
|
||||
)
|
||||
|
||||
// SelectedContext
|
||||
const (
|
||||
SC_SelectedImages = 1 // repeated SelectedImage
|
||||
)
|
||||
|
||||
// SelectedImage
|
||||
const (
|
||||
SI_BlobId = 1 // bytes (oneof dataOrBlobId)
|
||||
SI_Uuid = 2 // string
|
||||
SI_Path = 3 // string
|
||||
SI_MimeType = 7 // string
|
||||
SI_Data = 8 // bytes (oneof dataOrBlobId)
|
||||
)
|
||||
|
||||
// ModelDetails (msg 88)
|
||||
const (
|
||||
MD_ModelId = 1 // string
|
||||
MD_ThinkingDetails = 2 // ThinkingDetails (optional)
|
||||
MD_DisplayModelId = 3 // string
|
||||
MD_DisplayName = 4 // string
|
||||
)
|
||||
|
||||
// McpTools (msg 307)
|
||||
const (
|
||||
MT_McpTools = 1 // repeated McpToolDefinition
|
||||
)
|
||||
|
||||
// McpToolDefinition (msg 306)
|
||||
const (
|
||||
MTD_Name = 1 // string
|
||||
MTD_Description = 2 // string
|
||||
MTD_InputSchema = 3 // bytes
|
||||
MTD_ProviderIdentifier = 4 // string
|
||||
MTD_ToolName = 5 // string
|
||||
)
|
||||
|
||||
// ConversationTurnStructure (msg 70) oneof "turn"
|
||||
const (
|
||||
CTS_AgentConversationTurn = 1 // AgentConversationTurnStructure
|
||||
)
|
||||
|
||||
// AgentConversationTurnStructure (msg 72)
|
||||
const (
|
||||
ACTS_UserMessage = 1 // bytes (serialized UserMessage)
|
||||
ACTS_Steps = 2 // repeated bytes (serialized ConversationStep)
|
||||
)
|
||||
|
||||
// ConversationStep (msg 53) oneof "message"
|
||||
const (
|
||||
CS_AssistantMessage = 1 // AssistantMessage
|
||||
)
|
||||
|
||||
// AssistantMessage
|
||||
const (
|
||||
AM_Text = 1 // string
|
||||
)
|
||||
|
||||
// --- Server-side message fields ---
|
||||
|
||||
// InteractionUpdate oneof "message"
|
||||
const (
|
||||
IU_TextDelta = 1 // TextDeltaUpdate
|
||||
IU_ThinkingDelta = 4 // ThinkingDeltaUpdate
|
||||
IU_ThinkingCompleted = 5 // ThinkingCompletedUpdate
|
||||
)
|
||||
|
||||
// TextDeltaUpdate (msg 92)
|
||||
const (
|
||||
TDU_Text = 1 // string
|
||||
)
|
||||
|
||||
// ThinkingDeltaUpdate (msg 97)
|
||||
const (
|
||||
TKD_Text = 1 // string
|
||||
)
|
||||
|
||||
// KvServerMessage (msg 271)
|
||||
const (
|
||||
KSM_Id = 1 // uint32
|
||||
KSM_GetBlobArgs = 2 // GetBlobArgs
|
||||
KSM_SetBlobArgs = 3 // SetBlobArgs
|
||||
)
|
||||
|
||||
// GetBlobArgs (msg 267)
|
||||
const (
|
||||
GBA_BlobId = 1 // bytes
|
||||
)
|
||||
|
||||
// SetBlobArgs (msg 269)
|
||||
const (
|
||||
SBA_BlobId = 1 // bytes
|
||||
SBA_BlobData = 2 // bytes
|
||||
)
|
||||
|
||||
// KvClientMessage (msg 272)
|
||||
const (
|
||||
KCM_Id = 1 // uint32
|
||||
KCM_GetBlobResult = 2 // GetBlobResult
|
||||
KCM_SetBlobResult = 3 // SetBlobResult
|
||||
)
|
||||
|
||||
// GetBlobResult (msg 268)
|
||||
const (
|
||||
GBR_BlobData = 1 // bytes (optional)
|
||||
)
|
||||
|
||||
// ExecServerMessage
|
||||
const (
|
||||
ESM_Id = 1 // uint32
|
||||
ESM_ExecId = 15 // string
|
||||
// oneof message:
|
||||
ESM_ShellArgs = 2 // ShellArgs
|
||||
ESM_WriteArgs = 3 // WriteArgs
|
||||
ESM_DeleteArgs = 4 // DeleteArgs
|
||||
ESM_GrepArgs = 5 // GrepArgs
|
||||
ESM_ReadArgs = 7 // ReadArgs (NOTE: 6 is skipped)
|
||||
ESM_LsArgs = 8 // LsArgs
|
||||
ESM_DiagnosticsArgs = 9 // DiagnosticsArgs
|
||||
ESM_RequestContextArgs = 10 // RequestContextArgs
|
||||
ESM_McpArgs = 11 // McpArgs
|
||||
ESM_ShellStreamArgs = 14 // ShellArgs (stream variant)
|
||||
ESM_BackgroundShellSpawn = 16 // BackgroundShellSpawnArgs
|
||||
ESM_FetchArgs = 20 // FetchArgs
|
||||
ESM_WriteShellStdinArgs = 23 // WriteShellStdinArgs
|
||||
)
|
||||
|
||||
// ExecClientMessage
|
||||
const (
|
||||
ECM_Id = 1 // uint32
|
||||
ECM_ExecId = 15 // string
|
||||
// oneof message (mirrors server fields):
|
||||
ECM_ShellResult = 2
|
||||
ECM_WriteResult = 3
|
||||
ECM_DeleteResult = 4
|
||||
ECM_GrepResult = 5
|
||||
ECM_ReadResult = 7
|
||||
ECM_LsResult = 8
|
||||
ECM_DiagnosticsResult = 9
|
||||
ECM_RequestContextResult = 10
|
||||
ECM_McpResult = 11
|
||||
ECM_ShellStream = 14
|
||||
ECM_BackgroundShellSpawnRes = 16
|
||||
ECM_FetchResult = 20
|
||||
ECM_WriteShellStdinResult = 23
|
||||
)
|
||||
|
||||
// McpArgs
|
||||
const (
|
||||
MCA_Name = 1 // string
|
||||
MCA_Args = 2 // map<string, bytes>
|
||||
MCA_ToolCallId = 3 // string
|
||||
MCA_ProviderIdentifier = 4 // string
|
||||
MCA_ToolName = 5 // string
|
||||
)
|
||||
|
||||
// RequestContextResult oneof "result"
|
||||
const (
|
||||
RCR_Success = 1 // RequestContextSuccess
|
||||
RCR_Error = 2 // RequestContextError
|
||||
)
|
||||
|
||||
// RequestContextSuccess (msg 337)
|
||||
const (
|
||||
RCS_RequestContext = 1 // RequestContext
|
||||
)
|
||||
|
||||
// RequestContext
|
||||
const (
|
||||
RC_Rules = 2 // repeated CursorRule
|
||||
RC_Tools = 7 // repeated McpToolDefinition
|
||||
)
|
||||
|
||||
// McpResult oneof "result"
|
||||
const (
|
||||
MCR_Success = 1 // McpSuccess
|
||||
MCR_Error = 2 // McpError
|
||||
MCR_Rejected = 3 // McpRejected
|
||||
)
|
||||
|
||||
// McpSuccess (msg 290)
|
||||
const (
|
||||
MCS_Content = 1 // repeated McpToolResultContentItem
|
||||
MCS_IsError = 2 // bool
|
||||
)
|
||||
|
||||
// McpToolResultContentItem oneof "content"
|
||||
const (
|
||||
MTRCI_Text = 1 // McpTextContent
|
||||
)
|
||||
|
||||
// McpTextContent (msg 287)
|
||||
const (
|
||||
MTC_Text = 1 // string
|
||||
)
|
||||
|
||||
// McpError (msg 291)
|
||||
const (
|
||||
MCE_Error = 1 // string
|
||||
)
|
||||
|
||||
// --- Rejection messages ---
|
||||
|
||||
// ReadRejected: path=1, reason=2
|
||||
// ShellRejected: command=1, workingDirectory=2, reason=3, isReadonly=4
|
||||
// WriteRejected: path=1, reason=2
|
||||
// DeleteRejected: path=1, reason=2
|
||||
// LsRejected: path=1, reason=2
|
||||
// GrepError: error=1
|
||||
// FetchError: url=1, error=2
|
||||
// WriteShellStdinError: error=1
|
||||
|
||||
// ReadResult oneof: success=1, error=2, rejected=3
|
||||
// ShellResult oneof: success=1 (+ various), rejected=?
|
||||
// The TS code uses specific result field numbers from the oneof:
|
||||
const (
|
||||
RR_Rejected = 3 // ReadResult.rejected
|
||||
SR_Rejected = 5 // ShellResult.rejected (from TS: ShellResult has success/various/rejected)
|
||||
WR_Rejected = 5 // WriteResult.rejected
|
||||
DR_Rejected = 3 // DeleteResult.rejected
|
||||
LR_Rejected = 3 // LsResult.rejected
|
||||
GR_Error = 2 // GrepResult.error
|
||||
FR_Error = 2 // FetchResult.error
|
||||
BSSR_Rejected = 2 // BackgroundShellSpawnResult.rejected (error field)
|
||||
WSSR_Error = 2 // WriteShellStdinResult.error
|
||||
)
|
||||
|
||||
// --- Rejection struct fields ---
|
||||
const (
|
||||
REJ_Path = 1
|
||||
REJ_Reason = 2
|
||||
SREJ_Command = 1
|
||||
SREJ_WorkingDir = 2
|
||||
SREJ_Reason = 3
|
||||
SREJ_IsReadonly = 4
|
||||
GERR_Error = 1
|
||||
FERR_Url = 1
|
||||
FERR_Error = 2
|
||||
)
|
||||
|
||||
// ReadArgs
|
||||
const (
|
||||
RA_Path = 1 // string
|
||||
)
|
||||
|
||||
// WriteArgs
|
||||
const (
|
||||
WA_Path = 1 // string
|
||||
)
|
||||
|
||||
// DeleteArgs
|
||||
const (
|
||||
DA_Path = 1 // string
|
||||
)
|
||||
|
||||
// LsArgs
|
||||
const (
|
||||
LA_Path = 1 // string
|
||||
)
|
||||
|
||||
// ShellArgs
|
||||
const (
|
||||
SHA_Command = 1 // string
|
||||
SHA_WorkingDirectory = 2 // string
|
||||
)
|
||||
|
||||
// FetchArgs
|
||||
const (
|
||||
FA_Url = 1 // string
|
||||
)
|
||||
313
internal/auth/cursor/proto/h2stream.go
Normal file
313
internal/auth/cursor/proto/h2stream.go
Normal file
@@ -0,0 +1,313 @@
|
||||
package proto
|
||||
|
||||
import (
|
||||
"crypto/tls"
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
"golang.org/x/net/http2"
|
||||
"golang.org/x/net/http2/hpack"
|
||||
)
|
||||
|
||||
const (
|
||||
defaultInitialWindowSize = 65535 // HTTP/2 default
|
||||
maxFramePayload = 16384 // HTTP/2 default max frame size
|
||||
)
|
||||
|
||||
// H2Stream provides bidirectional HTTP/2 streaming for the Connect protocol.
|
||||
// Go's net/http does not support full-duplex HTTP/2, so we use the low-level framer.
|
||||
type H2Stream struct {
|
||||
framer *http2.Framer
|
||||
conn net.Conn
|
||||
streamID uint32
|
||||
mu sync.Mutex
|
||||
id string // unique identifier for debugging
|
||||
frameNum int64 // sequential frame counter for debugging
|
||||
|
||||
dataCh chan []byte
|
||||
doneCh chan struct{}
|
||||
err error
|
||||
|
||||
// Send-side flow control
|
||||
sendWindow int32 // available bytes we can send on this stream
|
||||
connWindow int32 // available bytes on the connection level
|
||||
windowCond *sync.Cond // signaled when window is updated
|
||||
windowMu sync.Mutex // protects sendWindow, connWindow
|
||||
}
|
||||
|
||||
// ID returns the unique identifier for this stream (for logging).
|
||||
func (s *H2Stream) ID() string { return s.id }
|
||||
|
||||
// FrameNum returns the current frame number for debugging.
|
||||
func (s *H2Stream) FrameNum() int64 {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
return s.frameNum
|
||||
}
|
||||
|
||||
// DialH2Stream establishes a TLS+HTTP/2 connection and opens a new stream.
|
||||
func DialH2Stream(host string, headers map[string]string) (*H2Stream, error) {
|
||||
tlsConn, err := tls.Dial("tcp", host+":443", &tls.Config{
|
||||
NextProtos: []string{"h2"},
|
||||
})
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("h2: TLS dial failed: %w", err)
|
||||
}
|
||||
if tlsConn.ConnectionState().NegotiatedProtocol != "h2" {
|
||||
tlsConn.Close()
|
||||
return nil, fmt.Errorf("h2: server did not negotiate h2")
|
||||
}
|
||||
|
||||
framer := http2.NewFramer(tlsConn, tlsConn)
|
||||
|
||||
// Client connection preface
|
||||
if _, err := tlsConn.Write([]byte(http2.ClientPreface)); err != nil {
|
||||
tlsConn.Close()
|
||||
return nil, fmt.Errorf("h2: preface write failed: %w", err)
|
||||
}
|
||||
|
||||
// Send initial SETTINGS (tell server how much WE can receive)
|
||||
if err := framer.WriteSettings(
|
||||
http2.Setting{ID: http2.SettingInitialWindowSize, Val: 4 * 1024 * 1024},
|
||||
http2.Setting{ID: http2.SettingMaxConcurrentStreams, Val: 100},
|
||||
); err != nil {
|
||||
tlsConn.Close()
|
||||
return nil, fmt.Errorf("h2: settings write failed: %w", err)
|
||||
}
|
||||
|
||||
// Connection-level window update (for receiving)
|
||||
if err := framer.WriteWindowUpdate(0, 3*1024*1024); err != nil {
|
||||
tlsConn.Close()
|
||||
return nil, fmt.Errorf("h2: window update failed: %w", err)
|
||||
}
|
||||
|
||||
// Read and handle initial server frames (SETTINGS, WINDOW_UPDATE)
|
||||
// Track server's initial window size (how much WE can send)
|
||||
serverInitialWindowSize := int32(defaultInitialWindowSize)
|
||||
connWindowSize := int32(defaultInitialWindowSize) // connection-level send window
|
||||
for i := 0; i < 10; i++ {
|
||||
f, err := framer.ReadFrame()
|
||||
if err != nil {
|
||||
tlsConn.Close()
|
||||
return nil, fmt.Errorf("h2: initial frame read failed: %w", err)
|
||||
}
|
||||
switch sf := f.(type) {
|
||||
case *http2.SettingsFrame:
|
||||
if !sf.IsAck() {
|
||||
sf.ForeachSetting(func(s http2.Setting) error {
|
||||
if s.ID == http2.SettingInitialWindowSize {
|
||||
serverInitialWindowSize = int32(s.Val)
|
||||
log.Debugf("h2: server initial window size: %d", s.Val)
|
||||
}
|
||||
return nil
|
||||
})
|
||||
framer.WriteSettingsAck()
|
||||
} else {
|
||||
goto handshakeDone
|
||||
}
|
||||
case *http2.WindowUpdateFrame:
|
||||
if sf.StreamID == 0 {
|
||||
connWindowSize += int32(sf.Increment)
|
||||
log.Debugf("h2: initial conn window update: +%d, total=%d", sf.Increment, connWindowSize)
|
||||
}
|
||||
default:
|
||||
// unexpected but continue
|
||||
}
|
||||
}
|
||||
handshakeDone:
|
||||
|
||||
// Build HEADERS
|
||||
streamID := uint32(1)
|
||||
var hdrBuf []byte
|
||||
enc := hpack.NewEncoder(&sliceWriter{buf: &hdrBuf})
|
||||
enc.WriteField(hpack.HeaderField{Name: ":method", Value: "POST"})
|
||||
enc.WriteField(hpack.HeaderField{Name: ":scheme", Value: "https"})
|
||||
enc.WriteField(hpack.HeaderField{Name: ":authority", Value: host})
|
||||
if p, ok := headers[":path"]; ok {
|
||||
enc.WriteField(hpack.HeaderField{Name: ":path", Value: p})
|
||||
}
|
||||
for k, v := range headers {
|
||||
if len(k) > 0 && k[0] == ':' {
|
||||
continue
|
||||
}
|
||||
enc.WriteField(hpack.HeaderField{Name: k, Value: v})
|
||||
}
|
||||
|
||||
if err := framer.WriteHeaders(http2.HeadersFrameParam{
|
||||
StreamID: streamID,
|
||||
BlockFragment: hdrBuf,
|
||||
EndStream: false,
|
||||
EndHeaders: true,
|
||||
}); err != nil {
|
||||
tlsConn.Close()
|
||||
return nil, fmt.Errorf("h2: headers write failed: %w", err)
|
||||
}
|
||||
|
||||
s := &H2Stream{
|
||||
framer: framer,
|
||||
conn: tlsConn,
|
||||
streamID: streamID,
|
||||
dataCh: make(chan []byte, 256),
|
||||
doneCh: make(chan struct{}),
|
||||
id: fmt.Sprintf("%d-%s", streamID, time.Now().Format("150405.000")),
|
||||
frameNum: 0,
|
||||
sendWindow: serverInitialWindowSize,
|
||||
connWindow: connWindowSize,
|
||||
}
|
||||
s.windowCond = sync.NewCond(&s.windowMu)
|
||||
go s.readLoop()
|
||||
return s, nil
|
||||
}
|
||||
|
||||
// Write sends a DATA frame on the stream, respecting flow control.
|
||||
func (s *H2Stream) Write(data []byte) error {
|
||||
for len(data) > 0 {
|
||||
chunk := data
|
||||
if len(chunk) > maxFramePayload {
|
||||
chunk = data[:maxFramePayload]
|
||||
}
|
||||
|
||||
// Wait for flow control window
|
||||
s.windowMu.Lock()
|
||||
for s.sendWindow <= 0 || s.connWindow <= 0 {
|
||||
s.windowCond.Wait()
|
||||
}
|
||||
// Limit chunk to available window
|
||||
allowed := int(s.sendWindow)
|
||||
if int(s.connWindow) < allowed {
|
||||
allowed = int(s.connWindow)
|
||||
}
|
||||
if len(chunk) > allowed {
|
||||
chunk = chunk[:allowed]
|
||||
}
|
||||
s.sendWindow -= int32(len(chunk))
|
||||
s.connWindow -= int32(len(chunk))
|
||||
s.windowMu.Unlock()
|
||||
|
||||
s.mu.Lock()
|
||||
err := s.framer.WriteData(s.streamID, false, chunk)
|
||||
s.mu.Unlock()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
data = data[len(chunk):]
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Data returns the channel of received data chunks.
|
||||
func (s *H2Stream) Data() <-chan []byte { return s.dataCh }
|
||||
|
||||
// Done returns a channel closed when the stream ends.
|
||||
func (s *H2Stream) Done() <-chan struct{} { return s.doneCh }
|
||||
|
||||
// Err returns the error (if any) that caused the stream to close.
|
||||
// Returns nil for a clean shutdown (EOF / StreamEnded).
|
||||
func (s *H2Stream) Err() error { return s.err }
|
||||
|
||||
// Close tears down the connection.
|
||||
func (s *H2Stream) Close() {
|
||||
s.conn.Close()
|
||||
// Unblock any writers waiting on flow control
|
||||
s.windowCond.Broadcast()
|
||||
}
|
||||
|
||||
func (s *H2Stream) readLoop() {
|
||||
defer close(s.doneCh)
|
||||
defer close(s.dataCh)
|
||||
|
||||
for {
|
||||
f, err := s.framer.ReadFrame()
|
||||
if err != nil {
|
||||
if err != io.EOF {
|
||||
s.err = err
|
||||
log.Debugf("h2stream[%s]: readLoop error: %v", s.id, err)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// Increment frame counter
|
||||
s.mu.Lock()
|
||||
s.frameNum++
|
||||
s.mu.Unlock()
|
||||
|
||||
switch frame := f.(type) {
|
||||
case *http2.DataFrame:
|
||||
if frame.StreamID == s.streamID && len(frame.Data()) > 0 {
|
||||
cp := make([]byte, len(frame.Data()))
|
||||
copy(cp, frame.Data())
|
||||
s.dataCh <- cp
|
||||
|
||||
// Flow control: send WINDOW_UPDATE for received data
|
||||
s.mu.Lock()
|
||||
s.framer.WriteWindowUpdate(0, uint32(len(cp)))
|
||||
s.framer.WriteWindowUpdate(s.streamID, uint32(len(cp)))
|
||||
s.mu.Unlock()
|
||||
}
|
||||
if frame.StreamEnded() {
|
||||
return
|
||||
}
|
||||
|
||||
case *http2.HeadersFrame:
|
||||
if frame.StreamEnded() {
|
||||
return
|
||||
}
|
||||
|
||||
case *http2.RSTStreamFrame:
|
||||
s.err = fmt.Errorf("h2: RST_STREAM code=%d", frame.ErrCode)
|
||||
log.Debugf("h2stream[%s]: received RST_STREAM code=%d", s.id, frame.ErrCode)
|
||||
return
|
||||
|
||||
case *http2.GoAwayFrame:
|
||||
s.err = fmt.Errorf("h2: GOAWAY code=%d", frame.ErrCode)
|
||||
return
|
||||
|
||||
case *http2.PingFrame:
|
||||
if !frame.IsAck() {
|
||||
s.mu.Lock()
|
||||
s.framer.WritePing(true, frame.Data)
|
||||
s.mu.Unlock()
|
||||
}
|
||||
|
||||
case *http2.SettingsFrame:
|
||||
if !frame.IsAck() {
|
||||
// Check for window size changes
|
||||
frame.ForeachSetting(func(setting http2.Setting) error {
|
||||
if setting.ID == http2.SettingInitialWindowSize {
|
||||
s.windowMu.Lock()
|
||||
delta := int32(setting.Val) - s.sendWindow
|
||||
s.sendWindow += delta
|
||||
s.windowMu.Unlock()
|
||||
s.windowCond.Broadcast()
|
||||
}
|
||||
return nil
|
||||
})
|
||||
s.mu.Lock()
|
||||
s.framer.WriteSettingsAck()
|
||||
s.mu.Unlock()
|
||||
}
|
||||
|
||||
case *http2.WindowUpdateFrame:
|
||||
// Update send-side flow control window
|
||||
s.windowMu.Lock()
|
||||
if frame.StreamID == 0 {
|
||||
s.connWindow += int32(frame.Increment)
|
||||
} else if frame.StreamID == s.streamID {
|
||||
s.sendWindow += int32(frame.Increment)
|
||||
}
|
||||
s.windowMu.Unlock()
|
||||
s.windowCond.Broadcast()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
type sliceWriter struct{ buf *[]byte }
|
||||
|
||||
func (w *sliceWriter) Write(p []byte) (int, error) {
|
||||
*w.buf = append(*w.buf, p...)
|
||||
return len(p), nil
|
||||
}
|
||||
@@ -24,6 +24,7 @@ func newAuthManager() *sdkAuth.Manager {
|
||||
sdkAuth.NewKiloAuthenticator(),
|
||||
sdkAuth.NewGitLabAuthenticator(),
|
||||
sdkAuth.NewCodeBuddyAuthenticator(),
|
||||
sdkAuth.NewCursorAuthenticator(),
|
||||
)
|
||||
return manager
|
||||
}
|
||||
|
||||
37
internal/cmd/cursor_login.go
Normal file
37
internal/cmd/cursor_login.go
Normal file
@@ -0,0 +1,37 @@
|
||||
package cmd
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
|
||||
sdkAuth "github.com/router-for-me/CLIProxyAPI/v6/sdk/auth"
|
||||
log "github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
// DoCursorLogin triggers the OAuth PKCE flow for Cursor and saves tokens.
|
||||
func DoCursorLogin(cfg *config.Config, options *LoginOptions) {
|
||||
if options == nil {
|
||||
options = &LoginOptions{}
|
||||
}
|
||||
|
||||
manager := newAuthManager()
|
||||
authOpts := &sdkAuth.LoginOptions{
|
||||
NoBrowser: options.NoBrowser,
|
||||
Metadata: map[string]string{},
|
||||
Prompt: options.Prompt,
|
||||
}
|
||||
|
||||
record, savedPath, err := manager.Login(context.Background(), "cursor", cfg, authOpts)
|
||||
if err != nil {
|
||||
log.Errorf("Cursor authentication failed: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
if savedPath != "" {
|
||||
log.Infof("Authentication saved to %s", savedPath)
|
||||
}
|
||||
if record != nil && record.Label != "" {
|
||||
log.Infof("Authenticated as %s", record.Label)
|
||||
}
|
||||
log.Info("Cursor authentication successful!")
|
||||
}
|
||||
@@ -231,11 +231,25 @@ func GetStaticModelDefinitionsByChannel(channel string) []*ModelInfo {
|
||||
return GetAntigravityModels()
|
||||
case "codebuddy":
|
||||
return GetCodeBuddyModels()
|
||||
case "cursor":
|
||||
return GetCursorModels()
|
||||
default:
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
// GetCursorModels returns the fallback Cursor model definitions.
|
||||
func GetCursorModels() []*ModelInfo {
|
||||
return []*ModelInfo{
|
||||
{ID: "composer-2", Object: "model", OwnedBy: "cursor", Type: "cursor", DisplayName: "Composer 2", ContextLength: 200000, MaxCompletionTokens: 64000, Thinking: &ThinkingSupport{Max: 50000, DynamicAllowed: true}},
|
||||
{ID: "claude-4-sonnet", Object: "model", OwnedBy: "cursor", Type: "cursor", DisplayName: "Claude 4 Sonnet", ContextLength: 200000, MaxCompletionTokens: 64000, Thinking: &ThinkingSupport{Max: 50000, DynamicAllowed: true}},
|
||||
{ID: "claude-3.5-sonnet", Object: "model", OwnedBy: "cursor", Type: "cursor", DisplayName: "Claude 3.5 Sonnet", ContextLength: 200000, MaxCompletionTokens: 8192},
|
||||
{ID: "gpt-4o", Object: "model", OwnedBy: "cursor", Type: "cursor", DisplayName: "GPT-4o", ContextLength: 128000, MaxCompletionTokens: 16384},
|
||||
{ID: "cursor-small", Object: "model", OwnedBy: "cursor", Type: "cursor", DisplayName: "Cursor Small", ContextLength: 200000, MaxCompletionTokens: 64000},
|
||||
{ID: "gemini-2.5-pro", Object: "model", OwnedBy: "cursor", Type: "cursor", DisplayName: "Gemini 2.5 Pro", ContextLength: 1000000, MaxCompletionTokens: 65536, Thinking: &ThinkingSupport{Max: 50000, DynamicAllowed: true}},
|
||||
}
|
||||
}
|
||||
|
||||
// LookupStaticModelInfo searches all static model definitions for a model by ID.
|
||||
// Returns nil if no matching model is found.
|
||||
func LookupStaticModelInfo(modelID string) *ModelInfo {
|
||||
@@ -260,6 +274,7 @@ func LookupStaticModelInfo(modelID string) *ModelInfo {
|
||||
GetKiloModels(),
|
||||
GetAmazonQModels(),
|
||||
GetCodeBuddyModels(),
|
||||
GetCursorModels(),
|
||||
}
|
||||
for _, models := range allModels {
|
||||
for _, m := range models {
|
||||
|
||||
125
internal/runtime/executor/codex_continuity.go
Normal file
125
internal/runtime/executor/codex_continuity.go
Normal file
@@ -0,0 +1,125 @@
|
||||
package executor
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"strings"
|
||||
|
||||
"github.com/google/uuid"
|
||||
cliproxyauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth"
|
||||
cliproxyexecutor "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/executor"
|
||||
log "github.com/sirupsen/logrus"
|
||||
"github.com/tidwall/gjson"
|
||||
"github.com/tidwall/sjson"
|
||||
)
|
||||
|
||||
type codexContinuity struct {
|
||||
Key string
|
||||
Source string
|
||||
}
|
||||
|
||||
func metadataString(meta map[string]any, key string) string {
|
||||
if len(meta) == 0 {
|
||||
return ""
|
||||
}
|
||||
raw, ok := meta[key]
|
||||
if !ok || raw == nil {
|
||||
return ""
|
||||
}
|
||||
switch v := raw.(type) {
|
||||
case string:
|
||||
return strings.TrimSpace(v)
|
||||
case []byte:
|
||||
return strings.TrimSpace(string(v))
|
||||
default:
|
||||
return ""
|
||||
}
|
||||
}
|
||||
|
||||
func principalString(raw any) string {
|
||||
switch v := raw.(type) {
|
||||
case string:
|
||||
return strings.TrimSpace(v)
|
||||
case fmt.Stringer:
|
||||
return strings.TrimSpace(v.String())
|
||||
default:
|
||||
return strings.TrimSpace(fmt.Sprintf("%v", raw))
|
||||
}
|
||||
}
|
||||
|
||||
func resolveCodexContinuity(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) codexContinuity {
|
||||
if promptCacheKey := strings.TrimSpace(gjson.GetBytes(req.Payload, "prompt_cache_key").String()); promptCacheKey != "" {
|
||||
return codexContinuity{Key: promptCacheKey, Source: "prompt_cache_key"}
|
||||
}
|
||||
if executionSession := metadataString(opts.Metadata, cliproxyexecutor.ExecutionSessionMetadataKey); executionSession != "" {
|
||||
return codexContinuity{Key: executionSession, Source: "execution_session"}
|
||||
}
|
||||
if ginCtx := ginContextFrom(ctx); ginCtx != nil {
|
||||
if ginCtx.Request != nil {
|
||||
if v := strings.TrimSpace(ginCtx.GetHeader("Idempotency-Key")); v != "" {
|
||||
return codexContinuity{Key: v, Source: "idempotency_key"}
|
||||
}
|
||||
}
|
||||
if v, exists := ginCtx.Get("apiKey"); exists && v != nil {
|
||||
if trimmed := principalString(v); trimmed != "" {
|
||||
return codexContinuity{Key: uuid.NewSHA1(uuid.NameSpaceOID, []byte("cli-proxy-api:codex:prompt-cache:"+trimmed)).String(), Source: "client_principal"}
|
||||
}
|
||||
}
|
||||
}
|
||||
if auth != nil {
|
||||
if authID := strings.TrimSpace(auth.ID); authID != "" {
|
||||
return codexContinuity{Key: uuid.NewSHA1(uuid.NameSpaceOID, []byte("cli-proxy-api:codex:prompt-cache:auth:"+authID)).String(), Source: "auth_id"}
|
||||
}
|
||||
}
|
||||
return codexContinuity{}
|
||||
}
|
||||
|
||||
func applyCodexContinuityBody(rawJSON []byte, continuity codexContinuity) []byte {
|
||||
if continuity.Key == "" {
|
||||
return rawJSON
|
||||
}
|
||||
rawJSON, _ = sjson.SetBytes(rawJSON, "prompt_cache_key", continuity.Key)
|
||||
return rawJSON
|
||||
}
|
||||
|
||||
func applyCodexContinuityHeaders(headers http.Header, continuity codexContinuity) {
|
||||
if headers == nil || continuity.Key == "" {
|
||||
return
|
||||
}
|
||||
headers.Set("session_id", continuity.Key)
|
||||
}
|
||||
|
||||
func logCodexRequestDiagnostics(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options, headers http.Header, body []byte, continuity codexContinuity) {
|
||||
if !log.IsLevelEnabled(log.DebugLevel) {
|
||||
return
|
||||
}
|
||||
entry := logWithRequestID(ctx)
|
||||
authID := ""
|
||||
authFile := ""
|
||||
if auth != nil {
|
||||
authID = strings.TrimSpace(auth.ID)
|
||||
authFile = strings.TrimSpace(auth.FileName)
|
||||
}
|
||||
selectedAuthID := metadataString(opts.Metadata, cliproxyexecutor.SelectedAuthMetadataKey)
|
||||
executionSessionID := metadataString(opts.Metadata, cliproxyexecutor.ExecutionSessionMetadataKey)
|
||||
entry.Debugf(
|
||||
"codex request diagnostics auth_id=%s selected_auth_id=%s auth_file=%s exec_session=%s continuity_source=%s session_id=%s prompt_cache_key=%s prompt_cache_retention=%s store=%t has_instructions=%t reasoning_effort=%s reasoning_summary=%s chatgpt_account_id=%t originator=%s model=%s source_format=%s",
|
||||
authID,
|
||||
selectedAuthID,
|
||||
authFile,
|
||||
executionSessionID,
|
||||
continuity.Source,
|
||||
strings.TrimSpace(headers.Get("session_id")),
|
||||
gjson.GetBytes(body, "prompt_cache_key").String(),
|
||||
gjson.GetBytes(body, "prompt_cache_retention").String(),
|
||||
gjson.GetBytes(body, "store").Bool(),
|
||||
gjson.GetBytes(body, "instructions").Exists(),
|
||||
gjson.GetBytes(body, "reasoning.effort").String(),
|
||||
gjson.GetBytes(body, "reasoning.summary").String(),
|
||||
strings.TrimSpace(headers.Get("Chatgpt-Account-Id")) != "",
|
||||
strings.TrimSpace(headers.Get("Originator")),
|
||||
req.Model,
|
||||
opts.SourceFormat.String(),
|
||||
)
|
||||
}
|
||||
@@ -111,18 +111,19 @@ func (e *CodexExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, re
|
||||
body, _ = sjson.SetBytes(body, "model", baseModel)
|
||||
body, _ = sjson.SetBytes(body, "stream", true)
|
||||
body, _ = sjson.DeleteBytes(body, "previous_response_id")
|
||||
body, _ = sjson.DeleteBytes(body, "prompt_cache_retention")
|
||||
body, _ = sjson.DeleteBytes(body, "safety_identifier")
|
||||
body, _ = sjson.DeleteBytes(body, "stream_options")
|
||||
if !gjson.GetBytes(body, "instructions").Exists() {
|
||||
body, _ = sjson.SetBytes(body, "instructions", "")
|
||||
}
|
||||
|
||||
url := strings.TrimSuffix(baseURL, "/") + "/responses"
|
||||
httpReq, err := e.cacheHelper(ctx, from, url, req, body)
|
||||
httpReq, continuity, err := e.cacheHelper(ctx, auth, from, url, req, opts, body)
|
||||
if err != nil {
|
||||
return resp, err
|
||||
}
|
||||
applyCodexHeaders(httpReq, auth, apiKey, true, e.cfg)
|
||||
logCodexRequestDiagnostics(ctx, auth, req, opts, httpReq.Header, body, continuity)
|
||||
var authID, authLabel, authType, authValue string
|
||||
if auth != nil {
|
||||
authID = auth.ID
|
||||
@@ -222,11 +223,12 @@ func (e *CodexExecutor) executeCompact(ctx context.Context, auth *cliproxyauth.A
|
||||
body, _ = sjson.DeleteBytes(body, "stream")
|
||||
|
||||
url := strings.TrimSuffix(baseURL, "/") + "/responses/compact"
|
||||
httpReq, err := e.cacheHelper(ctx, from, url, req, body)
|
||||
httpReq, continuity, err := e.cacheHelper(ctx, auth, from, url, req, opts, body)
|
||||
if err != nil {
|
||||
return resp, err
|
||||
}
|
||||
applyCodexHeaders(httpReq, auth, apiKey, false, e.cfg)
|
||||
logCodexRequestDiagnostics(ctx, auth, req, opts, httpReq.Header, body, continuity)
|
||||
var authID, authLabel, authType, authValue string
|
||||
if auth != nil {
|
||||
authID = auth.ID
|
||||
@@ -309,19 +311,20 @@ func (e *CodexExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Au
|
||||
requestedModel := payloadRequestedModel(opts, req.Model)
|
||||
body = applyPayloadConfigWithRoot(e.cfg, baseModel, to.String(), "", body, originalTranslated, requestedModel)
|
||||
body, _ = sjson.DeleteBytes(body, "previous_response_id")
|
||||
body, _ = sjson.DeleteBytes(body, "prompt_cache_retention")
|
||||
body, _ = sjson.DeleteBytes(body, "safety_identifier")
|
||||
body, _ = sjson.DeleteBytes(body, "stream_options")
|
||||
body, _ = sjson.SetBytes(body, "model", baseModel)
|
||||
if !gjson.GetBytes(body, "instructions").Exists() {
|
||||
body, _ = sjson.SetBytes(body, "instructions", "")
|
||||
}
|
||||
|
||||
url := strings.TrimSuffix(baseURL, "/") + "/responses"
|
||||
httpReq, err := e.cacheHelper(ctx, from, url, req, body)
|
||||
httpReq, continuity, err := e.cacheHelper(ctx, auth, from, url, req, opts, body)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
applyCodexHeaders(httpReq, auth, apiKey, true, e.cfg)
|
||||
logCodexRequestDiagnostics(ctx, auth, req, opts, httpReq.Header, body, continuity)
|
||||
var authID, authLabel, authType, authValue string
|
||||
if auth != nil {
|
||||
authID = auth.ID
|
||||
@@ -415,6 +418,7 @@ func (e *CodexExecutor) CountTokens(ctx context.Context, auth *cliproxyauth.Auth
|
||||
body, _ = sjson.DeleteBytes(body, "previous_response_id")
|
||||
body, _ = sjson.DeleteBytes(body, "prompt_cache_retention")
|
||||
body, _ = sjson.DeleteBytes(body, "safety_identifier")
|
||||
body, _ = sjson.DeleteBytes(body, "stream_options")
|
||||
body, _ = sjson.SetBytes(body, "stream", false)
|
||||
if !gjson.GetBytes(body, "instructions").Exists() {
|
||||
body, _ = sjson.SetBytes(body, "instructions", "")
|
||||
@@ -596,8 +600,9 @@ func (e *CodexExecutor) Refresh(ctx context.Context, auth *cliproxyauth.Auth) (*
|
||||
return auth, nil
|
||||
}
|
||||
|
||||
func (e *CodexExecutor) cacheHelper(ctx context.Context, from sdktranslator.Format, url string, req cliproxyexecutor.Request, rawJSON []byte) (*http.Request, error) {
|
||||
func (e *CodexExecutor) cacheHelper(ctx context.Context, auth *cliproxyauth.Auth, from sdktranslator.Format, url string, req cliproxyexecutor.Request, opts cliproxyexecutor.Options, rawJSON []byte) (*http.Request, codexContinuity, error) {
|
||||
var cache codexCache
|
||||
continuity := codexContinuity{}
|
||||
if from == "claude" {
|
||||
userIDResult := gjson.GetBytes(req.Payload, "metadata.user_id")
|
||||
if userIDResult.Exists() {
|
||||
@@ -610,30 +615,26 @@ func (e *CodexExecutor) cacheHelper(ctx context.Context, from sdktranslator.Form
|
||||
}
|
||||
setCodexCache(key, cache)
|
||||
}
|
||||
continuity = codexContinuity{Key: cache.ID, Source: "claude_user_cache"}
|
||||
}
|
||||
} else if from == "openai-response" {
|
||||
promptCacheKey := gjson.GetBytes(req.Payload, "prompt_cache_key")
|
||||
if promptCacheKey.Exists() {
|
||||
cache.ID = promptCacheKey.String()
|
||||
continuity = codexContinuity{Key: cache.ID, Source: "prompt_cache_key"}
|
||||
}
|
||||
} else if from == "openai" {
|
||||
if apiKey := strings.TrimSpace(apiKeyFromContext(ctx)); apiKey != "" {
|
||||
cache.ID = uuid.NewSHA1(uuid.NameSpaceOID, []byte("cli-proxy-api:codex:prompt-cache:"+apiKey)).String()
|
||||
}
|
||||
continuity = resolveCodexContinuity(ctx, auth, req, opts)
|
||||
cache.ID = continuity.Key
|
||||
}
|
||||
|
||||
if cache.ID != "" {
|
||||
rawJSON, _ = sjson.SetBytes(rawJSON, "prompt_cache_key", cache.ID)
|
||||
}
|
||||
rawJSON = applyCodexContinuityBody(rawJSON, continuity)
|
||||
httpReq, err := http.NewRequestWithContext(ctx, http.MethodPost, url, bytes.NewReader(rawJSON))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
return nil, continuity, err
|
||||
}
|
||||
if cache.ID != "" {
|
||||
httpReq.Header.Set("Conversation_id", cache.ID)
|
||||
httpReq.Header.Set("Session_id", cache.ID)
|
||||
}
|
||||
return httpReq, nil
|
||||
applyCodexContinuityHeaders(httpReq.Header, continuity)
|
||||
return httpReq, continuity, nil
|
||||
}
|
||||
|
||||
func applyCodexHeaders(r *http.Request, auth *cliproxyauth.Auth, token string, stream bool, cfg *config.Config) {
|
||||
@@ -646,7 +647,7 @@ func applyCodexHeaders(r *http.Request, auth *cliproxyauth.Auth, token string, s
|
||||
}
|
||||
|
||||
misc.EnsureHeader(r.Header, ginHeaders, "Version", "")
|
||||
misc.EnsureHeader(r.Header, ginHeaders, "Session_id", uuid.NewString())
|
||||
misc.EnsureHeader(r.Header, ginHeaders, "session_id", uuid.NewString())
|
||||
misc.EnsureHeader(r.Header, ginHeaders, "X-Codex-Turn-Metadata", "")
|
||||
misc.EnsureHeader(r.Header, ginHeaders, "X-Client-Request-Id", "")
|
||||
cfgUserAgent, _ := codexHeaderDefaults(cfg, auth)
|
||||
@@ -685,13 +686,39 @@ func applyCodexHeaders(r *http.Request, auth *cliproxyauth.Auth, token string, s
|
||||
}
|
||||
|
||||
func newCodexStatusErr(statusCode int, body []byte) statusErr {
|
||||
err := statusErr{code: statusCode, msg: string(body)}
|
||||
if retryAfter := parseCodexRetryAfter(statusCode, body, time.Now()); retryAfter != nil {
|
||||
errCode := statusCode
|
||||
if isCodexModelCapacityError(body) {
|
||||
errCode = http.StatusTooManyRequests
|
||||
}
|
||||
err := statusErr{code: errCode, msg: string(body)}
|
||||
if retryAfter := parseCodexRetryAfter(errCode, body, time.Now()); retryAfter != nil {
|
||||
err.retryAfter = retryAfter
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
func isCodexModelCapacityError(errorBody []byte) bool {
|
||||
if len(errorBody) == 0 {
|
||||
return false
|
||||
}
|
||||
candidates := []string{
|
||||
gjson.GetBytes(errorBody, "error.message").String(),
|
||||
gjson.GetBytes(errorBody, "message").String(),
|
||||
string(errorBody),
|
||||
}
|
||||
for _, candidate := range candidates {
|
||||
lower := strings.ToLower(strings.TrimSpace(candidate))
|
||||
if lower == "" {
|
||||
continue
|
||||
}
|
||||
if strings.Contains(lower, "selected model is at capacity") ||
|
||||
strings.Contains(lower, "model is at capacity. please try a different model") {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func parseCodexRetryAfter(statusCode int, errorBody []byte, now time.Time) *time.Duration {
|
||||
if statusCode != http.StatusTooManyRequests || len(errorBody) == 0 {
|
||||
return nil
|
||||
|
||||
@@ -8,6 +8,7 @@ import (
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/google/uuid"
|
||||
cliproxyauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth"
|
||||
cliproxyexecutor "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/executor"
|
||||
sdktranslator "github.com/router-for-me/CLIProxyAPI/v6/sdk/translator"
|
||||
"github.com/tidwall/gjson"
|
||||
@@ -27,7 +28,7 @@ func TestCodexExecutorCacheHelper_OpenAIChatCompletions_StablePromptCacheKeyFrom
|
||||
}
|
||||
url := "https://example.com/responses"
|
||||
|
||||
httpReq, err := executor.cacheHelper(ctx, sdktranslator.FromString("openai"), url, req, rawJSON)
|
||||
httpReq, _, err := executor.cacheHelper(ctx, nil, sdktranslator.FromString("openai"), url, req, cliproxyexecutor.Options{}, rawJSON)
|
||||
if err != nil {
|
||||
t.Fatalf("cacheHelper error: %v", err)
|
||||
}
|
||||
@@ -42,14 +43,14 @@ func TestCodexExecutorCacheHelper_OpenAIChatCompletions_StablePromptCacheKeyFrom
|
||||
if gotKey != expectedKey {
|
||||
t.Fatalf("prompt_cache_key = %q, want %q", gotKey, expectedKey)
|
||||
}
|
||||
if gotConversation := httpReq.Header.Get("Conversation_id"); gotConversation != expectedKey {
|
||||
t.Fatalf("Conversation_id = %q, want %q", gotConversation, expectedKey)
|
||||
if gotSession := httpReq.Header.Get("session_id"); gotSession != expectedKey {
|
||||
t.Fatalf("session_id = %q, want %q", gotSession, expectedKey)
|
||||
}
|
||||
if gotSession := httpReq.Header.Get("Session_id"); gotSession != expectedKey {
|
||||
t.Fatalf("Session_id = %q, want %q", gotSession, expectedKey)
|
||||
if got := httpReq.Header.Get("Conversation_id"); got != "" {
|
||||
t.Fatalf("Conversation_id = %q, want empty", got)
|
||||
}
|
||||
|
||||
httpReq2, err := executor.cacheHelper(ctx, sdktranslator.FromString("openai"), url, req, rawJSON)
|
||||
httpReq2, _, err := executor.cacheHelper(ctx, nil, sdktranslator.FromString("openai"), url, req, cliproxyexecutor.Options{}, rawJSON)
|
||||
if err != nil {
|
||||
t.Fatalf("cacheHelper error (second call): %v", err)
|
||||
}
|
||||
@@ -62,3 +63,118 @@ func TestCodexExecutorCacheHelper_OpenAIChatCompletions_StablePromptCacheKeyFrom
|
||||
t.Fatalf("prompt_cache_key (second call) = %q, want %q", gotKey2, expectedKey)
|
||||
}
|
||||
}
|
||||
|
||||
func TestCodexExecutorCacheHelper_OpenAIResponses_PreservesPromptCacheRetention(t *testing.T) {
|
||||
executor := &CodexExecutor{}
|
||||
url := "https://example.com/responses"
|
||||
req := cliproxyexecutor.Request{
|
||||
Model: "gpt-5.3-codex",
|
||||
Payload: []byte(`{"model":"gpt-5.3-codex","prompt_cache_key":"cache-key-1","prompt_cache_retention":"persistent"}`),
|
||||
}
|
||||
rawJSON := []byte(`{"model":"gpt-5.3-codex","stream":true,"prompt_cache_retention":"persistent"}`)
|
||||
|
||||
httpReq, _, err := executor.cacheHelper(context.Background(), nil, sdktranslator.FromString("openai-response"), url, req, cliproxyexecutor.Options{}, rawJSON)
|
||||
if err != nil {
|
||||
t.Fatalf("cacheHelper error: %v", err)
|
||||
}
|
||||
|
||||
body, err := io.ReadAll(httpReq.Body)
|
||||
if err != nil {
|
||||
t.Fatalf("read request body: %v", err)
|
||||
}
|
||||
|
||||
if got := gjson.GetBytes(body, "prompt_cache_key").String(); got != "cache-key-1" {
|
||||
t.Fatalf("prompt_cache_key = %q, want %q", got, "cache-key-1")
|
||||
}
|
||||
if got := gjson.GetBytes(body, "prompt_cache_retention").String(); got != "persistent" {
|
||||
t.Fatalf("prompt_cache_retention = %q, want %q", got, "persistent")
|
||||
}
|
||||
if got := httpReq.Header.Get("session_id"); got != "cache-key-1" {
|
||||
t.Fatalf("session_id = %q, want %q", got, "cache-key-1")
|
||||
}
|
||||
if got := httpReq.Header.Get("Conversation_id"); got != "" {
|
||||
t.Fatalf("Conversation_id = %q, want empty", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestCodexExecutorCacheHelper_OpenAIChatCompletions_UsesExecutionSessionForContinuity(t *testing.T) {
|
||||
executor := &CodexExecutor{}
|
||||
rawJSON := []byte(`{"model":"gpt-5.4","stream":true}`)
|
||||
req := cliproxyexecutor.Request{
|
||||
Model: "gpt-5.4",
|
||||
Payload: []byte(`{"model":"gpt-5.4"}`),
|
||||
}
|
||||
opts := cliproxyexecutor.Options{Metadata: map[string]any{cliproxyexecutor.ExecutionSessionMetadataKey: "exec-session-1"}}
|
||||
|
||||
httpReq, _, err := executor.cacheHelper(context.Background(), nil, sdktranslator.FromString("openai"), "https://example.com/responses", req, opts, rawJSON)
|
||||
if err != nil {
|
||||
t.Fatalf("cacheHelper error: %v", err)
|
||||
}
|
||||
|
||||
body, err := io.ReadAll(httpReq.Body)
|
||||
if err != nil {
|
||||
t.Fatalf("read request body: %v", err)
|
||||
}
|
||||
|
||||
if got := gjson.GetBytes(body, "prompt_cache_key").String(); got != "exec-session-1" {
|
||||
t.Fatalf("prompt_cache_key = %q, want %q", got, "exec-session-1")
|
||||
}
|
||||
if got := httpReq.Header.Get("session_id"); got != "exec-session-1" {
|
||||
t.Fatalf("session_id = %q, want %q", got, "exec-session-1")
|
||||
}
|
||||
}
|
||||
|
||||
func TestCodexExecutorCacheHelper_OpenAIChatCompletions_FallsBackToStableAuthID(t *testing.T) {
|
||||
executor := &CodexExecutor{}
|
||||
rawJSON := []byte(`{"model":"gpt-5.4","stream":true}`)
|
||||
req := cliproxyexecutor.Request{
|
||||
Model: "gpt-5.4",
|
||||
Payload: []byte(`{"model":"gpt-5.4"}`),
|
||||
}
|
||||
auth := &cliproxyauth.Auth{ID: "codex-auth-1", Provider: "codex"}
|
||||
|
||||
httpReq, _, err := executor.cacheHelper(context.Background(), auth, sdktranslator.FromString("openai"), "https://example.com/responses", req, cliproxyexecutor.Options{}, rawJSON)
|
||||
if err != nil {
|
||||
t.Fatalf("cacheHelper error: %v", err)
|
||||
}
|
||||
|
||||
body, err := io.ReadAll(httpReq.Body)
|
||||
if err != nil {
|
||||
t.Fatalf("read request body: %v", err)
|
||||
}
|
||||
|
||||
expected := uuid.NewSHA1(uuid.NameSpaceOID, []byte("cli-proxy-api:codex:prompt-cache:auth:codex-auth-1")).String()
|
||||
if got := gjson.GetBytes(body, "prompt_cache_key").String(); got != expected {
|
||||
t.Fatalf("prompt_cache_key = %q, want %q", got, expected)
|
||||
}
|
||||
if got := httpReq.Header.Get("session_id"); got != expected {
|
||||
t.Fatalf("session_id = %q, want %q", got, expected)
|
||||
}
|
||||
}
|
||||
|
||||
func TestCodexExecutorCacheHelper_ClaudePreservesCacheContinuity(t *testing.T) {
|
||||
executor := &CodexExecutor{}
|
||||
req := cliproxyexecutor.Request{
|
||||
Model: "claude-3-7-sonnet",
|
||||
Payload: []byte(`{"metadata":{"user_id":"user-1"}}`),
|
||||
}
|
||||
rawJSON := []byte(`{"model":"gpt-5.4","stream":true}`)
|
||||
|
||||
httpReq, continuity, err := executor.cacheHelper(context.Background(), nil, sdktranslator.FromString("claude"), "https://example.com/responses", req, cliproxyexecutor.Options{}, rawJSON)
|
||||
if err != nil {
|
||||
t.Fatalf("cacheHelper error: %v", err)
|
||||
}
|
||||
if continuity.Key == "" {
|
||||
t.Fatal("continuity.Key = empty, want non-empty")
|
||||
}
|
||||
body, err := io.ReadAll(httpReq.Body)
|
||||
if err != nil {
|
||||
t.Fatalf("read request body: %v", err)
|
||||
}
|
||||
if got := gjson.GetBytes(body, "prompt_cache_key").String(); got != continuity.Key {
|
||||
t.Fatalf("prompt_cache_key = %q, want %q", got, continuity.Key)
|
||||
}
|
||||
if got := httpReq.Header.Get("session_id"); got != continuity.Key {
|
||||
t.Fatalf("session_id = %q, want %q", got, continuity.Key)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -60,6 +60,19 @@ func TestParseCodexRetryAfter(t *testing.T) {
|
||||
})
|
||||
}
|
||||
|
||||
func TestNewCodexStatusErrTreatsCapacityAsRetryableRateLimit(t *testing.T) {
|
||||
body := []byte(`{"error":{"message":"Selected model is at capacity. Please try a different model."}}`)
|
||||
|
||||
err := newCodexStatusErr(http.StatusBadRequest, body)
|
||||
|
||||
if got := err.StatusCode(); got != http.StatusTooManyRequests {
|
||||
t.Fatalf("status code = %d, want %d", got, http.StatusTooManyRequests)
|
||||
}
|
||||
if err.RetryAfter() != nil {
|
||||
t.Fatalf("expected nil explicit retryAfter for capacity fallback, got %v", *err.RetryAfter())
|
||||
}
|
||||
}
|
||||
|
||||
func itoa(v int64) string {
|
||||
return strconv.FormatInt(v, 10)
|
||||
}
|
||||
|
||||
@@ -178,7 +178,6 @@ func (e *CodexWebsocketsExecutor) Execute(ctx context.Context, auth *cliproxyaut
|
||||
body, _ = sjson.SetBytes(body, "model", baseModel)
|
||||
body, _ = sjson.SetBytes(body, "stream", true)
|
||||
body, _ = sjson.DeleteBytes(body, "previous_response_id")
|
||||
body, _ = sjson.DeleteBytes(body, "prompt_cache_retention")
|
||||
body, _ = sjson.DeleteBytes(body, "safety_identifier")
|
||||
if !gjson.GetBytes(body, "instructions").Exists() {
|
||||
body, _ = sjson.SetBytes(body, "instructions", "")
|
||||
@@ -190,7 +189,7 @@ func (e *CodexWebsocketsExecutor) Execute(ctx context.Context, auth *cliproxyaut
|
||||
return resp, err
|
||||
}
|
||||
|
||||
body, wsHeaders := applyCodexPromptCacheHeaders(from, req, body)
|
||||
body, wsHeaders, continuity := applyCodexPromptCacheHeaders(ctx, auth, from, req, opts, body)
|
||||
wsHeaders = applyCodexWebsocketHeaders(ctx, wsHeaders, auth, apiKey, e.cfg)
|
||||
|
||||
var authID, authLabel, authType, authValue string
|
||||
@@ -209,6 +208,7 @@ func (e *CodexWebsocketsExecutor) Execute(ctx context.Context, auth *cliproxyaut
|
||||
}
|
||||
|
||||
wsReqBody := buildCodexWebsocketRequestBody(body)
|
||||
logCodexRequestDiagnostics(ctx, auth, req, opts, wsHeaders, body, continuity)
|
||||
recordAPIRequest(ctx, e.cfg, upstreamRequestLog{
|
||||
URL: wsURL,
|
||||
Method: "WEBSOCKET",
|
||||
@@ -385,7 +385,7 @@ func (e *CodexWebsocketsExecutor) ExecuteStream(ctx context.Context, auth *clipr
|
||||
return nil, err
|
||||
}
|
||||
|
||||
body, wsHeaders := applyCodexPromptCacheHeaders(from, req, body)
|
||||
body, wsHeaders, continuity := applyCodexPromptCacheHeaders(ctx, auth, from, req, opts, body)
|
||||
wsHeaders = applyCodexWebsocketHeaders(ctx, wsHeaders, auth, apiKey, e.cfg)
|
||||
|
||||
var authID, authLabel, authType, authValue string
|
||||
@@ -403,6 +403,7 @@ func (e *CodexWebsocketsExecutor) ExecuteStream(ctx context.Context, auth *clipr
|
||||
}
|
||||
|
||||
wsReqBody := buildCodexWebsocketRequestBody(body)
|
||||
logCodexRequestDiagnostics(ctx, auth, req, opts, wsHeaders, body, continuity)
|
||||
recordAPIRequest(ctx, e.cfg, upstreamRequestLog{
|
||||
URL: wsURL,
|
||||
Method: "WEBSOCKET",
|
||||
@@ -761,13 +762,14 @@ func buildCodexResponsesWebsocketURL(httpURL string) (string, error) {
|
||||
return parsed.String(), nil
|
||||
}
|
||||
|
||||
func applyCodexPromptCacheHeaders(from sdktranslator.Format, req cliproxyexecutor.Request, rawJSON []byte) ([]byte, http.Header) {
|
||||
func applyCodexPromptCacheHeaders(ctx context.Context, auth *cliproxyauth.Auth, from sdktranslator.Format, req cliproxyexecutor.Request, opts cliproxyexecutor.Options, rawJSON []byte) ([]byte, http.Header, codexContinuity) {
|
||||
headers := http.Header{}
|
||||
if len(rawJSON) == 0 {
|
||||
return rawJSON, headers
|
||||
return rawJSON, headers, codexContinuity{}
|
||||
}
|
||||
|
||||
var cache codexCache
|
||||
continuity := codexContinuity{}
|
||||
if from == "claude" {
|
||||
userIDResult := gjson.GetBytes(req.Payload, "metadata.user_id")
|
||||
if userIDResult.Exists() {
|
||||
@@ -781,20 +783,22 @@ func applyCodexPromptCacheHeaders(from sdktranslator.Format, req cliproxyexecuto
|
||||
}
|
||||
setCodexCache(key, cache)
|
||||
}
|
||||
continuity = codexContinuity{Key: cache.ID, Source: "claude_user_cache"}
|
||||
}
|
||||
} else if from == "openai-response" {
|
||||
if promptCacheKey := gjson.GetBytes(req.Payload, "prompt_cache_key"); promptCacheKey.Exists() {
|
||||
cache.ID = promptCacheKey.String()
|
||||
continuity = codexContinuity{Key: cache.ID, Source: "prompt_cache_key"}
|
||||
}
|
||||
} else if from == "openai" {
|
||||
continuity = resolveCodexContinuity(ctx, auth, req, opts)
|
||||
cache.ID = continuity.Key
|
||||
}
|
||||
|
||||
if cache.ID != "" {
|
||||
rawJSON, _ = sjson.SetBytes(rawJSON, "prompt_cache_key", cache.ID)
|
||||
headers.Set("Conversation_id", cache.ID)
|
||||
headers.Set("Session_id", cache.ID)
|
||||
}
|
||||
rawJSON = applyCodexContinuityBody(rawJSON, continuity)
|
||||
applyCodexContinuityHeaders(headers, continuity)
|
||||
|
||||
return rawJSON, headers
|
||||
return rawJSON, headers, continuity
|
||||
}
|
||||
|
||||
func applyCodexWebsocketHeaders(ctx context.Context, headers http.Header, auth *cliproxyauth.Auth, token string, cfg *config.Config) http.Header {
|
||||
@@ -826,7 +830,7 @@ func applyCodexWebsocketHeaders(ctx context.Context, headers http.Header, auth *
|
||||
betaHeader = codexResponsesWebsocketBetaHeaderValue
|
||||
}
|
||||
headers.Set("OpenAI-Beta", betaHeader)
|
||||
misc.EnsureHeader(headers, ginHeaders, "Session_id", uuid.NewString())
|
||||
misc.EnsureHeader(headers, ginHeaders, "session_id", uuid.NewString())
|
||||
ensureHeaderWithConfigPrecedence(headers, ginHeaders, "User-Agent", cfgUserAgent, codexUserAgent)
|
||||
|
||||
isAPIKey := false
|
||||
|
||||
@@ -9,7 +9,9 @@ import (
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
|
||||
cliproxyauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth"
|
||||
cliproxyexecutor "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/executor"
|
||||
sdkconfig "github.com/router-for-me/CLIProxyAPI/v6/sdk/config"
|
||||
sdktranslator "github.com/router-for-me/CLIProxyAPI/v6/sdk/translator"
|
||||
"github.com/tidwall/gjson"
|
||||
)
|
||||
|
||||
@@ -32,6 +34,49 @@ func TestBuildCodexWebsocketRequestBodyPreservesPreviousResponseID(t *testing.T)
|
||||
}
|
||||
}
|
||||
|
||||
func TestApplyCodexPromptCacheHeaders_PreservesPromptCacheRetention(t *testing.T) {
|
||||
req := cliproxyexecutor.Request{
|
||||
Model: "gpt-5-codex",
|
||||
Payload: []byte(`{"prompt_cache_key":"cache-key-1","prompt_cache_retention":"persistent"}`),
|
||||
}
|
||||
body := []byte(`{"model":"gpt-5-codex","stream":true,"prompt_cache_retention":"persistent"}`)
|
||||
|
||||
updatedBody, headers, _ := applyCodexPromptCacheHeaders(context.Background(), nil, sdktranslator.FromString("openai-response"), req, cliproxyexecutor.Options{}, body)
|
||||
|
||||
if got := gjson.GetBytes(updatedBody, "prompt_cache_key").String(); got != "cache-key-1" {
|
||||
t.Fatalf("prompt_cache_key = %q, want %q", got, "cache-key-1")
|
||||
}
|
||||
if got := gjson.GetBytes(updatedBody, "prompt_cache_retention").String(); got != "persistent" {
|
||||
t.Fatalf("prompt_cache_retention = %q, want %q", got, "persistent")
|
||||
}
|
||||
if got := headers.Get("session_id"); got != "cache-key-1" {
|
||||
t.Fatalf("session_id = %q, want %q", got, "cache-key-1")
|
||||
}
|
||||
if got := headers.Get("Conversation_id"); got != "" {
|
||||
t.Fatalf("Conversation_id = %q, want empty", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestApplyCodexPromptCacheHeaders_ClaudePreservesContinuity(t *testing.T) {
|
||||
req := cliproxyexecutor.Request{
|
||||
Model: "claude-3-7-sonnet",
|
||||
Payload: []byte(`{"metadata":{"user_id":"user-1"}}`),
|
||||
}
|
||||
body := []byte(`{"model":"gpt-5.4","stream":true}`)
|
||||
|
||||
updatedBody, headers, continuity := applyCodexPromptCacheHeaders(context.Background(), nil, sdktranslator.FromString("claude"), req, cliproxyexecutor.Options{}, body)
|
||||
|
||||
if continuity.Key == "" {
|
||||
t.Fatal("continuity.Key = empty, want non-empty")
|
||||
}
|
||||
if got := gjson.GetBytes(updatedBody, "prompt_cache_key").String(); got != continuity.Key {
|
||||
t.Fatalf("prompt_cache_key = %q, want %q", got, continuity.Key)
|
||||
}
|
||||
if got := headers.Get("session_id"); got != continuity.Key {
|
||||
t.Fatalf("session_id = %q, want %q", got, continuity.Key)
|
||||
}
|
||||
}
|
||||
|
||||
func TestApplyCodexWebsocketHeadersDefaultsToCurrentResponsesBeta(t *testing.T) {
|
||||
headers := applyCodexWebsocketHeaders(context.Background(), http.Header{}, nil, "", nil)
|
||||
|
||||
|
||||
1719
internal/runtime/executor/cursor_executor.go
Normal file
1719
internal/runtime/executor/cursor_executor.go
Normal file
File diff suppressed because it is too large
Load Diff
@@ -30,12 +30,20 @@ const (
|
||||
gitLabChatEndpoint = "/api/v4/chat/completions"
|
||||
gitLabCodeSuggestionsEndpoint = "/api/v4/code_suggestions/completions"
|
||||
gitLabSSEStreamingHeader = "X-Supports-Sse-Streaming"
|
||||
gitLabContext1MBeta = "context-1m-2025-08-07"
|
||||
gitLabNativeUserAgent = "CLIProxyAPIPlus/GitLab-Duo"
|
||||
)
|
||||
|
||||
type GitLabExecutor struct {
|
||||
cfg *config.Config
|
||||
}
|
||||
|
||||
type gitLabCatalogModel struct {
|
||||
ID string
|
||||
DisplayName string
|
||||
Provider string
|
||||
}
|
||||
|
||||
type gitLabPrompt struct {
|
||||
Instruction string
|
||||
FileName string
|
||||
@@ -53,6 +61,23 @@ type gitLabOpenAIStreamState struct {
|
||||
Finished bool
|
||||
}
|
||||
|
||||
var gitLabAgenticCatalog = []gitLabCatalogModel{
|
||||
{ID: "duo-chat-gpt-5-1", DisplayName: "GitLab Duo (GPT-5.1)", Provider: "openai"},
|
||||
{ID: "duo-chat-opus-4-6", DisplayName: "GitLab Duo (Claude Opus 4.6)", Provider: "anthropic"},
|
||||
{ID: "duo-chat-opus-4-5", DisplayName: "GitLab Duo (Claude Opus 4.5)", Provider: "anthropic"},
|
||||
{ID: "duo-chat-sonnet-4-6", DisplayName: "GitLab Duo (Claude Sonnet 4.6)", Provider: "anthropic"},
|
||||
{ID: "duo-chat-sonnet-4-5", DisplayName: "GitLab Duo (Claude Sonnet 4.5)", Provider: "anthropic"},
|
||||
{ID: "duo-chat-gpt-5-mini", DisplayName: "GitLab Duo (GPT-5 Mini)", Provider: "openai"},
|
||||
{ID: "duo-chat-gpt-5-2", DisplayName: "GitLab Duo (GPT-5.2)", Provider: "openai"},
|
||||
{ID: "duo-chat-gpt-5-2-codex", DisplayName: "GitLab Duo (GPT-5.2 Codex)", Provider: "openai"},
|
||||
{ID: "duo-chat-gpt-5-codex", DisplayName: "GitLab Duo (GPT-5 Codex)", Provider: "openai"},
|
||||
{ID: "duo-chat-haiku-4-5", DisplayName: "GitLab Duo (Claude Haiku 4.5)", Provider: "anthropic"},
|
||||
}
|
||||
|
||||
var gitLabModelAliases = map[string]string{
|
||||
"duo-chat-haiku-4-6": "duo-chat-haiku-4-5",
|
||||
}
|
||||
|
||||
func NewGitLabExecutor(cfg *config.Config) *GitLabExecutor {
|
||||
return &GitLabExecutor{cfg: cfg}
|
||||
}
|
||||
@@ -249,12 +274,12 @@ func (e *GitLabExecutor) nativeGateway(
|
||||
auth *cliproxyauth.Auth,
|
||||
req cliproxyexecutor.Request,
|
||||
) (cliproxyauth.ProviderExecutor, *cliproxyauth.Auth, cliproxyexecutor.Request, bool) {
|
||||
if nativeAuth, ok := buildGitLabAnthropicGatewayAuth(auth); ok {
|
||||
if nativeAuth, ok := buildGitLabAnthropicGatewayAuth(auth, req.Model); ok {
|
||||
nativeReq := req
|
||||
nativeReq.Model = gitLabResolvedModel(auth, req.Model)
|
||||
return NewClaudeExecutor(e.cfg), nativeAuth, nativeReq, true
|
||||
}
|
||||
if nativeAuth, ok := buildGitLabOpenAIGatewayAuth(auth); ok {
|
||||
if nativeAuth, ok := buildGitLabOpenAIGatewayAuth(auth, req.Model); ok {
|
||||
nativeReq := req
|
||||
nativeReq.Model = gitLabResolvedModel(auth, req.Model)
|
||||
return NewCodexExecutor(e.cfg), nativeAuth, nativeReq, true
|
||||
@@ -263,10 +288,10 @@ func (e *GitLabExecutor) nativeGateway(
|
||||
}
|
||||
|
||||
func (e *GitLabExecutor) nativeGatewayHTTP(auth *cliproxyauth.Auth) (cliproxyauth.ProviderExecutor, *cliproxyauth.Auth) {
|
||||
if nativeAuth, ok := buildGitLabAnthropicGatewayAuth(auth); ok {
|
||||
if nativeAuth, ok := buildGitLabAnthropicGatewayAuth(auth, ""); ok {
|
||||
return NewClaudeExecutor(e.cfg), nativeAuth
|
||||
}
|
||||
if nativeAuth, ok := buildGitLabOpenAIGatewayAuth(auth); ok {
|
||||
if nativeAuth, ok := buildGitLabOpenAIGatewayAuth(auth, ""); ok {
|
||||
return NewCodexExecutor(e.cfg), nativeAuth
|
||||
}
|
||||
return nil, nil
|
||||
@@ -664,7 +689,7 @@ func applyGitLabRequestHeaders(req *http.Request, auth *cliproxyauth.Auth) {
|
||||
if auth != nil {
|
||||
util.ApplyCustomHeadersFromAttrs(req, auth.Attributes)
|
||||
}
|
||||
for key, value := range gitLabGatewayHeaders(auth) {
|
||||
for key, value := range gitLabGatewayHeaders(auth, "") {
|
||||
if key == "" || value == "" {
|
||||
continue
|
||||
}
|
||||
@@ -672,34 +697,40 @@ func applyGitLabRequestHeaders(req *http.Request, auth *cliproxyauth.Auth) {
|
||||
}
|
||||
}
|
||||
|
||||
func gitLabGatewayHeaders(auth *cliproxyauth.Auth) map[string]string {
|
||||
if auth == nil || auth.Metadata == nil {
|
||||
return nil
|
||||
}
|
||||
raw, ok := auth.Metadata["duo_gateway_headers"]
|
||||
if !ok {
|
||||
return nil
|
||||
}
|
||||
func gitLabGatewayHeaders(auth *cliproxyauth.Auth, targetProvider string) map[string]string {
|
||||
out := make(map[string]string)
|
||||
switch typed := raw.(type) {
|
||||
case map[string]string:
|
||||
for key, value := range typed {
|
||||
key = strings.TrimSpace(key)
|
||||
value = strings.TrimSpace(value)
|
||||
if key != "" && value != "" {
|
||||
out[key] = value
|
||||
if auth != nil && auth.Metadata != nil {
|
||||
raw, ok := auth.Metadata["duo_gateway_headers"]
|
||||
if ok {
|
||||
switch typed := raw.(type) {
|
||||
case map[string]string:
|
||||
for key, value := range typed {
|
||||
key = strings.TrimSpace(key)
|
||||
value = strings.TrimSpace(value)
|
||||
if key != "" && value != "" {
|
||||
out[key] = value
|
||||
}
|
||||
}
|
||||
case map[string]any:
|
||||
for key, value := range typed {
|
||||
key = strings.TrimSpace(key)
|
||||
if key == "" {
|
||||
continue
|
||||
}
|
||||
strValue := strings.TrimSpace(fmt.Sprint(value))
|
||||
if strValue != "" {
|
||||
out[key] = strValue
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
case map[string]any:
|
||||
for key, value := range typed {
|
||||
key = strings.TrimSpace(key)
|
||||
if key == "" {
|
||||
continue
|
||||
}
|
||||
strValue := strings.TrimSpace(fmt.Sprint(value))
|
||||
if strValue != "" {
|
||||
out[key] = strValue
|
||||
}
|
||||
}
|
||||
if _, ok := out["User-Agent"]; !ok {
|
||||
out["User-Agent"] = gitLabNativeUserAgent
|
||||
}
|
||||
if strings.EqualFold(strings.TrimSpace(targetProvider), "openai") {
|
||||
if _, ok := out["anthropic-beta"]; !ok {
|
||||
out["anthropic-beta"] = gitLabContext1MBeta
|
||||
}
|
||||
}
|
||||
if len(out) == 0 {
|
||||
@@ -989,8 +1020,8 @@ func gitLabUsage(model string, translatedReq []byte, text string) (int64, int64)
|
||||
return promptTokens, int64(completionCount)
|
||||
}
|
||||
|
||||
func buildGitLabAnthropicGatewayAuth(auth *cliproxyauth.Auth) (*cliproxyauth.Auth, bool) {
|
||||
if !gitLabUsesAnthropicGateway(auth) {
|
||||
func buildGitLabAnthropicGatewayAuth(auth *cliproxyauth.Auth, requestedModel string) (*cliproxyauth.Auth, bool) {
|
||||
if !gitLabUsesAnthropicGateway(auth, requestedModel) {
|
||||
return nil, false
|
||||
}
|
||||
baseURL := gitLabAnthropicGatewayBaseURL(auth)
|
||||
@@ -1006,7 +1037,8 @@ func buildGitLabAnthropicGatewayAuth(auth *cliproxyauth.Auth) (*cliproxyauth.Aut
|
||||
}
|
||||
nativeAuth.Attributes["api_key"] = token
|
||||
nativeAuth.Attributes["base_url"] = baseURL
|
||||
for key, value := range gitLabGatewayHeaders(auth) {
|
||||
nativeAuth.Attributes["gitlab_duo_force_context_1m"] = "true"
|
||||
for key, value := range gitLabGatewayHeaders(auth, "anthropic") {
|
||||
if key == "" || value == "" {
|
||||
continue
|
||||
}
|
||||
@@ -1015,8 +1047,8 @@ func buildGitLabAnthropicGatewayAuth(auth *cliproxyauth.Auth) (*cliproxyauth.Aut
|
||||
return nativeAuth, true
|
||||
}
|
||||
|
||||
func buildGitLabOpenAIGatewayAuth(auth *cliproxyauth.Auth) (*cliproxyauth.Auth, bool) {
|
||||
if !gitLabUsesOpenAIGateway(auth) {
|
||||
func buildGitLabOpenAIGatewayAuth(auth *cliproxyauth.Auth, requestedModel string) (*cliproxyauth.Auth, bool) {
|
||||
if !gitLabUsesOpenAIGateway(auth, requestedModel) {
|
||||
return nil, false
|
||||
}
|
||||
baseURL := gitLabOpenAIGatewayBaseURL(auth)
|
||||
@@ -1032,7 +1064,7 @@ func buildGitLabOpenAIGatewayAuth(auth *cliproxyauth.Auth) (*cliproxyauth.Auth,
|
||||
}
|
||||
nativeAuth.Attributes["api_key"] = token
|
||||
nativeAuth.Attributes["base_url"] = baseURL
|
||||
for key, value := range gitLabGatewayHeaders(auth) {
|
||||
for key, value := range gitLabGatewayHeaders(auth, "openai") {
|
||||
if key == "" || value == "" {
|
||||
continue
|
||||
}
|
||||
@@ -1041,34 +1073,41 @@ func buildGitLabOpenAIGatewayAuth(auth *cliproxyauth.Auth) (*cliproxyauth.Auth,
|
||||
return nativeAuth, true
|
||||
}
|
||||
|
||||
func gitLabUsesAnthropicGateway(auth *cliproxyauth.Auth) bool {
|
||||
func gitLabUsesAnthropicGateway(auth *cliproxyauth.Auth, requestedModel string) bool {
|
||||
if auth == nil || auth.Metadata == nil {
|
||||
return false
|
||||
}
|
||||
provider := strings.ToLower(gitLabMetadataString(auth.Metadata, "model_provider"))
|
||||
if provider == "" {
|
||||
modelName := strings.ToLower(gitLabMetadataString(auth.Metadata, "model_name"))
|
||||
provider = inferGitLabProviderFromModel(modelName)
|
||||
}
|
||||
provider := gitLabGatewayProvider(auth, requestedModel)
|
||||
return provider == "anthropic" &&
|
||||
gitLabMetadataString(auth.Metadata, "duo_gateway_base_url") != "" &&
|
||||
gitLabMetadataString(auth.Metadata, "duo_gateway_token") != ""
|
||||
}
|
||||
|
||||
func gitLabUsesOpenAIGateway(auth *cliproxyauth.Auth) bool {
|
||||
func gitLabUsesOpenAIGateway(auth *cliproxyauth.Auth, requestedModel string) bool {
|
||||
if auth == nil || auth.Metadata == nil {
|
||||
return false
|
||||
}
|
||||
provider := strings.ToLower(gitLabMetadataString(auth.Metadata, "model_provider"))
|
||||
if provider == "" {
|
||||
modelName := strings.ToLower(gitLabMetadataString(auth.Metadata, "model_name"))
|
||||
provider = inferGitLabProviderFromModel(modelName)
|
||||
}
|
||||
provider := gitLabGatewayProvider(auth, requestedModel)
|
||||
return provider == "openai" &&
|
||||
gitLabMetadataString(auth.Metadata, "duo_gateway_base_url") != "" &&
|
||||
gitLabMetadataString(auth.Metadata, "duo_gateway_token") != ""
|
||||
}
|
||||
|
||||
func gitLabGatewayProvider(auth *cliproxyauth.Auth, requestedModel string) string {
|
||||
modelName := strings.TrimSpace(gitLabResolvedModel(auth, requestedModel))
|
||||
if provider := inferGitLabProviderFromModel(modelName); provider != "" {
|
||||
return provider
|
||||
}
|
||||
if auth == nil || auth.Metadata == nil {
|
||||
return ""
|
||||
}
|
||||
provider := strings.ToLower(gitLabMetadataString(auth.Metadata, "model_provider"))
|
||||
if provider == "" {
|
||||
provider = inferGitLabProviderFromModel(gitLabMetadataString(auth.Metadata, "model_name"))
|
||||
}
|
||||
return provider
|
||||
}
|
||||
|
||||
func inferGitLabProviderFromModel(model string) string {
|
||||
model = strings.ToLower(strings.TrimSpace(model))
|
||||
switch {
|
||||
@@ -1151,6 +1190,9 @@ func gitLabBaseURL(auth *cliproxyauth.Auth) string {
|
||||
func gitLabResolvedModel(auth *cliproxyauth.Auth, requested string) string {
|
||||
requested = strings.TrimSpace(thinking.ParseSuffix(requested).ModelName)
|
||||
if requested != "" && !strings.EqualFold(requested, "gitlab-duo") {
|
||||
if mapped, ok := gitLabModelAliases[strings.ToLower(requested)]; ok && strings.TrimSpace(mapped) != "" {
|
||||
return mapped
|
||||
}
|
||||
return requested
|
||||
}
|
||||
if auth != nil && auth.Metadata != nil {
|
||||
@@ -1277,8 +1319,8 @@ func gitLabAuthKind(method string) string {
|
||||
}
|
||||
|
||||
func GitLabModelsFromAuth(auth *cliproxyauth.Auth) []*registry.ModelInfo {
|
||||
models := make([]*registry.ModelInfo, 0, 4)
|
||||
seen := make(map[string]struct{}, 4)
|
||||
models := make([]*registry.ModelInfo, 0, len(gitLabAgenticCatalog)+4)
|
||||
seen := make(map[string]struct{}, len(gitLabAgenticCatalog)+4)
|
||||
addModel := func(id, displayName, provider string) {
|
||||
id = strings.TrimSpace(id)
|
||||
if id == "" {
|
||||
@@ -1302,6 +1344,18 @@ func GitLabModelsFromAuth(auth *cliproxyauth.Auth) []*registry.ModelInfo {
|
||||
}
|
||||
|
||||
addModel("gitlab-duo", "GitLab Duo", "gitlab")
|
||||
for _, model := range gitLabAgenticCatalog {
|
||||
addModel(model.ID, model.DisplayName, model.Provider)
|
||||
}
|
||||
for alias, upstream := range gitLabModelAliases {
|
||||
target := strings.TrimSpace(upstream)
|
||||
displayName := "GitLab Duo Alias"
|
||||
provider := strings.TrimSpace(inferGitLabProviderFromModel(target))
|
||||
if provider != "" {
|
||||
displayName = fmt.Sprintf("GitLab Duo Alias (%s)", provider)
|
||||
}
|
||||
addModel(alias, displayName, provider)
|
||||
}
|
||||
if auth == nil {
|
||||
return models
|
||||
}
|
||||
|
||||
@@ -217,6 +217,69 @@ func TestGitLabExecutorExecuteUsesOpenAIGateway(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestGitLabExecutorExecuteUsesRequestedModelToSelectOpenAIGateway(t *testing.T) {
|
||||
var gotAuthHeader, gotRealmHeader, gotBetaHeader, gotUserAgent string
|
||||
var gotPath string
|
||||
var gotModel string
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
gotPath = r.URL.Path
|
||||
gotAuthHeader = r.Header.Get("Authorization")
|
||||
gotRealmHeader = r.Header.Get("X-Gitlab-Realm")
|
||||
gotBetaHeader = r.Header.Get("anthropic-beta")
|
||||
gotUserAgent = r.Header.Get("User-Agent")
|
||||
gotModel = gjson.GetBytes(readBody(t, r), "model").String()
|
||||
w.Header().Set("Content-Type", "text/event-stream")
|
||||
_, _ = w.Write([]byte("data: {\"type\":\"response.created\",\"response\":{\"id\":\"resp_1\",\"created_at\":1710000000,\"model\":\"duo-chat-gpt-5-codex\"}}\n\n"))
|
||||
_, _ = w.Write([]byte("data: {\"type\":\"response.output_text.delta\",\"delta\":\"hello from explicit openai model\"}\n\n"))
|
||||
_, _ = w.Write([]byte("data: {\"type\":\"response.completed\",\"response\":{\"id\":\"resp_1\",\"created_at\":1710000000,\"model\":\"duo-chat-gpt-5-codex\",\"output\":[{\"type\":\"message\",\"id\":\"msg_1\",\"role\":\"assistant\",\"content\":[{\"type\":\"output_text\",\"text\":\"hello from explicit openai model\"}]}],\"usage\":{\"input_tokens\":11,\"output_tokens\":4,\"total_tokens\":15}}}\n\n"))
|
||||
}))
|
||||
defer srv.Close()
|
||||
|
||||
exec := NewGitLabExecutor(&config.Config{})
|
||||
auth := &cliproxyauth.Auth{
|
||||
Provider: "gitlab",
|
||||
Metadata: map[string]any{
|
||||
"duo_gateway_base_url": srv.URL,
|
||||
"duo_gateway_token": "gateway-token",
|
||||
"duo_gateway_headers": map[string]string{"X-Gitlab-Realm": "saas"},
|
||||
"model_provider": "anthropic",
|
||||
"model_name": "claude-sonnet-4-5",
|
||||
},
|
||||
}
|
||||
req := cliproxyexecutor.Request{
|
||||
Model: "duo-chat-gpt-5-codex",
|
||||
Payload: []byte(`{"model":"duo-chat-gpt-5-codex","messages":[{"role":"user","content":"hello"}]}`),
|
||||
}
|
||||
|
||||
resp, err := exec.Execute(context.Background(), auth, req, cliproxyexecutor.Options{
|
||||
SourceFormat: sdktranslator.FromString("openai"),
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("Execute() error = %v", err)
|
||||
}
|
||||
if gotPath != "/v1/proxy/openai/v1/responses" {
|
||||
t.Fatalf("Path = %q, want %q", gotPath, "/v1/proxy/openai/v1/responses")
|
||||
}
|
||||
if gotAuthHeader != "Bearer gateway-token" {
|
||||
t.Fatalf("Authorization = %q, want Bearer gateway-token", gotAuthHeader)
|
||||
}
|
||||
if gotRealmHeader != "saas" {
|
||||
t.Fatalf("X-Gitlab-Realm = %q, want saas", gotRealmHeader)
|
||||
}
|
||||
if gotBetaHeader != gitLabContext1MBeta {
|
||||
t.Fatalf("anthropic-beta = %q, want %q", gotBetaHeader, gitLabContext1MBeta)
|
||||
}
|
||||
if gotUserAgent != gitLabNativeUserAgent {
|
||||
t.Fatalf("User-Agent = %q, want %q", gotUserAgent, gitLabNativeUserAgent)
|
||||
}
|
||||
if gotModel != "duo-chat-gpt-5-codex" {
|
||||
t.Fatalf("model = %q, want duo-chat-gpt-5-codex", gotModel)
|
||||
}
|
||||
if got := gjson.GetBytes(resp.Payload, "choices.0.message.content").String(); got != "hello from explicit openai model" {
|
||||
t.Fatalf("expected explicit openai model response, got %q payload=%s", got, string(resp.Payload))
|
||||
}
|
||||
}
|
||||
|
||||
func TestGitLabExecutorRefreshUpdatesMetadata(t *testing.T) {
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
switch r.URL.Path {
|
||||
@@ -251,13 +314,12 @@ func TestGitLabExecutorRefreshUpdatesMetadata(t *testing.T) {
|
||||
ID: "gitlab-auth.json",
|
||||
Provider: "gitlab",
|
||||
Metadata: map[string]any{
|
||||
"base_url": srv.URL,
|
||||
"access_token": "oauth-access",
|
||||
"refresh_token": "oauth-refresh",
|
||||
"oauth_client_id": "client-id",
|
||||
"oauth_client_secret": "client-secret",
|
||||
"auth_method": "oauth",
|
||||
"oauth_expires_at": "2000-01-01T00:00:00Z",
|
||||
"base_url": srv.URL,
|
||||
"access_token": "oauth-access",
|
||||
"refresh_token": "oauth-refresh",
|
||||
"oauth_client_id": "client-id",
|
||||
"auth_method": "oauth",
|
||||
"oauth_expires_at": "2000-01-01T00:00:00Z",
|
||||
},
|
||||
}
|
||||
|
||||
@@ -397,9 +459,11 @@ func TestGitLabExecutorExecuteStreamFallsBackToSyntheticChat(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestGitLabExecutorExecuteStreamUsesAnthropicGateway(t *testing.T) {
|
||||
var gotPath string
|
||||
var gotPath, gotBetaHeader, gotUserAgent string
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
gotPath = r.URL.Path
|
||||
gotBetaHeader = r.Header.Get("Anthropic-Beta")
|
||||
gotUserAgent = r.Header.Get("User-Agent")
|
||||
w.Header().Set("Content-Type", "text/event-stream")
|
||||
_, _ = w.Write([]byte("event: message_start\n"))
|
||||
_, _ = w.Write([]byte("data: {\"type\":\"message_start\",\"message\":{\"id\":\"msg_1\",\"type\":\"message\",\"role\":\"assistant\",\"model\":\"claude-sonnet-4-5\",\"content\":[],\"stop_reason\":null,\"stop_sequence\":null,\"usage\":{\"input_tokens\":0,\"output_tokens\":0}}}\n\n"))
|
||||
@@ -441,6 +505,12 @@ func TestGitLabExecutorExecuteStreamUsesAnthropicGateway(t *testing.T) {
|
||||
if gotPath != "/v1/proxy/anthropic/v1/messages" {
|
||||
t.Fatalf("Path = %q, want %q", gotPath, "/v1/proxy/anthropic/v1/messages")
|
||||
}
|
||||
if !strings.Contains(gotBetaHeader, gitLabContext1MBeta) {
|
||||
t.Fatalf("Anthropic-Beta = %q, want to contain %q", gotBetaHeader, gitLabContext1MBeta)
|
||||
}
|
||||
if gotUserAgent != gitLabNativeUserAgent {
|
||||
t.Fatalf("User-Agent = %q, want %q", gotUserAgent, gitLabNativeUserAgent)
|
||||
}
|
||||
if !strings.Contains(strings.Join(lines, "\n"), "hello from gateway") {
|
||||
t.Fatalf("expected anthropic gateway stream, got %q", strings.Join(lines, "\n"))
|
||||
}
|
||||
|
||||
@@ -4,6 +4,7 @@ import (
|
||||
"bytes"
|
||||
"context"
|
||||
"fmt"
|
||||
"sort"
|
||||
"strings"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
@@ -16,6 +17,7 @@ import (
|
||||
type oaiToResponsesStateReasoning struct {
|
||||
ReasoningID string
|
||||
ReasoningData string
|
||||
OutputIndex int
|
||||
}
|
||||
type oaiToResponsesState struct {
|
||||
Seq int
|
||||
@@ -29,16 +31,19 @@ type oaiToResponsesState struct {
|
||||
MsgTextBuf map[int]*strings.Builder
|
||||
ReasoningBuf strings.Builder
|
||||
Reasonings []oaiToResponsesStateReasoning
|
||||
FuncArgsBuf map[int]*strings.Builder // index -> args
|
||||
FuncNames map[int]string // index -> name
|
||||
FuncCallIDs map[int]string // index -> call_id
|
||||
FuncArgsBuf map[string]*strings.Builder
|
||||
FuncNames map[string]string
|
||||
FuncCallIDs map[string]string
|
||||
FuncOutputIx map[string]int
|
||||
MsgOutputIx map[int]int
|
||||
NextOutputIx int
|
||||
// message item state per output index
|
||||
MsgItemAdded map[int]bool // whether response.output_item.added emitted for message
|
||||
MsgContentAdded map[int]bool // whether response.content_part.added emitted for message
|
||||
MsgItemDone map[int]bool // whether message done events were emitted
|
||||
// function item done state
|
||||
FuncArgsDone map[int]bool
|
||||
FuncItemDone map[int]bool
|
||||
FuncArgsDone map[string]bool
|
||||
FuncItemDone map[string]bool
|
||||
// usage aggregation
|
||||
PromptTokens int64
|
||||
CachedTokens int64
|
||||
@@ -60,15 +65,17 @@ func emitRespEvent(event string, payload []byte) []byte {
|
||||
func ConvertOpenAIChatCompletionsResponseToOpenAIResponses(ctx context.Context, modelName string, originalRequestRawJSON, requestRawJSON, rawJSON []byte, param *any) [][]byte {
|
||||
if *param == nil {
|
||||
*param = &oaiToResponsesState{
|
||||
FuncArgsBuf: make(map[int]*strings.Builder),
|
||||
FuncNames: make(map[int]string),
|
||||
FuncCallIDs: make(map[int]string),
|
||||
FuncArgsBuf: make(map[string]*strings.Builder),
|
||||
FuncNames: make(map[string]string),
|
||||
FuncCallIDs: make(map[string]string),
|
||||
FuncOutputIx: make(map[string]int),
|
||||
MsgOutputIx: make(map[int]int),
|
||||
MsgTextBuf: make(map[int]*strings.Builder),
|
||||
MsgItemAdded: make(map[int]bool),
|
||||
MsgContentAdded: make(map[int]bool),
|
||||
MsgItemDone: make(map[int]bool),
|
||||
FuncArgsDone: make(map[int]bool),
|
||||
FuncItemDone: make(map[int]bool),
|
||||
FuncArgsDone: make(map[string]bool),
|
||||
FuncItemDone: make(map[string]bool),
|
||||
Reasonings: make([]oaiToResponsesStateReasoning, 0),
|
||||
}
|
||||
}
|
||||
@@ -125,6 +132,12 @@ func ConvertOpenAIChatCompletionsResponseToOpenAIResponses(ctx context.Context,
|
||||
}
|
||||
|
||||
nextSeq := func() int { st.Seq++; return st.Seq }
|
||||
allocOutputIndex := func() int {
|
||||
ix := st.NextOutputIx
|
||||
st.NextOutputIx++
|
||||
return ix
|
||||
}
|
||||
toolStateKey := func(outputIndex, toolIndex int) string { return fmt.Sprintf("%d:%d", outputIndex, toolIndex) }
|
||||
var out [][]byte
|
||||
|
||||
if !st.Started {
|
||||
@@ -135,14 +148,17 @@ func ConvertOpenAIChatCompletionsResponseToOpenAIResponses(ctx context.Context,
|
||||
st.ReasoningBuf.Reset()
|
||||
st.ReasoningID = ""
|
||||
st.ReasoningIndex = 0
|
||||
st.FuncArgsBuf = make(map[int]*strings.Builder)
|
||||
st.FuncNames = make(map[int]string)
|
||||
st.FuncCallIDs = make(map[int]string)
|
||||
st.FuncArgsBuf = make(map[string]*strings.Builder)
|
||||
st.FuncNames = make(map[string]string)
|
||||
st.FuncCallIDs = make(map[string]string)
|
||||
st.FuncOutputIx = make(map[string]int)
|
||||
st.MsgOutputIx = make(map[int]int)
|
||||
st.NextOutputIx = 0
|
||||
st.MsgItemAdded = make(map[int]bool)
|
||||
st.MsgContentAdded = make(map[int]bool)
|
||||
st.MsgItemDone = make(map[int]bool)
|
||||
st.FuncArgsDone = make(map[int]bool)
|
||||
st.FuncItemDone = make(map[int]bool)
|
||||
st.FuncArgsDone = make(map[string]bool)
|
||||
st.FuncItemDone = make(map[string]bool)
|
||||
st.PromptTokens = 0
|
||||
st.CachedTokens = 0
|
||||
st.CompletionTokens = 0
|
||||
@@ -185,7 +201,7 @@ func ConvertOpenAIChatCompletionsResponseToOpenAIResponses(ctx context.Context,
|
||||
outputItemDone, _ = sjson.SetBytes(outputItemDone, "item.summary.text", text)
|
||||
out = append(out, emitRespEvent("response.output_item.done", outputItemDone))
|
||||
|
||||
st.Reasonings = append(st.Reasonings, oaiToResponsesStateReasoning{ReasoningID: st.ReasoningID, ReasoningData: text})
|
||||
st.Reasonings = append(st.Reasonings, oaiToResponsesStateReasoning{ReasoningID: st.ReasoningID, ReasoningData: text, OutputIndex: st.ReasoningIndex})
|
||||
st.ReasoningID = ""
|
||||
}
|
||||
|
||||
@@ -201,10 +217,14 @@ func ConvertOpenAIChatCompletionsResponseToOpenAIResponses(ctx context.Context,
|
||||
stopReasoning(st.ReasoningBuf.String())
|
||||
st.ReasoningBuf.Reset()
|
||||
}
|
||||
if _, exists := st.MsgOutputIx[idx]; !exists {
|
||||
st.MsgOutputIx[idx] = allocOutputIndex()
|
||||
}
|
||||
msgOutputIndex := st.MsgOutputIx[idx]
|
||||
if !st.MsgItemAdded[idx] {
|
||||
item := []byte(`{"type":"response.output_item.added","sequence_number":0,"output_index":0,"item":{"id":"","type":"message","status":"in_progress","content":[],"role":"assistant"}}`)
|
||||
item, _ = sjson.SetBytes(item, "sequence_number", nextSeq())
|
||||
item, _ = sjson.SetBytes(item, "output_index", idx)
|
||||
item, _ = sjson.SetBytes(item, "output_index", msgOutputIndex)
|
||||
item, _ = sjson.SetBytes(item, "item.id", fmt.Sprintf("msg_%s_%d", st.ResponseID, idx))
|
||||
out = append(out, emitRespEvent("response.output_item.added", item))
|
||||
st.MsgItemAdded[idx] = true
|
||||
@@ -213,7 +233,7 @@ func ConvertOpenAIChatCompletionsResponseToOpenAIResponses(ctx context.Context,
|
||||
part := []byte(`{"type":"response.content_part.added","sequence_number":0,"item_id":"","output_index":0,"content_index":0,"part":{"type":"output_text","annotations":[],"logprobs":[],"text":""}}`)
|
||||
part, _ = sjson.SetBytes(part, "sequence_number", nextSeq())
|
||||
part, _ = sjson.SetBytes(part, "item_id", fmt.Sprintf("msg_%s_%d", st.ResponseID, idx))
|
||||
part, _ = sjson.SetBytes(part, "output_index", idx)
|
||||
part, _ = sjson.SetBytes(part, "output_index", msgOutputIndex)
|
||||
part, _ = sjson.SetBytes(part, "content_index", 0)
|
||||
out = append(out, emitRespEvent("response.content_part.added", part))
|
||||
st.MsgContentAdded[idx] = true
|
||||
@@ -222,7 +242,7 @@ func ConvertOpenAIChatCompletionsResponseToOpenAIResponses(ctx context.Context,
|
||||
msg := []byte(`{"type":"response.output_text.delta","sequence_number":0,"item_id":"","output_index":0,"content_index":0,"delta":"","logprobs":[]}`)
|
||||
msg, _ = sjson.SetBytes(msg, "sequence_number", nextSeq())
|
||||
msg, _ = sjson.SetBytes(msg, "item_id", fmt.Sprintf("msg_%s_%d", st.ResponseID, idx))
|
||||
msg, _ = sjson.SetBytes(msg, "output_index", idx)
|
||||
msg, _ = sjson.SetBytes(msg, "output_index", msgOutputIndex)
|
||||
msg, _ = sjson.SetBytes(msg, "content_index", 0)
|
||||
msg, _ = sjson.SetBytes(msg, "delta", c.String())
|
||||
out = append(out, emitRespEvent("response.output_text.delta", msg))
|
||||
@@ -238,10 +258,10 @@ func ConvertOpenAIChatCompletionsResponseToOpenAIResponses(ctx context.Context,
|
||||
// On first appearance, add reasoning item and part
|
||||
if st.ReasoningID == "" {
|
||||
st.ReasoningID = fmt.Sprintf("rs_%s_%d", st.ResponseID, idx)
|
||||
st.ReasoningIndex = idx
|
||||
st.ReasoningIndex = allocOutputIndex()
|
||||
item := []byte(`{"type":"response.output_item.added","sequence_number":0,"output_index":0,"item":{"id":"","type":"reasoning","status":"in_progress","summary":[]}}`)
|
||||
item, _ = sjson.SetBytes(item, "sequence_number", nextSeq())
|
||||
item, _ = sjson.SetBytes(item, "output_index", idx)
|
||||
item, _ = sjson.SetBytes(item, "output_index", st.ReasoningIndex)
|
||||
item, _ = sjson.SetBytes(item, "item.id", st.ReasoningID)
|
||||
out = append(out, emitRespEvent("response.output_item.added", item))
|
||||
part := []byte(`{"type":"response.reasoning_summary_part.added","sequence_number":0,"item_id":"","output_index":0,"summary_index":0,"part":{"type":"summary_text","text":""}}`)
|
||||
@@ -269,6 +289,7 @@ func ConvertOpenAIChatCompletionsResponseToOpenAIResponses(ctx context.Context,
|
||||
// Before emitting any function events, if a message is open for this index,
|
||||
// close its text/content to match Codex expected ordering.
|
||||
if st.MsgItemAdded[idx] && !st.MsgItemDone[idx] {
|
||||
msgOutputIndex := st.MsgOutputIx[idx]
|
||||
fullText := ""
|
||||
if b := st.MsgTextBuf[idx]; b != nil {
|
||||
fullText = b.String()
|
||||
@@ -276,7 +297,7 @@ func ConvertOpenAIChatCompletionsResponseToOpenAIResponses(ctx context.Context,
|
||||
done := []byte(`{"type":"response.output_text.done","sequence_number":0,"item_id":"","output_index":0,"content_index":0,"text":"","logprobs":[]}`)
|
||||
done, _ = sjson.SetBytes(done, "sequence_number", nextSeq())
|
||||
done, _ = sjson.SetBytes(done, "item_id", fmt.Sprintf("msg_%s_%d", st.ResponseID, idx))
|
||||
done, _ = sjson.SetBytes(done, "output_index", idx)
|
||||
done, _ = sjson.SetBytes(done, "output_index", msgOutputIndex)
|
||||
done, _ = sjson.SetBytes(done, "content_index", 0)
|
||||
done, _ = sjson.SetBytes(done, "text", fullText)
|
||||
out = append(out, emitRespEvent("response.output_text.done", done))
|
||||
@@ -284,69 +305,72 @@ func ConvertOpenAIChatCompletionsResponseToOpenAIResponses(ctx context.Context,
|
||||
partDone := []byte(`{"type":"response.content_part.done","sequence_number":0,"item_id":"","output_index":0,"content_index":0,"part":{"type":"output_text","annotations":[],"logprobs":[],"text":""}}`)
|
||||
partDone, _ = sjson.SetBytes(partDone, "sequence_number", nextSeq())
|
||||
partDone, _ = sjson.SetBytes(partDone, "item_id", fmt.Sprintf("msg_%s_%d", st.ResponseID, idx))
|
||||
partDone, _ = sjson.SetBytes(partDone, "output_index", idx)
|
||||
partDone, _ = sjson.SetBytes(partDone, "output_index", msgOutputIndex)
|
||||
partDone, _ = sjson.SetBytes(partDone, "content_index", 0)
|
||||
partDone, _ = sjson.SetBytes(partDone, "part.text", fullText)
|
||||
out = append(out, emitRespEvent("response.content_part.done", partDone))
|
||||
|
||||
itemDone := []byte(`{"type":"response.output_item.done","sequence_number":0,"output_index":0,"item":{"id":"","type":"message","status":"completed","content":[{"type":"output_text","annotations":[],"logprobs":[],"text":""}],"role":"assistant"}}`)
|
||||
itemDone, _ = sjson.SetBytes(itemDone, "sequence_number", nextSeq())
|
||||
itemDone, _ = sjson.SetBytes(itemDone, "output_index", idx)
|
||||
itemDone, _ = sjson.SetBytes(itemDone, "output_index", msgOutputIndex)
|
||||
itemDone, _ = sjson.SetBytes(itemDone, "item.id", fmt.Sprintf("msg_%s_%d", st.ResponseID, idx))
|
||||
itemDone, _ = sjson.SetBytes(itemDone, "item.content.0.text", fullText)
|
||||
out = append(out, emitRespEvent("response.output_item.done", itemDone))
|
||||
st.MsgItemDone[idx] = true
|
||||
}
|
||||
|
||||
// Only emit item.added once per tool call and preserve call_id across chunks.
|
||||
newCallID := tcs.Get("0.id").String()
|
||||
nameChunk := tcs.Get("0.function.name").String()
|
||||
if nameChunk != "" {
|
||||
st.FuncNames[idx] = nameChunk
|
||||
}
|
||||
existingCallID := st.FuncCallIDs[idx]
|
||||
effectiveCallID := existingCallID
|
||||
shouldEmitItem := false
|
||||
if existingCallID == "" && newCallID != "" {
|
||||
// First time seeing a valid call_id for this index
|
||||
effectiveCallID = newCallID
|
||||
st.FuncCallIDs[idx] = newCallID
|
||||
shouldEmitItem = true
|
||||
}
|
||||
|
||||
if shouldEmitItem && effectiveCallID != "" {
|
||||
o := []byte(`{"type":"response.output_item.added","sequence_number":0,"output_index":0,"item":{"id":"","type":"function_call","status":"in_progress","arguments":"","call_id":"","name":""}}`)
|
||||
o, _ = sjson.SetBytes(o, "sequence_number", nextSeq())
|
||||
o, _ = sjson.SetBytes(o, "output_index", idx)
|
||||
o, _ = sjson.SetBytes(o, "item.id", fmt.Sprintf("fc_%s", effectiveCallID))
|
||||
o, _ = sjson.SetBytes(o, "item.call_id", effectiveCallID)
|
||||
name := st.FuncNames[idx]
|
||||
o, _ = sjson.SetBytes(o, "item.name", name)
|
||||
out = append(out, emitRespEvent("response.output_item.added", o))
|
||||
}
|
||||
|
||||
// Ensure args buffer exists for this index
|
||||
if st.FuncArgsBuf[idx] == nil {
|
||||
st.FuncArgsBuf[idx] = &strings.Builder{}
|
||||
}
|
||||
|
||||
// Append arguments delta if available and we have a valid call_id to reference
|
||||
if args := tcs.Get("0.function.arguments"); args.Exists() && args.String() != "" {
|
||||
// Prefer an already known call_id; fall back to newCallID if first time
|
||||
refCallID := st.FuncCallIDs[idx]
|
||||
if refCallID == "" {
|
||||
refCallID = newCallID
|
||||
tcs.ForEach(func(_, tc gjson.Result) bool {
|
||||
toolIndex := int(tc.Get("index").Int())
|
||||
key := toolStateKey(idx, toolIndex)
|
||||
newCallID := tc.Get("id").String()
|
||||
nameChunk := tc.Get("function.name").String()
|
||||
if nameChunk != "" {
|
||||
st.FuncNames[key] = nameChunk
|
||||
}
|
||||
if refCallID != "" {
|
||||
ad := []byte(`{"type":"response.function_call_arguments.delta","sequence_number":0,"item_id":"","output_index":0,"delta":""}`)
|
||||
ad, _ = sjson.SetBytes(ad, "sequence_number", nextSeq())
|
||||
ad, _ = sjson.SetBytes(ad, "item_id", fmt.Sprintf("fc_%s", refCallID))
|
||||
ad, _ = sjson.SetBytes(ad, "output_index", idx)
|
||||
ad, _ = sjson.SetBytes(ad, "delta", args.String())
|
||||
out = append(out, emitRespEvent("response.function_call_arguments.delta", ad))
|
||||
|
||||
existingCallID := st.FuncCallIDs[key]
|
||||
effectiveCallID := existingCallID
|
||||
shouldEmitItem := false
|
||||
if existingCallID == "" && newCallID != "" {
|
||||
effectiveCallID = newCallID
|
||||
st.FuncCallIDs[key] = newCallID
|
||||
st.FuncOutputIx[key] = allocOutputIndex()
|
||||
shouldEmitItem = true
|
||||
}
|
||||
st.FuncArgsBuf[idx].WriteString(args.String())
|
||||
}
|
||||
|
||||
if shouldEmitItem && effectiveCallID != "" {
|
||||
outputIndex := st.FuncOutputIx[key]
|
||||
o := []byte(`{"type":"response.output_item.added","sequence_number":0,"output_index":0,"item":{"id":"","type":"function_call","status":"in_progress","arguments":"","call_id":"","name":""}}`)
|
||||
o, _ = sjson.SetBytes(o, "sequence_number", nextSeq())
|
||||
o, _ = sjson.SetBytes(o, "output_index", outputIndex)
|
||||
o, _ = sjson.SetBytes(o, "item.id", fmt.Sprintf("fc_%s", effectiveCallID))
|
||||
o, _ = sjson.SetBytes(o, "item.call_id", effectiveCallID)
|
||||
o, _ = sjson.SetBytes(o, "item.name", st.FuncNames[key])
|
||||
out = append(out, emitRespEvent("response.output_item.added", o))
|
||||
}
|
||||
|
||||
if st.FuncArgsBuf[key] == nil {
|
||||
st.FuncArgsBuf[key] = &strings.Builder{}
|
||||
}
|
||||
|
||||
if args := tc.Get("function.arguments"); args.Exists() && args.String() != "" {
|
||||
refCallID := st.FuncCallIDs[key]
|
||||
if refCallID == "" {
|
||||
refCallID = newCallID
|
||||
}
|
||||
if refCallID != "" {
|
||||
outputIndex := st.FuncOutputIx[key]
|
||||
ad := []byte(`{"type":"response.function_call_arguments.delta","sequence_number":0,"item_id":"","output_index":0,"delta":""}`)
|
||||
ad, _ = sjson.SetBytes(ad, "sequence_number", nextSeq())
|
||||
ad, _ = sjson.SetBytes(ad, "item_id", fmt.Sprintf("fc_%s", refCallID))
|
||||
ad, _ = sjson.SetBytes(ad, "output_index", outputIndex)
|
||||
ad, _ = sjson.SetBytes(ad, "delta", args.String())
|
||||
out = append(out, emitRespEvent("response.function_call_arguments.delta", ad))
|
||||
}
|
||||
st.FuncArgsBuf[key].WriteString(args.String())
|
||||
}
|
||||
return true
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
@@ -360,15 +384,10 @@ func ConvertOpenAIChatCompletionsResponseToOpenAIResponses(ctx context.Context,
|
||||
for i := range st.MsgItemAdded {
|
||||
idxs = append(idxs, i)
|
||||
}
|
||||
for i := 0; i < len(idxs); i++ {
|
||||
for j := i + 1; j < len(idxs); j++ {
|
||||
if idxs[j] < idxs[i] {
|
||||
idxs[i], idxs[j] = idxs[j], idxs[i]
|
||||
}
|
||||
}
|
||||
}
|
||||
sort.Slice(idxs, func(i, j int) bool { return st.MsgOutputIx[idxs[i]] < st.MsgOutputIx[idxs[j]] })
|
||||
for _, i := range idxs {
|
||||
if st.MsgItemAdded[i] && !st.MsgItemDone[i] {
|
||||
msgOutputIndex := st.MsgOutputIx[i]
|
||||
fullText := ""
|
||||
if b := st.MsgTextBuf[i]; b != nil {
|
||||
fullText = b.String()
|
||||
@@ -376,7 +395,7 @@ func ConvertOpenAIChatCompletionsResponseToOpenAIResponses(ctx context.Context,
|
||||
done := []byte(`{"type":"response.output_text.done","sequence_number":0,"item_id":"","output_index":0,"content_index":0,"text":"","logprobs":[]}`)
|
||||
done, _ = sjson.SetBytes(done, "sequence_number", nextSeq())
|
||||
done, _ = sjson.SetBytes(done, "item_id", fmt.Sprintf("msg_%s_%d", st.ResponseID, i))
|
||||
done, _ = sjson.SetBytes(done, "output_index", i)
|
||||
done, _ = sjson.SetBytes(done, "output_index", msgOutputIndex)
|
||||
done, _ = sjson.SetBytes(done, "content_index", 0)
|
||||
done, _ = sjson.SetBytes(done, "text", fullText)
|
||||
out = append(out, emitRespEvent("response.output_text.done", done))
|
||||
@@ -384,14 +403,14 @@ func ConvertOpenAIChatCompletionsResponseToOpenAIResponses(ctx context.Context,
|
||||
partDone := []byte(`{"type":"response.content_part.done","sequence_number":0,"item_id":"","output_index":0,"content_index":0,"part":{"type":"output_text","annotations":[],"logprobs":[],"text":""}}`)
|
||||
partDone, _ = sjson.SetBytes(partDone, "sequence_number", nextSeq())
|
||||
partDone, _ = sjson.SetBytes(partDone, "item_id", fmt.Sprintf("msg_%s_%d", st.ResponseID, i))
|
||||
partDone, _ = sjson.SetBytes(partDone, "output_index", i)
|
||||
partDone, _ = sjson.SetBytes(partDone, "output_index", msgOutputIndex)
|
||||
partDone, _ = sjson.SetBytes(partDone, "content_index", 0)
|
||||
partDone, _ = sjson.SetBytes(partDone, "part.text", fullText)
|
||||
out = append(out, emitRespEvent("response.content_part.done", partDone))
|
||||
|
||||
itemDone := []byte(`{"type":"response.output_item.done","sequence_number":0,"output_index":0,"item":{"id":"","type":"message","status":"completed","content":[{"type":"output_text","annotations":[],"logprobs":[],"text":""}],"role":"assistant"}}`)
|
||||
itemDone, _ = sjson.SetBytes(itemDone, "sequence_number", nextSeq())
|
||||
itemDone, _ = sjson.SetBytes(itemDone, "output_index", i)
|
||||
itemDone, _ = sjson.SetBytes(itemDone, "output_index", msgOutputIndex)
|
||||
itemDone, _ = sjson.SetBytes(itemDone, "item.id", fmt.Sprintf("msg_%s_%d", st.ResponseID, i))
|
||||
itemDone, _ = sjson.SetBytes(itemDone, "item.content.0.text", fullText)
|
||||
out = append(out, emitRespEvent("response.output_item.done", itemDone))
|
||||
@@ -407,43 +426,42 @@ func ConvertOpenAIChatCompletionsResponseToOpenAIResponses(ctx context.Context,
|
||||
|
||||
// Emit function call done events for any active function calls
|
||||
if len(st.FuncCallIDs) > 0 {
|
||||
idxs := make([]int, 0, len(st.FuncCallIDs))
|
||||
for i := range st.FuncCallIDs {
|
||||
idxs = append(idxs, i)
|
||||
keys := make([]string, 0, len(st.FuncCallIDs))
|
||||
for key := range st.FuncCallIDs {
|
||||
keys = append(keys, key)
|
||||
}
|
||||
for i := 0; i < len(idxs); i++ {
|
||||
for j := i + 1; j < len(idxs); j++ {
|
||||
if idxs[j] < idxs[i] {
|
||||
idxs[i], idxs[j] = idxs[j], idxs[i]
|
||||
}
|
||||
}
|
||||
}
|
||||
for _, i := range idxs {
|
||||
callID := st.FuncCallIDs[i]
|
||||
if callID == "" || st.FuncItemDone[i] {
|
||||
sort.Slice(keys, func(i, j int) bool {
|
||||
left := st.FuncOutputIx[keys[i]]
|
||||
right := st.FuncOutputIx[keys[j]]
|
||||
return left < right || (left == right && keys[i] < keys[j])
|
||||
})
|
||||
for _, key := range keys {
|
||||
callID := st.FuncCallIDs[key]
|
||||
if callID == "" || st.FuncItemDone[key] {
|
||||
continue
|
||||
}
|
||||
outputIndex := st.FuncOutputIx[key]
|
||||
args := "{}"
|
||||
if b := st.FuncArgsBuf[i]; b != nil && b.Len() > 0 {
|
||||
if b := st.FuncArgsBuf[key]; b != nil && b.Len() > 0 {
|
||||
args = b.String()
|
||||
}
|
||||
fcDone := []byte(`{"type":"response.function_call_arguments.done","sequence_number":0,"item_id":"","output_index":0,"arguments":""}`)
|
||||
fcDone, _ = sjson.SetBytes(fcDone, "sequence_number", nextSeq())
|
||||
fcDone, _ = sjson.SetBytes(fcDone, "item_id", fmt.Sprintf("fc_%s", callID))
|
||||
fcDone, _ = sjson.SetBytes(fcDone, "output_index", i)
|
||||
fcDone, _ = sjson.SetBytes(fcDone, "output_index", outputIndex)
|
||||
fcDone, _ = sjson.SetBytes(fcDone, "arguments", args)
|
||||
out = append(out, emitRespEvent("response.function_call_arguments.done", fcDone))
|
||||
|
||||
itemDone := []byte(`{"type":"response.output_item.done","sequence_number":0,"output_index":0,"item":{"id":"","type":"function_call","status":"completed","arguments":"","call_id":"","name":""}}`)
|
||||
itemDone, _ = sjson.SetBytes(itemDone, "sequence_number", nextSeq())
|
||||
itemDone, _ = sjson.SetBytes(itemDone, "output_index", i)
|
||||
itemDone, _ = sjson.SetBytes(itemDone, "output_index", outputIndex)
|
||||
itemDone, _ = sjson.SetBytes(itemDone, "item.id", fmt.Sprintf("fc_%s", callID))
|
||||
itemDone, _ = sjson.SetBytes(itemDone, "item.arguments", args)
|
||||
itemDone, _ = sjson.SetBytes(itemDone, "item.call_id", callID)
|
||||
itemDone, _ = sjson.SetBytes(itemDone, "item.name", st.FuncNames[i])
|
||||
itemDone, _ = sjson.SetBytes(itemDone, "item.name", st.FuncNames[key])
|
||||
out = append(out, emitRespEvent("response.output_item.done", itemDone))
|
||||
st.FuncItemDone[i] = true
|
||||
st.FuncArgsDone[i] = true
|
||||
st.FuncItemDone[key] = true
|
||||
st.FuncArgsDone[key] = true
|
||||
}
|
||||
}
|
||||
completed := []byte(`{"type":"response.completed","sequence_number":0,"response":{"id":"","object":"response","created_at":0,"status":"completed","background":false,"error":null}}`)
|
||||
@@ -516,28 +534,21 @@ func ConvertOpenAIChatCompletionsResponseToOpenAIResponses(ctx context.Context,
|
||||
}
|
||||
// Build response.output using aggregated buffers
|
||||
outputsWrapper := []byte(`{"arr":[]}`)
|
||||
type completedOutputItem struct {
|
||||
index int
|
||||
raw []byte
|
||||
}
|
||||
outputItems := make([]completedOutputItem, 0, len(st.Reasonings)+len(st.MsgItemAdded)+len(st.FuncArgsBuf))
|
||||
if len(st.Reasonings) > 0 {
|
||||
for _, r := range st.Reasonings {
|
||||
item := []byte(`{"id":"","type":"reasoning","summary":[{"type":"summary_text","text":""}]}`)
|
||||
item, _ = sjson.SetBytes(item, "id", r.ReasoningID)
|
||||
item, _ = sjson.SetBytes(item, "summary.0.text", r.ReasoningData)
|
||||
outputsWrapper, _ = sjson.SetRawBytes(outputsWrapper, "arr.-1", item)
|
||||
outputItems = append(outputItems, completedOutputItem{index: r.OutputIndex, raw: item})
|
||||
}
|
||||
}
|
||||
// Append message items in ascending index order
|
||||
if len(st.MsgItemAdded) > 0 {
|
||||
midxs := make([]int, 0, len(st.MsgItemAdded))
|
||||
for i := range st.MsgItemAdded {
|
||||
midxs = append(midxs, i)
|
||||
}
|
||||
for i := 0; i < len(midxs); i++ {
|
||||
for j := i + 1; j < len(midxs); j++ {
|
||||
if midxs[j] < midxs[i] {
|
||||
midxs[i], midxs[j] = midxs[j], midxs[i]
|
||||
}
|
||||
}
|
||||
}
|
||||
for _, i := range midxs {
|
||||
txt := ""
|
||||
if b := st.MsgTextBuf[i]; b != nil {
|
||||
txt = b.String()
|
||||
@@ -545,37 +556,29 @@ func ConvertOpenAIChatCompletionsResponseToOpenAIResponses(ctx context.Context,
|
||||
item := []byte(`{"id":"","type":"message","status":"completed","content":[{"type":"output_text","annotations":[],"logprobs":[],"text":""}],"role":"assistant"}`)
|
||||
item, _ = sjson.SetBytes(item, "id", fmt.Sprintf("msg_%s_%d", st.ResponseID, i))
|
||||
item, _ = sjson.SetBytes(item, "content.0.text", txt)
|
||||
outputsWrapper, _ = sjson.SetRawBytes(outputsWrapper, "arr.-1", item)
|
||||
outputItems = append(outputItems, completedOutputItem{index: st.MsgOutputIx[i], raw: item})
|
||||
}
|
||||
}
|
||||
if len(st.FuncArgsBuf) > 0 {
|
||||
idxs := make([]int, 0, len(st.FuncArgsBuf))
|
||||
for i := range st.FuncArgsBuf {
|
||||
idxs = append(idxs, i)
|
||||
}
|
||||
// small-N sort without extra imports
|
||||
for i := 0; i < len(idxs); i++ {
|
||||
for j := i + 1; j < len(idxs); j++ {
|
||||
if idxs[j] < idxs[i] {
|
||||
idxs[i], idxs[j] = idxs[j], idxs[i]
|
||||
}
|
||||
}
|
||||
}
|
||||
for _, i := range idxs {
|
||||
for key := range st.FuncArgsBuf {
|
||||
args := ""
|
||||
if b := st.FuncArgsBuf[i]; b != nil {
|
||||
if b := st.FuncArgsBuf[key]; b != nil {
|
||||
args = b.String()
|
||||
}
|
||||
callID := st.FuncCallIDs[i]
|
||||
name := st.FuncNames[i]
|
||||
callID := st.FuncCallIDs[key]
|
||||
name := st.FuncNames[key]
|
||||
item := []byte(`{"id":"","type":"function_call","status":"completed","arguments":"","call_id":"","name":""}`)
|
||||
item, _ = sjson.SetBytes(item, "id", fmt.Sprintf("fc_%s", callID))
|
||||
item, _ = sjson.SetBytes(item, "arguments", args)
|
||||
item, _ = sjson.SetBytes(item, "call_id", callID)
|
||||
item, _ = sjson.SetBytes(item, "name", name)
|
||||
outputsWrapper, _ = sjson.SetRawBytes(outputsWrapper, "arr.-1", item)
|
||||
outputItems = append(outputItems, completedOutputItem{index: st.FuncOutputIx[key], raw: item})
|
||||
}
|
||||
}
|
||||
sort.Slice(outputItems, func(i, j int) bool { return outputItems[i].index < outputItems[j].index })
|
||||
for _, item := range outputItems {
|
||||
outputsWrapper, _ = sjson.SetRawBytes(outputsWrapper, "arr.-1", item.raw)
|
||||
}
|
||||
if gjson.GetBytes(outputsWrapper, "arr.#").Int() > 0 {
|
||||
completed, _ = sjson.SetRawBytes(completed, "response.output", []byte(gjson.GetBytes(outputsWrapper, "arr").Raw))
|
||||
}
|
||||
|
||||
@@ -0,0 +1,305 @@
|
||||
package responses
|
||||
|
||||
import (
|
||||
"context"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/tidwall/gjson"
|
||||
)
|
||||
|
||||
func parseOpenAIResponsesSSEEvent(t *testing.T, chunk []byte) (string, gjson.Result) {
|
||||
t.Helper()
|
||||
|
||||
lines := strings.Split(string(chunk), "\n")
|
||||
if len(lines) < 2 {
|
||||
t.Fatalf("unexpected SSE chunk: %q", chunk)
|
||||
}
|
||||
|
||||
event := strings.TrimSpace(strings.TrimPrefix(lines[0], "event:"))
|
||||
dataLine := strings.TrimSpace(strings.TrimPrefix(lines[1], "data:"))
|
||||
if !gjson.Valid(dataLine) {
|
||||
t.Fatalf("invalid SSE data JSON: %q", dataLine)
|
||||
}
|
||||
return event, gjson.Parse(dataLine)
|
||||
}
|
||||
|
||||
func TestConvertOpenAIChatCompletionsResponseToOpenAIResponses_MultipleToolCallsRemainSeparate(t *testing.T) {
|
||||
in := []string{
|
||||
`data: {"id":"resp_test","object":"chat.completion.chunk","created":1773896263,"model":"model","choices":[{"index":0,"delta":{"role":"assistant","content":null,"reasoning_content":null,"tool_calls":[{"index":0,"id":"call_read","type":"function","function":{"name":"read","arguments":""}}]},"finish_reason":null}]}`,
|
||||
`data: {"id":"resp_test","object":"chat.completion.chunk","created":1773896263,"model":"model","choices":[{"index":0,"delta":{"role":null,"content":null,"reasoning_content":null,"tool_calls":[{"index":0,"function":{"arguments":"{\"filePath\":\"C:\\\\repo\",\"limit\":400,\"offset\":1}"}}]},"finish_reason":null}]}`,
|
||||
`data: {"id":"resp_test","object":"chat.completion.chunk","created":1773896263,"model":"model","choices":[{"index":0,"delta":{"role":"assistant","content":null,"reasoning_content":null,"tool_calls":[{"index":1,"id":"call_glob","type":"function","function":{"name":"glob","arguments":""}}]},"finish_reason":null}]}`,
|
||||
`data: {"id":"resp_test","object":"chat.completion.chunk","created":1773896263,"model":"model","choices":[{"index":0,"delta":{"role":null,"content":null,"reasoning_content":null,"tool_calls":[{"index":1,"function":{"arguments":"{\"path\":\"C:\\\\repo\",\"pattern\":\"*.{yml,yaml}\"}"}}]},"finish_reason":null}]}`,
|
||||
`data: {"id":"resp_test","object":"chat.completion.chunk","created":1773896263,"model":"model","choices":[{"index":0,"delta":{"role":null,"content":null,"reasoning_content":null,"tool_calls":null},"finish_reason":"tool_calls"}],"usage":{"completion_tokens":10,"total_tokens":20,"prompt_tokens":10}}`,
|
||||
}
|
||||
|
||||
request := []byte(`{"model":"gpt-5.4","tool_choice":"auto","parallel_tool_calls":true}`)
|
||||
|
||||
var param any
|
||||
var out [][]byte
|
||||
for _, line := range in {
|
||||
out = append(out, ConvertOpenAIChatCompletionsResponseToOpenAIResponses(context.Background(), "model", request, request, []byte(line), ¶m)...)
|
||||
}
|
||||
|
||||
addedNames := map[string]string{}
|
||||
doneArgs := map[string]string{}
|
||||
doneNames := map[string]string{}
|
||||
outputItems := map[string]gjson.Result{}
|
||||
|
||||
for _, chunk := range out {
|
||||
ev, data := parseOpenAIResponsesSSEEvent(t, chunk)
|
||||
switch ev {
|
||||
case "response.output_item.added":
|
||||
if data.Get("item.type").String() != "function_call" {
|
||||
continue
|
||||
}
|
||||
addedNames[data.Get("item.call_id").String()] = data.Get("item.name").String()
|
||||
case "response.output_item.done":
|
||||
if data.Get("item.type").String() != "function_call" {
|
||||
continue
|
||||
}
|
||||
callID := data.Get("item.call_id").String()
|
||||
doneArgs[callID] = data.Get("item.arguments").String()
|
||||
doneNames[callID] = data.Get("item.name").String()
|
||||
case "response.completed":
|
||||
output := data.Get("response.output")
|
||||
for _, item := range output.Array() {
|
||||
if item.Get("type").String() == "function_call" {
|
||||
outputItems[item.Get("call_id").String()] = item
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if len(addedNames) != 2 {
|
||||
t.Fatalf("expected 2 function_call added events, got %d", len(addedNames))
|
||||
}
|
||||
if len(doneArgs) != 2 {
|
||||
t.Fatalf("expected 2 function_call done events, got %d", len(doneArgs))
|
||||
}
|
||||
|
||||
if addedNames["call_read"] != "read" {
|
||||
t.Fatalf("unexpected added name for call_read: %q", addedNames["call_read"])
|
||||
}
|
||||
if addedNames["call_glob"] != "glob" {
|
||||
t.Fatalf("unexpected added name for call_glob: %q", addedNames["call_glob"])
|
||||
}
|
||||
|
||||
if !gjson.Valid(doneArgs["call_read"]) {
|
||||
t.Fatalf("invalid JSON args for call_read: %q", doneArgs["call_read"])
|
||||
}
|
||||
if !gjson.Valid(doneArgs["call_glob"]) {
|
||||
t.Fatalf("invalid JSON args for call_glob: %q", doneArgs["call_glob"])
|
||||
}
|
||||
if strings.Contains(doneArgs["call_read"], "}{") {
|
||||
t.Fatalf("call_read args were concatenated: %q", doneArgs["call_read"])
|
||||
}
|
||||
if strings.Contains(doneArgs["call_glob"], "}{") {
|
||||
t.Fatalf("call_glob args were concatenated: %q", doneArgs["call_glob"])
|
||||
}
|
||||
|
||||
if doneNames["call_read"] != "read" {
|
||||
t.Fatalf("unexpected done name for call_read: %q", doneNames["call_read"])
|
||||
}
|
||||
if doneNames["call_glob"] != "glob" {
|
||||
t.Fatalf("unexpected done name for call_glob: %q", doneNames["call_glob"])
|
||||
}
|
||||
|
||||
if got := gjson.Get(doneArgs["call_read"], "filePath").String(); got != `C:\repo` {
|
||||
t.Fatalf("unexpected filePath for call_read: %q", got)
|
||||
}
|
||||
if got := gjson.Get(doneArgs["call_glob"], "path").String(); got != `C:\repo` {
|
||||
t.Fatalf("unexpected path for call_glob: %q", got)
|
||||
}
|
||||
if got := gjson.Get(doneArgs["call_glob"], "pattern").String(); got != "*.{yml,yaml}" {
|
||||
t.Fatalf("unexpected pattern for call_glob: %q", got)
|
||||
}
|
||||
|
||||
if len(outputItems) != 2 {
|
||||
t.Fatalf("expected 2 function_call items in response.output, got %d", len(outputItems))
|
||||
}
|
||||
if outputItems["call_read"].Get("name").String() != "read" {
|
||||
t.Fatalf("unexpected response.output name for call_read: %q", outputItems["call_read"].Get("name").String())
|
||||
}
|
||||
if outputItems["call_glob"].Get("name").String() != "glob" {
|
||||
t.Fatalf("unexpected response.output name for call_glob: %q", outputItems["call_glob"].Get("name").String())
|
||||
}
|
||||
}
|
||||
|
||||
func TestConvertOpenAIChatCompletionsResponseToOpenAIResponses_MultiChoiceToolCallsUseDistinctOutputIndexes(t *testing.T) {
|
||||
in := []string{
|
||||
`data: {"id":"resp_multi_choice","object":"chat.completion.chunk","created":1773896263,"model":"model","choices":[{"index":0,"delta":{"role":"assistant","content":null,"reasoning_content":null,"tool_calls":[{"index":0,"id":"call_choice0","type":"function","function":{"name":"glob","arguments":""}}]},"finish_reason":null},{"index":1,"delta":{"role":"assistant","content":null,"reasoning_content":null,"tool_calls":[{"index":0,"id":"call_choice1","type":"function","function":{"name":"read","arguments":""}}]},"finish_reason":null}]}`,
|
||||
`data: {"id":"resp_multi_choice","object":"chat.completion.chunk","created":1773896263,"model":"model","choices":[{"index":0,"delta":{"role":null,"content":null,"reasoning_content":null,"tool_calls":[{"index":0,"function":{"arguments":"{\"path\":\"C:\\\\repo\",\"pattern\":\"*.go\"}"}}]},"finish_reason":null},{"index":1,"delta":{"role":null,"content":null,"reasoning_content":null,"tool_calls":[{"index":0,"function":{"arguments":"{\"filePath\":\"C:\\\\repo\\\\README.md\",\"limit\":20,\"offset\":1}"}}]},"finish_reason":null}]}`,
|
||||
`data: {"id":"resp_multi_choice","object":"chat.completion.chunk","created":1773896263,"model":"model","choices":[{"index":0,"delta":{"role":null,"content":null,"reasoning_content":null,"tool_calls":null},"finish_reason":"tool_calls"},{"index":1,"delta":{"role":null,"content":null,"reasoning_content":null,"tool_calls":null},"finish_reason":"tool_calls"}],"usage":{"completion_tokens":10,"total_tokens":20,"prompt_tokens":10}}`,
|
||||
}
|
||||
|
||||
request := []byte(`{"model":"gpt-5.4","tool_choice":"auto","parallel_tool_calls":true}`)
|
||||
|
||||
var param any
|
||||
var out [][]byte
|
||||
for _, line := range in {
|
||||
out = append(out, ConvertOpenAIChatCompletionsResponseToOpenAIResponses(context.Background(), "model", request, request, []byte(line), ¶m)...)
|
||||
}
|
||||
|
||||
type fcEvent struct {
|
||||
outputIndex int64
|
||||
name string
|
||||
arguments string
|
||||
}
|
||||
|
||||
added := map[string]fcEvent{}
|
||||
done := map[string]fcEvent{}
|
||||
|
||||
for _, chunk := range out {
|
||||
ev, data := parseOpenAIResponsesSSEEvent(t, chunk)
|
||||
switch ev {
|
||||
case "response.output_item.added":
|
||||
if data.Get("item.type").String() != "function_call" {
|
||||
continue
|
||||
}
|
||||
callID := data.Get("item.call_id").String()
|
||||
added[callID] = fcEvent{
|
||||
outputIndex: data.Get("output_index").Int(),
|
||||
name: data.Get("item.name").String(),
|
||||
}
|
||||
case "response.output_item.done":
|
||||
if data.Get("item.type").String() != "function_call" {
|
||||
continue
|
||||
}
|
||||
callID := data.Get("item.call_id").String()
|
||||
done[callID] = fcEvent{
|
||||
outputIndex: data.Get("output_index").Int(),
|
||||
name: data.Get("item.name").String(),
|
||||
arguments: data.Get("item.arguments").String(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if len(added) != 2 {
|
||||
t.Fatalf("expected 2 function_call added events, got %d", len(added))
|
||||
}
|
||||
if len(done) != 2 {
|
||||
t.Fatalf("expected 2 function_call done events, got %d", len(done))
|
||||
}
|
||||
|
||||
if added["call_choice0"].name != "glob" {
|
||||
t.Fatalf("unexpected added name for call_choice0: %q", added["call_choice0"].name)
|
||||
}
|
||||
if added["call_choice1"].name != "read" {
|
||||
t.Fatalf("unexpected added name for call_choice1: %q", added["call_choice1"].name)
|
||||
}
|
||||
if added["call_choice0"].outputIndex == added["call_choice1"].outputIndex {
|
||||
t.Fatalf("expected distinct output indexes for different choices, both got %d", added["call_choice0"].outputIndex)
|
||||
}
|
||||
|
||||
if !gjson.Valid(done["call_choice0"].arguments) {
|
||||
t.Fatalf("invalid JSON args for call_choice0: %q", done["call_choice0"].arguments)
|
||||
}
|
||||
if !gjson.Valid(done["call_choice1"].arguments) {
|
||||
t.Fatalf("invalid JSON args for call_choice1: %q", done["call_choice1"].arguments)
|
||||
}
|
||||
if done["call_choice0"].outputIndex == done["call_choice1"].outputIndex {
|
||||
t.Fatalf("expected distinct done output indexes for different choices, both got %d", done["call_choice0"].outputIndex)
|
||||
}
|
||||
if done["call_choice0"].name != "glob" {
|
||||
t.Fatalf("unexpected done name for call_choice0: %q", done["call_choice0"].name)
|
||||
}
|
||||
if done["call_choice1"].name != "read" {
|
||||
t.Fatalf("unexpected done name for call_choice1: %q", done["call_choice1"].name)
|
||||
}
|
||||
}
|
||||
|
||||
func TestConvertOpenAIChatCompletionsResponseToOpenAIResponses_MixedMessageAndToolUseDistinctOutputIndexes(t *testing.T) {
|
||||
in := []string{
|
||||
`data: {"id":"resp_mixed","object":"chat.completion.chunk","created":1773896263,"model":"model","choices":[{"index":0,"delta":{"role":"assistant","content":"hello","reasoning_content":null,"tool_calls":null},"finish_reason":null},{"index":1,"delta":{"role":"assistant","content":null,"reasoning_content":null,"tool_calls":[{"index":0,"id":"call_choice1","type":"function","function":{"name":"read","arguments":""}}]},"finish_reason":null}]}`,
|
||||
`data: {"id":"resp_mixed","object":"chat.completion.chunk","created":1773896263,"model":"model","choices":[{"index":0,"delta":{"role":null,"content":null,"reasoning_content":null,"tool_calls":null},"finish_reason":"stop"},{"index":1,"delta":{"role":null,"content":null,"reasoning_content":null,"tool_calls":[{"index":0,"function":{"arguments":"{\"filePath\":\"C:\\\\repo\\\\README.md\",\"limit\":20,\"offset\":1}"}}]},"finish_reason":"tool_calls"}],"usage":{"completion_tokens":10,"total_tokens":20,"prompt_tokens":10}}`,
|
||||
}
|
||||
|
||||
request := []byte(`{"model":"gpt-5.4","tool_choice":"auto","parallel_tool_calls":true}`)
|
||||
|
||||
var param any
|
||||
var out [][]byte
|
||||
for _, line := range in {
|
||||
out = append(out, ConvertOpenAIChatCompletionsResponseToOpenAIResponses(context.Background(), "model", request, request, []byte(line), ¶m)...)
|
||||
}
|
||||
|
||||
var messageOutputIndex int64 = -1
|
||||
var toolOutputIndex int64 = -1
|
||||
|
||||
for _, chunk := range out {
|
||||
ev, data := parseOpenAIResponsesSSEEvent(t, chunk)
|
||||
if ev != "response.output_item.added" {
|
||||
continue
|
||||
}
|
||||
switch data.Get("item.type").String() {
|
||||
case "message":
|
||||
if data.Get("item.id").String() == "msg_resp_mixed_0" {
|
||||
messageOutputIndex = data.Get("output_index").Int()
|
||||
}
|
||||
case "function_call":
|
||||
if data.Get("item.call_id").String() == "call_choice1" {
|
||||
toolOutputIndex = data.Get("output_index").Int()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if messageOutputIndex < 0 {
|
||||
t.Fatal("did not find message output index")
|
||||
}
|
||||
if toolOutputIndex < 0 {
|
||||
t.Fatal("did not find tool output index")
|
||||
}
|
||||
if messageOutputIndex == toolOutputIndex {
|
||||
t.Fatalf("expected distinct output indexes for message and tool call, both got %d", messageOutputIndex)
|
||||
}
|
||||
}
|
||||
|
||||
func TestConvertOpenAIChatCompletionsResponseToOpenAIResponses_FunctionCallDoneAndCompletedOutputStayAscending(t *testing.T) {
|
||||
in := []string{
|
||||
`data: {"id":"resp_order","object":"chat.completion.chunk","created":1773896263,"model":"model","choices":[{"index":0,"delta":{"role":"assistant","content":null,"reasoning_content":null,"tool_calls":[{"index":0,"id":"call_glob","type":"function","function":{"name":"glob","arguments":""}}]},"finish_reason":null}]}`,
|
||||
`data: {"id":"resp_order","object":"chat.completion.chunk","created":1773896263,"model":"model","choices":[{"index":0,"delta":{"role":null,"content":null,"reasoning_content":null,"tool_calls":[{"index":0,"function":{"arguments":"{\"path\":\"C:\\\\repo\",\"pattern\":\"*.go\"}"}}]},"finish_reason":null}]}`,
|
||||
`data: {"id":"resp_order","object":"chat.completion.chunk","created":1773896263,"model":"model","choices":[{"index":0,"delta":{"role":"assistant","content":null,"reasoning_content":null,"tool_calls":[{"index":1,"id":"call_read","type":"function","function":{"name":"read","arguments":""}}]},"finish_reason":null}]}`,
|
||||
`data: {"id":"resp_order","object":"chat.completion.chunk","created":1773896263,"model":"model","choices":[{"index":0,"delta":{"role":null,"content":null,"reasoning_content":null,"tool_calls":[{"index":1,"function":{"arguments":"{\"filePath\":\"C:\\\\repo\\\\README.md\",\"limit\":20,\"offset\":1}"}}]},"finish_reason":null}]}`,
|
||||
`data: {"id":"resp_order","object":"chat.completion.chunk","created":1773896263,"model":"model","choices":[{"index":0,"delta":{"role":null,"content":null,"reasoning_content":null,"tool_calls":null},"finish_reason":"tool_calls"}],"usage":{"completion_tokens":10,"total_tokens":20,"prompt_tokens":10}}`,
|
||||
}
|
||||
|
||||
request := []byte(`{"model":"gpt-5.4","tool_choice":"auto","parallel_tool_calls":true}`)
|
||||
|
||||
var param any
|
||||
var out [][]byte
|
||||
for _, line := range in {
|
||||
out = append(out, ConvertOpenAIChatCompletionsResponseToOpenAIResponses(context.Background(), "model", request, request, []byte(line), ¶m)...)
|
||||
}
|
||||
|
||||
var doneIndexes []int64
|
||||
var completedOrder []string
|
||||
|
||||
for _, chunk := range out {
|
||||
ev, data := parseOpenAIResponsesSSEEvent(t, chunk)
|
||||
switch ev {
|
||||
case "response.output_item.done":
|
||||
if data.Get("item.type").String() == "function_call" {
|
||||
doneIndexes = append(doneIndexes, data.Get("output_index").Int())
|
||||
}
|
||||
case "response.completed":
|
||||
for _, item := range data.Get("response.output").Array() {
|
||||
if item.Get("type").String() == "function_call" {
|
||||
completedOrder = append(completedOrder, item.Get("call_id").String())
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if len(doneIndexes) != 2 {
|
||||
t.Fatalf("expected 2 function_call done indexes, got %d", len(doneIndexes))
|
||||
}
|
||||
if doneIndexes[0] >= doneIndexes[1] {
|
||||
t.Fatalf("expected ascending done output indexes, got %v", doneIndexes)
|
||||
}
|
||||
if len(completedOrder) != 2 {
|
||||
t.Fatalf("expected 2 function_call items in completed output, got %d", len(completedOrder))
|
||||
}
|
||||
if completedOrder[0] != "call_glob" || completedOrder[1] != "call_read" {
|
||||
t.Fatalf("unexpected completed function_call order: %v", completedOrder)
|
||||
}
|
||||
}
|
||||
@@ -57,6 +57,12 @@ func GetProviderName(modelName string) []string {
|
||||
return providers
|
||||
}
|
||||
|
||||
// Fallback: if cursor provider has registered models, route unknown models to it.
|
||||
// Cursor acts as a universal proxy supporting multiple model families (Claude, GPT, Gemini, etc.).
|
||||
if models := registry.GetGlobalRegistry().GetAvailableModelsByProvider("cursor"); len(models) > 0 {
|
||||
return []string{"cursor"}
|
||||
}
|
||||
|
||||
return providers
|
||||
}
|
||||
|
||||
|
||||
98
sdk/auth/cursor.go
Normal file
98
sdk/auth/cursor.go
Normal file
@@ -0,0 +1,98 @@
|
||||
package auth
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
cursorauth "github.com/router-for-me/CLIProxyAPI/v6/internal/auth/cursor"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/browser"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
|
||||
coreauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth"
|
||||
log "github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
// CursorAuthenticator implements OAuth PKCE login for Cursor.
|
||||
type CursorAuthenticator struct{}
|
||||
|
||||
// NewCursorAuthenticator constructs a new Cursor authenticator.
|
||||
func NewCursorAuthenticator() Authenticator {
|
||||
return &CursorAuthenticator{}
|
||||
}
|
||||
|
||||
// Provider returns the provider key for cursor.
|
||||
func (CursorAuthenticator) Provider() string {
|
||||
return "cursor"
|
||||
}
|
||||
|
||||
// RefreshLead returns the time before expiry when a refresh should be attempted.
|
||||
func (CursorAuthenticator) RefreshLead() *time.Duration {
|
||||
d := 10 * time.Minute
|
||||
return &d
|
||||
}
|
||||
|
||||
// Login initiates the Cursor PKCE authentication flow.
|
||||
func (a CursorAuthenticator) Login(ctx context.Context, cfg *config.Config, opts *LoginOptions) (*coreauth.Auth, error) {
|
||||
if cfg == nil {
|
||||
return nil, fmt.Errorf("cursor auth: configuration is required")
|
||||
}
|
||||
if opts == nil {
|
||||
opts = &LoginOptions{}
|
||||
}
|
||||
|
||||
// Generate PKCE auth parameters
|
||||
authParams, err := cursorauth.GenerateAuthParams()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("cursor: failed to generate auth params: %w", err)
|
||||
}
|
||||
|
||||
// Display the login URL
|
||||
log.Info("Starting Cursor authentication...")
|
||||
log.Infof("Please visit this URL to log in: %s", authParams.LoginURL)
|
||||
|
||||
// Try to open the browser automatically
|
||||
if !opts.NoBrowser {
|
||||
if browser.IsAvailable() {
|
||||
if errOpen := browser.OpenURL(authParams.LoginURL); errOpen != nil {
|
||||
log.Warnf("Failed to open browser automatically: %v", errOpen)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
log.Info("Waiting for Cursor authorization...")
|
||||
|
||||
// Poll for the auth result
|
||||
tokens, err := cursorauth.PollForAuth(ctx, authParams.UUID, authParams.Verifier)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("cursor: authentication failed: %w", err)
|
||||
}
|
||||
|
||||
expiresAt := cursorauth.GetTokenExpiry(tokens.AccessToken)
|
||||
|
||||
// Auto-identify account from JWT sub claim
|
||||
sub := cursorauth.ParseJWTSub(tokens.AccessToken)
|
||||
subHash := cursorauth.SubToShortHash(sub)
|
||||
|
||||
log.Info("Cursor authentication successful!")
|
||||
|
||||
metadata := map[string]any{
|
||||
"type": "cursor",
|
||||
"access_token": tokens.AccessToken,
|
||||
"refresh_token": tokens.RefreshToken,
|
||||
"expires_at": expiresAt.Format(time.RFC3339),
|
||||
"timestamp": time.Now().UnixMilli(),
|
||||
}
|
||||
if sub != "" {
|
||||
metadata["sub"] = sub
|
||||
}
|
||||
|
||||
fileName := cursorauth.CredentialFileName("", subHash)
|
||||
|
||||
return &coreauth.Auth{
|
||||
ID: fileName,
|
||||
Provider: a.Provider(),
|
||||
FileName: fileName,
|
||||
Label: cursorauth.DisplayLabel("", subHash),
|
||||
Metadata: metadata,
|
||||
}, nil
|
||||
}
|
||||
@@ -209,9 +209,6 @@ waitForCallback:
|
||||
metadata := buildGitLabAuthMetadata(baseURL, gitLabLoginModeOAuth, tokenResp, direct)
|
||||
metadata["auth_kind"] = "oauth"
|
||||
metadata[gitLabOAuthClientIDMetadataKey] = clientID
|
||||
if strings.TrimSpace(clientSecret) != "" {
|
||||
metadata[gitLabOAuthClientSecretMetadataKey] = clientSecret
|
||||
}
|
||||
metadata["username"] = strings.TrimSpace(user.Username)
|
||||
if email := strings.TrimSpace(primaryGitLabEmail(user)); email != "" {
|
||||
metadata["email"] = email
|
||||
|
||||
@@ -19,6 +19,7 @@ func init() {
|
||||
registerRefreshLead("github-copilot", func() Authenticator { return NewGitHubCopilotAuthenticator() })
|
||||
registerRefreshLead("gitlab", func() Authenticator { return NewGitLabAuthenticator() })
|
||||
registerRefreshLead("codebuddy", func() Authenticator { return NewCodeBuddyAuthenticator() })
|
||||
registerRefreshLead("cursor", func() Authenticator { return NewCursorAuthenticator() })
|
||||
}
|
||||
|
||||
func registerRefreshLead(provider string, factory func() Authenticator) {
|
||||
|
||||
@@ -923,8 +923,10 @@ func (m *Manager) Update(ctx context.Context, auth *Auth) (*Auth, error) {
|
||||
auth.Index = existing.Index
|
||||
auth.indexAssigned = existing.indexAssigned
|
||||
}
|
||||
if len(auth.ModelStates) == 0 && len(existing.ModelStates) > 0 {
|
||||
auth.ModelStates = existing.ModelStates
|
||||
if !existing.Disabled && existing.Status != StatusDisabled && !auth.Disabled && auth.Status != StatusDisabled {
|
||||
if len(auth.ModelStates) == 0 && len(existing.ModelStates) > 0 {
|
||||
auth.ModelStates = existing.ModelStates
|
||||
}
|
||||
}
|
||||
}
|
||||
auth.EnsureIndex()
|
||||
|
||||
@@ -47,3 +47,158 @@ func TestManager_Update_PreservesModelStates(t *testing.T) {
|
||||
t.Fatalf("expected BackoffLevel to be %d, got %d", backoffLevel, state.Quota.BackoffLevel)
|
||||
}
|
||||
}
|
||||
|
||||
func TestManager_Update_DisabledExistingDoesNotInheritModelStates(t *testing.T) {
|
||||
m := NewManager(nil, nil, nil)
|
||||
|
||||
// Register a disabled auth with existing ModelStates.
|
||||
if _, err := m.Register(context.Background(), &Auth{
|
||||
ID: "auth-disabled",
|
||||
Provider: "claude",
|
||||
Disabled: true,
|
||||
Status: StatusDisabled,
|
||||
ModelStates: map[string]*ModelState{
|
||||
"stale-model": {
|
||||
Quota: QuotaState{BackoffLevel: 5},
|
||||
},
|
||||
},
|
||||
}); err != nil {
|
||||
t.Fatalf("register auth: %v", err)
|
||||
}
|
||||
|
||||
// Update with empty ModelStates — should NOT inherit stale states.
|
||||
if _, err := m.Update(context.Background(), &Auth{
|
||||
ID: "auth-disabled",
|
||||
Provider: "claude",
|
||||
Disabled: true,
|
||||
Status: StatusDisabled,
|
||||
}); err != nil {
|
||||
t.Fatalf("update auth: %v", err)
|
||||
}
|
||||
|
||||
updated, ok := m.GetByID("auth-disabled")
|
||||
if !ok || updated == nil {
|
||||
t.Fatalf("expected auth to be present")
|
||||
}
|
||||
if len(updated.ModelStates) != 0 {
|
||||
t.Fatalf("expected disabled auth NOT to inherit ModelStates, got %d entries", len(updated.ModelStates))
|
||||
}
|
||||
}
|
||||
|
||||
func TestManager_Update_ActiveToDisabledDoesNotInheritModelStates(t *testing.T) {
|
||||
m := NewManager(nil, nil, nil)
|
||||
|
||||
// Register an active auth with ModelStates (simulates existing live auth).
|
||||
if _, err := m.Register(context.Background(), &Auth{
|
||||
ID: "auth-a2d",
|
||||
Provider: "claude",
|
||||
Status: StatusActive,
|
||||
ModelStates: map[string]*ModelState{
|
||||
"stale-model": {
|
||||
Quota: QuotaState{BackoffLevel: 9},
|
||||
},
|
||||
},
|
||||
}); err != nil {
|
||||
t.Fatalf("register auth: %v", err)
|
||||
}
|
||||
|
||||
// File watcher deletes config → synthesizes Disabled=true auth → Update.
|
||||
// Even though existing is active, incoming auth is disabled → skip inheritance.
|
||||
if _, err := m.Update(context.Background(), &Auth{
|
||||
ID: "auth-a2d",
|
||||
Provider: "claude",
|
||||
Disabled: true,
|
||||
Status: StatusDisabled,
|
||||
}); err != nil {
|
||||
t.Fatalf("update auth: %v", err)
|
||||
}
|
||||
|
||||
updated, ok := m.GetByID("auth-a2d")
|
||||
if !ok || updated == nil {
|
||||
t.Fatalf("expected auth to be present")
|
||||
}
|
||||
if len(updated.ModelStates) != 0 {
|
||||
t.Fatalf("expected active→disabled transition NOT to inherit ModelStates, got %d entries", len(updated.ModelStates))
|
||||
}
|
||||
}
|
||||
|
||||
func TestManager_Update_DisabledToActiveDoesNotInheritStaleModelStates(t *testing.T) {
|
||||
m := NewManager(nil, nil, nil)
|
||||
|
||||
// Register a disabled auth with stale ModelStates.
|
||||
if _, err := m.Register(context.Background(), &Auth{
|
||||
ID: "auth-d2a",
|
||||
Provider: "claude",
|
||||
Disabled: true,
|
||||
Status: StatusDisabled,
|
||||
ModelStates: map[string]*ModelState{
|
||||
"stale-model": {
|
||||
Quota: QuotaState{BackoffLevel: 4},
|
||||
},
|
||||
},
|
||||
}); err != nil {
|
||||
t.Fatalf("register auth: %v", err)
|
||||
}
|
||||
|
||||
// Re-enable: incoming auth is active, existing is disabled → skip inheritance.
|
||||
if _, err := m.Update(context.Background(), &Auth{
|
||||
ID: "auth-d2a",
|
||||
Provider: "claude",
|
||||
Status: StatusActive,
|
||||
}); err != nil {
|
||||
t.Fatalf("update auth: %v", err)
|
||||
}
|
||||
|
||||
updated, ok := m.GetByID("auth-d2a")
|
||||
if !ok || updated == nil {
|
||||
t.Fatalf("expected auth to be present")
|
||||
}
|
||||
if len(updated.ModelStates) != 0 {
|
||||
t.Fatalf("expected disabled→active transition NOT to inherit stale ModelStates, got %d entries", len(updated.ModelStates))
|
||||
}
|
||||
}
|
||||
|
||||
func TestManager_Update_ActiveInheritsModelStates(t *testing.T) {
|
||||
m := NewManager(nil, nil, nil)
|
||||
|
||||
model := "active-model"
|
||||
backoffLevel := 3
|
||||
|
||||
// Register an active auth with ModelStates.
|
||||
if _, err := m.Register(context.Background(), &Auth{
|
||||
ID: "auth-active",
|
||||
Provider: "claude",
|
||||
Status: StatusActive,
|
||||
ModelStates: map[string]*ModelState{
|
||||
model: {
|
||||
Quota: QuotaState{BackoffLevel: backoffLevel},
|
||||
},
|
||||
},
|
||||
}); err != nil {
|
||||
t.Fatalf("register auth: %v", err)
|
||||
}
|
||||
|
||||
// Update with empty ModelStates — both sides active → SHOULD inherit.
|
||||
if _, err := m.Update(context.Background(), &Auth{
|
||||
ID: "auth-active",
|
||||
Provider: "claude",
|
||||
Status: StatusActive,
|
||||
}); err != nil {
|
||||
t.Fatalf("update auth: %v", err)
|
||||
}
|
||||
|
||||
updated, ok := m.GetByID("auth-active")
|
||||
if !ok || updated == nil {
|
||||
t.Fatalf("expected auth to be present")
|
||||
}
|
||||
if len(updated.ModelStates) == 0 {
|
||||
t.Fatalf("expected active auth to inherit ModelStates")
|
||||
}
|
||||
state := updated.ModelStates[model]
|
||||
if state == nil {
|
||||
t.Fatalf("expected model state to be present")
|
||||
}
|
||||
if state.Quota.BackoffLevel != backoffLevel {
|
||||
t.Fatalf("expected BackoffLevel to be %d, got %d", backoffLevel, state.Quota.BackoffLevel)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -545,6 +545,11 @@ func (m *scheduledAuthMeta) supportsModel(modelKey string) bool {
|
||||
if modelKey == "" {
|
||||
return true
|
||||
}
|
||||
// Cursor acts as a universal proxy supporting multiple model families.
|
||||
// Allow any model to be routed to cursor auth.
|
||||
if m.providerKey == "cursor" {
|
||||
return true
|
||||
}
|
||||
if len(m.supportedModelSet) == 0 {
|
||||
return false
|
||||
}
|
||||
|
||||
@@ -298,10 +298,12 @@ func (s *Service) applyCoreAuthAddOrUpdate(ctx context.Context, auth *coreauth.A
|
||||
var err error
|
||||
if existing, ok := s.coreManager.GetByID(auth.ID); ok {
|
||||
auth.CreatedAt = existing.CreatedAt
|
||||
auth.LastRefreshedAt = existing.LastRefreshedAt
|
||||
auth.NextRefreshAfter = existing.NextRefreshAfter
|
||||
if len(auth.ModelStates) == 0 && len(existing.ModelStates) > 0 {
|
||||
auth.ModelStates = existing.ModelStates
|
||||
if !existing.Disabled && existing.Status != coreauth.StatusDisabled && !auth.Disabled && auth.Status != coreauth.StatusDisabled {
|
||||
auth.LastRefreshedAt = existing.LastRefreshedAt
|
||||
auth.NextRefreshAfter = existing.NextRefreshAfter
|
||||
if len(auth.ModelStates) == 0 && len(existing.ModelStates) > 0 {
|
||||
auth.ModelStates = existing.ModelStates
|
||||
}
|
||||
}
|
||||
op = "update"
|
||||
_, err = s.coreManager.Update(ctx, auth)
|
||||
@@ -441,6 +443,8 @@ func (s *Service) ensureExecutorsForAuthWithMode(a *coreauth.Auth, forceReplace
|
||||
s.coreManager.RegisterExecutor(executor.NewKiroExecutor(s.cfg))
|
||||
case "kilo":
|
||||
s.coreManager.RegisterExecutor(executor.NewKiloExecutor(s.cfg))
|
||||
case "cursor":
|
||||
s.coreManager.RegisterExecutor(executor.NewCursorExecutor(s.cfg))
|
||||
case "github-copilot":
|
||||
s.coreManager.RegisterExecutor(executor.NewGitHubCopilotExecutor(s.cfg))
|
||||
case "codebuddy":
|
||||
@@ -942,6 +946,11 @@ func (s *Service) registerModelsForAuth(a *coreauth.Auth) {
|
||||
case "kimi":
|
||||
models = registry.GetKimiModels()
|
||||
models = applyExcludedModels(models, excluded)
|
||||
case "cursor":
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 15*time.Second)
|
||||
defer cancel()
|
||||
models = executor.FetchCursorModels(ctx, a, s.cfg)
|
||||
models = applyExcludedModels(models, excluded)
|
||||
case "github-copilot":
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 15*time.Second)
|
||||
defer cancel()
|
||||
|
||||
@@ -46,3 +46,41 @@ func TestRegisterModelsForAuth_GitLabUsesDiscoveredModels(t *testing.T) {
|
||||
t.Fatalf("expected gitlab-duo and discovered model, got %+v", models)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRegisterModelsForAuth_GitLabIncludesAgenticCatalog(t *testing.T) {
|
||||
service := &Service{cfg: &config.Config{}}
|
||||
auth := &coreauth.Auth{
|
||||
ID: "gitlab-agentic-auth.json",
|
||||
Provider: "gitlab",
|
||||
Status: coreauth.StatusActive,
|
||||
}
|
||||
|
||||
reg := registry.GetGlobalRegistry()
|
||||
reg.UnregisterClient(auth.ID)
|
||||
t.Cleanup(func() { reg.UnregisterClient(auth.ID) })
|
||||
|
||||
service.registerModelsForAuth(auth)
|
||||
models := reg.GetModelsForClient(auth.ID)
|
||||
if len(models) < 5 {
|
||||
t.Fatalf("expected stable alias plus built-in agentic catalog, got %d entries", len(models))
|
||||
}
|
||||
|
||||
required := map[string]bool{
|
||||
"gitlab-duo": false,
|
||||
"duo-chat-opus-4-6": false,
|
||||
"duo-chat-haiku-4-5": false,
|
||||
"duo-chat-sonnet-4-5": false,
|
||||
"duo-chat-opus-4-5": false,
|
||||
"duo-chat-gpt-5-codex": false,
|
||||
}
|
||||
for _, model := range models {
|
||||
if _, ok := required[model.ID]; ok {
|
||||
required[model.ID] = true
|
||||
}
|
||||
}
|
||||
for id, seen := range required {
|
||||
if !seen {
|
||||
t.Fatalf("expected built-in GitLab Duo model %q, got %+v", id, models)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
85
sdk/cliproxy/service_stale_state_test.go
Normal file
85
sdk/cliproxy/service_stale_state_test.go
Normal file
@@ -0,0 +1,85 @@
|
||||
package cliproxy
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/registry"
|
||||
coreauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/sdk/config"
|
||||
)
|
||||
|
||||
func TestServiceApplyCoreAuthAddOrUpdate_DeleteReAddDoesNotInheritStaleRuntimeState(t *testing.T) {
|
||||
service := &Service{
|
||||
cfg: &config.Config{},
|
||||
coreManager: coreauth.NewManager(nil, nil, nil),
|
||||
}
|
||||
|
||||
authID := "service-stale-state-auth"
|
||||
modelID := "stale-model"
|
||||
lastRefreshedAt := time.Date(2026, time.March, 1, 8, 0, 0, 0, time.UTC)
|
||||
nextRefreshAfter := lastRefreshedAt.Add(30 * time.Minute)
|
||||
|
||||
t.Cleanup(func() {
|
||||
GlobalModelRegistry().UnregisterClient(authID)
|
||||
})
|
||||
|
||||
service.applyCoreAuthAddOrUpdate(context.Background(), &coreauth.Auth{
|
||||
ID: authID,
|
||||
Provider: "claude",
|
||||
Status: coreauth.StatusActive,
|
||||
LastRefreshedAt: lastRefreshedAt,
|
||||
NextRefreshAfter: nextRefreshAfter,
|
||||
ModelStates: map[string]*coreauth.ModelState{
|
||||
modelID: {
|
||||
Quota: coreauth.QuotaState{BackoffLevel: 7},
|
||||
},
|
||||
},
|
||||
})
|
||||
|
||||
service.applyCoreAuthRemoval(context.Background(), authID)
|
||||
|
||||
disabled, ok := service.coreManager.GetByID(authID)
|
||||
if !ok || disabled == nil {
|
||||
t.Fatalf("expected disabled auth after removal")
|
||||
}
|
||||
if !disabled.Disabled || disabled.Status != coreauth.StatusDisabled {
|
||||
t.Fatalf("expected disabled auth after removal, got disabled=%v status=%v", disabled.Disabled, disabled.Status)
|
||||
}
|
||||
if disabled.LastRefreshedAt.IsZero() {
|
||||
t.Fatalf("expected disabled auth to still carry prior LastRefreshedAt for regression setup")
|
||||
}
|
||||
if disabled.NextRefreshAfter.IsZero() {
|
||||
t.Fatalf("expected disabled auth to still carry prior NextRefreshAfter for regression setup")
|
||||
}
|
||||
if len(disabled.ModelStates) == 0 {
|
||||
t.Fatalf("expected disabled auth to still carry prior ModelStates for regression setup")
|
||||
}
|
||||
|
||||
service.applyCoreAuthAddOrUpdate(context.Background(), &coreauth.Auth{
|
||||
ID: authID,
|
||||
Provider: "claude",
|
||||
Status: coreauth.StatusActive,
|
||||
})
|
||||
|
||||
updated, ok := service.coreManager.GetByID(authID)
|
||||
if !ok || updated == nil {
|
||||
t.Fatalf("expected re-added auth to be present")
|
||||
}
|
||||
if updated.Disabled {
|
||||
t.Fatalf("expected re-added auth to be active")
|
||||
}
|
||||
if !updated.LastRefreshedAt.IsZero() {
|
||||
t.Fatalf("expected LastRefreshedAt to reset on delete -> re-add, got %v", updated.LastRefreshedAt)
|
||||
}
|
||||
if !updated.NextRefreshAfter.IsZero() {
|
||||
t.Fatalf("expected NextRefreshAfter to reset on delete -> re-add, got %v", updated.NextRefreshAfter)
|
||||
}
|
||||
if len(updated.ModelStates) != 0 {
|
||||
t.Fatalf("expected ModelStates to reset on delete -> re-add, got %d entries", len(updated.ModelStates))
|
||||
}
|
||||
if models := registry.GetGlobalRegistry().GetModelsForClient(authID); len(models) == 0 {
|
||||
t.Fatalf("expected re-added auth to re-register models in global registry")
|
||||
}
|
||||
}
|
||||
@@ -68,14 +68,18 @@ func Parse(raw string) (Setting, error) {
|
||||
}
|
||||
}
|
||||
|
||||
func cloneDefaultTransport() *http.Transport {
|
||||
if transport, ok := http.DefaultTransport.(*http.Transport); ok && transport != nil {
|
||||
return transport.Clone()
|
||||
}
|
||||
return &http.Transport{}
|
||||
}
|
||||
|
||||
// NewDirectTransport returns a transport that bypasses environment proxies.
|
||||
func NewDirectTransport() *http.Transport {
|
||||
if transport, ok := http.DefaultTransport.(*http.Transport); ok && transport != nil {
|
||||
clone := transport.Clone()
|
||||
clone.Proxy = nil
|
||||
return clone
|
||||
}
|
||||
return &http.Transport{Proxy: nil}
|
||||
clone := cloneDefaultTransport()
|
||||
clone.Proxy = nil
|
||||
return clone
|
||||
}
|
||||
|
||||
// BuildHTTPTransport constructs an HTTP transport for the provided proxy setting.
|
||||
@@ -102,14 +106,16 @@ func BuildHTTPTransport(raw string) (*http.Transport, Mode, error) {
|
||||
if errSOCKS5 != nil {
|
||||
return nil, setting.Mode, fmt.Errorf("create SOCKS5 dialer failed: %w", errSOCKS5)
|
||||
}
|
||||
return &http.Transport{
|
||||
Proxy: nil,
|
||||
DialContext: func(_ context.Context, network, addr string) (net.Conn, error) {
|
||||
return dialer.Dial(network, addr)
|
||||
},
|
||||
}, setting.Mode, nil
|
||||
transport := cloneDefaultTransport()
|
||||
transport.Proxy = nil
|
||||
transport.DialContext = func(_ context.Context, network, addr string) (net.Conn, error) {
|
||||
return dialer.Dial(network, addr)
|
||||
}
|
||||
return transport, setting.Mode, nil
|
||||
}
|
||||
return &http.Transport{Proxy: http.ProxyURL(setting.URL)}, setting.Mode, nil
|
||||
transport := cloneDefaultTransport()
|
||||
transport.Proxy = http.ProxyURL(setting.URL)
|
||||
return transport, setting.Mode, nil
|
||||
default:
|
||||
return nil, setting.Mode, nil
|
||||
}
|
||||
|
||||
@@ -5,6 +5,16 @@ import (
|
||||
"testing"
|
||||
)
|
||||
|
||||
func mustDefaultTransport(t *testing.T) *http.Transport {
|
||||
t.Helper()
|
||||
|
||||
transport, ok := http.DefaultTransport.(*http.Transport)
|
||||
if !ok || transport == nil {
|
||||
t.Fatal("http.DefaultTransport is not an *http.Transport")
|
||||
}
|
||||
return transport
|
||||
}
|
||||
|
||||
func TestParse(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
@@ -86,4 +96,44 @@ func TestBuildHTTPTransportHTTPProxy(t *testing.T) {
|
||||
if proxyURL == nil || proxyURL.String() != "http://proxy.example.com:8080" {
|
||||
t.Fatalf("proxy URL = %v, want http://proxy.example.com:8080", proxyURL)
|
||||
}
|
||||
|
||||
defaultTransport := mustDefaultTransport(t)
|
||||
if transport.ForceAttemptHTTP2 != defaultTransport.ForceAttemptHTTP2 {
|
||||
t.Fatalf("ForceAttemptHTTP2 = %v, want %v", transport.ForceAttemptHTTP2, defaultTransport.ForceAttemptHTTP2)
|
||||
}
|
||||
if transport.IdleConnTimeout != defaultTransport.IdleConnTimeout {
|
||||
t.Fatalf("IdleConnTimeout = %v, want %v", transport.IdleConnTimeout, defaultTransport.IdleConnTimeout)
|
||||
}
|
||||
if transport.TLSHandshakeTimeout != defaultTransport.TLSHandshakeTimeout {
|
||||
t.Fatalf("TLSHandshakeTimeout = %v, want %v", transport.TLSHandshakeTimeout, defaultTransport.TLSHandshakeTimeout)
|
||||
}
|
||||
}
|
||||
|
||||
func TestBuildHTTPTransportSOCKS5ProxyInheritsDefaultTransportSettings(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
transport, mode, errBuild := BuildHTTPTransport("socks5://proxy.example.com:1080")
|
||||
if errBuild != nil {
|
||||
t.Fatalf("BuildHTTPTransport returned error: %v", errBuild)
|
||||
}
|
||||
if mode != ModeProxy {
|
||||
t.Fatalf("mode = %d, want %d", mode, ModeProxy)
|
||||
}
|
||||
if transport == nil {
|
||||
t.Fatal("expected transport, got nil")
|
||||
}
|
||||
if transport.Proxy != nil {
|
||||
t.Fatal("expected SOCKS5 transport to bypass http proxy function")
|
||||
}
|
||||
|
||||
defaultTransport := mustDefaultTransport(t)
|
||||
if transport.ForceAttemptHTTP2 != defaultTransport.ForceAttemptHTTP2 {
|
||||
t.Fatalf("ForceAttemptHTTP2 = %v, want %v", transport.ForceAttemptHTTP2, defaultTransport.ForceAttemptHTTP2)
|
||||
}
|
||||
if transport.IdleConnTimeout != defaultTransport.IdleConnTimeout {
|
||||
t.Fatalf("IdleConnTimeout = %v, want %v", transport.IdleConnTimeout, defaultTransport.IdleConnTimeout)
|
||||
}
|
||||
if transport.TLSHandshakeTimeout != defaultTransport.TLSHandshakeTimeout {
|
||||
t.Fatalf("TLSHandshakeTimeout = %v, want %v", transport.TLSHandshakeTimeout, defaultTransport.TLSHandshakeTimeout)
|
||||
}
|
||||
}
|
||||
|
||||
309
test_cursor.sh
Executable file
309
test_cursor.sh
Executable file
@@ -0,0 +1,309 @@
|
||||
#!/bin/bash
|
||||
# Test script for Cursor proxy integration
|
||||
# Usage:
|
||||
# ./test_cursor.sh login - Login to Cursor (opens browser)
|
||||
# ./test_cursor.sh start - Build and start the server
|
||||
# ./test_cursor.sh test - Run API tests against running server
|
||||
# ./test_cursor.sh all - Login + Start + Test (full flow)
|
||||
|
||||
set -e
|
||||
|
||||
export PATH="/opt/homebrew/bin:$PATH"
|
||||
export GOROOT="/opt/homebrew/Cellar/go/1.26.1/libexec"
|
||||
|
||||
PROJECT_DIR="/Volumes/Personal/cursor-cli-proxy/CLIProxyAPIPlus"
|
||||
BINARY="$PROJECT_DIR/cliproxy-test"
|
||||
API_KEY="quotio-local-D6ABC285-3085-44B4-B872-BD269888811F"
|
||||
BASE_URL="http://127.0.0.1:8317"
|
||||
CONFIG="$PROJECT_DIR/config-cursor-test.yaml"
|
||||
PID_FILE="/tmp/cliproxy-test.pid"
|
||||
|
||||
# Colors
|
||||
GREEN='\033[0;32m'
|
||||
RED='\033[0;31m'
|
||||
YELLOW='\033[1;33m'
|
||||
NC='\033[0m'
|
||||
|
||||
info() { echo -e "${GREEN}[INFO]${NC} $1"; }
|
||||
warn() { echo -e "${YELLOW}[WARN]${NC} $1"; }
|
||||
error() { echo -e "${RED}[ERROR]${NC} $1"; }
|
||||
|
||||
# --- Build ---
|
||||
build() {
|
||||
info "Building CLIProxyAPIPlus..."
|
||||
cd "$PROJECT_DIR"
|
||||
go build -o "$BINARY" ./cmd/server/
|
||||
info "Build successful: $BINARY"
|
||||
}
|
||||
|
||||
# --- Create test config ---
|
||||
create_config() {
|
||||
cat > "$CONFIG" << 'EOF'
|
||||
host: '127.0.0.1'
|
||||
port: 8317
|
||||
auth-dir: '~/.cli-proxy-api'
|
||||
api-keys:
|
||||
- 'quotio-local-D6ABC285-3085-44B4-B872-BD269888811F'
|
||||
debug: true
|
||||
EOF
|
||||
info "Test config created: $CONFIG"
|
||||
}
|
||||
|
||||
# --- Login ---
|
||||
do_login() {
|
||||
build
|
||||
create_config
|
||||
info "Starting Cursor login (will open browser)..."
|
||||
"$BINARY" --config "$CONFIG" --cursor-login
|
||||
}
|
||||
|
||||
# --- Start server ---
|
||||
start_server() {
|
||||
# Kill any existing instance
|
||||
stop_server 2>/dev/null || true
|
||||
|
||||
build
|
||||
create_config
|
||||
|
||||
info "Starting server on port 8317..."
|
||||
"$BINARY" --config "$CONFIG" &
|
||||
SERVER_PID=$!
|
||||
echo "$SERVER_PID" > "$PID_FILE"
|
||||
info "Server started (PID: $SERVER_PID)"
|
||||
|
||||
# Wait for server to be ready
|
||||
info "Waiting for server to be ready..."
|
||||
for i in $(seq 1 15); do
|
||||
if curl -s "$BASE_URL/v1/models" -H "Authorization: Bearer $API_KEY" > /dev/null 2>&1; then
|
||||
info "Server is ready!"
|
||||
return 0
|
||||
fi
|
||||
sleep 1
|
||||
done
|
||||
error "Server failed to start within 15 seconds"
|
||||
return 1
|
||||
}
|
||||
|
||||
# --- Stop server ---
|
||||
stop_server() {
|
||||
if [ -f "$PID_FILE" ]; then
|
||||
PID=$(cat "$PID_FILE")
|
||||
if kill -0 "$PID" 2>/dev/null; then
|
||||
info "Stopping server (PID: $PID)..."
|
||||
kill "$PID"
|
||||
rm -f "$PID_FILE"
|
||||
fi
|
||||
fi
|
||||
# Also kill any stale process on port 8317
|
||||
lsof -ti:8317 2>/dev/null | xargs kill 2>/dev/null || true
|
||||
}
|
||||
|
||||
# --- Test: List models ---
|
||||
test_models() {
|
||||
info "Testing GET /v1/models (looking for cursor models)..."
|
||||
RESPONSE=$(curl -s "$BASE_URL/v1/models" \
|
||||
-H "Authorization: Bearer $API_KEY")
|
||||
|
||||
CURSOR_MODELS=$(echo "$RESPONSE" | python3 -c "
|
||||
import json, sys
|
||||
try:
|
||||
data = json.load(sys.stdin)
|
||||
models = [m['id'] for m in data.get('data', []) if m.get('owned_by') == 'cursor' or m.get('type') == 'cursor']
|
||||
if models:
|
||||
print('\n'.join(models))
|
||||
else:
|
||||
print('NONE')
|
||||
except:
|
||||
print('ERROR')
|
||||
" 2>/dev/null || echo "PARSE_ERROR")
|
||||
|
||||
if [ "$CURSOR_MODELS" = "NONE" ] || [ "$CURSOR_MODELS" = "ERROR" ] || [ "$CURSOR_MODELS" = "PARSE_ERROR" ]; then
|
||||
warn "No cursor models found. Have you run '--cursor-login' first?"
|
||||
echo " Response preview: $(echo "$RESPONSE" | head -c 200)"
|
||||
return 1
|
||||
else
|
||||
info "Found cursor models:"
|
||||
echo "$CURSOR_MODELS" | while read -r model; do
|
||||
echo " - $model"
|
||||
done
|
||||
return 0
|
||||
fi
|
||||
}
|
||||
|
||||
# --- Test: Chat completion (streaming) ---
|
||||
test_chat_stream() {
|
||||
local model="${1:-cursor-small}"
|
||||
info "Testing POST /v1/chat/completions (stream, model=$model)..."
|
||||
|
||||
RESPONSE=$(curl -s --max-time 30 "$BASE_URL/v1/chat/completions" \
|
||||
-H "Authorization: Bearer $API_KEY" \
|
||||
-H "Content-Type: application/json" \
|
||||
-d "{
|
||||
\"model\": \"$model\",
|
||||
\"messages\": [{\"role\": \"user\", \"content\": \"Say hello in exactly 3 words.\"}],
|
||||
\"stream\": true
|
||||
}" 2>&1)
|
||||
|
||||
# Check if we got SSE data
|
||||
if echo "$RESPONSE" | grep -q "data:"; then
|
||||
# Extract content from SSE chunks
|
||||
CONTENT=$(echo "$RESPONSE" | grep "^data: " | grep -v "\[DONE\]" | while read -r line; do
|
||||
echo "${line#data: }" | python3 -c "
|
||||
import json, sys
|
||||
try:
|
||||
chunk = json.load(sys.stdin)
|
||||
delta = chunk.get('choices', [{}])[0].get('delta', {})
|
||||
content = delta.get('content', '')
|
||||
if content:
|
||||
sys.stdout.write(content)
|
||||
except:
|
||||
pass
|
||||
" 2>/dev/null
|
||||
done)
|
||||
|
||||
if [ -n "$CONTENT" ]; then
|
||||
info "Stream response received:"
|
||||
echo " Content: $CONTENT"
|
||||
return 0
|
||||
else
|
||||
warn "Got SSE chunks but no content extracted"
|
||||
echo " Raw (first 500 chars): $(echo "$RESPONSE" | head -c 500)"
|
||||
return 1
|
||||
fi
|
||||
else
|
||||
error "No SSE data received"
|
||||
echo " Response: $(echo "$RESPONSE" | head -c 300)"
|
||||
return 1
|
||||
fi
|
||||
}
|
||||
|
||||
# --- Test: Chat completion (non-streaming) ---
|
||||
test_chat_nonstream() {
|
||||
local model="${1:-cursor-small}"
|
||||
info "Testing POST /v1/chat/completions (non-stream, model=$model)..."
|
||||
|
||||
RESPONSE=$(curl -s --max-time 30 "$BASE_URL/v1/chat/completions" \
|
||||
-H "Authorization: Bearer $API_KEY" \
|
||||
-H "Content-Type: application/json" \
|
||||
-d "{
|
||||
\"model\": \"$model\",
|
||||
\"messages\": [{\"role\": \"user\", \"content\": \"What is 2+2? Answer with just the number.\"}],
|
||||
\"stream\": false
|
||||
}" 2>&1)
|
||||
|
||||
CONTENT=$(echo "$RESPONSE" | python3 -c "
|
||||
import json, sys
|
||||
try:
|
||||
data = json.load(sys.stdin)
|
||||
content = data['choices'][0]['message']['content']
|
||||
print(content)
|
||||
except Exception as e:
|
||||
print(f'ERROR: {e}')
|
||||
" 2>/dev/null || echo "PARSE_ERROR")
|
||||
|
||||
if echo "$CONTENT" | grep -q "ERROR\|PARSE_ERROR"; then
|
||||
error "Non-streaming request failed"
|
||||
echo " Response: $(echo "$RESPONSE" | head -c 300)"
|
||||
return 1
|
||||
else
|
||||
info "Non-stream response received:"
|
||||
echo " Content: $CONTENT"
|
||||
return 0
|
||||
fi
|
||||
}
|
||||
|
||||
# --- Run all tests ---
|
||||
run_tests() {
|
||||
local passed=0
|
||||
local failed=0
|
||||
|
||||
echo ""
|
||||
echo "========================================="
|
||||
echo " Cursor Proxy Integration Tests"
|
||||
echo "========================================="
|
||||
echo ""
|
||||
|
||||
# Test 1: Models
|
||||
if test_models; then
|
||||
((passed++))
|
||||
else
|
||||
((failed++))
|
||||
fi
|
||||
echo ""
|
||||
|
||||
# Test 2: Streaming chat
|
||||
if test_chat_stream "cursor-small"; then
|
||||
((passed++))
|
||||
else
|
||||
((failed++))
|
||||
fi
|
||||
echo ""
|
||||
|
||||
# Test 3: Non-streaming chat
|
||||
if test_chat_nonstream "cursor-small"; then
|
||||
((passed++))
|
||||
else
|
||||
((failed++))
|
||||
fi
|
||||
echo ""
|
||||
|
||||
echo "========================================="
|
||||
echo " Results: ${passed} passed, ${failed} failed"
|
||||
echo "========================================="
|
||||
|
||||
[ "$failed" -eq 0 ]
|
||||
}
|
||||
|
||||
# --- Cleanup ---
|
||||
cleanup() {
|
||||
stop_server
|
||||
rm -f "$BINARY" "$CONFIG"
|
||||
info "Cleaned up."
|
||||
}
|
||||
|
||||
# --- Main ---
|
||||
case "${1:-help}" in
|
||||
login)
|
||||
do_login
|
||||
;;
|
||||
start)
|
||||
start_server
|
||||
info "Server running. Use './test_cursor.sh test' to run tests."
|
||||
info "Use './test_cursor.sh stop' to stop."
|
||||
;;
|
||||
stop)
|
||||
stop_server
|
||||
;;
|
||||
test)
|
||||
run_tests
|
||||
;;
|
||||
all)
|
||||
info "=== Full flow: login -> start -> test ==="
|
||||
echo ""
|
||||
info "Step 1: Login to Cursor"
|
||||
do_login
|
||||
echo ""
|
||||
info "Step 2: Start server"
|
||||
start_server
|
||||
echo ""
|
||||
info "Step 3: Run tests"
|
||||
sleep 2
|
||||
run_tests
|
||||
echo ""
|
||||
info "Step 4: Cleanup"
|
||||
stop_server
|
||||
;;
|
||||
clean)
|
||||
cleanup
|
||||
;;
|
||||
*)
|
||||
echo "Usage: $0 {login|start|stop|test|all|clean}"
|
||||
echo ""
|
||||
echo " login - Authenticate with Cursor (opens browser)"
|
||||
echo " start - Build and start the proxy server"
|
||||
echo " stop - Stop the running server"
|
||||
echo " test - Run API tests against running server"
|
||||
echo " all - Full flow: login + start + test"
|
||||
echo " clean - Stop server and remove artifacts"
|
||||
;;
|
||||
esac
|
||||
Reference in New Issue
Block a user