feat: stash code

This commit is contained in:
黄姜恒
2026-03-25 10:14:14 +08:00
parent 7fa527193c
commit 19c52bcb60
20 changed files with 5006 additions and 0 deletions

20
cmd/mcpdebug/main.go Normal file
View File

@@ -0,0 +1,20 @@
package main
import (
"encoding/hex"
"fmt"
"os"
cursorproto "github.com/router-for-me/CLIProxyAPI/v6/internal/auth/cursor/proto"
)
func main() {
// Encode MCP result with empty execId
resultBytes := cursorproto.EncodeExecMcpResult(1, "", `{"test": "data"}`, false)
fmt.Printf("Result protobuf hex: %s\n", hex.EncodeToString(resultBytes))
fmt.Printf("Result length: %d bytes\n", len(resultBytes))
// Write to file for analysis
os.WriteFile("mcp_result.bin", resultBytes)
fmt.Println("Wrote mcp_result.bin")
}

32
cmd/protocheck/main.go Normal file
View File

@@ -0,0 +1,32 @@
package main
import (
"fmt"
cursorproto "github.com/router-for-me/CLIProxyAPI/v6/internal/auth/cursor/proto"
)
func main() {
ecm := cursorproto.NewMsg("ExecClientMessage")
// Try different field names
names := []string{
"mcp_result", "mcpResult", "McpResult", "MCP_RESULT",
"shell_result", "shellResult",
}
for _, name := range names {
fd := ecm.Descriptor().Fields().ByName(name)
if fd != nil {
fmt.Printf("Found field %q: number=%d, kind=%s\n", name, fd.Number(), fd.Kind())
} else {
fmt.Printf("Field %q NOT FOUND\n", name)
}
}
// List all fields
fmt.Println("\nAll fields in ExecClientMessage:")
for i := 0; i < ecm.Descriptor().Fields().Len(); i++ {
f := ecm.Descriptor().Fields().Get(i)
fmt.Printf(" %d: %q (number=%d)\n", i, f.Name(), f.Number())
}
}

View File

@@ -85,6 +85,7 @@ func main() {
var oauthCallbackPort int
var antigravityLogin bool
var kimiLogin bool
var cursorLogin bool
var kiroLogin bool
var kiroGoogleLogin bool
var kiroAWSLogin bool
@@ -123,6 +124,7 @@ func main() {
flag.BoolVar(&noIncognito, "no-incognito", false, "Force disable incognito mode (uses existing browser session)")
flag.BoolVar(&antigravityLogin, "antigravity-login", false, "Login to Antigravity using OAuth")
flag.BoolVar(&kimiLogin, "kimi-login", false, "Login to Kimi using OAuth")
flag.BoolVar(&cursorLogin, "cursor-login", false, "Login to Cursor using OAuth")
flag.BoolVar(&kiroLogin, "kiro-login", false, "Login to Kiro using Google OAuth")
flag.BoolVar(&kiroGoogleLogin, "kiro-google-login", false, "Login to Kiro using Google OAuth (same as --kiro-login)")
flag.BoolVar(&kiroAWSLogin, "kiro-aws-login", false, "Login to Kiro using AWS Builder ID (device code flow)")
@@ -544,6 +546,8 @@ func main() {
cmd.DoGitLabTokenLogin(cfg, options)
} else if kimiLogin {
cmd.DoKimiLogin(cfg, options)
} else if cursorLogin {
cmd.DoCursorLogin(cfg, options)
} else if kiroLogin {
// For Kiro auth, default to incognito mode for multi-account support
// Users can explicitly override with --no-incognito

View File

@@ -0,0 +1,218 @@
// Package cursor implements Cursor OAuth PKCE authentication and token refresh.
package cursor
import (
"context"
"crypto/rand"
"crypto/sha256"
"encoding/base64"
"encoding/json"
"fmt"
"io"
"math"
"net/http"
"strings"
"time"
)
const (
CursorLoginURL = "https://cursor.com/loginDeepControl"
CursorPollURL = "https://api2.cursor.sh/auth/poll"
CursorRefreshURL = "https://api2.cursor.sh/auth/exchange_user_api_key"
pollMaxAttempts = 150
pollBaseDelay = 1 * time.Second
pollMaxDelay = 10 * time.Second
pollBackoffMultiply = 1.2
maxConsecutiveErrors = 10
)
// AuthParams holds the PKCE parameters for Cursor login.
type AuthParams struct {
Verifier string
Challenge string
UUID string
LoginURL string
}
// TokenPair holds the access and refresh tokens from Cursor.
type TokenPair struct {
AccessToken string `json:"accessToken"`
RefreshToken string `json:"refreshToken"`
}
// GeneratePKCE creates a PKCE verifier and challenge pair.
func GeneratePKCE() (verifier, challenge string, err error) {
verifierBytes := make([]byte, 96)
if _, err = rand.Read(verifierBytes); err != nil {
return "", "", fmt.Errorf("cursor: failed to generate PKCE verifier: %w", err)
}
verifier = base64.RawURLEncoding.EncodeToString(verifierBytes)
h := sha256.Sum256([]byte(verifier))
challenge = base64.RawURLEncoding.EncodeToString(h[:])
return verifier, challenge, nil
}
// GenerateAuthParams creates the full set of auth params for Cursor login.
func GenerateAuthParams() (*AuthParams, error) {
verifier, challenge, err := GeneratePKCE()
if err != nil {
return nil, err
}
uuidBytes := make([]byte, 16)
if _, err = rand.Read(uuidBytes); err != nil {
return nil, fmt.Errorf("cursor: failed to generate UUID: %w", err)
}
uuid := fmt.Sprintf("%x-%x-%x-%x-%x",
uuidBytes[0:4], uuidBytes[4:6], uuidBytes[6:8], uuidBytes[8:10], uuidBytes[10:16])
loginURL := fmt.Sprintf("%s?challenge=%s&uuid=%s&mode=login&redirectTarget=cli",
CursorLoginURL, challenge, uuid)
return &AuthParams{
Verifier: verifier,
Challenge: challenge,
UUID: uuid,
LoginURL: loginURL,
}, nil
}
// PollForAuth polls the Cursor auth endpoint until the user completes login.
func PollForAuth(ctx context.Context, uuid, verifier string) (*TokenPair, error) {
delay := pollBaseDelay
consecutiveErrors := 0
client := &http.Client{Timeout: 10 * time.Second}
for attempt := 0; attempt < pollMaxAttempts; attempt++ {
select {
case <-ctx.Done():
return nil, ctx.Err()
case <-time.After(delay):
}
url := fmt.Sprintf("%s?uuid=%s&verifier=%s", CursorPollURL, uuid, verifier)
req, err := http.NewRequestWithContext(ctx, http.MethodGet, url, nil)
if err != nil {
return nil, fmt.Errorf("cursor: failed to create poll request: %w", err)
}
resp, err := client.Do(req)
if err != nil {
consecutiveErrors++
if consecutiveErrors >= maxConsecutiveErrors {
return nil, fmt.Errorf("cursor: too many consecutive poll errors (last: %v)", err)
}
delay = minDuration(time.Duration(float64(delay)*pollBackoffMultiply), pollMaxDelay)
continue
}
body, _ := io.ReadAll(resp.Body)
resp.Body.Close()
if resp.StatusCode == http.StatusNotFound {
// Still waiting for user to authorize
consecutiveErrors = 0
delay = minDuration(time.Duration(float64(delay)*pollBackoffMultiply), pollMaxDelay)
continue
}
if resp.StatusCode >= 200 && resp.StatusCode < 300 {
var tokens TokenPair
if err := json.Unmarshal(body, &tokens); err != nil {
return nil, fmt.Errorf("cursor: failed to parse auth response: %w", err)
}
return &tokens, nil
}
return nil, fmt.Errorf("cursor: poll failed with status %d: %s", resp.StatusCode, string(body))
}
return nil, fmt.Errorf("cursor: authentication polling timeout (waited ~%.0f seconds)",
float64(pollMaxAttempts)*pollMaxDelay.Seconds()/2)
}
// RefreshToken refreshes a Cursor access token using the refresh token.
func RefreshToken(ctx context.Context, refreshToken string) (*TokenPair, error) {
client := &http.Client{Timeout: 10 * time.Second}
req, err := http.NewRequestWithContext(ctx, http.MethodPost, CursorRefreshURL,
strings.NewReader("{}"))
if err != nil {
return nil, fmt.Errorf("cursor: failed to create refresh request: %w", err)
}
req.Header.Set("Authorization", "Bearer "+refreshToken)
req.Header.Set("Content-Type", "application/json")
resp, err := client.Do(req)
if err != nil {
return nil, fmt.Errorf("cursor: token refresh request failed: %w", err)
}
defer resp.Body.Close()
body, _ := io.ReadAll(resp.Body)
if resp.StatusCode < 200 || resp.StatusCode >= 300 {
return nil, fmt.Errorf("cursor: token refresh failed (status %d): %s", resp.StatusCode, string(body))
}
var tokens TokenPair
if err := json.Unmarshal(body, &tokens); err != nil {
return nil, fmt.Errorf("cursor: failed to parse refresh response: %w", err)
}
// Keep original refresh token if not returned
if tokens.RefreshToken == "" {
tokens.RefreshToken = refreshToken
}
return &tokens, nil
}
// GetTokenExpiry extracts the JWT expiry from an access token with a 5-minute safety margin.
// Falls back to 1 hour from now if the token can't be parsed.
func GetTokenExpiry(token string) time.Time {
parts := strings.Split(token, ".")
if len(parts) != 3 {
return time.Now().Add(1 * time.Hour)
}
// Decode the payload (middle part)
payload := parts[1]
// Add padding if needed
switch len(payload) % 4 {
case 2:
payload += "=="
case 3:
payload += "="
}
// Replace URL-safe characters
payload = strings.ReplaceAll(payload, "-", "+")
payload = strings.ReplaceAll(payload, "_", "/")
decoded, err := base64.StdEncoding.DecodeString(payload)
if err != nil {
return time.Now().Add(1 * time.Hour)
}
var claims struct {
Exp float64 `json:"exp"`
}
if err := json.Unmarshal(decoded, &claims); err != nil || claims.Exp == 0 {
return time.Now().Add(1 * time.Hour)
}
sec, frac := math.Modf(claims.Exp)
expiry := time.Unix(int64(sec), int64(frac*1e9))
// Subtract 5-minute safety margin
return expiry.Add(-5 * time.Minute)
}
func minDuration(a, b time.Duration) time.Duration {
if a < b {
return a
}
return b
}

View File

@@ -0,0 +1,71 @@
package proto
import (
"encoding/binary"
"encoding/json"
"fmt"
)
const (
// ConnectEndStreamFlag marks the end-of-stream frame (trailers).
ConnectEndStreamFlag byte = 0x02
// ConnectCompressionFlag indicates the payload is compressed (not supported).
ConnectCompressionFlag byte = 0x01
// ConnectFrameHeaderSize is the fixed 5-byte frame header.
ConnectFrameHeaderSize = 5
)
// FrameConnectMessage wraps a protobuf payload in a Connect frame.
// Frame format: [1 byte flags][4 bytes payload length (big-endian)][payload]
func FrameConnectMessage(data []byte, flags byte) []byte {
frame := make([]byte, ConnectFrameHeaderSize+len(data))
frame[0] = flags
binary.BigEndian.PutUint32(frame[1:5], uint32(len(data)))
copy(frame[5:], data)
return frame
}
// ParseConnectFrame extracts one frame from a buffer.
// Returns (flags, payload, bytesConsumed, ok).
// ok is false when the buffer is too short for a complete frame.
func ParseConnectFrame(buf []byte) (flags byte, payload []byte, consumed int, ok bool) {
if len(buf) < ConnectFrameHeaderSize {
return 0, nil, 0, false
}
flags = buf[0]
length := binary.BigEndian.Uint32(buf[1:5])
total := ConnectFrameHeaderSize + int(length)
if len(buf) < total {
return 0, nil, 0, false
}
return flags, buf[5:total], total, true
}
// ParseConnectEndStream parses a Connect end-of-stream frame payload (JSON).
// Returns nil if there is no error in the trailer.
func ParseConnectEndStream(data []byte) error {
if len(data) == 0 {
return nil
}
var trailer struct {
Error *struct {
Code string `json:"code"`
Message string `json:"message"`
} `json:"error"`
}
if err := json.Unmarshal(data, &trailer); err != nil {
return fmt.Errorf("failed to parse Connect end stream: %w", err)
}
if trailer.Error != nil {
code := trailer.Error.Code
if code == "" {
code = "unknown"
}
msg := trailer.Error.Message
if msg == "" {
msg = "Unknown error"
}
return fmt.Errorf("Connect error %s: %s", code, msg)
}
return nil
}

View File

@@ -0,0 +1,507 @@
package proto
import (
"encoding/hex"
"fmt"
log "github.com/sirupsen/logrus"
"google.golang.org/protobuf/encoding/protowire"
)
// ServerMessageType identifies the kind of decoded server message.
type ServerMessageType int
const (
ServerMsgUnknown ServerMessageType = iota
ServerMsgTextDelta // Text content delta
ServerMsgThinkingDelta // Thinking/reasoning delta
ServerMsgThinkingCompleted // Thinking completed
ServerMsgKvGetBlob // Server wants a blob
ServerMsgKvSetBlob // Server wants to store a blob
ServerMsgExecRequestCtx // Server requests context (tools, etc.)
ServerMsgExecMcpArgs // Server wants MCP tool execution
ServerMsgExecShellArgs // Rejected: shell command
ServerMsgExecReadArgs // Rejected: file read
ServerMsgExecWriteArgs // Rejected: file write
ServerMsgExecDeleteArgs // Rejected: file delete
ServerMsgExecLsArgs // Rejected: directory listing
ServerMsgExecGrepArgs // Rejected: grep search
ServerMsgExecFetchArgs // Rejected: HTTP fetch
ServerMsgExecDiagnostics // Respond with empty diagnostics
ServerMsgExecShellStream // Rejected: shell stream
ServerMsgExecBgShellSpawn // Rejected: background shell
ServerMsgExecWriteShellStdin // Rejected: write shell stdin
ServerMsgExecOther // Other exec types (respond with empty)
)
// DecodedServerMessage holds parsed data from an AgentServerMessage.
type DecodedServerMessage struct {
Type ServerMessageType
// For text/thinking deltas
Text string
// For KV messages
KvId uint32
BlobId []byte // hex-encoded blob ID
BlobData []byte // for setBlobArgs
// For exec messages
ExecMsgId uint32
ExecId string
// For MCP args
McpToolName string
McpToolCallId string
McpArgs map[string][]byte // arg name -> protobuf-encoded value
// For rejection context
Path string
Command string
WorkingDirectory string
Url string
// For other exec - the raw field number for building a response
ExecFieldNumber int
}
// DecodeAgentServerMessage parses an AgentServerMessage and returns
// a structured representation of the first meaningful message found.
func DecodeAgentServerMessage(data []byte) (*DecodedServerMessage, error) {
msg := &DecodedServerMessage{Type: ServerMsgUnknown}
for len(data) > 0 {
num, typ, n := protowire.ConsumeTag(data)
if n < 0 {
return msg, fmt.Errorf("invalid tag")
}
data = data[n:]
switch typ {
case protowire.BytesType:
val, n := protowire.ConsumeBytes(data)
if n < 0 {
return msg, fmt.Errorf("invalid bytes field %d", num)
}
data = data[n:]
// Debug: log top-level ASM fields
log.Debugf("DecodeAgentServerMessage: found ASM field %d, len=%d", num, len(val))
switch num {
case ASM_InteractionUpdate:
log.Debugf("DecodeAgentServerMessage: calling decodeInteractionUpdate")
decodeInteractionUpdate(val, msg)
case ASM_ExecServerMessage:
log.Debugf("DecodeAgentServerMessage: calling decodeExecServerMessage")
decodeExecServerMessage(val, msg)
case ASM_KvServerMessage:
decodeKvServerMessage(val, msg)
case ASM_ConversationCheckpoint:
// Ignore checkpoint updates
log.Debugf("DecodeAgentServerMessage: ignoring ConversationCheckpoint")
}
case protowire.VarintType:
_, n := protowire.ConsumeVarint(data)
if n < 0 {
return msg, fmt.Errorf("invalid varint field %d", num)
}
data = data[n:]
default:
// Skip unknown wire types
n := protowire.ConsumeFieldValue(num, typ, data)
if n < 0 {
return msg, fmt.Errorf("invalid field %d", num)
}
data = data[n:]
}
}
return msg, nil
}
func decodeInteractionUpdate(data []byte, msg *DecodedServerMessage) {
log.Debugf("decodeInteractionUpdate: input len=%d, hex=%x", len(data), data)
for len(data) > 0 {
num, typ, n := protowire.ConsumeTag(data)
if n < 0 {
log.Debugf("decodeInteractionUpdate: invalid tag, remaining=%x", data)
return
}
data = data[n:]
log.Debugf("decodeInteractionUpdate: field=%d wire=%d remaining=%d bytes", num, typ, len(data))
if typ == protowire.BytesType {
val, n := protowire.ConsumeBytes(data)
if n < 0 {
log.Debugf("decodeInteractionUpdate: invalid bytes field %d", num)
return
}
data = data[n:]
log.Debugf("decodeInteractionUpdate: field %d content len=%d, first 20 bytes: %x", num, len(val), val[:min(20, len(val))])
switch num {
case IU_TextDelta:
msg.Type = ServerMsgTextDelta
msg.Text = decodeStringField(val, TDU_Text)
log.Debugf("decodeInteractionUpdate: TextDelta text=%q", msg.Text)
case IU_ThinkingDelta:
msg.Type = ServerMsgThinkingDelta
msg.Text = decodeStringField(val, TKD_Text)
log.Debugf("decodeInteractionUpdate: ThinkingDelta text=%q", msg.Text)
case IU_ThinkingCompleted:
msg.Type = ServerMsgThinkingCompleted
log.Debugf("decodeInteractionUpdate: ThinkingCompleted")
case 2:
// tool_call_started - ignore but log
log.Debugf("decodeInteractionUpdate: ToolCallStarted (ignored)")
case 3:
// tool_call_completed - ignore but log
log.Debugf("decodeInteractionUpdate: ToolCallCompleted (ignored)")
default:
log.Debugf("decodeInteractionUpdate: unknown field %d", num)
}
} else {
n := protowire.ConsumeFieldValue(num, typ, data)
if n < 0 {
return
}
data = data[n:]
}
}
}
func decodeKvServerMessage(data []byte, msg *DecodedServerMessage) {
for len(data) > 0 {
num, typ, n := protowire.ConsumeTag(data)
if n < 0 {
return
}
data = data[n:]
switch typ {
case protowire.VarintType:
val, n := protowire.ConsumeVarint(data)
if n < 0 {
return
}
data = data[n:]
if num == KSM_Id {
msg.KvId = uint32(val)
}
case protowire.BytesType:
val, n := protowire.ConsumeBytes(data)
if n < 0 {
return
}
data = data[n:]
switch num {
case KSM_GetBlobArgs:
msg.Type = ServerMsgKvGetBlob
msg.BlobId = decodeBytesField(val, GBA_BlobId)
case KSM_SetBlobArgs:
msg.Type = ServerMsgKvSetBlob
decodeSetBlobArgs(val, msg)
}
default:
n := protowire.ConsumeFieldValue(num, typ, data)
if n < 0 {
return
}
data = data[n:]
}
}
}
func decodeSetBlobArgs(data []byte, msg *DecodedServerMessage) {
for len(data) > 0 {
num, typ, n := protowire.ConsumeTag(data)
if n < 0 {
return
}
data = data[n:]
if typ == protowire.BytesType {
val, n := protowire.ConsumeBytes(data)
if n < 0 {
return
}
data = data[n:]
switch num {
case SBA_BlobId:
msg.BlobId = val
case SBA_BlobData:
msg.BlobData = val
}
} else {
n := protowire.ConsumeFieldValue(num, typ, data)
if n < 0 {
return
}
data = data[n:]
}
}
}
func decodeExecServerMessage(data []byte, msg *DecodedServerMessage) {
for len(data) > 0 {
num, typ, n := protowire.ConsumeTag(data)
if n < 0 {
return
}
data = data[n:]
switch typ {
case protowire.VarintType:
val, n := protowire.ConsumeVarint(data)
if n < 0 {
return
}
data = data[n:]
if num == ESM_Id {
msg.ExecMsgId = uint32(val)
log.Debugf("decodeExecServerMessage: ESM_Id = %d", val)
}
case protowire.BytesType:
val, n := protowire.ConsumeBytes(data)
if n < 0 {
return
}
data = data[n:]
// Debug: log all fields found in ExecServerMessage
log.Debugf("decodeExecServerMessage: found field %d, len=%d, first 20 bytes: %x", num, len(val), val[:min(20, len(val))])
switch num {
case ESM_ExecId:
msg.ExecId = string(val)
log.Debugf("decodeExecServerMessage: ESM_ExecId = %q", msg.ExecId)
case ESM_RequestContextArgs:
msg.Type = ServerMsgExecRequestCtx
case ESM_McpArgs:
msg.Type = ServerMsgExecMcpArgs
decodeMcpArgs(val, msg)
case ESM_ShellArgs:
msg.Type = ServerMsgExecShellArgs
decodeShellArgs(val, msg)
case ESM_ShellStreamArgs:
msg.Type = ServerMsgExecShellStream
decodeShellArgs(val, msg)
case ESM_ReadArgs:
msg.Type = ServerMsgExecReadArgs
msg.Path = decodeStringField(val, RA_Path)
case ESM_WriteArgs:
msg.Type = ServerMsgExecWriteArgs
msg.Path = decodeStringField(val, WA_Path)
case ESM_DeleteArgs:
msg.Type = ServerMsgExecDeleteArgs
msg.Path = decodeStringField(val, DA_Path)
case ESM_LsArgs:
msg.Type = ServerMsgExecLsArgs
msg.Path = decodeStringField(val, LA_Path)
case ESM_GrepArgs:
msg.Type = ServerMsgExecGrepArgs
case ESM_FetchArgs:
msg.Type = ServerMsgExecFetchArgs
msg.Url = decodeStringField(val, FA_Url)
case ESM_DiagnosticsArgs:
msg.Type = ServerMsgExecDiagnostics
case ESM_BackgroundShellSpawn:
msg.Type = ServerMsgExecBgShellSpawn
decodeShellArgs(val, msg) // same structure
case ESM_WriteShellStdinArgs:
msg.Type = ServerMsgExecWriteShellStdin
default:
// Unknown exec types - only set if we haven't identified the type yet
// (other fields like span_context (19) come after the exec type field)
if msg.Type == ServerMsgUnknown {
msg.Type = ServerMsgExecOther
msg.ExecFieldNumber = int(num)
}
}
default:
n := protowire.ConsumeFieldValue(num, typ, data)
if n < 0 {
return
}
data = data[n:]
}
}
}
func decodeMcpArgs(data []byte, msg *DecodedServerMessage) {
msg.McpArgs = make(map[string][]byte)
for len(data) > 0 {
num, typ, n := protowire.ConsumeTag(data)
if n < 0 {
return
}
data = data[n:]
if typ == protowire.BytesType {
val, n := protowire.ConsumeBytes(data)
if n < 0 {
return
}
data = data[n:]
switch num {
case MCA_Name:
msg.McpToolName = string(val)
case MCA_Args:
// Map entries are encoded as submessages with key=1, value=2
decodeMapEntry(val, msg.McpArgs)
case MCA_ToolCallId:
msg.McpToolCallId = string(val)
case MCA_ToolName:
// ToolName takes precedence if present
if msg.McpToolName == "" || string(val) != "" {
msg.McpToolName = string(val)
}
}
} else {
n := protowire.ConsumeFieldValue(num, typ, data)
if n < 0 {
return
}
data = data[n:]
}
}
}
func decodeMapEntry(data []byte, m map[string][]byte) {
var key string
var value []byte
for len(data) > 0 {
num, typ, n := protowire.ConsumeTag(data)
if n < 0 {
return
}
data = data[n:]
if typ == protowire.BytesType {
val, n := protowire.ConsumeBytes(data)
if n < 0 {
return
}
data = data[n:]
if num == 1 {
key = string(val)
} else if num == 2 {
value = append([]byte(nil), val...)
}
} else {
n := protowire.ConsumeFieldValue(num, typ, data)
if n < 0 {
return
}
data = data[n:]
}
}
if key != "" {
m[key] = value
}
}
func decodeShellArgs(data []byte, msg *DecodedServerMessage) {
for len(data) > 0 {
num, typ, n := protowire.ConsumeTag(data)
if n < 0 {
return
}
data = data[n:]
if typ == protowire.BytesType {
val, n := protowire.ConsumeBytes(data)
if n < 0 {
return
}
data = data[n:]
switch num {
case SHA_Command:
msg.Command = string(val)
case SHA_WorkingDirectory:
msg.WorkingDirectory = string(val)
}
} else {
n := protowire.ConsumeFieldValue(num, typ, data)
if n < 0 {
return
}
data = data[n:]
}
}
}
// --- Helper decoders ---
// decodeStringField extracts a string from the first matching field in a submessage.
func decodeStringField(data []byte, targetField protowire.Number) string {
for len(data) > 0 {
num, typ, n := protowire.ConsumeTag(data)
if n < 0 {
return ""
}
data = data[n:]
if typ == protowire.BytesType {
val, n := protowire.ConsumeBytes(data)
if n < 0 {
return ""
}
data = data[n:]
if num == targetField {
return string(val)
}
} else {
n := protowire.ConsumeFieldValue(num, typ, data)
if n < 0 {
return ""
}
data = data[n:]
}
}
return ""
}
// decodeBytesField extracts bytes from the first matching field in a submessage.
func decodeBytesField(data []byte, targetField protowire.Number) []byte {
for len(data) > 0 {
num, typ, n := protowire.ConsumeTag(data)
if n < 0 {
return nil
}
data = data[n:]
if typ == protowire.BytesType {
val, n := protowire.ConsumeBytes(data)
if n < 0 {
return nil
}
data = data[n:]
if num == targetField {
return append([]byte(nil), val...)
}
} else {
n := protowire.ConsumeFieldValue(num, typ, data)
if n < 0 {
return nil
}
data = data[n:]
}
}
return nil
}
// BlobIdHex returns the hex string of a blob ID for use as a map key.
func BlobIdHex(blobId []byte) string {
return hex.EncodeToString(blobId)
}

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,491 @@
// Package proto provides protobuf encoding for Cursor's gRPC API,
// using dynamicpb with the embedded FileDescriptorProto from agent.proto.
// This mirrors the cursor-auth TS plugin's use of @bufbuild/protobuf create()+toBinary().
package proto
import (
"crypto/sha256"
"encoding/hex"
"encoding/json"
"fmt"
log "github.com/sirupsen/logrus"
"google.golang.org/protobuf/proto"
"google.golang.org/protobuf/reflect/protoreflect"
"google.golang.org/protobuf/types/dynamicpb"
"google.golang.org/protobuf/types/known/structpb"
)
// --- Public types ---
// RunRequestParams holds all data needed to build an AgentRunRequest.
type RunRequestParams struct {
ModelId string
SystemPrompt string
UserText string
MessageId string
ConversationId string
Images []ImageData
Turns []TurnData
McpTools []McpToolDef
BlobStore map[string][]byte // hex(sha256) -> data, populated during encoding
}
type ImageData struct {
MimeType string
Data []byte
}
type TurnData struct {
UserText string
AssistantText string
}
type McpToolDef struct {
Name string
Description string
InputSchema json.RawMessage
}
// --- Helper: create a dynamic message and set fields ---
func newMsg(name string) *dynamicpb.Message {
return dynamicpb.NewMessage(Msg(name))
}
func field(msg *dynamicpb.Message, name string) protoreflect.FieldDescriptor {
return msg.Descriptor().Fields().ByName(protoreflect.Name(name))
}
func setStr(msg *dynamicpb.Message, name, val string) {
if val != "" {
msg.Set(field(msg, name), protoreflect.ValueOfString(val))
}
}
func setBytes(msg *dynamicpb.Message, name string, val []byte) {
if len(val) > 0 {
msg.Set(field(msg, name), protoreflect.ValueOfBytes(val))
}
}
func setUint32(msg *dynamicpb.Message, name string, val uint32) {
msg.Set(field(msg, name), protoreflect.ValueOfUint32(val))
}
func setBool(msg *dynamicpb.Message, name string, val bool) {
msg.Set(field(msg, name), protoreflect.ValueOfBool(val))
}
func setMsg(msg *dynamicpb.Message, name string, sub *dynamicpb.Message) {
msg.Set(field(msg, name), protoreflect.ValueOfMessage(sub.ProtoReflect()))
}
func marshal(msg *dynamicpb.Message) []byte {
b, err := proto.Marshal(msg)
if err != nil {
panic("cursor proto marshal: " + err.Error())
}
return b
}
// --- Encode functions mirroring cursor-fetch.ts ---
// EncodeHeartbeat returns an encoded AgentClientMessage with clientHeartbeat.
// Mirrors: create(AgentClientMessageSchema, { message: { case: 'clientHeartbeat', value: create(ClientHeartbeatSchema, {}) } })
func EncodeHeartbeat() []byte {
hb := newMsg("ClientHeartbeat")
acm := newMsg("AgentClientMessage")
setMsg(acm, "client_heartbeat", hb)
return marshal(acm)
}
// EncodeRunRequest builds a full AgentClientMessage wrapping an AgentRunRequest.
// Mirrors buildCursorRequest() in cursor-fetch.ts.
func EncodeRunRequest(p *RunRequestParams) []byte {
if p.BlobStore == nil {
p.BlobStore = make(map[string][]byte)
}
// --- Conversation turns ---
// Each turn is serialized as bytes (ConversationTurnStructure → bytes)
var turnBytes [][]byte
for _, turn := range p.Turns {
// UserMessage for this turn
um := newMsg("UserMessage")
setStr(um, "text", turn.UserText)
setStr(um, "message_id", generateId())
umBytes := marshal(um)
// Steps (assistant response)
var stepBytes [][]byte
if turn.AssistantText != "" {
am := newMsg("AssistantMessage")
setStr(am, "text", turn.AssistantText)
step := newMsg("ConversationStep")
setMsg(step, "assistant_message", am)
stepBytes = append(stepBytes, marshal(step))
}
// AgentConversationTurnStructure (fields are bytes, not submessages)
agentTurn := newMsg("AgentConversationTurnStructure")
setBytes(agentTurn, "user_message", umBytes)
for _, sb := range stepBytes {
stepsField := field(agentTurn, "steps")
list := agentTurn.Mutable(stepsField).List()
list.Append(protoreflect.ValueOfBytes(sb))
}
// ConversationTurnStructure (oneof turn → agentConversationTurn)
cts := newMsg("ConversationTurnStructure")
setMsg(cts, "agent_conversation_turn", agentTurn)
turnBytes = append(turnBytes, marshal(cts))
}
// --- System prompt blob ---
systemJSON, _ := json.Marshal(map[string]string{"role": "system", "content": p.SystemPrompt})
blobId := sha256Sum(systemJSON)
p.BlobStore[hex.EncodeToString(blobId)] = systemJSON
// --- ConversationStateStructure ---
css := newMsg("ConversationStateStructure")
// rootPromptMessagesJson: repeated bytes
rootField := field(css, "root_prompt_messages_json")
rootList := css.Mutable(rootField).List()
rootList.Append(protoreflect.ValueOfBytes(blobId))
// turns: repeated bytes
turnsField := field(css, "turns")
turnsList := css.Mutable(turnsField).List()
for _, tb := range turnBytes {
turnsList.Append(protoreflect.ValueOfBytes(tb))
}
// --- UserMessage (current) ---
userMessage := newMsg("UserMessage")
setStr(userMessage, "text", p.UserText)
setStr(userMessage, "message_id", p.MessageId)
// Images via SelectedContext
if len(p.Images) > 0 {
sc := newMsg("SelectedContext")
imgsField := field(sc, "selected_images")
imgsList := sc.Mutable(imgsField).List()
for _, img := range p.Images {
si := newMsg("SelectedImage")
setStr(si, "uuid", generateId())
setStr(si, "mime_type", img.MimeType)
setBytes(si, "data", img.Data)
imgsList.Append(protoreflect.ValueOfMessage(si.ProtoReflect()))
}
setMsg(userMessage, "selected_context", sc)
}
// --- UserMessageAction ---
uma := newMsg("UserMessageAction")
setMsg(uma, "user_message", userMessage)
// --- ConversationAction ---
ca := newMsg("ConversationAction")
setMsg(ca, "user_message_action", uma)
// --- ModelDetails ---
md := newMsg("ModelDetails")
setStr(md, "model_id", p.ModelId)
setStr(md, "display_model_id", p.ModelId)
setStr(md, "display_name", p.ModelId)
// --- AgentRunRequest ---
arr := newMsg("AgentRunRequest")
setMsg(arr, "conversation_state", css)
setMsg(arr, "action", ca)
setMsg(arr, "model_details", md)
setStr(arr, "conversation_id", p.ConversationId)
// McpTools
if len(p.McpTools) > 0 {
mcpTools := newMsg("McpTools")
toolsField := field(mcpTools, "mcp_tools")
toolsList := mcpTools.Mutable(toolsField).List()
for _, tool := range p.McpTools {
td := newMsg("McpToolDefinition")
setStr(td, "name", tool.Name)
setStr(td, "description", tool.Description)
if len(tool.InputSchema) > 0 {
setBytes(td, "input_schema", jsonToProtobufValueBytes(tool.InputSchema))
}
setStr(td, "provider_identifier", "proxy")
setStr(td, "tool_name", tool.Name)
toolsList.Append(protoreflect.ValueOfMessage(td.ProtoReflect()))
}
setMsg(arr, "mcp_tools", mcpTools)
}
// --- AgentClientMessage ---
acm := newMsg("AgentClientMessage")
setMsg(acm, "run_request", arr)
return marshal(acm)
}
// --- KV response encoders ---
// Mirrors handleKvMessage() in cursor-fetch.ts
// EncodeKvGetBlobResult responds to a getBlobArgs request.
func EncodeKvGetBlobResult(kvId uint32, blobData []byte) []byte {
result := newMsg("GetBlobResult")
if blobData != nil {
setBytes(result, "blob_data", blobData)
}
kvc := newMsg("KvClientMessage")
setUint32(kvc, "id", kvId)
setMsg(kvc, "get_blob_result", result)
acm := newMsg("AgentClientMessage")
setMsg(acm, "kv_client_message", kvc)
return marshal(acm)
}
// EncodeKvSetBlobResult responds to a setBlobArgs request.
func EncodeKvSetBlobResult(kvId uint32) []byte {
result := newMsg("SetBlobResult")
kvc := newMsg("KvClientMessage")
setUint32(kvc, "id", kvId)
setMsg(kvc, "set_blob_result", result)
acm := newMsg("AgentClientMessage")
setMsg(acm, "kv_client_message", kvc)
return marshal(acm)
}
// --- Exec response encoders ---
// Mirrors handleExecMessage() and sendExec() in cursor-fetch.ts
// EncodeExecRequestContextResult responds to requestContextArgs with tool definitions.
func EncodeExecRequestContextResult(execMsgId uint32, execId string, tools []McpToolDef) []byte {
// RequestContext with tools
rc := newMsg("RequestContext")
if len(tools) > 0 {
toolsField := field(rc, "tools")
toolsList := rc.Mutable(toolsField).List()
for _, tool := range tools {
td := newMsg("McpToolDefinition")
setStr(td, "name", tool.Name)
setStr(td, "description", tool.Description)
if len(tool.InputSchema) > 0 {
setBytes(td, "input_schema", jsonToProtobufValueBytes(tool.InputSchema))
}
setStr(td, "provider_identifier", "proxy")
setStr(td, "tool_name", tool.Name)
toolsList.Append(protoreflect.ValueOfMessage(td.ProtoReflect()))
}
}
// RequestContextSuccess
rcs := newMsg("RequestContextSuccess")
setMsg(rcs, "request_context", rc)
// RequestContextResult (oneof success)
rcr := newMsg("RequestContextResult")
setMsg(rcr, "success", rcs)
return encodeExecClientMsg(execMsgId, execId, "request_context_result", rcr)
}
// EncodeExecMcpResult responds with MCP tool result.
func EncodeExecMcpResult(execMsgId uint32, execId string, content string, isError bool) []byte {
textContent := newMsg("McpTextContent")
setStr(textContent, "text", content)
contentItem := newMsg("McpToolResultContentItem")
setMsg(contentItem, "text", textContent)
success := newMsg("McpSuccess")
contentField := field(success, "content")
contentList := success.Mutable(contentField).List()
contentList.Append(protoreflect.ValueOfMessage(contentItem.ProtoReflect()))
setBool(success, "is_error", isError)
result := newMsg("McpResult")
setMsg(result, "success", success)
return encodeExecClientMsg(execMsgId, execId, "mcp_result", result)
}
// EncodeExecMcpError responds with MCP error.
func EncodeExecMcpError(execMsgId uint32, execId string, errMsg string) []byte {
mcpErr := newMsg("McpError")
setStr(mcpErr, "error", errMsg)
result := newMsg("McpResult")
setMsg(result, "error", mcpErr)
return encodeExecClientMsg(execMsgId, execId, "mcp_result", result)
}
// --- Rejection encoders (mirror handleExecMessage rejections) ---
func EncodeExecReadRejected(execMsgId uint32, execId string, path, reason string) []byte {
rej := newMsg("ReadRejected")
setStr(rej, "path", path)
setStr(rej, "reason", reason)
result := newMsg("ReadResult")
setMsg(result, "rejected", rej)
return encodeExecClientMsg(execMsgId, execId, "read_result", result)
}
func EncodeExecShellRejected(execMsgId uint32, execId string, command, workDir, reason string) []byte {
rej := newMsg("ShellRejected")
setStr(rej, "command", command)
setStr(rej, "working_directory", workDir)
setStr(rej, "reason", reason)
result := newMsg("ShellResult")
setMsg(result, "rejected", rej)
return encodeExecClientMsg(execMsgId, execId, "shell_result", result)
}
func EncodeExecWriteRejected(execMsgId uint32, execId string, path, reason string) []byte {
rej := newMsg("WriteRejected")
setStr(rej, "path", path)
setStr(rej, "reason", reason)
result := newMsg("WriteResult")
setMsg(result, "rejected", rej)
return encodeExecClientMsg(execMsgId, execId, "write_result", result)
}
func EncodeExecDeleteRejected(execMsgId uint32, execId string, path, reason string) []byte {
rej := newMsg("DeleteRejected")
setStr(rej, "path", path)
setStr(rej, "reason", reason)
result := newMsg("DeleteResult")
setMsg(result, "rejected", rej)
return encodeExecClientMsg(execMsgId, execId, "delete_result", result)
}
func EncodeExecLsRejected(execMsgId uint32, execId string, path, reason string) []byte {
rej := newMsg("LsRejected")
setStr(rej, "path", path)
setStr(rej, "reason", reason)
result := newMsg("LsResult")
setMsg(result, "rejected", rej)
return encodeExecClientMsg(execMsgId, execId, "ls_result", result)
}
func EncodeExecGrepError(execMsgId uint32, execId string, errMsg string) []byte {
grepErr := newMsg("GrepError")
setStr(grepErr, "error", errMsg)
result := newMsg("GrepResult")
setMsg(result, "error", grepErr)
return encodeExecClientMsg(execMsgId, execId, "grep_result", result)
}
func EncodeExecFetchError(execMsgId uint32, execId string, url, errMsg string) []byte {
fetchErr := newMsg("FetchError")
setStr(fetchErr, "url", url)
setStr(fetchErr, "error", errMsg)
result := newMsg("FetchResult")
setMsg(result, "error", fetchErr)
return encodeExecClientMsg(execMsgId, execId, "fetch_result", result)
}
func EncodeExecDiagnosticsResult(execMsgId uint32, execId string) []byte {
result := newMsg("DiagnosticsResult")
return encodeExecClientMsg(execMsgId, execId, "diagnostics_result", result)
}
func EncodeExecBackgroundShellSpawnRejected(execMsgId uint32, execId string, command, workDir, reason string) []byte {
rej := newMsg("ShellRejected")
setStr(rej, "command", command)
setStr(rej, "working_directory", workDir)
setStr(rej, "reason", reason)
result := newMsg("BackgroundShellSpawnResult")
setMsg(result, "rejected", rej)
return encodeExecClientMsg(execMsgId, execId, "background_shell_spawn_result", result)
}
func EncodeExecWriteShellStdinError(execMsgId uint32, execId string, errMsg string) []byte {
wsErr := newMsg("WriteShellStdinError")
setStr(wsErr, "error", errMsg)
result := newMsg("WriteShellStdinResult")
setMsg(result, "error", wsErr)
return encodeExecClientMsg(execMsgId, execId, "write_shell_stdin_result", result)
}
// encodeExecClientMsg wraps an exec result in AgentClientMessage.
// Mirrors sendExec() in cursor-fetch.ts.
func encodeExecClientMsg(id uint32, execId string, resultFieldName string, resultMsg *dynamicpb.Message) []byte {
ecm := newMsg("ExecClientMessage")
setUint32(ecm, "id", id)
// Force set exec_id even if empty - Cursor requires this field to be set
ecm.Set(field(ecm, "exec_id"), protoreflect.ValueOfString(execId))
// Debug: check if field exists
fd := field(ecm, resultFieldName)
if fd == nil {
panic(fmt.Sprintf("field %q NOT FOUND in ExecClientMessage! Available fields: %v", resultFieldName, listFields(ecm)))
}
// Debug: log the actual field being set
log.Debugf("encodeExecClientMsg: setting field %q (number=%d, kind=%s)", fd.Name(), fd.Number(), fd.Kind())
ecm.Set(fd, protoreflect.ValueOfMessage(resultMsg.ProtoReflect()))
acm := newMsg("AgentClientMessage")
setMsg(acm, "exec_client_message", ecm)
return marshal(acm)
}
func listFields(msg *dynamicpb.Message) []string {
var names []string
for i := 0; i < msg.Descriptor().Fields().Len(); i++ {
names = append(names, string(msg.Descriptor().Fields().Get(i).Name()))
}
return names
}
// --- Utilities ---
// jsonToProtobufValueBytes converts a JSON schema (json.RawMessage) to protobuf Value binary.
// This mirrors the TS pattern: toBinary(ValueSchema, fromJson(ValueSchema, jsonSchema))
func jsonToProtobufValueBytes(jsonData json.RawMessage) []byte {
if len(jsonData) == 0 {
return nil
}
var v interface{}
if err := json.Unmarshal(jsonData, &v); err != nil {
return jsonData // fallback to raw JSON if parsing fails
}
pbVal, err := structpb.NewValue(v)
if err != nil {
return jsonData // fallback
}
b, err := proto.Marshal(pbVal)
if err != nil {
return jsonData // fallback
}
return b
}
// ProtobufValueBytesToJSON converts protobuf Value binary back to JSON.
// This mirrors the TS pattern: toJson(ValueSchema, fromBinary(ValueSchema, value))
func ProtobufValueBytesToJSON(data []byte) (interface{}, error) {
val := &structpb.Value{}
if err := proto.Unmarshal(data, val); err != nil {
return nil, err
}
return val.AsInterface(), nil
}
func sha256Sum(data []byte) []byte {
h := sha256.Sum256(data)
return h[:]
}
var idCounter uint64
func generateId() string {
idCounter++
h := sha256.Sum256([]byte{byte(idCounter), byte(idCounter >> 8), byte(idCounter >> 16)})
return hex.EncodeToString(h[:16])
}

View File

@@ -0,0 +1,332 @@
// Package proto provides hand-rolled protobuf encode/decode for Cursor's gRPC API.
// Field numbers are extracted from the TypeScript generated proto/agent_pb.ts in alma-plugins/cursor-auth.
package proto
// AgentClientMessage (msg 118) oneof "message"
const (
ACM_RunRequest = 1 // AgentRunRequest
ACM_ExecClientMessage = 2 // ExecClientMessage
ACM_KvClientMessage = 3 // KvClientMessage
ACM_ConversationAction = 4 // ConversationAction
ACM_ExecClientControlMsg = 5 // ExecClientControlMessage
ACM_InteractionResponse = 6 // InteractionResponse
ACM_ClientHeartbeat = 7 // ClientHeartbeat
)
// AgentServerMessage (msg 119) oneof "message"
const (
ASM_InteractionUpdate = 1 // InteractionUpdate
ASM_ExecServerMessage = 2 // ExecServerMessage
ASM_ConversationCheckpoint = 3 // ConversationStateStructure
ASM_KvServerMessage = 4 // KvServerMessage
ASM_ExecServerControlMessage = 5 // ExecServerControlMessage
ASM_InteractionQuery = 7 // InteractionQuery
)
// AgentRunRequest (msg 91)
const (
ARR_ConversationState = 1 // ConversationStateStructure
ARR_Action = 2 // ConversationAction
ARR_ModelDetails = 3 // ModelDetails
ARR_McpTools = 4 // McpTools
ARR_ConversationId = 5 // string (optional)
)
// ConversationStateStructure (msg 83)
const (
CSS_RootPromptMessagesJson = 1 // repeated bytes
CSS_TurnsOld = 2 // repeated bytes (deprecated)
CSS_Todos = 3 // repeated bytes
CSS_PendingToolCalls = 4 // repeated string
CSS_Turns = 8 // repeated bytes (CURRENT field for turns)
CSS_PreviousWorkspaceUris = 9 // repeated string
CSS_SelfSummaryCount = 17 // uint32
CSS_ReadPaths = 18 // repeated string
)
// ConversationAction (msg 54) oneof "action"
const (
CA_UserMessageAction = 1 // UserMessageAction
)
// UserMessageAction (msg 55)
const (
UMA_UserMessage = 1 // UserMessage
)
// UserMessage (msg 63)
const (
UM_Text = 1 // string
UM_MessageId = 2 // string
UM_SelectedContext = 3 // SelectedContext (optional)
)
// SelectedContext
const (
SC_SelectedImages = 1 // repeated SelectedImage
)
// SelectedImage
const (
SI_BlobId = 1 // bytes (oneof dataOrBlobId)
SI_Uuid = 2 // string
SI_Path = 3 // string
SI_MimeType = 7 // string
SI_Data = 8 // bytes (oneof dataOrBlobId)
)
// ModelDetails (msg 88)
const (
MD_ModelId = 1 // string
MD_ThinkingDetails = 2 // ThinkingDetails (optional)
MD_DisplayModelId = 3 // string
MD_DisplayName = 4 // string
)
// McpTools (msg 307)
const (
MT_McpTools = 1 // repeated McpToolDefinition
)
// McpToolDefinition (msg 306)
const (
MTD_Name = 1 // string
MTD_Description = 2 // string
MTD_InputSchema = 3 // bytes
MTD_ProviderIdentifier = 4 // string
MTD_ToolName = 5 // string
)
// ConversationTurnStructure (msg 70) oneof "turn"
const (
CTS_AgentConversationTurn = 1 // AgentConversationTurnStructure
)
// AgentConversationTurnStructure (msg 72)
const (
ACTS_UserMessage = 1 // bytes (serialized UserMessage)
ACTS_Steps = 2 // repeated bytes (serialized ConversationStep)
)
// ConversationStep (msg 53) oneof "message"
const (
CS_AssistantMessage = 1 // AssistantMessage
)
// AssistantMessage
const (
AM_Text = 1 // string
)
// --- Server-side message fields ---
// InteractionUpdate oneof "message"
const (
IU_TextDelta = 1 // TextDeltaUpdate
IU_ThinkingDelta = 4 // ThinkingDeltaUpdate
IU_ThinkingCompleted = 5 // ThinkingCompletedUpdate
)
// TextDeltaUpdate (msg 92)
const (
TDU_Text = 1 // string
)
// ThinkingDeltaUpdate (msg 97)
const (
TKD_Text = 1 // string
)
// KvServerMessage (msg 271)
const (
KSM_Id = 1 // uint32
KSM_GetBlobArgs = 2 // GetBlobArgs
KSM_SetBlobArgs = 3 // SetBlobArgs
)
// GetBlobArgs (msg 267)
const (
GBA_BlobId = 1 // bytes
)
// SetBlobArgs (msg 269)
const (
SBA_BlobId = 1 // bytes
SBA_BlobData = 2 // bytes
)
// KvClientMessage (msg 272)
const (
KCM_Id = 1 // uint32
KCM_GetBlobResult = 2 // GetBlobResult
KCM_SetBlobResult = 3 // SetBlobResult
)
// GetBlobResult (msg 268)
const (
GBR_BlobData = 1 // bytes (optional)
)
// ExecServerMessage
const (
ESM_Id = 1 // uint32
ESM_ExecId = 15 // string
// oneof message:
ESM_ShellArgs = 2 // ShellArgs
ESM_WriteArgs = 3 // WriteArgs
ESM_DeleteArgs = 4 // DeleteArgs
ESM_GrepArgs = 5 // GrepArgs
ESM_ReadArgs = 7 // ReadArgs (NOTE: 6 is skipped)
ESM_LsArgs = 8 // LsArgs
ESM_DiagnosticsArgs = 9 // DiagnosticsArgs
ESM_RequestContextArgs = 10 // RequestContextArgs
ESM_McpArgs = 11 // McpArgs
ESM_ShellStreamArgs = 14 // ShellArgs (stream variant)
ESM_BackgroundShellSpawn = 16 // BackgroundShellSpawnArgs
ESM_FetchArgs = 20 // FetchArgs
ESM_WriteShellStdinArgs = 23 // WriteShellStdinArgs
)
// ExecClientMessage
const (
ECM_Id = 1 // uint32
ECM_ExecId = 15 // string
// oneof message (mirrors server fields):
ECM_ShellResult = 2
ECM_WriteResult = 3
ECM_DeleteResult = 4
ECM_GrepResult = 5
ECM_ReadResult = 7
ECM_LsResult = 8
ECM_DiagnosticsResult = 9
ECM_RequestContextResult = 10
ECM_McpResult = 11
ECM_ShellStream = 14
ECM_BackgroundShellSpawnRes = 16
ECM_FetchResult = 20
ECM_WriteShellStdinResult = 23
)
// McpArgs
const (
MCA_Name = 1 // string
MCA_Args = 2 // map<string, bytes>
MCA_ToolCallId = 3 // string
MCA_ProviderIdentifier = 4 // string
MCA_ToolName = 5 // string
)
// RequestContextResult oneof "result"
const (
RCR_Success = 1 // RequestContextSuccess
RCR_Error = 2 // RequestContextError
)
// RequestContextSuccess (msg 337)
const (
RCS_RequestContext = 1 // RequestContext
)
// RequestContext
const (
RC_Rules = 2 // repeated CursorRule
RC_Tools = 7 // repeated McpToolDefinition
)
// McpResult oneof "result"
const (
MCR_Success = 1 // McpSuccess
MCR_Error = 2 // McpError
MCR_Rejected = 3 // McpRejected
)
// McpSuccess (msg 290)
const (
MCS_Content = 1 // repeated McpToolResultContentItem
MCS_IsError = 2 // bool
)
// McpToolResultContentItem oneof "content"
const (
MTRCI_Text = 1 // McpTextContent
)
// McpTextContent (msg 287)
const (
MTC_Text = 1 // string
)
// McpError (msg 291)
const (
MCE_Error = 1 // string
)
// --- Rejection messages ---
// ReadRejected: path=1, reason=2
// ShellRejected: command=1, workingDirectory=2, reason=3, isReadonly=4
// WriteRejected: path=1, reason=2
// DeleteRejected: path=1, reason=2
// LsRejected: path=1, reason=2
// GrepError: error=1
// FetchError: url=1, error=2
// WriteShellStdinError: error=1
// ReadResult oneof: success=1, error=2, rejected=3
// ShellResult oneof: success=1 (+ various), rejected=?
// The TS code uses specific result field numbers from the oneof:
const (
RR_Rejected = 3 // ReadResult.rejected
SR_Rejected = 5 // ShellResult.rejected (from TS: ShellResult has success/various/rejected)
WR_Rejected = 5 // WriteResult.rejected
DR_Rejected = 3 // DeleteResult.rejected
LR_Rejected = 3 // LsResult.rejected
GR_Error = 2 // GrepResult.error
FR_Error = 2 // FetchResult.error
BSSR_Rejected = 2 // BackgroundShellSpawnResult.rejected (error field)
WSSR_Error = 2 // WriteShellStdinResult.error
)
// --- Rejection struct fields ---
const (
REJ_Path = 1
REJ_Reason = 2
SREJ_Command = 1
SREJ_WorkingDir = 2
SREJ_Reason = 3
SREJ_IsReadonly = 4
GERR_Error = 1
FERR_Url = 1
FERR_Error = 2
)
// ReadArgs
const (
RA_Path = 1 // string
)
// WriteArgs
const (
WA_Path = 1 // string
)
// DeleteArgs
const (
DA_Path = 1 // string
)
// LsArgs
const (
LA_Path = 1 // string
)
// ShellArgs
const (
SHA_Command = 1 // string
SHA_WorkingDirectory = 2 // string
)
// FetchArgs
const (
FA_Url = 1 // string
)

View File

@@ -0,0 +1,273 @@
package proto
import (
"crypto/tls"
"fmt"
"io"
"net"
"sync"
"time"
log "github.com/sirupsen/logrus"
"golang.org/x/net/http2"
"golang.org/x/net/http2/hpack"
)
// H2Stream provides bidirectional HTTP/2 streaming for the Connect protocol.
// Go's net/http does not support full-duplex HTTP/2, so we use the low-level framer.
type H2Stream struct {
framer *http2.Framer
conn net.Conn
streamID uint32
mu sync.Mutex
id string // unique identifier for debugging
frameNum int64 // sequential frame counter for debugging
dataCh chan []byte
doneCh chan struct{}
err error
}
// ID returns the unique identifier for this stream (for logging).
func (s *H2Stream) ID() string { return s.id }
// FrameNum returns the current frame number for debugging.
func (s *H2Stream) FrameNum() int64 {
s.mu.Lock()
defer s.mu.Unlock()
return s.frameNum
}
// DialH2Stream establishes a TLS+HTTP/2 connection and opens a new stream.
func DialH2Stream(host string, headers map[string]string) (*H2Stream, error) {
tlsConn, err := tls.Dial("tcp", host+":443", &tls.Config{
NextProtos: []string{"h2"},
})
if err != nil {
return nil, fmt.Errorf("h2: TLS dial failed: %w", err)
}
if tlsConn.ConnectionState().NegotiatedProtocol != "h2" {
tlsConn.Close()
return nil, fmt.Errorf("h2: server did not negotiate h2")
}
framer := http2.NewFramer(tlsConn, tlsConn)
// Client connection preface
if _, err := tlsConn.Write([]byte(http2.ClientPreface)); err != nil {
tlsConn.Close()
return nil, fmt.Errorf("h2: preface write failed: %w", err)
}
// Send initial SETTINGS (with large initial window)
if err := framer.WriteSettings(
http2.Setting{ID: http2.SettingInitialWindowSize, Val: 4 * 1024 * 1024},
http2.Setting{ID: http2.SettingMaxConcurrentStreams, Val: 100},
); err != nil {
tlsConn.Close()
return nil, fmt.Errorf("h2: settings write failed: %w", err)
}
// Connection-level window update (default is 65535, bump it up)
if err := framer.WriteWindowUpdate(0, 3*1024*1024); err != nil {
tlsConn.Close()
return nil, fmt.Errorf("h2: window update failed: %w", err)
}
// Read and handle initial server frames (SETTINGS, WINDOW_UPDATE)
for i := 0; i < 5; i++ {
f, err := framer.ReadFrame()
if err != nil {
tlsConn.Close()
return nil, fmt.Errorf("h2: initial frame read failed: %w", err)
}
switch sf := f.(type) {
case *http2.SettingsFrame:
if !sf.IsAck() {
framer.WriteSettingsAck()
} else {
goto handshakeDone
}
case *http2.WindowUpdateFrame:
// ignore
default:
// unexpected but continue
}
}
handshakeDone:
// Build HEADERS
streamID := uint32(1)
var hdrBuf []byte
enc := hpack.NewEncoder(&sliceWriter{buf: &hdrBuf})
enc.WriteField(hpack.HeaderField{Name: ":method", Value: "POST"})
enc.WriteField(hpack.HeaderField{Name: ":scheme", Value: "https"})
enc.WriteField(hpack.HeaderField{Name: ":authority", Value: host})
if p, ok := headers[":path"]; ok {
enc.WriteField(hpack.HeaderField{Name: ":path", Value: p})
}
for k, v := range headers {
if len(k) > 0 && k[0] == ':' {
continue
}
enc.WriteField(hpack.HeaderField{Name: k, Value: v})
}
if err := framer.WriteHeaders(http2.HeadersFrameParam{
StreamID: streamID,
BlockFragment: hdrBuf,
EndStream: false,
EndHeaders: true,
}); err != nil {
tlsConn.Close()
return nil, fmt.Errorf("h2: headers write failed: %w", err)
}
s := &H2Stream{
framer: framer,
conn: tlsConn,
streamID: streamID,
dataCh: make(chan []byte, 256),
doneCh: make(chan struct{}),
id: fmt.Sprintf("%d-%s", streamID, time.Now().Format("150405.000")),
frameNum: 0,
}
go s.readLoop()
return s, nil
}
// Write sends a DATA frame on the stream.
func (s *H2Stream) Write(data []byte) error {
s.mu.Lock()
defer s.mu.Unlock()
const maxFrame = 16384
for len(data) > 0 {
chunk := data
if len(chunk) > maxFrame {
chunk = data[:maxFrame]
}
data = data[len(chunk):]
if err := s.framer.WriteData(s.streamID, false, chunk); err != nil {
return err
}
}
// Try to flush the underlying connection if it supports it
if flusher, ok := s.conn.(interface{ Flush() error }); ok {
flusher.Flush()
}
return nil
}
// Data returns the channel of received data chunks.
func (s *H2Stream) Data() <-chan []byte { return s.dataCh }
// Done returns a channel closed when the stream ends.
func (s *H2Stream) Done() <-chan struct{} { return s.doneCh }
// Close tears down the connection.
func (s *H2Stream) Close() {
s.conn.Close()
}
func (s *H2Stream) readLoop() {
defer close(s.doneCh)
defer close(s.dataCh)
log.Debugf("h2stream[%s]: readLoop started for streamID=%d", s.id, s.streamID)
for {
f, err := s.framer.ReadFrame()
if err != nil {
if err != io.EOF {
s.err = err
log.Debugf("h2stream[%s]: readLoop error: %v", s.id, err)
} else {
log.Debugf("h2stream[%s]: readLoop EOF", s.id)
}
return
}
// Increment frame counter for debugging
s.mu.Lock()
s.frameNum++
frameNum := s.frameNum
s.mu.Unlock()
switch frame := f.(type) {
case *http2.DataFrame:
log.Debugf("h2stream[%s]: frame#%d received DATA frame streamID=%d, len=%d, endStream=%v", s.id, frameNum, frame.StreamID, len(frame.Data()), frame.StreamEnded())
if frame.StreamID == s.streamID && len(frame.Data()) > 0 {
cp := make([]byte, len(frame.Data()))
copy(cp, frame.Data())
// Log first 20 bytes for debugging
previewLen := len(cp)
if previewLen > 20 {
previewLen = 20
}
log.Debugf("h2stream[%s]: frame#%d sending to dataCh: len=%d, dataCh len=%d/%d, first bytes: %x (%q)", s.id, frameNum, len(cp), len(s.dataCh), cap(s.dataCh), cp[:previewLen], string(cp[:previewLen]))
s.dataCh <- cp
// Flow control: send WINDOW_UPDATE
s.mu.Lock()
s.framer.WriteWindowUpdate(0, uint32(len(cp)))
s.framer.WriteWindowUpdate(s.streamID, uint32(len(cp)))
s.mu.Unlock()
}
if frame.StreamEnded() {
log.Debugf("h2stream[%s]: frame#%d DATA frame has END_STREAM flag, stream ending", s.id, frameNum)
return
}
case *http2.HeadersFrame:
// Decode HPACK headers for debugging
decoder := hpack.NewDecoder(4096, func(hf hpack.HeaderField) {
log.Debugf("h2stream[%s]: frame#%d header: %s = %q", s.id, frameNum, hf.Name, hf.Value)
// Check for error status
if hf.Name == "grpc-status" || hf.Name == ":status" && hf.Value != "200" {
log.Warnf("h2stream[%s]: frame#%d received error status header: %s = %q", s.id, frameNum, hf.Name, hf.Value)
}
})
decoder.Write(frame.HeaderBlockFragment())
log.Debugf("h2stream[%s]: frame#%d received HEADERS frame streamID=%d, endStream=%v", s.id, frameNum, frame.StreamID, frame.StreamEnded())
if frame.StreamEnded() {
log.Debugf("h2stream[%s]: frame#%d HEADERS frame has END_STREAM flag, stream ending", s.id, frameNum)
return
}
case *http2.RSTStreamFrame:
s.err = fmt.Errorf("h2: RST_STREAM code=%d", frame.ErrCode)
log.Debugf("h2stream[%s]: frame#%d received RST_STREAM code=%d", s.id, frameNum, frame.ErrCode)
return
case *http2.GoAwayFrame:
s.err = fmt.Errorf("h2: GOAWAY code=%d", frame.ErrCode)
log.Debugf("h2stream[%s]: received GOAWAY code=%d", s.id, frame.ErrCode)
return
case *http2.PingFrame:
log.Debugf("h2stream[%s]: received PING frame, isAck=%v", s.id, frame.IsAck())
if !frame.IsAck() {
s.mu.Lock()
s.framer.WritePing(true, frame.Data)
s.mu.Unlock()
}
case *http2.SettingsFrame:
log.Debugf("h2stream[%s]: received SETTINGS frame, isAck=%v, numSettings=%d", s.id, frame.IsAck(), frame.NumSettings())
if !frame.IsAck() {
s.mu.Lock()
s.framer.WriteSettingsAck()
s.mu.Unlock()
}
case *http2.WindowUpdateFrame:
log.Debugf("h2stream[%s]: received WINDOW_UPDATE frame", s.id)
}
}
}
type sliceWriter struct{ buf *[]byte }
func (w *sliceWriter) Write(p []byte) (int, error) {
*w.buf = append(*w.buf, p...)
return len(p), nil
}

View File

@@ -24,6 +24,7 @@ func newAuthManager() *sdkAuth.Manager {
sdkAuth.NewKiloAuthenticator(),
sdkAuth.NewGitLabAuthenticator(),
sdkAuth.NewCodeBuddyAuthenticator(),
sdkAuth.NewCursorAuthenticator(),
)
return manager
}

View File

@@ -0,0 +1,38 @@
package cmd
import (
"context"
"fmt"
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
sdkAuth "github.com/router-for-me/CLIProxyAPI/v6/sdk/auth"
log "github.com/sirupsen/logrus"
)
// DoCursorLogin triggers the OAuth PKCE flow for Cursor and saves tokens.
func DoCursorLogin(cfg *config.Config, options *LoginOptions) {
if options == nil {
options = &LoginOptions{}
}
manager := newAuthManager()
authOpts := &sdkAuth.LoginOptions{
NoBrowser: options.NoBrowser,
Metadata: map[string]string{},
Prompt: options.Prompt,
}
record, savedPath, err := manager.Login(context.Background(), "cursor", cfg, authOpts)
if err != nil {
log.Errorf("Cursor authentication failed: %v", err)
return
}
if savedPath != "" {
fmt.Printf("Authentication saved to %s\n", savedPath)
}
if record != nil && record.Label != "" {
fmt.Printf("Authenticated as %s\n", record.Label)
}
fmt.Println("Cursor authentication successful!")
}

View File

@@ -231,11 +231,25 @@ func GetStaticModelDefinitionsByChannel(channel string) []*ModelInfo {
return GetAntigravityModels()
case "codebuddy":
return GetCodeBuddyModels()
case "cursor":
return GetCursorModels()
default:
return nil
}
}
// GetCursorModels returns the fallback Cursor model definitions.
func GetCursorModels() []*ModelInfo {
return []*ModelInfo{
{ID: "composer-2", Object: "model", OwnedBy: "cursor", Type: "cursor", DisplayName: "Composer 2", ContextLength: 200000, MaxCompletionTokens: 64000, Thinking: &ThinkingSupport{Max: 50000, DynamicAllowed: true}},
{ID: "claude-4-sonnet", Object: "model", OwnedBy: "cursor", Type: "cursor", DisplayName: "Claude 4 Sonnet", ContextLength: 200000, MaxCompletionTokens: 64000, Thinking: &ThinkingSupport{Max: 50000, DynamicAllowed: true}},
{ID: "claude-3.5-sonnet", Object: "model", OwnedBy: "cursor", Type: "cursor", DisplayName: "Claude 3.5 Sonnet", ContextLength: 200000, MaxCompletionTokens: 8192},
{ID: "gpt-4o", Object: "model", OwnedBy: "cursor", Type: "cursor", DisplayName: "GPT-4o", ContextLength: 128000, MaxCompletionTokens: 16384},
{ID: "cursor-small", Object: "model", OwnedBy: "cursor", Type: "cursor", DisplayName: "Cursor Small", ContextLength: 200000, MaxCompletionTokens: 64000},
{ID: "gemini-2.5-pro", Object: "model", OwnedBy: "cursor", Type: "cursor", DisplayName: "Gemini 2.5 Pro", ContextLength: 1000000, MaxCompletionTokens: 65536, Thinking: &ThinkingSupport{Max: 50000, DynamicAllowed: true}},
}
}
// LookupStaticModelInfo searches all static model definitions for a model by ID.
// Returns nil if no matching model is found.
func LookupStaticModelInfo(modelID string) *ModelInfo {
@@ -260,6 +274,7 @@ func LookupStaticModelInfo(modelID string) *ModelInfo {
GetKiloModels(),
GetAmazonQModels(),
GetCodeBuddyModels(),
GetCursorModels(),
}
for _, models := range allModels {
for _, m := range models {

File diff suppressed because it is too large Load Diff

View File

@@ -57,6 +57,12 @@ func GetProviderName(modelName string) []string {
return providers
}
// Fallback: if cursor provider has registered models, route unknown models to it.
// Cursor acts as a universal proxy supporting multiple model families (Claude, GPT, Gemini, etc.).
if models := registry.GetGlobalRegistry().GetAvailableModelsByProvider("cursor"); len(models) > 0 {
return []string{"cursor"}
}
return providers
}

91
sdk/auth/cursor.go Normal file
View File

@@ -0,0 +1,91 @@
package auth
import (
"context"
"fmt"
"time"
cursorauth "github.com/router-for-me/CLIProxyAPI/v6/internal/auth/cursor"
"github.com/router-for-me/CLIProxyAPI/v6/internal/browser"
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
coreauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth"
log "github.com/sirupsen/logrus"
)
// CursorAuthenticator implements OAuth PKCE login for Cursor.
type CursorAuthenticator struct{}
// NewCursorAuthenticator constructs a new Cursor authenticator.
func NewCursorAuthenticator() Authenticator {
return &CursorAuthenticator{}
}
// Provider returns the provider key for cursor.
func (CursorAuthenticator) Provider() string {
return "cursor"
}
// RefreshLead returns the time before expiry when a refresh should be attempted.
func (CursorAuthenticator) RefreshLead() *time.Duration {
d := 10 * time.Minute
return &d
}
// Login initiates the Cursor PKCE authentication flow.
func (a CursorAuthenticator) Login(ctx context.Context, cfg *config.Config, opts *LoginOptions) (*coreauth.Auth, error) {
if cfg == nil {
return nil, fmt.Errorf("cursor auth: configuration is required")
}
if opts == nil {
opts = &LoginOptions{}
}
// Generate PKCE auth parameters
authParams, err := cursorauth.GenerateAuthParams()
if err != nil {
return nil, fmt.Errorf("cursor: failed to generate auth params: %w", err)
}
// Display the login URL
fmt.Println("Starting Cursor authentication...")
fmt.Printf("\nPlease visit this URL to log in:\n%s\n\n", authParams.LoginURL)
// Try to open the browser automatically
if !opts.NoBrowser {
if browser.IsAvailable() {
if errOpen := browser.OpenURL(authParams.LoginURL); errOpen != nil {
log.Warnf("Failed to open browser automatically: %v", errOpen)
}
}
}
fmt.Println("Waiting for Cursor authorization...")
// Poll for the auth result
tokens, err := cursorauth.PollForAuth(ctx, authParams.UUID, authParams.Verifier)
if err != nil {
return nil, fmt.Errorf("cursor: authentication failed: %w", err)
}
expiresAt := cursorauth.GetTokenExpiry(tokens.AccessToken)
fmt.Println("\nCursor authentication successful!")
metadata := map[string]any{
"type": "cursor",
"access_token": tokens.AccessToken,
"refresh_token": tokens.RefreshToken,
"expires_at": expiresAt.Format(time.RFC3339),
"timestamp": time.Now().UnixMilli(),
}
fileName := "cursor.json"
return &coreauth.Auth{
ID: fileName,
Provider: a.Provider(),
FileName: fileName,
Label: "cursor-user",
Metadata: metadata,
}, nil
}

View File

@@ -19,6 +19,7 @@ func init() {
registerRefreshLead("github-copilot", func() Authenticator { return NewGitHubCopilotAuthenticator() })
registerRefreshLead("gitlab", func() Authenticator { return NewGitLabAuthenticator() })
registerRefreshLead("codebuddy", func() Authenticator { return NewCodeBuddyAuthenticator() })
registerRefreshLead("cursor", func() Authenticator { return NewCursorAuthenticator() })
}
func registerRefreshLead(provider string, factory func() Authenticator) {

View File

@@ -545,6 +545,11 @@ func (m *scheduledAuthMeta) supportsModel(modelKey string) bool {
if modelKey == "" {
return true
}
// Cursor acts as a universal proxy supporting multiple model families.
// Allow any model to be routed to cursor auth.
if m.providerKey == "cursor" {
return true
}
if len(m.supportedModelSet) == 0 {
return false
}

View File

@@ -441,6 +441,8 @@ func (s *Service) ensureExecutorsForAuthWithMode(a *coreauth.Auth, forceReplace
s.coreManager.RegisterExecutor(executor.NewKiroExecutor(s.cfg))
case "kilo":
s.coreManager.RegisterExecutor(executor.NewKiloExecutor(s.cfg))
case "cursor":
s.coreManager.RegisterExecutor(executor.NewCursorExecutor(s.cfg))
case "github-copilot":
s.coreManager.RegisterExecutor(executor.NewGitHubCopilotExecutor(s.cfg))
case "codebuddy":
@@ -942,6 +944,11 @@ func (s *Service) registerModelsForAuth(a *coreauth.Auth) {
case "kimi":
models = registry.GetKimiModels()
models = applyExcludedModels(models, excluded)
case "cursor":
ctx, cancel := context.WithTimeout(context.Background(), 15*time.Second)
defer cancel()
models = executor.FetchCursorModels(ctx, a, s.cfg)
models = applyExcludedModels(models, excluded)
case "github-copilot":
ctx, cancel := context.WithTimeout(context.Background(), 15*time.Second)
defer cancel()

309
test_cursor.sh Executable file
View File

@@ -0,0 +1,309 @@
#!/bin/bash
# Test script for Cursor proxy integration
# Usage:
# ./test_cursor.sh login - Login to Cursor (opens browser)
# ./test_cursor.sh start - Build and start the server
# ./test_cursor.sh test - Run API tests against running server
# ./test_cursor.sh all - Login + Start + Test (full flow)
set -e
export PATH="/opt/homebrew/bin:$PATH"
export GOROOT="/opt/homebrew/Cellar/go/1.26.1/libexec"
PROJECT_DIR="/Volumes/Personal/cursor-cli-proxy/CLIProxyAPIPlus"
BINARY="$PROJECT_DIR/cliproxy-test"
API_KEY="quotio-local-D6ABC285-3085-44B4-B872-BD269888811F"
BASE_URL="http://127.0.0.1:8317"
CONFIG="$PROJECT_DIR/config-cursor-test.yaml"
PID_FILE="/tmp/cliproxy-test.pid"
# Colors
GREEN='\033[0;32m'
RED='\033[0;31m'
YELLOW='\033[1;33m'
NC='\033[0m'
info() { echo -e "${GREEN}[INFO]${NC} $1"; }
warn() { echo -e "${YELLOW}[WARN]${NC} $1"; }
error() { echo -e "${RED}[ERROR]${NC} $1"; }
# --- Build ---
build() {
info "Building CLIProxyAPIPlus..."
cd "$PROJECT_DIR"
go build -o "$BINARY" ./cmd/server/
info "Build successful: $BINARY"
}
# --- Create test config ---
create_config() {
cat > "$CONFIG" << 'EOF'
host: '127.0.0.1'
port: 8317
auth-dir: '~/.cli-proxy-api'
api-keys:
- 'quotio-local-D6ABC285-3085-44B4-B872-BD269888811F'
debug: true
EOF
info "Test config created: $CONFIG"
}
# --- Login ---
do_login() {
build
create_config
info "Starting Cursor login (will open browser)..."
"$BINARY" --config "$CONFIG" --cursor-login
}
# --- Start server ---
start_server() {
# Kill any existing instance
stop_server 2>/dev/null || true
build
create_config
info "Starting server on port 8317..."
"$BINARY" --config "$CONFIG" &
SERVER_PID=$!
echo "$SERVER_PID" > "$PID_FILE"
info "Server started (PID: $SERVER_PID)"
# Wait for server to be ready
info "Waiting for server to be ready..."
for i in $(seq 1 15); do
if curl -s "$BASE_URL/v1/models" -H "Authorization: Bearer $API_KEY" > /dev/null 2>&1; then
info "Server is ready!"
return 0
fi
sleep 1
done
error "Server failed to start within 15 seconds"
return 1
}
# --- Stop server ---
stop_server() {
if [ -f "$PID_FILE" ]; then
PID=$(cat "$PID_FILE")
if kill -0 "$PID" 2>/dev/null; then
info "Stopping server (PID: $PID)..."
kill "$PID"
rm -f "$PID_FILE"
fi
fi
# Also kill any stale process on port 8317
lsof -ti:8317 2>/dev/null | xargs kill 2>/dev/null || true
}
# --- Test: List models ---
test_models() {
info "Testing GET /v1/models (looking for cursor models)..."
RESPONSE=$(curl -s "$BASE_URL/v1/models" \
-H "Authorization: Bearer $API_KEY")
CURSOR_MODELS=$(echo "$RESPONSE" | python3 -c "
import json, sys
try:
data = json.load(sys.stdin)
models = [m['id'] for m in data.get('data', []) if m.get('owned_by') == 'cursor' or m.get('type') == 'cursor']
if models:
print('\n'.join(models))
else:
print('NONE')
except:
print('ERROR')
" 2>/dev/null || echo "PARSE_ERROR")
if [ "$CURSOR_MODELS" = "NONE" ] || [ "$CURSOR_MODELS" = "ERROR" ] || [ "$CURSOR_MODELS" = "PARSE_ERROR" ]; then
warn "No cursor models found. Have you run '--cursor-login' first?"
echo " Response preview: $(echo "$RESPONSE" | head -c 200)"
return 1
else
info "Found cursor models:"
echo "$CURSOR_MODELS" | while read -r model; do
echo " - $model"
done
return 0
fi
}
# --- Test: Chat completion (streaming) ---
test_chat_stream() {
local model="${1:-cursor-small}"
info "Testing POST /v1/chat/completions (stream, model=$model)..."
RESPONSE=$(curl -s --max-time 30 "$BASE_URL/v1/chat/completions" \
-H "Authorization: Bearer $API_KEY" \
-H "Content-Type: application/json" \
-d "{
\"model\": \"$model\",
\"messages\": [{\"role\": \"user\", \"content\": \"Say hello in exactly 3 words.\"}],
\"stream\": true
}" 2>&1)
# Check if we got SSE data
if echo "$RESPONSE" | grep -q "data:"; then
# Extract content from SSE chunks
CONTENT=$(echo "$RESPONSE" | grep "^data: " | grep -v "\[DONE\]" | while read -r line; do
echo "${line#data: }" | python3 -c "
import json, sys
try:
chunk = json.load(sys.stdin)
delta = chunk.get('choices', [{}])[0].get('delta', {})
content = delta.get('content', '')
if content:
sys.stdout.write(content)
except:
pass
" 2>/dev/null
done)
if [ -n "$CONTENT" ]; then
info "Stream response received:"
echo " Content: $CONTENT"
return 0
else
warn "Got SSE chunks but no content extracted"
echo " Raw (first 500 chars): $(echo "$RESPONSE" | head -c 500)"
return 1
fi
else
error "No SSE data received"
echo " Response: $(echo "$RESPONSE" | head -c 300)"
return 1
fi
}
# --- Test: Chat completion (non-streaming) ---
test_chat_nonstream() {
local model="${1:-cursor-small}"
info "Testing POST /v1/chat/completions (non-stream, model=$model)..."
RESPONSE=$(curl -s --max-time 30 "$BASE_URL/v1/chat/completions" \
-H "Authorization: Bearer $API_KEY" \
-H "Content-Type: application/json" \
-d "{
\"model\": \"$model\",
\"messages\": [{\"role\": \"user\", \"content\": \"What is 2+2? Answer with just the number.\"}],
\"stream\": false
}" 2>&1)
CONTENT=$(echo "$RESPONSE" | python3 -c "
import json, sys
try:
data = json.load(sys.stdin)
content = data['choices'][0]['message']['content']
print(content)
except Exception as e:
print(f'ERROR: {e}')
" 2>/dev/null || echo "PARSE_ERROR")
if echo "$CONTENT" | grep -q "ERROR\|PARSE_ERROR"; then
error "Non-streaming request failed"
echo " Response: $(echo "$RESPONSE" | head -c 300)"
return 1
else
info "Non-stream response received:"
echo " Content: $CONTENT"
return 0
fi
}
# --- Run all tests ---
run_tests() {
local passed=0
local failed=0
echo ""
echo "========================================="
echo " Cursor Proxy Integration Tests"
echo "========================================="
echo ""
# Test 1: Models
if test_models; then
((passed++))
else
((failed++))
fi
echo ""
# Test 2: Streaming chat
if test_chat_stream "cursor-small"; then
((passed++))
else
((failed++))
fi
echo ""
# Test 3: Non-streaming chat
if test_chat_nonstream "cursor-small"; then
((passed++))
else
((failed++))
fi
echo ""
echo "========================================="
echo " Results: ${passed} passed, ${failed} failed"
echo "========================================="
[ "$failed" -eq 0 ]
}
# --- Cleanup ---
cleanup() {
stop_server
rm -f "$BINARY" "$CONFIG"
info "Cleaned up."
}
# --- Main ---
case "${1:-help}" in
login)
do_login
;;
start)
start_server
info "Server running. Use './test_cursor.sh test' to run tests."
info "Use './test_cursor.sh stop' to stop."
;;
stop)
stop_server
;;
test)
run_tests
;;
all)
info "=== Full flow: login -> start -> test ==="
echo ""
info "Step 1: Login to Cursor"
do_login
echo ""
info "Step 2: Start server"
start_server
echo ""
info "Step 3: Run tests"
sleep 2
run_tests
echo ""
info "Step 4: Cleanup"
stop_server
;;
clean)
cleanup
;;
*)
echo "Usage: $0 {login|start|stop|test|all|clean}"
echo ""
echo " login - Authenticate with Cursor (opens browser)"
echo " start - Build and start the proxy server"
echo " stop - Stop the running server"
echo " test - Run API tests against running server"
echo " all - Full flow: login + start + test"
echo " clean - Stop server and remove artifacts"
;;
esac