mirror of
https://github.com/router-for-me/CLIProxyAPIPlus.git
synced 2026-04-06 21:08:39 +00:00
When an assistant message appears after tool results without a pending user message, append it to the last turn's assistant text instead of dropping it. Also add bakeToolResultsIntoTurns() to merge tool results into turn context when no active H2 session exists for resume, ensuring the model sees the full tool interaction history in follow-up requests. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
1390 lines
44 KiB
Go
1390 lines
44 KiB
Go
package executor
|
|
|
|
import (
|
|
"bytes"
|
|
"context"
|
|
"crypto/sha256"
|
|
"crypto/tls"
|
|
"encoding/base64"
|
|
"encoding/hex"
|
|
"encoding/json"
|
|
"fmt"
|
|
"io"
|
|
"net/http"
|
|
"strings"
|
|
"sync"
|
|
"time"
|
|
|
|
"github.com/google/uuid"
|
|
cursorauth "github.com/router-for-me/CLIProxyAPI/v6/internal/auth/cursor"
|
|
cursorproto "github.com/router-for-me/CLIProxyAPI/v6/internal/auth/cursor/proto"
|
|
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
|
|
"github.com/router-for-me/CLIProxyAPI/v6/internal/registry"
|
|
cliproxyauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth"
|
|
cliproxyexecutor "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/executor"
|
|
sdktranslator "github.com/router-for-me/CLIProxyAPI/v6/sdk/translator"
|
|
log "github.com/sirupsen/logrus"
|
|
"github.com/tidwall/gjson"
|
|
"golang.org/x/net/http2"
|
|
)
|
|
|
|
const (
|
|
cursorAPIURL = "https://api2.cursor.sh"
|
|
cursorRunPath = "/agent.v1.AgentService/Run"
|
|
cursorModelsPath = "/agent.v1.AgentService/GetUsableModels"
|
|
cursorClientVersion = "cli-2026.02.13-41ac335"
|
|
cursorAuthType = "cursor"
|
|
cursorHeartbeatInterval = 5 * time.Second
|
|
cursorSessionTTL = 5 * time.Minute
|
|
)
|
|
|
|
// CursorExecutor handles requests to the Cursor API via Connect+Protobuf protocol.
|
|
type CursorExecutor struct {
|
|
cfg *config.Config
|
|
mu sync.Mutex
|
|
sessions map[string]*cursorSession
|
|
}
|
|
|
|
type cursorSession struct {
|
|
stream *cursorproto.H2Stream
|
|
blobStore map[string][]byte
|
|
mcpTools []cursorproto.McpToolDef
|
|
pending []pendingMcpExec
|
|
cancel context.CancelFunc // cancels the session-scoped heartbeat (NOT tied to HTTP request)
|
|
createdAt time.Time
|
|
toolResultCh chan []toolResultInfo // receives tool results from the next HTTP request
|
|
resumeOutCh chan cliproxyexecutor.StreamChunk // output channel for resumed response
|
|
switchOutput func(ch chan cliproxyexecutor.StreamChunk) // callback to switch output channel
|
|
}
|
|
|
|
type pendingMcpExec struct {
|
|
ExecMsgId uint32
|
|
ExecId string
|
|
ToolCallId string
|
|
ToolName string
|
|
Args string // JSON-encoded args
|
|
}
|
|
|
|
// NewCursorExecutor constructs a new executor instance.
|
|
func NewCursorExecutor(cfg *config.Config) *CursorExecutor {
|
|
e := &CursorExecutor{
|
|
cfg: cfg,
|
|
sessions: make(map[string]*cursorSession),
|
|
}
|
|
go e.cleanupLoop()
|
|
return e
|
|
}
|
|
|
|
// Identifier implements ProviderExecutor.
|
|
func (e *CursorExecutor) Identifier() string { return cursorAuthType }
|
|
|
|
// CloseExecutionSession implements ExecutionSessionCloser.
|
|
func (e *CursorExecutor) CloseExecutionSession(sessionID string) {
|
|
e.mu.Lock()
|
|
defer e.mu.Unlock()
|
|
if sessionID == cliproxyauth.CloseAllExecutionSessionsID {
|
|
for k, s := range e.sessions {
|
|
s.cancel()
|
|
delete(e.sessions, k)
|
|
}
|
|
return
|
|
}
|
|
if s, ok := e.sessions[sessionID]; ok {
|
|
s.cancel()
|
|
delete(e.sessions, sessionID)
|
|
}
|
|
}
|
|
|
|
func (e *CursorExecutor) cleanupLoop() {
|
|
ticker := time.NewTicker(1 * time.Minute)
|
|
defer ticker.Stop()
|
|
for range ticker.C {
|
|
e.mu.Lock()
|
|
for k, s := range e.sessions {
|
|
if time.Since(s.createdAt) > cursorSessionTTL {
|
|
s.cancel()
|
|
delete(e.sessions, k)
|
|
}
|
|
}
|
|
e.mu.Unlock()
|
|
}
|
|
}
|
|
|
|
// PrepareRequest implements ProviderExecutor (for HttpRequest support).
|
|
func (e *CursorExecutor) PrepareRequest(req *http.Request, auth *cliproxyauth.Auth) error {
|
|
token := cursorAccessToken(auth)
|
|
if token == "" {
|
|
return fmt.Errorf("cursor: access token not found")
|
|
}
|
|
req.Header.Set("Authorization", "Bearer "+token)
|
|
return nil
|
|
}
|
|
|
|
// HttpRequest injects credentials and executes the request.
|
|
func (e *CursorExecutor) HttpRequest(ctx context.Context, auth *cliproxyauth.Auth, req *http.Request) (*http.Response, error) {
|
|
if req == nil {
|
|
return nil, fmt.Errorf("cursor: request is nil")
|
|
}
|
|
if err := e.PrepareRequest(req, auth); err != nil {
|
|
return nil, err
|
|
}
|
|
return http.DefaultClient.Do(req)
|
|
}
|
|
|
|
// CountTokens estimates token count locally using tiktoken.
|
|
func (e *CursorExecutor) CountTokens(_ context.Context, _ *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (resp cliproxyexecutor.Response, err error) {
|
|
defer func() {
|
|
if err != nil {
|
|
log.Warnf("cursor CountTokens error: %v", err)
|
|
} else {
|
|
log.Debugf("cursor CountTokens: model=%s result=%s", req.Model, string(resp.Payload))
|
|
}
|
|
}()
|
|
model := gjson.GetBytes(req.Payload, "model").String()
|
|
if model == "" {
|
|
model = req.Model
|
|
}
|
|
|
|
enc, err := getTokenizer(model)
|
|
if err != nil {
|
|
// Fallback: return zero tokens rather than error (avoids 502)
|
|
return cliproxyexecutor.Response{Payload: buildOpenAIUsageJSON(0)}, nil
|
|
}
|
|
|
|
// Detect format: Claude (/v1/messages) vs OpenAI (/v1/chat/completions)
|
|
var count int64
|
|
if gjson.GetBytes(req.Payload, "system").Exists() || opts.SourceFormat.String() == "claude" {
|
|
count, _ = countClaudeChatTokens(enc, req.Payload)
|
|
} else {
|
|
count, _ = countOpenAIChatTokens(enc, req.Payload)
|
|
}
|
|
|
|
return cliproxyexecutor.Response{Payload: buildOpenAIUsageJSON(count)}, nil
|
|
}
|
|
|
|
// Refresh attempts to refresh the Cursor access token.
|
|
func (e *CursorExecutor) Refresh(ctx context.Context, auth *cliproxyauth.Auth) (*cliproxyauth.Auth, error) {
|
|
refreshToken := cursorRefreshToken(auth)
|
|
if refreshToken == "" {
|
|
return nil, fmt.Errorf("cursor: no refresh token available")
|
|
}
|
|
|
|
tokens, err := cursorauth.RefreshToken(ctx, refreshToken)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
expiresAt := cursorauth.GetTokenExpiry(tokens.AccessToken)
|
|
|
|
newAuth := auth.Clone()
|
|
newAuth.Metadata["access_token"] = tokens.AccessToken
|
|
newAuth.Metadata["refresh_token"] = tokens.RefreshToken
|
|
newAuth.Metadata["expires_at"] = expiresAt.Format(time.RFC3339)
|
|
return newAuth, nil
|
|
}
|
|
|
|
// Execute handles non-streaming requests.
|
|
func (e *CursorExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (resp cliproxyexecutor.Response, err error) {
|
|
log.Debugf("cursor Execute: model=%s sourceFormat=%s payloadLen=%d", req.Model, opts.SourceFormat, len(req.Payload))
|
|
defer func() {
|
|
if r := recover(); r != nil {
|
|
log.Errorf("cursor Execute PANIC: %v", r)
|
|
err = fmt.Errorf("cursor: internal panic: %v", r)
|
|
}
|
|
if err != nil {
|
|
log.Warnf("cursor Execute error: %v", err)
|
|
}
|
|
}()
|
|
accessToken := cursorAccessToken(auth)
|
|
if accessToken == "" {
|
|
return resp, fmt.Errorf("cursor: access token not found")
|
|
}
|
|
|
|
// Translate input to OpenAI format if needed (e.g. Claude /v1/messages format)
|
|
from := opts.SourceFormat
|
|
to := sdktranslator.FromString("openai")
|
|
payload := req.Payload
|
|
if from.String() != "" && from.String() != "openai" {
|
|
payload = sdktranslator.TranslateRequest(from, to, req.Model, bytes.Clone(payload), false)
|
|
}
|
|
|
|
parsed := parseOpenAIRequest(payload)
|
|
params := buildRunRequestParams(parsed)
|
|
|
|
requestBytes := cursorproto.EncodeRunRequest(params)
|
|
framedRequest := cursorproto.FrameConnectMessage(requestBytes, 0)
|
|
|
|
stream, err := openCursorH2Stream(accessToken)
|
|
if err != nil {
|
|
return resp, err
|
|
}
|
|
defer stream.Close()
|
|
|
|
// Send the request frame
|
|
if err := stream.Write(framedRequest); err != nil {
|
|
return resp, fmt.Errorf("cursor: failed to send request: %w", err)
|
|
}
|
|
|
|
// Start heartbeat
|
|
sessionCtx, sessionCancel := context.WithCancel(ctx)
|
|
defer sessionCancel()
|
|
go cursorH2Heartbeat(sessionCtx, stream)
|
|
|
|
// Collect full text from streaming response
|
|
var fullText strings.Builder
|
|
processH2SessionFrames(sessionCtx, stream, params.BlobStore, nil,
|
|
func(text string, isThinking bool) {
|
|
fullText.WriteString(text)
|
|
},
|
|
nil,
|
|
nil,
|
|
nil, // tokenUsage - non-streaming
|
|
)
|
|
|
|
id := "chatcmpl-" + uuid.New().String()[:28]
|
|
created := time.Now().Unix()
|
|
openaiResp := fmt.Sprintf(`{"id":"%s","object":"chat.completion","created":%d,"model":"%s","choices":[{"index":0,"message":{"role":"assistant","content":%s},"finish_reason":"stop"}],"usage":{"prompt_tokens":0,"completion_tokens":0,"total_tokens":0}}`,
|
|
id, created, parsed.Model, jsonString(fullText.String()))
|
|
|
|
// Translate response back to source format if needed
|
|
result := []byte(openaiResp)
|
|
if from.String() != "" && from.String() != "openai" {
|
|
var param any
|
|
result = sdktranslator.TranslateNonStream(ctx, to, from, req.Model, bytes.Clone(opts.OriginalRequest), payload, result, ¶m)
|
|
}
|
|
resp.Payload = result
|
|
return resp, nil
|
|
}
|
|
|
|
// ExecuteStream handles streaming requests.
|
|
// It supports MCP tool call sessions: when Cursor returns an MCP tool call,
|
|
// the H2 stream is kept alive. When Claude Code returns the tool result in
|
|
// the next request, the result is sent back on the same stream (session resume).
|
|
// This mirrors the activeSessions/resumeWithToolResults pattern in cursor-fetch.ts.
|
|
func (e *CursorExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (_ *cliproxyexecutor.StreamResult, err error) {
|
|
log.Debugf("cursor ExecuteStream: model=%s sourceFormat=%s payloadLen=%d", req.Model, opts.SourceFormat, len(req.Payload))
|
|
defer func() {
|
|
if r := recover(); r != nil {
|
|
log.Errorf("cursor ExecuteStream PANIC: %v", r)
|
|
err = fmt.Errorf("cursor: internal panic: %v", r)
|
|
}
|
|
if err != nil {
|
|
log.Warnf("cursor ExecuteStream error: %v", err)
|
|
}
|
|
}()
|
|
accessToken := cursorAccessToken(auth)
|
|
if accessToken == "" {
|
|
return nil, fmt.Errorf("cursor: access token not found")
|
|
}
|
|
|
|
// Translate input to OpenAI format if needed
|
|
from := opts.SourceFormat
|
|
to := sdktranslator.FromString("openai")
|
|
payload := req.Payload
|
|
originalPayload := bytes.Clone(req.Payload)
|
|
if len(opts.OriginalRequest) > 0 {
|
|
originalPayload = bytes.Clone(opts.OriginalRequest)
|
|
}
|
|
if from.String() != "" && from.String() != "openai" {
|
|
log.Debugf("cursor: translating request from %s to openai", from)
|
|
payload = sdktranslator.TranslateRequest(from, to, req.Model, bytes.Clone(payload), true)
|
|
log.Debugf("cursor: translated payload len=%d", len(payload))
|
|
}
|
|
|
|
parsed := parseOpenAIRequest(payload)
|
|
log.Debugf("cursor: parsed request: model=%s userText=%d chars, turns=%d, tools=%d, toolResults=%d",
|
|
parsed.Model, len(parsed.UserText), len(parsed.Turns), len(parsed.Tools), len(parsed.ToolResults))
|
|
|
|
sessionKey := deriveSessionKey(parsed.Model, parsed.Messages)
|
|
needsTranslate := from.String() != "" && from.String() != "openai"
|
|
|
|
// Check if we can resume an existing session with tool results
|
|
if len(parsed.ToolResults) > 0 {
|
|
e.mu.Lock()
|
|
session, hasSession := e.sessions[sessionKey]
|
|
if hasSession {
|
|
delete(e.sessions, sessionKey)
|
|
}
|
|
e.mu.Unlock()
|
|
|
|
if hasSession && session.stream != nil {
|
|
log.Debugf("cursor: resuming session %s with %d tool results", sessionKey, len(parsed.ToolResults))
|
|
return e.resumeWithToolResults(ctx, session, parsed, from, to, req, originalPayload, payload, needsTranslate)
|
|
}
|
|
}
|
|
|
|
// Clean up any stale session for this key
|
|
e.mu.Lock()
|
|
if old, ok := e.sessions[sessionKey]; ok {
|
|
old.cancel()
|
|
delete(e.sessions, sessionKey)
|
|
}
|
|
e.mu.Unlock()
|
|
|
|
// If tool results exist but no session to resume, bake them into turns
|
|
// so the model sees tool interaction context in the new conversation.
|
|
if len(parsed.ToolResults) > 0 {
|
|
log.Debugf("cursor: no session to resume, baking %d tool results into turns", len(parsed.ToolResults))
|
|
bakeToolResultsIntoTurns(parsed)
|
|
}
|
|
|
|
params := buildRunRequestParams(parsed)
|
|
requestBytes := cursorproto.EncodeRunRequest(params)
|
|
framedRequest := cursorproto.FrameConnectMessage(requestBytes, 0)
|
|
|
|
stream, err := openCursorH2Stream(accessToken)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
if err := stream.Write(framedRequest); err != nil {
|
|
stream.Close()
|
|
return nil, fmt.Errorf("cursor: failed to send request: %w", err)
|
|
}
|
|
|
|
// Use a session-scoped context for the heartbeat that is NOT tied to the HTTP request.
|
|
// This ensures the heartbeat survives across request boundaries during MCP tool execution.
|
|
// Mirrors the TS plugin's setInterval-based heartbeat that lives independently of HTTP responses.
|
|
sessionCtx, sessionCancel := context.WithCancel(context.Background())
|
|
go cursorH2Heartbeat(sessionCtx, stream)
|
|
|
|
chunks := make(chan cliproxyexecutor.StreamChunk, 64)
|
|
chatId := "chatcmpl-" + uuid.New().String()[:28]
|
|
created := time.Now().Unix()
|
|
|
|
var streamParam any
|
|
|
|
// Tool result channel for inline mode. processH2SessionFrames blocks on it
|
|
// when mcpArgs is received, while continuing to handle KV/heartbeat.
|
|
toolResultCh := make(chan []toolResultInfo, 1)
|
|
|
|
// Switchable output: initially writes to `chunks`. After mcpArgs, the
|
|
// onMcpExec callback closes `chunks` (ending the first HTTP response),
|
|
// then processH2SessionFrames blocks on toolResultCh. When results arrive,
|
|
// it switches to `resumeOutCh` (created by resumeWithToolResults).
|
|
var outMu sync.Mutex
|
|
currentOut := chunks
|
|
|
|
emitToOut := func(chunk cliproxyexecutor.StreamChunk) {
|
|
outMu.Lock()
|
|
out := currentOut
|
|
outMu.Unlock()
|
|
if out != nil {
|
|
out <- chunk
|
|
}
|
|
}
|
|
|
|
// Wrap sendChunk/sendDone to use emitToOut
|
|
sendChunkSwitchable := func(delta string, finishReason string) {
|
|
fr := "null"
|
|
if finishReason != "" {
|
|
fr = finishReason
|
|
}
|
|
openaiJSON := fmt.Sprintf(`{"id":"%s","object":"chat.completion.chunk","created":%d,"model":"%s","choices":[{"index":0,"delta":%s,"finish_reason":%s}]}`,
|
|
chatId, created, parsed.Model, delta, fr)
|
|
sseLine := []byte("data: " + openaiJSON + "\n")
|
|
|
|
if needsTranslate {
|
|
translated := sdktranslator.TranslateStream(ctx, to, from, req.Model, originalPayload, payload, sseLine, &streamParam)
|
|
for _, t := range translated {
|
|
emitToOut(cliproxyexecutor.StreamChunk{Payload: bytes.Clone(t)})
|
|
}
|
|
} else {
|
|
emitToOut(cliproxyexecutor.StreamChunk{Payload: []byte(openaiJSON)})
|
|
}
|
|
}
|
|
|
|
sendDoneSwitchable := func() {
|
|
if needsTranslate {
|
|
done := sdktranslator.TranslateStream(ctx, to, from, req.Model, originalPayload, payload, []byte("data: [DONE]\n"), &streamParam)
|
|
for _, d := range done {
|
|
emitToOut(cliproxyexecutor.StreamChunk{Payload: bytes.Clone(d)})
|
|
}
|
|
} else {
|
|
emitToOut(cliproxyexecutor.StreamChunk{Payload: []byte("[DONE]")})
|
|
}
|
|
}
|
|
|
|
go func() {
|
|
var resumeOutCh chan cliproxyexecutor.StreamChunk
|
|
_ = resumeOutCh
|
|
thinkingActive := false
|
|
toolCallIndex := 0
|
|
usage := &cursorTokenUsage{}
|
|
usage.setInputEstimate(len(payload))
|
|
|
|
processH2SessionFrames(sessionCtx, stream, params.BlobStore, params.McpTools,
|
|
func(text string, isThinking bool) {
|
|
if isThinking {
|
|
if !thinkingActive {
|
|
thinkingActive = true
|
|
sendChunkSwitchable(`{"role":"assistant","content":"<think>"}`, "")
|
|
}
|
|
sendChunkSwitchable(fmt.Sprintf(`{"content":%s}`, jsonString(text)), "")
|
|
} else {
|
|
if thinkingActive {
|
|
thinkingActive = false
|
|
sendChunkSwitchable(`{"content":"</think>"}`, "")
|
|
}
|
|
sendChunkSwitchable(fmt.Sprintf(`{"content":%s}`, jsonString(text)), "")
|
|
}
|
|
},
|
|
func(exec pendingMcpExec) {
|
|
if thinkingActive {
|
|
thinkingActive = false
|
|
sendChunkSwitchable(`{"content":"</think>"}`, "")
|
|
}
|
|
toolCallJSON := fmt.Sprintf(`{"tool_calls":[{"index":%d,"id":"%s","type":"function","function":{"name":"%s","arguments":%s}}]}`,
|
|
toolCallIndex, exec.ToolCallId, exec.ToolName, jsonString(exec.Args))
|
|
toolCallIndex++
|
|
sendChunkSwitchable(toolCallJSON, "")
|
|
sendChunkSwitchable(`{}`, `"tool_calls"`)
|
|
sendDoneSwitchable()
|
|
|
|
// Close current output to end the current HTTP SSE response
|
|
outMu.Lock()
|
|
if currentOut != nil {
|
|
close(currentOut)
|
|
currentOut = nil
|
|
}
|
|
outMu.Unlock()
|
|
|
|
// Create new resume output channel, reuse the same toolResultCh
|
|
resumeOut := make(chan cliproxyexecutor.StreamChunk, 64)
|
|
log.Debugf("cursor: saving session %s for MCP tool resume (tool=%s)", sessionKey, exec.ToolName)
|
|
e.mu.Lock()
|
|
e.sessions[sessionKey] = &cursorSession{
|
|
stream: stream,
|
|
blobStore: params.BlobStore,
|
|
mcpTools: params.McpTools,
|
|
pending: []pendingMcpExec{exec},
|
|
cancel: sessionCancel,
|
|
createdAt: time.Now(),
|
|
toolResultCh: toolResultCh, // reuse same channel across rounds
|
|
resumeOutCh: resumeOut,
|
|
switchOutput: func(ch chan cliproxyexecutor.StreamChunk) {
|
|
outMu.Lock()
|
|
currentOut = ch
|
|
// Reset translator state so the new HTTP response gets
|
|
// a fresh message_start, content_block_start, etc.
|
|
streamParam = nil
|
|
// New response needs its own message ID
|
|
chatId = "chatcmpl-" + uuid.New().String()[:28]
|
|
created = time.Now().Unix()
|
|
outMu.Unlock()
|
|
},
|
|
}
|
|
e.mu.Unlock()
|
|
resumeOutCh = resumeOut
|
|
|
|
// processH2SessionFrames will now block on toolResultCh (inline wait loop)
|
|
// while continuing to handle KV messages
|
|
},
|
|
toolResultCh,
|
|
usage,
|
|
)
|
|
|
|
// processH2SessionFrames returned — stream is done
|
|
if thinkingActive {
|
|
sendChunkSwitchable(`{"content":"</think>"}`, "")
|
|
}
|
|
// Include token usage in the final stop chunk
|
|
inputTok, outputTok := usage.get()
|
|
stopDelta := fmt.Sprintf(`{},"usage":{"prompt_tokens":%d,"completion_tokens":%d,"total_tokens":%d}`,
|
|
inputTok, outputTok, inputTok+outputTok)
|
|
// Build the stop chunk with usage embedded in the choices array level
|
|
fr := `"stop"`
|
|
openaiJSON := fmt.Sprintf(`{"id":"%s","object":"chat.completion.chunk","created":%d,"model":"%s","choices":[{"index":0,"delta":{},"finish_reason":%s}],"usage":{"prompt_tokens":%d,"completion_tokens":%d,"total_tokens":%d}}`,
|
|
chatId, created, parsed.Model, fr, inputTok, outputTok, inputTok+outputTok)
|
|
sseLine := []byte("data: " + openaiJSON + "\n")
|
|
if needsTranslate {
|
|
translated := sdktranslator.TranslateStream(ctx, to, from, req.Model, originalPayload, payload, sseLine, &streamParam)
|
|
for _, t := range translated {
|
|
emitToOut(cliproxyexecutor.StreamChunk{Payload: bytes.Clone(t)})
|
|
}
|
|
} else {
|
|
emitToOut(cliproxyexecutor.StreamChunk{Payload: []byte(openaiJSON)})
|
|
}
|
|
sendDoneSwitchable()
|
|
_ = stopDelta // unused
|
|
|
|
// Close whatever output channel is still active
|
|
outMu.Lock()
|
|
if currentOut != nil {
|
|
close(currentOut)
|
|
currentOut = nil
|
|
}
|
|
outMu.Unlock()
|
|
sessionCancel()
|
|
stream.Close()
|
|
}()
|
|
|
|
return &cliproxyexecutor.StreamResult{Chunks: chunks}, nil
|
|
}
|
|
|
|
// resumeWithToolResults injects tool results into the running processH2SessionFrames
|
|
// via the toolResultCh channel. The original goroutine from ExecuteStream is still alive,
|
|
// blocking on toolResultCh. Once we send the results, it sends the MCP result to Cursor
|
|
// and continues processing the response text — all in the same goroutine that has been
|
|
// handling KV messages the whole time.
|
|
func (e *CursorExecutor) resumeWithToolResults(
|
|
ctx context.Context,
|
|
session *cursorSession,
|
|
parsed *parsedOpenAIRequest,
|
|
from, to sdktranslator.Format,
|
|
req cliproxyexecutor.Request,
|
|
originalPayload, payload []byte,
|
|
needsTranslate bool,
|
|
) (*cliproxyexecutor.StreamResult, error) {
|
|
log.Debugf("cursor: resumeWithToolResults: injecting %d tool results via channel", len(parsed.ToolResults))
|
|
|
|
if session.toolResultCh == nil {
|
|
return nil, fmt.Errorf("cursor: session has no toolResultCh (stale session?)")
|
|
}
|
|
if session.resumeOutCh == nil {
|
|
return nil, fmt.Errorf("cursor: session has no resumeOutCh")
|
|
}
|
|
|
|
log.Debugf("cursor: resumeWithToolResults: switching output to resumeOutCh and injecting results")
|
|
|
|
// Switch the output channel BEFORE injecting results, so that when
|
|
// processH2SessionFrames unblocks and starts emitting text, it writes
|
|
// to the resumeOutCh which the new HTTP handler is reading from.
|
|
if session.switchOutput != nil {
|
|
session.switchOutput(session.resumeOutCh)
|
|
}
|
|
|
|
// Inject tool results — this unblocks the waiting processH2SessionFrames
|
|
session.toolResultCh <- parsed.ToolResults
|
|
|
|
// Return the resumeOutCh for the new HTTP handler to read from
|
|
return &cliproxyexecutor.StreamResult{Chunks: session.resumeOutCh}, nil
|
|
}
|
|
|
|
// --- H2Stream helpers ---
|
|
|
|
func openCursorH2Stream(accessToken string) (*cursorproto.H2Stream, error) {
|
|
headers := map[string]string{
|
|
":path": cursorRunPath,
|
|
"content-type": "application/connect+proto",
|
|
"connect-protocol-version": "1",
|
|
"te": "trailers",
|
|
"authorization": "Bearer " + accessToken,
|
|
"x-ghost-mode": "true",
|
|
"x-cursor-client-version": cursorClientVersion,
|
|
"x-cursor-client-type": "cli",
|
|
"x-request-id": uuid.New().String(),
|
|
}
|
|
return cursorproto.DialH2Stream("api2.cursor.sh", headers)
|
|
}
|
|
|
|
func cursorH2Heartbeat(ctx context.Context, stream *cursorproto.H2Stream) {
|
|
ticker := time.NewTicker(cursorHeartbeatInterval)
|
|
defer ticker.Stop()
|
|
for {
|
|
select {
|
|
case <-ctx.Done():
|
|
return
|
|
case <-ticker.C:
|
|
hb := cursorproto.EncodeHeartbeat()
|
|
frame := cursorproto.FrameConnectMessage(hb, 0)
|
|
if err := stream.Write(frame); err != nil {
|
|
return
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
// --- Response processing ---
|
|
|
|
// cursorTokenUsage tracks token counts from Cursor's TokenDeltaUpdate messages.
|
|
type cursorTokenUsage struct {
|
|
mu sync.Mutex
|
|
outputTokens int64
|
|
inputTokensEst int64 // estimated from request payload size
|
|
}
|
|
|
|
func (u *cursorTokenUsage) addOutput(delta int64) {
|
|
u.mu.Lock()
|
|
defer u.mu.Unlock()
|
|
u.outputTokens += delta
|
|
}
|
|
|
|
func (u *cursorTokenUsage) setInputEstimate(payloadBytes int) {
|
|
u.mu.Lock()
|
|
defer u.mu.Unlock()
|
|
// Rough estimate: ~4 bytes per token for mixed content
|
|
u.inputTokensEst = int64(payloadBytes / 4)
|
|
if u.inputTokensEst < 1 {
|
|
u.inputTokensEst = 1
|
|
}
|
|
}
|
|
|
|
func (u *cursorTokenUsage) get() (input, output int64) {
|
|
u.mu.Lock()
|
|
defer u.mu.Unlock()
|
|
return u.inputTokensEst, u.outputTokens
|
|
}
|
|
|
|
func processH2SessionFrames(
|
|
ctx context.Context,
|
|
stream *cursorproto.H2Stream,
|
|
blobStore map[string][]byte,
|
|
mcpTools []cursorproto.McpToolDef,
|
|
onText func(text string, isThinking bool),
|
|
onMcpExec func(exec pendingMcpExec),
|
|
toolResultCh <-chan []toolResultInfo, // nil for no tool result injection; non-nil to wait for results
|
|
tokenUsage *cursorTokenUsage, // tracks accumulated token usage (may be nil)
|
|
) {
|
|
var buf bytes.Buffer
|
|
rejectReason := "Tool not available in this environment. Use the MCP tools provided instead."
|
|
log.Debugf("cursor: processH2SessionFrames started for streamID=%s, waiting for data...", stream.ID())
|
|
for {
|
|
select {
|
|
case <-ctx.Done():
|
|
log.Debugf("cursor: processH2SessionFrames exiting: context done")
|
|
return
|
|
case data, ok := <-stream.Data():
|
|
if !ok {
|
|
log.Debugf("cursor: processH2SessionFrames[%s]: exiting: stream data channel closed", stream.ID())
|
|
return
|
|
}
|
|
// Log first 20 bytes of raw data for debugging
|
|
previewLen := min(20, len(data))
|
|
log.Debugf("cursor: processH2SessionFrames[%s]: received %d bytes from dataCh, first bytes: %x (%q)", stream.ID(), len(data), data[:previewLen], string(data[:previewLen]))
|
|
buf.Write(data)
|
|
log.Debugf("cursor: processH2SessionFrames[%s]: buf total=%d", stream.ID(), buf.Len())
|
|
|
|
// Process all complete frames
|
|
for {
|
|
currentBuf := buf.Bytes()
|
|
if len(currentBuf) == 0 {
|
|
break
|
|
}
|
|
flags, payload, consumed, ok := cursorproto.ParseConnectFrame(currentBuf)
|
|
if !ok {
|
|
// Log detailed info about why parsing failed
|
|
previewLen := min(20, len(currentBuf))
|
|
log.Debugf("cursor: incomplete frame in buffer, waiting for more data (buf=%d bytes, first bytes: %x = %q)", len(currentBuf), currentBuf[:previewLen], string(currentBuf[:previewLen]))
|
|
break
|
|
}
|
|
buf.Next(consumed)
|
|
log.Debugf("cursor: parsed Connect frame flags=0x%02x payload=%d bytes consumed=%d", flags, len(payload), consumed)
|
|
|
|
if flags&cursorproto.ConnectEndStreamFlag != 0 {
|
|
if err := cursorproto.ParseConnectEndStream(payload); err != nil {
|
|
log.Warnf("cursor: connect end stream error: %v", err)
|
|
}
|
|
continue
|
|
}
|
|
|
|
msg, err := cursorproto.DecodeAgentServerMessage(payload)
|
|
if err != nil {
|
|
log.Debugf("cursor: failed to decode server message: %v", err)
|
|
continue
|
|
}
|
|
|
|
log.Debugf("cursor: decoded server message type=%d", msg.Type)
|
|
switch msg.Type {
|
|
case cursorproto.ServerMsgTextDelta:
|
|
if msg.Text != "" && onText != nil {
|
|
onText(msg.Text, false)
|
|
}
|
|
case cursorproto.ServerMsgThinkingDelta:
|
|
if msg.Text != "" && onText != nil {
|
|
onText(msg.Text, true)
|
|
}
|
|
case cursorproto.ServerMsgThinkingCompleted:
|
|
// Handled by caller
|
|
|
|
case cursorproto.ServerMsgTurnEnded:
|
|
log.Debugf("cursor: TurnEnded received, stream will finish")
|
|
return
|
|
|
|
case cursorproto.ServerMsgHeartbeat:
|
|
// Server heartbeat, ignore silently
|
|
continue
|
|
|
|
case cursorproto.ServerMsgTokenDelta:
|
|
if tokenUsage != nil && msg.TokenDelta > 0 {
|
|
tokenUsage.addOutput(msg.TokenDelta)
|
|
}
|
|
continue
|
|
|
|
case cursorproto.ServerMsgKvGetBlob:
|
|
blobKey := cursorproto.BlobIdHex(msg.BlobId)
|
|
data := blobStore[blobKey]
|
|
resp := cursorproto.EncodeKvGetBlobResult(msg.KvId, data)
|
|
stream.Write(cursorproto.FrameConnectMessage(resp, 0))
|
|
|
|
case cursorproto.ServerMsgKvSetBlob:
|
|
blobKey := cursorproto.BlobIdHex(msg.BlobId)
|
|
blobStore[blobKey] = append([]byte(nil), msg.BlobData...)
|
|
resp := cursorproto.EncodeKvSetBlobResult(msg.KvId)
|
|
stream.Write(cursorproto.FrameConnectMessage(resp, 0))
|
|
|
|
case cursorproto.ServerMsgExecRequestCtx:
|
|
resp := cursorproto.EncodeExecRequestContextResult(msg.ExecMsgId, msg.ExecId, mcpTools)
|
|
stream.Write(cursorproto.FrameConnectMessage(resp, 0))
|
|
|
|
case cursorproto.ServerMsgExecMcpArgs:
|
|
if onMcpExec != nil {
|
|
decodedArgs := decodeMcpArgsToJSON(msg.McpArgs)
|
|
toolCallId := msg.McpToolCallId
|
|
if toolCallId == "" {
|
|
toolCallId = uuid.New().String()
|
|
}
|
|
log.Debugf("cursor: received mcpArgs from server: execMsgId=%d execId=%q toolName=%s toolCallId=%s",
|
|
msg.ExecMsgId, msg.ExecId, msg.McpToolName, toolCallId)
|
|
pending := pendingMcpExec{
|
|
ExecMsgId: msg.ExecMsgId,
|
|
ExecId: msg.ExecId,
|
|
ToolCallId: toolCallId,
|
|
ToolName: msg.McpToolName,
|
|
Args: decodedArgs,
|
|
}
|
|
onMcpExec(pending)
|
|
|
|
if toolResultCh == nil {
|
|
return
|
|
}
|
|
|
|
// Inline mode: wait for tool result while handling KV/heartbeat
|
|
log.Debugf("cursor: waiting for tool result on channel (inline mode)...")
|
|
var toolResults []toolResultInfo
|
|
waitLoop:
|
|
for {
|
|
select {
|
|
case <-ctx.Done():
|
|
return
|
|
case results, ok := <-toolResultCh:
|
|
if !ok {
|
|
return
|
|
}
|
|
toolResults = results
|
|
break waitLoop
|
|
case waitData, ok := <-stream.Data():
|
|
if !ok {
|
|
return
|
|
}
|
|
buf.Write(waitData)
|
|
for {
|
|
cb := buf.Bytes()
|
|
if len(cb) == 0 {
|
|
break
|
|
}
|
|
wf, wp, wc, wok := cursorproto.ParseConnectFrame(cb)
|
|
if !wok {
|
|
break
|
|
}
|
|
buf.Next(wc)
|
|
if wf&cursorproto.ConnectEndStreamFlag != 0 {
|
|
continue
|
|
}
|
|
wmsg, werr := cursorproto.DecodeAgentServerMessage(wp)
|
|
if werr != nil {
|
|
continue
|
|
}
|
|
switch wmsg.Type {
|
|
case cursorproto.ServerMsgKvGetBlob:
|
|
blobKey := cursorproto.BlobIdHex(wmsg.BlobId)
|
|
d := blobStore[blobKey]
|
|
stream.Write(cursorproto.FrameConnectMessage(cursorproto.EncodeKvGetBlobResult(wmsg.KvId, d), 0))
|
|
case cursorproto.ServerMsgKvSetBlob:
|
|
blobKey := cursorproto.BlobIdHex(wmsg.BlobId)
|
|
blobStore[blobKey] = append([]byte(nil), wmsg.BlobData...)
|
|
stream.Write(cursorproto.FrameConnectMessage(cursorproto.EncodeKvSetBlobResult(wmsg.KvId), 0))
|
|
case cursorproto.ServerMsgExecRequestCtx:
|
|
stream.Write(cursorproto.FrameConnectMessage(cursorproto.EncodeExecRequestContextResult(wmsg.ExecMsgId, wmsg.ExecId, mcpTools), 0))
|
|
}
|
|
}
|
|
case <-stream.Done():
|
|
return
|
|
}
|
|
}
|
|
|
|
// Send MCP result
|
|
for _, tr := range toolResults {
|
|
if tr.ToolCallId == pending.ToolCallId {
|
|
log.Debugf("cursor: sending inline MCP result for tool=%s", pending.ToolName)
|
|
resultBytes := cursorproto.EncodeExecMcpResult(pending.ExecMsgId, pending.ExecId, tr.Content, false)
|
|
stream.Write(cursorproto.FrameConnectMessage(resultBytes, 0))
|
|
break
|
|
}
|
|
}
|
|
continue
|
|
}
|
|
|
|
case cursorproto.ServerMsgExecReadArgs:
|
|
stream.Write(cursorproto.FrameConnectMessage(cursorproto.EncodeExecReadRejected(msg.ExecMsgId, msg.ExecId, msg.Path, rejectReason), 0))
|
|
case cursorproto.ServerMsgExecWriteArgs:
|
|
stream.Write(cursorproto.FrameConnectMessage(cursorproto.EncodeExecWriteRejected(msg.ExecMsgId, msg.ExecId, msg.Path, rejectReason), 0))
|
|
case cursorproto.ServerMsgExecDeleteArgs:
|
|
stream.Write(cursorproto.FrameConnectMessage(cursorproto.EncodeExecDeleteRejected(msg.ExecMsgId, msg.ExecId, msg.Path, rejectReason), 0))
|
|
case cursorproto.ServerMsgExecLsArgs:
|
|
stream.Write(cursorproto.FrameConnectMessage(cursorproto.EncodeExecLsRejected(msg.ExecMsgId, msg.ExecId, msg.Path, rejectReason), 0))
|
|
case cursorproto.ServerMsgExecGrepArgs:
|
|
stream.Write(cursorproto.FrameConnectMessage(cursorproto.EncodeExecGrepError(msg.ExecMsgId, msg.ExecId, rejectReason), 0))
|
|
case cursorproto.ServerMsgExecShellArgs, cursorproto.ServerMsgExecShellStream:
|
|
stream.Write(cursorproto.FrameConnectMessage(cursorproto.EncodeExecShellRejected(msg.ExecMsgId, msg.ExecId, msg.Command, msg.WorkingDirectory, rejectReason), 0))
|
|
case cursorproto.ServerMsgExecBgShellSpawn:
|
|
stream.Write(cursorproto.FrameConnectMessage(cursorproto.EncodeExecBackgroundShellSpawnRejected(msg.ExecMsgId, msg.ExecId, msg.Command, msg.WorkingDirectory, rejectReason), 0))
|
|
case cursorproto.ServerMsgExecFetchArgs:
|
|
stream.Write(cursorproto.FrameConnectMessage(cursorproto.EncodeExecFetchError(msg.ExecMsgId, msg.ExecId, msg.Url, rejectReason), 0))
|
|
case cursorproto.ServerMsgExecDiagnostics:
|
|
stream.Write(cursorproto.FrameConnectMessage(cursorproto.EncodeExecDiagnosticsResult(msg.ExecMsgId, msg.ExecId), 0))
|
|
case cursorproto.ServerMsgExecWriteShellStdin:
|
|
stream.Write(cursorproto.FrameConnectMessage(cursorproto.EncodeExecWriteShellStdinError(msg.ExecMsgId, msg.ExecId, rejectReason), 0))
|
|
}
|
|
}
|
|
|
|
case <-stream.Done():
|
|
log.Debugf("cursor: processH2SessionFrames exiting: stream done")
|
|
return
|
|
}
|
|
}
|
|
}
|
|
|
|
// --- OpenAI request parsing ---
|
|
|
|
type parsedOpenAIRequest struct {
|
|
Model string
|
|
Messages []gjson.Result
|
|
Tools []gjson.Result
|
|
Stream bool
|
|
SystemPrompt string
|
|
UserText string
|
|
Images []cursorproto.ImageData
|
|
Turns []cursorproto.TurnData
|
|
ToolResults []toolResultInfo
|
|
}
|
|
|
|
type toolResultInfo struct {
|
|
ToolCallId string
|
|
Content string
|
|
}
|
|
|
|
func parseOpenAIRequest(payload []byte) *parsedOpenAIRequest {
|
|
p := &parsedOpenAIRequest{
|
|
Model: gjson.GetBytes(payload, "model").String(),
|
|
Stream: gjson.GetBytes(payload, "stream").Bool(),
|
|
}
|
|
|
|
messages := gjson.GetBytes(payload, "messages").Array()
|
|
p.Messages = messages
|
|
|
|
// Extract system prompt
|
|
var systemParts []string
|
|
for _, msg := range messages {
|
|
if msg.Get("role").String() == "system" {
|
|
systemParts = append(systemParts, extractTextContent(msg.Get("content")))
|
|
}
|
|
}
|
|
if len(systemParts) > 0 {
|
|
p.SystemPrompt = strings.Join(systemParts, "\n")
|
|
} else {
|
|
p.SystemPrompt = "You are a helpful assistant."
|
|
}
|
|
|
|
// Extract turns, tool results, and last user message
|
|
var pendingUser string
|
|
for _, msg := range messages {
|
|
role := msg.Get("role").String()
|
|
switch role {
|
|
case "system":
|
|
continue
|
|
case "tool":
|
|
p.ToolResults = append(p.ToolResults, toolResultInfo{
|
|
ToolCallId: msg.Get("tool_call_id").String(),
|
|
Content: extractTextContent(msg.Get("content")),
|
|
})
|
|
case "user":
|
|
if pendingUser != "" {
|
|
p.Turns = append(p.Turns, cursorproto.TurnData{UserText: pendingUser})
|
|
}
|
|
pendingUser = extractTextContent(msg.Get("content"))
|
|
p.Images = extractImages(msg.Get("content"))
|
|
case "assistant":
|
|
assistantText := extractTextContent(msg.Get("content"))
|
|
if pendingUser != "" {
|
|
p.Turns = append(p.Turns, cursorproto.TurnData{
|
|
UserText: pendingUser,
|
|
AssistantText: assistantText,
|
|
})
|
|
pendingUser = ""
|
|
} else if len(p.Turns) > 0 && assistantText != "" {
|
|
// Assistant message after tool results (no pending user) —
|
|
// append to the last turn's assistant text to preserve context.
|
|
last := &p.Turns[len(p.Turns)-1]
|
|
if last.AssistantText != "" {
|
|
last.AssistantText += "\n" + assistantText
|
|
} else {
|
|
last.AssistantText = assistantText
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
if pendingUser != "" {
|
|
p.UserText = pendingUser
|
|
} else if len(p.Turns) > 0 && len(p.ToolResults) == 0 {
|
|
last := p.Turns[len(p.Turns)-1]
|
|
p.Turns = p.Turns[:len(p.Turns)-1]
|
|
p.UserText = last.UserText
|
|
}
|
|
|
|
// Extract tools
|
|
p.Tools = gjson.GetBytes(payload, "tools").Array()
|
|
|
|
return p
|
|
}
|
|
|
|
// bakeToolResultsIntoTurns merges tool results into the last turn's assistant text
|
|
// when there's no active H2 session to resume. This ensures the model sees the
|
|
// full tool interaction context in a new conversation.
|
|
func bakeToolResultsIntoTurns(parsed *parsedOpenAIRequest) {
|
|
if len(parsed.ToolResults) == 0 || len(parsed.Turns) == 0 {
|
|
return
|
|
}
|
|
last := &parsed.Turns[len(parsed.Turns)-1]
|
|
var toolContext strings.Builder
|
|
for _, tr := range parsed.ToolResults {
|
|
toolContext.WriteString("\n\n[Tool Result]\n")
|
|
toolContext.WriteString(tr.Content)
|
|
}
|
|
if last.AssistantText != "" {
|
|
last.AssistantText += toolContext.String()
|
|
} else {
|
|
last.AssistantText = toolContext.String()
|
|
}
|
|
parsed.ToolResults = nil // consumed
|
|
}
|
|
|
|
func extractTextContent(content gjson.Result) string {
|
|
if content.Type == gjson.String {
|
|
return content.String()
|
|
}
|
|
if content.IsArray() {
|
|
var parts []string
|
|
for _, part := range content.Array() {
|
|
if part.Get("type").String() == "text" {
|
|
parts = append(parts, part.Get("text").String())
|
|
}
|
|
}
|
|
return strings.Join(parts, "")
|
|
}
|
|
return content.String()
|
|
}
|
|
|
|
func extractImages(content gjson.Result) []cursorproto.ImageData {
|
|
if !content.IsArray() {
|
|
return nil
|
|
}
|
|
var images []cursorproto.ImageData
|
|
for _, part := range content.Array() {
|
|
if part.Get("type").String() == "image_url" {
|
|
url := part.Get("image_url.url").String()
|
|
if strings.HasPrefix(url, "data:") {
|
|
img := parseDataURL(url)
|
|
if img != nil {
|
|
images = append(images, *img)
|
|
}
|
|
}
|
|
}
|
|
}
|
|
return images
|
|
}
|
|
|
|
func parseDataURL(url string) *cursorproto.ImageData {
|
|
// data:image/png;base64,...
|
|
if !strings.HasPrefix(url, "data:") {
|
|
return nil
|
|
}
|
|
parts := strings.SplitN(url[5:], ";", 2)
|
|
if len(parts) != 2 {
|
|
return nil
|
|
}
|
|
mimeType := parts[0]
|
|
if !strings.HasPrefix(parts[1], "base64,") {
|
|
return nil
|
|
}
|
|
encoded := parts[1][7:]
|
|
data, err := base64.StdEncoding.DecodeString(encoded)
|
|
if err != nil {
|
|
// Try RawStdEncoding for unpadded base64
|
|
data, err = base64.RawStdEncoding.DecodeString(encoded)
|
|
if err != nil {
|
|
return nil
|
|
}
|
|
}
|
|
return &cursorproto.ImageData{
|
|
MimeType: mimeType,
|
|
Data: data,
|
|
}
|
|
}
|
|
|
|
func buildRunRequestParams(parsed *parsedOpenAIRequest) *cursorproto.RunRequestParams {
|
|
params := &cursorproto.RunRequestParams{
|
|
ModelId: parsed.Model,
|
|
SystemPrompt: parsed.SystemPrompt,
|
|
UserText: parsed.UserText,
|
|
MessageId: uuid.New().String(),
|
|
ConversationId: uuid.New().String(),
|
|
Images: parsed.Images,
|
|
Turns: parsed.Turns,
|
|
BlobStore: make(map[string][]byte),
|
|
}
|
|
|
|
// Convert OpenAI tools to McpToolDefs
|
|
for _, tool := range parsed.Tools {
|
|
fn := tool.Get("function")
|
|
params.McpTools = append(params.McpTools, cursorproto.McpToolDef{
|
|
Name: fn.Get("name").String(),
|
|
Description: fn.Get("description").String(),
|
|
InputSchema: json.RawMessage(fn.Get("parameters").Raw),
|
|
})
|
|
}
|
|
|
|
return params
|
|
}
|
|
|
|
// --- Helpers ---
|
|
|
|
func cursorAccessToken(auth *cliproxyauth.Auth) string {
|
|
if auth == nil || auth.Metadata == nil {
|
|
return ""
|
|
}
|
|
if v, ok := auth.Metadata["access_token"].(string); ok {
|
|
return v
|
|
}
|
|
return ""
|
|
}
|
|
|
|
func cursorRefreshToken(auth *cliproxyauth.Auth) string {
|
|
if auth == nil || auth.Metadata == nil {
|
|
return ""
|
|
}
|
|
if v, ok := auth.Metadata["refresh_token"].(string); ok {
|
|
return v
|
|
}
|
|
return ""
|
|
}
|
|
|
|
func applyCursorHeaders(req *http.Request, accessToken string) {
|
|
req.Header.Set("Content-Type", "application/connect+proto")
|
|
req.Header.Set("Connect-Protocol-Version", "1")
|
|
req.Header.Set("Te", "trailers")
|
|
req.Header.Set("Authorization", "Bearer "+accessToken)
|
|
req.Header.Set("X-Ghost-Mode", "true")
|
|
req.Header.Set("X-Cursor-Client-Version", cursorClientVersion)
|
|
req.Header.Set("X-Cursor-Client-Type", "cli")
|
|
req.Header.Set("X-Request-Id", uuid.New().String())
|
|
}
|
|
|
|
func newH2Client() *http.Client {
|
|
return &http.Client{
|
|
Transport: &http2.Transport{
|
|
TLSClientConfig: &tls.Config{},
|
|
},
|
|
}
|
|
}
|
|
|
|
func deriveSessionKey(model string, messages []gjson.Result) string {
|
|
var firstUserContent string
|
|
for _, msg := range messages {
|
|
if msg.Get("role").String() == "user" {
|
|
firstUserContent = extractTextContent(msg.Get("content"))
|
|
break
|
|
}
|
|
}
|
|
input := model + ":" + firstUserContent
|
|
if len(input) > 200 {
|
|
input = input[:200]
|
|
}
|
|
h := sha256.Sum256([]byte(input))
|
|
return hex.EncodeToString(h[:])[:16]
|
|
}
|
|
|
|
func sseChunk(id string, created int64, model string, delta string, finishReason string) cliproxyexecutor.StreamChunk {
|
|
fr := "null"
|
|
if finishReason != "" {
|
|
fr = finishReason
|
|
}
|
|
// Note: the framework's WriteChunk adds "data: " prefix and "\n\n" suffix,
|
|
// so we only output the raw JSON here.
|
|
data := fmt.Sprintf(`{"id":"%s","object":"chat.completion.chunk","created":%d,"model":"%s","choices":[{"index":0,"delta":%s,"finish_reason":%s}]}`,
|
|
id, created, model, delta, fr)
|
|
return cliproxyexecutor.StreamChunk{
|
|
Payload: []byte(data),
|
|
}
|
|
}
|
|
|
|
func jsonString(s string) string {
|
|
b, _ := json.Marshal(s)
|
|
return string(b)
|
|
}
|
|
|
|
func decodeMcpArgsToJSON(args map[string][]byte) string {
|
|
if len(args) == 0 {
|
|
return "{}"
|
|
}
|
|
result := make(map[string]interface{})
|
|
for k, v := range args {
|
|
// Try protobuf Value decoding first (matches TS: toJson(ValueSchema, fromBinary(ValueSchema, value)))
|
|
if decoded, err := cursorproto.ProtobufValueBytesToJSON(v); err == nil {
|
|
result[k] = decoded
|
|
} else {
|
|
// Fallback: try raw JSON
|
|
var jsonVal interface{}
|
|
if err := json.Unmarshal(v, &jsonVal); err == nil {
|
|
result[k] = jsonVal
|
|
} else {
|
|
result[k] = string(v)
|
|
}
|
|
}
|
|
}
|
|
b, _ := json.Marshal(result)
|
|
return string(b)
|
|
}
|
|
|
|
// --- Model Discovery ---
|
|
|
|
// FetchCursorModels retrieves available models from Cursor's API.
|
|
func FetchCursorModels(ctx context.Context, auth *cliproxyauth.Auth, cfg *config.Config) []*registry.ModelInfo {
|
|
accessToken := cursorAccessToken(auth)
|
|
if accessToken == "" {
|
|
return GetCursorFallbackModels()
|
|
}
|
|
|
|
ctx, cancel := context.WithTimeout(ctx, 5*time.Second)
|
|
defer cancel()
|
|
|
|
// GetUsableModels is a unary RPC call (not streaming)
|
|
// Send an empty protobuf request
|
|
emptyReq := make([]byte, 0)
|
|
|
|
h2Req, err := http.NewRequestWithContext(ctx, http.MethodPost,
|
|
cursorAPIURL+cursorModelsPath, bytes.NewReader(emptyReq))
|
|
if err != nil {
|
|
log.Debugf("cursor: failed to create models request: %v", err)
|
|
return GetCursorFallbackModels()
|
|
}
|
|
|
|
h2Req.Header.Set("Content-Type", "application/proto")
|
|
h2Req.Header.Set("Te", "trailers")
|
|
h2Req.Header.Set("Authorization", "Bearer "+accessToken)
|
|
h2Req.Header.Set("X-Ghost-Mode", "true")
|
|
h2Req.Header.Set("X-Cursor-Client-Version", cursorClientVersion)
|
|
h2Req.Header.Set("X-Cursor-Client-Type", "cli")
|
|
|
|
client := newH2Client()
|
|
resp, err := client.Do(h2Req)
|
|
if err != nil {
|
|
log.Debugf("cursor: models request failed: %v", err)
|
|
return GetCursorFallbackModels()
|
|
}
|
|
defer resp.Body.Close()
|
|
|
|
if resp.StatusCode < 200 || resp.StatusCode >= 300 {
|
|
log.Debugf("cursor: models request returned status %d", resp.StatusCode)
|
|
return GetCursorFallbackModels()
|
|
}
|
|
|
|
body, err := io.ReadAll(resp.Body)
|
|
if err != nil {
|
|
return GetCursorFallbackModels()
|
|
}
|
|
|
|
models := parseModelsResponse(body)
|
|
if len(models) == 0 {
|
|
return GetCursorFallbackModels()
|
|
}
|
|
return models
|
|
}
|
|
|
|
func parseModelsResponse(data []byte) []*registry.ModelInfo {
|
|
// Try stripping Connect framing first
|
|
if len(data) >= cursorproto.ConnectFrameHeaderSize {
|
|
_, payload, _, ok := cursorproto.ParseConnectFrame(data)
|
|
if ok {
|
|
data = payload
|
|
}
|
|
}
|
|
|
|
// The response is a GetUsableModelsResponse protobuf.
|
|
// We need to decode it manually - it contains a repeated "models" field.
|
|
// Based on the TS code, the response has a `models` field (repeated) containing
|
|
// model objects with modelId, displayName, thinkingDetails, etc.
|
|
|
|
// For now, we'll try a simple decode approach
|
|
var models []*registry.ModelInfo
|
|
// Field 1 is likely "models" (repeated submessage)
|
|
for len(data) > 0 {
|
|
num, typ, n := consumeTag(data)
|
|
if n < 0 {
|
|
break
|
|
}
|
|
data = data[n:]
|
|
|
|
if typ == 2 { // BytesType (submessage)
|
|
val, n := consumeBytes(data)
|
|
if n < 0 {
|
|
break
|
|
}
|
|
data = data[n:]
|
|
|
|
if num == 1 { // models field
|
|
if m := parseModelEntry(val); m != nil {
|
|
models = append(models, m)
|
|
}
|
|
}
|
|
} else {
|
|
n := consumeFieldValue(num, typ, data)
|
|
if n < 0 {
|
|
break
|
|
}
|
|
data = data[n:]
|
|
}
|
|
}
|
|
|
|
return models
|
|
}
|
|
|
|
func parseModelEntry(data []byte) *registry.ModelInfo {
|
|
var modelId, displayName string
|
|
var hasThinking bool
|
|
|
|
for len(data) > 0 {
|
|
num, typ, n := consumeTag(data)
|
|
if n < 0 {
|
|
break
|
|
}
|
|
data = data[n:]
|
|
|
|
switch typ {
|
|
case 2: // BytesType
|
|
val, n := consumeBytes(data)
|
|
if n < 0 {
|
|
return nil
|
|
}
|
|
data = data[n:]
|
|
switch num {
|
|
case 1: // modelId
|
|
modelId = string(val)
|
|
case 2: // thinkingDetails
|
|
hasThinking = true
|
|
case 3: // displayModelId (use as fallback)
|
|
if displayName == "" {
|
|
displayName = string(val)
|
|
}
|
|
case 4: // displayName
|
|
displayName = string(val)
|
|
case 5: // displayNameShort
|
|
if displayName == "" {
|
|
displayName = string(val)
|
|
}
|
|
}
|
|
case 0: // VarintType
|
|
_, n := consumeVarint(data)
|
|
if n < 0 {
|
|
return nil
|
|
}
|
|
data = data[n:]
|
|
default:
|
|
n := consumeFieldValue(num, typ, data)
|
|
if n < 0 {
|
|
return nil
|
|
}
|
|
data = data[n:]
|
|
}
|
|
}
|
|
|
|
if modelId == "" {
|
|
return nil
|
|
}
|
|
if displayName == "" {
|
|
displayName = modelId
|
|
}
|
|
|
|
info := ®istry.ModelInfo{
|
|
ID: modelId,
|
|
Object: "model",
|
|
Created: time.Now().Unix(),
|
|
OwnedBy: "cursor",
|
|
Type: cursorAuthType,
|
|
DisplayName: displayName,
|
|
ContextLength: 200000,
|
|
MaxCompletionTokens: 64000,
|
|
}
|
|
if hasThinking {
|
|
info.Thinking = ®istry.ThinkingSupport{
|
|
Max: 50000,
|
|
DynamicAllowed: true,
|
|
}
|
|
}
|
|
return info
|
|
}
|
|
|
|
// GetCursorFallbackModels returns hardcoded fallback models.
|
|
func GetCursorFallbackModels() []*registry.ModelInfo {
|
|
return []*registry.ModelInfo{
|
|
{ID: "composer-2", Object: "model", OwnedBy: "cursor", Type: cursorAuthType, DisplayName: "Composer 2", ContextLength: 200000, MaxCompletionTokens: 64000, Thinking: ®istry.ThinkingSupport{Max: 50000, DynamicAllowed: true}},
|
|
{ID: "claude-4-sonnet", Object: "model", OwnedBy: "cursor", Type: cursorAuthType, DisplayName: "Claude 4 Sonnet", ContextLength: 200000, MaxCompletionTokens: 64000, Thinking: ®istry.ThinkingSupport{Max: 50000, DynamicAllowed: true}},
|
|
{ID: "claude-3.5-sonnet", Object: "model", OwnedBy: "cursor", Type: cursorAuthType, DisplayName: "Claude 3.5 Sonnet", ContextLength: 200000, MaxCompletionTokens: 8192},
|
|
{ID: "gpt-4o", Object: "model", OwnedBy: "cursor", Type: cursorAuthType, DisplayName: "GPT-4o", ContextLength: 128000, MaxCompletionTokens: 16384},
|
|
{ID: "cursor-small", Object: "model", OwnedBy: "cursor", Type: cursorAuthType, DisplayName: "Cursor Small", ContextLength: 200000, MaxCompletionTokens: 64000},
|
|
{ID: "gemini-2.5-pro", Object: "model", OwnedBy: "cursor", Type: cursorAuthType, DisplayName: "Gemini 2.5 Pro", ContextLength: 1000000, MaxCompletionTokens: 65536, Thinking: ®istry.ThinkingSupport{Max: 50000, DynamicAllowed: true}},
|
|
}
|
|
}
|
|
|
|
// Low-level protowire helpers (avoid importing protowire in executor)
|
|
func consumeTag(b []byte) (num int, typ int, n int) {
|
|
v, n := consumeVarint(b)
|
|
if n < 0 {
|
|
return 0, 0, -1
|
|
}
|
|
return int(v >> 3), int(v & 7), n
|
|
}
|
|
|
|
func consumeVarint(b []byte) (uint64, int) {
|
|
var val uint64
|
|
for i := 0; i < len(b) && i < 10; i++ {
|
|
val |= uint64(b[i]&0x7f) << (7 * i)
|
|
if b[i]&0x80 == 0 {
|
|
return val, i + 1
|
|
}
|
|
}
|
|
return 0, -1
|
|
}
|
|
|
|
func consumeBytes(b []byte) ([]byte, int) {
|
|
length, n := consumeVarint(b)
|
|
if n < 0 || int(length) > len(b)-n {
|
|
return nil, -1
|
|
}
|
|
return b[n : n+int(length)], n + int(length)
|
|
}
|
|
|
|
func consumeFieldValue(num, typ int, b []byte) int {
|
|
switch typ {
|
|
case 0: // Varint
|
|
_, n := consumeVarint(b)
|
|
return n
|
|
case 1: // 64-bit
|
|
if len(b) < 8 {
|
|
return -1
|
|
}
|
|
return 8
|
|
case 2: // Length-delimited
|
|
_, n := consumeBytes(b)
|
|
return n
|
|
case 5: // 32-bit
|
|
if len(b) < 4 {
|
|
return -1
|
|
}
|
|
return 4
|
|
default:
|
|
return -1
|
|
}
|
|
}
|