mirror of
https://github.com/router-for-me/CLIProxyAPIPlus.git
synced 2026-04-23 20:02:40 +00:00
Compare commits
22 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
d3100085b0 | ||
|
|
f481d25133 | ||
|
|
8c6c90da74 | ||
|
|
24bcfd9c03 | ||
|
|
816fb4c5da | ||
|
|
c1bb77c7c9 | ||
|
|
6bcac3a55a | ||
|
|
fc346f4537 | ||
|
|
43e531a3b6 | ||
|
|
d24ea4ce2a | ||
|
|
2c30c981ae | ||
|
|
aa1da8a858 | ||
|
|
f1e9a787d7 | ||
|
|
514ae341c8 | ||
|
|
8ce07f38dd | ||
|
|
3b3e0d1141 | ||
|
|
7acd428507 | ||
|
|
450d1227bd | ||
|
|
4e26182d14 | ||
|
|
3b421c8181 | ||
|
|
afc8a0f9be | ||
|
|
d693d7993b |
@@ -1929,8 +1929,6 @@ func (h *Handler) RequestGitHubToken(c *gin.Context) {
|
|||||||
state := fmt.Sprintf("gh-%d", time.Now().UnixNano())
|
state := fmt.Sprintf("gh-%d", time.Now().UnixNano())
|
||||||
|
|
||||||
// Initialize Copilot auth service
|
// Initialize Copilot auth service
|
||||||
// We need to import "github.com/router-for-me/CLIProxyAPI/v6/internal/auth/copilot" first if not present
|
|
||||||
// Assuming copilot package is imported as "copilot"
|
|
||||||
deviceClient := copilot.NewDeviceFlowClient(h.cfg)
|
deviceClient := copilot.NewDeviceFlowClient(h.cfg)
|
||||||
|
|
||||||
// Initiate device flow
|
// Initiate device flow
|
||||||
@@ -1944,7 +1942,7 @@ func (h *Handler) RequestGitHubToken(c *gin.Context) {
|
|||||||
authURL := deviceCode.VerificationURI
|
authURL := deviceCode.VerificationURI
|
||||||
userCode := deviceCode.UserCode
|
userCode := deviceCode.UserCode
|
||||||
|
|
||||||
RegisterOAuthSession(state, "github")
|
RegisterOAuthSession(state, "github-copilot")
|
||||||
|
|
||||||
go func() {
|
go func() {
|
||||||
fmt.Printf("Please visit %s and enter code: %s\n", authURL, userCode)
|
fmt.Printf("Please visit %s and enter code: %s\n", authURL, userCode)
|
||||||
@@ -1956,9 +1954,13 @@ func (h *Handler) RequestGitHubToken(c *gin.Context) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
username, errUser := deviceClient.FetchUserInfo(ctx, tokenData.AccessToken)
|
userInfo, errUser := deviceClient.FetchUserInfo(ctx, tokenData.AccessToken)
|
||||||
if errUser != nil {
|
if errUser != nil {
|
||||||
log.Warnf("Failed to fetch user info: %v", errUser)
|
log.Warnf("Failed to fetch user info: %v", errUser)
|
||||||
|
}
|
||||||
|
|
||||||
|
username := userInfo.Login
|
||||||
|
if username == "" {
|
||||||
username = "github-user"
|
username = "github-user"
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -1967,18 +1969,26 @@ func (h *Handler) RequestGitHubToken(c *gin.Context) {
|
|||||||
TokenType: tokenData.TokenType,
|
TokenType: tokenData.TokenType,
|
||||||
Scope: tokenData.Scope,
|
Scope: tokenData.Scope,
|
||||||
Username: username,
|
Username: username,
|
||||||
|
Email: userInfo.Email,
|
||||||
|
Name: userInfo.Name,
|
||||||
Type: "github-copilot",
|
Type: "github-copilot",
|
||||||
}
|
}
|
||||||
|
|
||||||
fileName := fmt.Sprintf("github-%s.json", username)
|
fileName := fmt.Sprintf("github-copilot-%s.json", username)
|
||||||
|
label := userInfo.Email
|
||||||
|
if label == "" {
|
||||||
|
label = username
|
||||||
|
}
|
||||||
record := &coreauth.Auth{
|
record := &coreauth.Auth{
|
||||||
ID: fileName,
|
ID: fileName,
|
||||||
Provider: "github",
|
Provider: "github-copilot",
|
||||||
|
Label: label,
|
||||||
FileName: fileName,
|
FileName: fileName,
|
||||||
Storage: tokenStorage,
|
Storage: tokenStorage,
|
||||||
Metadata: map[string]any{
|
Metadata: map[string]any{
|
||||||
"email": username,
|
"email": userInfo.Email,
|
||||||
"username": username,
|
"username": username,
|
||||||
|
"name": userInfo.Name,
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -1992,7 +2002,7 @@ func (h *Handler) RequestGitHubToken(c *gin.Context) {
|
|||||||
fmt.Printf("Authentication successful! Token saved to %s\n", savedPath)
|
fmt.Printf("Authentication successful! Token saved to %s\n", savedPath)
|
||||||
fmt.Println("You can now use GitHub Copilot services through this CLI")
|
fmt.Println("You can now use GitHub Copilot services through this CLI")
|
||||||
CompleteOAuthSession(state)
|
CompleteOAuthSession(state)
|
||||||
CompleteOAuthSessionsByProvider("github")
|
CompleteOAuthSessionsByProvider("github-copilot")
|
||||||
}()
|
}()
|
||||||
|
|
||||||
c.JSON(200, gin.H{
|
c.JSON(200, gin.H{
|
||||||
|
|||||||
@@ -276,6 +276,10 @@ func (o *CodexAuth) RefreshTokensWithRetry(ctx context.Context, refreshToken str
|
|||||||
if err == nil {
|
if err == nil {
|
||||||
return tokenData, nil
|
return tokenData, nil
|
||||||
}
|
}
|
||||||
|
if isNonRetryableRefreshErr(err) {
|
||||||
|
log.Warnf("Token refresh attempt %d failed with non-retryable error: %v", attempt+1, err)
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
lastErr = err
|
lastErr = err
|
||||||
log.Warnf("Token refresh attempt %d failed: %v", attempt+1, err)
|
log.Warnf("Token refresh attempt %d failed: %v", attempt+1, err)
|
||||||
@@ -284,6 +288,14 @@ func (o *CodexAuth) RefreshTokensWithRetry(ctx context.Context, refreshToken str
|
|||||||
return nil, fmt.Errorf("token refresh failed after %d attempts: %w", maxRetries, lastErr)
|
return nil, fmt.Errorf("token refresh failed after %d attempts: %w", maxRetries, lastErr)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func isNonRetryableRefreshErr(err error) bool {
|
||||||
|
if err == nil {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
raw := strings.ToLower(err.Error())
|
||||||
|
return strings.Contains(raw, "refresh_token_reused")
|
||||||
|
}
|
||||||
|
|
||||||
// UpdateTokenStorage updates an existing CodexTokenStorage with new token data.
|
// UpdateTokenStorage updates an existing CodexTokenStorage with new token data.
|
||||||
// This is typically called after a successful token refresh to persist the new credentials.
|
// This is typically called after a successful token refresh to persist the new credentials.
|
||||||
func (o *CodexAuth) UpdateTokenStorage(storage *CodexTokenStorage, tokenData *CodexTokenData) {
|
func (o *CodexAuth) UpdateTokenStorage(storage *CodexTokenStorage, tokenData *CodexTokenData) {
|
||||||
|
|||||||
44
internal/auth/codex/openai_auth_test.go
Normal file
44
internal/auth/codex/openai_auth_test.go
Normal file
@@ -0,0 +1,44 @@
|
|||||||
|
package codex
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"io"
|
||||||
|
"net/http"
|
||||||
|
"strings"
|
||||||
|
"sync/atomic"
|
||||||
|
"testing"
|
||||||
|
)
|
||||||
|
|
||||||
|
type roundTripFunc func(*http.Request) (*http.Response, error)
|
||||||
|
|
||||||
|
func (f roundTripFunc) RoundTrip(req *http.Request) (*http.Response, error) {
|
||||||
|
return f(req)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRefreshTokensWithRetry_NonRetryableOnlyAttemptsOnce(t *testing.T) {
|
||||||
|
var calls int32
|
||||||
|
auth := &CodexAuth{
|
||||||
|
httpClient: &http.Client{
|
||||||
|
Transport: roundTripFunc(func(req *http.Request) (*http.Response, error) {
|
||||||
|
atomic.AddInt32(&calls, 1)
|
||||||
|
return &http.Response{
|
||||||
|
StatusCode: http.StatusBadRequest,
|
||||||
|
Body: io.NopCloser(strings.NewReader(`{"error":"invalid_grant","code":"refresh_token_reused"}`)),
|
||||||
|
Header: make(http.Header),
|
||||||
|
Request: req,
|
||||||
|
}, nil
|
||||||
|
}),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
_, err := auth.RefreshTokensWithRetry(context.Background(), "dummy_refresh_token", 3)
|
||||||
|
if err == nil {
|
||||||
|
t.Fatalf("expected error for non-retryable refresh failure")
|
||||||
|
}
|
||||||
|
if !strings.Contains(strings.ToLower(err.Error()), "refresh_token_reused") {
|
||||||
|
t.Fatalf("expected refresh_token_reused in error, got: %v", err)
|
||||||
|
}
|
||||||
|
if got := atomic.LoadInt32(&calls); got != 1 {
|
||||||
|
t.Fatalf("expected 1 refresh attempt, got %d", got)
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -82,15 +82,21 @@ func (c *CopilotAuth) WaitForAuthorization(ctx context.Context, deviceCode *Devi
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Fetch the GitHub username
|
// Fetch the GitHub username
|
||||||
username, err := c.deviceClient.FetchUserInfo(ctx, tokenData.AccessToken)
|
userInfo, err := c.deviceClient.FetchUserInfo(ctx, tokenData.AccessToken)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Warnf("copilot: failed to fetch user info: %v", err)
|
log.Warnf("copilot: failed to fetch user info: %v", err)
|
||||||
username = "unknown"
|
}
|
||||||
|
|
||||||
|
username := userInfo.Login
|
||||||
|
if username == "" {
|
||||||
|
username = "github-user"
|
||||||
}
|
}
|
||||||
|
|
||||||
return &CopilotAuthBundle{
|
return &CopilotAuthBundle{
|
||||||
TokenData: tokenData,
|
TokenData: tokenData,
|
||||||
Username: username,
|
Username: username,
|
||||||
|
Email: userInfo.Email,
|
||||||
|
Name: userInfo.Name,
|
||||||
}, nil
|
}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -150,12 +156,12 @@ func (c *CopilotAuth) ValidateToken(ctx context.Context, accessToken string) (bo
|
|||||||
return false, "", nil
|
return false, "", nil
|
||||||
}
|
}
|
||||||
|
|
||||||
username, err := c.deviceClient.FetchUserInfo(ctx, accessToken)
|
userInfo, err := c.deviceClient.FetchUserInfo(ctx, accessToken)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return false, "", err
|
return false, "", err
|
||||||
}
|
}
|
||||||
|
|
||||||
return true, username, nil
|
return true, userInfo.Login, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// CreateTokenStorage creates a new CopilotTokenStorage from auth bundle.
|
// CreateTokenStorage creates a new CopilotTokenStorage from auth bundle.
|
||||||
@@ -165,6 +171,8 @@ func (c *CopilotAuth) CreateTokenStorage(bundle *CopilotAuthBundle) *CopilotToke
|
|||||||
TokenType: bundle.TokenData.TokenType,
|
TokenType: bundle.TokenData.TokenType,
|
||||||
Scope: bundle.TokenData.Scope,
|
Scope: bundle.TokenData.Scope,
|
||||||
Username: bundle.Username,
|
Username: bundle.Username,
|
||||||
|
Email: bundle.Email,
|
||||||
|
Name: bundle.Name,
|
||||||
Type: "github-copilot",
|
Type: "github-copilot",
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -53,7 +53,7 @@ func NewDeviceFlowClient(cfg *config.Config) *DeviceFlowClient {
|
|||||||
func (c *DeviceFlowClient) RequestDeviceCode(ctx context.Context) (*DeviceCodeResponse, error) {
|
func (c *DeviceFlowClient) RequestDeviceCode(ctx context.Context) (*DeviceCodeResponse, error) {
|
||||||
data := url.Values{}
|
data := url.Values{}
|
||||||
data.Set("client_id", copilotClientID)
|
data.Set("client_id", copilotClientID)
|
||||||
data.Set("scope", "user:email")
|
data.Set("scope", "read:user user:email")
|
||||||
|
|
||||||
req, err := http.NewRequestWithContext(ctx, http.MethodPost, copilotDeviceCodeURL, strings.NewReader(data.Encode()))
|
req, err := http.NewRequestWithContext(ctx, http.MethodPost, copilotDeviceCodeURL, strings.NewReader(data.Encode()))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -211,15 +211,25 @@ func (c *DeviceFlowClient) exchangeDeviceCode(ctx context.Context, deviceCode st
|
|||||||
}, nil
|
}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// FetchUserInfo retrieves the GitHub username for the authenticated user.
|
// GitHubUserInfo holds GitHub user profile information.
|
||||||
func (c *DeviceFlowClient) FetchUserInfo(ctx context.Context, accessToken string) (string, error) {
|
type GitHubUserInfo struct {
|
||||||
|
// Login is the GitHub username.
|
||||||
|
Login string
|
||||||
|
// Email is the primary email address (may be empty if not public).
|
||||||
|
Email string
|
||||||
|
// Name is the display name.
|
||||||
|
Name string
|
||||||
|
}
|
||||||
|
|
||||||
|
// FetchUserInfo retrieves the GitHub user profile for the authenticated user.
|
||||||
|
func (c *DeviceFlowClient) FetchUserInfo(ctx context.Context, accessToken string) (GitHubUserInfo, error) {
|
||||||
if accessToken == "" {
|
if accessToken == "" {
|
||||||
return "", NewAuthenticationError(ErrUserInfoFailed, fmt.Errorf("access token is empty"))
|
return GitHubUserInfo{}, NewAuthenticationError(ErrUserInfoFailed, fmt.Errorf("access token is empty"))
|
||||||
}
|
}
|
||||||
|
|
||||||
req, err := http.NewRequestWithContext(ctx, http.MethodGet, copilotUserInfoURL, nil)
|
req, err := http.NewRequestWithContext(ctx, http.MethodGet, copilotUserInfoURL, nil)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return "", NewAuthenticationError(ErrUserInfoFailed, err)
|
return GitHubUserInfo{}, NewAuthenticationError(ErrUserInfoFailed, err)
|
||||||
}
|
}
|
||||||
req.Header.Set("Authorization", "Bearer "+accessToken)
|
req.Header.Set("Authorization", "Bearer "+accessToken)
|
||||||
req.Header.Set("Accept", "application/json")
|
req.Header.Set("Accept", "application/json")
|
||||||
@@ -227,7 +237,7 @@ func (c *DeviceFlowClient) FetchUserInfo(ctx context.Context, accessToken string
|
|||||||
|
|
||||||
resp, err := c.httpClient.Do(req)
|
resp, err := c.httpClient.Do(req)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return "", NewAuthenticationError(ErrUserInfoFailed, err)
|
return GitHubUserInfo{}, NewAuthenticationError(ErrUserInfoFailed, err)
|
||||||
}
|
}
|
||||||
defer func() {
|
defer func() {
|
||||||
if errClose := resp.Body.Close(); errClose != nil {
|
if errClose := resp.Body.Close(); errClose != nil {
|
||||||
@@ -237,19 +247,25 @@ func (c *DeviceFlowClient) FetchUserInfo(ctx context.Context, accessToken string
|
|||||||
|
|
||||||
if !isHTTPSuccess(resp.StatusCode) {
|
if !isHTTPSuccess(resp.StatusCode) {
|
||||||
bodyBytes, _ := io.ReadAll(resp.Body)
|
bodyBytes, _ := io.ReadAll(resp.Body)
|
||||||
return "", NewAuthenticationError(ErrUserInfoFailed, fmt.Errorf("status %d: %s", resp.StatusCode, string(bodyBytes)))
|
return GitHubUserInfo{}, NewAuthenticationError(ErrUserInfoFailed, fmt.Errorf("status %d: %s", resp.StatusCode, string(bodyBytes)))
|
||||||
}
|
}
|
||||||
|
|
||||||
var userInfo struct {
|
var raw struct {
|
||||||
Login string `json:"login"`
|
Login string `json:"login"`
|
||||||
|
Email string `json:"email"`
|
||||||
|
Name string `json:"name"`
|
||||||
}
|
}
|
||||||
if err = json.NewDecoder(resp.Body).Decode(&userInfo); err != nil {
|
if err = json.NewDecoder(resp.Body).Decode(&raw); err != nil {
|
||||||
return "", NewAuthenticationError(ErrUserInfoFailed, err)
|
return GitHubUserInfo{}, NewAuthenticationError(ErrUserInfoFailed, err)
|
||||||
}
|
}
|
||||||
|
|
||||||
if userInfo.Login == "" {
|
if raw.Login == "" {
|
||||||
return "", NewAuthenticationError(ErrUserInfoFailed, fmt.Errorf("empty username"))
|
return GitHubUserInfo{}, NewAuthenticationError(ErrUserInfoFailed, fmt.Errorf("empty username"))
|
||||||
}
|
}
|
||||||
|
|
||||||
return userInfo.Login, nil
|
return GitHubUserInfo{
|
||||||
|
Login: raw.Login,
|
||||||
|
Email: raw.Email,
|
||||||
|
Name: raw.Name,
|
||||||
|
}, nil
|
||||||
}
|
}
|
||||||
|
|||||||
213
internal/auth/copilot/oauth_test.go
Normal file
213
internal/auth/copilot/oauth_test.go
Normal file
@@ -0,0 +1,213 @@
|
|||||||
|
package copilot
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"encoding/json"
|
||||||
|
"net/http"
|
||||||
|
"net/http/httptest"
|
||||||
|
"strings"
|
||||||
|
"testing"
|
||||||
|
)
|
||||||
|
|
||||||
|
// roundTripFunc lets us inject a custom transport for testing.
|
||||||
|
type roundTripFunc func(*http.Request) (*http.Response, error)
|
||||||
|
|
||||||
|
func (f roundTripFunc) RoundTrip(r *http.Request) (*http.Response, error) { return f(r) }
|
||||||
|
|
||||||
|
// newTestClient returns an *http.Client whose requests are redirected to the given test server,
|
||||||
|
// regardless of the original URL host.
|
||||||
|
func newTestClient(srv *httptest.Server) *http.Client {
|
||||||
|
return &http.Client{
|
||||||
|
Transport: roundTripFunc(func(req *http.Request) (*http.Response, error) {
|
||||||
|
req2 := req.Clone(req.Context())
|
||||||
|
req2.URL.Scheme = "http"
|
||||||
|
req2.URL.Host = strings.TrimPrefix(srv.URL, "http://")
|
||||||
|
return srv.Client().Transport.RoundTrip(req2)
|
||||||
|
}),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestFetchUserInfo_FullProfile verifies that FetchUserInfo returns login, email, and name.
|
||||||
|
func TestFetchUserInfo_FullProfile(t *testing.T) {
|
||||||
|
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
if !strings.HasPrefix(r.Header.Get("Authorization"), "Bearer ") {
|
||||||
|
w.WriteHeader(http.StatusUnauthorized)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
w.Header().Set("Content-Type", "application/json")
|
||||||
|
_ = json.NewEncoder(w).Encode(map[string]string{
|
||||||
|
"login": "octocat",
|
||||||
|
"email": "octocat@github.com",
|
||||||
|
"name": "The Octocat",
|
||||||
|
})
|
||||||
|
}))
|
||||||
|
defer srv.Close()
|
||||||
|
|
||||||
|
client := &DeviceFlowClient{httpClient: newTestClient(srv)}
|
||||||
|
info, err := client.FetchUserInfo(context.Background(), "test-token")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("unexpected error: %v", err)
|
||||||
|
}
|
||||||
|
if info.Login != "octocat" {
|
||||||
|
t.Errorf("Login: got %q, want %q", info.Login, "octocat")
|
||||||
|
}
|
||||||
|
if info.Email != "octocat@github.com" {
|
||||||
|
t.Errorf("Email: got %q, want %q", info.Email, "octocat@github.com")
|
||||||
|
}
|
||||||
|
if info.Name != "The Octocat" {
|
||||||
|
t.Errorf("Name: got %q, want %q", info.Name, "The Octocat")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestFetchUserInfo_EmptyEmail verifies graceful handling when email is absent (private account).
|
||||||
|
func TestFetchUserInfo_EmptyEmail(t *testing.T) {
|
||||||
|
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
w.Header().Set("Content-Type", "application/json")
|
||||||
|
// GitHub returns null for private emails.
|
||||||
|
_, _ = w.Write([]byte(`{"login":"privateuser","email":null,"name":"Private User"}`))
|
||||||
|
}))
|
||||||
|
defer srv.Close()
|
||||||
|
|
||||||
|
client := &DeviceFlowClient{httpClient: newTestClient(srv)}
|
||||||
|
info, err := client.FetchUserInfo(context.Background(), "test-token")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("unexpected error: %v", err)
|
||||||
|
}
|
||||||
|
if info.Login != "privateuser" {
|
||||||
|
t.Errorf("Login: got %q, want %q", info.Login, "privateuser")
|
||||||
|
}
|
||||||
|
if info.Email != "" {
|
||||||
|
t.Errorf("Email: got %q, want empty string", info.Email)
|
||||||
|
}
|
||||||
|
if info.Name != "Private User" {
|
||||||
|
t.Errorf("Name: got %q, want %q", info.Name, "Private User")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestFetchUserInfo_EmptyToken verifies error is returned for empty access token.
|
||||||
|
func TestFetchUserInfo_EmptyToken(t *testing.T) {
|
||||||
|
client := &DeviceFlowClient{httpClient: http.DefaultClient}
|
||||||
|
_, err := client.FetchUserInfo(context.Background(), "")
|
||||||
|
if err == nil {
|
||||||
|
t.Fatal("expected error for empty token, got nil")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestFetchUserInfo_EmptyLogin verifies error is returned when API returns no login.
|
||||||
|
func TestFetchUserInfo_EmptyLogin(t *testing.T) {
|
||||||
|
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
w.Header().Set("Content-Type", "application/json")
|
||||||
|
_, _ = w.Write([]byte(`{"email":"someone@example.com","name":"No Login"}`))
|
||||||
|
}))
|
||||||
|
defer srv.Close()
|
||||||
|
|
||||||
|
client := &DeviceFlowClient{httpClient: newTestClient(srv)}
|
||||||
|
_, err := client.FetchUserInfo(context.Background(), "test-token")
|
||||||
|
if err == nil {
|
||||||
|
t.Fatal("expected error for empty login, got nil")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestFetchUserInfo_HTTPError verifies error is returned on non-2xx response.
|
||||||
|
func TestFetchUserInfo_HTTPError(t *testing.T) {
|
||||||
|
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
w.WriteHeader(http.StatusUnauthorized)
|
||||||
|
_, _ = w.Write([]byte(`{"message":"Bad credentials"}`))
|
||||||
|
}))
|
||||||
|
defer srv.Close()
|
||||||
|
|
||||||
|
client := &DeviceFlowClient{httpClient: newTestClient(srv)}
|
||||||
|
_, err := client.FetchUserInfo(context.Background(), "bad-token")
|
||||||
|
if err == nil {
|
||||||
|
t.Fatal("expected error for 401 response, got nil")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestCopilotTokenStorage_EmailNameFields verifies Email and Name serialise correctly.
|
||||||
|
func TestCopilotTokenStorage_EmailNameFields(t *testing.T) {
|
||||||
|
ts := &CopilotTokenStorage{
|
||||||
|
AccessToken: "ghu_abc",
|
||||||
|
TokenType: "bearer",
|
||||||
|
Scope: "read:user user:email",
|
||||||
|
Username: "octocat",
|
||||||
|
Email: "octocat@github.com",
|
||||||
|
Name: "The Octocat",
|
||||||
|
Type: "github-copilot",
|
||||||
|
}
|
||||||
|
|
||||||
|
data, err := json.Marshal(ts)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("marshal error: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
var out map[string]any
|
||||||
|
if err = json.Unmarshal(data, &out); err != nil {
|
||||||
|
t.Fatalf("unmarshal error: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, key := range []string{"access_token", "username", "email", "name", "type"} {
|
||||||
|
if _, ok := out[key]; !ok {
|
||||||
|
t.Errorf("expected key %q in JSON output, not found", key)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if out["email"] != "octocat@github.com" {
|
||||||
|
t.Errorf("email: got %v, want %q", out["email"], "octocat@github.com")
|
||||||
|
}
|
||||||
|
if out["name"] != "The Octocat" {
|
||||||
|
t.Errorf("name: got %v, want %q", out["name"], "The Octocat")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestCopilotTokenStorage_OmitEmptyEmailName verifies email/name are omitted when empty (omitempty).
|
||||||
|
func TestCopilotTokenStorage_OmitEmptyEmailName(t *testing.T) {
|
||||||
|
ts := &CopilotTokenStorage{
|
||||||
|
AccessToken: "ghu_abc",
|
||||||
|
Username: "octocat",
|
||||||
|
Type: "github-copilot",
|
||||||
|
}
|
||||||
|
|
||||||
|
data, err := json.Marshal(ts)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("marshal error: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
var out map[string]any
|
||||||
|
if err = json.Unmarshal(data, &out); err != nil {
|
||||||
|
t.Fatalf("unmarshal error: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if _, ok := out["email"]; ok {
|
||||||
|
t.Error("email key should be omitted when empty (omitempty), but was present")
|
||||||
|
}
|
||||||
|
if _, ok := out["name"]; ok {
|
||||||
|
t.Error("name key should be omitted when empty (omitempty), but was present")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestCopilotAuthBundle_EmailNameFields verifies bundle carries email and name through the pipeline.
|
||||||
|
func TestCopilotAuthBundle_EmailNameFields(t *testing.T) {
|
||||||
|
bundle := &CopilotAuthBundle{
|
||||||
|
TokenData: &CopilotTokenData{AccessToken: "ghu_abc"},
|
||||||
|
Username: "octocat",
|
||||||
|
Email: "octocat@github.com",
|
||||||
|
Name: "The Octocat",
|
||||||
|
}
|
||||||
|
if bundle.Email != "octocat@github.com" {
|
||||||
|
t.Errorf("bundle.Email: got %q, want %q", bundle.Email, "octocat@github.com")
|
||||||
|
}
|
||||||
|
if bundle.Name != "The Octocat" {
|
||||||
|
t.Errorf("bundle.Name: got %q, want %q", bundle.Name, "The Octocat")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestGitHubUserInfo_Struct verifies the exported GitHubUserInfo struct fields are accessible.
|
||||||
|
func TestGitHubUserInfo_Struct(t *testing.T) {
|
||||||
|
info := GitHubUserInfo{
|
||||||
|
Login: "octocat",
|
||||||
|
Email: "octocat@github.com",
|
||||||
|
Name: "The Octocat",
|
||||||
|
}
|
||||||
|
if info.Login == "" || info.Email == "" || info.Name == "" {
|
||||||
|
t.Error("GitHubUserInfo fields should not be empty")
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -26,6 +26,10 @@ type CopilotTokenStorage struct {
|
|||||||
ExpiresAt string `json:"expires_at,omitempty"`
|
ExpiresAt string `json:"expires_at,omitempty"`
|
||||||
// Username is the GitHub username associated with this token.
|
// Username is the GitHub username associated with this token.
|
||||||
Username string `json:"username"`
|
Username string `json:"username"`
|
||||||
|
// Email is the GitHub email address associated with this token.
|
||||||
|
Email string `json:"email,omitempty"`
|
||||||
|
// Name is the GitHub display name associated with this token.
|
||||||
|
Name string `json:"name,omitempty"`
|
||||||
// Type indicates the authentication provider type, always "github-copilot" for this storage.
|
// Type indicates the authentication provider type, always "github-copilot" for this storage.
|
||||||
Type string `json:"type"`
|
Type string `json:"type"`
|
||||||
}
|
}
|
||||||
@@ -46,6 +50,10 @@ type CopilotAuthBundle struct {
|
|||||||
TokenData *CopilotTokenData
|
TokenData *CopilotTokenData
|
||||||
// Username is the GitHub username.
|
// Username is the GitHub username.
|
||||||
Username string
|
Username string
|
||||||
|
// Email is the GitHub email address.
|
||||||
|
Email string
|
||||||
|
// Name is the GitHub display name.
|
||||||
|
Name string
|
||||||
}
|
}
|
||||||
|
|
||||||
// DeviceCodeResponse represents GitHub's device code response.
|
// DeviceCodeResponse represents GitHub's device code response.
|
||||||
|
|||||||
@@ -916,19 +916,12 @@ func GetIFlowModels() []*ModelInfo {
|
|||||||
Created int64
|
Created int64
|
||||||
Thinking *ThinkingSupport
|
Thinking *ThinkingSupport
|
||||||
}{
|
}{
|
||||||
{ID: "tstars2.0", DisplayName: "TStars-2.0", Description: "iFlow TStars-2.0 multimodal assistant", Created: 1746489600},
|
|
||||||
{ID: "qwen3-coder-plus", DisplayName: "Qwen3-Coder-Plus", Description: "Qwen3 Coder Plus code generation", Created: 1753228800},
|
{ID: "qwen3-coder-plus", DisplayName: "Qwen3-Coder-Plus", Description: "Qwen3 Coder Plus code generation", Created: 1753228800},
|
||||||
{ID: "qwen3-max", DisplayName: "Qwen3-Max", Description: "Qwen3 flagship model", Created: 1758672000},
|
{ID: "qwen3-max", DisplayName: "Qwen3-Max", Description: "Qwen3 flagship model", Created: 1758672000},
|
||||||
{ID: "qwen3-vl-plus", DisplayName: "Qwen3-VL-Plus", Description: "Qwen3 multimodal vision-language", Created: 1758672000},
|
{ID: "qwen3-vl-plus", DisplayName: "Qwen3-VL-Plus", Description: "Qwen3 multimodal vision-language", Created: 1758672000},
|
||||||
{ID: "qwen3-max-preview", DisplayName: "Qwen3-Max-Preview", Description: "Qwen3 Max preview build", Created: 1757030400, Thinking: iFlowThinkingSupport},
|
{ID: "qwen3-max-preview", DisplayName: "Qwen3-Max-Preview", Description: "Qwen3 Max preview build", Created: 1757030400, Thinking: iFlowThinkingSupport},
|
||||||
{ID: "kimi-k2-0905", DisplayName: "Kimi-K2-Instruct-0905", Description: "Moonshot Kimi K2 instruct 0905", Created: 1757030400},
|
|
||||||
{ID: "glm-4.6", DisplayName: "GLM-4.6", Description: "Zhipu GLM 4.6 general model", Created: 1759190400, Thinking: iFlowThinkingSupport},
|
{ID: "glm-4.6", DisplayName: "GLM-4.6", Description: "Zhipu GLM 4.6 general model", Created: 1759190400, Thinking: iFlowThinkingSupport},
|
||||||
{ID: "glm-4.7", DisplayName: "GLM-4.7", Description: "Zhipu GLM 4.7 general model", Created: 1766448000, Thinking: iFlowThinkingSupport},
|
|
||||||
{ID: "glm-5", DisplayName: "GLM-5", Description: "Zhipu GLM 5 general model", Created: 1770768000, Thinking: iFlowThinkingSupport},
|
|
||||||
{ID: "kimi-k2", DisplayName: "Kimi-K2", Description: "Moonshot Kimi K2 general model", Created: 1752192000},
|
{ID: "kimi-k2", DisplayName: "Kimi-K2", Description: "Moonshot Kimi K2 general model", Created: 1752192000},
|
||||||
{ID: "kimi-k2-thinking", DisplayName: "Kimi-K2-Thinking", Description: "Moonshot Kimi K2 thinking model", Created: 1762387200},
|
|
||||||
{ID: "deepseek-v3.2-chat", DisplayName: "DeepSeek-V3.2", Description: "DeepSeek V3.2 Chat", Created: 1764576000},
|
|
||||||
{ID: "deepseek-v3.2-reasoner", DisplayName: "DeepSeek-V3.2", Description: "DeepSeek V3.2 Reasoner", Created: 1764576000},
|
|
||||||
{ID: "deepseek-v3.2", DisplayName: "DeepSeek-V3.2-Exp", Description: "DeepSeek V3.2 experimental", Created: 1759104000, Thinking: iFlowThinkingSupport},
|
{ID: "deepseek-v3.2", DisplayName: "DeepSeek-V3.2-Exp", Description: "DeepSeek V3.2 experimental", Created: 1759104000, Thinking: iFlowThinkingSupport},
|
||||||
{ID: "deepseek-v3.1", DisplayName: "DeepSeek-V3.1-Terminus", Description: "DeepSeek V3.1 Terminus", Created: 1756339200, Thinking: iFlowThinkingSupport},
|
{ID: "deepseek-v3.1", DisplayName: "DeepSeek-V3.1-Terminus", Description: "DeepSeek V3.1 Terminus", Created: 1756339200, Thinking: iFlowThinkingSupport},
|
||||||
{ID: "deepseek-r1", DisplayName: "DeepSeek-R1", Description: "DeepSeek reasoning model R1", Created: 1737331200},
|
{ID: "deepseek-r1", DisplayName: "DeepSeek-R1", Description: "DeepSeek reasoning model R1", Created: 1737331200},
|
||||||
@@ -937,11 +930,7 @@ func GetIFlowModels() []*ModelInfo {
|
|||||||
{ID: "qwen3-235b-a22b-thinking-2507", DisplayName: "Qwen3-235B-A22B-Thinking", Description: "Qwen3 235B A22B Thinking (2507)", Created: 1753401600},
|
{ID: "qwen3-235b-a22b-thinking-2507", DisplayName: "Qwen3-235B-A22B-Thinking", Description: "Qwen3 235B A22B Thinking (2507)", Created: 1753401600},
|
||||||
{ID: "qwen3-235b-a22b-instruct", DisplayName: "Qwen3-235B-A22B-Instruct", Description: "Qwen3 235B A22B Instruct", Created: 1753401600},
|
{ID: "qwen3-235b-a22b-instruct", DisplayName: "Qwen3-235B-A22B-Instruct", Description: "Qwen3 235B A22B Instruct", Created: 1753401600},
|
||||||
{ID: "qwen3-235b", DisplayName: "Qwen3-235B-A22B", Description: "Qwen3 235B A22B", Created: 1753401600},
|
{ID: "qwen3-235b", DisplayName: "Qwen3-235B-A22B", Description: "Qwen3 235B A22B", Created: 1753401600},
|
||||||
{ID: "minimax-m2", DisplayName: "MiniMax-M2", Description: "MiniMax M2", Created: 1758672000, Thinking: iFlowThinkingSupport},
|
|
||||||
{ID: "minimax-m2.1", DisplayName: "MiniMax-M2.1", Description: "MiniMax M2.1", Created: 1766448000, Thinking: iFlowThinkingSupport},
|
|
||||||
{ID: "minimax-m2.5", DisplayName: "MiniMax-M2.5", Description: "MiniMax M2.5", Created: 1770825600, Thinking: iFlowThinkingSupport},
|
|
||||||
{ID: "iflow-rome-30ba3b", DisplayName: "iFlow-ROME", Description: "iFlow Rome 30BA3B model", Created: 1736899200},
|
{ID: "iflow-rome-30ba3b", DisplayName: "iFlow-ROME", Description: "iFlow Rome 30BA3B model", Created: 1736899200},
|
||||||
{ID: "kimi-k2.5", DisplayName: "Kimi-K2.5", Description: "Moonshot Kimi K2.5", Created: 1769443200, Thinking: iFlowThinkingSupport},
|
|
||||||
}
|
}
|
||||||
models := make([]*ModelInfo, 0, len(entries))
|
models := make([]*ModelInfo, 0, len(entries))
|
||||||
for _, entry := range entries {
|
for _, entry := range entries {
|
||||||
|
|||||||
@@ -54,8 +54,78 @@ const (
|
|||||||
var (
|
var (
|
||||||
randSource = rand.New(rand.NewSource(time.Now().UnixNano()))
|
randSource = rand.New(rand.NewSource(time.Now().UnixNano()))
|
||||||
randSourceMutex sync.Mutex
|
randSourceMutex sync.Mutex
|
||||||
|
// antigravityPrimaryModelsCache keeps the latest non-empty model list fetched
|
||||||
|
// from any antigravity auth. Empty fetches never overwrite this cache.
|
||||||
|
antigravityPrimaryModelsCache struct {
|
||||||
|
mu sync.RWMutex
|
||||||
|
models []*registry.ModelInfo
|
||||||
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
|
func cloneAntigravityModels(models []*registry.ModelInfo) []*registry.ModelInfo {
|
||||||
|
if len(models) == 0 {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
out := make([]*registry.ModelInfo, 0, len(models))
|
||||||
|
for _, model := range models {
|
||||||
|
if model == nil || strings.TrimSpace(model.ID) == "" {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
out = append(out, cloneAntigravityModelInfo(model))
|
||||||
|
}
|
||||||
|
if len(out) == 0 {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
return out
|
||||||
|
}
|
||||||
|
|
||||||
|
func cloneAntigravityModelInfo(model *registry.ModelInfo) *registry.ModelInfo {
|
||||||
|
if model == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
clone := *model
|
||||||
|
if len(model.SupportedGenerationMethods) > 0 {
|
||||||
|
clone.SupportedGenerationMethods = append([]string(nil), model.SupportedGenerationMethods...)
|
||||||
|
}
|
||||||
|
if len(model.SupportedParameters) > 0 {
|
||||||
|
clone.SupportedParameters = append([]string(nil), model.SupportedParameters...)
|
||||||
|
}
|
||||||
|
if model.Thinking != nil {
|
||||||
|
thinkingClone := *model.Thinking
|
||||||
|
if len(model.Thinking.Levels) > 0 {
|
||||||
|
thinkingClone.Levels = append([]string(nil), model.Thinking.Levels...)
|
||||||
|
}
|
||||||
|
clone.Thinking = &thinkingClone
|
||||||
|
}
|
||||||
|
return &clone
|
||||||
|
}
|
||||||
|
|
||||||
|
func storeAntigravityPrimaryModels(models []*registry.ModelInfo) bool {
|
||||||
|
cloned := cloneAntigravityModels(models)
|
||||||
|
if len(cloned) == 0 {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
antigravityPrimaryModelsCache.mu.Lock()
|
||||||
|
antigravityPrimaryModelsCache.models = cloned
|
||||||
|
antigravityPrimaryModelsCache.mu.Unlock()
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
func loadAntigravityPrimaryModels() []*registry.ModelInfo {
|
||||||
|
antigravityPrimaryModelsCache.mu.RLock()
|
||||||
|
cloned := cloneAntigravityModels(antigravityPrimaryModelsCache.models)
|
||||||
|
antigravityPrimaryModelsCache.mu.RUnlock()
|
||||||
|
return cloned
|
||||||
|
}
|
||||||
|
|
||||||
|
func fallbackAntigravityPrimaryModels() []*registry.ModelInfo {
|
||||||
|
models := loadAntigravityPrimaryModels()
|
||||||
|
if len(models) > 0 {
|
||||||
|
log.Debugf("antigravity executor: using cached primary model list (%d models)", len(models))
|
||||||
|
}
|
||||||
|
return models
|
||||||
|
}
|
||||||
|
|
||||||
// AntigravityExecutor proxies requests to the antigravity upstream.
|
// AntigravityExecutor proxies requests to the antigravity upstream.
|
||||||
type AntigravityExecutor struct {
|
type AntigravityExecutor struct {
|
||||||
cfg *config.Config
|
cfg *config.Config
|
||||||
@@ -1006,14 +1076,9 @@ func (e *AntigravityExecutor) CountTokens(ctx context.Context, auth *cliproxyaut
|
|||||||
func FetchAntigravityModels(ctx context.Context, auth *cliproxyauth.Auth, cfg *config.Config) []*registry.ModelInfo {
|
func FetchAntigravityModels(ctx context.Context, auth *cliproxyauth.Auth, cfg *config.Config) []*registry.ModelInfo {
|
||||||
exec := &AntigravityExecutor{cfg: cfg}
|
exec := &AntigravityExecutor{cfg: cfg}
|
||||||
token, updatedAuth, errToken := exec.ensureAccessToken(ctx, auth)
|
token, updatedAuth, errToken := exec.ensureAccessToken(ctx, auth)
|
||||||
if errToken != nil {
|
if errToken != nil || token == "" {
|
||||||
log.Warnf("antigravity executor: fetch models failed for %s: token error: %v", auth.ID, errToken)
|
return fallbackAntigravityPrimaryModels()
|
||||||
return nil
|
}
|
||||||
}
|
|
||||||
if token == "" {
|
|
||||||
log.Warnf("antigravity executor: fetch models failed for %s: got empty token", auth.ID)
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
if updatedAuth != nil {
|
if updatedAuth != nil {
|
||||||
auth = updatedAuth
|
auth = updatedAuth
|
||||||
}
|
}
|
||||||
@@ -1025,8 +1090,7 @@ func FetchAntigravityModels(ctx context.Context, auth *cliproxyauth.Auth, cfg *c
|
|||||||
modelsURL := baseURL + antigravityModelsPath
|
modelsURL := baseURL + antigravityModelsPath
|
||||||
httpReq, errReq := http.NewRequestWithContext(ctx, http.MethodPost, modelsURL, bytes.NewReader([]byte(`{}`)))
|
httpReq, errReq := http.NewRequestWithContext(ctx, http.MethodPost, modelsURL, bytes.NewReader([]byte(`{}`)))
|
||||||
if errReq != nil {
|
if errReq != nil {
|
||||||
log.Warnf("antigravity executor: fetch models failed for %s: create request error: %v", auth.ID, errReq)
|
return fallbackAntigravityPrimaryModels()
|
||||||
return nil
|
|
||||||
}
|
}
|
||||||
httpReq.Header.Set("Content-Type", "application/json")
|
httpReq.Header.Set("Content-Type", "application/json")
|
||||||
httpReq.Header.Set("Authorization", "Bearer "+token)
|
httpReq.Header.Set("Authorization", "Bearer "+token)
|
||||||
@@ -1038,15 +1102,13 @@ func FetchAntigravityModels(ctx context.Context, auth *cliproxyauth.Auth, cfg *c
|
|||||||
httpResp, errDo := httpClient.Do(httpReq)
|
httpResp, errDo := httpClient.Do(httpReq)
|
||||||
if errDo != nil {
|
if errDo != nil {
|
||||||
if errors.Is(errDo, context.Canceled) || errors.Is(errDo, context.DeadlineExceeded) {
|
if errors.Is(errDo, context.Canceled) || errors.Is(errDo, context.DeadlineExceeded) {
|
||||||
log.Warnf("antigravity executor: fetch models failed for %s: context canceled: %v", auth.ID, errDo)
|
return fallbackAntigravityPrimaryModels()
|
||||||
return nil
|
|
||||||
}
|
}
|
||||||
if idx+1 < len(baseURLs) {
|
if idx+1 < len(baseURLs) {
|
||||||
log.Debugf("antigravity executor: models request error on base url %s, retrying with fallback base url: %s", baseURL, baseURLs[idx+1])
|
log.Debugf("antigravity executor: models request error on base url %s, retrying with fallback base url: %s", baseURL, baseURLs[idx+1])
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
log.Warnf("antigravity executor: fetch models failed for %s: request error: %v", auth.ID, errDo)
|
return fallbackAntigravityPrimaryModels()
|
||||||
return nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
bodyBytes, errRead := io.ReadAll(httpResp.Body)
|
bodyBytes, errRead := io.ReadAll(httpResp.Body)
|
||||||
@@ -1058,22 +1120,27 @@ func FetchAntigravityModels(ctx context.Context, auth *cliproxyauth.Auth, cfg *c
|
|||||||
log.Debugf("antigravity executor: models read error on base url %s, retrying with fallback base url: %s", baseURL, baseURLs[idx+1])
|
log.Debugf("antigravity executor: models read error on base url %s, retrying with fallback base url: %s", baseURL, baseURLs[idx+1])
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
log.Warnf("antigravity executor: fetch models failed for %s: read body error: %v", auth.ID, errRead)
|
return fallbackAntigravityPrimaryModels()
|
||||||
return nil
|
|
||||||
}
|
}
|
||||||
if httpResp.StatusCode < http.StatusOK || httpResp.StatusCode >= http.StatusMultipleChoices {
|
if httpResp.StatusCode < http.StatusOK || httpResp.StatusCode >= http.StatusMultipleChoices {
|
||||||
if httpResp.StatusCode == http.StatusTooManyRequests && idx+1 < len(baseURLs) {
|
if httpResp.StatusCode == http.StatusTooManyRequests && idx+1 < len(baseURLs) {
|
||||||
log.Debugf("antigravity executor: models request rate limited on base url %s, retrying with fallback base url: %s", baseURL, baseURLs[idx+1])
|
log.Debugf("antigravity executor: models request rate limited on base url %s, retrying with fallback base url: %s", baseURL, baseURLs[idx+1])
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
log.Warnf("antigravity executor: fetch models failed for %s: unexpected status %d, body: %s", auth.ID, httpResp.StatusCode, string(bodyBytes))
|
if idx+1 < len(baseURLs) {
|
||||||
return nil
|
log.Debugf("antigravity executor: models request failed with status %d on base url %s, retrying with fallback base url: %s", httpResp.StatusCode, baseURL, baseURLs[idx+1])
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
return fallbackAntigravityPrimaryModels()
|
||||||
}
|
}
|
||||||
|
|
||||||
result := gjson.GetBytes(bodyBytes, "models")
|
result := gjson.GetBytes(bodyBytes, "models")
|
||||||
if !result.Exists() {
|
if !result.Exists() {
|
||||||
log.Warnf("antigravity executor: fetch models failed for %s: no models field in response, body: %s", auth.ID, string(bodyBytes))
|
if idx+1 < len(baseURLs) {
|
||||||
return nil
|
log.Debugf("antigravity executor: models field missing on base url %s, retrying with fallback base url: %s", baseURL, baseURLs[idx+1])
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
return fallbackAntigravityPrimaryModels()
|
||||||
}
|
}
|
||||||
|
|
||||||
now := time.Now().Unix()
|
now := time.Now().Unix()
|
||||||
@@ -1118,9 +1185,18 @@ func FetchAntigravityModels(ctx context.Context, auth *cliproxyauth.Auth, cfg *c
|
|||||||
}
|
}
|
||||||
models = append(models, modelInfo)
|
models = append(models, modelInfo)
|
||||||
}
|
}
|
||||||
|
if len(models) == 0 {
|
||||||
|
if idx+1 < len(baseURLs) {
|
||||||
|
log.Debugf("antigravity executor: empty models list on base url %s, retrying with fallback base url: %s", baseURL, baseURLs[idx+1])
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
log.Debug("antigravity executor: fetched empty model list; retaining cached primary model list")
|
||||||
|
return fallbackAntigravityPrimaryModels()
|
||||||
|
}
|
||||||
|
storeAntigravityPrimaryModels(models)
|
||||||
return models
|
return models
|
||||||
}
|
}
|
||||||
return nil
|
return fallbackAntigravityPrimaryModels()
|
||||||
}
|
}
|
||||||
|
|
||||||
func (e *AntigravityExecutor) ensureAccessToken(ctx context.Context, auth *cliproxyauth.Auth) (string, *cliproxyauth.Auth, error) {
|
func (e *AntigravityExecutor) ensureAccessToken(ctx context.Context, auth *cliproxyauth.Auth) (string, *cliproxyauth.Auth, error) {
|
||||||
|
|||||||
@@ -0,0 +1,90 @@
|
|||||||
|
package executor
|
||||||
|
|
||||||
|
import (
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/router-for-me/CLIProxyAPI/v6/internal/registry"
|
||||||
|
)
|
||||||
|
|
||||||
|
func resetAntigravityPrimaryModelsCacheForTest() {
|
||||||
|
antigravityPrimaryModelsCache.mu.Lock()
|
||||||
|
antigravityPrimaryModelsCache.models = nil
|
||||||
|
antigravityPrimaryModelsCache.mu.Unlock()
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestStoreAntigravityPrimaryModels_EmptyDoesNotOverwrite(t *testing.T) {
|
||||||
|
resetAntigravityPrimaryModelsCacheForTest()
|
||||||
|
t.Cleanup(resetAntigravityPrimaryModelsCacheForTest)
|
||||||
|
|
||||||
|
seed := []*registry.ModelInfo{
|
||||||
|
{ID: "claude-sonnet-4-5"},
|
||||||
|
{ID: "gemini-2.5-pro"},
|
||||||
|
}
|
||||||
|
if updated := storeAntigravityPrimaryModels(seed); !updated {
|
||||||
|
t.Fatal("expected non-empty model list to update primary cache")
|
||||||
|
}
|
||||||
|
|
||||||
|
if updated := storeAntigravityPrimaryModels(nil); updated {
|
||||||
|
t.Fatal("expected nil model list not to overwrite primary cache")
|
||||||
|
}
|
||||||
|
if updated := storeAntigravityPrimaryModels([]*registry.ModelInfo{}); updated {
|
||||||
|
t.Fatal("expected empty model list not to overwrite primary cache")
|
||||||
|
}
|
||||||
|
|
||||||
|
got := loadAntigravityPrimaryModels()
|
||||||
|
if len(got) != 2 {
|
||||||
|
t.Fatalf("expected cached model count 2, got %d", len(got))
|
||||||
|
}
|
||||||
|
if got[0].ID != "claude-sonnet-4-5" || got[1].ID != "gemini-2.5-pro" {
|
||||||
|
t.Fatalf("unexpected cached model ids: %q, %q", got[0].ID, got[1].ID)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestLoadAntigravityPrimaryModels_ReturnsClone(t *testing.T) {
|
||||||
|
resetAntigravityPrimaryModelsCacheForTest()
|
||||||
|
t.Cleanup(resetAntigravityPrimaryModelsCacheForTest)
|
||||||
|
|
||||||
|
if updated := storeAntigravityPrimaryModels([]*registry.ModelInfo{{
|
||||||
|
ID: "gpt-5",
|
||||||
|
DisplayName: "GPT-5",
|
||||||
|
SupportedGenerationMethods: []string{"generateContent"},
|
||||||
|
SupportedParameters: []string{"temperature"},
|
||||||
|
Thinking: ®istry.ThinkingSupport{
|
||||||
|
Levels: []string{"high"},
|
||||||
|
},
|
||||||
|
}}); !updated {
|
||||||
|
t.Fatal("expected model cache update")
|
||||||
|
}
|
||||||
|
|
||||||
|
got := loadAntigravityPrimaryModels()
|
||||||
|
if len(got) != 1 {
|
||||||
|
t.Fatalf("expected one cached model, got %d", len(got))
|
||||||
|
}
|
||||||
|
got[0].ID = "mutated-id"
|
||||||
|
if len(got[0].SupportedGenerationMethods) > 0 {
|
||||||
|
got[0].SupportedGenerationMethods[0] = "mutated-method"
|
||||||
|
}
|
||||||
|
if len(got[0].SupportedParameters) > 0 {
|
||||||
|
got[0].SupportedParameters[0] = "mutated-parameter"
|
||||||
|
}
|
||||||
|
if got[0].Thinking != nil && len(got[0].Thinking.Levels) > 0 {
|
||||||
|
got[0].Thinking.Levels[0] = "mutated-level"
|
||||||
|
}
|
||||||
|
|
||||||
|
again := loadAntigravityPrimaryModels()
|
||||||
|
if len(again) != 1 {
|
||||||
|
t.Fatalf("expected one cached model after mutation, got %d", len(again))
|
||||||
|
}
|
||||||
|
if again[0].ID != "gpt-5" {
|
||||||
|
t.Fatalf("expected cached model id to remain %q, got %q", "gpt-5", again[0].ID)
|
||||||
|
}
|
||||||
|
if len(again[0].SupportedGenerationMethods) == 0 || again[0].SupportedGenerationMethods[0] != "generateContent" {
|
||||||
|
t.Fatalf("expected cached generation methods to be unmutated, got %v", again[0].SupportedGenerationMethods)
|
||||||
|
}
|
||||||
|
if len(again[0].SupportedParameters) == 0 || again[0].SupportedParameters[0] != "temperature" {
|
||||||
|
t.Fatalf("expected cached supported parameters to be unmutated, got %v", again[0].SupportedParameters)
|
||||||
|
}
|
||||||
|
if again[0].Thinking == nil || len(again[0].Thinking.Levels) == 0 || again[0].Thinking.Levels[0] != "high" {
|
||||||
|
t.Fatalf("expected cached model thinking levels to be unmutated, got %v", again[0].Thinking)
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -8,6 +8,7 @@ import (
|
|||||||
"io"
|
"io"
|
||||||
"net/http"
|
"net/http"
|
||||||
"strings"
|
"strings"
|
||||||
|
"sync"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
qwenauth "github.com/router-for-me/CLIProxyAPI/v6/internal/auth/qwen"
|
qwenauth "github.com/router-for-me/CLIProxyAPI/v6/internal/auth/qwen"
|
||||||
@@ -22,9 +23,151 @@ import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
const (
|
const (
|
||||||
qwenUserAgent = "QwenCode/0.10.3 (darwin; arm64)"
|
qwenUserAgent = "QwenCode/0.10.3 (darwin; arm64)"
|
||||||
|
qwenRateLimitPerMin = 60 // 60 requests per minute per credential
|
||||||
|
qwenRateLimitWindow = time.Minute // sliding window duration
|
||||||
)
|
)
|
||||||
|
|
||||||
|
// qwenBeijingLoc caches the Beijing timezone to avoid repeated LoadLocation syscalls.
|
||||||
|
var qwenBeijingLoc = func() *time.Location {
|
||||||
|
loc, err := time.LoadLocation("Asia/Shanghai")
|
||||||
|
if err != nil || loc == nil {
|
||||||
|
log.Warnf("qwen: failed to load Asia/Shanghai timezone: %v, using fixed UTC+8", err)
|
||||||
|
return time.FixedZone("CST", 8*3600)
|
||||||
|
}
|
||||||
|
return loc
|
||||||
|
}()
|
||||||
|
|
||||||
|
// qwenQuotaCodes is a package-level set of error codes that indicate quota exhaustion.
|
||||||
|
var qwenQuotaCodes = map[string]struct{}{
|
||||||
|
"insufficient_quota": {},
|
||||||
|
"quota_exceeded": {},
|
||||||
|
}
|
||||||
|
|
||||||
|
// qwenRateLimiter tracks request timestamps per credential for rate limiting.
|
||||||
|
// Qwen has a limit of 60 requests per minute per account.
|
||||||
|
var qwenRateLimiter = struct {
|
||||||
|
sync.Mutex
|
||||||
|
requests map[string][]time.Time // authID -> request timestamps
|
||||||
|
}{
|
||||||
|
requests: make(map[string][]time.Time),
|
||||||
|
}
|
||||||
|
|
||||||
|
// redactAuthID returns a redacted version of the auth ID for safe logging.
|
||||||
|
// Keeps a small prefix/suffix to allow correlation across events.
|
||||||
|
func redactAuthID(id string) string {
|
||||||
|
if id == "" {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
if len(id) <= 8 {
|
||||||
|
return id
|
||||||
|
}
|
||||||
|
return id[:4] + "..." + id[len(id)-4:]
|
||||||
|
}
|
||||||
|
|
||||||
|
// checkQwenRateLimit checks if the credential has exceeded the rate limit.
|
||||||
|
// Returns nil if allowed, or a statusErr with retryAfter if rate limited.
|
||||||
|
func checkQwenRateLimit(authID string) error {
|
||||||
|
if authID == "" {
|
||||||
|
// Empty authID should not bypass rate limiting in production
|
||||||
|
// Use debug level to avoid log spam for certain auth flows
|
||||||
|
log.Debug("qwen rate limit check: empty authID, skipping rate limit")
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
now := time.Now()
|
||||||
|
windowStart := now.Add(-qwenRateLimitWindow)
|
||||||
|
|
||||||
|
qwenRateLimiter.Lock()
|
||||||
|
defer qwenRateLimiter.Unlock()
|
||||||
|
|
||||||
|
// Get and filter timestamps within the window
|
||||||
|
timestamps := qwenRateLimiter.requests[authID]
|
||||||
|
var validTimestamps []time.Time
|
||||||
|
for _, ts := range timestamps {
|
||||||
|
if ts.After(windowStart) {
|
||||||
|
validTimestamps = append(validTimestamps, ts)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Always prune expired entries to prevent memory leak
|
||||||
|
// Delete empty entries, otherwise update with pruned slice
|
||||||
|
if len(validTimestamps) == 0 {
|
||||||
|
delete(qwenRateLimiter.requests, authID)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check if rate limit exceeded
|
||||||
|
if len(validTimestamps) >= qwenRateLimitPerMin {
|
||||||
|
// Calculate when the oldest request will expire
|
||||||
|
oldestInWindow := validTimestamps[0]
|
||||||
|
retryAfter := oldestInWindow.Add(qwenRateLimitWindow).Sub(now)
|
||||||
|
if retryAfter < time.Second {
|
||||||
|
retryAfter = time.Second
|
||||||
|
}
|
||||||
|
retryAfterSec := int(retryAfter.Seconds())
|
||||||
|
return statusErr{
|
||||||
|
code: http.StatusTooManyRequests,
|
||||||
|
msg: fmt.Sprintf(`{"error":{"code":"rate_limit_exceeded","message":"Qwen rate limit: %d requests/minute exceeded, retry after %ds","type":"rate_limit_exceeded"}}`, qwenRateLimitPerMin, retryAfterSec),
|
||||||
|
retryAfter: &retryAfter,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Record this request and update the map with pruned timestamps
|
||||||
|
validTimestamps = append(validTimestamps, now)
|
||||||
|
qwenRateLimiter.requests[authID] = validTimestamps
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// isQwenQuotaError checks if the error response indicates a quota exceeded error.
|
||||||
|
// Qwen returns HTTP 403 with error.code="insufficient_quota" when daily quota is exhausted.
|
||||||
|
func isQwenQuotaError(body []byte) bool {
|
||||||
|
code := strings.ToLower(gjson.GetBytes(body, "error.code").String())
|
||||||
|
errType := strings.ToLower(gjson.GetBytes(body, "error.type").String())
|
||||||
|
|
||||||
|
// Primary check: exact match on error.code or error.type (most reliable)
|
||||||
|
if _, ok := qwenQuotaCodes[code]; ok {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
if _, ok := qwenQuotaCodes[errType]; ok {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
// Fallback: check message only if code/type don't match (less reliable)
|
||||||
|
msg := strings.ToLower(gjson.GetBytes(body, "error.message").String())
|
||||||
|
if strings.Contains(msg, "insufficient_quota") || strings.Contains(msg, "quota exceeded") ||
|
||||||
|
strings.Contains(msg, "free allocated quota exceeded") {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
// wrapQwenError wraps an HTTP error response, detecting quota errors and mapping them to 429.
|
||||||
|
// Returns the appropriate status code and retryAfter duration for statusErr.
|
||||||
|
// Only checks for quota errors when httpCode is 403 or 429 to avoid false positives.
|
||||||
|
func wrapQwenError(ctx context.Context, httpCode int, body []byte) (errCode int, retryAfter *time.Duration) {
|
||||||
|
errCode = httpCode
|
||||||
|
// Only check quota errors for expected status codes to avoid false positives
|
||||||
|
// Qwen returns 403 for quota errors, 429 for rate limits
|
||||||
|
if (httpCode == http.StatusForbidden || httpCode == http.StatusTooManyRequests) && isQwenQuotaError(body) {
|
||||||
|
errCode = http.StatusTooManyRequests // Map to 429 to trigger quota logic
|
||||||
|
cooldown := timeUntilNextDay()
|
||||||
|
retryAfter = &cooldown
|
||||||
|
logWithRequestID(ctx).Warnf("qwen quota exceeded (http %d -> %d), cooling down until tomorrow (%v)", httpCode, errCode, cooldown)
|
||||||
|
}
|
||||||
|
return errCode, retryAfter
|
||||||
|
}
|
||||||
|
|
||||||
|
// timeUntilNextDay returns duration until midnight Beijing time (UTC+8).
|
||||||
|
// Qwen's daily quota resets at 00:00 Beijing time.
|
||||||
|
func timeUntilNextDay() time.Duration {
|
||||||
|
now := time.Now()
|
||||||
|
nowLocal := now.In(qwenBeijingLoc)
|
||||||
|
tomorrow := time.Date(nowLocal.Year(), nowLocal.Month(), nowLocal.Day()+1, 0, 0, 0, 0, qwenBeijingLoc)
|
||||||
|
return tomorrow.Sub(now)
|
||||||
|
}
|
||||||
|
|
||||||
// QwenExecutor is a stateless executor for Qwen Code using OpenAI-compatible chat completions.
|
// QwenExecutor is a stateless executor for Qwen Code using OpenAI-compatible chat completions.
|
||||||
// If access token is unavailable, it falls back to legacy via ClientAdapter.
|
// If access token is unavailable, it falls back to legacy via ClientAdapter.
|
||||||
type QwenExecutor struct {
|
type QwenExecutor struct {
|
||||||
@@ -67,6 +210,17 @@ func (e *QwenExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, req
|
|||||||
if opts.Alt == "responses/compact" {
|
if opts.Alt == "responses/compact" {
|
||||||
return resp, statusErr{code: http.StatusNotImplemented, msg: "/responses/compact not supported"}
|
return resp, statusErr{code: http.StatusNotImplemented, msg: "/responses/compact not supported"}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Check rate limit before proceeding
|
||||||
|
var authID string
|
||||||
|
if auth != nil {
|
||||||
|
authID = auth.ID
|
||||||
|
}
|
||||||
|
if err := checkQwenRateLimit(authID); err != nil {
|
||||||
|
logWithRequestID(ctx).Warnf("qwen rate limit exceeded for credential %s", redactAuthID(authID))
|
||||||
|
return resp, err
|
||||||
|
}
|
||||||
|
|
||||||
baseModel := thinking.ParseSuffix(req.Model).ModelName
|
baseModel := thinking.ParseSuffix(req.Model).ModelName
|
||||||
|
|
||||||
token, baseURL := qwenCreds(auth)
|
token, baseURL := qwenCreds(auth)
|
||||||
@@ -102,9 +256,8 @@ func (e *QwenExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, req
|
|||||||
return resp, err
|
return resp, err
|
||||||
}
|
}
|
||||||
applyQwenHeaders(httpReq, token, false)
|
applyQwenHeaders(httpReq, token, false)
|
||||||
var authID, authLabel, authType, authValue string
|
var authLabel, authType, authValue string
|
||||||
if auth != nil {
|
if auth != nil {
|
||||||
authID = auth.ID
|
|
||||||
authLabel = auth.Label
|
authLabel = auth.Label
|
||||||
authType, authValue = auth.AccountInfo()
|
authType, authValue = auth.AccountInfo()
|
||||||
}
|
}
|
||||||
@@ -135,8 +288,10 @@ func (e *QwenExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, req
|
|||||||
if httpResp.StatusCode < 200 || httpResp.StatusCode >= 300 {
|
if httpResp.StatusCode < 200 || httpResp.StatusCode >= 300 {
|
||||||
b, _ := io.ReadAll(httpResp.Body)
|
b, _ := io.ReadAll(httpResp.Body)
|
||||||
appendAPIResponseChunk(ctx, e.cfg, b)
|
appendAPIResponseChunk(ctx, e.cfg, b)
|
||||||
logWithRequestID(ctx).Debugf("request error, error status: %d, error message: %s", httpResp.StatusCode, summarizeErrorBody(httpResp.Header.Get("Content-Type"), b))
|
|
||||||
err = statusErr{code: httpResp.StatusCode, msg: string(b)}
|
errCode, retryAfter := wrapQwenError(ctx, httpResp.StatusCode, b)
|
||||||
|
logWithRequestID(ctx).Debugf("request error, error status: %d (mapped: %d), error message: %s", httpResp.StatusCode, errCode, summarizeErrorBody(httpResp.Header.Get("Content-Type"), b))
|
||||||
|
err = statusErr{code: errCode, msg: string(b), retryAfter: retryAfter}
|
||||||
return resp, err
|
return resp, err
|
||||||
}
|
}
|
||||||
data, err := io.ReadAll(httpResp.Body)
|
data, err := io.ReadAll(httpResp.Body)
|
||||||
@@ -158,6 +313,17 @@ func (e *QwenExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Aut
|
|||||||
if opts.Alt == "responses/compact" {
|
if opts.Alt == "responses/compact" {
|
||||||
return nil, statusErr{code: http.StatusNotImplemented, msg: "/responses/compact not supported"}
|
return nil, statusErr{code: http.StatusNotImplemented, msg: "/responses/compact not supported"}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Check rate limit before proceeding
|
||||||
|
var authID string
|
||||||
|
if auth != nil {
|
||||||
|
authID = auth.ID
|
||||||
|
}
|
||||||
|
if err := checkQwenRateLimit(authID); err != nil {
|
||||||
|
logWithRequestID(ctx).Warnf("qwen rate limit exceeded for credential %s", redactAuthID(authID))
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
baseModel := thinking.ParseSuffix(req.Model).ModelName
|
baseModel := thinking.ParseSuffix(req.Model).ModelName
|
||||||
|
|
||||||
token, baseURL := qwenCreds(auth)
|
token, baseURL := qwenCreds(auth)
|
||||||
@@ -200,9 +366,8 @@ func (e *QwenExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Aut
|
|||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
applyQwenHeaders(httpReq, token, true)
|
applyQwenHeaders(httpReq, token, true)
|
||||||
var authID, authLabel, authType, authValue string
|
var authLabel, authType, authValue string
|
||||||
if auth != nil {
|
if auth != nil {
|
||||||
authID = auth.ID
|
|
||||||
authLabel = auth.Label
|
authLabel = auth.Label
|
||||||
authType, authValue = auth.AccountInfo()
|
authType, authValue = auth.AccountInfo()
|
||||||
}
|
}
|
||||||
@@ -228,11 +393,13 @@ func (e *QwenExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Aut
|
|||||||
if httpResp.StatusCode < 200 || httpResp.StatusCode >= 300 {
|
if httpResp.StatusCode < 200 || httpResp.StatusCode >= 300 {
|
||||||
b, _ := io.ReadAll(httpResp.Body)
|
b, _ := io.ReadAll(httpResp.Body)
|
||||||
appendAPIResponseChunk(ctx, e.cfg, b)
|
appendAPIResponseChunk(ctx, e.cfg, b)
|
||||||
logWithRequestID(ctx).Debugf("request error, error status: %d, error message: %s", httpResp.StatusCode, summarizeErrorBody(httpResp.Header.Get("Content-Type"), b))
|
|
||||||
|
errCode, retryAfter := wrapQwenError(ctx, httpResp.StatusCode, b)
|
||||||
|
logWithRequestID(ctx).Debugf("request error, error status: %d (mapped: %d), error message: %s", httpResp.StatusCode, errCode, summarizeErrorBody(httpResp.Header.Get("Content-Type"), b))
|
||||||
if errClose := httpResp.Body.Close(); errClose != nil {
|
if errClose := httpResp.Body.Close(); errClose != nil {
|
||||||
log.Errorf("qwen executor: close response body error: %v", errClose)
|
log.Errorf("qwen executor: close response body error: %v", errClose)
|
||||||
}
|
}
|
||||||
err = statusErr{code: httpResp.StatusCode, msg: string(b)}
|
err = statusErr{code: errCode, msg: string(b), retryAfter: retryAfter}
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
out := make(chan cliproxyexecutor.StreamChunk)
|
out := make(chan cliproxyexecutor.StreamChunk)
|
||||||
|
|||||||
@@ -223,14 +223,65 @@ func ConvertClaudeRequestToAntigravity(modelName string, inputRawJSON []byte, _
|
|||||||
functionResponseJSON, _ = sjson.Set(functionResponseJSON, "response.result", responseData)
|
functionResponseJSON, _ = sjson.Set(functionResponseJSON, "response.result", responseData)
|
||||||
} else if functionResponseResult.IsArray() {
|
} else if functionResponseResult.IsArray() {
|
||||||
frResults := functionResponseResult.Array()
|
frResults := functionResponseResult.Array()
|
||||||
if len(frResults) == 1 {
|
nonImageCount := 0
|
||||||
functionResponseJSON, _ = sjson.SetRaw(functionResponseJSON, "response.result", frResults[0].Raw)
|
lastNonImageRaw := ""
|
||||||
|
filteredJSON := "[]"
|
||||||
|
imagePartsJSON := "[]"
|
||||||
|
for _, fr := range frResults {
|
||||||
|
if fr.Get("type").String() == "image" && fr.Get("source.type").String() == "base64" {
|
||||||
|
inlineDataJSON := `{}`
|
||||||
|
if mimeType := fr.Get("source.media_type").String(); mimeType != "" {
|
||||||
|
inlineDataJSON, _ = sjson.Set(inlineDataJSON, "mimeType", mimeType)
|
||||||
|
}
|
||||||
|
if data := fr.Get("source.data").String(); data != "" {
|
||||||
|
inlineDataJSON, _ = sjson.Set(inlineDataJSON, "data", data)
|
||||||
|
}
|
||||||
|
|
||||||
|
imagePartJSON := `{}`
|
||||||
|
imagePartJSON, _ = sjson.SetRaw(imagePartJSON, "inlineData", inlineDataJSON)
|
||||||
|
imagePartsJSON, _ = sjson.SetRaw(imagePartsJSON, "-1", imagePartJSON)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
nonImageCount++
|
||||||
|
lastNonImageRaw = fr.Raw
|
||||||
|
filteredJSON, _ = sjson.SetRaw(filteredJSON, "-1", fr.Raw)
|
||||||
|
}
|
||||||
|
|
||||||
|
if nonImageCount == 1 {
|
||||||
|
functionResponseJSON, _ = sjson.SetRaw(functionResponseJSON, "response.result", lastNonImageRaw)
|
||||||
|
} else if nonImageCount > 1 {
|
||||||
|
functionResponseJSON, _ = sjson.SetRaw(functionResponseJSON, "response.result", filteredJSON)
|
||||||
} else {
|
} else {
|
||||||
functionResponseJSON, _ = sjson.SetRaw(functionResponseJSON, "response.result", functionResponseResult.Raw)
|
functionResponseJSON, _ = sjson.Set(functionResponseJSON, "response.result", "")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Place image data inside functionResponse.parts as inlineData
|
||||||
|
// instead of as sibling parts in the outer content, to avoid
|
||||||
|
// base64 data bloating the text context.
|
||||||
|
if gjson.Get(imagePartsJSON, "#").Int() > 0 {
|
||||||
|
functionResponseJSON, _ = sjson.SetRaw(functionResponseJSON, "parts", imagePartsJSON)
|
||||||
}
|
}
|
||||||
|
|
||||||
} else if functionResponseResult.IsObject() {
|
} else if functionResponseResult.IsObject() {
|
||||||
functionResponseJSON, _ = sjson.SetRaw(functionResponseJSON, "response.result", functionResponseResult.Raw)
|
if functionResponseResult.Get("type").String() == "image" && functionResponseResult.Get("source.type").String() == "base64" {
|
||||||
|
inlineDataJSON := `{}`
|
||||||
|
if mimeType := functionResponseResult.Get("source.media_type").String(); mimeType != "" {
|
||||||
|
inlineDataJSON, _ = sjson.Set(inlineDataJSON, "mimeType", mimeType)
|
||||||
|
}
|
||||||
|
if data := functionResponseResult.Get("source.data").String(); data != "" {
|
||||||
|
inlineDataJSON, _ = sjson.Set(inlineDataJSON, "data", data)
|
||||||
|
}
|
||||||
|
|
||||||
|
imagePartJSON := `{}`
|
||||||
|
imagePartJSON, _ = sjson.SetRaw(imagePartJSON, "inlineData", inlineDataJSON)
|
||||||
|
imagePartsJSON := "[]"
|
||||||
|
imagePartsJSON, _ = sjson.SetRaw(imagePartsJSON, "-1", imagePartJSON)
|
||||||
|
functionResponseJSON, _ = sjson.SetRaw(functionResponseJSON, "parts", imagePartsJSON)
|
||||||
|
functionResponseJSON, _ = sjson.Set(functionResponseJSON, "response.result", "")
|
||||||
|
} else {
|
||||||
|
functionResponseJSON, _ = sjson.SetRaw(functionResponseJSON, "response.result", functionResponseResult.Raw)
|
||||||
|
}
|
||||||
} else if functionResponseResult.Raw != "" {
|
} else if functionResponseResult.Raw != "" {
|
||||||
functionResponseJSON, _ = sjson.SetRaw(functionResponseJSON, "response.result", functionResponseResult.Raw)
|
functionResponseJSON, _ = sjson.SetRaw(functionResponseJSON, "response.result", functionResponseResult.Raw)
|
||||||
} else {
|
} else {
|
||||||
@@ -248,7 +299,7 @@ func ConvertClaudeRequestToAntigravity(modelName string, inputRawJSON []byte, _
|
|||||||
if sourceResult.Get("type").String() == "base64" {
|
if sourceResult.Get("type").String() == "base64" {
|
||||||
inlineDataJSON := `{}`
|
inlineDataJSON := `{}`
|
||||||
if mimeType := sourceResult.Get("media_type").String(); mimeType != "" {
|
if mimeType := sourceResult.Get("media_type").String(); mimeType != "" {
|
||||||
inlineDataJSON, _ = sjson.Set(inlineDataJSON, "mime_type", mimeType)
|
inlineDataJSON, _ = sjson.Set(inlineDataJSON, "mimeType", mimeType)
|
||||||
}
|
}
|
||||||
if data := sourceResult.Get("data").String(); data != "" {
|
if data := sourceResult.Get("data").String(); data != "" {
|
||||||
inlineDataJSON, _ = sjson.Set(inlineDataJSON, "data", data)
|
inlineDataJSON, _ = sjson.Set(inlineDataJSON, "data", data)
|
||||||
|
|||||||
@@ -413,8 +413,8 @@ func TestConvertClaudeRequestToAntigravity_ImageContent(t *testing.T) {
|
|||||||
if !inlineData.Exists() {
|
if !inlineData.Exists() {
|
||||||
t.Error("inlineData should exist")
|
t.Error("inlineData should exist")
|
||||||
}
|
}
|
||||||
if inlineData.Get("mime_type").String() != "image/png" {
|
if inlineData.Get("mimeType").String() != "image/png" {
|
||||||
t.Error("mime_type mismatch")
|
t.Error("mimeType mismatch")
|
||||||
}
|
}
|
||||||
if !strings.Contains(inlineData.Get("data").String(), "iVBORw0KGgo") {
|
if !strings.Contains(inlineData.Get("data").String(), "iVBORw0KGgo") {
|
||||||
t.Error("data mismatch")
|
t.Error("data mismatch")
|
||||||
@@ -740,6 +740,429 @@ func TestConvertClaudeRequestToAntigravity_ToolResultNullContent(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestConvertClaudeRequestToAntigravity_ToolResultWithImage(t *testing.T) {
|
||||||
|
// tool_result with array content containing text + image should place
|
||||||
|
// image data inside functionResponse.parts as inlineData, not as a
|
||||||
|
// sibling part in the outer content (to avoid base64 context bloat).
|
||||||
|
inputJSON := []byte(`{
|
||||||
|
"model": "claude-3-5-sonnet-20240620",
|
||||||
|
"messages": [
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": [
|
||||||
|
{
|
||||||
|
"type": "tool_result",
|
||||||
|
"tool_use_id": "Read-123-456",
|
||||||
|
"content": [
|
||||||
|
{
|
||||||
|
"type": "text",
|
||||||
|
"text": "File content here"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"type": "image",
|
||||||
|
"source": {
|
||||||
|
"type": "base64",
|
||||||
|
"media_type": "image/png",
|
||||||
|
"data": "iVBORw0KGgoAAAANSUhEUg=="
|
||||||
|
}
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}`)
|
||||||
|
|
||||||
|
output := ConvertClaudeRequestToAntigravity("claude-sonnet-4-5", inputJSON, false)
|
||||||
|
outputStr := string(output)
|
||||||
|
|
||||||
|
if !gjson.Valid(outputStr) {
|
||||||
|
t.Fatalf("Result is not valid JSON:\n%s", outputStr)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Image should be inside functionResponse.parts, not as outer sibling part
|
||||||
|
funcResp := gjson.Get(outputStr, "request.contents.0.parts.0.functionResponse")
|
||||||
|
if !funcResp.Exists() {
|
||||||
|
t.Fatal("functionResponse should exist")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Text content should be in response.result
|
||||||
|
resultText := funcResp.Get("response.result.text").String()
|
||||||
|
if resultText != "File content here" {
|
||||||
|
t.Errorf("Expected response.result.text = 'File content here', got '%s'", resultText)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Image should be in functionResponse.parts[0].inlineData
|
||||||
|
inlineData := funcResp.Get("parts.0.inlineData")
|
||||||
|
if !inlineData.Exists() {
|
||||||
|
t.Fatal("functionResponse.parts[0].inlineData should exist")
|
||||||
|
}
|
||||||
|
if inlineData.Get("mimeType").String() != "image/png" {
|
||||||
|
t.Errorf("Expected mimeType 'image/png', got '%s'", inlineData.Get("mimeType").String())
|
||||||
|
}
|
||||||
|
if !strings.Contains(inlineData.Get("data").String(), "iVBORw0KGgo") {
|
||||||
|
t.Error("data mismatch")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Image should NOT be in outer parts (only functionResponse part should exist)
|
||||||
|
outerParts := gjson.Get(outputStr, "request.contents.0.parts")
|
||||||
|
if outerParts.IsArray() && len(outerParts.Array()) > 1 {
|
||||||
|
t.Errorf("Expected only 1 outer part (functionResponse), got %d", len(outerParts.Array()))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestConvertClaudeRequestToAntigravity_ToolResultWithSingleImage(t *testing.T) {
|
||||||
|
// tool_result with single image object as content should place
|
||||||
|
// image data inside functionResponse.parts, not as outer sibling part.
|
||||||
|
inputJSON := []byte(`{
|
||||||
|
"model": "claude-3-5-sonnet-20240620",
|
||||||
|
"messages": [
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": [
|
||||||
|
{
|
||||||
|
"type": "tool_result",
|
||||||
|
"tool_use_id": "Read-789-012",
|
||||||
|
"content": {
|
||||||
|
"type": "image",
|
||||||
|
"source": {
|
||||||
|
"type": "base64",
|
||||||
|
"media_type": "image/jpeg",
|
||||||
|
"data": "/9j/4AAQSkZJRgABAQ=="
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}`)
|
||||||
|
|
||||||
|
output := ConvertClaudeRequestToAntigravity("claude-sonnet-4-5", inputJSON, false)
|
||||||
|
outputStr := string(output)
|
||||||
|
|
||||||
|
if !gjson.Valid(outputStr) {
|
||||||
|
t.Fatalf("Result is not valid JSON:\n%s", outputStr)
|
||||||
|
}
|
||||||
|
|
||||||
|
funcResp := gjson.Get(outputStr, "request.contents.0.parts.0.functionResponse")
|
||||||
|
if !funcResp.Exists() {
|
||||||
|
t.Fatal("functionResponse should exist")
|
||||||
|
}
|
||||||
|
|
||||||
|
// response.result should be empty (image only)
|
||||||
|
if funcResp.Get("response.result").String() != "" {
|
||||||
|
t.Errorf("Expected empty response.result for image-only content, got '%s'", funcResp.Get("response.result").String())
|
||||||
|
}
|
||||||
|
|
||||||
|
// Image should be in functionResponse.parts[0].inlineData
|
||||||
|
inlineData := funcResp.Get("parts.0.inlineData")
|
||||||
|
if !inlineData.Exists() {
|
||||||
|
t.Fatal("functionResponse.parts[0].inlineData should exist")
|
||||||
|
}
|
||||||
|
if inlineData.Get("mimeType").String() != "image/jpeg" {
|
||||||
|
t.Errorf("Expected mimeType 'image/jpeg', got '%s'", inlineData.Get("mimeType").String())
|
||||||
|
}
|
||||||
|
|
||||||
|
// Image should NOT be in outer parts
|
||||||
|
outerParts := gjson.Get(outputStr, "request.contents.0.parts")
|
||||||
|
if outerParts.IsArray() && len(outerParts.Array()) > 1 {
|
||||||
|
t.Errorf("Expected only 1 outer part, got %d", len(outerParts.Array()))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestConvertClaudeRequestToAntigravity_ToolResultWithMultipleImagesAndTexts(t *testing.T) {
|
||||||
|
// tool_result with array content: 2 text items + 2 images
|
||||||
|
// All images go into functionResponse.parts, texts into response.result array
|
||||||
|
inputJSON := []byte(`{
|
||||||
|
"model": "claude-3-5-sonnet-20240620",
|
||||||
|
"messages": [
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": [
|
||||||
|
{
|
||||||
|
"type": "tool_result",
|
||||||
|
"tool_use_id": "Multi-001",
|
||||||
|
"content": [
|
||||||
|
{"type": "text", "text": "First text"},
|
||||||
|
{
|
||||||
|
"type": "image",
|
||||||
|
"source": {"type": "base64", "media_type": "image/png", "data": "AAAA"}
|
||||||
|
},
|
||||||
|
{"type": "text", "text": "Second text"},
|
||||||
|
{
|
||||||
|
"type": "image",
|
||||||
|
"source": {"type": "base64", "media_type": "image/jpeg", "data": "BBBB"}
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}`)
|
||||||
|
|
||||||
|
output := ConvertClaudeRequestToAntigravity("claude-sonnet-4-5", inputJSON, false)
|
||||||
|
outputStr := string(output)
|
||||||
|
|
||||||
|
if !gjson.Valid(outputStr) {
|
||||||
|
t.Fatalf("Result is not valid JSON:\n%s", outputStr)
|
||||||
|
}
|
||||||
|
|
||||||
|
funcResp := gjson.Get(outputStr, "request.contents.0.parts.0.functionResponse")
|
||||||
|
if !funcResp.Exists() {
|
||||||
|
t.Fatal("functionResponse should exist")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Multiple text items => response.result is an array
|
||||||
|
resultArr := funcResp.Get("response.result")
|
||||||
|
if !resultArr.IsArray() {
|
||||||
|
t.Fatalf("Expected response.result to be an array, got: %s", resultArr.Raw)
|
||||||
|
}
|
||||||
|
results := resultArr.Array()
|
||||||
|
if len(results) != 2 {
|
||||||
|
t.Fatalf("Expected 2 result items, got %d", len(results))
|
||||||
|
}
|
||||||
|
|
||||||
|
// Both images should be in functionResponse.parts
|
||||||
|
imgParts := funcResp.Get("parts").Array()
|
||||||
|
if len(imgParts) != 2 {
|
||||||
|
t.Fatalf("Expected 2 image parts in functionResponse.parts, got %d", len(imgParts))
|
||||||
|
}
|
||||||
|
if imgParts[0].Get("inlineData.mimeType").String() != "image/png" {
|
||||||
|
t.Errorf("Expected first image mimeType 'image/png', got '%s'", imgParts[0].Get("inlineData.mimeType").String())
|
||||||
|
}
|
||||||
|
if imgParts[0].Get("inlineData.data").String() != "AAAA" {
|
||||||
|
t.Errorf("Expected first image data 'AAAA', got '%s'", imgParts[0].Get("inlineData.data").String())
|
||||||
|
}
|
||||||
|
if imgParts[1].Get("inlineData.mimeType").String() != "image/jpeg" {
|
||||||
|
t.Errorf("Expected second image mimeType 'image/jpeg', got '%s'", imgParts[1].Get("inlineData.mimeType").String())
|
||||||
|
}
|
||||||
|
if imgParts[1].Get("inlineData.data").String() != "BBBB" {
|
||||||
|
t.Errorf("Expected second image data 'BBBB', got '%s'", imgParts[1].Get("inlineData.data").String())
|
||||||
|
}
|
||||||
|
|
||||||
|
// Only 1 outer part (the functionResponse itself)
|
||||||
|
outerParts := gjson.Get(outputStr, "request.contents.0.parts").Array()
|
||||||
|
if len(outerParts) != 1 {
|
||||||
|
t.Errorf("Expected 1 outer part, got %d", len(outerParts))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestConvertClaudeRequestToAntigravity_ToolResultWithOnlyMultipleImages(t *testing.T) {
|
||||||
|
// tool_result with only images (no text) — response.result should be empty string
|
||||||
|
inputJSON := []byte(`{
|
||||||
|
"model": "claude-3-5-sonnet-20240620",
|
||||||
|
"messages": [
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": [
|
||||||
|
{
|
||||||
|
"type": "tool_result",
|
||||||
|
"tool_use_id": "ImgOnly-001",
|
||||||
|
"content": [
|
||||||
|
{
|
||||||
|
"type": "image",
|
||||||
|
"source": {"type": "base64", "media_type": "image/png", "data": "PNG1"}
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"type": "image",
|
||||||
|
"source": {"type": "base64", "media_type": "image/gif", "data": "GIF1"}
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}`)
|
||||||
|
|
||||||
|
output := ConvertClaudeRequestToAntigravity("claude-sonnet-4-5", inputJSON, false)
|
||||||
|
outputStr := string(output)
|
||||||
|
|
||||||
|
if !gjson.Valid(outputStr) {
|
||||||
|
t.Fatalf("Result is not valid JSON:\n%s", outputStr)
|
||||||
|
}
|
||||||
|
|
||||||
|
funcResp := gjson.Get(outputStr, "request.contents.0.parts.0.functionResponse")
|
||||||
|
if !funcResp.Exists() {
|
||||||
|
t.Fatal("functionResponse should exist")
|
||||||
|
}
|
||||||
|
|
||||||
|
// No text => response.result should be empty string
|
||||||
|
if funcResp.Get("response.result").String() != "" {
|
||||||
|
t.Errorf("Expected empty response.result, got '%s'", funcResp.Get("response.result").String())
|
||||||
|
}
|
||||||
|
|
||||||
|
// Both images in functionResponse.parts
|
||||||
|
imgParts := funcResp.Get("parts").Array()
|
||||||
|
if len(imgParts) != 2 {
|
||||||
|
t.Fatalf("Expected 2 image parts, got %d", len(imgParts))
|
||||||
|
}
|
||||||
|
if imgParts[0].Get("inlineData.mimeType").String() != "image/png" {
|
||||||
|
t.Error("first image mimeType mismatch")
|
||||||
|
}
|
||||||
|
if imgParts[1].Get("inlineData.mimeType").String() != "image/gif" {
|
||||||
|
t.Error("second image mimeType mismatch")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Only 1 outer part
|
||||||
|
outerParts := gjson.Get(outputStr, "request.contents.0.parts").Array()
|
||||||
|
if len(outerParts) != 1 {
|
||||||
|
t.Errorf("Expected 1 outer part, got %d", len(outerParts))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestConvertClaudeRequestToAntigravity_ToolResultImageNotBase64(t *testing.T) {
|
||||||
|
// image with source.type != "base64" should be treated as non-image (falls through)
|
||||||
|
inputJSON := []byte(`{
|
||||||
|
"model": "claude-3-5-sonnet-20240620",
|
||||||
|
"messages": [
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": [
|
||||||
|
{
|
||||||
|
"type": "tool_result",
|
||||||
|
"tool_use_id": "NotB64-001",
|
||||||
|
"content": [
|
||||||
|
{"type": "text", "text": "some output"},
|
||||||
|
{
|
||||||
|
"type": "image",
|
||||||
|
"source": {"type": "url", "url": "https://example.com/img.png"}
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}`)
|
||||||
|
|
||||||
|
output := ConvertClaudeRequestToAntigravity("claude-sonnet-4-5", inputJSON, false)
|
||||||
|
outputStr := string(output)
|
||||||
|
|
||||||
|
if !gjson.Valid(outputStr) {
|
||||||
|
t.Fatalf("Result is not valid JSON:\n%s", outputStr)
|
||||||
|
}
|
||||||
|
|
||||||
|
funcResp := gjson.Get(outputStr, "request.contents.0.parts.0.functionResponse")
|
||||||
|
if !funcResp.Exists() {
|
||||||
|
t.Fatal("functionResponse should exist")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Non-base64 image is treated as non-image, so it goes into the filtered results
|
||||||
|
// along with the text item. Since there are 2 non-image items, result is array.
|
||||||
|
resultArr := funcResp.Get("response.result")
|
||||||
|
if !resultArr.IsArray() {
|
||||||
|
t.Fatalf("Expected response.result to be an array (2 non-image items), got: %s", resultArr.Raw)
|
||||||
|
}
|
||||||
|
results := resultArr.Array()
|
||||||
|
if len(results) != 2 {
|
||||||
|
t.Fatalf("Expected 2 result items, got %d", len(results))
|
||||||
|
}
|
||||||
|
|
||||||
|
// No functionResponse.parts (no base64 images collected)
|
||||||
|
if funcResp.Get("parts").Exists() {
|
||||||
|
t.Error("functionResponse.parts should NOT exist when no base64 images")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestConvertClaudeRequestToAntigravity_ToolResultImageMissingData(t *testing.T) {
|
||||||
|
// image with source.type=base64 but missing data field
|
||||||
|
inputJSON := []byte(`{
|
||||||
|
"model": "claude-3-5-sonnet-20240620",
|
||||||
|
"messages": [
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": [
|
||||||
|
{
|
||||||
|
"type": "tool_result",
|
||||||
|
"tool_use_id": "NoData-001",
|
||||||
|
"content": [
|
||||||
|
{"type": "text", "text": "output"},
|
||||||
|
{
|
||||||
|
"type": "image",
|
||||||
|
"source": {"type": "base64", "media_type": "image/png"}
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}`)
|
||||||
|
|
||||||
|
output := ConvertClaudeRequestToAntigravity("claude-sonnet-4-5", inputJSON, false)
|
||||||
|
outputStr := string(output)
|
||||||
|
|
||||||
|
if !gjson.Valid(outputStr) {
|
||||||
|
t.Fatalf("Result is not valid JSON:\n%s", outputStr)
|
||||||
|
}
|
||||||
|
|
||||||
|
funcResp := gjson.Get(outputStr, "request.contents.0.parts.0.functionResponse")
|
||||||
|
if !funcResp.Exists() {
|
||||||
|
t.Fatal("functionResponse should exist")
|
||||||
|
}
|
||||||
|
|
||||||
|
// The image is still classified as base64 image (type check passes),
|
||||||
|
// but data field is missing => inlineData has mimeType but no data
|
||||||
|
imgParts := funcResp.Get("parts").Array()
|
||||||
|
if len(imgParts) != 1 {
|
||||||
|
t.Fatalf("Expected 1 image part, got %d", len(imgParts))
|
||||||
|
}
|
||||||
|
if imgParts[0].Get("inlineData.mimeType").String() != "image/png" {
|
||||||
|
t.Error("mimeType should still be set")
|
||||||
|
}
|
||||||
|
if imgParts[0].Get("inlineData.data").Exists() {
|
||||||
|
t.Error("data should not exist when source.data is missing")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestConvertClaudeRequestToAntigravity_ToolResultImageMissingMediaType(t *testing.T) {
|
||||||
|
// image with source.type=base64 but missing media_type field
|
||||||
|
inputJSON := []byte(`{
|
||||||
|
"model": "claude-3-5-sonnet-20240620",
|
||||||
|
"messages": [
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": [
|
||||||
|
{
|
||||||
|
"type": "tool_result",
|
||||||
|
"tool_use_id": "NoMime-001",
|
||||||
|
"content": [
|
||||||
|
{"type": "text", "text": "output"},
|
||||||
|
{
|
||||||
|
"type": "image",
|
||||||
|
"source": {"type": "base64", "data": "AAAA"}
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}`)
|
||||||
|
|
||||||
|
output := ConvertClaudeRequestToAntigravity("claude-sonnet-4-5", inputJSON, false)
|
||||||
|
outputStr := string(output)
|
||||||
|
|
||||||
|
if !gjson.Valid(outputStr) {
|
||||||
|
t.Fatalf("Result is not valid JSON:\n%s", outputStr)
|
||||||
|
}
|
||||||
|
|
||||||
|
funcResp := gjson.Get(outputStr, "request.contents.0.parts.0.functionResponse")
|
||||||
|
if !funcResp.Exists() {
|
||||||
|
t.Fatal("functionResponse should exist")
|
||||||
|
}
|
||||||
|
|
||||||
|
// The image is still classified as base64 image,
|
||||||
|
// but media_type is missing => inlineData has data but no mimeType
|
||||||
|
imgParts := funcResp.Get("parts").Array()
|
||||||
|
if len(imgParts) != 1 {
|
||||||
|
t.Fatalf("Expected 1 image part, got %d", len(imgParts))
|
||||||
|
}
|
||||||
|
if imgParts[0].Get("inlineData.mimeType").Exists() {
|
||||||
|
t.Error("mimeType should not exist when media_type is missing")
|
||||||
|
}
|
||||||
|
if imgParts[0].Get("inlineData.data").String() != "AAAA" {
|
||||||
|
t.Error("data should still be set")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func TestConvertClaudeRequestToAntigravity_ToolAndThinking_NoExistingSystem(t *testing.T) {
|
func TestConvertClaudeRequestToAntigravity_ToolAndThinking_NoExistingSystem(t *testing.T) {
|
||||||
// When tools + thinking but no system instruction, should create one with hint
|
// When tools + thinking but no system instruction, should create one with hint
|
||||||
inputJSON := []byte(`{
|
inputJSON := []byte(`{
|
||||||
|
|||||||
@@ -93,3 +93,81 @@ func TestConvertGeminiRequestToAntigravity_ParallelFunctionCalls(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestFixCLIToolResponse_PreservesFunctionResponseParts(t *testing.T) {
|
||||||
|
// When functionResponse contains a "parts" field with inlineData (from Claude
|
||||||
|
// translator's image embedding), fixCLIToolResponse should preserve it as-is.
|
||||||
|
// parseFunctionResponseRaw returns response.Raw for valid JSON objects,
|
||||||
|
// so extra fields like "parts" survive the pipeline.
|
||||||
|
input := `{
|
||||||
|
"model": "claude-opus-4-6-thinking",
|
||||||
|
"request": {
|
||||||
|
"contents": [
|
||||||
|
{
|
||||||
|
"role": "model",
|
||||||
|
"parts": [
|
||||||
|
{
|
||||||
|
"functionCall": {"name": "screenshot", "args": {}}
|
||||||
|
}
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"role": "function",
|
||||||
|
"parts": [
|
||||||
|
{
|
||||||
|
"functionResponse": {
|
||||||
|
"id": "tool-001",
|
||||||
|
"name": "screenshot",
|
||||||
|
"response": {"result": "Screenshot taken"},
|
||||||
|
"parts": [
|
||||||
|
{"inlineData": {"mimeType": "image/png", "data": "iVBOR"}}
|
||||||
|
]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}
|
||||||
|
}`
|
||||||
|
|
||||||
|
result, err := fixCLIToolResponse(input)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("fixCLIToolResponse failed: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Find the function response content (role=function)
|
||||||
|
contents := gjson.Get(result, "request.contents").Array()
|
||||||
|
var funcContent gjson.Result
|
||||||
|
for _, c := range contents {
|
||||||
|
if c.Get("role").String() == "function" {
|
||||||
|
funcContent = c
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if !funcContent.Exists() {
|
||||||
|
t.Fatal("function role content should exist in output")
|
||||||
|
}
|
||||||
|
|
||||||
|
// The functionResponse should be preserved with its parts field
|
||||||
|
funcResp := funcContent.Get("parts.0.functionResponse")
|
||||||
|
if !funcResp.Exists() {
|
||||||
|
t.Fatal("functionResponse should exist in output")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Verify the parts field with inlineData is preserved
|
||||||
|
inlineParts := funcResp.Get("parts").Array()
|
||||||
|
if len(inlineParts) != 1 {
|
||||||
|
t.Fatalf("Expected 1 inlineData part in functionResponse.parts, got %d", len(inlineParts))
|
||||||
|
}
|
||||||
|
if inlineParts[0].Get("inlineData.mimeType").String() != "image/png" {
|
||||||
|
t.Errorf("Expected mimeType 'image/png', got '%s'", inlineParts[0].Get("inlineData.mimeType").String())
|
||||||
|
}
|
||||||
|
if inlineParts[0].Get("inlineData.data").String() != "iVBOR" {
|
||||||
|
t.Errorf("Expected data 'iVBOR', got '%s'", inlineParts[0].Get("inlineData.data").String())
|
||||||
|
}
|
||||||
|
|
||||||
|
// Verify response.result is also preserved
|
||||||
|
if funcResp.Get("response.result").String() != "Screenshot taken" {
|
||||||
|
t.Errorf("Expected response.result 'Screenshot taken', got '%s'", funcResp.Get("response.result").String())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|||||||
@@ -187,7 +187,7 @@ func ConvertOpenAIRequestToAntigravity(modelName string, inputRawJSON []byte, _
|
|||||||
if len(pieces) == 2 && len(pieces[1]) > 7 {
|
if len(pieces) == 2 && len(pieces[1]) > 7 {
|
||||||
mime := pieces[0]
|
mime := pieces[0]
|
||||||
data := pieces[1][7:]
|
data := pieces[1][7:]
|
||||||
node, _ = sjson.SetBytes(node, "parts."+itoa(p)+".inlineData.mime_type", mime)
|
node, _ = sjson.SetBytes(node, "parts."+itoa(p)+".inlineData.mimeType", mime)
|
||||||
node, _ = sjson.SetBytes(node, "parts."+itoa(p)+".inlineData.data", data)
|
node, _ = sjson.SetBytes(node, "parts."+itoa(p)+".inlineData.data", data)
|
||||||
node, _ = sjson.SetBytes(node, "parts."+itoa(p)+".thoughtSignature", geminiCLIFunctionThoughtSignature)
|
node, _ = sjson.SetBytes(node, "parts."+itoa(p)+".thoughtSignature", geminiCLIFunctionThoughtSignature)
|
||||||
p++
|
p++
|
||||||
@@ -201,7 +201,7 @@ func ConvertOpenAIRequestToAntigravity(modelName string, inputRawJSON []byte, _
|
|||||||
ext = sp[len(sp)-1]
|
ext = sp[len(sp)-1]
|
||||||
}
|
}
|
||||||
if mimeType, ok := misc.MimeTypes[ext]; ok {
|
if mimeType, ok := misc.MimeTypes[ext]; ok {
|
||||||
node, _ = sjson.SetBytes(node, "parts."+itoa(p)+".inlineData.mime_type", mimeType)
|
node, _ = sjson.SetBytes(node, "parts."+itoa(p)+".inlineData.mimeType", mimeType)
|
||||||
node, _ = sjson.SetBytes(node, "parts."+itoa(p)+".inlineData.data", fileData)
|
node, _ = sjson.SetBytes(node, "parts."+itoa(p)+".inlineData.data", fileData)
|
||||||
p++
|
p++
|
||||||
} else {
|
} else {
|
||||||
@@ -235,7 +235,7 @@ func ConvertOpenAIRequestToAntigravity(modelName string, inputRawJSON []byte, _
|
|||||||
if len(pieces) == 2 && len(pieces[1]) > 7 {
|
if len(pieces) == 2 && len(pieces[1]) > 7 {
|
||||||
mime := pieces[0]
|
mime := pieces[0]
|
||||||
data := pieces[1][7:]
|
data := pieces[1][7:]
|
||||||
node, _ = sjson.SetBytes(node, "parts."+itoa(p)+".inlineData.mime_type", mime)
|
node, _ = sjson.SetBytes(node, "parts."+itoa(p)+".inlineData.mimeType", mime)
|
||||||
node, _ = sjson.SetBytes(node, "parts."+itoa(p)+".inlineData.data", data)
|
node, _ = sjson.SetBytes(node, "parts."+itoa(p)+".inlineData.data", data)
|
||||||
node, _ = sjson.SetBytes(node, "parts."+itoa(p)+".thoughtSignature", geminiCLIFunctionThoughtSignature)
|
node, _ = sjson.SetBytes(node, "parts."+itoa(p)+".thoughtSignature", geminiCLIFunctionThoughtSignature)
|
||||||
p++
|
p++
|
||||||
|
|||||||
@@ -26,6 +26,8 @@ func ConvertOpenAIResponsesRequestToCodex(modelName string, inputRawJSON []byte,
|
|||||||
rawJSON, _ = sjson.DeleteBytes(rawJSON, "temperature")
|
rawJSON, _ = sjson.DeleteBytes(rawJSON, "temperature")
|
||||||
rawJSON, _ = sjson.DeleteBytes(rawJSON, "top_p")
|
rawJSON, _ = sjson.DeleteBytes(rawJSON, "top_p")
|
||||||
rawJSON, _ = sjson.DeleteBytes(rawJSON, "service_tier")
|
rawJSON, _ = sjson.DeleteBytes(rawJSON, "service_tier")
|
||||||
|
rawJSON, _ = sjson.DeleteBytes(rawJSON, "truncation")
|
||||||
|
rawJSON = applyResponsesCompactionCompatibility(rawJSON)
|
||||||
|
|
||||||
// Delete the user field as it is not supported by the Codex upstream.
|
// Delete the user field as it is not supported by the Codex upstream.
|
||||||
rawJSON, _ = sjson.DeleteBytes(rawJSON, "user")
|
rawJSON, _ = sjson.DeleteBytes(rawJSON, "user")
|
||||||
@@ -36,6 +38,23 @@ func ConvertOpenAIResponsesRequestToCodex(modelName string, inputRawJSON []byte,
|
|||||||
return rawJSON
|
return rawJSON
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// applyResponsesCompactionCompatibility handles OpenAI Responses context_management.compaction
|
||||||
|
// for Codex upstream compatibility.
|
||||||
|
//
|
||||||
|
// Codex /responses currently rejects context_management with:
|
||||||
|
// {"detail":"Unsupported parameter: context_management"}.
|
||||||
|
//
|
||||||
|
// Compatibility strategy:
|
||||||
|
// 1) Remove context_management before forwarding to Codex upstream.
|
||||||
|
func applyResponsesCompactionCompatibility(rawJSON []byte) []byte {
|
||||||
|
if !gjson.GetBytes(rawJSON, "context_management").Exists() {
|
||||||
|
return rawJSON
|
||||||
|
}
|
||||||
|
|
||||||
|
rawJSON, _ = sjson.DeleteBytes(rawJSON, "context_management")
|
||||||
|
return rawJSON
|
||||||
|
}
|
||||||
|
|
||||||
// convertSystemRoleToDeveloper traverses the input array and converts any message items
|
// convertSystemRoleToDeveloper traverses the input array and converts any message items
|
||||||
// with role "system" to role "developer". This is necessary because Codex API does not
|
// with role "system" to role "developer". This is necessary because Codex API does not
|
||||||
// accept "system" role in the input array.
|
// accept "system" role in the input array.
|
||||||
|
|||||||
@@ -280,3 +280,41 @@ func TestUserFieldDeletion(t *testing.T) {
|
|||||||
t.Errorf("user field should be deleted, but it was found with value: %s", userField.Raw)
|
t.Errorf("user field should be deleted, but it was found with value: %s", userField.Raw)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestContextManagementCompactionCompatibility(t *testing.T) {
|
||||||
|
inputJSON := []byte(`{
|
||||||
|
"model": "gpt-5.2",
|
||||||
|
"context_management": [
|
||||||
|
{
|
||||||
|
"type": "compaction",
|
||||||
|
"compact_threshold": 12000
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"input": [{"role":"user","content":"hello"}]
|
||||||
|
}`)
|
||||||
|
|
||||||
|
output := ConvertOpenAIResponsesRequestToCodex("gpt-5.2", inputJSON, false)
|
||||||
|
outputStr := string(output)
|
||||||
|
|
||||||
|
if gjson.Get(outputStr, "context_management").Exists() {
|
||||||
|
t.Fatalf("context_management should be removed for Codex compatibility")
|
||||||
|
}
|
||||||
|
if gjson.Get(outputStr, "truncation").Exists() {
|
||||||
|
t.Fatalf("truncation should be removed for Codex compatibility")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestTruncationRemovedForCodexCompatibility(t *testing.T) {
|
||||||
|
inputJSON := []byte(`{
|
||||||
|
"model": "gpt-5.2",
|
||||||
|
"truncation": "disabled",
|
||||||
|
"input": [{"role":"user","content":"hello"}]
|
||||||
|
}`)
|
||||||
|
|
||||||
|
output := ConvertOpenAIResponsesRequestToCodex("gpt-5.2", inputJSON, false)
|
||||||
|
outputStr := string(output)
|
||||||
|
|
||||||
|
if gjson.Get(outputStr, "truncation").Exists() {
|
||||||
|
t.Fatalf("truncation should be removed for Codex compatibility")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|||||||
@@ -86,6 +86,8 @@ func (a GitHubCopilotAuthenticator) Login(ctx context.Context, cfg *config.Confi
|
|||||||
metadata := map[string]any{
|
metadata := map[string]any{
|
||||||
"type": "github-copilot",
|
"type": "github-copilot",
|
||||||
"username": authBundle.Username,
|
"username": authBundle.Username,
|
||||||
|
"email": authBundle.Email,
|
||||||
|
"name": authBundle.Name,
|
||||||
"access_token": authBundle.TokenData.AccessToken,
|
"access_token": authBundle.TokenData.AccessToken,
|
||||||
"token_type": authBundle.TokenData.TokenType,
|
"token_type": authBundle.TokenData.TokenType,
|
||||||
"scope": authBundle.TokenData.Scope,
|
"scope": authBundle.TokenData.Scope,
|
||||||
@@ -98,13 +100,18 @@ func (a GitHubCopilotAuthenticator) Login(ctx context.Context, cfg *config.Confi
|
|||||||
|
|
||||||
fileName := fmt.Sprintf("github-copilot-%s.json", authBundle.Username)
|
fileName := fmt.Sprintf("github-copilot-%s.json", authBundle.Username)
|
||||||
|
|
||||||
|
label := authBundle.Email
|
||||||
|
if label == "" {
|
||||||
|
label = authBundle.Username
|
||||||
|
}
|
||||||
|
|
||||||
fmt.Printf("\nGitHub Copilot authentication successful for user: %s\n", authBundle.Username)
|
fmt.Printf("\nGitHub Copilot authentication successful for user: %s\n", authBundle.Username)
|
||||||
|
|
||||||
return &coreauth.Auth{
|
return &coreauth.Auth{
|
||||||
ID: fileName,
|
ID: fileName,
|
||||||
Provider: a.Provider(),
|
Provider: a.Provider(),
|
||||||
FileName: fileName,
|
FileName: fileName,
|
||||||
Label: authBundle.Username,
|
Label: label,
|
||||||
Storage: tokenStorage,
|
Storage: tokenStorage,
|
||||||
Metadata: metadata,
|
Metadata: metadata,
|
||||||
}, nil
|
}, nil
|
||||||
|
|||||||
@@ -1828,9 +1828,7 @@ func (m *Manager) persist(ctx context.Context, auth *Auth) error {
|
|||||||
// every few seconds and triggers refresh operations when required.
|
// every few seconds and triggers refresh operations when required.
|
||||||
// Only one loop is kept alive; starting a new one cancels the previous run.
|
// Only one loop is kept alive; starting a new one cancels the previous run.
|
||||||
func (m *Manager) StartAutoRefresh(parent context.Context, interval time.Duration) {
|
func (m *Manager) StartAutoRefresh(parent context.Context, interval time.Duration) {
|
||||||
if interval <= 0 || interval > refreshCheckInterval {
|
if interval <= 0 {
|
||||||
interval = refreshCheckInterval
|
|
||||||
} else {
|
|
||||||
interval = refreshCheckInterval
|
interval = refreshCheckInterval
|
||||||
}
|
}
|
||||||
if m.refreshCancel != nil {
|
if m.refreshCancel != nil {
|
||||||
|
|||||||
@@ -963,6 +963,9 @@ func (s *Service) registerModelsForAuth(a *coreauth.Auth) {
|
|||||||
key = strings.ToLower(strings.TrimSpace(a.Provider))
|
key = strings.ToLower(strings.TrimSpace(a.Provider))
|
||||||
}
|
}
|
||||||
GlobalModelRegistry().RegisterClient(a.ID, key, applyModelPrefixes(models, a.Prefix, s.cfg != nil && s.cfg.ForceModelPrefix))
|
GlobalModelRegistry().RegisterClient(a.ID, key, applyModelPrefixes(models, a.Prefix, s.cfg != nil && s.cfg.ForceModelPrefix))
|
||||||
|
if provider == "antigravity" {
|
||||||
|
s.backfillAntigravityModels(a, models)
|
||||||
|
}
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -1107,6 +1110,56 @@ func (s *Service) oauthExcludedModels(provider, authKind string) []string {
|
|||||||
return cfg.OAuthExcludedModels[providerKey]
|
return cfg.OAuthExcludedModels[providerKey]
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (s *Service) backfillAntigravityModels(source *coreauth.Auth, primaryModels []*ModelInfo) {
|
||||||
|
if s == nil || s.coreManager == nil || len(primaryModels) == 0 {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
sourceID := ""
|
||||||
|
if source != nil {
|
||||||
|
sourceID = strings.TrimSpace(source.ID)
|
||||||
|
}
|
||||||
|
|
||||||
|
reg := registry.GetGlobalRegistry()
|
||||||
|
for _, candidate := range s.coreManager.List() {
|
||||||
|
if candidate == nil || candidate.Disabled {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
candidateID := strings.TrimSpace(candidate.ID)
|
||||||
|
if candidateID == "" || candidateID == sourceID {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if !strings.EqualFold(strings.TrimSpace(candidate.Provider), "antigravity") {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if len(reg.GetModelsForClient(candidateID)) > 0 {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
authKind := strings.ToLower(strings.TrimSpace(candidate.Attributes["auth_kind"]))
|
||||||
|
if authKind == "" {
|
||||||
|
if kind, _ := candidate.AccountInfo(); strings.EqualFold(kind, "api_key") {
|
||||||
|
authKind = "apikey"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
excluded := s.oauthExcludedModels("antigravity", authKind)
|
||||||
|
if candidate.Attributes != nil {
|
||||||
|
if val, ok := candidate.Attributes["excluded_models"]; ok && strings.TrimSpace(val) != "" {
|
||||||
|
excluded = strings.Split(val, ",")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
models := applyExcludedModels(primaryModels, excluded)
|
||||||
|
models = applyOAuthModelAlias(s.cfg, "antigravity", authKind, models)
|
||||||
|
if len(models) == 0 {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
reg.RegisterClient(candidateID, "antigravity", applyModelPrefixes(models, candidate.Prefix, s.cfg != nil && s.cfg.ForceModelPrefix))
|
||||||
|
log.Debugf("antigravity models backfilled for auth %s using primary model list", candidateID)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func applyExcludedModels(models []*ModelInfo, excluded []string) []*ModelInfo {
|
func applyExcludedModels(models []*ModelInfo, excluded []string) []*ModelInfo {
|
||||||
if len(models) == 0 || len(excluded) == 0 {
|
if len(models) == 0 || len(excluded) == 0 {
|
||||||
return models
|
return models
|
||||||
|
|||||||
135
sdk/cliproxy/service_antigravity_backfill_test.go
Normal file
135
sdk/cliproxy/service_antigravity_backfill_test.go
Normal file
@@ -0,0 +1,135 @@
|
|||||||
|
package cliproxy
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"strings"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/router-for-me/CLIProxyAPI/v6/internal/registry"
|
||||||
|
coreauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth"
|
||||||
|
"github.com/router-for-me/CLIProxyAPI/v6/sdk/config"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestBackfillAntigravityModels_RegistersMissingAuth(t *testing.T) {
|
||||||
|
source := &coreauth.Auth{
|
||||||
|
ID: "ag-backfill-source",
|
||||||
|
Provider: "antigravity",
|
||||||
|
Status: coreauth.StatusActive,
|
||||||
|
Attributes: map[string]string{
|
||||||
|
"auth_kind": "oauth",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
target := &coreauth.Auth{
|
||||||
|
ID: "ag-backfill-target",
|
||||||
|
Provider: "antigravity",
|
||||||
|
Status: coreauth.StatusActive,
|
||||||
|
Attributes: map[string]string{
|
||||||
|
"auth_kind": "oauth",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
manager := coreauth.NewManager(nil, nil, nil)
|
||||||
|
if _, err := manager.Register(context.Background(), source); err != nil {
|
||||||
|
t.Fatalf("register source auth: %v", err)
|
||||||
|
}
|
||||||
|
if _, err := manager.Register(context.Background(), target); err != nil {
|
||||||
|
t.Fatalf("register target auth: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
service := &Service{
|
||||||
|
cfg: &config.Config{},
|
||||||
|
coreManager: manager,
|
||||||
|
}
|
||||||
|
|
||||||
|
reg := registry.GetGlobalRegistry()
|
||||||
|
reg.UnregisterClient(source.ID)
|
||||||
|
reg.UnregisterClient(target.ID)
|
||||||
|
t.Cleanup(func() {
|
||||||
|
reg.UnregisterClient(source.ID)
|
||||||
|
reg.UnregisterClient(target.ID)
|
||||||
|
})
|
||||||
|
|
||||||
|
primary := []*ModelInfo{
|
||||||
|
{ID: "claude-sonnet-4-5"},
|
||||||
|
{ID: "gemini-2.5-pro"},
|
||||||
|
}
|
||||||
|
reg.RegisterClient(source.ID, "antigravity", primary)
|
||||||
|
|
||||||
|
service.backfillAntigravityModels(source, primary)
|
||||||
|
|
||||||
|
got := reg.GetModelsForClient(target.ID)
|
||||||
|
if len(got) != 2 {
|
||||||
|
t.Fatalf("expected target auth to be backfilled with 2 models, got %d", len(got))
|
||||||
|
}
|
||||||
|
|
||||||
|
ids := make(map[string]struct{}, len(got))
|
||||||
|
for _, model := range got {
|
||||||
|
if model == nil {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
ids[strings.ToLower(strings.TrimSpace(model.ID))] = struct{}{}
|
||||||
|
}
|
||||||
|
if _, ok := ids["claude-sonnet-4-5"]; !ok {
|
||||||
|
t.Fatal("expected backfilled model claude-sonnet-4-5")
|
||||||
|
}
|
||||||
|
if _, ok := ids["gemini-2.5-pro"]; !ok {
|
||||||
|
t.Fatal("expected backfilled model gemini-2.5-pro")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestBackfillAntigravityModels_RespectsExcludedModels(t *testing.T) {
|
||||||
|
source := &coreauth.Auth{
|
||||||
|
ID: "ag-backfill-source-excluded",
|
||||||
|
Provider: "antigravity",
|
||||||
|
Status: coreauth.StatusActive,
|
||||||
|
Attributes: map[string]string{
|
||||||
|
"auth_kind": "oauth",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
target := &coreauth.Auth{
|
||||||
|
ID: "ag-backfill-target-excluded",
|
||||||
|
Provider: "antigravity",
|
||||||
|
Status: coreauth.StatusActive,
|
||||||
|
Attributes: map[string]string{
|
||||||
|
"auth_kind": "oauth",
|
||||||
|
"excluded_models": "gemini-2.5-pro",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
manager := coreauth.NewManager(nil, nil, nil)
|
||||||
|
if _, err := manager.Register(context.Background(), source); err != nil {
|
||||||
|
t.Fatalf("register source auth: %v", err)
|
||||||
|
}
|
||||||
|
if _, err := manager.Register(context.Background(), target); err != nil {
|
||||||
|
t.Fatalf("register target auth: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
service := &Service{
|
||||||
|
cfg: &config.Config{},
|
||||||
|
coreManager: manager,
|
||||||
|
}
|
||||||
|
|
||||||
|
reg := registry.GetGlobalRegistry()
|
||||||
|
reg.UnregisterClient(source.ID)
|
||||||
|
reg.UnregisterClient(target.ID)
|
||||||
|
t.Cleanup(func() {
|
||||||
|
reg.UnregisterClient(source.ID)
|
||||||
|
reg.UnregisterClient(target.ID)
|
||||||
|
})
|
||||||
|
|
||||||
|
primary := []*ModelInfo{
|
||||||
|
{ID: "claude-sonnet-4-5"},
|
||||||
|
{ID: "gemini-2.5-pro"},
|
||||||
|
}
|
||||||
|
reg.RegisterClient(source.ID, "antigravity", primary)
|
||||||
|
|
||||||
|
service.backfillAntigravityModels(source, primary)
|
||||||
|
|
||||||
|
got := reg.GetModelsForClient(target.ID)
|
||||||
|
if len(got) != 1 {
|
||||||
|
t.Fatalf("expected 1 model after exclusion, got %d", len(got))
|
||||||
|
}
|
||||||
|
if got[0] == nil || !strings.EqualFold(strings.TrimSpace(got[0].ID), "claude-sonnet-4-5") {
|
||||||
|
t.Fatalf("expected remaining model %q, got %+v", "claude-sonnet-4-5", got[0])
|
||||||
|
}
|
||||||
|
}
|
||||||
Reference in New Issue
Block a user