diff --git a/internal/runtime/executor/aistudio_executor.go b/internal/runtime/executor/aistudio_executor.go index d37cd2c2..4558b319 100644 --- a/internal/runtime/executor/aistudio_executor.go +++ b/internal/runtime/executor/aistudio_executor.go @@ -1,3 +1,6 @@ +// Package executor provides runtime execution capabilities for various AI service providers. +// This file implements the AI Studio executor that routes requests through a websocket-backed +// transport for the AI Studio provider. package executor import ( @@ -26,19 +29,28 @@ type AIStudioExecutor struct { cfg *config.Config } -// NewAIStudioExecutor constructs a websocket executor for the provider name. +// NewAIStudioExecutor creates a new AI Studio executor instance. +// +// Parameters: +// - cfg: The application configuration +// - provider: The provider name +// - relay: The websocket relay manager +// +// Returns: +// - *AIStudioExecutor: A new AI Studio executor instance func NewAIStudioExecutor(cfg *config.Config, provider string, relay *wsrelay.Manager) *AIStudioExecutor { return &AIStudioExecutor{provider: strings.ToLower(provider), relay: relay, cfg: cfg} } -// Identifier returns the logical provider key for routing. +// Identifier returns the executor identifier. func (e *AIStudioExecutor) Identifier() string { return "aistudio" } -// PrepareRequest is a no-op because websocket transport already injects headers. +// PrepareRequest prepares the HTTP request for execution (no-op for AI Studio). func (e *AIStudioExecutor) PrepareRequest(_ *http.Request, _ *cliproxyauth.Auth) error { return nil } +// Execute performs a non-streaming request to the AI Studio API. func (e *AIStudioExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (resp cliproxyexecutor.Response, err error) { reporter := newUsageReporter(ctx, e.Identifier(), req.Model, auth) defer reporter.trackFailure(ctx, &err) @@ -92,6 +104,7 @@ func (e *AIStudioExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, return resp, nil } +// ExecuteStream performs a streaming request to the AI Studio API. func (e *AIStudioExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (stream <-chan cliproxyexecutor.StreamChunk, err error) { reporter := newUsageReporter(ctx, e.Identifier(), req.Model, auth) defer reporter.trackFailure(ctx, &err) @@ -239,6 +252,7 @@ func (e *AIStudioExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth return stream, nil } +// CountTokens counts tokens for the given request using the AI Studio API. func (e *AIStudioExecutor) CountTokens(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (cliproxyexecutor.Response, error) { _, body, err := e.translateRequest(req, opts, false) if err != nil { @@ -293,8 +307,8 @@ func (e *AIStudioExecutor) CountTokens(ctx context.Context, auth *cliproxyauth.A return cliproxyexecutor.Response{Payload: []byte(translated)}, nil } -func (e *AIStudioExecutor) Refresh(ctx context.Context, auth *cliproxyauth.Auth) (*cliproxyauth.Auth, error) { - _ = ctx +// Refresh refreshes the authentication credentials (no-op for AI Studio). +func (e *AIStudioExecutor) Refresh(_ context.Context, auth *cliproxyauth.Auth) (*cliproxyauth.Auth, error) { return auth, nil } diff --git a/internal/runtime/executor/antigravity_executor.go b/internal/runtime/executor/antigravity_executor.go index e9ae3dc0..a77bc037 100644 --- a/internal/runtime/executor/antigravity_executor.go +++ b/internal/runtime/executor/antigravity_executor.go @@ -1,3 +1,6 @@ +// Package executor provides runtime execution capabilities for various AI service providers. +// This file implements the Antigravity executor that proxies requests to the antigravity +// upstream using OAuth credentials. package executor import ( @@ -30,16 +33,15 @@ import ( const ( antigravityBaseURLDaily = "https://daily-cloudcode-pa.sandbox.googleapis.com" // antigravityBaseURLAutopush = "https://autopush-cloudcode-pa.sandbox.googleapis.com" - antigravityBaseURLProd = "https://cloudcode-pa.googleapis.com" - antigravityStreamPath = "/v1internal:streamGenerateContent" - antigravityGeneratePath = "/v1internal:generateContent" - antigravityModelsPath = "/v1internal:fetchAvailableModels" - antigravityClientID = "1071006060591-tmhssin2h21lcre235vtolojh4g403ep.apps.googleusercontent.com" - antigravityClientSecret = "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" - defaultAntigravityAgent = "antigravity/1.11.5 windows/amd64" - antigravityAuthType = "antigravity" - refreshSkew = 3000 * time.Second - streamScannerBuffer int = 52_428_800 // 50MB + antigravityBaseURLProd = "https://cloudcode-pa.googleapis.com" + antigravityStreamPath = "/v1internal:streamGenerateContent" + antigravityGeneratePath = "/v1internal:generateContent" + antigravityModelsPath = "/v1internal:fetchAvailableModels" + antigravityClientID = "1071006060591-tmhssin2h21lcre235vtolojh4g403ep.apps.googleusercontent.com" + antigravityClientSecret = "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" + defaultAntigravityAgent = "antigravity/1.11.5 windows/amd64" + antigravityAuthType = "antigravity" + refreshSkew = 3000 * time.Second ) var ( @@ -52,18 +54,24 @@ type AntigravityExecutor struct { cfg *config.Config } -// NewAntigravityExecutor constructs a new executor instance. +// NewAntigravityExecutor creates a new Antigravity executor instance. +// +// Parameters: +// - cfg: The application configuration +// +// Returns: +// - *AntigravityExecutor: A new Antigravity executor instance func NewAntigravityExecutor(cfg *config.Config) *AntigravityExecutor { return &AntigravityExecutor{cfg: cfg} } -// Identifier implements ProviderExecutor. +// Identifier returns the executor identifier. func (e *AntigravityExecutor) Identifier() string { return antigravityAuthType } -// PrepareRequest implements ProviderExecutor. +// PrepareRequest prepares the HTTP request for execution (no-op for Antigravity). func (e *AntigravityExecutor) PrepareRequest(_ *http.Request, _ *cliproxyauth.Auth) error { return nil } -// Execute handles non-streaming requests via the antigravity generate endpoint. +// Execute performs a non-streaming request to the Antigravity API. func (e *AntigravityExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (resp cliproxyexecutor.Response, err error) { token, updatedAuth, errToken := e.ensureAccessToken(ctx, auth) if errToken != nil { @@ -156,7 +164,7 @@ func (e *AntigravityExecutor) Execute(ctx context.Context, auth *cliproxyauth.Au return resp, err } -// ExecuteStream handles streaming requests via the antigravity upstream. +// ExecuteStream performs a streaming request to the Antigravity API. func (e *AntigravityExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (stream <-chan cliproxyexecutor.StreamChunk, err error) { ctx = context.WithValue(ctx, "alt", "") @@ -296,7 +304,7 @@ func (e *AntigravityExecutor) ExecuteStream(ctx context.Context, auth *cliproxya return nil, err } -// Refresh refreshes the OAuth token using the refresh token. +// Refresh refreshes the authentication credentials using the refresh token. func (e *AntigravityExecutor) Refresh(ctx context.Context, auth *cliproxyauth.Auth) (*cliproxyauth.Auth, error) { if auth == nil { return auth, nil @@ -308,7 +316,7 @@ func (e *AntigravityExecutor) Refresh(ctx context.Context, auth *cliproxyauth.Au return updated, nil } -// CountTokens is not supported for the antigravity provider. +// CountTokens counts tokens for the given request (not supported for Antigravity). func (e *AntigravityExecutor) CountTokens(context.Context, *cliproxyauth.Auth, cliproxyexecutor.Request, cliproxyexecutor.Options) (cliproxyexecutor.Response, error) { return cliproxyexecutor.Response{}, statusErr{code: http.StatusNotImplemented, msg: "count tokens not supported"} } diff --git a/internal/runtime/executor/gemini_cli_executor.go b/internal/runtime/executor/gemini_cli_executor.go index 2c4f3f88..4db33c57 100644 --- a/internal/runtime/executor/gemini_cli_executor.go +++ b/internal/runtime/executor/gemini_cli_executor.go @@ -1,3 +1,6 @@ +// Package executor provides runtime execution capabilities for various AI service providers. +// This file implements the Gemini CLI executor that talks to Cloud Code Assist endpoints +// using OAuth credentials from auth metadata. package executor import ( @@ -29,11 +32,11 @@ import ( const ( codeAssistEndpoint = "https://cloudcode-pa.googleapis.com" codeAssistVersion = "v1internal" - geminiOauthClientID = "681255809395-oo8ft2oprdrnp9e3aqf6av3hmdib135j.apps.googleusercontent.com" - geminiOauthClientSecret = "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" + geminiOAuthClientID = "681255809395-oo8ft2oprdrnp9e3aqf6av3hmdib135j.apps.googleusercontent.com" + geminiOAuthClientSecret = "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" ) -var geminiOauthScopes = []string{ +var geminiOAuthScopes = []string{ "https://www.googleapis.com/auth/cloud-platform", "https://www.googleapis.com/auth/userinfo.email", "https://www.googleapis.com/auth/userinfo.profile", @@ -44,14 +47,24 @@ type GeminiCLIExecutor struct { cfg *config.Config } +// NewGeminiCLIExecutor creates a new Gemini CLI executor instance. +// +// Parameters: +// - cfg: The application configuration +// +// Returns: +// - *GeminiCLIExecutor: A new Gemini CLI executor instance func NewGeminiCLIExecutor(cfg *config.Config) *GeminiCLIExecutor { return &GeminiCLIExecutor{cfg: cfg} } +// Identifier returns the executor identifier. func (e *GeminiCLIExecutor) Identifier() string { return "gemini-cli" } +// PrepareRequest prepares the HTTP request for execution (no-op for Gemini CLI). func (e *GeminiCLIExecutor) PrepareRequest(_ *http.Request, _ *cliproxyauth.Auth) error { return nil } +// Execute performs a non-streaming request to the Gemini CLI API. func (e *GeminiCLIExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (resp cliproxyexecutor.Response, err error) { tokenSource, baseTokenData, err := prepareGeminiCLITokenSource(ctx, e.cfg, auth) if err != nil { @@ -189,6 +202,7 @@ func (e *GeminiCLIExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth return resp, err } +// ExecuteStream performs a streaming request to the Gemini CLI API. func (e *GeminiCLIExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (stream <-chan cliproxyexecutor.StreamChunk, err error) { tokenSource, baseTokenData, err := prepareGeminiCLITokenSource(ctx, e.cfg, auth) if err != nil { @@ -309,7 +323,7 @@ func (e *GeminiCLIExecutor) ExecuteStream(ctx context.Context, auth *cliproxyaut }() if opts.Alt == "" { scanner := bufio.NewScanner(resp.Body) - scanner.Buffer(nil, 52_428_800) // 50MB + scanner.Buffer(nil, streamScannerBuffer) var param any for scanner.Scan() { line := scanner.Bytes() @@ -371,6 +385,7 @@ func (e *GeminiCLIExecutor) ExecuteStream(ctx context.Context, auth *cliproxyaut return nil, err } +// CountTokens counts tokens for the given request using the Gemini CLI API. func (e *GeminiCLIExecutor) CountTokens(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (cliproxyexecutor.Response, error) { tokenSource, baseTokenData, err := prepareGeminiCLITokenSource(ctx, e.cfg, auth) if err != nil { @@ -471,9 +486,8 @@ func (e *GeminiCLIExecutor) CountTokens(ctx context.Context, auth *cliproxyauth. return cliproxyexecutor.Response{}, newGeminiStatusErr(lastStatus, lastBody) } -func (e *GeminiCLIExecutor) Refresh(ctx context.Context, auth *cliproxyauth.Auth) (*cliproxyauth.Auth, error) { - log.Debugf("gemini cli executor: refresh called") - _ = ctx +// Refresh refreshes the authentication credentials (no-op for Gemini CLI). +func (e *GeminiCLIExecutor) Refresh(_ context.Context, auth *cliproxyauth.Auth) (*cliproxyauth.Auth, error) { return auth, nil } @@ -515,9 +529,9 @@ func prepareGeminiCLITokenSource(ctx context.Context, cfg *config.Config, auth * } conf := &oauth2.Config{ - ClientID: geminiOauthClientID, - ClientSecret: geminiOauthClientSecret, - Scopes: geminiOauthScopes, + ClientID: geminiOAuthClientID, + ClientSecret: geminiOAuthClientSecret, + Scopes: geminiOAuthScopes, Endpoint: google.Endpoint, } diff --git a/internal/runtime/executor/gemini_executor.go b/internal/runtime/executor/gemini_executor.go index 7b94b145..8dd3dc3b 100644 --- a/internal/runtime/executor/gemini_executor.go +++ b/internal/runtime/executor/gemini_executor.go @@ -11,7 +11,6 @@ import ( "io" "net/http" "strings" - "time" "github.com/router-for-me/CLIProxyAPI/v6/internal/config" "github.com/router-for-me/CLIProxyAPI/v6/internal/util" @@ -21,8 +20,6 @@ import ( log "github.com/sirupsen/logrus" "github.com/tidwall/gjson" "github.com/tidwall/sjson" - "golang.org/x/oauth2" - "golang.org/x/oauth2/google" ) const ( @@ -31,6 +28,9 @@ const ( // glAPIVersion is the API version used for Gemini requests. glAPIVersion = "v1beta" + + // streamScannerBuffer is the buffer size for SSE stream scanning. + streamScannerBuffer = 52_428_800 ) // GeminiExecutor is a stateless executor for the official Gemini API using API keys. @@ -48,9 +48,11 @@ type GeminiExecutor struct { // // Returns: // - *GeminiExecutor: A new Gemini executor instance -func NewGeminiExecutor(cfg *config.Config) *GeminiExecutor { return &GeminiExecutor{cfg: cfg} } +func NewGeminiExecutor(cfg *config.Config) *GeminiExecutor { + return &GeminiExecutor{cfg: cfg} +} -// Identifier returns the executor identifier for Gemini. +// Identifier returns the executor identifier. func (e *GeminiExecutor) Identifier() string { return "gemini" } // PrepareRequest prepares the HTTP request for execution (no-op for Gemini). @@ -164,6 +166,7 @@ func (e *GeminiExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, r return resp, nil } +// ExecuteStream performs a streaming request to the Gemini API. func (e *GeminiExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (stream <-chan cliproxyexecutor.StreamChunk, err error) { apiKey, bearer := geminiCreds(auth) @@ -249,7 +252,7 @@ func (e *GeminiExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.A } }() scanner := bufio.NewScanner(httpResp.Body) - scanner.Buffer(nil, 52_428_800) // 50MB + scanner.Buffer(nil, streamScannerBuffer) var param any for scanner.Scan() { line := scanner.Bytes() @@ -280,6 +283,7 @@ func (e *GeminiExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.A return stream, nil } +// CountTokens counts tokens for the given request using the Gemini API. func (e *GeminiExecutor) CountTokens(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (cliproxyexecutor.Response, error) { apiKey, bearer := geminiCreds(auth) @@ -353,106 +357,8 @@ func (e *GeminiExecutor) CountTokens(ctx context.Context, auth *cliproxyauth.Aut return cliproxyexecutor.Response{Payload: []byte(translated)}, nil } -func (e *GeminiExecutor) Refresh(ctx context.Context, auth *cliproxyauth.Auth) (*cliproxyauth.Auth, error) { - log.Debugf("gemini executor: refresh called") - // OAuth bearer token refresh for official Gemini API. - if auth == nil { - return nil, fmt.Errorf("gemini executor: auth is nil") - } - if auth.Metadata == nil { - return auth, nil - } - // Token data is typically nested under "token" map in Gemini files. - tokenMap, _ := auth.Metadata["token"].(map[string]any) - var refreshToken, accessToken, clientID, clientSecret, tokenURI, expiryStr string - if tokenMap != nil { - if v, ok := tokenMap["refresh_token"].(string); ok { - refreshToken = v - } - if v, ok := tokenMap["access_token"].(string); ok { - accessToken = v - } - if v, ok := tokenMap["client_id"].(string); ok { - clientID = v - } - if v, ok := tokenMap["client_secret"].(string); ok { - clientSecret = v - } - if v, ok := tokenMap["token_uri"].(string); ok { - tokenURI = v - } - if v, ok := tokenMap["expiry"].(string); ok { - expiryStr = v - } - } else { - // Fallback to top-level keys if present - if v, ok := auth.Metadata["refresh_token"].(string); ok { - refreshToken = v - } - if v, ok := auth.Metadata["access_token"].(string); ok { - accessToken = v - } - if v, ok := auth.Metadata["client_id"].(string); ok { - clientID = v - } - if v, ok := auth.Metadata["client_secret"].(string); ok { - clientSecret = v - } - if v, ok := auth.Metadata["token_uri"].(string); ok { - tokenURI = v - } - if v, ok := auth.Metadata["expiry"].(string); ok { - expiryStr = v - } - } - if refreshToken == "" { - // Nothing to do for API key or cookie based entries - return auth, nil - } - - // Prepare oauth2 config; default to Google endpoints - endpoint := google.Endpoint - if tokenURI != "" { - endpoint.TokenURL = tokenURI - } - conf := &oauth2.Config{ClientID: clientID, ClientSecret: clientSecret, Endpoint: endpoint} - - // Ensure proxy-aware HTTP client for token refresh - httpClient := util.SetProxy(&e.cfg.SDKConfig, &http.Client{}) - ctx = context.WithValue(ctx, oauth2.HTTPClient, httpClient) - - // Build base token - tok := &oauth2.Token{AccessToken: accessToken, RefreshToken: refreshToken} - if t, err := time.Parse(time.RFC3339, expiryStr); err == nil { - tok.Expiry = t - } - newTok, err := conf.TokenSource(ctx, tok).Token() - if err != nil { - return nil, err - } - - // Persist back to metadata; prefer nested token map if present - if tokenMap == nil { - tokenMap = make(map[string]any) - } - tokenMap["access_token"] = newTok.AccessToken - tokenMap["refresh_token"] = newTok.RefreshToken - tokenMap["expiry"] = newTok.Expiry.Format(time.RFC3339) - if clientID != "" { - tokenMap["client_id"] = clientID - } - if clientSecret != "" { - tokenMap["client_secret"] = clientSecret - } - if tokenURI != "" { - tokenMap["token_uri"] = tokenURI - } - auth.Metadata["token"] = tokenMap - - // Also mirror top-level access_token for compatibility if previously present - if _, ok := auth.Metadata["access_token"]; ok { - auth.Metadata["access_token"] = newTok.AccessToken - } +// Refresh refreshes the authentication credentials (no-op for Gemini API key). +func (e *GeminiExecutor) Refresh(_ context.Context, auth *cliproxyauth.Auth) (*cliproxyauth.Auth, error) { return auth, nil } diff --git a/internal/runtime/executor/gemini_vertex_executor.go b/internal/runtime/executor/gemini_vertex_executor.go index 51a6118c..df8ee506 100644 --- a/internal/runtime/executor/gemini_vertex_executor.go +++ b/internal/runtime/executor/gemini_vertex_executor.go @@ -1,6 +1,6 @@ -// Package executor contains provider executors. This file implements the Vertex AI -// Gemini executor that talks to Google Vertex AI endpoints using service account -// credentials imported by the CLI. +// Package executor provides runtime execution capabilities for various AI service providers. +// This file implements the Vertex AI Gemini executor that talks to Google Vertex AI +// endpoints using service account credentials or API keys. package executor import ( @@ -36,20 +36,26 @@ type GeminiVertexExecutor struct { cfg *config.Config } -// NewGeminiVertexExecutor constructs the Vertex executor. +// NewGeminiVertexExecutor creates a new Vertex AI Gemini executor instance. +// +// Parameters: +// - cfg: The application configuration +// +// Returns: +// - *GeminiVertexExecutor: A new Vertex AI Gemini executor instance func NewGeminiVertexExecutor(cfg *config.Config) *GeminiVertexExecutor { return &GeminiVertexExecutor{cfg: cfg} } -// Identifier returns provider key for manager routing. +// Identifier returns the executor identifier. func (e *GeminiVertexExecutor) Identifier() string { return "vertex" } -// PrepareRequest is a no-op for Vertex. +// PrepareRequest prepares the HTTP request for execution (no-op for Vertex). func (e *GeminiVertexExecutor) PrepareRequest(_ *http.Request, _ *cliproxyauth.Auth) error { return nil } -// Execute handles non-streaming requests. +// Execute performs a non-streaming request to the Vertex AI API. func (e *GeminiVertexExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (resp cliproxyexecutor.Response, err error) { // Try API key authentication first apiKey, baseURL := vertexAPICreds(auth) @@ -67,7 +73,7 @@ func (e *GeminiVertexExecutor) Execute(ctx context.Context, auth *cliproxyauth.A return e.executeWithAPIKey(ctx, auth, req, opts, apiKey, baseURL) } -// ExecuteStream handles SSE streaming for Vertex. +// ExecuteStream performs a streaming request to the Vertex AI API. func (e *GeminiVertexExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (stream <-chan cliproxyexecutor.StreamChunk, err error) { // Try API key authentication first apiKey, baseURL := vertexAPICreds(auth) @@ -85,7 +91,7 @@ func (e *GeminiVertexExecutor) ExecuteStream(ctx context.Context, auth *cliproxy return e.executeStreamWithAPIKey(ctx, auth, req, opts, apiKey, baseURL) } -// CountTokens calls Vertex countTokens endpoint. +// CountTokens counts tokens for the given request using the Vertex AI API. func (e *GeminiVertexExecutor) CountTokens(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (cliproxyexecutor.Response, error) { // Try API key authentication first apiKey, baseURL := vertexAPICreds(auth) @@ -103,185 +109,7 @@ func (e *GeminiVertexExecutor) CountTokens(ctx context.Context, auth *cliproxyau return e.countTokensWithAPIKey(ctx, auth, req, opts, apiKey, baseURL) } -// countTokensWithServiceAccount handles token counting using service account credentials. -func (e *GeminiVertexExecutor) countTokensWithServiceAccount(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options, projectID, location string, saJSON []byte) (cliproxyexecutor.Response, error) { - upstreamModel := util.ResolveOriginalModel(req.Model, req.Metadata) - - from := opts.SourceFormat - to := sdktranslator.FromString("gemini") - translatedReq := sdktranslator.TranslateRequest(from, to, req.Model, bytes.Clone(req.Payload), false) - if budgetOverride, includeOverride, ok := util.ResolveThinkingConfigFromMetadata(req.Model, req.Metadata); ok && util.ModelSupportsThinking(req.Model) { - if budgetOverride != nil { - norm := util.NormalizeThinkingBudget(req.Model, *budgetOverride) - budgetOverride = &norm - } - translatedReq = util.ApplyGeminiThinkingConfig(translatedReq, budgetOverride, includeOverride) - } - translatedReq = util.StripThinkingConfigIfUnsupported(req.Model, translatedReq) - translatedReq = fixGeminiImageAspectRatio(req.Model, translatedReq) - translatedReq, _ = sjson.SetBytes(translatedReq, "model", upstreamModel) - respCtx := context.WithValue(ctx, "alt", opts.Alt) - translatedReq, _ = sjson.DeleteBytes(translatedReq, "tools") - translatedReq, _ = sjson.DeleteBytes(translatedReq, "generationConfig") - translatedReq, _ = sjson.DeleteBytes(translatedReq, "safetySettings") - - baseURL := vertexBaseURL(location) - url := fmt.Sprintf("%s/%s/projects/%s/locations/%s/publishers/google/models/%s:%s", baseURL, vertexAPIVersion, projectID, location, upstreamModel, "countTokens") - - httpReq, errNewReq := http.NewRequestWithContext(respCtx, http.MethodPost, url, bytes.NewReader(translatedReq)) - if errNewReq != nil { - return cliproxyexecutor.Response{}, errNewReq - } - httpReq.Header.Set("Content-Type", "application/json") - if token, errTok := vertexAccessToken(ctx, e.cfg, auth, saJSON); errTok == nil && token != "" { - httpReq.Header.Set("Authorization", "Bearer "+token) - } else if errTok != nil { - log.Errorf("vertex executor: access token error: %v", errTok) - return cliproxyexecutor.Response{}, statusErr{code: 500, msg: "internal server error"} - } - applyGeminiHeaders(httpReq, auth) - - var authID, authLabel, authType, authValue string - if auth != nil { - authID = auth.ID - authLabel = auth.Label - authType, authValue = auth.AccountInfo() - } - recordAPIRequest(ctx, e.cfg, upstreamRequestLog{ - URL: url, - Method: http.MethodPost, - Headers: httpReq.Header.Clone(), - Body: translatedReq, - Provider: e.Identifier(), - AuthID: authID, - AuthLabel: authLabel, - AuthType: authType, - AuthValue: authValue, - }) - - httpClient := newProxyAwareHTTPClient(ctx, e.cfg, auth, 0) - httpResp, errDo := httpClient.Do(httpReq) - if errDo != nil { - recordAPIResponseError(ctx, e.cfg, errDo) - return cliproxyexecutor.Response{}, errDo - } - defer func() { - if errClose := httpResp.Body.Close(); errClose != nil { - log.Errorf("vertex executor: close response body error: %v", errClose) - } - }() - recordAPIResponseMetadata(ctx, e.cfg, httpResp.StatusCode, httpResp.Header.Clone()) - if httpResp.StatusCode < 200 || httpResp.StatusCode >= 300 { - b, _ := io.ReadAll(httpResp.Body) - appendAPIResponseChunk(ctx, e.cfg, b) - log.Debugf("request error, error status: %d, error body: %s", httpResp.StatusCode, summarizeErrorBody(httpResp.Header.Get("Content-Type"), b)) - return cliproxyexecutor.Response{}, statusErr{code: httpResp.StatusCode, msg: string(b)} - } - data, errRead := io.ReadAll(httpResp.Body) - if errRead != nil { - recordAPIResponseError(ctx, e.cfg, errRead) - return cliproxyexecutor.Response{}, errRead - } - appendAPIResponseChunk(ctx, e.cfg, data) - if httpResp.StatusCode < 200 || httpResp.StatusCode >= 300 { - log.Debugf("request error, error status: %d, error body: %s", httpResp.StatusCode, summarizeErrorBody(httpResp.Header.Get("Content-Type"), data)) - return cliproxyexecutor.Response{}, statusErr{code: httpResp.StatusCode, msg: string(data)} - } - count := gjson.GetBytes(data, "totalTokens").Int() - out := sdktranslator.TranslateTokenCount(ctx, to, from, count, data) - return cliproxyexecutor.Response{Payload: []byte(out)}, nil -} - -// countTokensWithAPIKey handles token counting using API key credentials. -func (e *GeminiVertexExecutor) countTokensWithAPIKey(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options, apiKey, baseURL string) (cliproxyexecutor.Response, error) { - upstreamModel := util.ResolveOriginalModel(req.Model, req.Metadata) - - from := opts.SourceFormat - to := sdktranslator.FromString("gemini") - translatedReq := sdktranslator.TranslateRequest(from, to, req.Model, bytes.Clone(req.Payload), false) - if budgetOverride, includeOverride, ok := util.ResolveThinkingConfigFromMetadata(req.Model, req.Metadata); ok && util.ModelSupportsThinking(req.Model) { - if budgetOverride != nil { - norm := util.NormalizeThinkingBudget(req.Model, *budgetOverride) - budgetOverride = &norm - } - translatedReq = util.ApplyGeminiThinkingConfig(translatedReq, budgetOverride, includeOverride) - } - translatedReq = util.StripThinkingConfigIfUnsupported(req.Model, translatedReq) - translatedReq = fixGeminiImageAspectRatio(req.Model, translatedReq) - translatedReq, _ = sjson.SetBytes(translatedReq, "model", upstreamModel) - respCtx := context.WithValue(ctx, "alt", opts.Alt) - translatedReq, _ = sjson.DeleteBytes(translatedReq, "tools") - translatedReq, _ = sjson.DeleteBytes(translatedReq, "generationConfig") - translatedReq, _ = sjson.DeleteBytes(translatedReq, "safetySettings") - - // For API key auth, use simpler URL format without project/location - if baseURL == "" { - baseURL = "https://generativelanguage.googleapis.com" - } - url := fmt.Sprintf("%s/%s/publishers/google/models/%s:%s", baseURL, vertexAPIVersion, req.Model, "countTokens") - - httpReq, errNewReq := http.NewRequestWithContext(respCtx, http.MethodPost, url, bytes.NewReader(translatedReq)) - if errNewReq != nil { - return cliproxyexecutor.Response{}, errNewReq - } - httpReq.Header.Set("Content-Type", "application/json") - if apiKey != "" { - httpReq.Header.Set("x-goog-api-key", apiKey) - } - applyGeminiHeaders(httpReq, auth) - - var authID, authLabel, authType, authValue string - if auth != nil { - authID = auth.ID - authLabel = auth.Label - authType, authValue = auth.AccountInfo() - } - recordAPIRequest(ctx, e.cfg, upstreamRequestLog{ - URL: url, - Method: http.MethodPost, - Headers: httpReq.Header.Clone(), - Body: translatedReq, - Provider: e.Identifier(), - AuthID: authID, - AuthLabel: authLabel, - AuthType: authType, - AuthValue: authValue, - }) - - httpClient := newProxyAwareHTTPClient(ctx, e.cfg, auth, 0) - httpResp, errDo := httpClient.Do(httpReq) - if errDo != nil { - recordAPIResponseError(ctx, e.cfg, errDo) - return cliproxyexecutor.Response{}, errDo - } - defer func() { - if errClose := httpResp.Body.Close(); errClose != nil { - log.Errorf("vertex executor: close response body error: %v", errClose) - } - }() - recordAPIResponseMetadata(ctx, e.cfg, httpResp.StatusCode, httpResp.Header.Clone()) - if httpResp.StatusCode < 200 || httpResp.StatusCode >= 300 { - b, _ := io.ReadAll(httpResp.Body) - appendAPIResponseChunk(ctx, e.cfg, b) - log.Debugf("request error, error status: %d, error body: %s", httpResp.StatusCode, summarizeErrorBody(httpResp.Header.Get("Content-Type"), b)) - return cliproxyexecutor.Response{}, statusErr{code: httpResp.StatusCode, msg: string(b)} - } - data, errRead := io.ReadAll(httpResp.Body) - if errRead != nil { - recordAPIResponseError(ctx, e.cfg, errRead) - return cliproxyexecutor.Response{}, errRead - } - appendAPIResponseChunk(ctx, e.cfg, data) - if httpResp.StatusCode < 200 || httpResp.StatusCode >= 300 { - log.Debugf("request error, error status: %d, error body: %s", httpResp.StatusCode, summarizeErrorBody(httpResp.Header.Get("Content-Type"), data)) - return cliproxyexecutor.Response{}, statusErr{code: httpResp.StatusCode, msg: string(data)} - } - count := gjson.GetBytes(data, "totalTokens").Int() - out := sdktranslator.TranslateTokenCount(ctx, to, from, count, data) - return cliproxyexecutor.Response{Payload: []byte(out)}, nil -} - -// Refresh is a no-op for service account based credentials. +// Refresh refreshes the authentication credentials (no-op for Vertex). func (e *GeminiVertexExecutor) Refresh(_ context.Context, auth *cliproxyauth.Auth) (*cliproxyauth.Auth, error) { return auth, nil } @@ -579,7 +407,7 @@ func (e *GeminiVertexExecutor) executeStreamWithServiceAccount(ctx context.Conte } }() scanner := bufio.NewScanner(httpResp.Body) - scanner.Buffer(nil, 52_428_800) // 50MB + scanner.Buffer(nil, streamScannerBuffer) var param any for scanner.Scan() { line := scanner.Bytes() @@ -696,7 +524,7 @@ func (e *GeminiVertexExecutor) executeStreamWithAPIKey(ctx context.Context, auth } }() scanner := bufio.NewScanner(httpResp.Body) - scanner.Buffer(nil, 52_428_800) // 50MB + scanner.Buffer(nil, streamScannerBuffer) var param any for scanner.Scan() { line := scanner.Bytes() @@ -722,6 +550,184 @@ func (e *GeminiVertexExecutor) executeStreamWithAPIKey(ctx context.Context, auth return stream, nil } +// countTokensWithServiceAccount counts tokens using service account credentials. +func (e *GeminiVertexExecutor) countTokensWithServiceAccount(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options, projectID, location string, saJSON []byte) (cliproxyexecutor.Response, error) { + upstreamModel := util.ResolveOriginalModel(req.Model, req.Metadata) + + from := opts.SourceFormat + to := sdktranslator.FromString("gemini") + translatedReq := sdktranslator.TranslateRequest(from, to, req.Model, bytes.Clone(req.Payload), false) + if budgetOverride, includeOverride, ok := util.ResolveThinkingConfigFromMetadata(req.Model, req.Metadata); ok && util.ModelSupportsThinking(req.Model) { + if budgetOverride != nil { + norm := util.NormalizeThinkingBudget(req.Model, *budgetOverride) + budgetOverride = &norm + } + translatedReq = util.ApplyGeminiThinkingConfig(translatedReq, budgetOverride, includeOverride) + } + translatedReq = util.StripThinkingConfigIfUnsupported(req.Model, translatedReq) + translatedReq = fixGeminiImageAspectRatio(req.Model, translatedReq) + translatedReq, _ = sjson.SetBytes(translatedReq, "model", upstreamModel) + respCtx := context.WithValue(ctx, "alt", opts.Alt) + translatedReq, _ = sjson.DeleteBytes(translatedReq, "tools") + translatedReq, _ = sjson.DeleteBytes(translatedReq, "generationConfig") + translatedReq, _ = sjson.DeleteBytes(translatedReq, "safetySettings") + + baseURL := vertexBaseURL(location) + url := fmt.Sprintf("%s/%s/projects/%s/locations/%s/publishers/google/models/%s:%s", baseURL, vertexAPIVersion, projectID, location, upstreamModel, "countTokens") + + httpReq, errNewReq := http.NewRequestWithContext(respCtx, http.MethodPost, url, bytes.NewReader(translatedReq)) + if errNewReq != nil { + return cliproxyexecutor.Response{}, errNewReq + } + httpReq.Header.Set("Content-Type", "application/json") + if token, errTok := vertexAccessToken(ctx, e.cfg, auth, saJSON); errTok == nil && token != "" { + httpReq.Header.Set("Authorization", "Bearer "+token) + } else if errTok != nil { + log.Errorf("vertex executor: access token error: %v", errTok) + return cliproxyexecutor.Response{}, statusErr{code: 500, msg: "internal server error"} + } + applyGeminiHeaders(httpReq, auth) + + var authID, authLabel, authType, authValue string + if auth != nil { + authID = auth.ID + authLabel = auth.Label + authType, authValue = auth.AccountInfo() + } + recordAPIRequest(ctx, e.cfg, upstreamRequestLog{ + URL: url, + Method: http.MethodPost, + Headers: httpReq.Header.Clone(), + Body: translatedReq, + Provider: e.Identifier(), + AuthID: authID, + AuthLabel: authLabel, + AuthType: authType, + AuthValue: authValue, + }) + + httpClient := newProxyAwareHTTPClient(ctx, e.cfg, auth, 0) + httpResp, errDo := httpClient.Do(httpReq) + if errDo != nil { + recordAPIResponseError(ctx, e.cfg, errDo) + return cliproxyexecutor.Response{}, errDo + } + defer func() { + if errClose := httpResp.Body.Close(); errClose != nil { + log.Errorf("vertex executor: close response body error: %v", errClose) + } + }() + recordAPIResponseMetadata(ctx, e.cfg, httpResp.StatusCode, httpResp.Header.Clone()) + if httpResp.StatusCode < 200 || httpResp.StatusCode >= 300 { + b, _ := io.ReadAll(httpResp.Body) + appendAPIResponseChunk(ctx, e.cfg, b) + log.Debugf("request error, error status: %d, error body: %s", httpResp.StatusCode, summarizeErrorBody(httpResp.Header.Get("Content-Type"), b)) + return cliproxyexecutor.Response{}, statusErr{code: httpResp.StatusCode, msg: string(b)} + } + data, errRead := io.ReadAll(httpResp.Body) + if errRead != nil { + recordAPIResponseError(ctx, e.cfg, errRead) + return cliproxyexecutor.Response{}, errRead + } + appendAPIResponseChunk(ctx, e.cfg, data) + if httpResp.StatusCode < 200 || httpResp.StatusCode >= 300 { + log.Debugf("request error, error status: %d, error body: %s", httpResp.StatusCode, summarizeErrorBody(httpResp.Header.Get("Content-Type"), data)) + return cliproxyexecutor.Response{}, statusErr{code: httpResp.StatusCode, msg: string(data)} + } + count := gjson.GetBytes(data, "totalTokens").Int() + out := sdktranslator.TranslateTokenCount(ctx, to, from, count, data) + return cliproxyexecutor.Response{Payload: []byte(out)}, nil +} + +// countTokensWithAPIKey handles token counting using API key credentials. +func (e *GeminiVertexExecutor) countTokensWithAPIKey(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options, apiKey, baseURL string) (cliproxyexecutor.Response, error) { + upstreamModel := util.ResolveOriginalModel(req.Model, req.Metadata) + + from := opts.SourceFormat + to := sdktranslator.FromString("gemini") + translatedReq := sdktranslator.TranslateRequest(from, to, req.Model, bytes.Clone(req.Payload), false) + if budgetOverride, includeOverride, ok := util.ResolveThinkingConfigFromMetadata(req.Model, req.Metadata); ok && util.ModelSupportsThinking(req.Model) { + if budgetOverride != nil { + norm := util.NormalizeThinkingBudget(req.Model, *budgetOverride) + budgetOverride = &norm + } + translatedReq = util.ApplyGeminiThinkingConfig(translatedReq, budgetOverride, includeOverride) + } + translatedReq = util.StripThinkingConfigIfUnsupported(req.Model, translatedReq) + translatedReq = fixGeminiImageAspectRatio(req.Model, translatedReq) + translatedReq, _ = sjson.SetBytes(translatedReq, "model", upstreamModel) + respCtx := context.WithValue(ctx, "alt", opts.Alt) + translatedReq, _ = sjson.DeleteBytes(translatedReq, "tools") + translatedReq, _ = sjson.DeleteBytes(translatedReq, "generationConfig") + translatedReq, _ = sjson.DeleteBytes(translatedReq, "safetySettings") + + // For API key auth, use simpler URL format without project/location + if baseURL == "" { + baseURL = "https://generativelanguage.googleapis.com" + } + url := fmt.Sprintf("%s/%s/publishers/google/models/%s:%s", baseURL, vertexAPIVersion, req.Model, "countTokens") + + httpReq, errNewReq := http.NewRequestWithContext(respCtx, http.MethodPost, url, bytes.NewReader(translatedReq)) + if errNewReq != nil { + return cliproxyexecutor.Response{}, errNewReq + } + httpReq.Header.Set("Content-Type", "application/json") + if apiKey != "" { + httpReq.Header.Set("x-goog-api-key", apiKey) + } + applyGeminiHeaders(httpReq, auth) + + var authID, authLabel, authType, authValue string + if auth != nil { + authID = auth.ID + authLabel = auth.Label + authType, authValue = auth.AccountInfo() + } + recordAPIRequest(ctx, e.cfg, upstreamRequestLog{ + URL: url, + Method: http.MethodPost, + Headers: httpReq.Header.Clone(), + Body: translatedReq, + Provider: e.Identifier(), + AuthID: authID, + AuthLabel: authLabel, + AuthType: authType, + AuthValue: authValue, + }) + + httpClient := newProxyAwareHTTPClient(ctx, e.cfg, auth, 0) + httpResp, errDo := httpClient.Do(httpReq) + if errDo != nil { + recordAPIResponseError(ctx, e.cfg, errDo) + return cliproxyexecutor.Response{}, errDo + } + defer func() { + if errClose := httpResp.Body.Close(); errClose != nil { + log.Errorf("vertex executor: close response body error: %v", errClose) + } + }() + recordAPIResponseMetadata(ctx, e.cfg, httpResp.StatusCode, httpResp.Header.Clone()) + if httpResp.StatusCode < 200 || httpResp.StatusCode >= 300 { + b, _ := io.ReadAll(httpResp.Body) + appendAPIResponseChunk(ctx, e.cfg, b) + log.Debugf("request error, error status: %d, error body: %s", httpResp.StatusCode, summarizeErrorBody(httpResp.Header.Get("Content-Type"), b)) + return cliproxyexecutor.Response{}, statusErr{code: httpResp.StatusCode, msg: string(b)} + } + data, errRead := io.ReadAll(httpResp.Body) + if errRead != nil { + recordAPIResponseError(ctx, e.cfg, errRead) + return cliproxyexecutor.Response{}, errRead + } + appendAPIResponseChunk(ctx, e.cfg, data) + if httpResp.StatusCode < 200 || httpResp.StatusCode >= 300 { + log.Debugf("request error, error status: %d, error body: %s", httpResp.StatusCode, summarizeErrorBody(httpResp.Header.Get("Content-Type"), data)) + return cliproxyexecutor.Response{}, statusErr{code: httpResp.StatusCode, msg: string(data)} + } + count := gjson.GetBytes(data, "totalTokens").Int() + out := sdktranslator.TranslateTokenCount(ctx, to, from, count, data) + return cliproxyexecutor.Response{Payload: []byte(out)}, nil +} + // vertexCreds extracts project, location and raw service account JSON from auth metadata. func vertexCreds(a *cliproxyauth.Auth) (projectID, location string, serviceAccountJSON []byte, err error) { if a == nil || a.Metadata == nil {