Merge branch 'router-for-me:main' into main

This commit is contained in:
Luis Pater
2025-12-12 02:41:05 +08:00
committed by GitHub
5 changed files with 276 additions and 328 deletions

View File

@@ -1,3 +1,6 @@
// Package executor provides runtime execution capabilities for various AI service providers.
// This file implements the AI Studio executor that routes requests through a websocket-backed
// transport for the AI Studio provider.
package executor
import (
@@ -26,19 +29,28 @@ type AIStudioExecutor struct {
cfg *config.Config
}
// NewAIStudioExecutor constructs a websocket executor for the provider name.
// NewAIStudioExecutor creates a new AI Studio executor instance.
//
// Parameters:
// - cfg: The application configuration
// - provider: The provider name
// - relay: The websocket relay manager
//
// Returns:
// - *AIStudioExecutor: A new AI Studio executor instance
func NewAIStudioExecutor(cfg *config.Config, provider string, relay *wsrelay.Manager) *AIStudioExecutor {
return &AIStudioExecutor{provider: strings.ToLower(provider), relay: relay, cfg: cfg}
}
// Identifier returns the logical provider key for routing.
// Identifier returns the executor identifier.
func (e *AIStudioExecutor) Identifier() string { return "aistudio" }
// PrepareRequest is a no-op because websocket transport already injects headers.
// PrepareRequest prepares the HTTP request for execution (no-op for AI Studio).
func (e *AIStudioExecutor) PrepareRequest(_ *http.Request, _ *cliproxyauth.Auth) error {
return nil
}
// Execute performs a non-streaming request to the AI Studio API.
func (e *AIStudioExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (resp cliproxyexecutor.Response, err error) {
reporter := newUsageReporter(ctx, e.Identifier(), req.Model, auth)
defer reporter.trackFailure(ctx, &err)
@@ -92,6 +104,7 @@ func (e *AIStudioExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth,
return resp, nil
}
// ExecuteStream performs a streaming request to the AI Studio API.
func (e *AIStudioExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (stream <-chan cliproxyexecutor.StreamChunk, err error) {
reporter := newUsageReporter(ctx, e.Identifier(), req.Model, auth)
defer reporter.trackFailure(ctx, &err)
@@ -239,6 +252,7 @@ func (e *AIStudioExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth
return stream, nil
}
// CountTokens counts tokens for the given request using the AI Studio API.
func (e *AIStudioExecutor) CountTokens(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (cliproxyexecutor.Response, error) {
_, body, err := e.translateRequest(req, opts, false)
if err != nil {
@@ -293,8 +307,8 @@ func (e *AIStudioExecutor) CountTokens(ctx context.Context, auth *cliproxyauth.A
return cliproxyexecutor.Response{Payload: []byte(translated)}, nil
}
func (e *AIStudioExecutor) Refresh(ctx context.Context, auth *cliproxyauth.Auth) (*cliproxyauth.Auth, error) {
_ = ctx
// Refresh refreshes the authentication credentials (no-op for AI Studio).
func (e *AIStudioExecutor) Refresh(_ context.Context, auth *cliproxyauth.Auth) (*cliproxyauth.Auth, error) {
return auth, nil
}

View File

@@ -1,3 +1,6 @@
// Package executor provides runtime execution capabilities for various AI service providers.
// This file implements the Antigravity executor that proxies requests to the antigravity
// upstream using OAuth credentials.
package executor
import (
@@ -30,16 +33,15 @@ import (
const (
antigravityBaseURLDaily = "https://daily-cloudcode-pa.sandbox.googleapis.com"
// antigravityBaseURLAutopush = "https://autopush-cloudcode-pa.sandbox.googleapis.com"
antigravityBaseURLProd = "https://cloudcode-pa.googleapis.com"
antigravityStreamPath = "/v1internal:streamGenerateContent"
antigravityGeneratePath = "/v1internal:generateContent"
antigravityModelsPath = "/v1internal:fetchAvailableModels"
antigravityClientID = "1071006060591-tmhssin2h21lcre235vtolojh4g403ep.apps.googleusercontent.com"
antigravityClientSecret = "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"
defaultAntigravityAgent = "antigravity/1.11.5 windows/amd64"
antigravityAuthType = "antigravity"
refreshSkew = 3000 * time.Second
streamScannerBuffer int = 52_428_800 // 50MB
antigravityBaseURLProd = "https://cloudcode-pa.googleapis.com"
antigravityStreamPath = "/v1internal:streamGenerateContent"
antigravityGeneratePath = "/v1internal:generateContent"
antigravityModelsPath = "/v1internal:fetchAvailableModels"
antigravityClientID = "1071006060591-tmhssin2h21lcre235vtolojh4g403ep.apps.googleusercontent.com"
antigravityClientSecret = "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"
defaultAntigravityAgent = "antigravity/1.11.5 windows/amd64"
antigravityAuthType = "antigravity"
refreshSkew = 3000 * time.Second
)
var (
@@ -52,18 +54,24 @@ type AntigravityExecutor struct {
cfg *config.Config
}
// NewAntigravityExecutor constructs a new executor instance.
// NewAntigravityExecutor creates a new Antigravity executor instance.
//
// Parameters:
// - cfg: The application configuration
//
// Returns:
// - *AntigravityExecutor: A new Antigravity executor instance
func NewAntigravityExecutor(cfg *config.Config) *AntigravityExecutor {
return &AntigravityExecutor{cfg: cfg}
}
// Identifier implements ProviderExecutor.
// Identifier returns the executor identifier.
func (e *AntigravityExecutor) Identifier() string { return antigravityAuthType }
// PrepareRequest implements ProviderExecutor.
// PrepareRequest prepares the HTTP request for execution (no-op for Antigravity).
func (e *AntigravityExecutor) PrepareRequest(_ *http.Request, _ *cliproxyauth.Auth) error { return nil }
// Execute handles non-streaming requests via the antigravity generate endpoint.
// Execute performs a non-streaming request to the Antigravity API.
func (e *AntigravityExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (resp cliproxyexecutor.Response, err error) {
token, updatedAuth, errToken := e.ensureAccessToken(ctx, auth)
if errToken != nil {
@@ -156,7 +164,7 @@ func (e *AntigravityExecutor) Execute(ctx context.Context, auth *cliproxyauth.Au
return resp, err
}
// ExecuteStream handles streaming requests via the antigravity upstream.
// ExecuteStream performs a streaming request to the Antigravity API.
func (e *AntigravityExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (stream <-chan cliproxyexecutor.StreamChunk, err error) {
ctx = context.WithValue(ctx, "alt", "")
@@ -296,7 +304,7 @@ func (e *AntigravityExecutor) ExecuteStream(ctx context.Context, auth *cliproxya
return nil, err
}
// Refresh refreshes the OAuth token using the refresh token.
// Refresh refreshes the authentication credentials using the refresh token.
func (e *AntigravityExecutor) Refresh(ctx context.Context, auth *cliproxyauth.Auth) (*cliproxyauth.Auth, error) {
if auth == nil {
return auth, nil
@@ -308,7 +316,7 @@ func (e *AntigravityExecutor) Refresh(ctx context.Context, auth *cliproxyauth.Au
return updated, nil
}
// CountTokens is not supported for the antigravity provider.
// CountTokens counts tokens for the given request (not supported for Antigravity).
func (e *AntigravityExecutor) CountTokens(context.Context, *cliproxyauth.Auth, cliproxyexecutor.Request, cliproxyexecutor.Options) (cliproxyexecutor.Response, error) {
return cliproxyexecutor.Response{}, statusErr{code: http.StatusNotImplemented, msg: "count tokens not supported"}
}

View File

@@ -1,3 +1,6 @@
// Package executor provides runtime execution capabilities for various AI service providers.
// This file implements the Gemini CLI executor that talks to Cloud Code Assist endpoints
// using OAuth credentials from auth metadata.
package executor
import (
@@ -29,11 +32,11 @@ import (
const (
codeAssistEndpoint = "https://cloudcode-pa.googleapis.com"
codeAssistVersion = "v1internal"
geminiOauthClientID = "681255809395-oo8ft2oprdrnp9e3aqf6av3hmdib135j.apps.googleusercontent.com"
geminiOauthClientSecret = "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"
geminiOAuthClientID = "681255809395-oo8ft2oprdrnp9e3aqf6av3hmdib135j.apps.googleusercontent.com"
geminiOAuthClientSecret = "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"
)
var geminiOauthScopes = []string{
var geminiOAuthScopes = []string{
"https://www.googleapis.com/auth/cloud-platform",
"https://www.googleapis.com/auth/userinfo.email",
"https://www.googleapis.com/auth/userinfo.profile",
@@ -44,14 +47,24 @@ type GeminiCLIExecutor struct {
cfg *config.Config
}
// NewGeminiCLIExecutor creates a new Gemini CLI executor instance.
//
// Parameters:
// - cfg: The application configuration
//
// Returns:
// - *GeminiCLIExecutor: A new Gemini CLI executor instance
func NewGeminiCLIExecutor(cfg *config.Config) *GeminiCLIExecutor {
return &GeminiCLIExecutor{cfg: cfg}
}
// Identifier returns the executor identifier.
func (e *GeminiCLIExecutor) Identifier() string { return "gemini-cli" }
// PrepareRequest prepares the HTTP request for execution (no-op for Gemini CLI).
func (e *GeminiCLIExecutor) PrepareRequest(_ *http.Request, _ *cliproxyauth.Auth) error { return nil }
// Execute performs a non-streaming request to the Gemini CLI API.
func (e *GeminiCLIExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (resp cliproxyexecutor.Response, err error) {
tokenSource, baseTokenData, err := prepareGeminiCLITokenSource(ctx, e.cfg, auth)
if err != nil {
@@ -189,6 +202,7 @@ func (e *GeminiCLIExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth
return resp, err
}
// ExecuteStream performs a streaming request to the Gemini CLI API.
func (e *GeminiCLIExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (stream <-chan cliproxyexecutor.StreamChunk, err error) {
tokenSource, baseTokenData, err := prepareGeminiCLITokenSource(ctx, e.cfg, auth)
if err != nil {
@@ -309,7 +323,7 @@ func (e *GeminiCLIExecutor) ExecuteStream(ctx context.Context, auth *cliproxyaut
}()
if opts.Alt == "" {
scanner := bufio.NewScanner(resp.Body)
scanner.Buffer(nil, 52_428_800) // 50MB
scanner.Buffer(nil, streamScannerBuffer)
var param any
for scanner.Scan() {
line := scanner.Bytes()
@@ -371,6 +385,7 @@ func (e *GeminiCLIExecutor) ExecuteStream(ctx context.Context, auth *cliproxyaut
return nil, err
}
// CountTokens counts tokens for the given request using the Gemini CLI API.
func (e *GeminiCLIExecutor) CountTokens(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (cliproxyexecutor.Response, error) {
tokenSource, baseTokenData, err := prepareGeminiCLITokenSource(ctx, e.cfg, auth)
if err != nil {
@@ -471,9 +486,8 @@ func (e *GeminiCLIExecutor) CountTokens(ctx context.Context, auth *cliproxyauth.
return cliproxyexecutor.Response{}, newGeminiStatusErr(lastStatus, lastBody)
}
func (e *GeminiCLIExecutor) Refresh(ctx context.Context, auth *cliproxyauth.Auth) (*cliproxyauth.Auth, error) {
log.Debugf("gemini cli executor: refresh called")
_ = ctx
// Refresh refreshes the authentication credentials (no-op for Gemini CLI).
func (e *GeminiCLIExecutor) Refresh(_ context.Context, auth *cliproxyauth.Auth) (*cliproxyauth.Auth, error) {
return auth, nil
}
@@ -515,9 +529,9 @@ func prepareGeminiCLITokenSource(ctx context.Context, cfg *config.Config, auth *
}
conf := &oauth2.Config{
ClientID: geminiOauthClientID,
ClientSecret: geminiOauthClientSecret,
Scopes: geminiOauthScopes,
ClientID: geminiOAuthClientID,
ClientSecret: geminiOAuthClientSecret,
Scopes: geminiOAuthScopes,
Endpoint: google.Endpoint,
}

View File

@@ -11,7 +11,6 @@ import (
"io"
"net/http"
"strings"
"time"
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
"github.com/router-for-me/CLIProxyAPI/v6/internal/util"
@@ -21,8 +20,6 @@ import (
log "github.com/sirupsen/logrus"
"github.com/tidwall/gjson"
"github.com/tidwall/sjson"
"golang.org/x/oauth2"
"golang.org/x/oauth2/google"
)
const (
@@ -31,6 +28,9 @@ const (
// glAPIVersion is the API version used for Gemini requests.
glAPIVersion = "v1beta"
// streamScannerBuffer is the buffer size for SSE stream scanning.
streamScannerBuffer = 52_428_800
)
// GeminiExecutor is a stateless executor for the official Gemini API using API keys.
@@ -48,9 +48,11 @@ type GeminiExecutor struct {
//
// Returns:
// - *GeminiExecutor: A new Gemini executor instance
func NewGeminiExecutor(cfg *config.Config) *GeminiExecutor { return &GeminiExecutor{cfg: cfg} }
func NewGeminiExecutor(cfg *config.Config) *GeminiExecutor {
return &GeminiExecutor{cfg: cfg}
}
// Identifier returns the executor identifier for Gemini.
// Identifier returns the executor identifier.
func (e *GeminiExecutor) Identifier() string { return "gemini" }
// PrepareRequest prepares the HTTP request for execution (no-op for Gemini).
@@ -164,6 +166,7 @@ func (e *GeminiExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, r
return resp, nil
}
// ExecuteStream performs a streaming request to the Gemini API.
func (e *GeminiExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (stream <-chan cliproxyexecutor.StreamChunk, err error) {
apiKey, bearer := geminiCreds(auth)
@@ -249,7 +252,7 @@ func (e *GeminiExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.A
}
}()
scanner := bufio.NewScanner(httpResp.Body)
scanner.Buffer(nil, 52_428_800) // 50MB
scanner.Buffer(nil, streamScannerBuffer)
var param any
for scanner.Scan() {
line := scanner.Bytes()
@@ -280,6 +283,7 @@ func (e *GeminiExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.A
return stream, nil
}
// CountTokens counts tokens for the given request using the Gemini API.
func (e *GeminiExecutor) CountTokens(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (cliproxyexecutor.Response, error) {
apiKey, bearer := geminiCreds(auth)
@@ -353,106 +357,8 @@ func (e *GeminiExecutor) CountTokens(ctx context.Context, auth *cliproxyauth.Aut
return cliproxyexecutor.Response{Payload: []byte(translated)}, nil
}
func (e *GeminiExecutor) Refresh(ctx context.Context, auth *cliproxyauth.Auth) (*cliproxyauth.Auth, error) {
log.Debugf("gemini executor: refresh called")
// OAuth bearer token refresh for official Gemini API.
if auth == nil {
return nil, fmt.Errorf("gemini executor: auth is nil")
}
if auth.Metadata == nil {
return auth, nil
}
// Token data is typically nested under "token" map in Gemini files.
tokenMap, _ := auth.Metadata["token"].(map[string]any)
var refreshToken, accessToken, clientID, clientSecret, tokenURI, expiryStr string
if tokenMap != nil {
if v, ok := tokenMap["refresh_token"].(string); ok {
refreshToken = v
}
if v, ok := tokenMap["access_token"].(string); ok {
accessToken = v
}
if v, ok := tokenMap["client_id"].(string); ok {
clientID = v
}
if v, ok := tokenMap["client_secret"].(string); ok {
clientSecret = v
}
if v, ok := tokenMap["token_uri"].(string); ok {
tokenURI = v
}
if v, ok := tokenMap["expiry"].(string); ok {
expiryStr = v
}
} else {
// Fallback to top-level keys if present
if v, ok := auth.Metadata["refresh_token"].(string); ok {
refreshToken = v
}
if v, ok := auth.Metadata["access_token"].(string); ok {
accessToken = v
}
if v, ok := auth.Metadata["client_id"].(string); ok {
clientID = v
}
if v, ok := auth.Metadata["client_secret"].(string); ok {
clientSecret = v
}
if v, ok := auth.Metadata["token_uri"].(string); ok {
tokenURI = v
}
if v, ok := auth.Metadata["expiry"].(string); ok {
expiryStr = v
}
}
if refreshToken == "" {
// Nothing to do for API key or cookie based entries
return auth, nil
}
// Prepare oauth2 config; default to Google endpoints
endpoint := google.Endpoint
if tokenURI != "" {
endpoint.TokenURL = tokenURI
}
conf := &oauth2.Config{ClientID: clientID, ClientSecret: clientSecret, Endpoint: endpoint}
// Ensure proxy-aware HTTP client for token refresh
httpClient := util.SetProxy(&e.cfg.SDKConfig, &http.Client{})
ctx = context.WithValue(ctx, oauth2.HTTPClient, httpClient)
// Build base token
tok := &oauth2.Token{AccessToken: accessToken, RefreshToken: refreshToken}
if t, err := time.Parse(time.RFC3339, expiryStr); err == nil {
tok.Expiry = t
}
newTok, err := conf.TokenSource(ctx, tok).Token()
if err != nil {
return nil, err
}
// Persist back to metadata; prefer nested token map if present
if tokenMap == nil {
tokenMap = make(map[string]any)
}
tokenMap["access_token"] = newTok.AccessToken
tokenMap["refresh_token"] = newTok.RefreshToken
tokenMap["expiry"] = newTok.Expiry.Format(time.RFC3339)
if clientID != "" {
tokenMap["client_id"] = clientID
}
if clientSecret != "" {
tokenMap["client_secret"] = clientSecret
}
if tokenURI != "" {
tokenMap["token_uri"] = tokenURI
}
auth.Metadata["token"] = tokenMap
// Also mirror top-level access_token for compatibility if previously present
if _, ok := auth.Metadata["access_token"]; ok {
auth.Metadata["access_token"] = newTok.AccessToken
}
// Refresh refreshes the authentication credentials (no-op for Gemini API key).
func (e *GeminiExecutor) Refresh(_ context.Context, auth *cliproxyauth.Auth) (*cliproxyauth.Auth, error) {
return auth, nil
}

View File

@@ -1,6 +1,6 @@
// Package executor contains provider executors. This file implements the Vertex AI
// Gemini executor that talks to Google Vertex AI endpoints using service account
// credentials imported by the CLI.
// Package executor provides runtime execution capabilities for various AI service providers.
// This file implements the Vertex AI Gemini executor that talks to Google Vertex AI
// endpoints using service account credentials or API keys.
package executor
import (
@@ -36,20 +36,26 @@ type GeminiVertexExecutor struct {
cfg *config.Config
}
// NewGeminiVertexExecutor constructs the Vertex executor.
// NewGeminiVertexExecutor creates a new Vertex AI Gemini executor instance.
//
// Parameters:
// - cfg: The application configuration
//
// Returns:
// - *GeminiVertexExecutor: A new Vertex AI Gemini executor instance
func NewGeminiVertexExecutor(cfg *config.Config) *GeminiVertexExecutor {
return &GeminiVertexExecutor{cfg: cfg}
}
// Identifier returns provider key for manager routing.
// Identifier returns the executor identifier.
func (e *GeminiVertexExecutor) Identifier() string { return "vertex" }
// PrepareRequest is a no-op for Vertex.
// PrepareRequest prepares the HTTP request for execution (no-op for Vertex).
func (e *GeminiVertexExecutor) PrepareRequest(_ *http.Request, _ *cliproxyauth.Auth) error {
return nil
}
// Execute handles non-streaming requests.
// Execute performs a non-streaming request to the Vertex AI API.
func (e *GeminiVertexExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (resp cliproxyexecutor.Response, err error) {
// Try API key authentication first
apiKey, baseURL := vertexAPICreds(auth)
@@ -67,7 +73,7 @@ func (e *GeminiVertexExecutor) Execute(ctx context.Context, auth *cliproxyauth.A
return e.executeWithAPIKey(ctx, auth, req, opts, apiKey, baseURL)
}
// ExecuteStream handles SSE streaming for Vertex.
// ExecuteStream performs a streaming request to the Vertex AI API.
func (e *GeminiVertexExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (stream <-chan cliproxyexecutor.StreamChunk, err error) {
// Try API key authentication first
apiKey, baseURL := vertexAPICreds(auth)
@@ -85,7 +91,7 @@ func (e *GeminiVertexExecutor) ExecuteStream(ctx context.Context, auth *cliproxy
return e.executeStreamWithAPIKey(ctx, auth, req, opts, apiKey, baseURL)
}
// CountTokens calls Vertex countTokens endpoint.
// CountTokens counts tokens for the given request using the Vertex AI API.
func (e *GeminiVertexExecutor) CountTokens(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (cliproxyexecutor.Response, error) {
// Try API key authentication first
apiKey, baseURL := vertexAPICreds(auth)
@@ -103,185 +109,7 @@ func (e *GeminiVertexExecutor) CountTokens(ctx context.Context, auth *cliproxyau
return e.countTokensWithAPIKey(ctx, auth, req, opts, apiKey, baseURL)
}
// countTokensWithServiceAccount handles token counting using service account credentials.
func (e *GeminiVertexExecutor) countTokensWithServiceAccount(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options, projectID, location string, saJSON []byte) (cliproxyexecutor.Response, error) {
upstreamModel := util.ResolveOriginalModel(req.Model, req.Metadata)
from := opts.SourceFormat
to := sdktranslator.FromString("gemini")
translatedReq := sdktranslator.TranslateRequest(from, to, req.Model, bytes.Clone(req.Payload), false)
if budgetOverride, includeOverride, ok := util.ResolveThinkingConfigFromMetadata(req.Model, req.Metadata); ok && util.ModelSupportsThinking(req.Model) {
if budgetOverride != nil {
norm := util.NormalizeThinkingBudget(req.Model, *budgetOverride)
budgetOverride = &norm
}
translatedReq = util.ApplyGeminiThinkingConfig(translatedReq, budgetOverride, includeOverride)
}
translatedReq = util.StripThinkingConfigIfUnsupported(req.Model, translatedReq)
translatedReq = fixGeminiImageAspectRatio(req.Model, translatedReq)
translatedReq, _ = sjson.SetBytes(translatedReq, "model", upstreamModel)
respCtx := context.WithValue(ctx, "alt", opts.Alt)
translatedReq, _ = sjson.DeleteBytes(translatedReq, "tools")
translatedReq, _ = sjson.DeleteBytes(translatedReq, "generationConfig")
translatedReq, _ = sjson.DeleteBytes(translatedReq, "safetySettings")
baseURL := vertexBaseURL(location)
url := fmt.Sprintf("%s/%s/projects/%s/locations/%s/publishers/google/models/%s:%s", baseURL, vertexAPIVersion, projectID, location, upstreamModel, "countTokens")
httpReq, errNewReq := http.NewRequestWithContext(respCtx, http.MethodPost, url, bytes.NewReader(translatedReq))
if errNewReq != nil {
return cliproxyexecutor.Response{}, errNewReq
}
httpReq.Header.Set("Content-Type", "application/json")
if token, errTok := vertexAccessToken(ctx, e.cfg, auth, saJSON); errTok == nil && token != "" {
httpReq.Header.Set("Authorization", "Bearer "+token)
} else if errTok != nil {
log.Errorf("vertex executor: access token error: %v", errTok)
return cliproxyexecutor.Response{}, statusErr{code: 500, msg: "internal server error"}
}
applyGeminiHeaders(httpReq, auth)
var authID, authLabel, authType, authValue string
if auth != nil {
authID = auth.ID
authLabel = auth.Label
authType, authValue = auth.AccountInfo()
}
recordAPIRequest(ctx, e.cfg, upstreamRequestLog{
URL: url,
Method: http.MethodPost,
Headers: httpReq.Header.Clone(),
Body: translatedReq,
Provider: e.Identifier(),
AuthID: authID,
AuthLabel: authLabel,
AuthType: authType,
AuthValue: authValue,
})
httpClient := newProxyAwareHTTPClient(ctx, e.cfg, auth, 0)
httpResp, errDo := httpClient.Do(httpReq)
if errDo != nil {
recordAPIResponseError(ctx, e.cfg, errDo)
return cliproxyexecutor.Response{}, errDo
}
defer func() {
if errClose := httpResp.Body.Close(); errClose != nil {
log.Errorf("vertex executor: close response body error: %v", errClose)
}
}()
recordAPIResponseMetadata(ctx, e.cfg, httpResp.StatusCode, httpResp.Header.Clone())
if httpResp.StatusCode < 200 || httpResp.StatusCode >= 300 {
b, _ := io.ReadAll(httpResp.Body)
appendAPIResponseChunk(ctx, e.cfg, b)
log.Debugf("request error, error status: %d, error body: %s", httpResp.StatusCode, summarizeErrorBody(httpResp.Header.Get("Content-Type"), b))
return cliproxyexecutor.Response{}, statusErr{code: httpResp.StatusCode, msg: string(b)}
}
data, errRead := io.ReadAll(httpResp.Body)
if errRead != nil {
recordAPIResponseError(ctx, e.cfg, errRead)
return cliproxyexecutor.Response{}, errRead
}
appendAPIResponseChunk(ctx, e.cfg, data)
if httpResp.StatusCode < 200 || httpResp.StatusCode >= 300 {
log.Debugf("request error, error status: %d, error body: %s", httpResp.StatusCode, summarizeErrorBody(httpResp.Header.Get("Content-Type"), data))
return cliproxyexecutor.Response{}, statusErr{code: httpResp.StatusCode, msg: string(data)}
}
count := gjson.GetBytes(data, "totalTokens").Int()
out := sdktranslator.TranslateTokenCount(ctx, to, from, count, data)
return cliproxyexecutor.Response{Payload: []byte(out)}, nil
}
// countTokensWithAPIKey handles token counting using API key credentials.
func (e *GeminiVertexExecutor) countTokensWithAPIKey(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options, apiKey, baseURL string) (cliproxyexecutor.Response, error) {
upstreamModel := util.ResolveOriginalModel(req.Model, req.Metadata)
from := opts.SourceFormat
to := sdktranslator.FromString("gemini")
translatedReq := sdktranslator.TranslateRequest(from, to, req.Model, bytes.Clone(req.Payload), false)
if budgetOverride, includeOverride, ok := util.ResolveThinkingConfigFromMetadata(req.Model, req.Metadata); ok && util.ModelSupportsThinking(req.Model) {
if budgetOverride != nil {
norm := util.NormalizeThinkingBudget(req.Model, *budgetOverride)
budgetOverride = &norm
}
translatedReq = util.ApplyGeminiThinkingConfig(translatedReq, budgetOverride, includeOverride)
}
translatedReq = util.StripThinkingConfigIfUnsupported(req.Model, translatedReq)
translatedReq = fixGeminiImageAspectRatio(req.Model, translatedReq)
translatedReq, _ = sjson.SetBytes(translatedReq, "model", upstreamModel)
respCtx := context.WithValue(ctx, "alt", opts.Alt)
translatedReq, _ = sjson.DeleteBytes(translatedReq, "tools")
translatedReq, _ = sjson.DeleteBytes(translatedReq, "generationConfig")
translatedReq, _ = sjson.DeleteBytes(translatedReq, "safetySettings")
// For API key auth, use simpler URL format without project/location
if baseURL == "" {
baseURL = "https://generativelanguage.googleapis.com"
}
url := fmt.Sprintf("%s/%s/publishers/google/models/%s:%s", baseURL, vertexAPIVersion, req.Model, "countTokens")
httpReq, errNewReq := http.NewRequestWithContext(respCtx, http.MethodPost, url, bytes.NewReader(translatedReq))
if errNewReq != nil {
return cliproxyexecutor.Response{}, errNewReq
}
httpReq.Header.Set("Content-Type", "application/json")
if apiKey != "" {
httpReq.Header.Set("x-goog-api-key", apiKey)
}
applyGeminiHeaders(httpReq, auth)
var authID, authLabel, authType, authValue string
if auth != nil {
authID = auth.ID
authLabel = auth.Label
authType, authValue = auth.AccountInfo()
}
recordAPIRequest(ctx, e.cfg, upstreamRequestLog{
URL: url,
Method: http.MethodPost,
Headers: httpReq.Header.Clone(),
Body: translatedReq,
Provider: e.Identifier(),
AuthID: authID,
AuthLabel: authLabel,
AuthType: authType,
AuthValue: authValue,
})
httpClient := newProxyAwareHTTPClient(ctx, e.cfg, auth, 0)
httpResp, errDo := httpClient.Do(httpReq)
if errDo != nil {
recordAPIResponseError(ctx, e.cfg, errDo)
return cliproxyexecutor.Response{}, errDo
}
defer func() {
if errClose := httpResp.Body.Close(); errClose != nil {
log.Errorf("vertex executor: close response body error: %v", errClose)
}
}()
recordAPIResponseMetadata(ctx, e.cfg, httpResp.StatusCode, httpResp.Header.Clone())
if httpResp.StatusCode < 200 || httpResp.StatusCode >= 300 {
b, _ := io.ReadAll(httpResp.Body)
appendAPIResponseChunk(ctx, e.cfg, b)
log.Debugf("request error, error status: %d, error body: %s", httpResp.StatusCode, summarizeErrorBody(httpResp.Header.Get("Content-Type"), b))
return cliproxyexecutor.Response{}, statusErr{code: httpResp.StatusCode, msg: string(b)}
}
data, errRead := io.ReadAll(httpResp.Body)
if errRead != nil {
recordAPIResponseError(ctx, e.cfg, errRead)
return cliproxyexecutor.Response{}, errRead
}
appendAPIResponseChunk(ctx, e.cfg, data)
if httpResp.StatusCode < 200 || httpResp.StatusCode >= 300 {
log.Debugf("request error, error status: %d, error body: %s", httpResp.StatusCode, summarizeErrorBody(httpResp.Header.Get("Content-Type"), data))
return cliproxyexecutor.Response{}, statusErr{code: httpResp.StatusCode, msg: string(data)}
}
count := gjson.GetBytes(data, "totalTokens").Int()
out := sdktranslator.TranslateTokenCount(ctx, to, from, count, data)
return cliproxyexecutor.Response{Payload: []byte(out)}, nil
}
// Refresh is a no-op for service account based credentials.
// Refresh refreshes the authentication credentials (no-op for Vertex).
func (e *GeminiVertexExecutor) Refresh(_ context.Context, auth *cliproxyauth.Auth) (*cliproxyauth.Auth, error) {
return auth, nil
}
@@ -579,7 +407,7 @@ func (e *GeminiVertexExecutor) executeStreamWithServiceAccount(ctx context.Conte
}
}()
scanner := bufio.NewScanner(httpResp.Body)
scanner.Buffer(nil, 52_428_800) // 50MB
scanner.Buffer(nil, streamScannerBuffer)
var param any
for scanner.Scan() {
line := scanner.Bytes()
@@ -696,7 +524,7 @@ func (e *GeminiVertexExecutor) executeStreamWithAPIKey(ctx context.Context, auth
}
}()
scanner := bufio.NewScanner(httpResp.Body)
scanner.Buffer(nil, 52_428_800) // 50MB
scanner.Buffer(nil, streamScannerBuffer)
var param any
for scanner.Scan() {
line := scanner.Bytes()
@@ -722,6 +550,184 @@ func (e *GeminiVertexExecutor) executeStreamWithAPIKey(ctx context.Context, auth
return stream, nil
}
// countTokensWithServiceAccount counts tokens using service account credentials.
func (e *GeminiVertexExecutor) countTokensWithServiceAccount(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options, projectID, location string, saJSON []byte) (cliproxyexecutor.Response, error) {
upstreamModel := util.ResolveOriginalModel(req.Model, req.Metadata)
from := opts.SourceFormat
to := sdktranslator.FromString("gemini")
translatedReq := sdktranslator.TranslateRequest(from, to, req.Model, bytes.Clone(req.Payload), false)
if budgetOverride, includeOverride, ok := util.ResolveThinkingConfigFromMetadata(req.Model, req.Metadata); ok && util.ModelSupportsThinking(req.Model) {
if budgetOverride != nil {
norm := util.NormalizeThinkingBudget(req.Model, *budgetOverride)
budgetOverride = &norm
}
translatedReq = util.ApplyGeminiThinkingConfig(translatedReq, budgetOverride, includeOverride)
}
translatedReq = util.StripThinkingConfigIfUnsupported(req.Model, translatedReq)
translatedReq = fixGeminiImageAspectRatio(req.Model, translatedReq)
translatedReq, _ = sjson.SetBytes(translatedReq, "model", upstreamModel)
respCtx := context.WithValue(ctx, "alt", opts.Alt)
translatedReq, _ = sjson.DeleteBytes(translatedReq, "tools")
translatedReq, _ = sjson.DeleteBytes(translatedReq, "generationConfig")
translatedReq, _ = sjson.DeleteBytes(translatedReq, "safetySettings")
baseURL := vertexBaseURL(location)
url := fmt.Sprintf("%s/%s/projects/%s/locations/%s/publishers/google/models/%s:%s", baseURL, vertexAPIVersion, projectID, location, upstreamModel, "countTokens")
httpReq, errNewReq := http.NewRequestWithContext(respCtx, http.MethodPost, url, bytes.NewReader(translatedReq))
if errNewReq != nil {
return cliproxyexecutor.Response{}, errNewReq
}
httpReq.Header.Set("Content-Type", "application/json")
if token, errTok := vertexAccessToken(ctx, e.cfg, auth, saJSON); errTok == nil && token != "" {
httpReq.Header.Set("Authorization", "Bearer "+token)
} else if errTok != nil {
log.Errorf("vertex executor: access token error: %v", errTok)
return cliproxyexecutor.Response{}, statusErr{code: 500, msg: "internal server error"}
}
applyGeminiHeaders(httpReq, auth)
var authID, authLabel, authType, authValue string
if auth != nil {
authID = auth.ID
authLabel = auth.Label
authType, authValue = auth.AccountInfo()
}
recordAPIRequest(ctx, e.cfg, upstreamRequestLog{
URL: url,
Method: http.MethodPost,
Headers: httpReq.Header.Clone(),
Body: translatedReq,
Provider: e.Identifier(),
AuthID: authID,
AuthLabel: authLabel,
AuthType: authType,
AuthValue: authValue,
})
httpClient := newProxyAwareHTTPClient(ctx, e.cfg, auth, 0)
httpResp, errDo := httpClient.Do(httpReq)
if errDo != nil {
recordAPIResponseError(ctx, e.cfg, errDo)
return cliproxyexecutor.Response{}, errDo
}
defer func() {
if errClose := httpResp.Body.Close(); errClose != nil {
log.Errorf("vertex executor: close response body error: %v", errClose)
}
}()
recordAPIResponseMetadata(ctx, e.cfg, httpResp.StatusCode, httpResp.Header.Clone())
if httpResp.StatusCode < 200 || httpResp.StatusCode >= 300 {
b, _ := io.ReadAll(httpResp.Body)
appendAPIResponseChunk(ctx, e.cfg, b)
log.Debugf("request error, error status: %d, error body: %s", httpResp.StatusCode, summarizeErrorBody(httpResp.Header.Get("Content-Type"), b))
return cliproxyexecutor.Response{}, statusErr{code: httpResp.StatusCode, msg: string(b)}
}
data, errRead := io.ReadAll(httpResp.Body)
if errRead != nil {
recordAPIResponseError(ctx, e.cfg, errRead)
return cliproxyexecutor.Response{}, errRead
}
appendAPIResponseChunk(ctx, e.cfg, data)
if httpResp.StatusCode < 200 || httpResp.StatusCode >= 300 {
log.Debugf("request error, error status: %d, error body: %s", httpResp.StatusCode, summarizeErrorBody(httpResp.Header.Get("Content-Type"), data))
return cliproxyexecutor.Response{}, statusErr{code: httpResp.StatusCode, msg: string(data)}
}
count := gjson.GetBytes(data, "totalTokens").Int()
out := sdktranslator.TranslateTokenCount(ctx, to, from, count, data)
return cliproxyexecutor.Response{Payload: []byte(out)}, nil
}
// countTokensWithAPIKey handles token counting using API key credentials.
func (e *GeminiVertexExecutor) countTokensWithAPIKey(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options, apiKey, baseURL string) (cliproxyexecutor.Response, error) {
upstreamModel := util.ResolveOriginalModel(req.Model, req.Metadata)
from := opts.SourceFormat
to := sdktranslator.FromString("gemini")
translatedReq := sdktranslator.TranslateRequest(from, to, req.Model, bytes.Clone(req.Payload), false)
if budgetOverride, includeOverride, ok := util.ResolveThinkingConfigFromMetadata(req.Model, req.Metadata); ok && util.ModelSupportsThinking(req.Model) {
if budgetOverride != nil {
norm := util.NormalizeThinkingBudget(req.Model, *budgetOverride)
budgetOverride = &norm
}
translatedReq = util.ApplyGeminiThinkingConfig(translatedReq, budgetOverride, includeOverride)
}
translatedReq = util.StripThinkingConfigIfUnsupported(req.Model, translatedReq)
translatedReq = fixGeminiImageAspectRatio(req.Model, translatedReq)
translatedReq, _ = sjson.SetBytes(translatedReq, "model", upstreamModel)
respCtx := context.WithValue(ctx, "alt", opts.Alt)
translatedReq, _ = sjson.DeleteBytes(translatedReq, "tools")
translatedReq, _ = sjson.DeleteBytes(translatedReq, "generationConfig")
translatedReq, _ = sjson.DeleteBytes(translatedReq, "safetySettings")
// For API key auth, use simpler URL format without project/location
if baseURL == "" {
baseURL = "https://generativelanguage.googleapis.com"
}
url := fmt.Sprintf("%s/%s/publishers/google/models/%s:%s", baseURL, vertexAPIVersion, req.Model, "countTokens")
httpReq, errNewReq := http.NewRequestWithContext(respCtx, http.MethodPost, url, bytes.NewReader(translatedReq))
if errNewReq != nil {
return cliproxyexecutor.Response{}, errNewReq
}
httpReq.Header.Set("Content-Type", "application/json")
if apiKey != "" {
httpReq.Header.Set("x-goog-api-key", apiKey)
}
applyGeminiHeaders(httpReq, auth)
var authID, authLabel, authType, authValue string
if auth != nil {
authID = auth.ID
authLabel = auth.Label
authType, authValue = auth.AccountInfo()
}
recordAPIRequest(ctx, e.cfg, upstreamRequestLog{
URL: url,
Method: http.MethodPost,
Headers: httpReq.Header.Clone(),
Body: translatedReq,
Provider: e.Identifier(),
AuthID: authID,
AuthLabel: authLabel,
AuthType: authType,
AuthValue: authValue,
})
httpClient := newProxyAwareHTTPClient(ctx, e.cfg, auth, 0)
httpResp, errDo := httpClient.Do(httpReq)
if errDo != nil {
recordAPIResponseError(ctx, e.cfg, errDo)
return cliproxyexecutor.Response{}, errDo
}
defer func() {
if errClose := httpResp.Body.Close(); errClose != nil {
log.Errorf("vertex executor: close response body error: %v", errClose)
}
}()
recordAPIResponseMetadata(ctx, e.cfg, httpResp.StatusCode, httpResp.Header.Clone())
if httpResp.StatusCode < 200 || httpResp.StatusCode >= 300 {
b, _ := io.ReadAll(httpResp.Body)
appendAPIResponseChunk(ctx, e.cfg, b)
log.Debugf("request error, error status: %d, error body: %s", httpResp.StatusCode, summarizeErrorBody(httpResp.Header.Get("Content-Type"), b))
return cliproxyexecutor.Response{}, statusErr{code: httpResp.StatusCode, msg: string(b)}
}
data, errRead := io.ReadAll(httpResp.Body)
if errRead != nil {
recordAPIResponseError(ctx, e.cfg, errRead)
return cliproxyexecutor.Response{}, errRead
}
appendAPIResponseChunk(ctx, e.cfg, data)
if httpResp.StatusCode < 200 || httpResp.StatusCode >= 300 {
log.Debugf("request error, error status: %d, error body: %s", httpResp.StatusCode, summarizeErrorBody(httpResp.Header.Get("Content-Type"), data))
return cliproxyexecutor.Response{}, statusErr{code: httpResp.StatusCode, msg: string(data)}
}
count := gjson.GetBytes(data, "totalTokens").Int()
out := sdktranslator.TranslateTokenCount(ctx, to, from, count, data)
return cliproxyexecutor.Response{Payload: []byte(out)}, nil
}
// vertexCreds extracts project, location and raw service account JSON from auth metadata.
func vertexCreds(a *cliproxyauth.Auth) (projectID, location string, serviceAccountJSON []byte, err error) {
if a == nil || a.Metadata == nil {