mirror of
https://github.com/router-for-me/CLIProxyAPIPlus.git
synced 2026-03-09 15:25:17 +00:00
Merge remote-tracking branch 'upstream/dev' into dev
# Conflicts: # internal/runtime/executor/antigravity_executor.go
This commit is contained in:
@@ -58,6 +58,7 @@ func main() {
|
||||
// Command-line flags to control the application's behavior.
|
||||
var login bool
|
||||
var codexLogin bool
|
||||
var codexDeviceLogin bool
|
||||
var claudeLogin bool
|
||||
var qwenLogin bool
|
||||
var iflowLogin bool
|
||||
@@ -76,6 +77,7 @@ func main() {
|
||||
// Define command-line flags for different operation modes.
|
||||
flag.BoolVar(&login, "login", false, "Login Google Account")
|
||||
flag.BoolVar(&codexLogin, "codex-login", false, "Login to Codex using OAuth")
|
||||
flag.BoolVar(&codexDeviceLogin, "codex-device-login", false, "Login to Codex using device code flow")
|
||||
flag.BoolVar(&claudeLogin, "claude-login", false, "Login to Claude using OAuth")
|
||||
flag.BoolVar(&qwenLogin, "qwen-login", false, "Login to Qwen using OAuth")
|
||||
flag.BoolVar(&iflowLogin, "iflow-login", false, "Login to iFlow using OAuth")
|
||||
@@ -467,6 +469,9 @@ func main() {
|
||||
} else if codexLogin {
|
||||
// Handle Codex login
|
||||
cmd.DoCodexLogin(cfg, options)
|
||||
} else if codexDeviceLogin {
|
||||
// Handle Codex device-code login
|
||||
cmd.DoCodexDeviceLogin(cfg, options)
|
||||
} else if claudeLogin {
|
||||
// Handle Claude login
|
||||
cmd.DoClaudeLogin(cfg, options)
|
||||
|
||||
@@ -408,6 +408,9 @@ func (h *Handler) buildAuthFileEntry(auth *coreauth.Auth) gin.H {
|
||||
if !auth.LastRefreshedAt.IsZero() {
|
||||
entry["last_refresh"] = auth.LastRefreshedAt
|
||||
}
|
||||
if !auth.NextRetryAfter.IsZero() {
|
||||
entry["next_retry_after"] = auth.NextRetryAfter
|
||||
}
|
||||
if path != "" {
|
||||
entry["path"] = path
|
||||
entry["source"] = "file"
|
||||
@@ -947,11 +950,17 @@ func (h *Handler) saveTokenRecord(ctx context.Context, record *coreauth.Auth) (s
|
||||
if store == nil {
|
||||
return "", fmt.Errorf("token store unavailable")
|
||||
}
|
||||
if h.postAuthHook != nil {
|
||||
if err := h.postAuthHook(ctx, record); err != nil {
|
||||
return "", fmt.Errorf("post-auth hook failed: %w", err)
|
||||
}
|
||||
}
|
||||
return store.Save(ctx, record)
|
||||
}
|
||||
|
||||
func (h *Handler) RequestAnthropicToken(c *gin.Context) {
|
||||
ctx := context.Background()
|
||||
ctx = PopulateAuthContext(ctx, c)
|
||||
|
||||
fmt.Println("Initializing Claude authentication...")
|
||||
|
||||
@@ -1096,6 +1105,7 @@ func (h *Handler) RequestAnthropicToken(c *gin.Context) {
|
||||
|
||||
func (h *Handler) RequestGeminiCLIToken(c *gin.Context) {
|
||||
ctx := context.Background()
|
||||
ctx = PopulateAuthContext(ctx, c)
|
||||
proxyHTTPClient := util.SetProxy(&h.cfg.SDKConfig, &http.Client{})
|
||||
ctx = context.WithValue(ctx, oauth2.HTTPClient, proxyHTTPClient)
|
||||
|
||||
@@ -1354,6 +1364,7 @@ func (h *Handler) RequestGeminiCLIToken(c *gin.Context) {
|
||||
|
||||
func (h *Handler) RequestCodexToken(c *gin.Context) {
|
||||
ctx := context.Background()
|
||||
ctx = PopulateAuthContext(ctx, c)
|
||||
|
||||
fmt.Println("Initializing Codex authentication...")
|
||||
|
||||
@@ -1499,6 +1510,7 @@ func (h *Handler) RequestCodexToken(c *gin.Context) {
|
||||
|
||||
func (h *Handler) RequestAntigravityToken(c *gin.Context) {
|
||||
ctx := context.Background()
|
||||
ctx = PopulateAuthContext(ctx, c)
|
||||
|
||||
fmt.Println("Initializing Antigravity authentication...")
|
||||
|
||||
@@ -1663,6 +1675,7 @@ func (h *Handler) RequestAntigravityToken(c *gin.Context) {
|
||||
|
||||
func (h *Handler) RequestQwenToken(c *gin.Context) {
|
||||
ctx := context.Background()
|
||||
ctx = PopulateAuthContext(ctx, c)
|
||||
|
||||
fmt.Println("Initializing Qwen authentication...")
|
||||
|
||||
@@ -1718,6 +1731,7 @@ func (h *Handler) RequestQwenToken(c *gin.Context) {
|
||||
|
||||
func (h *Handler) RequestKimiToken(c *gin.Context) {
|
||||
ctx := context.Background()
|
||||
ctx = PopulateAuthContext(ctx, c)
|
||||
|
||||
fmt.Println("Initializing Kimi authentication...")
|
||||
|
||||
@@ -1794,6 +1808,7 @@ func (h *Handler) RequestKimiToken(c *gin.Context) {
|
||||
|
||||
func (h *Handler) RequestIFlowToken(c *gin.Context) {
|
||||
ctx := context.Background()
|
||||
ctx = PopulateAuthContext(ctx, c)
|
||||
|
||||
fmt.Println("Initializing iFlow authentication...")
|
||||
|
||||
@@ -2412,3 +2427,12 @@ func (h *Handler) GetAuthStatus(c *gin.Context) {
|
||||
}
|
||||
c.JSON(http.StatusOK, gin.H{"status": "wait"})
|
||||
}
|
||||
|
||||
// PopulateAuthContext extracts request info and adds it to the context
|
||||
func PopulateAuthContext(ctx context.Context, c *gin.Context) context.Context {
|
||||
info := &coreauth.RequestInfo{
|
||||
Query: c.Request.URL.Query(),
|
||||
Headers: c.Request.Header,
|
||||
}
|
||||
return coreauth.WithRequestInfo(ctx, info)
|
||||
}
|
||||
|
||||
@@ -47,6 +47,7 @@ type Handler struct {
|
||||
allowRemoteOverride bool
|
||||
envSecret string
|
||||
logDir string
|
||||
postAuthHook coreauth.PostAuthHook
|
||||
}
|
||||
|
||||
// NewHandler creates a new management handler instance.
|
||||
@@ -128,6 +129,11 @@ func (h *Handler) SetLogDirectory(dir string) {
|
||||
h.logDir = dir
|
||||
}
|
||||
|
||||
// SetPostAuthHook registers a hook to be called after auth record creation but before persistence.
|
||||
func (h *Handler) SetPostAuthHook(hook coreauth.PostAuthHook) {
|
||||
h.postAuthHook = hook
|
||||
}
|
||||
|
||||
// Middleware enforces access control for management endpoints.
|
||||
// All requests (local and remote) require a valid management key.
|
||||
// Additionally, remote access requires allow-remote-management=true.
|
||||
|
||||
@@ -51,6 +51,7 @@ type serverOptionConfig struct {
|
||||
keepAliveEnabled bool
|
||||
keepAliveTimeout time.Duration
|
||||
keepAliveOnTimeout func()
|
||||
postAuthHook auth.PostAuthHook
|
||||
}
|
||||
|
||||
// ServerOption customises HTTP server construction.
|
||||
@@ -111,6 +112,13 @@ func WithRequestLoggerFactory(factory func(*config.Config, string) logging.Reque
|
||||
}
|
||||
}
|
||||
|
||||
// WithPostAuthHook registers a hook to be called after auth record creation.
|
||||
func WithPostAuthHook(hook auth.PostAuthHook) ServerOption {
|
||||
return func(cfg *serverOptionConfig) {
|
||||
cfg.postAuthHook = hook
|
||||
}
|
||||
}
|
||||
|
||||
// Server represents the main API server.
|
||||
// It encapsulates the Gin engine, HTTP server, handlers, and configuration.
|
||||
type Server struct {
|
||||
@@ -262,6 +270,9 @@ func NewServer(cfg *config.Config, authManager *auth.Manager, accessManager *sdk
|
||||
}
|
||||
logDir := logging.ResolveLogDirectory(cfg)
|
||||
s.mgmt.SetLogDirectory(logDir)
|
||||
if optionState.postAuthHook != nil {
|
||||
s.mgmt.SetPostAuthHook(optionState.postAuthHook)
|
||||
}
|
||||
s.localPassword = optionState.localPassword
|
||||
|
||||
// Setup routes
|
||||
|
||||
@@ -36,11 +36,21 @@ type ClaudeTokenStorage struct {
|
||||
|
||||
// Expire is the timestamp when the current access token expires.
|
||||
Expire string `json:"expired"`
|
||||
|
||||
// Metadata holds arbitrary key-value pairs injected via hooks.
|
||||
// It is not exported to JSON directly to allow flattening during serialization.
|
||||
Metadata map[string]any `json:"-"`
|
||||
}
|
||||
|
||||
// SetMetadata allows external callers to inject metadata into the storage before saving.
|
||||
func (ts *ClaudeTokenStorage) SetMetadata(meta map[string]any) {
|
||||
ts.Metadata = meta
|
||||
}
|
||||
|
||||
// SaveTokenToFile serializes the Claude token storage to a JSON file.
|
||||
// This method creates the necessary directory structure and writes the token
|
||||
// data in JSON format to the specified file path for persistent storage.
|
||||
// It merges any injected metadata into the top-level JSON object.
|
||||
//
|
||||
// Parameters:
|
||||
// - authFilePath: The full path where the token file should be saved
|
||||
@@ -65,8 +75,14 @@ func (ts *ClaudeTokenStorage) SaveTokenToFile(authFilePath string) error {
|
||||
_ = f.Close()
|
||||
}()
|
||||
|
||||
// Merge metadata using helper
|
||||
data, errMerge := misc.MergeMetadata(ts, ts.Metadata)
|
||||
if errMerge != nil {
|
||||
return fmt.Errorf("failed to merge metadata: %w", errMerge)
|
||||
}
|
||||
|
||||
// Encode and write the token data as JSON
|
||||
if err = json.NewEncoder(f).Encode(ts); err != nil {
|
||||
if err = json.NewEncoder(f).Encode(data); err != nil {
|
||||
return fmt.Errorf("failed to write token to file: %w", err)
|
||||
}
|
||||
return nil
|
||||
|
||||
@@ -71,16 +71,26 @@ func (o *CodexAuth) GenerateAuthURL(state string, pkceCodes *PKCECodes) (string,
|
||||
// It performs an HTTP POST request to the OpenAI token endpoint with the provided
|
||||
// authorization code and PKCE verifier.
|
||||
func (o *CodexAuth) ExchangeCodeForTokens(ctx context.Context, code string, pkceCodes *PKCECodes) (*CodexAuthBundle, error) {
|
||||
return o.ExchangeCodeForTokensWithRedirect(ctx, code, RedirectURI, pkceCodes)
|
||||
}
|
||||
|
||||
// ExchangeCodeForTokensWithRedirect exchanges an authorization code for tokens using
|
||||
// a caller-provided redirect URI. This supports alternate auth flows such as device
|
||||
// login while preserving the existing token parsing and storage behavior.
|
||||
func (o *CodexAuth) ExchangeCodeForTokensWithRedirect(ctx context.Context, code, redirectURI string, pkceCodes *PKCECodes) (*CodexAuthBundle, error) {
|
||||
if pkceCodes == nil {
|
||||
return nil, fmt.Errorf("PKCE codes are required for token exchange")
|
||||
}
|
||||
if strings.TrimSpace(redirectURI) == "" {
|
||||
return nil, fmt.Errorf("redirect URI is required for token exchange")
|
||||
}
|
||||
|
||||
// Prepare token exchange request
|
||||
data := url.Values{
|
||||
"grant_type": {"authorization_code"},
|
||||
"client_id": {ClientID},
|
||||
"code": {code},
|
||||
"redirect_uri": {RedirectURI},
|
||||
"redirect_uri": {strings.TrimSpace(redirectURI)},
|
||||
"code_verifier": {pkceCodes.CodeVerifier},
|
||||
}
|
||||
|
||||
@@ -266,6 +276,10 @@ func (o *CodexAuth) RefreshTokensWithRetry(ctx context.Context, refreshToken str
|
||||
if err == nil {
|
||||
return tokenData, nil
|
||||
}
|
||||
if isNonRetryableRefreshErr(err) {
|
||||
log.Warnf("Token refresh attempt %d failed with non-retryable error: %v", attempt+1, err)
|
||||
return nil, err
|
||||
}
|
||||
|
||||
lastErr = err
|
||||
log.Warnf("Token refresh attempt %d failed: %v", attempt+1, err)
|
||||
@@ -274,6 +288,14 @@ func (o *CodexAuth) RefreshTokensWithRetry(ctx context.Context, refreshToken str
|
||||
return nil, fmt.Errorf("token refresh failed after %d attempts: %w", maxRetries, lastErr)
|
||||
}
|
||||
|
||||
func isNonRetryableRefreshErr(err error) bool {
|
||||
if err == nil {
|
||||
return false
|
||||
}
|
||||
raw := strings.ToLower(err.Error())
|
||||
return strings.Contains(raw, "refresh_token_reused")
|
||||
}
|
||||
|
||||
// UpdateTokenStorage updates an existing CodexTokenStorage with new token data.
|
||||
// This is typically called after a successful token refresh to persist the new credentials.
|
||||
func (o *CodexAuth) UpdateTokenStorage(storage *CodexTokenStorage, tokenData *CodexTokenData) {
|
||||
|
||||
44
internal/auth/codex/openai_auth_test.go
Normal file
44
internal/auth/codex/openai_auth_test.go
Normal file
@@ -0,0 +1,44 @@
|
||||
package codex
|
||||
|
||||
import (
|
||||
"context"
|
||||
"io"
|
||||
"net/http"
|
||||
"strings"
|
||||
"sync/atomic"
|
||||
"testing"
|
||||
)
|
||||
|
||||
type roundTripFunc func(*http.Request) (*http.Response, error)
|
||||
|
||||
func (f roundTripFunc) RoundTrip(req *http.Request) (*http.Response, error) {
|
||||
return f(req)
|
||||
}
|
||||
|
||||
func TestRefreshTokensWithRetry_NonRetryableOnlyAttemptsOnce(t *testing.T) {
|
||||
var calls int32
|
||||
auth := &CodexAuth{
|
||||
httpClient: &http.Client{
|
||||
Transport: roundTripFunc(func(req *http.Request) (*http.Response, error) {
|
||||
atomic.AddInt32(&calls, 1)
|
||||
return &http.Response{
|
||||
StatusCode: http.StatusBadRequest,
|
||||
Body: io.NopCloser(strings.NewReader(`{"error":"invalid_grant","code":"refresh_token_reused"}`)),
|
||||
Header: make(http.Header),
|
||||
Request: req,
|
||||
}, nil
|
||||
}),
|
||||
},
|
||||
}
|
||||
|
||||
_, err := auth.RefreshTokensWithRetry(context.Background(), "dummy_refresh_token", 3)
|
||||
if err == nil {
|
||||
t.Fatalf("expected error for non-retryable refresh failure")
|
||||
}
|
||||
if !strings.Contains(strings.ToLower(err.Error()), "refresh_token_reused") {
|
||||
t.Fatalf("expected refresh_token_reused in error, got: %v", err)
|
||||
}
|
||||
if got := atomic.LoadInt32(&calls); got != 1 {
|
||||
t.Fatalf("expected 1 refresh attempt, got %d", got)
|
||||
}
|
||||
}
|
||||
@@ -32,11 +32,21 @@ type CodexTokenStorage struct {
|
||||
Type string `json:"type"`
|
||||
// Expire is the timestamp when the current access token expires.
|
||||
Expire string `json:"expired"`
|
||||
|
||||
// Metadata holds arbitrary key-value pairs injected via hooks.
|
||||
// It is not exported to JSON directly to allow flattening during serialization.
|
||||
Metadata map[string]any `json:"-"`
|
||||
}
|
||||
|
||||
// SetMetadata allows external callers to inject metadata into the storage before saving.
|
||||
func (ts *CodexTokenStorage) SetMetadata(meta map[string]any) {
|
||||
ts.Metadata = meta
|
||||
}
|
||||
|
||||
// SaveTokenToFile serializes the Codex token storage to a JSON file.
|
||||
// This method creates the necessary directory structure and writes the token
|
||||
// data in JSON format to the specified file path for persistent storage.
|
||||
// It merges any injected metadata into the top-level JSON object.
|
||||
//
|
||||
// Parameters:
|
||||
// - authFilePath: The full path where the token file should be saved
|
||||
@@ -58,7 +68,13 @@ func (ts *CodexTokenStorage) SaveTokenToFile(authFilePath string) error {
|
||||
_ = f.Close()
|
||||
}()
|
||||
|
||||
if err = json.NewEncoder(f).Encode(ts); err != nil {
|
||||
// Merge metadata using helper
|
||||
data, errMerge := misc.MergeMetadata(ts, ts.Metadata)
|
||||
if errMerge != nil {
|
||||
return fmt.Errorf("failed to merge metadata: %w", errMerge)
|
||||
}
|
||||
|
||||
if err = json.NewEncoder(f).Encode(data); err != nil {
|
||||
return fmt.Errorf("failed to write token to file: %w", err)
|
||||
}
|
||||
return nil
|
||||
|
||||
@@ -35,11 +35,21 @@ type GeminiTokenStorage struct {
|
||||
|
||||
// Type indicates the authentication provider type, always "gemini" for this storage.
|
||||
Type string `json:"type"`
|
||||
|
||||
// Metadata holds arbitrary key-value pairs injected via hooks.
|
||||
// It is not exported to JSON directly to allow flattening during serialization.
|
||||
Metadata map[string]any `json:"-"`
|
||||
}
|
||||
|
||||
// SetMetadata allows external callers to inject metadata into the storage before saving.
|
||||
func (ts *GeminiTokenStorage) SetMetadata(meta map[string]any) {
|
||||
ts.Metadata = meta
|
||||
}
|
||||
|
||||
// SaveTokenToFile serializes the Gemini token storage to a JSON file.
|
||||
// This method creates the necessary directory structure and writes the token
|
||||
// data in JSON format to the specified file path for persistent storage.
|
||||
// It merges any injected metadata into the top-level JSON object.
|
||||
//
|
||||
// Parameters:
|
||||
// - authFilePath: The full path where the token file should be saved
|
||||
@@ -49,6 +59,11 @@ type GeminiTokenStorage struct {
|
||||
func (ts *GeminiTokenStorage) SaveTokenToFile(authFilePath string) error {
|
||||
misc.LogSavingCredentials(authFilePath)
|
||||
ts.Type = "gemini"
|
||||
// Merge metadata using helper
|
||||
data, errMerge := misc.MergeMetadata(ts, ts.Metadata)
|
||||
if errMerge != nil {
|
||||
return fmt.Errorf("failed to merge metadata: %w", errMerge)
|
||||
}
|
||||
if err := os.MkdirAll(filepath.Dir(authFilePath), 0700); err != nil {
|
||||
return fmt.Errorf("failed to create directory: %v", err)
|
||||
}
|
||||
@@ -63,7 +78,9 @@ func (ts *GeminiTokenStorage) SaveTokenToFile(authFilePath string) error {
|
||||
}
|
||||
}()
|
||||
|
||||
if err = json.NewEncoder(f).Encode(ts); err != nil {
|
||||
enc := json.NewEncoder(f)
|
||||
enc.SetIndent("", " ")
|
||||
if err := enc.Encode(data); err != nil {
|
||||
return fmt.Errorf("failed to write token to file: %w", err)
|
||||
}
|
||||
return nil
|
||||
|
||||
@@ -21,6 +21,15 @@ type IFlowTokenStorage struct {
|
||||
Scope string `json:"scope"`
|
||||
Cookie string `json:"cookie"`
|
||||
Type string `json:"type"`
|
||||
|
||||
// Metadata holds arbitrary key-value pairs injected via hooks.
|
||||
// It is not exported to JSON directly to allow flattening during serialization.
|
||||
Metadata map[string]any `json:"-"`
|
||||
}
|
||||
|
||||
// SetMetadata allows external callers to inject metadata into the storage before saving.
|
||||
func (ts *IFlowTokenStorage) SetMetadata(meta map[string]any) {
|
||||
ts.Metadata = meta
|
||||
}
|
||||
|
||||
// SaveTokenToFile serialises the token storage to disk.
|
||||
@@ -37,7 +46,13 @@ func (ts *IFlowTokenStorage) SaveTokenToFile(authFilePath string) error {
|
||||
}
|
||||
defer func() { _ = f.Close() }()
|
||||
|
||||
if err = json.NewEncoder(f).Encode(ts); err != nil {
|
||||
// Merge metadata using helper
|
||||
data, errMerge := misc.MergeMetadata(ts, ts.Metadata)
|
||||
if errMerge != nil {
|
||||
return fmt.Errorf("failed to merge metadata: %w", errMerge)
|
||||
}
|
||||
|
||||
if err = json.NewEncoder(f).Encode(data); err != nil {
|
||||
return fmt.Errorf("iflow token: encode token failed: %w", err)
|
||||
}
|
||||
return nil
|
||||
|
||||
@@ -29,6 +29,15 @@ type KimiTokenStorage struct {
|
||||
Expired string `json:"expired,omitempty"`
|
||||
// Type indicates the authentication provider type, always "kimi" for this storage.
|
||||
Type string `json:"type"`
|
||||
|
||||
// Metadata holds arbitrary key-value pairs injected via hooks.
|
||||
// It is not exported to JSON directly to allow flattening during serialization.
|
||||
Metadata map[string]any `json:"-"`
|
||||
}
|
||||
|
||||
// SetMetadata allows external callers to inject metadata into the storage before saving.
|
||||
func (ts *KimiTokenStorage) SetMetadata(meta map[string]any) {
|
||||
ts.Metadata = meta
|
||||
}
|
||||
|
||||
// KimiTokenData holds the raw OAuth token response from Kimi.
|
||||
@@ -86,9 +95,15 @@ func (ts *KimiTokenStorage) SaveTokenToFile(authFilePath string) error {
|
||||
_ = f.Close()
|
||||
}()
|
||||
|
||||
// Merge metadata using helper
|
||||
data, errMerge := misc.MergeMetadata(ts, ts.Metadata)
|
||||
if errMerge != nil {
|
||||
return fmt.Errorf("failed to merge metadata: %w", errMerge)
|
||||
}
|
||||
|
||||
encoder := json.NewEncoder(f)
|
||||
encoder.SetIndent("", " ")
|
||||
if err = encoder.Encode(ts); err != nil {
|
||||
if err = encoder.Encode(data); err != nil {
|
||||
return fmt.Errorf("failed to write token to file: %w", err)
|
||||
}
|
||||
return nil
|
||||
|
||||
@@ -30,11 +30,21 @@ type QwenTokenStorage struct {
|
||||
Type string `json:"type"`
|
||||
// Expire is the timestamp when the current access token expires.
|
||||
Expire string `json:"expired"`
|
||||
|
||||
// Metadata holds arbitrary key-value pairs injected via hooks.
|
||||
// It is not exported to JSON directly to allow flattening during serialization.
|
||||
Metadata map[string]any `json:"-"`
|
||||
}
|
||||
|
||||
// SetMetadata allows external callers to inject metadata into the storage before saving.
|
||||
func (ts *QwenTokenStorage) SetMetadata(meta map[string]any) {
|
||||
ts.Metadata = meta
|
||||
}
|
||||
|
||||
// SaveTokenToFile serializes the Qwen token storage to a JSON file.
|
||||
// This method creates the necessary directory structure and writes the token
|
||||
// data in JSON format to the specified file path for persistent storage.
|
||||
// It merges any injected metadata into the top-level JSON object.
|
||||
//
|
||||
// Parameters:
|
||||
// - authFilePath: The full path where the token file should be saved
|
||||
@@ -56,7 +66,13 @@ func (ts *QwenTokenStorage) SaveTokenToFile(authFilePath string) error {
|
||||
_ = f.Close()
|
||||
}()
|
||||
|
||||
if err = json.NewEncoder(f).Encode(ts); err != nil {
|
||||
// Merge metadata using helper
|
||||
data, errMerge := misc.MergeMetadata(ts, ts.Metadata)
|
||||
if errMerge != nil {
|
||||
return fmt.Errorf("failed to merge metadata: %w", errMerge)
|
||||
}
|
||||
|
||||
if err = json.NewEncoder(f).Encode(data); err != nil {
|
||||
return fmt.Errorf("failed to write token to file: %w", err)
|
||||
}
|
||||
return nil
|
||||
|
||||
60
internal/cmd/openai_device_login.go
Normal file
60
internal/cmd/openai_device_login.go
Normal file
@@ -0,0 +1,60 @@
|
||||
package cmd
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"os"
|
||||
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/auth/codex"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
|
||||
sdkAuth "github.com/router-for-me/CLIProxyAPI/v6/sdk/auth"
|
||||
log "github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
const (
|
||||
codexLoginModeMetadataKey = "codex_login_mode"
|
||||
codexLoginModeDevice = "device"
|
||||
)
|
||||
|
||||
// DoCodexDeviceLogin triggers the Codex device-code flow while keeping the
|
||||
// existing codex-login OAuth callback flow intact.
|
||||
func DoCodexDeviceLogin(cfg *config.Config, options *LoginOptions) {
|
||||
if options == nil {
|
||||
options = &LoginOptions{}
|
||||
}
|
||||
|
||||
promptFn := options.Prompt
|
||||
if promptFn == nil {
|
||||
promptFn = defaultProjectPrompt()
|
||||
}
|
||||
|
||||
manager := newAuthManager()
|
||||
|
||||
authOpts := &sdkAuth.LoginOptions{
|
||||
NoBrowser: options.NoBrowser,
|
||||
CallbackPort: options.CallbackPort,
|
||||
Metadata: map[string]string{
|
||||
codexLoginModeMetadataKey: codexLoginModeDevice,
|
||||
},
|
||||
Prompt: promptFn,
|
||||
}
|
||||
|
||||
_, savedPath, err := manager.Login(context.Background(), "codex", cfg, authOpts)
|
||||
if err != nil {
|
||||
if authErr, ok := errors.AsType[*codex.AuthenticationError](err); ok {
|
||||
log.Error(codex.GetUserFriendlyMessage(authErr))
|
||||
if authErr.Type == codex.ErrPortInUse.Type {
|
||||
os.Exit(codex.ErrPortInUse.Code)
|
||||
}
|
||||
return
|
||||
}
|
||||
fmt.Printf("Codex device authentication failed: %v\n", err)
|
||||
return
|
||||
}
|
||||
|
||||
if savedPath != "" {
|
||||
fmt.Printf("Authentication saved to %s\n", savedPath)
|
||||
}
|
||||
fmt.Println("Codex device authentication successful!")
|
||||
}
|
||||
@@ -1,6 +1,7 @@
|
||||
package misc
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
@@ -24,3 +25,37 @@ func LogSavingCredentials(path string) {
|
||||
func LogCredentialSeparator() {
|
||||
log.Debug(credentialSeparator)
|
||||
}
|
||||
|
||||
// MergeMetadata serializes the source struct into a map and merges the provided metadata into it.
|
||||
func MergeMetadata(source any, metadata map[string]any) (map[string]any, error) {
|
||||
var data map[string]any
|
||||
|
||||
// Fast path: if source is already a map, just copy it to avoid mutation of original
|
||||
if srcMap, ok := source.(map[string]any); ok {
|
||||
data = make(map[string]any, len(srcMap)+len(metadata))
|
||||
for k, v := range srcMap {
|
||||
data[k] = v
|
||||
}
|
||||
} else {
|
||||
// Slow path: marshal to JSON and back to map to respect JSON tags
|
||||
temp, err := json.Marshal(source)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to marshal source: %w", err)
|
||||
}
|
||||
if err := json.Unmarshal(temp, &data); err != nil {
|
||||
return nil, fmt.Errorf("failed to unmarshal to map: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
// Merge extra metadata
|
||||
if metadata != nil {
|
||||
if data == nil {
|
||||
data = make(map[string]any)
|
||||
}
|
||||
for k, v := range metadata {
|
||||
data[k] = v
|
||||
}
|
||||
}
|
||||
|
||||
return data, nil
|
||||
}
|
||||
|
||||
@@ -904,19 +904,12 @@ func GetIFlowModels() []*ModelInfo {
|
||||
Created int64
|
||||
Thinking *ThinkingSupport
|
||||
}{
|
||||
{ID: "tstars2.0", DisplayName: "TStars-2.0", Description: "iFlow TStars-2.0 multimodal assistant", Created: 1746489600},
|
||||
{ID: "qwen3-coder-plus", DisplayName: "Qwen3-Coder-Plus", Description: "Qwen3 Coder Plus code generation", Created: 1753228800},
|
||||
{ID: "qwen3-max", DisplayName: "Qwen3-Max", Description: "Qwen3 flagship model", Created: 1758672000},
|
||||
{ID: "qwen3-vl-plus", DisplayName: "Qwen3-VL-Plus", Description: "Qwen3 multimodal vision-language", Created: 1758672000},
|
||||
{ID: "qwen3-max-preview", DisplayName: "Qwen3-Max-Preview", Description: "Qwen3 Max preview build", Created: 1757030400, Thinking: iFlowThinkingSupport},
|
||||
{ID: "kimi-k2-0905", DisplayName: "Kimi-K2-Instruct-0905", Description: "Moonshot Kimi K2 instruct 0905", Created: 1757030400},
|
||||
{ID: "glm-4.6", DisplayName: "GLM-4.6", Description: "Zhipu GLM 4.6 general model", Created: 1759190400, Thinking: iFlowThinkingSupport},
|
||||
{ID: "glm-4.7", DisplayName: "GLM-4.7", Description: "Zhipu GLM 4.7 general model", Created: 1766448000, Thinking: iFlowThinkingSupport},
|
||||
{ID: "glm-5", DisplayName: "GLM-5", Description: "Zhipu GLM 5 general model", Created: 1770768000, Thinking: iFlowThinkingSupport},
|
||||
{ID: "kimi-k2", DisplayName: "Kimi-K2", Description: "Moonshot Kimi K2 general model", Created: 1752192000},
|
||||
{ID: "kimi-k2-thinking", DisplayName: "Kimi-K2-Thinking", Description: "Moonshot Kimi K2 thinking model", Created: 1762387200},
|
||||
{ID: "deepseek-v3.2-chat", DisplayName: "DeepSeek-V3.2", Description: "DeepSeek V3.2 Chat", Created: 1764576000},
|
||||
{ID: "deepseek-v3.2-reasoner", DisplayName: "DeepSeek-V3.2", Description: "DeepSeek V3.2 Reasoner", Created: 1764576000},
|
||||
{ID: "deepseek-v3.2", DisplayName: "DeepSeek-V3.2-Exp", Description: "DeepSeek V3.2 experimental", Created: 1759104000, Thinking: iFlowThinkingSupport},
|
||||
{ID: "deepseek-v3.1", DisplayName: "DeepSeek-V3.1-Terminus", Description: "DeepSeek V3.1 Terminus", Created: 1756339200, Thinking: iFlowThinkingSupport},
|
||||
{ID: "deepseek-r1", DisplayName: "DeepSeek-R1", Description: "DeepSeek reasoning model R1", Created: 1737331200},
|
||||
@@ -925,11 +918,7 @@ func GetIFlowModels() []*ModelInfo {
|
||||
{ID: "qwen3-235b-a22b-thinking-2507", DisplayName: "Qwen3-235B-A22B-Thinking", Description: "Qwen3 235B A22B Thinking (2507)", Created: 1753401600},
|
||||
{ID: "qwen3-235b-a22b-instruct", DisplayName: "Qwen3-235B-A22B-Instruct", Description: "Qwen3 235B A22B Instruct", Created: 1753401600},
|
||||
{ID: "qwen3-235b", DisplayName: "Qwen3-235B-A22B", Description: "Qwen3 235B A22B", Created: 1753401600},
|
||||
{ID: "minimax-m2", DisplayName: "MiniMax-M2", Description: "MiniMax M2", Created: 1758672000, Thinking: iFlowThinkingSupport},
|
||||
{ID: "minimax-m2.1", DisplayName: "MiniMax-M2.1", Description: "MiniMax M2.1", Created: 1766448000, Thinking: iFlowThinkingSupport},
|
||||
{ID: "minimax-m2.5", DisplayName: "MiniMax-M2.5", Description: "MiniMax M2.5", Created: 1770825600, Thinking: iFlowThinkingSupport},
|
||||
{ID: "iflow-rome-30ba3b", DisplayName: "iFlow-ROME", Description: "iFlow Rome 30BA3B model", Created: 1736899200},
|
||||
{ID: "kimi-k2.5", DisplayName: "Kimi-K2.5", Description: "Moonshot Kimi K2.5", Created: 1769443200, Thinking: iFlowThinkingSupport},
|
||||
}
|
||||
models := make([]*ModelInfo, 0, len(entries))
|
||||
for _, entry := range entries {
|
||||
@@ -963,6 +952,7 @@ func GetAntigravityModelConfig() map[string]*AntigravityModelConfig {
|
||||
"gemini-2.5-flash-lite": {Thinking: &ThinkingSupport{Min: 0, Max: 24576, ZeroAllowed: true, DynamicAllowed: true}},
|
||||
"gemini-3-pro-high": {Thinking: &ThinkingSupport{Min: 128, Max: 32768, ZeroAllowed: false, DynamicAllowed: true, Levels: []string{"low", "high"}}},
|
||||
"gemini-3-pro-image": {Thinking: &ThinkingSupport{Min: 128, Max: 32768, ZeroAllowed: false, DynamicAllowed: true, Levels: []string{"low", "high"}}},
|
||||
"gemini-3.1-pro-high": {Thinking: &ThinkingSupport{Min: 128, Max: 32768, ZeroAllowed: false, DynamicAllowed: true, Levels: []string{"low", "high"}}},
|
||||
"gemini-3-flash": {Thinking: &ThinkingSupport{Min: 128, Max: 32768, ZeroAllowed: false, DynamicAllowed: true, Levels: []string{"minimal", "low", "medium", "high"}}},
|
||||
"claude-sonnet-4-5-thinking": {Thinking: &ThinkingSupport{Min: 1024, Max: 128000, ZeroAllowed: true, DynamicAllowed: true}, MaxCompletionTokens: 64000},
|
||||
"claude-opus-4-5-thinking": {Thinking: &ThinkingSupport{Min: 1024, Max: 128000, ZeroAllowed: true, DynamicAllowed: true}, MaxCompletionTokens: 64000},
|
||||
|
||||
@@ -55,8 +55,78 @@ const (
|
||||
var (
|
||||
randSource = rand.New(rand.NewSource(time.Now().UnixNano()))
|
||||
randSourceMutex sync.Mutex
|
||||
// antigravityPrimaryModelsCache keeps the latest non-empty model list fetched
|
||||
// from any antigravity auth. Empty fetches never overwrite this cache.
|
||||
antigravityPrimaryModelsCache struct {
|
||||
mu sync.RWMutex
|
||||
models []*registry.ModelInfo
|
||||
}
|
||||
)
|
||||
|
||||
func cloneAntigravityModels(models []*registry.ModelInfo) []*registry.ModelInfo {
|
||||
if len(models) == 0 {
|
||||
return nil
|
||||
}
|
||||
out := make([]*registry.ModelInfo, 0, len(models))
|
||||
for _, model := range models {
|
||||
if model == nil || strings.TrimSpace(model.ID) == "" {
|
||||
continue
|
||||
}
|
||||
out = append(out, cloneAntigravityModelInfo(model))
|
||||
}
|
||||
if len(out) == 0 {
|
||||
return nil
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
func cloneAntigravityModelInfo(model *registry.ModelInfo) *registry.ModelInfo {
|
||||
if model == nil {
|
||||
return nil
|
||||
}
|
||||
clone := *model
|
||||
if len(model.SupportedGenerationMethods) > 0 {
|
||||
clone.SupportedGenerationMethods = append([]string(nil), model.SupportedGenerationMethods...)
|
||||
}
|
||||
if len(model.SupportedParameters) > 0 {
|
||||
clone.SupportedParameters = append([]string(nil), model.SupportedParameters...)
|
||||
}
|
||||
if model.Thinking != nil {
|
||||
thinkingClone := *model.Thinking
|
||||
if len(model.Thinking.Levels) > 0 {
|
||||
thinkingClone.Levels = append([]string(nil), model.Thinking.Levels...)
|
||||
}
|
||||
clone.Thinking = &thinkingClone
|
||||
}
|
||||
return &clone
|
||||
}
|
||||
|
||||
func storeAntigravityPrimaryModels(models []*registry.ModelInfo) bool {
|
||||
cloned := cloneAntigravityModels(models)
|
||||
if len(cloned) == 0 {
|
||||
return false
|
||||
}
|
||||
antigravityPrimaryModelsCache.mu.Lock()
|
||||
antigravityPrimaryModelsCache.models = cloned
|
||||
antigravityPrimaryModelsCache.mu.Unlock()
|
||||
return true
|
||||
}
|
||||
|
||||
func loadAntigravityPrimaryModels() []*registry.ModelInfo {
|
||||
antigravityPrimaryModelsCache.mu.RLock()
|
||||
cloned := cloneAntigravityModels(antigravityPrimaryModelsCache.models)
|
||||
antigravityPrimaryModelsCache.mu.RUnlock()
|
||||
return cloned
|
||||
}
|
||||
|
||||
func fallbackAntigravityPrimaryModels() []*registry.ModelInfo {
|
||||
models := loadAntigravityPrimaryModels()
|
||||
if len(models) > 0 {
|
||||
log.Debugf("antigravity executor: using cached primary model list (%d models)", len(models))
|
||||
}
|
||||
return models
|
||||
}
|
||||
|
||||
// AntigravityExecutor proxies requests to the antigravity upstream.
|
||||
type AntigravityExecutor struct {
|
||||
cfg *config.Config
|
||||
@@ -1072,7 +1142,7 @@ func FetchAntigravityModels(ctx context.Context, auth *cliproxyauth.Auth, cfg *c
|
||||
exec := &AntigravityExecutor{cfg: cfg}
|
||||
token, updatedAuth, errToken := exec.ensureAccessToken(ctx, auth)
|
||||
if errToken != nil || token == "" {
|
||||
return nil
|
||||
return fallbackAntigravityPrimaryModels()
|
||||
}
|
||||
if updatedAuth != nil {
|
||||
auth = updatedAuth
|
||||
@@ -1096,7 +1166,7 @@ func FetchAntigravityModels(ctx context.Context, auth *cliproxyauth.Auth, cfg *c
|
||||
|
||||
httpReq, errReq := http.NewRequestWithContext(ctx, http.MethodPost, modelsURL, bytes.NewReader(payload))
|
||||
if errReq != nil {
|
||||
return nil
|
||||
return fallbackAntigravityPrimaryModels()
|
||||
}
|
||||
httpReq.Close = true
|
||||
httpReq.Header.Set("Content-Type", "application/json")
|
||||
@@ -1109,14 +1179,13 @@ func FetchAntigravityModels(ctx context.Context, auth *cliproxyauth.Auth, cfg *c
|
||||
httpResp, errDo := httpClient.Do(httpReq)
|
||||
if errDo != nil {
|
||||
if errors.Is(errDo, context.Canceled) || errors.Is(errDo, context.DeadlineExceeded) {
|
||||
return nil
|
||||
return fallbackAntigravityPrimaryModels()
|
||||
}
|
||||
if idx+1 < len(baseURLs) {
|
||||
log.Debugf("antigravity executor: models request error on base url %s, retrying with fallback base url: %s", baseURL, baseURLs[idx+1])
|
||||
continue
|
||||
}
|
||||
log.Errorf("antigravity executor: models request failed: %v", errDo)
|
||||
return nil
|
||||
return fallbackAntigravityPrimaryModels()
|
||||
}
|
||||
|
||||
bodyBytes, errRead := io.ReadAll(httpResp.Body)
|
||||
@@ -1128,21 +1197,27 @@ func FetchAntigravityModels(ctx context.Context, auth *cliproxyauth.Auth, cfg *c
|
||||
log.Debugf("antigravity executor: models read error on base url %s, retrying with fallback base url: %s", baseURL, baseURLs[idx+1])
|
||||
continue
|
||||
}
|
||||
log.Errorf("antigravity executor: models read body failed: %v", errRead)
|
||||
return nil
|
||||
return fallbackAntigravityPrimaryModels()
|
||||
}
|
||||
if httpResp.StatusCode < http.StatusOK || httpResp.StatusCode >= http.StatusMultipleChoices {
|
||||
if httpResp.StatusCode == http.StatusTooManyRequests && idx+1 < len(baseURLs) {
|
||||
log.Debugf("antigravity executor: models request rate limited on base url %s, retrying with fallback base url: %s", baseURL, baseURLs[idx+1])
|
||||
continue
|
||||
}
|
||||
log.Errorf("antigravity executor: models request error status %d: %s", httpResp.StatusCode, string(bodyBytes))
|
||||
return nil
|
||||
if idx+1 < len(baseURLs) {
|
||||
log.Debugf("antigravity executor: models request failed with status %d on base url %s, retrying with fallback base url: %s", httpResp.StatusCode, baseURL, baseURLs[idx+1])
|
||||
continue
|
||||
}
|
||||
return fallbackAntigravityPrimaryModels()
|
||||
}
|
||||
|
||||
result := gjson.GetBytes(bodyBytes, "models")
|
||||
if !result.Exists() {
|
||||
return nil
|
||||
if idx+1 < len(baseURLs) {
|
||||
log.Debugf("antigravity executor: models field missing on base url %s, retrying with fallback base url: %s", baseURL, baseURLs[idx+1])
|
||||
continue
|
||||
}
|
||||
return fallbackAntigravityPrimaryModels()
|
||||
}
|
||||
|
||||
now := time.Now().Unix()
|
||||
@@ -1210,9 +1285,18 @@ func FetchAntigravityModels(ctx context.Context, auth *cliproxyauth.Auth, cfg *c
|
||||
}
|
||||
models = append(models, modelInfo)
|
||||
}
|
||||
if len(models) == 0 {
|
||||
if idx+1 < len(baseURLs) {
|
||||
log.Debugf("antigravity executor: empty models list on base url %s, retrying with fallback base url: %s", baseURL, baseURLs[idx+1])
|
||||
continue
|
||||
}
|
||||
log.Debug("antigravity executor: fetched empty model list; retaining cached primary model list")
|
||||
return fallbackAntigravityPrimaryModels()
|
||||
}
|
||||
storeAntigravityPrimaryModels(models)
|
||||
return models
|
||||
}
|
||||
return nil
|
||||
return fallbackAntigravityPrimaryModels()
|
||||
}
|
||||
|
||||
func (e *AntigravityExecutor) ensureAccessToken(ctx context.Context, auth *cliproxyauth.Auth) (string, *cliproxyauth.Auth, error) {
|
||||
|
||||
@@ -0,0 +1,90 @@
|
||||
package executor
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/registry"
|
||||
)
|
||||
|
||||
func resetAntigravityPrimaryModelsCacheForTest() {
|
||||
antigravityPrimaryModelsCache.mu.Lock()
|
||||
antigravityPrimaryModelsCache.models = nil
|
||||
antigravityPrimaryModelsCache.mu.Unlock()
|
||||
}
|
||||
|
||||
func TestStoreAntigravityPrimaryModels_EmptyDoesNotOverwrite(t *testing.T) {
|
||||
resetAntigravityPrimaryModelsCacheForTest()
|
||||
t.Cleanup(resetAntigravityPrimaryModelsCacheForTest)
|
||||
|
||||
seed := []*registry.ModelInfo{
|
||||
{ID: "claude-sonnet-4-5"},
|
||||
{ID: "gemini-2.5-pro"},
|
||||
}
|
||||
if updated := storeAntigravityPrimaryModels(seed); !updated {
|
||||
t.Fatal("expected non-empty model list to update primary cache")
|
||||
}
|
||||
|
||||
if updated := storeAntigravityPrimaryModels(nil); updated {
|
||||
t.Fatal("expected nil model list not to overwrite primary cache")
|
||||
}
|
||||
if updated := storeAntigravityPrimaryModels([]*registry.ModelInfo{}); updated {
|
||||
t.Fatal("expected empty model list not to overwrite primary cache")
|
||||
}
|
||||
|
||||
got := loadAntigravityPrimaryModels()
|
||||
if len(got) != 2 {
|
||||
t.Fatalf("expected cached model count 2, got %d", len(got))
|
||||
}
|
||||
if got[0].ID != "claude-sonnet-4-5" || got[1].ID != "gemini-2.5-pro" {
|
||||
t.Fatalf("unexpected cached model ids: %q, %q", got[0].ID, got[1].ID)
|
||||
}
|
||||
}
|
||||
|
||||
func TestLoadAntigravityPrimaryModels_ReturnsClone(t *testing.T) {
|
||||
resetAntigravityPrimaryModelsCacheForTest()
|
||||
t.Cleanup(resetAntigravityPrimaryModelsCacheForTest)
|
||||
|
||||
if updated := storeAntigravityPrimaryModels([]*registry.ModelInfo{{
|
||||
ID: "gpt-5",
|
||||
DisplayName: "GPT-5",
|
||||
SupportedGenerationMethods: []string{"generateContent"},
|
||||
SupportedParameters: []string{"temperature"},
|
||||
Thinking: ®istry.ThinkingSupport{
|
||||
Levels: []string{"high"},
|
||||
},
|
||||
}}); !updated {
|
||||
t.Fatal("expected model cache update")
|
||||
}
|
||||
|
||||
got := loadAntigravityPrimaryModels()
|
||||
if len(got) != 1 {
|
||||
t.Fatalf("expected one cached model, got %d", len(got))
|
||||
}
|
||||
got[0].ID = "mutated-id"
|
||||
if len(got[0].SupportedGenerationMethods) > 0 {
|
||||
got[0].SupportedGenerationMethods[0] = "mutated-method"
|
||||
}
|
||||
if len(got[0].SupportedParameters) > 0 {
|
||||
got[0].SupportedParameters[0] = "mutated-parameter"
|
||||
}
|
||||
if got[0].Thinking != nil && len(got[0].Thinking.Levels) > 0 {
|
||||
got[0].Thinking.Levels[0] = "mutated-level"
|
||||
}
|
||||
|
||||
again := loadAntigravityPrimaryModels()
|
||||
if len(again) != 1 {
|
||||
t.Fatalf("expected one cached model after mutation, got %d", len(again))
|
||||
}
|
||||
if again[0].ID != "gpt-5" {
|
||||
t.Fatalf("expected cached model id to remain %q, got %q", "gpt-5", again[0].ID)
|
||||
}
|
||||
if len(again[0].SupportedGenerationMethods) == 0 || again[0].SupportedGenerationMethods[0] != "generateContent" {
|
||||
t.Fatalf("expected cached generation methods to be unmutated, got %v", again[0].SupportedGenerationMethods)
|
||||
}
|
||||
if len(again[0].SupportedParameters) == 0 || again[0].SupportedParameters[0] != "temperature" {
|
||||
t.Fatalf("expected cached supported parameters to be unmutated, got %v", again[0].SupportedParameters)
|
||||
}
|
||||
if again[0].Thinking == nil || len(again[0].Thinking.Levels) == 0 || again[0].Thinking.Levels[0] != "high" {
|
||||
t.Fatalf("expected cached model thinking levels to be unmutated, got %v", again[0].Thinking)
|
||||
}
|
||||
}
|
||||
@@ -156,7 +156,7 @@ func (e *CodexExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, re
|
||||
b, _ := io.ReadAll(httpResp.Body)
|
||||
appendAPIResponseChunk(ctx, e.cfg, b)
|
||||
logWithRequestID(ctx).Debugf("request error, error status: %d, error message: %s", httpResp.StatusCode, summarizeErrorBody(httpResp.Header.Get("Content-Type"), b))
|
||||
err = statusErr{code: httpResp.StatusCode, msg: string(b)}
|
||||
err = newCodexStatusErr(httpResp.StatusCode, b)
|
||||
return resp, err
|
||||
}
|
||||
data, err := io.ReadAll(httpResp.Body)
|
||||
@@ -260,7 +260,7 @@ func (e *CodexExecutor) executeCompact(ctx context.Context, auth *cliproxyauth.A
|
||||
b, _ := io.ReadAll(httpResp.Body)
|
||||
appendAPIResponseChunk(ctx, e.cfg, b)
|
||||
logWithRequestID(ctx).Debugf("request error, error status: %d, error message: %s", httpResp.StatusCode, summarizeErrorBody(httpResp.Header.Get("Content-Type"), b))
|
||||
err = statusErr{code: httpResp.StatusCode, msg: string(b)}
|
||||
err = newCodexStatusErr(httpResp.StatusCode, b)
|
||||
return resp, err
|
||||
}
|
||||
data, err := io.ReadAll(httpResp.Body)
|
||||
@@ -358,7 +358,7 @@ func (e *CodexExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Au
|
||||
}
|
||||
appendAPIResponseChunk(ctx, e.cfg, data)
|
||||
logWithRequestID(ctx).Debugf("request error, error status: %d, error message: %s", httpResp.StatusCode, summarizeErrorBody(httpResp.Header.Get("Content-Type"), data))
|
||||
err = statusErr{code: httpResp.StatusCode, msg: string(data)}
|
||||
err = newCodexStatusErr(httpResp.StatusCode, data)
|
||||
return nil, err
|
||||
}
|
||||
out := make(chan cliproxyexecutor.StreamChunk)
|
||||
@@ -673,6 +673,35 @@ func applyCodexHeaders(r *http.Request, auth *cliproxyauth.Auth, token string, s
|
||||
util.ApplyCustomHeadersFromAttrs(r, attrs)
|
||||
}
|
||||
|
||||
func newCodexStatusErr(statusCode int, body []byte) statusErr {
|
||||
err := statusErr{code: statusCode, msg: string(body)}
|
||||
if retryAfter := parseCodexRetryAfter(statusCode, body, time.Now()); retryAfter != nil {
|
||||
err.retryAfter = retryAfter
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
func parseCodexRetryAfter(statusCode int, errorBody []byte, now time.Time) *time.Duration {
|
||||
if statusCode != http.StatusTooManyRequests || len(errorBody) == 0 {
|
||||
return nil
|
||||
}
|
||||
if strings.TrimSpace(gjson.GetBytes(errorBody, "error.type").String()) != "usage_limit_reached" {
|
||||
return nil
|
||||
}
|
||||
if resetsAt := gjson.GetBytes(errorBody, "error.resets_at").Int(); resetsAt > 0 {
|
||||
resetAtTime := time.Unix(resetsAt, 0)
|
||||
if resetAtTime.After(now) {
|
||||
retryAfter := resetAtTime.Sub(now)
|
||||
return &retryAfter
|
||||
}
|
||||
}
|
||||
if resetsInSeconds := gjson.GetBytes(errorBody, "error.resets_in_seconds").Int(); resetsInSeconds > 0 {
|
||||
retryAfter := time.Duration(resetsInSeconds) * time.Second
|
||||
return &retryAfter
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func codexCreds(a *cliproxyauth.Auth) (apiKey, baseURL string) {
|
||||
if a == nil {
|
||||
return "", ""
|
||||
|
||||
65
internal/runtime/executor/codex_executor_retry_test.go
Normal file
65
internal/runtime/executor/codex_executor_retry_test.go
Normal file
@@ -0,0 +1,65 @@
|
||||
package executor
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"strconv"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
func TestParseCodexRetryAfter(t *testing.T) {
|
||||
now := time.Unix(1_700_000_000, 0)
|
||||
|
||||
t.Run("resets_in_seconds", func(t *testing.T) {
|
||||
body := []byte(`{"error":{"type":"usage_limit_reached","resets_in_seconds":123}}`)
|
||||
retryAfter := parseCodexRetryAfter(http.StatusTooManyRequests, body, now)
|
||||
if retryAfter == nil {
|
||||
t.Fatalf("expected retryAfter, got nil")
|
||||
}
|
||||
if *retryAfter != 123*time.Second {
|
||||
t.Fatalf("retryAfter = %v, want %v", *retryAfter, 123*time.Second)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("prefers resets_at", func(t *testing.T) {
|
||||
resetAt := now.Add(5 * time.Minute).Unix()
|
||||
body := []byte(`{"error":{"type":"usage_limit_reached","resets_at":` + itoa(resetAt) + `,"resets_in_seconds":1}}`)
|
||||
retryAfter := parseCodexRetryAfter(http.StatusTooManyRequests, body, now)
|
||||
if retryAfter == nil {
|
||||
t.Fatalf("expected retryAfter, got nil")
|
||||
}
|
||||
if *retryAfter != 5*time.Minute {
|
||||
t.Fatalf("retryAfter = %v, want %v", *retryAfter, 5*time.Minute)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("fallback when resets_at is past", func(t *testing.T) {
|
||||
resetAt := now.Add(-1 * time.Minute).Unix()
|
||||
body := []byte(`{"error":{"type":"usage_limit_reached","resets_at":` + itoa(resetAt) + `,"resets_in_seconds":77}}`)
|
||||
retryAfter := parseCodexRetryAfter(http.StatusTooManyRequests, body, now)
|
||||
if retryAfter == nil {
|
||||
t.Fatalf("expected retryAfter, got nil")
|
||||
}
|
||||
if *retryAfter != 77*time.Second {
|
||||
t.Fatalf("retryAfter = %v, want %v", *retryAfter, 77*time.Second)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("non-429 status code", func(t *testing.T) {
|
||||
body := []byte(`{"error":{"type":"usage_limit_reached","resets_in_seconds":30}}`)
|
||||
if got := parseCodexRetryAfter(http.StatusBadRequest, body, now); got != nil {
|
||||
t.Fatalf("expected nil for non-429, got %v", *got)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("non usage_limit_reached error type", func(t *testing.T) {
|
||||
body := []byte(`{"error":{"type":"server_error","resets_in_seconds":30}}`)
|
||||
if got := parseCodexRetryAfter(http.StatusTooManyRequests, body, now); got != nil {
|
||||
t.Fatalf("expected nil for non-usage_limit_reached, got %v", *got)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func itoa(v int64) string {
|
||||
return strconv.FormatInt(v, 10)
|
||||
}
|
||||
@@ -8,6 +8,7 @@ import (
|
||||
"io"
|
||||
"net/http"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
qwenauth "github.com/router-for-me/CLIProxyAPI/v6/internal/auth/qwen"
|
||||
@@ -22,9 +23,151 @@ import (
|
||||
)
|
||||
|
||||
const (
|
||||
qwenUserAgent = "QwenCode/0.10.3 (darwin; arm64)"
|
||||
qwenUserAgent = "QwenCode/0.10.3 (darwin; arm64)"
|
||||
qwenRateLimitPerMin = 60 // 60 requests per minute per credential
|
||||
qwenRateLimitWindow = time.Minute // sliding window duration
|
||||
)
|
||||
|
||||
// qwenBeijingLoc caches the Beijing timezone to avoid repeated LoadLocation syscalls.
|
||||
var qwenBeijingLoc = func() *time.Location {
|
||||
loc, err := time.LoadLocation("Asia/Shanghai")
|
||||
if err != nil || loc == nil {
|
||||
log.Warnf("qwen: failed to load Asia/Shanghai timezone: %v, using fixed UTC+8", err)
|
||||
return time.FixedZone("CST", 8*3600)
|
||||
}
|
||||
return loc
|
||||
}()
|
||||
|
||||
// qwenQuotaCodes is a package-level set of error codes that indicate quota exhaustion.
|
||||
var qwenQuotaCodes = map[string]struct{}{
|
||||
"insufficient_quota": {},
|
||||
"quota_exceeded": {},
|
||||
}
|
||||
|
||||
// qwenRateLimiter tracks request timestamps per credential for rate limiting.
|
||||
// Qwen has a limit of 60 requests per minute per account.
|
||||
var qwenRateLimiter = struct {
|
||||
sync.Mutex
|
||||
requests map[string][]time.Time // authID -> request timestamps
|
||||
}{
|
||||
requests: make(map[string][]time.Time),
|
||||
}
|
||||
|
||||
// redactAuthID returns a redacted version of the auth ID for safe logging.
|
||||
// Keeps a small prefix/suffix to allow correlation across events.
|
||||
func redactAuthID(id string) string {
|
||||
if id == "" {
|
||||
return ""
|
||||
}
|
||||
if len(id) <= 8 {
|
||||
return id
|
||||
}
|
||||
return id[:4] + "..." + id[len(id)-4:]
|
||||
}
|
||||
|
||||
// checkQwenRateLimit checks if the credential has exceeded the rate limit.
|
||||
// Returns nil if allowed, or a statusErr with retryAfter if rate limited.
|
||||
func checkQwenRateLimit(authID string) error {
|
||||
if authID == "" {
|
||||
// Empty authID should not bypass rate limiting in production
|
||||
// Use debug level to avoid log spam for certain auth flows
|
||||
log.Debug("qwen rate limit check: empty authID, skipping rate limit")
|
||||
return nil
|
||||
}
|
||||
|
||||
now := time.Now()
|
||||
windowStart := now.Add(-qwenRateLimitWindow)
|
||||
|
||||
qwenRateLimiter.Lock()
|
||||
defer qwenRateLimiter.Unlock()
|
||||
|
||||
// Get and filter timestamps within the window
|
||||
timestamps := qwenRateLimiter.requests[authID]
|
||||
var validTimestamps []time.Time
|
||||
for _, ts := range timestamps {
|
||||
if ts.After(windowStart) {
|
||||
validTimestamps = append(validTimestamps, ts)
|
||||
}
|
||||
}
|
||||
|
||||
// Always prune expired entries to prevent memory leak
|
||||
// Delete empty entries, otherwise update with pruned slice
|
||||
if len(validTimestamps) == 0 {
|
||||
delete(qwenRateLimiter.requests, authID)
|
||||
}
|
||||
|
||||
// Check if rate limit exceeded
|
||||
if len(validTimestamps) >= qwenRateLimitPerMin {
|
||||
// Calculate when the oldest request will expire
|
||||
oldestInWindow := validTimestamps[0]
|
||||
retryAfter := oldestInWindow.Add(qwenRateLimitWindow).Sub(now)
|
||||
if retryAfter < time.Second {
|
||||
retryAfter = time.Second
|
||||
}
|
||||
retryAfterSec := int(retryAfter.Seconds())
|
||||
return statusErr{
|
||||
code: http.StatusTooManyRequests,
|
||||
msg: fmt.Sprintf(`{"error":{"code":"rate_limit_exceeded","message":"Qwen rate limit: %d requests/minute exceeded, retry after %ds","type":"rate_limit_exceeded"}}`, qwenRateLimitPerMin, retryAfterSec),
|
||||
retryAfter: &retryAfter,
|
||||
}
|
||||
}
|
||||
|
||||
// Record this request and update the map with pruned timestamps
|
||||
validTimestamps = append(validTimestamps, now)
|
||||
qwenRateLimiter.requests[authID] = validTimestamps
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// isQwenQuotaError checks if the error response indicates a quota exceeded error.
|
||||
// Qwen returns HTTP 403 with error.code="insufficient_quota" when daily quota is exhausted.
|
||||
func isQwenQuotaError(body []byte) bool {
|
||||
code := strings.ToLower(gjson.GetBytes(body, "error.code").String())
|
||||
errType := strings.ToLower(gjson.GetBytes(body, "error.type").String())
|
||||
|
||||
// Primary check: exact match on error.code or error.type (most reliable)
|
||||
if _, ok := qwenQuotaCodes[code]; ok {
|
||||
return true
|
||||
}
|
||||
if _, ok := qwenQuotaCodes[errType]; ok {
|
||||
return true
|
||||
}
|
||||
|
||||
// Fallback: check message only if code/type don't match (less reliable)
|
||||
msg := strings.ToLower(gjson.GetBytes(body, "error.message").String())
|
||||
if strings.Contains(msg, "insufficient_quota") || strings.Contains(msg, "quota exceeded") ||
|
||||
strings.Contains(msg, "free allocated quota exceeded") {
|
||||
return true
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
// wrapQwenError wraps an HTTP error response, detecting quota errors and mapping them to 429.
|
||||
// Returns the appropriate status code and retryAfter duration for statusErr.
|
||||
// Only checks for quota errors when httpCode is 403 or 429 to avoid false positives.
|
||||
func wrapQwenError(ctx context.Context, httpCode int, body []byte) (errCode int, retryAfter *time.Duration) {
|
||||
errCode = httpCode
|
||||
// Only check quota errors for expected status codes to avoid false positives
|
||||
// Qwen returns 403 for quota errors, 429 for rate limits
|
||||
if (httpCode == http.StatusForbidden || httpCode == http.StatusTooManyRequests) && isQwenQuotaError(body) {
|
||||
errCode = http.StatusTooManyRequests // Map to 429 to trigger quota logic
|
||||
cooldown := timeUntilNextDay()
|
||||
retryAfter = &cooldown
|
||||
logWithRequestID(ctx).Warnf("qwen quota exceeded (http %d -> %d), cooling down until tomorrow (%v)", httpCode, errCode, cooldown)
|
||||
}
|
||||
return errCode, retryAfter
|
||||
}
|
||||
|
||||
// timeUntilNextDay returns duration until midnight Beijing time (UTC+8).
|
||||
// Qwen's daily quota resets at 00:00 Beijing time.
|
||||
func timeUntilNextDay() time.Duration {
|
||||
now := time.Now()
|
||||
nowLocal := now.In(qwenBeijingLoc)
|
||||
tomorrow := time.Date(nowLocal.Year(), nowLocal.Month(), nowLocal.Day()+1, 0, 0, 0, 0, qwenBeijingLoc)
|
||||
return tomorrow.Sub(now)
|
||||
}
|
||||
|
||||
// QwenExecutor is a stateless executor for Qwen Code using OpenAI-compatible chat completions.
|
||||
// If access token is unavailable, it falls back to legacy via ClientAdapter.
|
||||
type QwenExecutor struct {
|
||||
@@ -67,6 +210,17 @@ func (e *QwenExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, req
|
||||
if opts.Alt == "responses/compact" {
|
||||
return resp, statusErr{code: http.StatusNotImplemented, msg: "/responses/compact not supported"}
|
||||
}
|
||||
|
||||
// Check rate limit before proceeding
|
||||
var authID string
|
||||
if auth != nil {
|
||||
authID = auth.ID
|
||||
}
|
||||
if err := checkQwenRateLimit(authID); err != nil {
|
||||
logWithRequestID(ctx).Warnf("qwen rate limit exceeded for credential %s", redactAuthID(authID))
|
||||
return resp, err
|
||||
}
|
||||
|
||||
baseModel := thinking.ParseSuffix(req.Model).ModelName
|
||||
|
||||
token, baseURL := qwenCreds(auth)
|
||||
@@ -102,9 +256,8 @@ func (e *QwenExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, req
|
||||
return resp, err
|
||||
}
|
||||
applyQwenHeaders(httpReq, token, false)
|
||||
var authID, authLabel, authType, authValue string
|
||||
var authLabel, authType, authValue string
|
||||
if auth != nil {
|
||||
authID = auth.ID
|
||||
authLabel = auth.Label
|
||||
authType, authValue = auth.AccountInfo()
|
||||
}
|
||||
@@ -135,8 +288,10 @@ func (e *QwenExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, req
|
||||
if httpResp.StatusCode < 200 || httpResp.StatusCode >= 300 {
|
||||
b, _ := io.ReadAll(httpResp.Body)
|
||||
appendAPIResponseChunk(ctx, e.cfg, b)
|
||||
logWithRequestID(ctx).Debugf("request error, error status: %d, error message: %s", httpResp.StatusCode, summarizeErrorBody(httpResp.Header.Get("Content-Type"), b))
|
||||
err = statusErr{code: httpResp.StatusCode, msg: string(b)}
|
||||
|
||||
errCode, retryAfter := wrapQwenError(ctx, httpResp.StatusCode, b)
|
||||
logWithRequestID(ctx).Debugf("request error, error status: %d (mapped: %d), error message: %s", httpResp.StatusCode, errCode, summarizeErrorBody(httpResp.Header.Get("Content-Type"), b))
|
||||
err = statusErr{code: errCode, msg: string(b), retryAfter: retryAfter}
|
||||
return resp, err
|
||||
}
|
||||
data, err := io.ReadAll(httpResp.Body)
|
||||
@@ -158,6 +313,17 @@ func (e *QwenExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Aut
|
||||
if opts.Alt == "responses/compact" {
|
||||
return nil, statusErr{code: http.StatusNotImplemented, msg: "/responses/compact not supported"}
|
||||
}
|
||||
|
||||
// Check rate limit before proceeding
|
||||
var authID string
|
||||
if auth != nil {
|
||||
authID = auth.ID
|
||||
}
|
||||
if err := checkQwenRateLimit(authID); err != nil {
|
||||
logWithRequestID(ctx).Warnf("qwen rate limit exceeded for credential %s", redactAuthID(authID))
|
||||
return nil, err
|
||||
}
|
||||
|
||||
baseModel := thinking.ParseSuffix(req.Model).ModelName
|
||||
|
||||
token, baseURL := qwenCreds(auth)
|
||||
@@ -200,9 +366,8 @@ func (e *QwenExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Aut
|
||||
return nil, err
|
||||
}
|
||||
applyQwenHeaders(httpReq, token, true)
|
||||
var authID, authLabel, authType, authValue string
|
||||
var authLabel, authType, authValue string
|
||||
if auth != nil {
|
||||
authID = auth.ID
|
||||
authLabel = auth.Label
|
||||
authType, authValue = auth.AccountInfo()
|
||||
}
|
||||
@@ -228,11 +393,13 @@ func (e *QwenExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Aut
|
||||
if httpResp.StatusCode < 200 || httpResp.StatusCode >= 300 {
|
||||
b, _ := io.ReadAll(httpResp.Body)
|
||||
appendAPIResponseChunk(ctx, e.cfg, b)
|
||||
logWithRequestID(ctx).Debugf("request error, error status: %d, error message: %s", httpResp.StatusCode, summarizeErrorBody(httpResp.Header.Get("Content-Type"), b))
|
||||
|
||||
errCode, retryAfter := wrapQwenError(ctx, httpResp.StatusCode, b)
|
||||
logWithRequestID(ctx).Debugf("request error, error status: %d (mapped: %d), error message: %s", httpResp.StatusCode, errCode, summarizeErrorBody(httpResp.Header.Get("Content-Type"), b))
|
||||
if errClose := httpResp.Body.Close(); errClose != nil {
|
||||
log.Errorf("qwen executor: close response body error: %v", errClose)
|
||||
}
|
||||
err = statusErr{code: httpResp.StatusCode, msg: string(b)}
|
||||
err = statusErr{code: errCode, msg: string(b), retryAfter: retryAfter}
|
||||
return nil, err
|
||||
}
|
||||
out := make(chan cliproxyexecutor.StreamChunk)
|
||||
|
||||
@@ -10,53 +10,10 @@ import (
|
||||
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/registry"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/thinking"
|
||||
log "github.com/sirupsen/logrus"
|
||||
"github.com/tidwall/gjson"
|
||||
"github.com/tidwall/sjson"
|
||||
)
|
||||
|
||||
// validReasoningEffortLevels contains the standard values accepted by the
|
||||
// OpenAI reasoning_effort field. Provider-specific extensions (xhigh, minimal,
|
||||
// auto) are NOT in this set and must be clamped before use.
|
||||
var validReasoningEffortLevels = map[string]struct{}{
|
||||
"none": {},
|
||||
"low": {},
|
||||
"medium": {},
|
||||
"high": {},
|
||||
}
|
||||
|
||||
// clampReasoningEffort maps any thinking level string to a value that is safe
|
||||
// to send as OpenAI reasoning_effort. Non-standard CPA-internal values are
|
||||
// mapped to the nearest standard equivalent.
|
||||
//
|
||||
// Mapping rules:
|
||||
// - none / low / medium / high → returned as-is (already valid)
|
||||
// - xhigh → "high" (nearest lower standard level)
|
||||
// - minimal → "low" (nearest higher standard level)
|
||||
// - auto → "medium" (reasonable default)
|
||||
// - anything else → "medium" (safe default)
|
||||
func clampReasoningEffort(level string) string {
|
||||
if _, ok := validReasoningEffortLevels[level]; ok {
|
||||
return level
|
||||
}
|
||||
var clamped string
|
||||
switch level {
|
||||
case string(thinking.LevelXHigh):
|
||||
clamped = string(thinking.LevelHigh)
|
||||
case string(thinking.LevelMinimal):
|
||||
clamped = string(thinking.LevelLow)
|
||||
case string(thinking.LevelAuto):
|
||||
clamped = string(thinking.LevelMedium)
|
||||
default:
|
||||
clamped = string(thinking.LevelMedium)
|
||||
}
|
||||
log.WithFields(log.Fields{
|
||||
"original": level,
|
||||
"clamped": clamped,
|
||||
}).Debug("openai: reasoning_effort clamped to nearest valid standard value")
|
||||
return clamped
|
||||
}
|
||||
|
||||
// Applier implements thinking.ProviderApplier for OpenAI models.
|
||||
//
|
||||
// OpenAI-specific behavior:
|
||||
@@ -101,7 +58,7 @@ func (a *Applier) Apply(body []byte, config thinking.ThinkingConfig, modelInfo *
|
||||
}
|
||||
|
||||
if config.Mode == thinking.ModeLevel {
|
||||
result, _ := sjson.SetBytes(body, "reasoning_effort", clampReasoningEffort(string(config.Level)))
|
||||
result, _ := sjson.SetBytes(body, "reasoning_effort", string(config.Level))
|
||||
return result, nil
|
||||
}
|
||||
|
||||
@@ -122,7 +79,7 @@ func (a *Applier) Apply(body []byte, config thinking.ThinkingConfig, modelInfo *
|
||||
return body, nil
|
||||
}
|
||||
|
||||
result, _ := sjson.SetBytes(body, "reasoning_effort", clampReasoningEffort(effort))
|
||||
result, _ := sjson.SetBytes(body, "reasoning_effort", effort)
|
||||
return result, nil
|
||||
}
|
||||
|
||||
@@ -157,7 +114,7 @@ func applyCompatibleOpenAI(body []byte, config thinking.ThinkingConfig) ([]byte,
|
||||
return body, nil
|
||||
}
|
||||
|
||||
result, _ := sjson.SetBytes(body, "reasoning_effort", clampReasoningEffort(effort))
|
||||
result, _ := sjson.SetBytes(body, "reasoning_effort", effort)
|
||||
return result, nil
|
||||
}
|
||||
|
||||
|
||||
@@ -223,14 +223,65 @@ func ConvertClaudeRequestToAntigravity(modelName string, inputRawJSON []byte, _
|
||||
functionResponseJSON, _ = sjson.Set(functionResponseJSON, "response.result", responseData)
|
||||
} else if functionResponseResult.IsArray() {
|
||||
frResults := functionResponseResult.Array()
|
||||
if len(frResults) == 1 {
|
||||
functionResponseJSON, _ = sjson.SetRaw(functionResponseJSON, "response.result", frResults[0].Raw)
|
||||
nonImageCount := 0
|
||||
lastNonImageRaw := ""
|
||||
filteredJSON := "[]"
|
||||
imagePartsJSON := "[]"
|
||||
for _, fr := range frResults {
|
||||
if fr.Get("type").String() == "image" && fr.Get("source.type").String() == "base64" {
|
||||
inlineDataJSON := `{}`
|
||||
if mimeType := fr.Get("source.media_type").String(); mimeType != "" {
|
||||
inlineDataJSON, _ = sjson.Set(inlineDataJSON, "mimeType", mimeType)
|
||||
}
|
||||
if data := fr.Get("source.data").String(); data != "" {
|
||||
inlineDataJSON, _ = sjson.Set(inlineDataJSON, "data", data)
|
||||
}
|
||||
|
||||
imagePartJSON := `{}`
|
||||
imagePartJSON, _ = sjson.SetRaw(imagePartJSON, "inlineData", inlineDataJSON)
|
||||
imagePartsJSON, _ = sjson.SetRaw(imagePartsJSON, "-1", imagePartJSON)
|
||||
continue
|
||||
}
|
||||
|
||||
nonImageCount++
|
||||
lastNonImageRaw = fr.Raw
|
||||
filteredJSON, _ = sjson.SetRaw(filteredJSON, "-1", fr.Raw)
|
||||
}
|
||||
|
||||
if nonImageCount == 1 {
|
||||
functionResponseJSON, _ = sjson.SetRaw(functionResponseJSON, "response.result", lastNonImageRaw)
|
||||
} else if nonImageCount > 1 {
|
||||
functionResponseJSON, _ = sjson.SetRaw(functionResponseJSON, "response.result", filteredJSON)
|
||||
} else {
|
||||
functionResponseJSON, _ = sjson.SetRaw(functionResponseJSON, "response.result", functionResponseResult.Raw)
|
||||
functionResponseJSON, _ = sjson.Set(functionResponseJSON, "response.result", "")
|
||||
}
|
||||
|
||||
// Place image data inside functionResponse.parts as inlineData
|
||||
// instead of as sibling parts in the outer content, to avoid
|
||||
// base64 data bloating the text context.
|
||||
if gjson.Get(imagePartsJSON, "#").Int() > 0 {
|
||||
functionResponseJSON, _ = sjson.SetRaw(functionResponseJSON, "parts", imagePartsJSON)
|
||||
}
|
||||
|
||||
} else if functionResponseResult.IsObject() {
|
||||
functionResponseJSON, _ = sjson.SetRaw(functionResponseJSON, "response.result", functionResponseResult.Raw)
|
||||
if functionResponseResult.Get("type").String() == "image" && functionResponseResult.Get("source.type").String() == "base64" {
|
||||
inlineDataJSON := `{}`
|
||||
if mimeType := functionResponseResult.Get("source.media_type").String(); mimeType != "" {
|
||||
inlineDataJSON, _ = sjson.Set(inlineDataJSON, "mimeType", mimeType)
|
||||
}
|
||||
if data := functionResponseResult.Get("source.data").String(); data != "" {
|
||||
inlineDataJSON, _ = sjson.Set(inlineDataJSON, "data", data)
|
||||
}
|
||||
|
||||
imagePartJSON := `{}`
|
||||
imagePartJSON, _ = sjson.SetRaw(imagePartJSON, "inlineData", inlineDataJSON)
|
||||
imagePartsJSON := "[]"
|
||||
imagePartsJSON, _ = sjson.SetRaw(imagePartsJSON, "-1", imagePartJSON)
|
||||
functionResponseJSON, _ = sjson.SetRaw(functionResponseJSON, "parts", imagePartsJSON)
|
||||
functionResponseJSON, _ = sjson.Set(functionResponseJSON, "response.result", "")
|
||||
} else {
|
||||
functionResponseJSON, _ = sjson.SetRaw(functionResponseJSON, "response.result", functionResponseResult.Raw)
|
||||
}
|
||||
} else if functionResponseResult.Raw != "" {
|
||||
functionResponseJSON, _ = sjson.SetRaw(functionResponseJSON, "response.result", functionResponseResult.Raw)
|
||||
} else {
|
||||
@@ -248,7 +299,7 @@ func ConvertClaudeRequestToAntigravity(modelName string, inputRawJSON []byte, _
|
||||
if sourceResult.Get("type").String() == "base64" {
|
||||
inlineDataJSON := `{}`
|
||||
if mimeType := sourceResult.Get("media_type").String(); mimeType != "" {
|
||||
inlineDataJSON, _ = sjson.Set(inlineDataJSON, "mime_type", mimeType)
|
||||
inlineDataJSON, _ = sjson.Set(inlineDataJSON, "mimeType", mimeType)
|
||||
}
|
||||
if data := sourceResult.Get("data").String(); data != "" {
|
||||
inlineDataJSON, _ = sjson.Set(inlineDataJSON, "data", data)
|
||||
|
||||
@@ -413,8 +413,8 @@ func TestConvertClaudeRequestToAntigravity_ImageContent(t *testing.T) {
|
||||
if !inlineData.Exists() {
|
||||
t.Error("inlineData should exist")
|
||||
}
|
||||
if inlineData.Get("mime_type").String() != "image/png" {
|
||||
t.Error("mime_type mismatch")
|
||||
if inlineData.Get("mimeType").String() != "image/png" {
|
||||
t.Error("mimeType mismatch")
|
||||
}
|
||||
if !strings.Contains(inlineData.Get("data").String(), "iVBORw0KGgo") {
|
||||
t.Error("data mismatch")
|
||||
@@ -740,6 +740,429 @@ func TestConvertClaudeRequestToAntigravity_ToolResultNullContent(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestConvertClaudeRequestToAntigravity_ToolResultWithImage(t *testing.T) {
|
||||
// tool_result with array content containing text + image should place
|
||||
// image data inside functionResponse.parts as inlineData, not as a
|
||||
// sibling part in the outer content (to avoid base64 context bloat).
|
||||
inputJSON := []byte(`{
|
||||
"model": "claude-3-5-sonnet-20240620",
|
||||
"messages": [
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{
|
||||
"type": "tool_result",
|
||||
"tool_use_id": "Read-123-456",
|
||||
"content": [
|
||||
{
|
||||
"type": "text",
|
||||
"text": "File content here"
|
||||
},
|
||||
{
|
||||
"type": "image",
|
||||
"source": {
|
||||
"type": "base64",
|
||||
"media_type": "image/png",
|
||||
"data": "iVBORw0KGgoAAAANSUhEUg=="
|
||||
}
|
||||
}
|
||||
]
|
||||
}
|
||||
]
|
||||
}
|
||||
]
|
||||
}`)
|
||||
|
||||
output := ConvertClaudeRequestToAntigravity("claude-sonnet-4-5", inputJSON, false)
|
||||
outputStr := string(output)
|
||||
|
||||
if !gjson.Valid(outputStr) {
|
||||
t.Fatalf("Result is not valid JSON:\n%s", outputStr)
|
||||
}
|
||||
|
||||
// Image should be inside functionResponse.parts, not as outer sibling part
|
||||
funcResp := gjson.Get(outputStr, "request.contents.0.parts.0.functionResponse")
|
||||
if !funcResp.Exists() {
|
||||
t.Fatal("functionResponse should exist")
|
||||
}
|
||||
|
||||
// Text content should be in response.result
|
||||
resultText := funcResp.Get("response.result.text").String()
|
||||
if resultText != "File content here" {
|
||||
t.Errorf("Expected response.result.text = 'File content here', got '%s'", resultText)
|
||||
}
|
||||
|
||||
// Image should be in functionResponse.parts[0].inlineData
|
||||
inlineData := funcResp.Get("parts.0.inlineData")
|
||||
if !inlineData.Exists() {
|
||||
t.Fatal("functionResponse.parts[0].inlineData should exist")
|
||||
}
|
||||
if inlineData.Get("mimeType").String() != "image/png" {
|
||||
t.Errorf("Expected mimeType 'image/png', got '%s'", inlineData.Get("mimeType").String())
|
||||
}
|
||||
if !strings.Contains(inlineData.Get("data").String(), "iVBORw0KGgo") {
|
||||
t.Error("data mismatch")
|
||||
}
|
||||
|
||||
// Image should NOT be in outer parts (only functionResponse part should exist)
|
||||
outerParts := gjson.Get(outputStr, "request.contents.0.parts")
|
||||
if outerParts.IsArray() && len(outerParts.Array()) > 1 {
|
||||
t.Errorf("Expected only 1 outer part (functionResponse), got %d", len(outerParts.Array()))
|
||||
}
|
||||
}
|
||||
|
||||
func TestConvertClaudeRequestToAntigravity_ToolResultWithSingleImage(t *testing.T) {
|
||||
// tool_result with single image object as content should place
|
||||
// image data inside functionResponse.parts, not as outer sibling part.
|
||||
inputJSON := []byte(`{
|
||||
"model": "claude-3-5-sonnet-20240620",
|
||||
"messages": [
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{
|
||||
"type": "tool_result",
|
||||
"tool_use_id": "Read-789-012",
|
||||
"content": {
|
||||
"type": "image",
|
||||
"source": {
|
||||
"type": "base64",
|
||||
"media_type": "image/jpeg",
|
||||
"data": "/9j/4AAQSkZJRgABAQ=="
|
||||
}
|
||||
}
|
||||
}
|
||||
]
|
||||
}
|
||||
]
|
||||
}`)
|
||||
|
||||
output := ConvertClaudeRequestToAntigravity("claude-sonnet-4-5", inputJSON, false)
|
||||
outputStr := string(output)
|
||||
|
||||
if !gjson.Valid(outputStr) {
|
||||
t.Fatalf("Result is not valid JSON:\n%s", outputStr)
|
||||
}
|
||||
|
||||
funcResp := gjson.Get(outputStr, "request.contents.0.parts.0.functionResponse")
|
||||
if !funcResp.Exists() {
|
||||
t.Fatal("functionResponse should exist")
|
||||
}
|
||||
|
||||
// response.result should be empty (image only)
|
||||
if funcResp.Get("response.result").String() != "" {
|
||||
t.Errorf("Expected empty response.result for image-only content, got '%s'", funcResp.Get("response.result").String())
|
||||
}
|
||||
|
||||
// Image should be in functionResponse.parts[0].inlineData
|
||||
inlineData := funcResp.Get("parts.0.inlineData")
|
||||
if !inlineData.Exists() {
|
||||
t.Fatal("functionResponse.parts[0].inlineData should exist")
|
||||
}
|
||||
if inlineData.Get("mimeType").String() != "image/jpeg" {
|
||||
t.Errorf("Expected mimeType 'image/jpeg', got '%s'", inlineData.Get("mimeType").String())
|
||||
}
|
||||
|
||||
// Image should NOT be in outer parts
|
||||
outerParts := gjson.Get(outputStr, "request.contents.0.parts")
|
||||
if outerParts.IsArray() && len(outerParts.Array()) > 1 {
|
||||
t.Errorf("Expected only 1 outer part, got %d", len(outerParts.Array()))
|
||||
}
|
||||
}
|
||||
|
||||
func TestConvertClaudeRequestToAntigravity_ToolResultWithMultipleImagesAndTexts(t *testing.T) {
|
||||
// tool_result with array content: 2 text items + 2 images
|
||||
// All images go into functionResponse.parts, texts into response.result array
|
||||
inputJSON := []byte(`{
|
||||
"model": "claude-3-5-sonnet-20240620",
|
||||
"messages": [
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{
|
||||
"type": "tool_result",
|
||||
"tool_use_id": "Multi-001",
|
||||
"content": [
|
||||
{"type": "text", "text": "First text"},
|
||||
{
|
||||
"type": "image",
|
||||
"source": {"type": "base64", "media_type": "image/png", "data": "AAAA"}
|
||||
},
|
||||
{"type": "text", "text": "Second text"},
|
||||
{
|
||||
"type": "image",
|
||||
"source": {"type": "base64", "media_type": "image/jpeg", "data": "BBBB"}
|
||||
}
|
||||
]
|
||||
}
|
||||
]
|
||||
}
|
||||
]
|
||||
}`)
|
||||
|
||||
output := ConvertClaudeRequestToAntigravity("claude-sonnet-4-5", inputJSON, false)
|
||||
outputStr := string(output)
|
||||
|
||||
if !gjson.Valid(outputStr) {
|
||||
t.Fatalf("Result is not valid JSON:\n%s", outputStr)
|
||||
}
|
||||
|
||||
funcResp := gjson.Get(outputStr, "request.contents.0.parts.0.functionResponse")
|
||||
if !funcResp.Exists() {
|
||||
t.Fatal("functionResponse should exist")
|
||||
}
|
||||
|
||||
// Multiple text items => response.result is an array
|
||||
resultArr := funcResp.Get("response.result")
|
||||
if !resultArr.IsArray() {
|
||||
t.Fatalf("Expected response.result to be an array, got: %s", resultArr.Raw)
|
||||
}
|
||||
results := resultArr.Array()
|
||||
if len(results) != 2 {
|
||||
t.Fatalf("Expected 2 result items, got %d", len(results))
|
||||
}
|
||||
|
||||
// Both images should be in functionResponse.parts
|
||||
imgParts := funcResp.Get("parts").Array()
|
||||
if len(imgParts) != 2 {
|
||||
t.Fatalf("Expected 2 image parts in functionResponse.parts, got %d", len(imgParts))
|
||||
}
|
||||
if imgParts[0].Get("inlineData.mimeType").String() != "image/png" {
|
||||
t.Errorf("Expected first image mimeType 'image/png', got '%s'", imgParts[0].Get("inlineData.mimeType").String())
|
||||
}
|
||||
if imgParts[0].Get("inlineData.data").String() != "AAAA" {
|
||||
t.Errorf("Expected first image data 'AAAA', got '%s'", imgParts[0].Get("inlineData.data").String())
|
||||
}
|
||||
if imgParts[1].Get("inlineData.mimeType").String() != "image/jpeg" {
|
||||
t.Errorf("Expected second image mimeType 'image/jpeg', got '%s'", imgParts[1].Get("inlineData.mimeType").String())
|
||||
}
|
||||
if imgParts[1].Get("inlineData.data").String() != "BBBB" {
|
||||
t.Errorf("Expected second image data 'BBBB', got '%s'", imgParts[1].Get("inlineData.data").String())
|
||||
}
|
||||
|
||||
// Only 1 outer part (the functionResponse itself)
|
||||
outerParts := gjson.Get(outputStr, "request.contents.0.parts").Array()
|
||||
if len(outerParts) != 1 {
|
||||
t.Errorf("Expected 1 outer part, got %d", len(outerParts))
|
||||
}
|
||||
}
|
||||
|
||||
func TestConvertClaudeRequestToAntigravity_ToolResultWithOnlyMultipleImages(t *testing.T) {
|
||||
// tool_result with only images (no text) — response.result should be empty string
|
||||
inputJSON := []byte(`{
|
||||
"model": "claude-3-5-sonnet-20240620",
|
||||
"messages": [
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{
|
||||
"type": "tool_result",
|
||||
"tool_use_id": "ImgOnly-001",
|
||||
"content": [
|
||||
{
|
||||
"type": "image",
|
||||
"source": {"type": "base64", "media_type": "image/png", "data": "PNG1"}
|
||||
},
|
||||
{
|
||||
"type": "image",
|
||||
"source": {"type": "base64", "media_type": "image/gif", "data": "GIF1"}
|
||||
}
|
||||
]
|
||||
}
|
||||
]
|
||||
}
|
||||
]
|
||||
}`)
|
||||
|
||||
output := ConvertClaudeRequestToAntigravity("claude-sonnet-4-5", inputJSON, false)
|
||||
outputStr := string(output)
|
||||
|
||||
if !gjson.Valid(outputStr) {
|
||||
t.Fatalf("Result is not valid JSON:\n%s", outputStr)
|
||||
}
|
||||
|
||||
funcResp := gjson.Get(outputStr, "request.contents.0.parts.0.functionResponse")
|
||||
if !funcResp.Exists() {
|
||||
t.Fatal("functionResponse should exist")
|
||||
}
|
||||
|
||||
// No text => response.result should be empty string
|
||||
if funcResp.Get("response.result").String() != "" {
|
||||
t.Errorf("Expected empty response.result, got '%s'", funcResp.Get("response.result").String())
|
||||
}
|
||||
|
||||
// Both images in functionResponse.parts
|
||||
imgParts := funcResp.Get("parts").Array()
|
||||
if len(imgParts) != 2 {
|
||||
t.Fatalf("Expected 2 image parts, got %d", len(imgParts))
|
||||
}
|
||||
if imgParts[0].Get("inlineData.mimeType").String() != "image/png" {
|
||||
t.Error("first image mimeType mismatch")
|
||||
}
|
||||
if imgParts[1].Get("inlineData.mimeType").String() != "image/gif" {
|
||||
t.Error("second image mimeType mismatch")
|
||||
}
|
||||
|
||||
// Only 1 outer part
|
||||
outerParts := gjson.Get(outputStr, "request.contents.0.parts").Array()
|
||||
if len(outerParts) != 1 {
|
||||
t.Errorf("Expected 1 outer part, got %d", len(outerParts))
|
||||
}
|
||||
}
|
||||
|
||||
func TestConvertClaudeRequestToAntigravity_ToolResultImageNotBase64(t *testing.T) {
|
||||
// image with source.type != "base64" should be treated as non-image (falls through)
|
||||
inputJSON := []byte(`{
|
||||
"model": "claude-3-5-sonnet-20240620",
|
||||
"messages": [
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{
|
||||
"type": "tool_result",
|
||||
"tool_use_id": "NotB64-001",
|
||||
"content": [
|
||||
{"type": "text", "text": "some output"},
|
||||
{
|
||||
"type": "image",
|
||||
"source": {"type": "url", "url": "https://example.com/img.png"}
|
||||
}
|
||||
]
|
||||
}
|
||||
]
|
||||
}
|
||||
]
|
||||
}`)
|
||||
|
||||
output := ConvertClaudeRequestToAntigravity("claude-sonnet-4-5", inputJSON, false)
|
||||
outputStr := string(output)
|
||||
|
||||
if !gjson.Valid(outputStr) {
|
||||
t.Fatalf("Result is not valid JSON:\n%s", outputStr)
|
||||
}
|
||||
|
||||
funcResp := gjson.Get(outputStr, "request.contents.0.parts.0.functionResponse")
|
||||
if !funcResp.Exists() {
|
||||
t.Fatal("functionResponse should exist")
|
||||
}
|
||||
|
||||
// Non-base64 image is treated as non-image, so it goes into the filtered results
|
||||
// along with the text item. Since there are 2 non-image items, result is array.
|
||||
resultArr := funcResp.Get("response.result")
|
||||
if !resultArr.IsArray() {
|
||||
t.Fatalf("Expected response.result to be an array (2 non-image items), got: %s", resultArr.Raw)
|
||||
}
|
||||
results := resultArr.Array()
|
||||
if len(results) != 2 {
|
||||
t.Fatalf("Expected 2 result items, got %d", len(results))
|
||||
}
|
||||
|
||||
// No functionResponse.parts (no base64 images collected)
|
||||
if funcResp.Get("parts").Exists() {
|
||||
t.Error("functionResponse.parts should NOT exist when no base64 images")
|
||||
}
|
||||
}
|
||||
|
||||
func TestConvertClaudeRequestToAntigravity_ToolResultImageMissingData(t *testing.T) {
|
||||
// image with source.type=base64 but missing data field
|
||||
inputJSON := []byte(`{
|
||||
"model": "claude-3-5-sonnet-20240620",
|
||||
"messages": [
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{
|
||||
"type": "tool_result",
|
||||
"tool_use_id": "NoData-001",
|
||||
"content": [
|
||||
{"type": "text", "text": "output"},
|
||||
{
|
||||
"type": "image",
|
||||
"source": {"type": "base64", "media_type": "image/png"}
|
||||
}
|
||||
]
|
||||
}
|
||||
]
|
||||
}
|
||||
]
|
||||
}`)
|
||||
|
||||
output := ConvertClaudeRequestToAntigravity("claude-sonnet-4-5", inputJSON, false)
|
||||
outputStr := string(output)
|
||||
|
||||
if !gjson.Valid(outputStr) {
|
||||
t.Fatalf("Result is not valid JSON:\n%s", outputStr)
|
||||
}
|
||||
|
||||
funcResp := gjson.Get(outputStr, "request.contents.0.parts.0.functionResponse")
|
||||
if !funcResp.Exists() {
|
||||
t.Fatal("functionResponse should exist")
|
||||
}
|
||||
|
||||
// The image is still classified as base64 image (type check passes),
|
||||
// but data field is missing => inlineData has mimeType but no data
|
||||
imgParts := funcResp.Get("parts").Array()
|
||||
if len(imgParts) != 1 {
|
||||
t.Fatalf("Expected 1 image part, got %d", len(imgParts))
|
||||
}
|
||||
if imgParts[0].Get("inlineData.mimeType").String() != "image/png" {
|
||||
t.Error("mimeType should still be set")
|
||||
}
|
||||
if imgParts[0].Get("inlineData.data").Exists() {
|
||||
t.Error("data should not exist when source.data is missing")
|
||||
}
|
||||
}
|
||||
|
||||
func TestConvertClaudeRequestToAntigravity_ToolResultImageMissingMediaType(t *testing.T) {
|
||||
// image with source.type=base64 but missing media_type field
|
||||
inputJSON := []byte(`{
|
||||
"model": "claude-3-5-sonnet-20240620",
|
||||
"messages": [
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{
|
||||
"type": "tool_result",
|
||||
"tool_use_id": "NoMime-001",
|
||||
"content": [
|
||||
{"type": "text", "text": "output"},
|
||||
{
|
||||
"type": "image",
|
||||
"source": {"type": "base64", "data": "AAAA"}
|
||||
}
|
||||
]
|
||||
}
|
||||
]
|
||||
}
|
||||
]
|
||||
}`)
|
||||
|
||||
output := ConvertClaudeRequestToAntigravity("claude-sonnet-4-5", inputJSON, false)
|
||||
outputStr := string(output)
|
||||
|
||||
if !gjson.Valid(outputStr) {
|
||||
t.Fatalf("Result is not valid JSON:\n%s", outputStr)
|
||||
}
|
||||
|
||||
funcResp := gjson.Get(outputStr, "request.contents.0.parts.0.functionResponse")
|
||||
if !funcResp.Exists() {
|
||||
t.Fatal("functionResponse should exist")
|
||||
}
|
||||
|
||||
// The image is still classified as base64 image,
|
||||
// but media_type is missing => inlineData has data but no mimeType
|
||||
imgParts := funcResp.Get("parts").Array()
|
||||
if len(imgParts) != 1 {
|
||||
t.Fatalf("Expected 1 image part, got %d", len(imgParts))
|
||||
}
|
||||
if imgParts[0].Get("inlineData.mimeType").Exists() {
|
||||
t.Error("mimeType should not exist when media_type is missing")
|
||||
}
|
||||
if imgParts[0].Get("inlineData.data").String() != "AAAA" {
|
||||
t.Error("data should still be set")
|
||||
}
|
||||
}
|
||||
|
||||
func TestConvertClaudeRequestToAntigravity_ToolAndThinking_NoExistingSystem(t *testing.T) {
|
||||
// When tools + thinking but no system instruction, should create one with hint
|
||||
inputJSON := []byte(`{
|
||||
|
||||
@@ -93,3 +93,81 @@ func TestConvertGeminiRequestToAntigravity_ParallelFunctionCalls(t *testing.T) {
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestFixCLIToolResponse_PreservesFunctionResponseParts(t *testing.T) {
|
||||
// When functionResponse contains a "parts" field with inlineData (from Claude
|
||||
// translator's image embedding), fixCLIToolResponse should preserve it as-is.
|
||||
// parseFunctionResponseRaw returns response.Raw for valid JSON objects,
|
||||
// so extra fields like "parts" survive the pipeline.
|
||||
input := `{
|
||||
"model": "claude-opus-4-6-thinking",
|
||||
"request": {
|
||||
"contents": [
|
||||
{
|
||||
"role": "model",
|
||||
"parts": [
|
||||
{
|
||||
"functionCall": {"name": "screenshot", "args": {}}
|
||||
}
|
||||
]
|
||||
},
|
||||
{
|
||||
"role": "function",
|
||||
"parts": [
|
||||
{
|
||||
"functionResponse": {
|
||||
"id": "tool-001",
|
||||
"name": "screenshot",
|
||||
"response": {"result": "Screenshot taken"},
|
||||
"parts": [
|
||||
{"inlineData": {"mimeType": "image/png", "data": "iVBOR"}}
|
||||
]
|
||||
}
|
||||
}
|
||||
]
|
||||
}
|
||||
]
|
||||
}
|
||||
}`
|
||||
|
||||
result, err := fixCLIToolResponse(input)
|
||||
if err != nil {
|
||||
t.Fatalf("fixCLIToolResponse failed: %v", err)
|
||||
}
|
||||
|
||||
// Find the function response content (role=function)
|
||||
contents := gjson.Get(result, "request.contents").Array()
|
||||
var funcContent gjson.Result
|
||||
for _, c := range contents {
|
||||
if c.Get("role").String() == "function" {
|
||||
funcContent = c
|
||||
break
|
||||
}
|
||||
}
|
||||
if !funcContent.Exists() {
|
||||
t.Fatal("function role content should exist in output")
|
||||
}
|
||||
|
||||
// The functionResponse should be preserved with its parts field
|
||||
funcResp := funcContent.Get("parts.0.functionResponse")
|
||||
if !funcResp.Exists() {
|
||||
t.Fatal("functionResponse should exist in output")
|
||||
}
|
||||
|
||||
// Verify the parts field with inlineData is preserved
|
||||
inlineParts := funcResp.Get("parts").Array()
|
||||
if len(inlineParts) != 1 {
|
||||
t.Fatalf("Expected 1 inlineData part in functionResponse.parts, got %d", len(inlineParts))
|
||||
}
|
||||
if inlineParts[0].Get("inlineData.mimeType").String() != "image/png" {
|
||||
t.Errorf("Expected mimeType 'image/png', got '%s'", inlineParts[0].Get("inlineData.mimeType").String())
|
||||
}
|
||||
if inlineParts[0].Get("inlineData.data").String() != "iVBOR" {
|
||||
t.Errorf("Expected data 'iVBOR', got '%s'", inlineParts[0].Get("inlineData.data").String())
|
||||
}
|
||||
|
||||
// Verify response.result is also preserved
|
||||
if funcResp.Get("response.result").String() != "Screenshot taken" {
|
||||
t.Errorf("Expected response.result 'Screenshot taken', got '%s'", funcResp.Get("response.result").String())
|
||||
}
|
||||
}
|
||||
|
||||
@@ -187,7 +187,7 @@ func ConvertOpenAIRequestToAntigravity(modelName string, inputRawJSON []byte, _
|
||||
if len(pieces) == 2 && len(pieces[1]) > 7 {
|
||||
mime := pieces[0]
|
||||
data := pieces[1][7:]
|
||||
node, _ = sjson.SetBytes(node, "parts."+itoa(p)+".inlineData.mime_type", mime)
|
||||
node, _ = sjson.SetBytes(node, "parts."+itoa(p)+".inlineData.mimeType", mime)
|
||||
node, _ = sjson.SetBytes(node, "parts."+itoa(p)+".inlineData.data", data)
|
||||
node, _ = sjson.SetBytes(node, "parts."+itoa(p)+".thoughtSignature", geminiCLIFunctionThoughtSignature)
|
||||
p++
|
||||
@@ -201,7 +201,7 @@ func ConvertOpenAIRequestToAntigravity(modelName string, inputRawJSON []byte, _
|
||||
ext = sp[len(sp)-1]
|
||||
}
|
||||
if mimeType, ok := misc.MimeTypes[ext]; ok {
|
||||
node, _ = sjson.SetBytes(node, "parts."+itoa(p)+".inlineData.mime_type", mimeType)
|
||||
node, _ = sjson.SetBytes(node, "parts."+itoa(p)+".inlineData.mimeType", mimeType)
|
||||
node, _ = sjson.SetBytes(node, "parts."+itoa(p)+".inlineData.data", fileData)
|
||||
p++
|
||||
} else {
|
||||
@@ -235,7 +235,7 @@ func ConvertOpenAIRequestToAntigravity(modelName string, inputRawJSON []byte, _
|
||||
if len(pieces) == 2 && len(pieces[1]) > 7 {
|
||||
mime := pieces[0]
|
||||
data := pieces[1][7:]
|
||||
node, _ = sjson.SetBytes(node, "parts."+itoa(p)+".inlineData.mime_type", mime)
|
||||
node, _ = sjson.SetBytes(node, "parts."+itoa(p)+".inlineData.mimeType", mime)
|
||||
node, _ = sjson.SetBytes(node, "parts."+itoa(p)+".inlineData.data", data)
|
||||
node, _ = sjson.SetBytes(node, "parts."+itoa(p)+".thoughtSignature", geminiCLIFunctionThoughtSignature)
|
||||
p++
|
||||
|
||||
@@ -95,9 +95,9 @@ func ConvertAntigravityResponseToOpenAI(_ context.Context, _ string, originalReq
|
||||
if totalTokenCountResult := usageResult.Get("totalTokenCount"); totalTokenCountResult.Exists() {
|
||||
template, _ = sjson.Set(template, "usage.total_tokens", totalTokenCountResult.Int())
|
||||
}
|
||||
promptTokenCount := usageResult.Get("promptTokenCount").Int() - cachedTokenCount
|
||||
promptTokenCount := usageResult.Get("promptTokenCount").Int()
|
||||
thoughtsTokenCount := usageResult.Get("thoughtsTokenCount").Int()
|
||||
template, _ = sjson.Set(template, "usage.prompt_tokens", promptTokenCount+thoughtsTokenCount)
|
||||
template, _ = sjson.Set(template, "usage.prompt_tokens", promptTokenCount)
|
||||
if thoughtsTokenCount > 0 {
|
||||
template, _ = sjson.Set(template, "usage.completion_tokens_details.reasoning_tokens", thoughtsTokenCount)
|
||||
}
|
||||
|
||||
@@ -199,6 +199,21 @@ func ConvertOpenAIRequestToClaude(modelName string, inputRawJSON []byte, stream
|
||||
msg, _ = sjson.SetRaw(msg, "content.-1", imagePart)
|
||||
}
|
||||
}
|
||||
|
||||
case "file":
|
||||
fileData := part.Get("file.file_data").String()
|
||||
if strings.HasPrefix(fileData, "data:") {
|
||||
semicolonIdx := strings.Index(fileData, ";")
|
||||
commaIdx := strings.Index(fileData, ",")
|
||||
if semicolonIdx != -1 && commaIdx != -1 && commaIdx > semicolonIdx {
|
||||
mediaType := strings.TrimPrefix(fileData[:semicolonIdx], "data:")
|
||||
data := fileData[commaIdx+1:]
|
||||
docPart := `{"type":"document","source":{"type":"base64","media_type":"","data":""}}`
|
||||
docPart, _ = sjson.Set(docPart, "source.media_type", mediaType)
|
||||
docPart, _ = sjson.Set(docPart, "source.data", data)
|
||||
msg, _ = sjson.SetRaw(msg, "content.-1", docPart)
|
||||
}
|
||||
}
|
||||
}
|
||||
return true
|
||||
})
|
||||
|
||||
@@ -155,6 +155,7 @@ func ConvertOpenAIResponsesRequestToClaude(modelName string, inputRawJSON []byte
|
||||
var textAggregate strings.Builder
|
||||
var partsJSON []string
|
||||
hasImage := false
|
||||
hasFile := false
|
||||
if parts := item.Get("content"); parts.Exists() && parts.IsArray() {
|
||||
parts.ForEach(func(_, part gjson.Result) bool {
|
||||
ptype := part.Get("type").String()
|
||||
@@ -207,6 +208,30 @@ func ConvertOpenAIResponsesRequestToClaude(modelName string, inputRawJSON []byte
|
||||
hasImage = true
|
||||
}
|
||||
}
|
||||
case "input_file":
|
||||
fileData := part.Get("file_data").String()
|
||||
if fileData != "" {
|
||||
mediaType := "application/octet-stream"
|
||||
data := fileData
|
||||
if strings.HasPrefix(fileData, "data:") {
|
||||
trimmed := strings.TrimPrefix(fileData, "data:")
|
||||
mediaAndData := strings.SplitN(trimmed, ";base64,", 2)
|
||||
if len(mediaAndData) == 2 {
|
||||
if mediaAndData[0] != "" {
|
||||
mediaType = mediaAndData[0]
|
||||
}
|
||||
data = mediaAndData[1]
|
||||
}
|
||||
}
|
||||
contentPart := `{"type":"document","source":{"type":"base64","media_type":"","data":""}}`
|
||||
contentPart, _ = sjson.Set(contentPart, "source.media_type", mediaType)
|
||||
contentPart, _ = sjson.Set(contentPart, "source.data", data)
|
||||
partsJSON = append(partsJSON, contentPart)
|
||||
if role == "" {
|
||||
role = "user"
|
||||
}
|
||||
hasFile = true
|
||||
}
|
||||
}
|
||||
return true
|
||||
})
|
||||
@@ -228,7 +253,7 @@ func ConvertOpenAIResponsesRequestToClaude(modelName string, inputRawJSON []byte
|
||||
if len(partsJSON) > 0 {
|
||||
msg := `{"role":"","content":[]}`
|
||||
msg, _ = sjson.Set(msg, "role", role)
|
||||
if len(partsJSON) == 1 && !hasImage {
|
||||
if len(partsJSON) == 1 && !hasImage && !hasFile {
|
||||
// Preserve legacy behavior for single text content
|
||||
msg, _ = sjson.Delete(msg, "content")
|
||||
textPart := gjson.Parse(partsJSON[0])
|
||||
|
||||
@@ -180,7 +180,19 @@ func ConvertOpenAIRequestToCodex(modelName string, inputRawJSON []byte, stream b
|
||||
msg, _ = sjson.SetRaw(msg, "content.-1", part)
|
||||
}
|
||||
case "file":
|
||||
// Files are not specified in examples; skip for now
|
||||
if role == "user" {
|
||||
fileData := it.Get("file.file_data").String()
|
||||
filename := it.Get("file.filename").String()
|
||||
if fileData != "" {
|
||||
part := `{}`
|
||||
part, _ = sjson.Set(part, "type", "input_file")
|
||||
part, _ = sjson.Set(part, "file_data", fileData)
|
||||
if filename != "" {
|
||||
part, _ = sjson.Set(part, "filename", filename)
|
||||
}
|
||||
msg, _ = sjson.SetRaw(msg, "content.-1", part)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -26,6 +26,8 @@ func ConvertOpenAIResponsesRequestToCodex(modelName string, inputRawJSON []byte,
|
||||
rawJSON, _ = sjson.DeleteBytes(rawJSON, "temperature")
|
||||
rawJSON, _ = sjson.DeleteBytes(rawJSON, "top_p")
|
||||
rawJSON, _ = sjson.DeleteBytes(rawJSON, "service_tier")
|
||||
rawJSON, _ = sjson.DeleteBytes(rawJSON, "truncation")
|
||||
rawJSON = applyResponsesCompactionCompatibility(rawJSON)
|
||||
|
||||
// Delete the user field as it is not supported by the Codex upstream.
|
||||
rawJSON, _ = sjson.DeleteBytes(rawJSON, "user")
|
||||
@@ -36,6 +38,23 @@ func ConvertOpenAIResponsesRequestToCodex(modelName string, inputRawJSON []byte,
|
||||
return rawJSON
|
||||
}
|
||||
|
||||
// applyResponsesCompactionCompatibility handles OpenAI Responses context_management.compaction
|
||||
// for Codex upstream compatibility.
|
||||
//
|
||||
// Codex /responses currently rejects context_management with:
|
||||
// {"detail":"Unsupported parameter: context_management"}.
|
||||
//
|
||||
// Compatibility strategy:
|
||||
// 1) Remove context_management before forwarding to Codex upstream.
|
||||
func applyResponsesCompactionCompatibility(rawJSON []byte) []byte {
|
||||
if !gjson.GetBytes(rawJSON, "context_management").Exists() {
|
||||
return rawJSON
|
||||
}
|
||||
|
||||
rawJSON, _ = sjson.DeleteBytes(rawJSON, "context_management")
|
||||
return rawJSON
|
||||
}
|
||||
|
||||
// convertSystemRoleToDeveloper traverses the input array and converts any message items
|
||||
// with role "system" to role "developer". This is necessary because Codex API does not
|
||||
// accept "system" role in the input array.
|
||||
|
||||
@@ -280,3 +280,41 @@ func TestUserFieldDeletion(t *testing.T) {
|
||||
t.Errorf("user field should be deleted, but it was found with value: %s", userField.Raw)
|
||||
}
|
||||
}
|
||||
|
||||
func TestContextManagementCompactionCompatibility(t *testing.T) {
|
||||
inputJSON := []byte(`{
|
||||
"model": "gpt-5.2",
|
||||
"context_management": [
|
||||
{
|
||||
"type": "compaction",
|
||||
"compact_threshold": 12000
|
||||
}
|
||||
],
|
||||
"input": [{"role":"user","content":"hello"}]
|
||||
}`)
|
||||
|
||||
output := ConvertOpenAIResponsesRequestToCodex("gpt-5.2", inputJSON, false)
|
||||
outputStr := string(output)
|
||||
|
||||
if gjson.Get(outputStr, "context_management").Exists() {
|
||||
t.Fatalf("context_management should be removed for Codex compatibility")
|
||||
}
|
||||
if gjson.Get(outputStr, "truncation").Exists() {
|
||||
t.Fatalf("truncation should be removed for Codex compatibility")
|
||||
}
|
||||
}
|
||||
|
||||
func TestTruncationRemovedForCodexCompatibility(t *testing.T) {
|
||||
inputJSON := []byte(`{
|
||||
"model": "gpt-5.2",
|
||||
"truncation": "disabled",
|
||||
"input": [{"role":"user","content":"hello"}]
|
||||
}`)
|
||||
|
||||
output := ConvertOpenAIResponsesRequestToCodex("gpt-5.2", inputJSON, false)
|
||||
outputStr := string(output)
|
||||
|
||||
if gjson.Get(outputStr, "truncation").Exists() {
|
||||
t.Fatalf("truncation should be removed for Codex compatibility")
|
||||
}
|
||||
}
|
||||
|
||||
@@ -100,7 +100,7 @@ func ConvertCliResponseToOpenAI(_ context.Context, _ string, originalRequestRawJ
|
||||
}
|
||||
promptTokenCount := usageResult.Get("promptTokenCount").Int()
|
||||
thoughtsTokenCount := usageResult.Get("thoughtsTokenCount").Int()
|
||||
template, _ = sjson.Set(template, "usage.prompt_tokens", promptTokenCount+thoughtsTokenCount)
|
||||
template, _ = sjson.Set(template, "usage.prompt_tokens", promptTokenCount)
|
||||
if thoughtsTokenCount > 0 {
|
||||
template, _ = sjson.Set(template, "usage.completion_tokens_details.reasoning_tokens", thoughtsTokenCount)
|
||||
}
|
||||
|
||||
@@ -100,9 +100,9 @@ func ConvertGeminiResponseToOpenAI(_ context.Context, _ string, originalRequestR
|
||||
if totalTokenCountResult := usageResult.Get("totalTokenCount"); totalTokenCountResult.Exists() {
|
||||
baseTemplate, _ = sjson.Set(baseTemplate, "usage.total_tokens", totalTokenCountResult.Int())
|
||||
}
|
||||
promptTokenCount := usageResult.Get("promptTokenCount").Int() - cachedTokenCount
|
||||
promptTokenCount := usageResult.Get("promptTokenCount").Int()
|
||||
thoughtsTokenCount := usageResult.Get("thoughtsTokenCount").Int()
|
||||
baseTemplate, _ = sjson.Set(baseTemplate, "usage.prompt_tokens", promptTokenCount+thoughtsTokenCount)
|
||||
baseTemplate, _ = sjson.Set(baseTemplate, "usage.prompt_tokens", promptTokenCount)
|
||||
if thoughtsTokenCount > 0 {
|
||||
baseTemplate, _ = sjson.Set(baseTemplate, "usage.completion_tokens_details.reasoning_tokens", thoughtsTokenCount)
|
||||
}
|
||||
@@ -297,7 +297,7 @@ func ConvertGeminiResponseToOpenAINonStream(_ context.Context, _ string, origina
|
||||
promptTokenCount := usageResult.Get("promptTokenCount").Int()
|
||||
thoughtsTokenCount := usageResult.Get("thoughtsTokenCount").Int()
|
||||
cachedTokenCount := usageResult.Get("cachedContentTokenCount").Int()
|
||||
template, _ = sjson.Set(template, "usage.prompt_tokens", promptTokenCount+thoughtsTokenCount)
|
||||
template, _ = sjson.Set(template, "usage.prompt_tokens", promptTokenCount)
|
||||
if thoughtsTokenCount > 0 {
|
||||
template, _ = sjson.Set(template, "usage.completion_tokens_details.reasoning_tokens", thoughtsTokenCount)
|
||||
}
|
||||
|
||||
@@ -531,8 +531,8 @@ func ConvertGeminiResponseToOpenAIResponses(_ context.Context, modelName string,
|
||||
|
||||
// usage mapping
|
||||
if um := root.Get("usageMetadata"); um.Exists() {
|
||||
// input tokens = prompt + thoughts
|
||||
input := um.Get("promptTokenCount").Int() + um.Get("thoughtsTokenCount").Int()
|
||||
// input tokens = prompt only (thoughts go to output)
|
||||
input := um.Get("promptTokenCount").Int()
|
||||
completed, _ = sjson.Set(completed, "response.usage.input_tokens", input)
|
||||
// cached token details: align with OpenAI "cached_tokens" semantics.
|
||||
completed, _ = sjson.Set(completed, "response.usage.input_tokens_details.cached_tokens", um.Get("cachedContentTokenCount").Int())
|
||||
@@ -737,8 +737,8 @@ func ConvertGeminiResponseToOpenAIResponsesNonStream(_ context.Context, _ string
|
||||
|
||||
// usage mapping
|
||||
if um := root.Get("usageMetadata"); um.Exists() {
|
||||
// input tokens = prompt + thoughts
|
||||
input := um.Get("promptTokenCount").Int() + um.Get("thoughtsTokenCount").Int()
|
||||
// input tokens = prompt only (thoughts go to output)
|
||||
input := um.Get("promptTokenCount").Int()
|
||||
resp, _ = sjson.Set(resp, "usage.input_tokens", input)
|
||||
// cached token details: align with OpenAI "cached_tokens" semantics.
|
||||
resp, _ = sjson.Set(resp, "usage.input_tokens_details.cached_tokens", um.Get("cachedContentTokenCount").Int())
|
||||
|
||||
@@ -716,6 +716,12 @@ func (h *BaseAPIHandler) ExecuteStreamWithAuthManager(ctx context.Context, handl
|
||||
return
|
||||
}
|
||||
if len(chunk.Payload) > 0 {
|
||||
if handlerType == "openai-response" {
|
||||
if err := validateSSEDataJSON(chunk.Payload); err != nil {
|
||||
_ = sendErr(&interfaces.ErrorMessage{StatusCode: http.StatusBadGateway, Error: err})
|
||||
return
|
||||
}
|
||||
}
|
||||
sentPayload = true
|
||||
if okSendData := sendData(cloneBytes(chunk.Payload)); !okSendData {
|
||||
return
|
||||
@@ -727,6 +733,35 @@ func (h *BaseAPIHandler) ExecuteStreamWithAuthManager(ctx context.Context, handl
|
||||
return dataChan, upstreamHeaders, errChan
|
||||
}
|
||||
|
||||
func validateSSEDataJSON(chunk []byte) error {
|
||||
for _, line := range bytes.Split(chunk, []byte("\n")) {
|
||||
line = bytes.TrimSpace(line)
|
||||
if len(line) == 0 {
|
||||
continue
|
||||
}
|
||||
if !bytes.HasPrefix(line, []byte("data:")) {
|
||||
continue
|
||||
}
|
||||
data := bytes.TrimSpace(line[5:])
|
||||
if len(data) == 0 {
|
||||
continue
|
||||
}
|
||||
if bytes.Equal(data, []byte("[DONE]")) {
|
||||
continue
|
||||
}
|
||||
if json.Valid(data) {
|
||||
continue
|
||||
}
|
||||
const max = 512
|
||||
preview := data
|
||||
if len(preview) > max {
|
||||
preview = preview[:max]
|
||||
}
|
||||
return fmt.Errorf("invalid SSE data JSON (len=%d): %q", len(data), preview)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func statusFromError(err error) int {
|
||||
if err == nil {
|
||||
return 0
|
||||
|
||||
@@ -134,6 +134,37 @@ type authAwareStreamExecutor struct {
|
||||
authIDs []string
|
||||
}
|
||||
|
||||
type invalidJSONStreamExecutor struct{}
|
||||
|
||||
func (e *invalidJSONStreamExecutor) Identifier() string { return "codex" }
|
||||
|
||||
func (e *invalidJSONStreamExecutor) Execute(context.Context, *coreauth.Auth, coreexecutor.Request, coreexecutor.Options) (coreexecutor.Response, error) {
|
||||
return coreexecutor.Response{}, &coreauth.Error{Code: "not_implemented", Message: "Execute not implemented"}
|
||||
}
|
||||
|
||||
func (e *invalidJSONStreamExecutor) ExecuteStream(context.Context, *coreauth.Auth, coreexecutor.Request, coreexecutor.Options) (*coreexecutor.StreamResult, error) {
|
||||
ch := make(chan coreexecutor.StreamChunk, 1)
|
||||
ch <- coreexecutor.StreamChunk{Payload: []byte("event: response.completed\ndata: {\"type\"")}
|
||||
close(ch)
|
||||
return &coreexecutor.StreamResult{Chunks: ch}, nil
|
||||
}
|
||||
|
||||
func (e *invalidJSONStreamExecutor) Refresh(ctx context.Context, auth *coreauth.Auth) (*coreauth.Auth, error) {
|
||||
return auth, nil
|
||||
}
|
||||
|
||||
func (e *invalidJSONStreamExecutor) CountTokens(context.Context, *coreauth.Auth, coreexecutor.Request, coreexecutor.Options) (coreexecutor.Response, error) {
|
||||
return coreexecutor.Response{}, &coreauth.Error{Code: "not_implemented", Message: "CountTokens not implemented"}
|
||||
}
|
||||
|
||||
func (e *invalidJSONStreamExecutor) HttpRequest(ctx context.Context, auth *coreauth.Auth, req *http.Request) (*http.Response, error) {
|
||||
return nil, &coreauth.Error{
|
||||
Code: "not_implemented",
|
||||
Message: "HttpRequest not implemented",
|
||||
HTTPStatus: http.StatusNotImplemented,
|
||||
}
|
||||
}
|
||||
|
||||
func (e *authAwareStreamExecutor) Identifier() string { return "codex" }
|
||||
|
||||
func (e *authAwareStreamExecutor) Execute(context.Context, *coreauth.Auth, coreexecutor.Request, coreexecutor.Options) (coreexecutor.Response, error) {
|
||||
@@ -524,3 +555,55 @@ func TestExecuteStreamWithAuthManager_SelectedAuthCallbackReceivesAuthID(t *test
|
||||
t.Fatalf("selectedAuthID = %q, want %q", selectedAuthID, "auth2")
|
||||
}
|
||||
}
|
||||
|
||||
func TestExecuteStreamWithAuthManager_ValidatesOpenAIResponsesStreamDataJSON(t *testing.T) {
|
||||
executor := &invalidJSONStreamExecutor{}
|
||||
manager := coreauth.NewManager(nil, nil, nil)
|
||||
manager.RegisterExecutor(executor)
|
||||
|
||||
auth1 := &coreauth.Auth{
|
||||
ID: "auth1",
|
||||
Provider: "codex",
|
||||
Status: coreauth.StatusActive,
|
||||
Metadata: map[string]any{"email": "test1@example.com"},
|
||||
}
|
||||
if _, err := manager.Register(context.Background(), auth1); err != nil {
|
||||
t.Fatalf("manager.Register(auth1): %v", err)
|
||||
}
|
||||
|
||||
registry.GetGlobalRegistry().RegisterClient(auth1.ID, auth1.Provider, []*registry.ModelInfo{{ID: "test-model"}})
|
||||
t.Cleanup(func() {
|
||||
registry.GetGlobalRegistry().UnregisterClient(auth1.ID)
|
||||
})
|
||||
|
||||
handler := NewBaseAPIHandlers(&sdkconfig.SDKConfig{}, manager)
|
||||
dataChan, _, errChan := handler.ExecuteStreamWithAuthManager(context.Background(), "openai-response", "test-model", []byte(`{"model":"test-model"}`), "")
|
||||
if dataChan == nil || errChan == nil {
|
||||
t.Fatalf("expected non-nil channels")
|
||||
}
|
||||
|
||||
var got []byte
|
||||
for chunk := range dataChan {
|
||||
got = append(got, chunk...)
|
||||
}
|
||||
if len(got) != 0 {
|
||||
t.Fatalf("expected empty payload, got %q", string(got))
|
||||
}
|
||||
|
||||
gotErr := false
|
||||
for msg := range errChan {
|
||||
if msg == nil {
|
||||
continue
|
||||
}
|
||||
if msg.StatusCode != http.StatusBadGateway {
|
||||
t.Fatalf("expected status %d, got %d", http.StatusBadGateway, msg.StatusCode)
|
||||
}
|
||||
if msg.Error == nil {
|
||||
t.Fatalf("expected error")
|
||||
}
|
||||
gotErr = true
|
||||
}
|
||||
if !gotErr {
|
||||
t.Fatalf("expected terminal error")
|
||||
}
|
||||
}
|
||||
|
||||
@@ -265,8 +265,8 @@ func (h *OpenAIResponsesAPIHandler) forwardResponsesStream(c *gin.Context, flush
|
||||
if errMsg.Error != nil && errMsg.Error.Error() != "" {
|
||||
errText = errMsg.Error.Error()
|
||||
}
|
||||
body := handlers.BuildErrorResponseBody(status, errText)
|
||||
_, _ = fmt.Fprintf(c.Writer, "\nevent: error\ndata: %s\n\n", string(body))
|
||||
chunk := handlers.BuildOpenAIResponsesStreamErrorChunk(status, errText, 0)
|
||||
_, _ = fmt.Fprintf(c.Writer, "\nevent: error\ndata: %s\n\n", string(chunk))
|
||||
},
|
||||
WriteDone: func() {
|
||||
_, _ = c.Writer.Write([]byte("\n"))
|
||||
|
||||
@@ -0,0 +1,43 @@
|
||||
package openai
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/interfaces"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/sdk/api/handlers"
|
||||
sdkconfig "github.com/router-for-me/CLIProxyAPI/v6/sdk/config"
|
||||
)
|
||||
|
||||
func TestForwardResponsesStreamTerminalErrorUsesResponsesErrorChunk(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
base := handlers.NewBaseAPIHandlers(&sdkconfig.SDKConfig{}, nil)
|
||||
h := NewOpenAIResponsesAPIHandler(base)
|
||||
|
||||
recorder := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(recorder)
|
||||
c.Request = httptest.NewRequest(http.MethodPost, "/v1/responses", nil)
|
||||
|
||||
flusher, ok := c.Writer.(http.Flusher)
|
||||
if !ok {
|
||||
t.Fatalf("expected gin writer to implement http.Flusher")
|
||||
}
|
||||
|
||||
data := make(chan []byte)
|
||||
errs := make(chan *interfaces.ErrorMessage, 1)
|
||||
errs <- &interfaces.ErrorMessage{StatusCode: http.StatusInternalServerError, Error: errors.New("unexpected EOF")}
|
||||
close(errs)
|
||||
|
||||
h.forwardResponsesStream(c, flusher, func(error) {}, data, errs)
|
||||
body := recorder.Body.String()
|
||||
if !strings.Contains(body, `"type":"error"`) {
|
||||
t.Fatalf("expected responses error chunk, got: %q", body)
|
||||
}
|
||||
if strings.Contains(body, `"error":{`) {
|
||||
t.Fatalf("expected streaming error chunk (top-level type), got HTTP error body: %q", body)
|
||||
}
|
||||
}
|
||||
119
sdk/api/handlers/openai_responses_stream_error.go
Normal file
119
sdk/api/handlers/openai_responses_stream_error.go
Normal file
@@ -0,0 +1,119 @@
|
||||
package handlers
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"strings"
|
||||
)
|
||||
|
||||
type openAIResponsesStreamErrorChunk struct {
|
||||
Type string `json:"type"`
|
||||
Code string `json:"code"`
|
||||
Message string `json:"message"`
|
||||
SequenceNumber int `json:"sequence_number"`
|
||||
}
|
||||
|
||||
func openAIResponsesStreamErrorCode(status int) string {
|
||||
switch status {
|
||||
case http.StatusUnauthorized:
|
||||
return "invalid_api_key"
|
||||
case http.StatusForbidden:
|
||||
return "insufficient_quota"
|
||||
case http.StatusTooManyRequests:
|
||||
return "rate_limit_exceeded"
|
||||
case http.StatusNotFound:
|
||||
return "model_not_found"
|
||||
case http.StatusRequestTimeout:
|
||||
return "request_timeout"
|
||||
default:
|
||||
if status >= http.StatusInternalServerError {
|
||||
return "internal_server_error"
|
||||
}
|
||||
if status >= http.StatusBadRequest {
|
||||
return "invalid_request_error"
|
||||
}
|
||||
return "unknown_error"
|
||||
}
|
||||
}
|
||||
|
||||
// BuildOpenAIResponsesStreamErrorChunk builds an OpenAI Responses streaming error chunk.
|
||||
//
|
||||
// Important: OpenAI's HTTP error bodies are shaped like {"error":{...}}; those are valid for
|
||||
// non-streaming responses, but streaming clients validate SSE `data:` payloads against a union
|
||||
// of chunks that requires a top-level `type` field.
|
||||
func BuildOpenAIResponsesStreamErrorChunk(status int, errText string, sequenceNumber int) []byte {
|
||||
if status <= 0 {
|
||||
status = http.StatusInternalServerError
|
||||
}
|
||||
if sequenceNumber < 0 {
|
||||
sequenceNumber = 0
|
||||
}
|
||||
|
||||
message := strings.TrimSpace(errText)
|
||||
if message == "" {
|
||||
message = http.StatusText(status)
|
||||
}
|
||||
|
||||
code := openAIResponsesStreamErrorCode(status)
|
||||
|
||||
trimmed := strings.TrimSpace(errText)
|
||||
if trimmed != "" && json.Valid([]byte(trimmed)) {
|
||||
var payload map[string]any
|
||||
if err := json.Unmarshal([]byte(trimmed), &payload); err == nil {
|
||||
if t, ok := payload["type"].(string); ok && strings.TrimSpace(t) == "error" {
|
||||
if m, ok := payload["message"].(string); ok && strings.TrimSpace(m) != "" {
|
||||
message = strings.TrimSpace(m)
|
||||
}
|
||||
if v, ok := payload["code"]; ok && v != nil {
|
||||
if c, ok := v.(string); ok && strings.TrimSpace(c) != "" {
|
||||
code = strings.TrimSpace(c)
|
||||
} else {
|
||||
code = strings.TrimSpace(fmt.Sprint(v))
|
||||
}
|
||||
}
|
||||
if v, ok := payload["sequence_number"].(float64); ok && sequenceNumber == 0 {
|
||||
sequenceNumber = int(v)
|
||||
}
|
||||
}
|
||||
if e, ok := payload["error"].(map[string]any); ok {
|
||||
if m, ok := e["message"].(string); ok && strings.TrimSpace(m) != "" {
|
||||
message = strings.TrimSpace(m)
|
||||
}
|
||||
if v, ok := e["code"]; ok && v != nil {
|
||||
if c, ok := v.(string); ok && strings.TrimSpace(c) != "" {
|
||||
code = strings.TrimSpace(c)
|
||||
} else {
|
||||
code = strings.TrimSpace(fmt.Sprint(v))
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if strings.TrimSpace(code) == "" {
|
||||
code = "unknown_error"
|
||||
}
|
||||
|
||||
data, err := json.Marshal(openAIResponsesStreamErrorChunk{
|
||||
Type: "error",
|
||||
Code: code,
|
||||
Message: message,
|
||||
SequenceNumber: sequenceNumber,
|
||||
})
|
||||
if err == nil {
|
||||
return data
|
||||
}
|
||||
|
||||
// Extremely defensive fallback.
|
||||
data, _ = json.Marshal(openAIResponsesStreamErrorChunk{
|
||||
Type: "error",
|
||||
Code: "internal_server_error",
|
||||
Message: message,
|
||||
SequenceNumber: sequenceNumber,
|
||||
})
|
||||
if len(data) > 0 {
|
||||
return data
|
||||
}
|
||||
return []byte(`{"type":"error","code":"internal_server_error","message":"internal error","sequence_number":0}`)
|
||||
}
|
||||
48
sdk/api/handlers/openai_responses_stream_error_test.go
Normal file
48
sdk/api/handlers/openai_responses_stream_error_test.go
Normal file
@@ -0,0 +1,48 @@
|
||||
package handlers
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestBuildOpenAIResponsesStreamErrorChunk(t *testing.T) {
|
||||
chunk := BuildOpenAIResponsesStreamErrorChunk(http.StatusInternalServerError, "unexpected EOF", 0)
|
||||
var payload map[string]any
|
||||
if err := json.Unmarshal(chunk, &payload); err != nil {
|
||||
t.Fatalf("unmarshal: %v", err)
|
||||
}
|
||||
if payload["type"] != "error" {
|
||||
t.Fatalf("type = %v, want %q", payload["type"], "error")
|
||||
}
|
||||
if payload["code"] != "internal_server_error" {
|
||||
t.Fatalf("code = %v, want %q", payload["code"], "internal_server_error")
|
||||
}
|
||||
if payload["message"] != "unexpected EOF" {
|
||||
t.Fatalf("message = %v, want %q", payload["message"], "unexpected EOF")
|
||||
}
|
||||
if payload["sequence_number"] != float64(0) {
|
||||
t.Fatalf("sequence_number = %v, want %v", payload["sequence_number"], 0)
|
||||
}
|
||||
}
|
||||
|
||||
func TestBuildOpenAIResponsesStreamErrorChunkExtractsHTTPErrorBody(t *testing.T) {
|
||||
chunk := BuildOpenAIResponsesStreamErrorChunk(
|
||||
http.StatusInternalServerError,
|
||||
`{"error":{"message":"oops","type":"server_error","code":"internal_server_error"}}`,
|
||||
0,
|
||||
)
|
||||
var payload map[string]any
|
||||
if err := json.Unmarshal(chunk, &payload); err != nil {
|
||||
t.Fatalf("unmarshal: %v", err)
|
||||
}
|
||||
if payload["type"] != "error" {
|
||||
t.Fatalf("type = %v, want %q", payload["type"], "error")
|
||||
}
|
||||
if payload["code"] != "internal_server_error" {
|
||||
t.Fatalf("code = %v, want %q", payload["code"], "internal_server_error")
|
||||
}
|
||||
if payload["message"] != "oops" {
|
||||
t.Fatalf("message = %v, want %q", payload["message"], "oops")
|
||||
}
|
||||
}
|
||||
@@ -2,8 +2,6 @@ package auth
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/sha256"
|
||||
"encoding/hex"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"strings"
|
||||
@@ -48,6 +46,10 @@ func (a *CodexAuthenticator) Login(ctx context.Context, cfg *config.Config, opts
|
||||
opts = &LoginOptions{}
|
||||
}
|
||||
|
||||
if shouldUseCodexDeviceFlow(opts) {
|
||||
return a.loginWithDeviceFlow(ctx, cfg, opts)
|
||||
}
|
||||
|
||||
callbackPort := a.CallbackPort
|
||||
if opts.CallbackPort > 0 {
|
||||
callbackPort = opts.CallbackPort
|
||||
@@ -186,39 +188,5 @@ waitForCallback:
|
||||
return nil, codex.NewAuthenticationError(codex.ErrCodeExchangeFailed, err)
|
||||
}
|
||||
|
||||
tokenStorage := authSvc.CreateTokenStorage(authBundle)
|
||||
|
||||
if tokenStorage == nil || tokenStorage.Email == "" {
|
||||
return nil, fmt.Errorf("codex token storage missing account information")
|
||||
}
|
||||
|
||||
planType := ""
|
||||
hashAccountID := ""
|
||||
if tokenStorage.IDToken != "" {
|
||||
if claims, errParse := codex.ParseJWTToken(tokenStorage.IDToken); errParse == nil && claims != nil {
|
||||
planType = strings.TrimSpace(claims.CodexAuthInfo.ChatgptPlanType)
|
||||
accountID := strings.TrimSpace(claims.CodexAuthInfo.ChatgptAccountID)
|
||||
if accountID != "" {
|
||||
digest := sha256.Sum256([]byte(accountID))
|
||||
hashAccountID = hex.EncodeToString(digest[:])[:8]
|
||||
}
|
||||
}
|
||||
}
|
||||
fileName := codex.CredentialFileName(tokenStorage.Email, planType, hashAccountID, true)
|
||||
metadata := map[string]any{
|
||||
"email": tokenStorage.Email,
|
||||
}
|
||||
|
||||
fmt.Println("Codex authentication successful")
|
||||
if authBundle.APIKey != "" {
|
||||
fmt.Println("Codex API key obtained and stored")
|
||||
}
|
||||
|
||||
return &coreauth.Auth{
|
||||
ID: fileName,
|
||||
Provider: a.Provider(),
|
||||
FileName: fileName,
|
||||
Storage: tokenStorage,
|
||||
Metadata: metadata,
|
||||
}, nil
|
||||
return a.buildAuthRecord(authSvc, authBundle)
|
||||
}
|
||||
|
||||
291
sdk/auth/codex_device.go
Normal file
291
sdk/auth/codex_device.go
Normal file
@@ -0,0 +1,291 @@
|
||||
package auth
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"crypto/sha256"
|
||||
"encoding/hex"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/auth/codex"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/browser"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/util"
|
||||
coreauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth"
|
||||
log "github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
const (
|
||||
codexLoginModeMetadataKey = "codex_login_mode"
|
||||
codexLoginModeDevice = "device"
|
||||
codexDeviceUserCodeURL = "https://auth.openai.com/api/accounts/deviceauth/usercode"
|
||||
codexDeviceTokenURL = "https://auth.openai.com/api/accounts/deviceauth/token"
|
||||
codexDeviceVerificationURL = "https://auth.openai.com/codex/device"
|
||||
codexDeviceTokenExchangeRedirectURI = "https://auth.openai.com/deviceauth/callback"
|
||||
codexDeviceTimeout = 15 * time.Minute
|
||||
codexDeviceDefaultPollIntervalSeconds = 5
|
||||
)
|
||||
|
||||
type codexDeviceUserCodeRequest struct {
|
||||
ClientID string `json:"client_id"`
|
||||
}
|
||||
|
||||
type codexDeviceUserCodeResponse struct {
|
||||
DeviceAuthID string `json:"device_auth_id"`
|
||||
UserCode string `json:"user_code"`
|
||||
UserCodeAlt string `json:"usercode"`
|
||||
Interval json.RawMessage `json:"interval"`
|
||||
}
|
||||
|
||||
type codexDeviceTokenRequest struct {
|
||||
DeviceAuthID string `json:"device_auth_id"`
|
||||
UserCode string `json:"user_code"`
|
||||
}
|
||||
|
||||
type codexDeviceTokenResponse struct {
|
||||
AuthorizationCode string `json:"authorization_code"`
|
||||
CodeVerifier string `json:"code_verifier"`
|
||||
CodeChallenge string `json:"code_challenge"`
|
||||
}
|
||||
|
||||
func shouldUseCodexDeviceFlow(opts *LoginOptions) bool {
|
||||
if opts == nil || opts.Metadata == nil {
|
||||
return false
|
||||
}
|
||||
return strings.EqualFold(strings.TrimSpace(opts.Metadata[codexLoginModeMetadataKey]), codexLoginModeDevice)
|
||||
}
|
||||
|
||||
func (a *CodexAuthenticator) loginWithDeviceFlow(ctx context.Context, cfg *config.Config, opts *LoginOptions) (*coreauth.Auth, error) {
|
||||
if ctx == nil {
|
||||
ctx = context.Background()
|
||||
}
|
||||
|
||||
httpClient := util.SetProxy(&cfg.SDKConfig, &http.Client{})
|
||||
|
||||
userCodeResp, err := requestCodexDeviceUserCode(ctx, httpClient)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
deviceCode := strings.TrimSpace(userCodeResp.UserCode)
|
||||
if deviceCode == "" {
|
||||
deviceCode = strings.TrimSpace(userCodeResp.UserCodeAlt)
|
||||
}
|
||||
deviceAuthID := strings.TrimSpace(userCodeResp.DeviceAuthID)
|
||||
if deviceCode == "" || deviceAuthID == "" {
|
||||
return nil, fmt.Errorf("codex device flow did not return required fields")
|
||||
}
|
||||
|
||||
pollInterval := parseCodexDevicePollInterval(userCodeResp.Interval)
|
||||
|
||||
fmt.Println("Starting Codex device authentication...")
|
||||
fmt.Printf("Codex device URL: %s\n", codexDeviceVerificationURL)
|
||||
fmt.Printf("Codex device code: %s\n", deviceCode)
|
||||
|
||||
if !opts.NoBrowser {
|
||||
if !browser.IsAvailable() {
|
||||
log.Warn("No browser available; please open the device URL manually")
|
||||
} else if errOpen := browser.OpenURL(codexDeviceVerificationURL); errOpen != nil {
|
||||
log.Warnf("Failed to open browser automatically: %v", errOpen)
|
||||
}
|
||||
}
|
||||
|
||||
tokenResp, err := pollCodexDeviceToken(ctx, httpClient, deviceAuthID, deviceCode, pollInterval)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
authCode := strings.TrimSpace(tokenResp.AuthorizationCode)
|
||||
codeVerifier := strings.TrimSpace(tokenResp.CodeVerifier)
|
||||
codeChallenge := strings.TrimSpace(tokenResp.CodeChallenge)
|
||||
if authCode == "" || codeVerifier == "" || codeChallenge == "" {
|
||||
return nil, fmt.Errorf("codex device flow token response missing required fields")
|
||||
}
|
||||
|
||||
authSvc := codex.NewCodexAuth(cfg)
|
||||
authBundle, err := authSvc.ExchangeCodeForTokensWithRedirect(
|
||||
ctx,
|
||||
authCode,
|
||||
codexDeviceTokenExchangeRedirectURI,
|
||||
&codex.PKCECodes{
|
||||
CodeVerifier: codeVerifier,
|
||||
CodeChallenge: codeChallenge,
|
||||
},
|
||||
)
|
||||
if err != nil {
|
||||
return nil, codex.NewAuthenticationError(codex.ErrCodeExchangeFailed, err)
|
||||
}
|
||||
|
||||
return a.buildAuthRecord(authSvc, authBundle)
|
||||
}
|
||||
|
||||
func requestCodexDeviceUserCode(ctx context.Context, client *http.Client) (*codexDeviceUserCodeResponse, error) {
|
||||
body, err := json.Marshal(codexDeviceUserCodeRequest{ClientID: codex.ClientID})
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to encode codex device request: %w", err)
|
||||
}
|
||||
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodPost, codexDeviceUserCodeURL, bytes.NewReader(body))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create codex device request: %w", err)
|
||||
}
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
req.Header.Set("Accept", "application/json")
|
||||
|
||||
resp, err := client.Do(req)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to request codex device code: %w", err)
|
||||
}
|
||||
defer func() { _ = resp.Body.Close() }()
|
||||
|
||||
respBody, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to read codex device code response: %w", err)
|
||||
}
|
||||
|
||||
if !codexDeviceIsSuccessStatus(resp.StatusCode) {
|
||||
trimmed := strings.TrimSpace(string(respBody))
|
||||
if resp.StatusCode == http.StatusNotFound {
|
||||
return nil, fmt.Errorf("codex device endpoint is unavailable (status %d)", resp.StatusCode)
|
||||
}
|
||||
if trimmed == "" {
|
||||
trimmed = "empty response body"
|
||||
}
|
||||
return nil, fmt.Errorf("codex device code request failed with status %d: %s", resp.StatusCode, trimmed)
|
||||
}
|
||||
|
||||
var parsed codexDeviceUserCodeResponse
|
||||
if err := json.Unmarshal(respBody, &parsed); err != nil {
|
||||
return nil, fmt.Errorf("failed to decode codex device code response: %w", err)
|
||||
}
|
||||
|
||||
return &parsed, nil
|
||||
}
|
||||
|
||||
func pollCodexDeviceToken(ctx context.Context, client *http.Client, deviceAuthID, userCode string, interval time.Duration) (*codexDeviceTokenResponse, error) {
|
||||
deadline := time.Now().Add(codexDeviceTimeout)
|
||||
|
||||
for {
|
||||
if time.Now().After(deadline) {
|
||||
return nil, fmt.Errorf("codex device authentication timed out after 15 minutes")
|
||||
}
|
||||
|
||||
body, err := json.Marshal(codexDeviceTokenRequest{
|
||||
DeviceAuthID: deviceAuthID,
|
||||
UserCode: userCode,
|
||||
})
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to encode codex device poll request: %w", err)
|
||||
}
|
||||
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodPost, codexDeviceTokenURL, bytes.NewReader(body))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create codex device poll request: %w", err)
|
||||
}
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
req.Header.Set("Accept", "application/json")
|
||||
|
||||
resp, err := client.Do(req)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to poll codex device token: %w", err)
|
||||
}
|
||||
|
||||
respBody, readErr := io.ReadAll(resp.Body)
|
||||
_ = resp.Body.Close()
|
||||
if readErr != nil {
|
||||
return nil, fmt.Errorf("failed to read codex device poll response: %w", readErr)
|
||||
}
|
||||
|
||||
switch {
|
||||
case codexDeviceIsSuccessStatus(resp.StatusCode):
|
||||
var parsed codexDeviceTokenResponse
|
||||
if err := json.Unmarshal(respBody, &parsed); err != nil {
|
||||
return nil, fmt.Errorf("failed to decode codex device token response: %w", err)
|
||||
}
|
||||
return &parsed, nil
|
||||
case resp.StatusCode == http.StatusForbidden || resp.StatusCode == http.StatusNotFound:
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return nil, ctx.Err()
|
||||
case <-time.After(interval):
|
||||
continue
|
||||
}
|
||||
default:
|
||||
trimmed := strings.TrimSpace(string(respBody))
|
||||
if trimmed == "" {
|
||||
trimmed = "empty response body"
|
||||
}
|
||||
return nil, fmt.Errorf("codex device token polling failed with status %d: %s", resp.StatusCode, trimmed)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func parseCodexDevicePollInterval(raw json.RawMessage) time.Duration {
|
||||
defaultInterval := time.Duration(codexDeviceDefaultPollIntervalSeconds) * time.Second
|
||||
if len(raw) == 0 {
|
||||
return defaultInterval
|
||||
}
|
||||
|
||||
var asString string
|
||||
if err := json.Unmarshal(raw, &asString); err == nil {
|
||||
if seconds, convErr := strconv.Atoi(strings.TrimSpace(asString)); convErr == nil && seconds > 0 {
|
||||
return time.Duration(seconds) * time.Second
|
||||
}
|
||||
}
|
||||
|
||||
var asInt int
|
||||
if err := json.Unmarshal(raw, &asInt); err == nil && asInt > 0 {
|
||||
return time.Duration(asInt) * time.Second
|
||||
}
|
||||
|
||||
return defaultInterval
|
||||
}
|
||||
|
||||
func codexDeviceIsSuccessStatus(code int) bool {
|
||||
return code >= 200 && code < 300
|
||||
}
|
||||
|
||||
func (a *CodexAuthenticator) buildAuthRecord(authSvc *codex.CodexAuth, authBundle *codex.CodexAuthBundle) (*coreauth.Auth, error) {
|
||||
tokenStorage := authSvc.CreateTokenStorage(authBundle)
|
||||
|
||||
if tokenStorage == nil || tokenStorage.Email == "" {
|
||||
return nil, fmt.Errorf("codex token storage missing account information")
|
||||
}
|
||||
|
||||
planType := ""
|
||||
hashAccountID := ""
|
||||
if tokenStorage.IDToken != "" {
|
||||
if claims, errParse := codex.ParseJWTToken(tokenStorage.IDToken); errParse == nil && claims != nil {
|
||||
planType = strings.TrimSpace(claims.CodexAuthInfo.ChatgptPlanType)
|
||||
accountID := strings.TrimSpace(claims.CodexAuthInfo.ChatgptAccountID)
|
||||
if accountID != "" {
|
||||
digest := sha256.Sum256([]byte(accountID))
|
||||
hashAccountID = hex.EncodeToString(digest[:])[:8]
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fileName := codex.CredentialFileName(tokenStorage.Email, planType, hashAccountID, true)
|
||||
metadata := map[string]any{
|
||||
"email": tokenStorage.Email,
|
||||
}
|
||||
|
||||
fmt.Println("Codex authentication successful")
|
||||
if authBundle.APIKey != "" {
|
||||
fmt.Println("Codex API key obtained and stored")
|
||||
}
|
||||
|
||||
return &coreauth.Auth{
|
||||
ID: fileName,
|
||||
Provider: a.Provider(),
|
||||
FileName: fileName,
|
||||
Storage: tokenStorage,
|
||||
Metadata: metadata,
|
||||
}, nil
|
||||
}
|
||||
@@ -64,8 +64,16 @@ func (s *FileTokenStore) Save(ctx context.Context, auth *cliproxyauth.Auth) (str
|
||||
return "", fmt.Errorf("auth filestore: create dir failed: %w", err)
|
||||
}
|
||||
|
||||
// metadataSetter is a private interface for TokenStorage implementations that support metadata injection.
|
||||
type metadataSetter interface {
|
||||
SetMetadata(map[string]any)
|
||||
}
|
||||
|
||||
switch {
|
||||
case auth.Storage != nil:
|
||||
if setter, ok := auth.Storage.(metadataSetter); ok {
|
||||
setter.SetMetadata(auth.Metadata)
|
||||
}
|
||||
if err = auth.Storage.SaveTokenToFile(path); err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
@@ -60,6 +60,7 @@ type RefreshEvaluator interface {
|
||||
|
||||
const (
|
||||
refreshCheckInterval = 5 * time.Second
|
||||
refreshMaxConcurrency = 16
|
||||
refreshPendingBackoff = time.Minute
|
||||
refreshFailureBackoff = 5 * time.Minute
|
||||
quotaBackoffBase = time.Second
|
||||
@@ -155,7 +156,8 @@ type Manager struct {
|
||||
rtProvider RoundTripperProvider
|
||||
|
||||
// Auto refresh state
|
||||
refreshCancel context.CancelFunc
|
||||
refreshCancel context.CancelFunc
|
||||
refreshSemaphore chan struct{}
|
||||
}
|
||||
|
||||
// NewManager constructs a manager with optional custom selector and hook.
|
||||
@@ -173,6 +175,7 @@ func NewManager(store Store, selector Selector, hook Hook) *Manager {
|
||||
hook: hook,
|
||||
auths: make(map[string]*Auth),
|
||||
providerOffsets: make(map[string]int),
|
||||
refreshSemaphore: make(chan struct{}, refreshMaxConcurrency),
|
||||
}
|
||||
// atomic.Value requires non-nil initial value.
|
||||
manager.runtimeConfig.Store(&internalconfig.Config{})
|
||||
@@ -1828,9 +1831,7 @@ func (m *Manager) persist(ctx context.Context, auth *Auth) error {
|
||||
// every few seconds and triggers refresh operations when required.
|
||||
// Only one loop is kept alive; starting a new one cancels the previous run.
|
||||
func (m *Manager) StartAutoRefresh(parent context.Context, interval time.Duration) {
|
||||
if interval <= 0 || interval > refreshCheckInterval {
|
||||
interval = refreshCheckInterval
|
||||
} else {
|
||||
if interval <= 0 {
|
||||
interval = refreshCheckInterval
|
||||
}
|
||||
if m.refreshCancel != nil {
|
||||
@@ -1880,11 +1881,25 @@ func (m *Manager) checkRefreshes(ctx context.Context) {
|
||||
if !m.markRefreshPending(a.ID, now) {
|
||||
continue
|
||||
}
|
||||
go m.refreshAuth(ctx, a.ID)
|
||||
go m.refreshAuthWithLimit(ctx, a.ID)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (m *Manager) refreshAuthWithLimit(ctx context.Context, id string) {
|
||||
if m.refreshSemaphore == nil {
|
||||
m.refreshAuth(ctx, id)
|
||||
return
|
||||
}
|
||||
select {
|
||||
case m.refreshSemaphore <- struct{}{}:
|
||||
defer func() { <-m.refreshSemaphore }()
|
||||
case <-ctx.Done():
|
||||
return
|
||||
}
|
||||
m.refreshAuth(ctx, id)
|
||||
}
|
||||
|
||||
func (m *Manager) snapshotAuths() []*Auth {
|
||||
m.mu.RLock()
|
||||
defer m.mu.RUnlock()
|
||||
|
||||
@@ -5,6 +5,7 @@ import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"math"
|
||||
"math/rand/v2"
|
||||
"net/http"
|
||||
"sort"
|
||||
"strconv"
|
||||
@@ -248,6 +249,9 @@ func getAvailableAuths(auths []*Auth, provider, model string, now time.Time) ([]
|
||||
}
|
||||
|
||||
// Pick selects the next available auth for the provider in a round-robin manner.
|
||||
// For gemini-cli virtual auths (identified by the gemini_virtual_parent attribute),
|
||||
// a two-level round-robin is used: first cycling across credential groups (parent
|
||||
// accounts), then cycling within each group's project auths.
|
||||
func (s *RoundRobinSelector) Pick(ctx context.Context, provider, model string, opts cliproxyexecutor.Options, auths []*Auth) (*Auth, error) {
|
||||
_ = opts
|
||||
now := time.Now()
|
||||
@@ -265,21 +269,87 @@ func (s *RoundRobinSelector) Pick(ctx context.Context, provider, model string, o
|
||||
if limit <= 0 {
|
||||
limit = 4096
|
||||
}
|
||||
if _, ok := s.cursors[key]; !ok && len(s.cursors) >= limit {
|
||||
s.cursors = make(map[string]int)
|
||||
}
|
||||
index := s.cursors[key]
|
||||
|
||||
// Check if any available auth has gemini_virtual_parent attribute,
|
||||
// indicating gemini-cli virtual auths that should use credential-level polling.
|
||||
groups, parentOrder := groupByVirtualParent(available)
|
||||
if len(parentOrder) > 1 {
|
||||
// Two-level round-robin: first select a credential group, then pick within it.
|
||||
groupKey := key + "::group"
|
||||
s.ensureCursorKey(groupKey, limit)
|
||||
if _, exists := s.cursors[groupKey]; !exists {
|
||||
// Seed with a random initial offset so the starting credential is randomized.
|
||||
s.cursors[groupKey] = rand.IntN(len(parentOrder))
|
||||
}
|
||||
groupIndex := s.cursors[groupKey]
|
||||
if groupIndex >= 2_147_483_640 {
|
||||
groupIndex = 0
|
||||
}
|
||||
s.cursors[groupKey] = groupIndex + 1
|
||||
|
||||
selectedParent := parentOrder[groupIndex%len(parentOrder)]
|
||||
group := groups[selectedParent]
|
||||
|
||||
// Second level: round-robin within the selected credential group.
|
||||
innerKey := key + "::cred:" + selectedParent
|
||||
s.ensureCursorKey(innerKey, limit)
|
||||
innerIndex := s.cursors[innerKey]
|
||||
if innerIndex >= 2_147_483_640 {
|
||||
innerIndex = 0
|
||||
}
|
||||
s.cursors[innerKey] = innerIndex + 1
|
||||
s.mu.Unlock()
|
||||
return group[innerIndex%len(group)], nil
|
||||
}
|
||||
|
||||
// Flat round-robin for non-grouped auths (original behavior).
|
||||
s.ensureCursorKey(key, limit)
|
||||
index := s.cursors[key]
|
||||
if index >= 2_147_483_640 {
|
||||
index = 0
|
||||
}
|
||||
|
||||
s.cursors[key] = index + 1
|
||||
s.mu.Unlock()
|
||||
// log.Debugf("available: %d, index: %d, key: %d", len(available), index, index%len(available))
|
||||
return available[index%len(available)], nil
|
||||
}
|
||||
|
||||
// ensureCursorKey ensures the cursor map has capacity for the given key.
|
||||
// Must be called with s.mu held.
|
||||
func (s *RoundRobinSelector) ensureCursorKey(key string, limit int) {
|
||||
if _, ok := s.cursors[key]; !ok && len(s.cursors) >= limit {
|
||||
s.cursors = make(map[string]int)
|
||||
}
|
||||
}
|
||||
|
||||
// groupByVirtualParent groups auths by their gemini_virtual_parent attribute.
|
||||
// Returns a map of parentID -> auths and a sorted slice of parent IDs for stable iteration.
|
||||
// Only auths with a non-empty gemini_virtual_parent are grouped; if any auth lacks
|
||||
// this attribute, nil/nil is returned so the caller falls back to flat round-robin.
|
||||
func groupByVirtualParent(auths []*Auth) (map[string][]*Auth, []string) {
|
||||
if len(auths) == 0 {
|
||||
return nil, nil
|
||||
}
|
||||
groups := make(map[string][]*Auth)
|
||||
for _, a := range auths {
|
||||
parent := ""
|
||||
if a.Attributes != nil {
|
||||
parent = strings.TrimSpace(a.Attributes["gemini_virtual_parent"])
|
||||
}
|
||||
if parent == "" {
|
||||
// Non-virtual auth present; fall back to flat round-robin.
|
||||
return nil, nil
|
||||
}
|
||||
groups[parent] = append(groups[parent], a)
|
||||
}
|
||||
// Collect parent IDs in sorted order for stable cursor indexing.
|
||||
parentOrder := make([]string, 0, len(groups))
|
||||
for p := range groups {
|
||||
parentOrder = append(parentOrder, p)
|
||||
}
|
||||
sort.Strings(parentOrder)
|
||||
return groups, parentOrder
|
||||
}
|
||||
|
||||
// Pick selects the first available auth for the provider in a deterministic manner.
|
||||
func (s *FillFirstSelector) Pick(ctx context.Context, provider, model string, opts cliproxyexecutor.Options, auths []*Auth) (*Auth, error) {
|
||||
_ = opts
|
||||
|
||||
@@ -402,3 +402,128 @@ func TestRoundRobinSelectorPick_CursorKeyCap(t *testing.T) {
|
||||
t.Fatalf("selector.cursors missing key %q", "gemini:m3")
|
||||
}
|
||||
}
|
||||
|
||||
func TestRoundRobinSelectorPick_GeminiCLICredentialGrouping(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
selector := &RoundRobinSelector{}
|
||||
|
||||
// Simulate two gemini-cli credentials, each with multiple projects:
|
||||
// Credential A (parent = "cred-a.json") has 3 projects
|
||||
// Credential B (parent = "cred-b.json") has 2 projects
|
||||
auths := []*Auth{
|
||||
{ID: "cred-a.json::proj-a1", Attributes: map[string]string{"gemini_virtual_parent": "cred-a.json"}},
|
||||
{ID: "cred-a.json::proj-a2", Attributes: map[string]string{"gemini_virtual_parent": "cred-a.json"}},
|
||||
{ID: "cred-a.json::proj-a3", Attributes: map[string]string{"gemini_virtual_parent": "cred-a.json"}},
|
||||
{ID: "cred-b.json::proj-b1", Attributes: map[string]string{"gemini_virtual_parent": "cred-b.json"}},
|
||||
{ID: "cred-b.json::proj-b2", Attributes: map[string]string{"gemini_virtual_parent": "cred-b.json"}},
|
||||
}
|
||||
|
||||
// Two-level round-robin: consecutive picks must alternate between credentials.
|
||||
// Credential group order is randomized, but within each call the group cursor
|
||||
// advances by 1, so consecutive picks should cycle through different parents.
|
||||
picks := make([]string, 6)
|
||||
parents := make([]string, 6)
|
||||
for i := 0; i < 6; i++ {
|
||||
got, err := selector.Pick(context.Background(), "gemini-cli", "gemini-2.5-pro", cliproxyexecutor.Options{}, auths)
|
||||
if err != nil {
|
||||
t.Fatalf("Pick() #%d error = %v", i, err)
|
||||
}
|
||||
if got == nil {
|
||||
t.Fatalf("Pick() #%d auth = nil", i)
|
||||
}
|
||||
picks[i] = got.ID
|
||||
parents[i] = got.Attributes["gemini_virtual_parent"]
|
||||
}
|
||||
|
||||
// Verify property: consecutive picks must alternate between credential groups.
|
||||
for i := 1; i < len(parents); i++ {
|
||||
if parents[i] == parents[i-1] {
|
||||
t.Fatalf("Pick() #%d and #%d both from same parent %q (IDs: %q, %q); expected alternating credentials",
|
||||
i-1, i, parents[i], picks[i-1], picks[i])
|
||||
}
|
||||
}
|
||||
|
||||
// Verify property: each credential's projects are picked in sequence (round-robin within group).
|
||||
credPicks := map[string][]string{}
|
||||
for i, id := range picks {
|
||||
credPicks[parents[i]] = append(credPicks[parents[i]], id)
|
||||
}
|
||||
for parent, ids := range credPicks {
|
||||
for i := 1; i < len(ids); i++ {
|
||||
if ids[i] == ids[i-1] {
|
||||
t.Fatalf("Credential %q picked same project %q twice in a row", parent, ids[i])
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestRoundRobinSelectorPick_SingleParentFallsBackToFlat(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
selector := &RoundRobinSelector{}
|
||||
|
||||
// All auths from the same parent - should fall back to flat round-robin
|
||||
// because there's only one credential group (no benefit from two-level).
|
||||
auths := []*Auth{
|
||||
{ID: "cred-a.json::proj-a1", Attributes: map[string]string{"gemini_virtual_parent": "cred-a.json"}},
|
||||
{ID: "cred-a.json::proj-a2", Attributes: map[string]string{"gemini_virtual_parent": "cred-a.json"}},
|
||||
{ID: "cred-a.json::proj-a3", Attributes: map[string]string{"gemini_virtual_parent": "cred-a.json"}},
|
||||
}
|
||||
|
||||
// With single parent group, parentOrder has length 1, so it uses flat round-robin.
|
||||
// Sorted by ID: proj-a1, proj-a2, proj-a3
|
||||
want := []string{
|
||||
"cred-a.json::proj-a1",
|
||||
"cred-a.json::proj-a2",
|
||||
"cred-a.json::proj-a3",
|
||||
"cred-a.json::proj-a1",
|
||||
}
|
||||
|
||||
for i, expectedID := range want {
|
||||
got, err := selector.Pick(context.Background(), "gemini-cli", "gemini-2.5-pro", cliproxyexecutor.Options{}, auths)
|
||||
if err != nil {
|
||||
t.Fatalf("Pick() #%d error = %v", i, err)
|
||||
}
|
||||
if got == nil {
|
||||
t.Fatalf("Pick() #%d auth = nil", i)
|
||||
}
|
||||
if got.ID != expectedID {
|
||||
t.Fatalf("Pick() #%d auth.ID = %q, want %q", i, got.ID, expectedID)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestRoundRobinSelectorPick_MixedVirtualAndNonVirtualFallsBackToFlat(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
selector := &RoundRobinSelector{}
|
||||
|
||||
// Mix of virtual and non-virtual auths (e.g., a regular gemini-cli auth without projects
|
||||
// alongside virtual ones). Should fall back to flat round-robin.
|
||||
auths := []*Auth{
|
||||
{ID: "cred-a.json::proj-a1", Attributes: map[string]string{"gemini_virtual_parent": "cred-a.json"}},
|
||||
{ID: "cred-regular.json"}, // no gemini_virtual_parent
|
||||
}
|
||||
|
||||
// groupByVirtualParent returns nil when any auth lacks the attribute,
|
||||
// so flat round-robin is used. Sorted by ID: cred-a.json::proj-a1, cred-regular.json
|
||||
want := []string{
|
||||
"cred-a.json::proj-a1",
|
||||
"cred-regular.json",
|
||||
"cred-a.json::proj-a1",
|
||||
}
|
||||
|
||||
for i, expectedID := range want {
|
||||
got, err := selector.Pick(context.Background(), "gemini-cli", "", cliproxyexecutor.Options{}, auths)
|
||||
if err != nil {
|
||||
t.Fatalf("Pick() #%d error = %v", i, err)
|
||||
}
|
||||
if got == nil {
|
||||
t.Fatalf("Pick() #%d auth = nil", i)
|
||||
}
|
||||
if got.ID != expectedID {
|
||||
t.Fatalf("Pick() #%d auth.ID = %q, want %q", i, got.ID, expectedID)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,9 +1,12 @@
|
||||
package auth
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/sha256"
|
||||
"encoding/hex"
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
@@ -12,6 +15,33 @@ import (
|
||||
baseauth "github.com/router-for-me/CLIProxyAPI/v6/internal/auth"
|
||||
)
|
||||
|
||||
// PostAuthHook defines a function that is called after an Auth record is created
|
||||
// but before it is persisted to storage. This allows for modification of the
|
||||
// Auth record (e.g., injecting metadata) based on external context.
|
||||
type PostAuthHook func(context.Context, *Auth) error
|
||||
|
||||
// RequestInfo holds information extracted from the HTTP request.
|
||||
// It is injected into the context passed to PostAuthHook.
|
||||
type RequestInfo struct {
|
||||
Query url.Values
|
||||
Headers http.Header
|
||||
}
|
||||
|
||||
type requestInfoKey struct{}
|
||||
|
||||
// WithRequestInfo returns a new context with the given RequestInfo attached.
|
||||
func WithRequestInfo(ctx context.Context, info *RequestInfo) context.Context {
|
||||
return context.WithValue(ctx, requestInfoKey{}, info)
|
||||
}
|
||||
|
||||
// GetRequestInfo retrieves the RequestInfo from the context, if present.
|
||||
func GetRequestInfo(ctx context.Context) *RequestInfo {
|
||||
if val, ok := ctx.Value(requestInfoKey{}).(*RequestInfo); ok {
|
||||
return val
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Auth encapsulates the runtime state and metadata associated with a single credential.
|
||||
type Auth struct {
|
||||
// ID uniquely identifies the auth record across restarts.
|
||||
|
||||
@@ -153,6 +153,16 @@ func (b *Builder) WithLocalManagementPassword(password string) *Builder {
|
||||
return b
|
||||
}
|
||||
|
||||
// WithPostAuthHook registers a hook to be called after an Auth record is created
|
||||
// but before it is persisted to storage.
|
||||
func (b *Builder) WithPostAuthHook(hook coreauth.PostAuthHook) *Builder {
|
||||
if hook == nil {
|
||||
return b
|
||||
}
|
||||
b.serverOptions = append(b.serverOptions, api.WithPostAuthHook(hook))
|
||||
return b
|
||||
}
|
||||
|
||||
// Build validates inputs, applies defaults, and returns a ready-to-run service.
|
||||
func (b *Builder) Build() (*Service, error) {
|
||||
if b.cfg == nil {
|
||||
|
||||
@@ -925,6 +925,9 @@ func (s *Service) registerModelsForAuth(a *coreauth.Auth) {
|
||||
key = strings.ToLower(strings.TrimSpace(a.Provider))
|
||||
}
|
||||
GlobalModelRegistry().RegisterClient(a.ID, key, applyModelPrefixes(models, a.Prefix, s.cfg != nil && s.cfg.ForceModelPrefix))
|
||||
if provider == "antigravity" {
|
||||
s.backfillAntigravityModels(a, models)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
@@ -1069,6 +1072,56 @@ func (s *Service) oauthExcludedModels(provider, authKind string) []string {
|
||||
return cfg.OAuthExcludedModels[providerKey]
|
||||
}
|
||||
|
||||
func (s *Service) backfillAntigravityModels(source *coreauth.Auth, primaryModels []*ModelInfo) {
|
||||
if s == nil || s.coreManager == nil || len(primaryModels) == 0 {
|
||||
return
|
||||
}
|
||||
|
||||
sourceID := ""
|
||||
if source != nil {
|
||||
sourceID = strings.TrimSpace(source.ID)
|
||||
}
|
||||
|
||||
reg := registry.GetGlobalRegistry()
|
||||
for _, candidate := range s.coreManager.List() {
|
||||
if candidate == nil || candidate.Disabled {
|
||||
continue
|
||||
}
|
||||
candidateID := strings.TrimSpace(candidate.ID)
|
||||
if candidateID == "" || candidateID == sourceID {
|
||||
continue
|
||||
}
|
||||
if !strings.EqualFold(strings.TrimSpace(candidate.Provider), "antigravity") {
|
||||
continue
|
||||
}
|
||||
if len(reg.GetModelsForClient(candidateID)) > 0 {
|
||||
continue
|
||||
}
|
||||
|
||||
authKind := strings.ToLower(strings.TrimSpace(candidate.Attributes["auth_kind"]))
|
||||
if authKind == "" {
|
||||
if kind, _ := candidate.AccountInfo(); strings.EqualFold(kind, "api_key") {
|
||||
authKind = "apikey"
|
||||
}
|
||||
}
|
||||
excluded := s.oauthExcludedModels("antigravity", authKind)
|
||||
if candidate.Attributes != nil {
|
||||
if val, ok := candidate.Attributes["excluded_models"]; ok && strings.TrimSpace(val) != "" {
|
||||
excluded = strings.Split(val, ",")
|
||||
}
|
||||
}
|
||||
|
||||
models := applyExcludedModels(primaryModels, excluded)
|
||||
models = applyOAuthModelAlias(s.cfg, "antigravity", authKind, models)
|
||||
if len(models) == 0 {
|
||||
continue
|
||||
}
|
||||
|
||||
reg.RegisterClient(candidateID, "antigravity", applyModelPrefixes(models, candidate.Prefix, s.cfg != nil && s.cfg.ForceModelPrefix))
|
||||
log.Debugf("antigravity models backfilled for auth %s using primary model list", candidateID)
|
||||
}
|
||||
}
|
||||
|
||||
func applyExcludedModels(models []*ModelInfo, excluded []string) []*ModelInfo {
|
||||
if len(models) == 0 || len(excluded) == 0 {
|
||||
return models
|
||||
|
||||
135
sdk/cliproxy/service_antigravity_backfill_test.go
Normal file
135
sdk/cliproxy/service_antigravity_backfill_test.go
Normal file
@@ -0,0 +1,135 @@
|
||||
package cliproxy
|
||||
|
||||
import (
|
||||
"context"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/registry"
|
||||
coreauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/sdk/config"
|
||||
)
|
||||
|
||||
func TestBackfillAntigravityModels_RegistersMissingAuth(t *testing.T) {
|
||||
source := &coreauth.Auth{
|
||||
ID: "ag-backfill-source",
|
||||
Provider: "antigravity",
|
||||
Status: coreauth.StatusActive,
|
||||
Attributes: map[string]string{
|
||||
"auth_kind": "oauth",
|
||||
},
|
||||
}
|
||||
target := &coreauth.Auth{
|
||||
ID: "ag-backfill-target",
|
||||
Provider: "antigravity",
|
||||
Status: coreauth.StatusActive,
|
||||
Attributes: map[string]string{
|
||||
"auth_kind": "oauth",
|
||||
},
|
||||
}
|
||||
|
||||
manager := coreauth.NewManager(nil, nil, nil)
|
||||
if _, err := manager.Register(context.Background(), source); err != nil {
|
||||
t.Fatalf("register source auth: %v", err)
|
||||
}
|
||||
if _, err := manager.Register(context.Background(), target); err != nil {
|
||||
t.Fatalf("register target auth: %v", err)
|
||||
}
|
||||
|
||||
service := &Service{
|
||||
cfg: &config.Config{},
|
||||
coreManager: manager,
|
||||
}
|
||||
|
||||
reg := registry.GetGlobalRegistry()
|
||||
reg.UnregisterClient(source.ID)
|
||||
reg.UnregisterClient(target.ID)
|
||||
t.Cleanup(func() {
|
||||
reg.UnregisterClient(source.ID)
|
||||
reg.UnregisterClient(target.ID)
|
||||
})
|
||||
|
||||
primary := []*ModelInfo{
|
||||
{ID: "claude-sonnet-4-5"},
|
||||
{ID: "gemini-2.5-pro"},
|
||||
}
|
||||
reg.RegisterClient(source.ID, "antigravity", primary)
|
||||
|
||||
service.backfillAntigravityModels(source, primary)
|
||||
|
||||
got := reg.GetModelsForClient(target.ID)
|
||||
if len(got) != 2 {
|
||||
t.Fatalf("expected target auth to be backfilled with 2 models, got %d", len(got))
|
||||
}
|
||||
|
||||
ids := make(map[string]struct{}, len(got))
|
||||
for _, model := range got {
|
||||
if model == nil {
|
||||
continue
|
||||
}
|
||||
ids[strings.ToLower(strings.TrimSpace(model.ID))] = struct{}{}
|
||||
}
|
||||
if _, ok := ids["claude-sonnet-4-5"]; !ok {
|
||||
t.Fatal("expected backfilled model claude-sonnet-4-5")
|
||||
}
|
||||
if _, ok := ids["gemini-2.5-pro"]; !ok {
|
||||
t.Fatal("expected backfilled model gemini-2.5-pro")
|
||||
}
|
||||
}
|
||||
|
||||
func TestBackfillAntigravityModels_RespectsExcludedModels(t *testing.T) {
|
||||
source := &coreauth.Auth{
|
||||
ID: "ag-backfill-source-excluded",
|
||||
Provider: "antigravity",
|
||||
Status: coreauth.StatusActive,
|
||||
Attributes: map[string]string{
|
||||
"auth_kind": "oauth",
|
||||
},
|
||||
}
|
||||
target := &coreauth.Auth{
|
||||
ID: "ag-backfill-target-excluded",
|
||||
Provider: "antigravity",
|
||||
Status: coreauth.StatusActive,
|
||||
Attributes: map[string]string{
|
||||
"auth_kind": "oauth",
|
||||
"excluded_models": "gemini-2.5-pro",
|
||||
},
|
||||
}
|
||||
|
||||
manager := coreauth.NewManager(nil, nil, nil)
|
||||
if _, err := manager.Register(context.Background(), source); err != nil {
|
||||
t.Fatalf("register source auth: %v", err)
|
||||
}
|
||||
if _, err := manager.Register(context.Background(), target); err != nil {
|
||||
t.Fatalf("register target auth: %v", err)
|
||||
}
|
||||
|
||||
service := &Service{
|
||||
cfg: &config.Config{},
|
||||
coreManager: manager,
|
||||
}
|
||||
|
||||
reg := registry.GetGlobalRegistry()
|
||||
reg.UnregisterClient(source.ID)
|
||||
reg.UnregisterClient(target.ID)
|
||||
t.Cleanup(func() {
|
||||
reg.UnregisterClient(source.ID)
|
||||
reg.UnregisterClient(target.ID)
|
||||
})
|
||||
|
||||
primary := []*ModelInfo{
|
||||
{ID: "claude-sonnet-4-5"},
|
||||
{ID: "gemini-2.5-pro"},
|
||||
}
|
||||
reg.RegisterClient(source.ID, "antigravity", primary)
|
||||
|
||||
service.backfillAntigravityModels(source, primary)
|
||||
|
||||
got := reg.GetModelsForClient(target.ID)
|
||||
if len(got) != 1 {
|
||||
t.Fatalf("expected 1 model after exclusion, got %d", len(got))
|
||||
}
|
||||
if got[0] == nil || !strings.EqualFold(strings.TrimSpace(got[0].ID), "claude-sonnet-4-5") {
|
||||
t.Fatalf("expected remaining model %q, got %+v", "claude-sonnet-4-5", got[0])
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user