diff --git a/internal/api/handlers/management/auth_files.go b/internal/api/handlers/management/auth_files.go index 0c41137f..4422428b 100644 --- a/internal/api/handlers/management/auth_files.go +++ b/internal/api/handlers/management/auth_files.go @@ -13,7 +13,6 @@ import ( "io" "net" "net/http" - "net/url" "os" "path/filepath" "sort" @@ -23,6 +22,7 @@ import ( "time" "github.com/gin-gonic/gin" + "github.com/router-for-me/CLIProxyAPI/v6/internal/auth/antigravity" "github.com/router-for-me/CLIProxyAPI/v6/internal/auth/claude" "github.com/router-for-me/CLIProxyAPI/v6/internal/auth/codex" "github.com/router-for-me/CLIProxyAPI/v6/internal/auth/copilot" @@ -236,14 +236,6 @@ func stopForwarderInstance(port int, forwarder *callbackForwarder) { log.Infof("callback forwarder on port %d stopped", port) } -func sanitizeAntigravityFileName(email string) string { - if strings.TrimSpace(email) == "" { - return "antigravity.json" - } - replacer := strings.NewReplacer("@", "_", ".", "_") - return fmt.Sprintf("antigravity-%s.json", replacer.Replace(email)) -} - func (h *Handler) managementCallbackURL(path string) (string, error) { if h == nil || h.cfg == nil || h.cfg.Port <= 0 { return "", fmt.Errorf("server port is not configured") @@ -985,67 +977,14 @@ func (h *Handler) RequestAnthropicToken(c *gin.Context) { rawCode := resultMap["code"] code := strings.Split(rawCode, "#")[0] - // Exchange code for tokens (replicate logic using updated redirect_uri) - // Extract client_id from the modified auth URL - clientID := "" - if u2, errP := url.Parse(authURL); errP == nil { - clientID = u2.Query().Get("client_id") - } - // Build request - bodyMap := map[string]any{ - "code": code, - "state": state, - "grant_type": "authorization_code", - "client_id": clientID, - "redirect_uri": "http://localhost:54545/callback", - "code_verifier": pkceCodes.CodeVerifier, - } - bodyJSON, _ := json.Marshal(bodyMap) - - httpClient := util.SetProxy(&h.cfg.SDKConfig, &http.Client{}) - req, _ := http.NewRequestWithContext(ctx, "POST", "https://console.anthropic.com/v1/oauth/token", strings.NewReader(string(bodyJSON))) - req.Header.Set("Content-Type", "application/json") - req.Header.Set("Accept", "application/json") - resp, errDo := httpClient.Do(req) - if errDo != nil { - authErr := claude.NewAuthenticationError(claude.ErrCodeExchangeFailed, errDo) + // Exchange code for tokens using internal auth service + bundle, errExchange := anthropicAuth.ExchangeCodeForTokens(ctx, code, state, pkceCodes) + if errExchange != nil { + authErr := claude.NewAuthenticationError(claude.ErrCodeExchangeFailed, errExchange) log.Errorf("Failed to exchange authorization code for tokens: %v", authErr) SetOAuthSessionError(state, "Failed to exchange authorization code for tokens") return } - defer func() { - if errClose := resp.Body.Close(); errClose != nil { - log.Errorf("failed to close response body: %v", errClose) - } - }() - respBody, _ := io.ReadAll(resp.Body) - if resp.StatusCode != http.StatusOK { - log.Errorf("token exchange failed with status %d: %s", resp.StatusCode, string(respBody)) - SetOAuthSessionError(state, fmt.Sprintf("token exchange failed with status %d", resp.StatusCode)) - return - } - var tResp struct { - AccessToken string `json:"access_token"` - RefreshToken string `json:"refresh_token"` - ExpiresIn int `json:"expires_in"` - Account struct { - EmailAddress string `json:"email_address"` - } `json:"account"` - } - if errU := json.Unmarshal(respBody, &tResp); errU != nil { - log.Errorf("failed to parse token response: %v", errU) - SetOAuthSessionError(state, "Failed to parse token response") - return - } - bundle := &claude.ClaudeAuthBundle{ - TokenData: claude.ClaudeTokenData{ - AccessToken: tResp.AccessToken, - RefreshToken: tResp.RefreshToken, - Email: tResp.Account.EmailAddress, - Expire: time.Now().Add(time.Duration(tResp.ExpiresIn) * time.Second).Format(time.RFC3339), - }, - LastRefresh: time.Now().Format(time.RFC3339), - } // Create token storage tokenStorage := anthropicAuth.CreateTokenStorage(bundle) @@ -1085,17 +1024,13 @@ func (h *Handler) RequestGeminiCLIToken(c *gin.Context) { fmt.Println("Initializing Google authentication...") - // OAuth2 configuration (mirrors internal/auth/gemini) + // OAuth2 configuration using exported constants from internal/auth/gemini conf := &oauth2.Config{ - ClientID: "681255809395-oo8ft2oprdrnp9e3aqf6av3hmdib135j.apps.googleusercontent.com", - ClientSecret: "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", - RedirectURL: "http://localhost:8085/oauth2callback", - Scopes: []string{ - "https://www.googleapis.com/auth/cloud-platform", - "https://www.googleapis.com/auth/userinfo.email", - "https://www.googleapis.com/auth/userinfo.profile", - }, - Endpoint: google.Endpoint, + ClientID: geminiAuth.ClientID, + ClientSecret: geminiAuth.ClientSecret, + RedirectURL: fmt.Sprintf("http://localhost:%d/oauth2callback", geminiAuth.DefaultCallbackPort), + Scopes: geminiAuth.Scopes, + Endpoint: google.Endpoint, } // Build authorization URL and return it immediately @@ -1217,13 +1152,9 @@ func (h *Handler) RequestGeminiCLIToken(c *gin.Context) { } ifToken["token_uri"] = "https://oauth2.googleapis.com/token" - ifToken["client_id"] = "681255809395-oo8ft2oprdrnp9e3aqf6av3hmdib135j.apps.googleusercontent.com" - ifToken["client_secret"] = "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" - ifToken["scopes"] = []string{ - "https://www.googleapis.com/auth/cloud-platform", - "https://www.googleapis.com/auth/userinfo.email", - "https://www.googleapis.com/auth/userinfo.profile", - } + ifToken["client_id"] = geminiAuth.ClientID + ifToken["client_secret"] = geminiAuth.ClientSecret + ifToken["scopes"] = geminiAuth.Scopes ifToken["universe_domain"] = "googleapis.com" ts := geminiAuth.GeminiTokenStorage{ @@ -1410,73 +1341,25 @@ func (h *Handler) RequestCodexToken(c *gin.Context) { } log.Debug("Authorization code received, exchanging for tokens...") - // Extract client_id from authURL - clientID := "" - if u2, errP := url.Parse(authURL); errP == nil { - clientID = u2.Query().Get("client_id") - } - // Exchange code for tokens with redirect equal to mgmtRedirect - form := url.Values{ - "grant_type": {"authorization_code"}, - "client_id": {clientID}, - "code": {code}, - "redirect_uri": {"http://localhost:1455/auth/callback"}, - "code_verifier": {pkceCodes.CodeVerifier}, - } - httpClient := util.SetProxy(&h.cfg.SDKConfig, &http.Client{}) - req, _ := http.NewRequestWithContext(ctx, "POST", "https://auth.openai.com/oauth/token", strings.NewReader(form.Encode())) - req.Header.Set("Content-Type", "application/x-www-form-urlencoded") - req.Header.Set("Accept", "application/json") - resp, errDo := httpClient.Do(req) - if errDo != nil { - authErr := codex.NewAuthenticationError(codex.ErrCodeExchangeFailed, errDo) + // Exchange code for tokens using internal auth service + bundle, errExchange := openaiAuth.ExchangeCodeForTokens(ctx, code, pkceCodes) + if errExchange != nil { + authErr := codex.NewAuthenticationError(codex.ErrCodeExchangeFailed, errExchange) SetOAuthSessionError(state, "Failed to exchange authorization code for tokens") log.Errorf("Failed to exchange authorization code for tokens: %v", authErr) return } - defer func() { _ = resp.Body.Close() }() - respBody, _ := io.ReadAll(resp.Body) - if resp.StatusCode != http.StatusOK { - SetOAuthSessionError(state, fmt.Sprintf("Token exchange failed with status %d", resp.StatusCode)) - log.Errorf("token exchange failed with status %d: %s", resp.StatusCode, string(respBody)) - return - } - var tokenResp struct { - AccessToken string `json:"access_token"` - RefreshToken string `json:"refresh_token"` - IDToken string `json:"id_token"` - ExpiresIn int `json:"expires_in"` - } - if errU := json.Unmarshal(respBody, &tokenResp); errU != nil { - SetOAuthSessionError(state, "Failed to parse token response") - log.Errorf("failed to parse token response: %v", errU) - return - } - claims, _ := codex.ParseJWTToken(tokenResp.IDToken) - email := "" - accountID := "" + + // Extract additional info for filename generation + claims, _ := codex.ParseJWTToken(bundle.TokenData.IDToken) planType := "" - if claims != nil { - email = claims.GetUserEmail() - accountID = claims.GetAccountID() - planType = strings.TrimSpace(claims.CodexAuthInfo.ChatgptPlanType) - } hashAccountID := "" - if accountID != "" { - digest := sha256.Sum256([]byte(accountID)) - hashAccountID = hex.EncodeToString(digest[:])[:8] - } - // Build bundle compatible with existing storage - bundle := &codex.CodexAuthBundle{ - TokenData: codex.CodexTokenData{ - IDToken: tokenResp.IDToken, - AccessToken: tokenResp.AccessToken, - RefreshToken: tokenResp.RefreshToken, - AccountID: accountID, - Email: email, - Expire: time.Now().Add(time.Duration(tokenResp.ExpiresIn) * time.Second).Format(time.RFC3339), - }, - LastRefresh: time.Now().Format(time.RFC3339), + if claims != nil { + planType = strings.TrimSpace(claims.CodexAuthInfo.ChatgptPlanType) + if accountID := claims.GetAccountID(); accountID != "" { + digest := sha256.Sum256([]byte(accountID)) + hashAccountID = hex.EncodeToString(digest[:])[:8] + } } // Create token storage and persist @@ -1511,23 +1394,12 @@ func (h *Handler) RequestCodexToken(c *gin.Context) { } func (h *Handler) RequestAntigravityToken(c *gin.Context) { - const ( - antigravityCallbackPort = 51121 - antigravityClientID = "1071006060591-tmhssin2h21lcre235vtolojh4g403ep.apps.googleusercontent.com" - antigravityClientSecret = "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" - ) - var antigravityScopes = []string{ - "https://www.googleapis.com/auth/cloud-platform", - "https://www.googleapis.com/auth/userinfo.email", - "https://www.googleapis.com/auth/userinfo.profile", - "https://www.googleapis.com/auth/cclog", - "https://www.googleapis.com/auth/experimentsandconfigs", - } - ctx := context.Background() fmt.Println("Initializing Antigravity authentication...") + authSvc := antigravity.NewAntigravityAuth(h.cfg, nil) + state, errState := misc.GenerateRandomState() if errState != nil { log.Errorf("Failed to generate state parameter: %v", errState) @@ -1535,17 +1407,8 @@ func (h *Handler) RequestAntigravityToken(c *gin.Context) { return } - redirectURI := fmt.Sprintf("http://localhost:%d/oauth-callback", antigravityCallbackPort) - - params := url.Values{} - params.Set("access_type", "offline") - params.Set("client_id", antigravityClientID) - params.Set("prompt", "consent") - params.Set("redirect_uri", redirectURI) - params.Set("response_type", "code") - params.Set("scope", strings.Join(antigravityScopes, " ")) - params.Set("state", state) - authURL := "https://accounts.google.com/o/oauth2/v2/auth?" + params.Encode() + redirectURI := fmt.Sprintf("http://localhost:%d/oauth-callback", antigravity.CallbackPort) + authURL := authSvc.BuildAuthURL(state, redirectURI) RegisterOAuthSession(state, "antigravity") @@ -1559,7 +1422,7 @@ func (h *Handler) RequestAntigravityToken(c *gin.Context) { return } var errStart error - if forwarder, errStart = startCallbackForwarder(antigravityCallbackPort, "antigravity", targetURL); errStart != nil { + if forwarder, errStart = startCallbackForwarder(antigravity.CallbackPort, "antigravity", targetURL); errStart != nil { log.WithError(errStart).Error("failed to start antigravity callback forwarder") c.JSON(http.StatusInternalServerError, gin.H{"error": "failed to start callback server"}) return @@ -1568,7 +1431,7 @@ func (h *Handler) RequestAntigravityToken(c *gin.Context) { go func() { if isWebUI { - defer stopCallbackForwarderInstance(antigravityCallbackPort, forwarder) + defer stopCallbackForwarderInstance(antigravity.CallbackPort, forwarder) } waitFile := filepath.Join(h.cfg.AuthDir, fmt.Sprintf(".oauth-antigravity-%s.oauth", state)) @@ -1608,93 +1471,36 @@ func (h *Handler) RequestAntigravityToken(c *gin.Context) { time.Sleep(500 * time.Millisecond) } - httpClient := util.SetProxy(&h.cfg.SDKConfig, &http.Client{}) - form := url.Values{} - form.Set("code", authCode) - form.Set("client_id", antigravityClientID) - form.Set("client_secret", antigravityClientSecret) - form.Set("redirect_uri", redirectURI) - form.Set("grant_type", "authorization_code") - - req, errNewRequest := http.NewRequestWithContext(ctx, http.MethodPost, "https://oauth2.googleapis.com/token", strings.NewReader(form.Encode())) - if errNewRequest != nil { - log.Errorf("Failed to build token request: %v", errNewRequest) - SetOAuthSessionError(state, "Failed to build token request") - return - } - req.Header.Set("Content-Type", "application/x-www-form-urlencoded") - - resp, errDo := httpClient.Do(req) - if errDo != nil { - log.Errorf("Failed to execute token request: %v", errDo) + tokenResp, errToken := authSvc.ExchangeCodeForTokens(ctx, authCode, redirectURI) + if errToken != nil { + log.Errorf("Failed to exchange token: %v", errToken) SetOAuthSessionError(state, "Failed to exchange token") return } - defer func() { - if errClose := resp.Body.Close(); errClose != nil { - log.Errorf("antigravity token exchange close error: %v", errClose) - } - }() - if resp.StatusCode < http.StatusOK || resp.StatusCode >= http.StatusMultipleChoices { - bodyBytes, _ := io.ReadAll(resp.Body) - log.Errorf("Antigravity token exchange failed with status %d: %s", resp.StatusCode, string(bodyBytes)) - SetOAuthSessionError(state, fmt.Sprintf("Token exchange failed: %d", resp.StatusCode)) + accessToken := strings.TrimSpace(tokenResp.AccessToken) + if accessToken == "" { + log.Error("antigravity: token exchange returned empty access token") + SetOAuthSessionError(state, "Failed to exchange token") return } - var tokenResp struct { - AccessToken string `json:"access_token"` - RefreshToken string `json:"refresh_token"` - ExpiresIn int64 `json:"expires_in"` - TokenType string `json:"token_type"` - } - if errDecode := json.NewDecoder(resp.Body).Decode(&tokenResp); errDecode != nil { - log.Errorf("Failed to parse token response: %v", errDecode) - SetOAuthSessionError(state, "Failed to parse token response") + email, errInfo := authSvc.FetchUserInfo(ctx, accessToken) + if errInfo != nil { + log.Errorf("Failed to fetch user info: %v", errInfo) + SetOAuthSessionError(state, "Failed to fetch user info") return } - - email := "" - if strings.TrimSpace(tokenResp.AccessToken) != "" { - infoReq, errInfoReq := http.NewRequestWithContext(ctx, http.MethodGet, "https://www.googleapis.com/oauth2/v1/userinfo?alt=json", nil) - if errInfoReq != nil { - log.Errorf("Failed to build user info request: %v", errInfoReq) - SetOAuthSessionError(state, "Failed to build user info request") - return - } - infoReq.Header.Set("Authorization", "Bearer "+tokenResp.AccessToken) - - infoResp, errInfo := httpClient.Do(infoReq) - if errInfo != nil { - log.Errorf("Failed to execute user info request: %v", errInfo) - SetOAuthSessionError(state, "Failed to execute user info request") - return - } - defer func() { - if errClose := infoResp.Body.Close(); errClose != nil { - log.Errorf("antigravity user info close error: %v", errClose) - } - }() - - if infoResp.StatusCode >= http.StatusOK && infoResp.StatusCode < http.StatusMultipleChoices { - var infoPayload struct { - Email string `json:"email"` - } - if errDecodeInfo := json.NewDecoder(infoResp.Body).Decode(&infoPayload); errDecodeInfo == nil { - email = strings.TrimSpace(infoPayload.Email) - } - } else { - bodyBytes, _ := io.ReadAll(infoResp.Body) - log.Errorf("User info request failed with status %d: %s", infoResp.StatusCode, string(bodyBytes)) - SetOAuthSessionError(state, fmt.Sprintf("User info request failed: %d", infoResp.StatusCode)) - return - } + email = strings.TrimSpace(email) + if email == "" { + log.Error("antigravity: user info returned empty email") + SetOAuthSessionError(state, "Failed to fetch user info") + return } projectID := "" - if strings.TrimSpace(tokenResp.AccessToken) != "" { - fetchedProjectID, errProject := sdkAuth.FetchAntigravityProjectID(ctx, tokenResp.AccessToken, httpClient) + if accessToken != "" { + fetchedProjectID, errProject := authSvc.FetchProjectID(ctx, accessToken) if errProject != nil { log.Warnf("antigravity: failed to fetch project ID: %v", errProject) } else { @@ -1719,7 +1525,7 @@ func (h *Handler) RequestAntigravityToken(c *gin.Context) { metadata["project_id"] = projectID } - fileName := sanitizeAntigravityFileName(email) + fileName := antigravity.CredentialFileName(email) label := strings.TrimSpace(email) if label == "" { label = "antigravity" diff --git a/internal/auth/antigravity/auth.go b/internal/auth/antigravity/auth.go new file mode 100644 index 00000000..449f413f --- /dev/null +++ b/internal/auth/antigravity/auth.go @@ -0,0 +1,344 @@ +// Package antigravity provides OAuth2 authentication functionality for the Antigravity provider. +package antigravity + +import ( + "context" + "encoding/json" + "fmt" + "io" + "net/http" + "net/url" + "strings" + "time" + + "github.com/router-for-me/CLIProxyAPI/v6/internal/config" + "github.com/router-for-me/CLIProxyAPI/v6/internal/util" + log "github.com/sirupsen/logrus" +) + +// TokenResponse represents OAuth token response from Google +type TokenResponse struct { + AccessToken string `json:"access_token"` + RefreshToken string `json:"refresh_token"` + ExpiresIn int64 `json:"expires_in"` + TokenType string `json:"token_type"` +} + +// userInfo represents Google user profile +type userInfo struct { + Email string `json:"email"` +} + +// AntigravityAuth handles Antigravity OAuth authentication +type AntigravityAuth struct { + httpClient *http.Client +} + +// NewAntigravityAuth creates a new Antigravity auth service. +func NewAntigravityAuth(cfg *config.Config, httpClient *http.Client) *AntigravityAuth { + if httpClient != nil { + return &AntigravityAuth{httpClient: httpClient} + } + if cfg == nil { + cfg = &config.Config{} + } + return &AntigravityAuth{ + httpClient: util.SetProxy(&cfg.SDKConfig, &http.Client{}), + } +} + +// BuildAuthURL generates the OAuth authorization URL. +func (o *AntigravityAuth) BuildAuthURL(state, redirectURI string) string { + if strings.TrimSpace(redirectURI) == "" { + redirectURI = fmt.Sprintf("http://localhost:%d/oauth-callback", CallbackPort) + } + params := url.Values{} + params.Set("access_type", "offline") + params.Set("client_id", ClientID) + params.Set("prompt", "consent") + params.Set("redirect_uri", redirectURI) + params.Set("response_type", "code") + params.Set("scope", strings.Join(Scopes, " ")) + params.Set("state", state) + return AuthEndpoint + "?" + params.Encode() +} + +// ExchangeCodeForTokens exchanges authorization code for access and refresh tokens +func (o *AntigravityAuth) ExchangeCodeForTokens(ctx context.Context, code, redirectURI string) (*TokenResponse, error) { + data := url.Values{} + data.Set("code", code) + data.Set("client_id", ClientID) + data.Set("client_secret", ClientSecret) + data.Set("redirect_uri", redirectURI) + data.Set("grant_type", "authorization_code") + + req, err := http.NewRequestWithContext(ctx, http.MethodPost, TokenEndpoint, strings.NewReader(data.Encode())) + if err != nil { + return nil, fmt.Errorf("antigravity token exchange: create request: %w", err) + } + req.Header.Set("Content-Type", "application/x-www-form-urlencoded") + + resp, errDo := o.httpClient.Do(req) + if errDo != nil { + return nil, fmt.Errorf("antigravity token exchange: execute request: %w", errDo) + } + defer func() { + if errClose := resp.Body.Close(); errClose != nil { + log.Errorf("antigravity token exchange: close body error: %v", errClose) + } + }() + + if resp.StatusCode < http.StatusOK || resp.StatusCode >= http.StatusMultipleChoices { + bodyBytes, errRead := io.ReadAll(io.LimitReader(resp.Body, 8<<10)) + if errRead != nil { + return nil, fmt.Errorf("antigravity token exchange: read response: %w", errRead) + } + body := strings.TrimSpace(string(bodyBytes)) + if body == "" { + return nil, fmt.Errorf("antigravity token exchange: request failed: status %d", resp.StatusCode) + } + return nil, fmt.Errorf("antigravity token exchange: request failed: status %d: %s", resp.StatusCode, body) + } + + var token TokenResponse + if errDecode := json.NewDecoder(resp.Body).Decode(&token); errDecode != nil { + return nil, fmt.Errorf("antigravity token exchange: decode response: %w", errDecode) + } + return &token, nil +} + +// FetchUserInfo retrieves user email from Google +func (o *AntigravityAuth) FetchUserInfo(ctx context.Context, accessToken string) (string, error) { + accessToken = strings.TrimSpace(accessToken) + if accessToken == "" { + return "", fmt.Errorf("antigravity userinfo: missing access token") + } + req, err := http.NewRequestWithContext(ctx, http.MethodGet, UserInfoEndpoint, nil) + if err != nil { + return "", fmt.Errorf("antigravity userinfo: create request: %w", err) + } + req.Header.Set("Authorization", "Bearer "+accessToken) + + resp, errDo := o.httpClient.Do(req) + if errDo != nil { + return "", fmt.Errorf("antigravity userinfo: execute request: %w", errDo) + } + defer func() { + if errClose := resp.Body.Close(); errClose != nil { + log.Errorf("antigravity userinfo: close body error: %v", errClose) + } + }() + + if resp.StatusCode < http.StatusOK || resp.StatusCode >= http.StatusMultipleChoices { + bodyBytes, errRead := io.ReadAll(io.LimitReader(resp.Body, 8<<10)) + if errRead != nil { + return "", fmt.Errorf("antigravity userinfo: read response: %w", errRead) + } + body := strings.TrimSpace(string(bodyBytes)) + if body == "" { + return "", fmt.Errorf("antigravity userinfo: request failed: status %d", resp.StatusCode) + } + return "", fmt.Errorf("antigravity userinfo: request failed: status %d: %s", resp.StatusCode, body) + } + var info userInfo + if errDecode := json.NewDecoder(resp.Body).Decode(&info); errDecode != nil { + return "", fmt.Errorf("antigravity userinfo: decode response: %w", errDecode) + } + email := strings.TrimSpace(info.Email) + if email == "" { + return "", fmt.Errorf("antigravity userinfo: response missing email") + } + return email, nil +} + +// FetchProjectID retrieves the project ID for the authenticated user via loadCodeAssist +func (o *AntigravityAuth) FetchProjectID(ctx context.Context, accessToken string) (string, error) { + loadReqBody := map[string]any{ + "metadata": map[string]string{ + "ideType": "ANTIGRAVITY", + "platform": "PLATFORM_UNSPECIFIED", + "pluginType": "GEMINI", + }, + } + + rawBody, errMarshal := json.Marshal(loadReqBody) + if errMarshal != nil { + return "", fmt.Errorf("marshal request body: %w", errMarshal) + } + + endpointURL := fmt.Sprintf("%s/%s:loadCodeAssist", APIEndpoint, APIVersion) + req, err := http.NewRequestWithContext(ctx, http.MethodPost, endpointURL, strings.NewReader(string(rawBody))) + if err != nil { + return "", fmt.Errorf("create request: %w", err) + } + req.Header.Set("Authorization", "Bearer "+accessToken) + req.Header.Set("Content-Type", "application/json") + req.Header.Set("User-Agent", APIUserAgent) + req.Header.Set("X-Goog-Api-Client", APIClient) + req.Header.Set("Client-Metadata", ClientMetadata) + + resp, errDo := o.httpClient.Do(req) + if errDo != nil { + return "", fmt.Errorf("execute request: %w", errDo) + } + defer func() { + if errClose := resp.Body.Close(); errClose != nil { + log.Errorf("antigravity loadCodeAssist: close body error: %v", errClose) + } + }() + + bodyBytes, errRead := io.ReadAll(resp.Body) + if errRead != nil { + return "", fmt.Errorf("read response: %w", errRead) + } + + if resp.StatusCode < http.StatusOK || resp.StatusCode >= http.StatusMultipleChoices { + return "", fmt.Errorf("request failed with status %d: %s", resp.StatusCode, strings.TrimSpace(string(bodyBytes))) + } + + var loadResp map[string]any + if errDecode := json.Unmarshal(bodyBytes, &loadResp); errDecode != nil { + return "", fmt.Errorf("decode response: %w", errDecode) + } + + // Extract projectID from response + projectID := "" + if id, ok := loadResp["cloudaicompanionProject"].(string); ok { + projectID = strings.TrimSpace(id) + } + if projectID == "" { + if projectMap, ok := loadResp["cloudaicompanionProject"].(map[string]any); ok { + if id, okID := projectMap["id"].(string); okID { + projectID = strings.TrimSpace(id) + } + } + } + + if projectID == "" { + tierID := "legacy-tier" + if tiers, okTiers := loadResp["allowedTiers"].([]any); okTiers { + for _, rawTier := range tiers { + tier, okTier := rawTier.(map[string]any) + if !okTier { + continue + } + if isDefault, okDefault := tier["isDefault"].(bool); okDefault && isDefault { + if id, okID := tier["id"].(string); okID && strings.TrimSpace(id) != "" { + tierID = strings.TrimSpace(id) + break + } + } + } + } + + projectID, err = o.OnboardUser(ctx, accessToken, tierID) + if err != nil { + return "", err + } + return projectID, nil + } + + return projectID, nil +} + +// OnboardUser attempts to fetch the project ID via onboardUser by polling for completion +func (o *AntigravityAuth) OnboardUser(ctx context.Context, accessToken, tierID string) (string, error) { + log.Infof("Antigravity: onboarding user with tier: %s", tierID) + requestBody := map[string]any{ + "tierId": tierID, + "metadata": map[string]string{ + "ideType": "ANTIGRAVITY", + "platform": "PLATFORM_UNSPECIFIED", + "pluginType": "GEMINI", + }, + } + + rawBody, errMarshal := json.Marshal(requestBody) + if errMarshal != nil { + return "", fmt.Errorf("marshal request body: %w", errMarshal) + } + + maxAttempts := 5 + for attempt := 1; attempt <= maxAttempts; attempt++ { + log.Debugf("Polling attempt %d/%d", attempt, maxAttempts) + + reqCtx := ctx + var cancel context.CancelFunc + if reqCtx == nil { + reqCtx = context.Background() + } + reqCtx, cancel = context.WithTimeout(reqCtx, 30*time.Second) + + endpointURL := fmt.Sprintf("%s/%s:onboardUser", APIEndpoint, APIVersion) + req, errRequest := http.NewRequestWithContext(reqCtx, http.MethodPost, endpointURL, strings.NewReader(string(rawBody))) + if errRequest != nil { + cancel() + return "", fmt.Errorf("create request: %w", errRequest) + } + req.Header.Set("Authorization", "Bearer "+accessToken) + req.Header.Set("Content-Type", "application/json") + req.Header.Set("User-Agent", APIUserAgent) + req.Header.Set("X-Goog-Api-Client", APIClient) + req.Header.Set("Client-Metadata", ClientMetadata) + + resp, errDo := o.httpClient.Do(req) + if errDo != nil { + cancel() + return "", fmt.Errorf("execute request: %w", errDo) + } + + bodyBytes, errRead := io.ReadAll(resp.Body) + if errClose := resp.Body.Close(); errClose != nil { + log.Errorf("close body error: %v", errClose) + } + cancel() + + if errRead != nil { + return "", fmt.Errorf("read response: %w", errRead) + } + + if resp.StatusCode == http.StatusOK { + var data map[string]any + if errDecode := json.Unmarshal(bodyBytes, &data); errDecode != nil { + return "", fmt.Errorf("decode response: %w", errDecode) + } + + if done, okDone := data["done"].(bool); okDone && done { + projectID := "" + if responseData, okResp := data["response"].(map[string]any); okResp { + switch projectValue := responseData["cloudaicompanionProject"].(type) { + case map[string]any: + if id, okID := projectValue["id"].(string); okID { + projectID = strings.TrimSpace(id) + } + case string: + projectID = strings.TrimSpace(projectValue) + } + } + + if projectID != "" { + log.Infof("Successfully fetched project_id: %s", projectID) + return projectID, nil + } + + return "", fmt.Errorf("no project_id in response") + } + + time.Sleep(2 * time.Second) + continue + } + + responsePreview := strings.TrimSpace(string(bodyBytes)) + if len(responsePreview) > 500 { + responsePreview = responsePreview[:500] + } + + responseErr := responsePreview + if len(responseErr) > 200 { + responseErr = responseErr[:200] + } + return "", fmt.Errorf("http %d: %s", resp.StatusCode, responseErr) + } + + return "", nil +} diff --git a/internal/auth/antigravity/constants.go b/internal/auth/antigravity/constants.go new file mode 100644 index 00000000..680c8e3c --- /dev/null +++ b/internal/auth/antigravity/constants.go @@ -0,0 +1,34 @@ +// Package antigravity provides OAuth2 authentication functionality for the Antigravity provider. +package antigravity + +// OAuth client credentials and configuration +const ( + ClientID = "1071006060591-tmhssin2h21lcre235vtolojh4g403ep.apps.googleusercontent.com" + ClientSecret = "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" + CallbackPort = 51121 +) + +// Scopes defines the OAuth scopes required for Antigravity authentication +var Scopes = []string{ + "https://www.googleapis.com/auth/cloud-platform", + "https://www.googleapis.com/auth/userinfo.email", + "https://www.googleapis.com/auth/userinfo.profile", + "https://www.googleapis.com/auth/cclog", + "https://www.googleapis.com/auth/experimentsandconfigs", +} + +// OAuth2 endpoints for Google authentication +const ( + TokenEndpoint = "https://oauth2.googleapis.com/token" + AuthEndpoint = "https://accounts.google.com/o/oauth2/v2/auth" + UserInfoEndpoint = "https://www.googleapis.com/oauth2/v1/userinfo?alt=json" +) + +// Antigravity API configuration +const ( + APIEndpoint = "https://cloudcode-pa.googleapis.com" + APIVersion = "v1internal" + APIUserAgent = "google-api-nodejs-client/9.15.1" + APIClient = "google-cloud-sdk vscode_cloudshelleditor/0.1" + ClientMetadata = `{"ideType":"IDE_UNSPECIFIED","platform":"PLATFORM_UNSPECIFIED","pluginType":"GEMINI"}` +) diff --git a/internal/auth/antigravity/filename.go b/internal/auth/antigravity/filename.go new file mode 100644 index 00000000..03ad3e2f --- /dev/null +++ b/internal/auth/antigravity/filename.go @@ -0,0 +1,16 @@ +package antigravity + +import ( + "fmt" + "strings" +) + +// CredentialFileName returns the filename used to persist Antigravity credentials. +// It uses the email as a suffix to disambiguate accounts. +func CredentialFileName(email string) string { + email = strings.TrimSpace(email) + if email == "" { + return "antigravity.json" + } + return fmt.Sprintf("antigravity-%s.json", email) +} diff --git a/internal/auth/claude/anthropic_auth.go b/internal/auth/claude/anthropic_auth.go index 07bd5b42..54edce3b 100644 --- a/internal/auth/claude/anthropic_auth.go +++ b/internal/auth/claude/anthropic_auth.go @@ -18,11 +18,12 @@ import ( log "github.com/sirupsen/logrus" ) +// OAuth configuration constants for Claude/Anthropic const ( - anthropicAuthURL = "https://claude.ai/oauth/authorize" - anthropicTokenURL = "https://console.anthropic.com/v1/oauth/token" - anthropicClientID = "9d1c250a-e61b-44d9-88ed-5944d1962f5e" - redirectURI = "http://localhost:54545/callback" + AuthURL = "https://claude.ai/oauth/authorize" + TokenURL = "https://console.anthropic.com/v1/oauth/token" + ClientID = "9d1c250a-e61b-44d9-88ed-5944d1962f5e" + RedirectURI = "http://localhost:54545/callback" ) // tokenResponse represents the response structure from Anthropic's OAuth token endpoint. @@ -82,16 +83,16 @@ func (o *ClaudeAuth) GenerateAuthURL(state string, pkceCodes *PKCECodes) (string params := url.Values{ "code": {"true"}, - "client_id": {anthropicClientID}, + "client_id": {ClientID}, "response_type": {"code"}, - "redirect_uri": {redirectURI}, + "redirect_uri": {RedirectURI}, "scope": {"org:create_api_key user:profile user:inference"}, "code_challenge": {pkceCodes.CodeChallenge}, "code_challenge_method": {"S256"}, "state": {state}, } - authURL := fmt.Sprintf("%s?%s", anthropicAuthURL, params.Encode()) + authURL := fmt.Sprintf("%s?%s", AuthURL, params.Encode()) return authURL, state, nil } @@ -137,8 +138,8 @@ func (o *ClaudeAuth) ExchangeCodeForTokens(ctx context.Context, code, state stri "code": newCode, "state": state, "grant_type": "authorization_code", - "client_id": anthropicClientID, - "redirect_uri": redirectURI, + "client_id": ClientID, + "redirect_uri": RedirectURI, "code_verifier": pkceCodes.CodeVerifier, } @@ -154,7 +155,7 @@ func (o *ClaudeAuth) ExchangeCodeForTokens(ctx context.Context, code, state stri // log.Debugf("Token exchange request: %s", string(jsonBody)) - req, err := http.NewRequestWithContext(ctx, "POST", anthropicTokenURL, strings.NewReader(string(jsonBody))) + req, err := http.NewRequestWithContext(ctx, "POST", TokenURL, strings.NewReader(string(jsonBody))) if err != nil { return nil, fmt.Errorf("failed to create token request: %w", err) } @@ -221,7 +222,7 @@ func (o *ClaudeAuth) RefreshTokens(ctx context.Context, refreshToken string) (*C } reqBody := map[string]interface{}{ - "client_id": anthropicClientID, + "client_id": ClientID, "grant_type": "refresh_token", "refresh_token": refreshToken, } @@ -231,7 +232,7 @@ func (o *ClaudeAuth) RefreshTokens(ctx context.Context, refreshToken string) (*C return nil, fmt.Errorf("failed to marshal request body: %w", err) } - req, err := http.NewRequestWithContext(ctx, "POST", anthropicTokenURL, strings.NewReader(string(jsonBody))) + req, err := http.NewRequestWithContext(ctx, "POST", TokenURL, strings.NewReader(string(jsonBody))) if err != nil { return nil, fmt.Errorf("failed to create refresh request: %w", err) } diff --git a/internal/auth/codex/openai_auth.go b/internal/auth/codex/openai_auth.go index c0299c3d..89deeadb 100644 --- a/internal/auth/codex/openai_auth.go +++ b/internal/auth/codex/openai_auth.go @@ -19,11 +19,12 @@ import ( log "github.com/sirupsen/logrus" ) +// OAuth configuration constants for OpenAI Codex const ( - openaiAuthURL = "https://auth.openai.com/oauth/authorize" - openaiTokenURL = "https://auth.openai.com/oauth/token" - openaiClientID = "app_EMoamEEZ73f0CkXaXp7hrann" - redirectURI = "http://localhost:1455/auth/callback" + AuthURL = "https://auth.openai.com/oauth/authorize" + TokenURL = "https://auth.openai.com/oauth/token" + ClientID = "app_EMoamEEZ73f0CkXaXp7hrann" + RedirectURI = "http://localhost:1455/auth/callback" ) // CodexAuth handles the OpenAI OAuth2 authentication flow. @@ -50,9 +51,9 @@ func (o *CodexAuth) GenerateAuthURL(state string, pkceCodes *PKCECodes) (string, } params := url.Values{ - "client_id": {openaiClientID}, + "client_id": {ClientID}, "response_type": {"code"}, - "redirect_uri": {redirectURI}, + "redirect_uri": {RedirectURI}, "scope": {"openid email profile offline_access"}, "state": {state}, "code_challenge": {pkceCodes.CodeChallenge}, @@ -62,7 +63,7 @@ func (o *CodexAuth) GenerateAuthURL(state string, pkceCodes *PKCECodes) (string, "codex_cli_simplified_flow": {"true"}, } - authURL := fmt.Sprintf("%s?%s", openaiAuthURL, params.Encode()) + authURL := fmt.Sprintf("%s?%s", AuthURL, params.Encode()) return authURL, nil } @@ -77,13 +78,13 @@ func (o *CodexAuth) ExchangeCodeForTokens(ctx context.Context, code string, pkce // Prepare token exchange request data := url.Values{ "grant_type": {"authorization_code"}, - "client_id": {openaiClientID}, + "client_id": {ClientID}, "code": {code}, - "redirect_uri": {redirectURI}, + "redirect_uri": {RedirectURI}, "code_verifier": {pkceCodes.CodeVerifier}, } - req, err := http.NewRequestWithContext(ctx, "POST", openaiTokenURL, strings.NewReader(data.Encode())) + req, err := http.NewRequestWithContext(ctx, "POST", TokenURL, strings.NewReader(data.Encode())) if err != nil { return nil, fmt.Errorf("failed to create token request: %w", err) } @@ -163,13 +164,13 @@ func (o *CodexAuth) RefreshTokens(ctx context.Context, refreshToken string) (*Co } data := url.Values{ - "client_id": {openaiClientID}, + "client_id": {ClientID}, "grant_type": {"refresh_token"}, "refresh_token": {refreshToken}, "scope": {"openid profile email"}, } - req, err := http.NewRequestWithContext(ctx, "POST", openaiTokenURL, strings.NewReader(data.Encode())) + req, err := http.NewRequestWithContext(ctx, "POST", TokenURL, strings.NewReader(data.Encode())) if err != nil { return nil, fmt.Errorf("failed to create refresh request: %w", err) } diff --git a/internal/auth/gemini/gemini_auth.go b/internal/auth/gemini/gemini_auth.go index 708ac809..6406a0e1 100644 --- a/internal/auth/gemini/gemini_auth.go +++ b/internal/auth/gemini/gemini_auth.go @@ -28,19 +28,19 @@ import ( "golang.org/x/oauth2/google" ) +// OAuth configuration constants for Gemini const ( - geminiOauthClientID = "681255809395-oo8ft2oprdrnp9e3aqf6av3hmdib135j.apps.googleusercontent.com" - geminiOauthClientSecret = "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" - geminiDefaultCallbackPort = 8085 + ClientID = "681255809395-oo8ft2oprdrnp9e3aqf6av3hmdib135j.apps.googleusercontent.com" + ClientSecret = "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" + DefaultCallbackPort = 8085 ) -var ( - geminiOauthScopes = []string{ - "https://www.googleapis.com/auth/cloud-platform", - "https://www.googleapis.com/auth/userinfo.email", - "https://www.googleapis.com/auth/userinfo.profile", - } -) +// OAuth scopes for Gemini authentication +var Scopes = []string{ + "https://www.googleapis.com/auth/cloud-platform", + "https://www.googleapis.com/auth/userinfo.email", + "https://www.googleapis.com/auth/userinfo.profile", +} // GeminiAuth provides methods for handling the Gemini OAuth2 authentication flow. // It encapsulates the logic for obtaining, storing, and refreshing authentication tokens @@ -74,7 +74,7 @@ func NewGeminiAuth() *GeminiAuth { // - *http.Client: An HTTP client configured with authentication // - error: An error if the client configuration fails, nil otherwise func (g *GeminiAuth) GetAuthenticatedClient(ctx context.Context, ts *GeminiTokenStorage, cfg *config.Config, opts *WebLoginOptions) (*http.Client, error) { - callbackPort := geminiDefaultCallbackPort + callbackPort := DefaultCallbackPort if opts != nil && opts.CallbackPort > 0 { callbackPort = opts.CallbackPort } @@ -112,10 +112,10 @@ func (g *GeminiAuth) GetAuthenticatedClient(ctx context.Context, ts *GeminiToken // Configure the OAuth2 client. conf := &oauth2.Config{ - ClientID: geminiOauthClientID, - ClientSecret: geminiOauthClientSecret, + ClientID: ClientID, + ClientSecret: ClientSecret, RedirectURL: callbackURL, // This will be used by the local server. - Scopes: geminiOauthScopes, + Scopes: Scopes, Endpoint: google.Endpoint, } @@ -198,9 +198,9 @@ func (g *GeminiAuth) createTokenStorage(ctx context.Context, config *oauth2.Conf } ifToken["token_uri"] = "https://oauth2.googleapis.com/token" - ifToken["client_id"] = geminiOauthClientID - ifToken["client_secret"] = geminiOauthClientSecret - ifToken["scopes"] = geminiOauthScopes + ifToken["client_id"] = ClientID + ifToken["client_secret"] = ClientSecret + ifToken["scopes"] = Scopes ifToken["universe_domain"] = "googleapis.com" ts := GeminiTokenStorage{ @@ -226,7 +226,7 @@ func (g *GeminiAuth) createTokenStorage(ctx context.Context, config *oauth2.Conf // - *oauth2.Token: The OAuth2 token obtained from the authorization flow // - error: An error if the token acquisition fails, nil otherwise func (g *GeminiAuth) getTokenFromWeb(ctx context.Context, config *oauth2.Config, opts *WebLoginOptions) (*oauth2.Token, error) { - callbackPort := geminiDefaultCallbackPort + callbackPort := DefaultCallbackPort if opts != nil && opts.CallbackPort > 0 { callbackPort = opts.CallbackPort } diff --git a/internal/registry/model_registry.go b/internal/registry/model_registry.go index f9b6331c..3145023e 100644 --- a/internal/registry/model_registry.go +++ b/internal/registry/model_registry.go @@ -1042,10 +1042,10 @@ func (r *ModelRegistry) convertModelToMap(model *ModelInfo, handlerType string) "owned_by": model.OwnedBy, } if model.Created > 0 { - result["created"] = model.Created + result["created_at"] = model.Created } if model.Type != "" { - result["type"] = model.Type + result["type"] = "model" } if model.DisplayName != "" { result["display_name"] = model.DisplayName diff --git a/internal/runtime/executor/antigravity_executor.go b/internal/runtime/executor/antigravity_executor.go index 110a1445..1ceb0f73 100644 --- a/internal/runtime/executor/antigravity_executor.go +++ b/internal/runtime/executor/antigravity_executor.go @@ -997,7 +997,7 @@ func FetchAntigravityModels(ctx context.Context, auth *cliproxyauth.Auth, cfg *c now := time.Now().Unix() modelConfig := registry.GetAntigravityModelConfig() models := make([]*registry.ModelInfo, 0, len(result.Map())) - for originalName := range result.Map() { + for originalName, modelData := range result.Map() { modelID := strings.TrimSpace(originalName) if modelID == "" { continue @@ -1007,12 +1007,18 @@ func FetchAntigravityModels(ctx context.Context, auth *cliproxyauth.Auth, cfg *c continue } modelCfg := modelConfig[modelID] - modelName := modelID + + // Extract displayName from upstream response, fallback to modelID + displayName := modelData.Get("displayName").String() + if displayName == "" { + displayName = modelID + } + modelInfo := ®istry.ModelInfo{ ID: modelID, - Name: modelName, - Description: modelID, - DisplayName: modelID, + Name: modelID, + Description: displayName, + DisplayName: displayName, Version: modelID, Object: "model", Created: now, diff --git a/internal/runtime/executor/claude_executor.go b/internal/runtime/executor/claude_executor.go index 7a9f1005..9c291328 100644 --- a/internal/runtime/executor/claude_executor.go +++ b/internal/runtime/executor/claude_executor.go @@ -733,6 +733,11 @@ func applyClaudeToolPrefix(body []byte, prefix string) []byte { if tools := gjson.GetBytes(body, "tools"); tools.Exists() && tools.IsArray() { tools.ForEach(func(index, tool gjson.Result) bool { + // Skip built-in tools (web_search, code_execution, etc.) which have + // a "type" field and require their name to remain unchanged. + if tool.Get("type").Exists() && tool.Get("type").String() != "" { + return true + } name := tool.Get("name").String() if name == "" || strings.HasPrefix(name, prefix) { return true diff --git a/internal/runtime/executor/claude_executor_test.go b/internal/runtime/executor/claude_executor_test.go index 05f5b60c..36fb7ad4 100644 --- a/internal/runtime/executor/claude_executor_test.go +++ b/internal/runtime/executor/claude_executor_test.go @@ -25,6 +25,18 @@ func TestApplyClaudeToolPrefix(t *testing.T) { } } +func TestApplyClaudeToolPrefix_SkipsBuiltinTools(t *testing.T) { + input := []byte(`{"tools":[{"type":"web_search_20250305","name":"web_search"},{"name":"my_custom_tool","input_schema":{"type":"object"}}]}`) + out := applyClaudeToolPrefix(input, "proxy_") + + if got := gjson.GetBytes(out, "tools.0.name").String(); got != "web_search" { + t.Fatalf("built-in tool name should not be prefixed: tools.0.name = %q, want %q", got, "web_search") + } + if got := gjson.GetBytes(out, "tools.1.name").String(); got != "proxy_my_custom_tool" { + t.Fatalf("custom tool should be prefixed: tools.1.name = %q, want %q", got, "proxy_my_custom_tool") + } +} + func TestStripClaudeToolPrefixFromResponse(t *testing.T) { input := []byte(`{"content":[{"type":"tool_use","name":"proxy_alpha","id":"t1","input":{}},{"type":"tool_use","name":"bravo","id":"t2","input":{}}]}`) out := stripClaudeToolPrefixFromResponse(input, "proxy_") diff --git a/internal/translator/antigravity/openai/chat-completions/antigravity_openai_request.go b/internal/translator/antigravity/openai/chat-completions/antigravity_openai_request.go index 51d4a02a..f2cb04d6 100644 --- a/internal/translator/antigravity/openai/chat-completions/antigravity_openai_request.go +++ b/internal/translator/antigravity/openai/chat-completions/antigravity_openai_request.go @@ -305,12 +305,12 @@ func ConvertOpenAIRequestToAntigravity(modelName string, inputRawJSON []byte, _ } } - // tools -> request.tools[0].functionDeclarations + request.tools[0].googleSearch passthrough + // tools -> request.tools[].functionDeclarations + request.tools[].googleSearch passthrough tools := gjson.GetBytes(rawJSON, "tools") if tools.IsArray() && len(tools.Array()) > 0 { - toolNode := []byte(`{}`) - hasTool := false + functionToolNode := []byte(`{}`) hasFunction := false + googleSearchNodes := make([][]byte, 0) for _, t := range tools.Array() { if t.Get("type").String() == "function" { fn := t.Get("function") @@ -349,31 +349,37 @@ func ConvertOpenAIRequestToAntigravity(modelName string, inputRawJSON []byte, _ } fnRaw, _ = sjson.Delete(fnRaw, "strict") if !hasFunction { - toolNode, _ = sjson.SetRawBytes(toolNode, "functionDeclarations", []byte("[]")) + functionToolNode, _ = sjson.SetRawBytes(functionToolNode, "functionDeclarations", []byte("[]")) } - tmp, errSet := sjson.SetRawBytes(toolNode, "functionDeclarations.-1", []byte(fnRaw)) + tmp, errSet := sjson.SetRawBytes(functionToolNode, "functionDeclarations.-1", []byte(fnRaw)) if errSet != nil { log.Warnf("Failed to append tool declaration for '%s': %v", fn.Get("name").String(), errSet) continue } - toolNode = tmp + functionToolNode = tmp hasFunction = true - hasTool = true } } if gs := t.Get("google_search"); gs.Exists() { + googleToolNode := []byte(`{}`) var errSet error - toolNode, errSet = sjson.SetRawBytes(toolNode, "googleSearch", []byte(gs.Raw)) + googleToolNode, errSet = sjson.SetRawBytes(googleToolNode, "googleSearch", []byte(gs.Raw)) if errSet != nil { log.Warnf("Failed to set googleSearch tool: %v", errSet) continue } - hasTool = true + googleSearchNodes = append(googleSearchNodes, googleToolNode) } } - if hasTool { - out, _ = sjson.SetRawBytes(out, "request.tools", []byte("[]")) - out, _ = sjson.SetRawBytes(out, "request.tools.0", toolNode) + if hasFunction || len(googleSearchNodes) > 0 { + toolsNode := []byte("[]") + if hasFunction { + toolsNode, _ = sjson.SetRawBytes(toolsNode, "-1", functionToolNode) + } + for _, googleNode := range googleSearchNodes { + toolsNode, _ = sjson.SetRawBytes(toolsNode, "-1", googleNode) + } + out, _ = sjson.SetRawBytes(out, "request.tools", toolsNode) } } diff --git a/internal/translator/gemini-cli/openai/chat-completions/gemini-cli_openai_request.go b/internal/translator/gemini-cli/openai/chat-completions/gemini-cli_openai_request.go index 85669689..6351fa58 100644 --- a/internal/translator/gemini-cli/openai/chat-completions/gemini-cli_openai_request.go +++ b/internal/translator/gemini-cli/openai/chat-completions/gemini-cli_openai_request.go @@ -283,12 +283,12 @@ func ConvertOpenAIRequestToGeminiCLI(modelName string, inputRawJSON []byte, _ bo } } - // tools -> request.tools[0].functionDeclarations + request.tools[0].googleSearch passthrough + // tools -> request.tools[].functionDeclarations + request.tools[].googleSearch passthrough tools := gjson.GetBytes(rawJSON, "tools") if tools.IsArray() && len(tools.Array()) > 0 { - toolNode := []byte(`{}`) - hasTool := false + functionToolNode := []byte(`{}`) hasFunction := false + googleSearchNodes := make([][]byte, 0) for _, t := range tools.Array() { if t.Get("type").String() == "function" { fn := t.Get("function") @@ -327,31 +327,37 @@ func ConvertOpenAIRequestToGeminiCLI(modelName string, inputRawJSON []byte, _ bo } fnRaw, _ = sjson.Delete(fnRaw, "strict") if !hasFunction { - toolNode, _ = sjson.SetRawBytes(toolNode, "functionDeclarations", []byte("[]")) + functionToolNode, _ = sjson.SetRawBytes(functionToolNode, "functionDeclarations", []byte("[]")) } - tmp, errSet := sjson.SetRawBytes(toolNode, "functionDeclarations.-1", []byte(fnRaw)) + tmp, errSet := sjson.SetRawBytes(functionToolNode, "functionDeclarations.-1", []byte(fnRaw)) if errSet != nil { log.Warnf("Failed to append tool declaration for '%s': %v", fn.Get("name").String(), errSet) continue } - toolNode = tmp + functionToolNode = tmp hasFunction = true - hasTool = true } } if gs := t.Get("google_search"); gs.Exists() { + googleToolNode := []byte(`{}`) var errSet error - toolNode, errSet = sjson.SetRawBytes(toolNode, "googleSearch", []byte(gs.Raw)) + googleToolNode, errSet = sjson.SetRawBytes(googleToolNode, "googleSearch", []byte(gs.Raw)) if errSet != nil { log.Warnf("Failed to set googleSearch tool: %v", errSet) continue } - hasTool = true + googleSearchNodes = append(googleSearchNodes, googleToolNode) } } - if hasTool { - out, _ = sjson.SetRawBytes(out, "request.tools", []byte("[]")) - out, _ = sjson.SetRawBytes(out, "request.tools.0", toolNode) + if hasFunction || len(googleSearchNodes) > 0 { + toolsNode := []byte("[]") + if hasFunction { + toolsNode, _ = sjson.SetRawBytes(toolsNode, "-1", functionToolNode) + } + for _, googleNode := range googleSearchNodes { + toolsNode, _ = sjson.SetRawBytes(toolsNode, "-1", googleNode) + } + out, _ = sjson.SetRawBytes(out, "request.tools", toolsNode) } } diff --git a/internal/translator/gemini/openai/chat-completions/gemini_openai_request.go b/internal/translator/gemini/openai/chat-completions/gemini_openai_request.go index ba8b47e3..0a35cfd0 100644 --- a/internal/translator/gemini/openai/chat-completions/gemini_openai_request.go +++ b/internal/translator/gemini/openai/chat-completions/gemini_openai_request.go @@ -289,12 +289,12 @@ func ConvertOpenAIRequestToGemini(modelName string, inputRawJSON []byte, _ bool) } } - // tools -> tools[0].functionDeclarations + tools[0].googleSearch passthrough + // tools -> tools[].functionDeclarations + tools[].googleSearch passthrough tools := gjson.GetBytes(rawJSON, "tools") if tools.IsArray() && len(tools.Array()) > 0 { - toolNode := []byte(`{}`) - hasTool := false + functionToolNode := []byte(`{}`) hasFunction := false + googleSearchNodes := make([][]byte, 0) for _, t := range tools.Array() { if t.Get("type").String() == "function" { fn := t.Get("function") @@ -333,31 +333,37 @@ func ConvertOpenAIRequestToGemini(modelName string, inputRawJSON []byte, _ bool) } fnRaw, _ = sjson.Delete(fnRaw, "strict") if !hasFunction { - toolNode, _ = sjson.SetRawBytes(toolNode, "functionDeclarations", []byte("[]")) + functionToolNode, _ = sjson.SetRawBytes(functionToolNode, "functionDeclarations", []byte("[]")) } - tmp, errSet := sjson.SetRawBytes(toolNode, "functionDeclarations.-1", []byte(fnRaw)) + tmp, errSet := sjson.SetRawBytes(functionToolNode, "functionDeclarations.-1", []byte(fnRaw)) if errSet != nil { log.Warnf("Failed to append tool declaration for '%s': %v", fn.Get("name").String(), errSet) continue } - toolNode = tmp + functionToolNode = tmp hasFunction = true - hasTool = true } } if gs := t.Get("google_search"); gs.Exists() { + googleToolNode := []byte(`{}`) var errSet error - toolNode, errSet = sjson.SetRawBytes(toolNode, "googleSearch", []byte(gs.Raw)) + googleToolNode, errSet = sjson.SetRawBytes(googleToolNode, "googleSearch", []byte(gs.Raw)) if errSet != nil { log.Warnf("Failed to set googleSearch tool: %v", errSet) continue } - hasTool = true + googleSearchNodes = append(googleSearchNodes, googleToolNode) } } - if hasTool { - out, _ = sjson.SetRawBytes(out, "tools", []byte("[]")) - out, _ = sjson.SetRawBytes(out, "tools.0", toolNode) + if hasFunction || len(googleSearchNodes) > 0 { + toolsNode := []byte("[]") + if hasFunction { + toolsNode, _ = sjson.SetRawBytes(toolsNode, "-1", functionToolNode) + } + for _, googleNode := range googleSearchNodes { + toolsNode, _ = sjson.SetRawBytes(toolsNode, "-1", googleNode) + } + out, _ = sjson.SetRawBytes(out, "tools", toolsNode) } } diff --git a/internal/translator/openai/claude/openai_claude_request_test.go b/internal/translator/openai/claude/openai_claude_request_test.go index 3a577957..d08de1b2 100644 --- a/internal/translator/openai/claude/openai_claude_request_test.go +++ b/internal/translator/openai/claude/openai_claude_request_test.go @@ -181,11 +181,11 @@ func TestConvertClaudeRequestToOpenAI_ThinkingToReasoningContent(t *testing.T) { result := ConvertClaudeRequestToOpenAI("test-model", []byte(tt.inputJSON), false) resultJSON := gjson.ParseBytes(result) - // Find the relevant message (skip system message at index 0) + // Find the relevant message messages := resultJSON.Get("messages").Array() - if len(messages) < 2 { + if len(messages) < 1 { if tt.wantHasReasoningContent || tt.wantHasContent { - t.Fatalf("Expected at least 2 messages (system + user/assistant), got %d", len(messages)) + t.Fatalf("Expected at least 1 message, got %d", len(messages)) } return } @@ -272,15 +272,15 @@ func TestConvertClaudeRequestToOpenAI_ThinkingOnlyMessagePreserved(t *testing.T) messages := resultJSON.Get("messages").Array() - // Should have: system (auto-added) + user + assistant (thinking-only) + user = 4 messages - if len(messages) != 4 { - t.Fatalf("Expected 4 messages, got %d. Messages: %v", len(messages), resultJSON.Get("messages").Raw) + // Should have: user + assistant (thinking-only) + user = 3 messages + if len(messages) != 3 { + t.Fatalf("Expected 3 messages, got %d. Messages: %v", len(messages), resultJSON.Get("messages").Raw) } - // Check the assistant message (index 2) has reasoning_content - assistantMsg := messages[2] + // Check the assistant message (index 1) has reasoning_content + assistantMsg := messages[1] if assistantMsg.Get("role").String() != "assistant" { - t.Errorf("Expected message[2] to be assistant, got %s", assistantMsg.Get("role").String()) + t.Errorf("Expected message[1] to be assistant, got %s", assistantMsg.Get("role").String()) } if !assistantMsg.Get("reasoning_content").Exists() { @@ -292,6 +292,104 @@ func TestConvertClaudeRequestToOpenAI_ThinkingOnlyMessagePreserved(t *testing.T) } } +func TestConvertClaudeRequestToOpenAI_SystemMessageScenarios(t *testing.T) { + tests := []struct { + name string + inputJSON string + wantHasSys bool + wantSysText string + }{ + { + name: "No system field", + inputJSON: `{ + "model": "claude-3-opus", + "messages": [{"role": "user", "content": "hello"}] + }`, + wantHasSys: false, + }, + { + name: "Empty string system field", + inputJSON: `{ + "model": "claude-3-opus", + "system": "", + "messages": [{"role": "user", "content": "hello"}] + }`, + wantHasSys: false, + }, + { + name: "String system field", + inputJSON: `{ + "model": "claude-3-opus", + "system": "Be helpful", + "messages": [{"role": "user", "content": "hello"}] + }`, + wantHasSys: true, + wantSysText: "Be helpful", + }, + { + name: "Array system field with text", + inputJSON: `{ + "model": "claude-3-opus", + "system": [{"type": "text", "text": "Array system"}], + "messages": [{"role": "user", "content": "hello"}] + }`, + wantHasSys: true, + wantSysText: "Array system", + }, + { + name: "Array system field with multiple text blocks", + inputJSON: `{ + "model": "claude-3-opus", + "system": [ + {"type": "text", "text": "Block 1"}, + {"type": "text", "text": "Block 2"} + ], + "messages": [{"role": "user", "content": "hello"}] + }`, + wantHasSys: true, + wantSysText: "Block 2", // We will update the test logic to check all blocks or specifically the second one + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := ConvertClaudeRequestToOpenAI("test-model", []byte(tt.inputJSON), false) + resultJSON := gjson.ParseBytes(result) + messages := resultJSON.Get("messages").Array() + + hasSys := false + var sysMsg gjson.Result + if len(messages) > 0 && messages[0].Get("role").String() == "system" { + hasSys = true + sysMsg = messages[0] + } + + if hasSys != tt.wantHasSys { + t.Errorf("got hasSystem = %v, want %v", hasSys, tt.wantHasSys) + } + + if tt.wantHasSys { + // Check content - it could be string or array in OpenAI + content := sysMsg.Get("content") + var gotText string + if content.IsArray() { + arr := content.Array() + if len(arr) > 0 { + // Get the last element's text for validation + gotText = arr[len(arr)-1].Get("text").String() + } + } else { + gotText = content.String() + } + + if tt.wantSysText != "" && gotText != tt.wantSysText { + t.Errorf("got system text = %q, want %q", gotText, tt.wantSysText) + } + } + }) + } +} + func TestConvertClaudeRequestToOpenAI_ToolResultOrderAndContent(t *testing.T) { inputJSON := `{ "model": "claude-3-opus", @@ -318,39 +416,35 @@ func TestConvertClaudeRequestToOpenAI_ToolResultOrderAndContent(t *testing.T) { messages := resultJSON.Get("messages").Array() // OpenAI requires: tool messages MUST immediately follow assistant(tool_calls). - // Correct order: system + assistant(tool_calls) + tool(result) + user(before+after) - if len(messages) != 4 { - t.Fatalf("Expected 4 messages, got %d. Messages: %s", len(messages), resultJSON.Get("messages").Raw) + // Correct order: assistant(tool_calls) + tool(result) + user(before+after) + if len(messages) != 3 { + t.Fatalf("Expected 3 messages, got %d. Messages: %s", len(messages), resultJSON.Get("messages").Raw) } - if messages[0].Get("role").String() != "system" { - t.Fatalf("Expected messages[0] to be system, got %s", messages[0].Get("role").String()) - } - - if messages[1].Get("role").String() != "assistant" || !messages[1].Get("tool_calls").Exists() { - t.Fatalf("Expected messages[1] to be assistant tool_calls, got %s: %s", messages[1].Get("role").String(), messages[1].Raw) + if messages[0].Get("role").String() != "assistant" || !messages[0].Get("tool_calls").Exists() { + t.Fatalf("Expected messages[0] to be assistant tool_calls, got %s: %s", messages[0].Get("role").String(), messages[0].Raw) } // tool message MUST immediately follow assistant(tool_calls) per OpenAI spec - if messages[2].Get("role").String() != "tool" { - t.Fatalf("Expected messages[2] to be tool (must follow tool_calls), got %s", messages[2].Get("role").String()) + if messages[1].Get("role").String() != "tool" { + t.Fatalf("Expected messages[1] to be tool (must follow tool_calls), got %s", messages[1].Get("role").String()) } - if got := messages[2].Get("tool_call_id").String(); got != "call_1" { + if got := messages[1].Get("tool_call_id").String(); got != "call_1" { t.Fatalf("Expected tool_call_id %q, got %q", "call_1", got) } - if got := messages[2].Get("content").String(); got != "tool ok" { + if got := messages[1].Get("content").String(); got != "tool ok" { t.Fatalf("Expected tool content %q, got %q", "tool ok", got) } // User message comes after tool message - if messages[3].Get("role").String() != "user" { - t.Fatalf("Expected messages[3] to be user, got %s", messages[3].Get("role").String()) + if messages[2].Get("role").String() != "user" { + t.Fatalf("Expected messages[2] to be user, got %s", messages[2].Get("role").String()) } // User message should contain both "before" and "after" text - if got := messages[3].Get("content.0.text").String(); got != "before" { + if got := messages[2].Get("content.0.text").String(); got != "before" { t.Fatalf("Expected user text[0] %q, got %q", "before", got) } - if got := messages[3].Get("content.1.text").String(); got != "after" { + if got := messages[2].Get("content.1.text").String(); got != "after" { t.Fatalf("Expected user text[1] %q, got %q", "after", got) } } @@ -378,16 +472,16 @@ func TestConvertClaudeRequestToOpenAI_ToolResultObjectContent(t *testing.T) { resultJSON := gjson.ParseBytes(result) messages := resultJSON.Get("messages").Array() - // system + assistant(tool_calls) + tool(result) - if len(messages) != 3 { - t.Fatalf("Expected 3 messages, got %d. Messages: %s", len(messages), resultJSON.Get("messages").Raw) + // assistant(tool_calls) + tool(result) + if len(messages) != 2 { + t.Fatalf("Expected 2 messages, got %d. Messages: %s", len(messages), resultJSON.Get("messages").Raw) } - if messages[2].Get("role").String() != "tool" { - t.Fatalf("Expected messages[2] to be tool, got %s", messages[2].Get("role").String()) + if messages[1].Get("role").String() != "tool" { + t.Fatalf("Expected messages[1] to be tool, got %s", messages[1].Get("role").String()) } - toolContent := messages[2].Get("content").String() + toolContent := messages[1].Get("content").String() parsed := gjson.Parse(toolContent) if parsed.Get("foo").String() != "bar" { t.Fatalf("Expected tool content JSON foo=bar, got %q", toolContent) @@ -414,18 +508,14 @@ func TestConvertClaudeRequestToOpenAI_AssistantTextToolUseTextOrder(t *testing.T messages := resultJSON.Get("messages").Array() // New behavior: content + tool_calls unified in single assistant message - // Expect: system + assistant(content[pre,post] + tool_calls) - if len(messages) != 2 { - t.Fatalf("Expected 2 messages, got %d. Messages: %s", len(messages), resultJSON.Get("messages").Raw) + // Expect: assistant(content[pre,post] + tool_calls) + if len(messages) != 1 { + t.Fatalf("Expected 1 message, got %d. Messages: %s", len(messages), resultJSON.Get("messages").Raw) } - if messages[0].Get("role").String() != "system" { - t.Fatalf("Expected messages[0] to be system, got %s", messages[0].Get("role").String()) - } - - assistantMsg := messages[1] + assistantMsg := messages[0] if assistantMsg.Get("role").String() != "assistant" { - t.Fatalf("Expected messages[1] to be assistant, got %s", assistantMsg.Get("role").String()) + t.Fatalf("Expected messages[0] to be assistant, got %s", assistantMsg.Get("role").String()) } // Should have both content and tool_calls in same message @@ -470,14 +560,14 @@ func TestConvertClaudeRequestToOpenAI_AssistantThinkingToolUseThinkingSplit(t *t messages := resultJSON.Get("messages").Array() // New behavior: all content, thinking, and tool_calls unified in single assistant message - // Expect: system + assistant(content[pre,post] + tool_calls + reasoning_content[t1+t2]) - if len(messages) != 2 { - t.Fatalf("Expected 2 messages, got %d. Messages: %s", len(messages), resultJSON.Get("messages").Raw) + // Expect: assistant(content[pre,post] + tool_calls + reasoning_content[t1+t2]) + if len(messages) != 1 { + t.Fatalf("Expected 1 message, got %d. Messages: %s", len(messages), resultJSON.Get("messages").Raw) } - assistantMsg := messages[1] + assistantMsg := messages[0] if assistantMsg.Get("role").String() != "assistant" { - t.Fatalf("Expected messages[1] to be assistant, got %s", assistantMsg.Get("role").String()) + t.Fatalf("Expected messages[0] to be assistant, got %s", assistantMsg.Get("role").String()) } // Should have content with both pre and post diff --git a/sdk/api/handlers/claude/code_handlers.go b/sdk/api/handlers/claude/code_handlers.go index 30ff228d..22e10fa5 100644 --- a/sdk/api/handlers/claude/code_handlers.go +++ b/sdk/api/handlers/claude/code_handlers.go @@ -128,8 +128,23 @@ func (h *ClaudeCodeAPIHandler) ClaudeCountTokens(c *gin.Context) { // Parameters: // - c: The Gin context for the request. func (h *ClaudeCodeAPIHandler) ClaudeModels(c *gin.Context) { + models := h.Models() + firstID := "" + lastID := "" + if len(models) > 0 { + if id, ok := models[0]["id"].(string); ok { + firstID = id + } + if id, ok := models[len(models)-1]["id"].(string); ok { + lastID = id + } + } + c.JSON(http.StatusOK, gin.H{ - "data": h.Models(), + "data": models, + "has_more": false, + "first_id": firstID, + "last_id": lastID, }) } diff --git a/sdk/api/handlers/gemini/gemini_handlers.go b/sdk/api/handlers/gemini/gemini_handlers.go index 27d8d1f5..71c485ad 100644 --- a/sdk/api/handlers/gemini/gemini_handlers.go +++ b/sdk/api/handlers/gemini/gemini_handlers.go @@ -60,8 +60,12 @@ func (h *GeminiAPIHandler) GeminiModels(c *gin.Context) { if !strings.HasPrefix(name, "models/") { normalizedModel["name"] = "models/" + name } - normalizedModel["displayName"] = name - normalizedModel["description"] = name + if displayName, _ := normalizedModel["displayName"].(string); displayName == "" { + normalizedModel["displayName"] = name + } + if description, _ := normalizedModel["description"].(string); description == "" { + normalizedModel["description"] = name + } } if _, ok := normalizedModel["supportedGenerationMethods"]; !ok { normalizedModel["supportedGenerationMethods"] = defaultMethods diff --git a/sdk/auth/antigravity.go b/sdk/auth/antigravity.go index 210da57f..ecca0a00 100644 --- a/sdk/auth/antigravity.go +++ b/sdk/auth/antigravity.go @@ -2,15 +2,13 @@ package auth import ( "context" - "encoding/json" "fmt" - "io" "net" "net/http" - "net/url" "strings" "time" + "github.com/router-for-me/CLIProxyAPI/v6/internal/auth/antigravity" "github.com/router-for-me/CLIProxyAPI/v6/internal/browser" "github.com/router-for-me/CLIProxyAPI/v6/internal/config" "github.com/router-for-me/CLIProxyAPI/v6/internal/misc" @@ -19,20 +17,6 @@ import ( log "github.com/sirupsen/logrus" ) -const ( - antigravityClientID = "1071006060591-tmhssin2h21lcre235vtolojh4g403ep.apps.googleusercontent.com" - antigravityClientSecret = "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" - antigravityCallbackPort = 51121 -) - -var antigravityScopes = []string{ - "https://www.googleapis.com/auth/cloud-platform", - "https://www.googleapis.com/auth/userinfo.email", - "https://www.googleapis.com/auth/userinfo.profile", - "https://www.googleapis.com/auth/cclog", - "https://www.googleapis.com/auth/experimentsandconfigs", -} - // AntigravityAuthenticator implements OAuth login for the antigravity provider. type AntigravityAuthenticator struct{} @@ -60,12 +44,12 @@ func (AntigravityAuthenticator) Login(ctx context.Context, cfg *config.Config, o opts = &LoginOptions{} } - callbackPort := antigravityCallbackPort + callbackPort := antigravity.CallbackPort if opts.CallbackPort > 0 { callbackPort = opts.CallbackPort } - httpClient := util.SetProxy(&cfg.SDKConfig, &http.Client{}) + authSvc := antigravity.NewAntigravityAuth(cfg, nil) state, err := misc.GenerateRandomState() if err != nil { @@ -83,7 +67,7 @@ func (AntigravityAuthenticator) Login(ctx context.Context, cfg *config.Config, o }() redirectURI := fmt.Sprintf("http://localhost:%d/oauth-callback", port) - authURL := buildAntigravityAuthURL(redirectURI, state) + authURL := authSvc.BuildAuthURL(state, redirectURI) if !opts.NoBrowser { fmt.Println("Opening browser for antigravity authentication") @@ -164,22 +148,29 @@ waitForCallback: return nil, fmt.Errorf("antigravity: missing authorization code") } - tokenResp, errToken := exchangeAntigravityCode(ctx, cbRes.Code, redirectURI, httpClient) + tokenResp, errToken := authSvc.ExchangeCodeForTokens(ctx, cbRes.Code, redirectURI) if errToken != nil { return nil, fmt.Errorf("antigravity: token exchange failed: %w", errToken) } - email := "" - if tokenResp.AccessToken != "" { - if info, errInfo := fetchAntigravityUserInfo(ctx, tokenResp.AccessToken, httpClient); errInfo == nil && strings.TrimSpace(info.Email) != "" { - email = strings.TrimSpace(info.Email) - } + accessToken := strings.TrimSpace(tokenResp.AccessToken) + if accessToken == "" { + return nil, fmt.Errorf("antigravity: token exchange returned empty access token") + } + + email, errInfo := authSvc.FetchUserInfo(ctx, accessToken) + if errInfo != nil { + return nil, fmt.Errorf("antigravity: fetch user info failed: %w", errInfo) + } + email = strings.TrimSpace(email) + if email == "" { + return nil, fmt.Errorf("antigravity: empty email returned from user info") } // Fetch project ID via loadCodeAssist (same approach as Gemini CLI) projectID := "" - if tokenResp.AccessToken != "" { - fetchedProjectID, errProject := fetchAntigravityProjectID(ctx, tokenResp.AccessToken, httpClient) + if accessToken != "" { + fetchedProjectID, errProject := authSvc.FetchProjectID(ctx, accessToken) if errProject != nil { log.Warnf("antigravity: failed to fetch project ID: %v", errProject) } else { @@ -204,7 +195,7 @@ waitForCallback: metadata["project_id"] = projectID } - fileName := sanitizeAntigravityFileName(email) + fileName := antigravity.CredentialFileName(email) label := email if label == "" { label = "antigravity" @@ -231,7 +222,7 @@ type callbackResult struct { func startAntigravityCallbackServer(port int) (*http.Server, int, <-chan callbackResult, error) { if port <= 0 { - port = antigravityCallbackPort + port = antigravity.CallbackPort } addr := fmt.Sprintf(":%d", port) listener, err := net.Listen("tcp", addr) @@ -267,309 +258,9 @@ func startAntigravityCallbackServer(port int) (*http.Server, int, <-chan callbac return srv, port, resultCh, nil } -type antigravityTokenResponse struct { - AccessToken string `json:"access_token"` - RefreshToken string `json:"refresh_token"` - ExpiresIn int64 `json:"expires_in"` - TokenType string `json:"token_type"` -} - -func exchangeAntigravityCode(ctx context.Context, code, redirectURI string, httpClient *http.Client) (*antigravityTokenResponse, error) { - data := url.Values{} - data.Set("code", code) - data.Set("client_id", antigravityClientID) - data.Set("client_secret", antigravityClientSecret) - data.Set("redirect_uri", redirectURI) - data.Set("grant_type", "authorization_code") - - req, err := http.NewRequestWithContext(ctx, http.MethodPost, "https://oauth2.googleapis.com/token", strings.NewReader(data.Encode())) - if err != nil { - return nil, err - } - req.Header.Set("Content-Type", "application/x-www-form-urlencoded") - - resp, errDo := httpClient.Do(req) - if errDo != nil { - return nil, errDo - } - defer func() { - if errClose := resp.Body.Close(); errClose != nil { - log.Errorf("antigravity token exchange: close body error: %v", errClose) - } - }() - - var token antigravityTokenResponse - if errDecode := json.NewDecoder(resp.Body).Decode(&token); errDecode != nil { - return nil, errDecode - } - if resp.StatusCode < http.StatusOK || resp.StatusCode >= http.StatusMultipleChoices { - return nil, fmt.Errorf("oauth token exchange failed: status %d", resp.StatusCode) - } - return &token, nil -} - -type antigravityUserInfo struct { - Email string `json:"email"` -} - -func fetchAntigravityUserInfo(ctx context.Context, accessToken string, httpClient *http.Client) (*antigravityUserInfo, error) { - if strings.TrimSpace(accessToken) == "" { - return &antigravityUserInfo{}, nil - } - req, err := http.NewRequestWithContext(ctx, http.MethodGet, "https://www.googleapis.com/oauth2/v1/userinfo?alt=json", nil) - if err != nil { - return nil, err - } - req.Header.Set("Authorization", "Bearer "+accessToken) - - resp, errDo := httpClient.Do(req) - if errDo != nil { - return nil, errDo - } - defer func() { - if errClose := resp.Body.Close(); errClose != nil { - log.Errorf("antigravity userinfo: close body error: %v", errClose) - } - }() - - if resp.StatusCode < http.StatusOK || resp.StatusCode >= http.StatusMultipleChoices { - return &antigravityUserInfo{}, nil - } - var info antigravityUserInfo - if errDecode := json.NewDecoder(resp.Body).Decode(&info); errDecode != nil { - return nil, errDecode - } - return &info, nil -} - -func buildAntigravityAuthURL(redirectURI, state string) string { - params := url.Values{} - params.Set("access_type", "offline") - params.Set("client_id", antigravityClientID) - params.Set("prompt", "consent") - params.Set("redirect_uri", redirectURI) - params.Set("response_type", "code") - params.Set("scope", strings.Join(antigravityScopes, " ")) - params.Set("state", state) - return "https://accounts.google.com/o/oauth2/v2/auth?" + params.Encode() -} - -func sanitizeAntigravityFileName(email string) string { - if strings.TrimSpace(email) == "" { - return "antigravity.json" - } - replacer := strings.NewReplacer("@", "_", ".", "_") - return fmt.Sprintf("antigravity-%s.json", replacer.Replace(email)) -} - -// Antigravity API constants for project discovery -const ( - antigravityAPIEndpoint = "https://cloudcode-pa.googleapis.com" - antigravityAPIVersion = "v1internal" - antigravityAPIUserAgent = "google-api-nodejs-client/9.15.1" - antigravityAPIClient = "google-cloud-sdk vscode_cloudshelleditor/0.1" - antigravityClientMetadata = `{"ideType":"IDE_UNSPECIFIED","platform":"PLATFORM_UNSPECIFIED","pluginType":"GEMINI"}` -) - // FetchAntigravityProjectID exposes project discovery for external callers. func FetchAntigravityProjectID(ctx context.Context, accessToken string, httpClient *http.Client) (string, error) { - return fetchAntigravityProjectID(ctx, accessToken, httpClient) -} - -// fetchAntigravityProjectID retrieves the project ID for the authenticated user via loadCodeAssist. -// This uses the same approach as Gemini CLI to get the cloudaicompanionProject. -func fetchAntigravityProjectID(ctx context.Context, accessToken string, httpClient *http.Client) (string, error) { - // Call loadCodeAssist to get the project - loadReqBody := map[string]any{ - "metadata": map[string]string{ - "ideType": "ANTIGRAVITY", - "platform": "PLATFORM_UNSPECIFIED", - "pluginType": "GEMINI", - }, - } - - rawBody, errMarshal := json.Marshal(loadReqBody) - if errMarshal != nil { - return "", fmt.Errorf("marshal request body: %w", errMarshal) - } - - endpointURL := fmt.Sprintf("%s/%s:loadCodeAssist", antigravityAPIEndpoint, antigravityAPIVersion) - req, err := http.NewRequestWithContext(ctx, http.MethodPost, endpointURL, strings.NewReader(string(rawBody))) - if err != nil { - return "", fmt.Errorf("create request: %w", err) - } - req.Header.Set("Authorization", "Bearer "+accessToken) - req.Header.Set("Content-Type", "application/json") - req.Header.Set("User-Agent", antigravityAPIUserAgent) - req.Header.Set("X-Goog-Api-Client", antigravityAPIClient) - req.Header.Set("Client-Metadata", antigravityClientMetadata) - - resp, errDo := httpClient.Do(req) - if errDo != nil { - return "", fmt.Errorf("execute request: %w", errDo) - } - defer func() { - if errClose := resp.Body.Close(); errClose != nil { - log.Errorf("antigravity loadCodeAssist: close body error: %v", errClose) - } - }() - - bodyBytes, errRead := io.ReadAll(resp.Body) - if errRead != nil { - return "", fmt.Errorf("read response: %w", errRead) - } - - if resp.StatusCode < http.StatusOK || resp.StatusCode >= http.StatusMultipleChoices { - return "", fmt.Errorf("request failed with status %d: %s", resp.StatusCode, strings.TrimSpace(string(bodyBytes))) - } - - var loadResp map[string]any - if errDecode := json.Unmarshal(bodyBytes, &loadResp); errDecode != nil { - return "", fmt.Errorf("decode response: %w", errDecode) - } - - // Extract projectID from response - projectID := "" - if id, ok := loadResp["cloudaicompanionProject"].(string); ok { - projectID = strings.TrimSpace(id) - } - if projectID == "" { - if projectMap, ok := loadResp["cloudaicompanionProject"].(map[string]any); ok { - if id, okID := projectMap["id"].(string); okID { - projectID = strings.TrimSpace(id) - } - } - } - - if projectID == "" { - tierID := "legacy-tier" - if tiers, okTiers := loadResp["allowedTiers"].([]any); okTiers { - for _, rawTier := range tiers { - tier, okTier := rawTier.(map[string]any) - if !okTier { - continue - } - if isDefault, okDefault := tier["isDefault"].(bool); okDefault && isDefault { - if id, okID := tier["id"].(string); okID && strings.TrimSpace(id) != "" { - tierID = strings.TrimSpace(id) - break - } - } - } - } - - projectID, err = antigravityOnboardUser(ctx, accessToken, tierID, httpClient) - if err != nil { - return "", err - } - return projectID, nil - } - - return projectID, nil -} - -// antigravityOnboardUser attempts to fetch the project ID via onboardUser by polling for completion. -// It returns an empty string when the operation times out or completes without a project ID. -func antigravityOnboardUser(ctx context.Context, accessToken, tierID string, httpClient *http.Client) (string, error) { - if httpClient == nil { - httpClient = http.DefaultClient - } - fmt.Println("Antigravity: onboarding user...", tierID) - requestBody := map[string]any{ - "tierId": tierID, - "metadata": map[string]string{ - "ideType": "ANTIGRAVITY", - "platform": "PLATFORM_UNSPECIFIED", - "pluginType": "GEMINI", - }, - } - - rawBody, errMarshal := json.Marshal(requestBody) - if errMarshal != nil { - return "", fmt.Errorf("marshal request body: %w", errMarshal) - } - - maxAttempts := 5 - for attempt := 1; attempt <= maxAttempts; attempt++ { - log.Debugf("Polling attempt %d/%d", attempt, maxAttempts) - - reqCtx := ctx - var cancel context.CancelFunc - if reqCtx == nil { - reqCtx = context.Background() - } - reqCtx, cancel = context.WithTimeout(reqCtx, 30*time.Second) - - endpointURL := fmt.Sprintf("%s/%s:onboardUser", antigravityAPIEndpoint, antigravityAPIVersion) - req, errRequest := http.NewRequestWithContext(reqCtx, http.MethodPost, endpointURL, strings.NewReader(string(rawBody))) - if errRequest != nil { - cancel() - return "", fmt.Errorf("create request: %w", errRequest) - } - req.Header.Set("Authorization", "Bearer "+accessToken) - req.Header.Set("Content-Type", "application/json") - req.Header.Set("User-Agent", antigravityAPIUserAgent) - req.Header.Set("X-Goog-Api-Client", antigravityAPIClient) - req.Header.Set("Client-Metadata", antigravityClientMetadata) - - resp, errDo := httpClient.Do(req) - if errDo != nil { - cancel() - return "", fmt.Errorf("execute request: %w", errDo) - } - - bodyBytes, errRead := io.ReadAll(resp.Body) - if errClose := resp.Body.Close(); errClose != nil { - log.Errorf("close body error: %v", errClose) - } - cancel() - - if errRead != nil { - return "", fmt.Errorf("read response: %w", errRead) - } - - if resp.StatusCode == http.StatusOK { - var data map[string]any - if errDecode := json.Unmarshal(bodyBytes, &data); errDecode != nil { - return "", fmt.Errorf("decode response: %w", errDecode) - } - - if done, okDone := data["done"].(bool); okDone && done { - projectID := "" - if responseData, okResp := data["response"].(map[string]any); okResp { - switch projectValue := responseData["cloudaicompanionProject"].(type) { - case map[string]any: - if id, okID := projectValue["id"].(string); okID { - projectID = strings.TrimSpace(id) - } - case string: - projectID = strings.TrimSpace(projectValue) - } - } - - if projectID != "" { - log.Infof("Successfully fetched project_id: %s", projectID) - return projectID, nil - } - - return "", fmt.Errorf("no project_id in response") - } - - time.Sleep(2 * time.Second) - continue - } - - responsePreview := strings.TrimSpace(string(bodyBytes)) - if len(responsePreview) > 500 { - responsePreview = responsePreview[:500] - } - - responseErr := responsePreview - if len(responseErr) > 200 { - responseErr = responseErr[:200] - } - return "", fmt.Errorf("http %d: %s", resp.StatusCode, responseErr) - } - - return "", nil + cfg := &config.Config{} + authSvc := antigravity.NewAntigravityAuth(cfg, httpClient) + return authSvc.FetchProjectID(ctx, accessToken) }