mirror of
https://github.com/router-for-me/CLIProxyAPIPlus.git
synced 2026-03-11 00:03:36 +00:00
Compare commits
22 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
0f646800f6 | ||
|
|
ca993238f3 | ||
|
|
d1220de02d | ||
|
|
cf9a246d53 | ||
|
|
13eb5268de | ||
|
|
88798816f2 | ||
|
|
598f0af19b | ||
|
|
a33f5d31fc | ||
|
|
54acd69e9d | ||
|
|
d687ee2777 | ||
|
|
54c2fefbad | ||
|
|
506699fba1 | ||
|
|
f7b17ee6ec | ||
|
|
408614c74c | ||
|
|
68a27772b3 | ||
|
|
de87fb622b | ||
|
|
0155a01bb1 | ||
|
|
cfeee5d511 | ||
|
|
f27672f6cf | ||
|
|
28420c14e4 | ||
|
|
9b956f6338 | ||
|
|
92c62bb2fb |
@@ -28,4 +28,6 @@ bin/*
|
||||
.claude/*
|
||||
.vscode/*
|
||||
.serena/*
|
||||
.bmad/*
|
||||
.agent/*
|
||||
.bmad/*
|
||||
_bmad/*
|
||||
|
||||
23
.github/workflows/pr-test-build.yml
vendored
Normal file
23
.github/workflows/pr-test-build.yml
vendored
Normal file
@@ -0,0 +1,23 @@
|
||||
name: pr-test-build
|
||||
|
||||
on:
|
||||
pull_request:
|
||||
|
||||
permissions:
|
||||
contents: read
|
||||
|
||||
jobs:
|
||||
build:
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- name: Checkout
|
||||
uses: actions/checkout@v4
|
||||
- name: Set up Go
|
||||
uses: actions/setup-go@v5
|
||||
with:
|
||||
go-version-file: go.mod
|
||||
cache: true
|
||||
- name: Build
|
||||
run: |
|
||||
go build -o test-output ./cmd/server
|
||||
rm -f test-output
|
||||
2
.gitignore
vendored
2
.gitignore
vendored
@@ -32,7 +32,9 @@ GEMINI.md
|
||||
.vscode/*
|
||||
.claude/*
|
||||
.serena/*
|
||||
.agent/*
|
||||
.bmad/*
|
||||
_bmad/*
|
||||
.mcp/cache/
|
||||
|
||||
# macOS
|
||||
|
||||
@@ -78,6 +78,7 @@ func main() {
|
||||
var kiroLogin bool
|
||||
var kiroGoogleLogin bool
|
||||
var kiroAWSLogin bool
|
||||
var kiroAWSAuthCode bool
|
||||
var kiroImport bool
|
||||
var githubCopilotLogin bool
|
||||
var projectID string
|
||||
@@ -101,6 +102,7 @@ func main() {
|
||||
flag.BoolVar(&kiroLogin, "kiro-login", false, "Login to Kiro using Google OAuth")
|
||||
flag.BoolVar(&kiroGoogleLogin, "kiro-google-login", false, "Login to Kiro using Google OAuth (same as --kiro-login)")
|
||||
flag.BoolVar(&kiroAWSLogin, "kiro-aws-login", false, "Login to Kiro using AWS Builder ID (device code flow)")
|
||||
flag.BoolVar(&kiroAWSAuthCode, "kiro-aws-authcode", false, "Login to Kiro using AWS Builder ID (authorization code flow, better UX)")
|
||||
flag.BoolVar(&kiroImport, "kiro-import", false, "Import Kiro token from Kiro IDE (~/.aws/sso/cache/kiro-auth-token.json)")
|
||||
flag.BoolVar(&githubCopilotLogin, "github-copilot-login", false, "Login to GitHub Copilot using device flow")
|
||||
flag.StringVar(&projectID, "project_id", "", "Project ID (Gemini only, not required)")
|
||||
@@ -513,6 +515,10 @@ func main() {
|
||||
// Users can explicitly override with --no-incognito
|
||||
setKiroIncognitoMode(cfg, useIncognito, noIncognito)
|
||||
cmd.DoKiroAWSLogin(cfg, options)
|
||||
} else if kiroAWSAuthCode {
|
||||
// For Kiro auth with authorization code flow (better UX)
|
||||
setKiroIncognitoMode(cfg, useIncognito, noIncognito)
|
||||
cmd.DoKiroAWSAuthCodeLogin(cfg, options)
|
||||
} else if kiroImport {
|
||||
cmd.DoKiroImport(cfg, options)
|
||||
} else {
|
||||
|
||||
166
internal/auth/kiro/codewhisperer_client.go
Normal file
166
internal/auth/kiro/codewhisperer_client.go
Normal file
@@ -0,0 +1,166 @@
|
||||
// Package kiro provides CodeWhisperer API client for fetching user info.
|
||||
package kiro
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/util"
|
||||
log "github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
const (
|
||||
codeWhispererAPI = "https://codewhisperer.us-east-1.amazonaws.com"
|
||||
kiroVersion = "0.6.18"
|
||||
)
|
||||
|
||||
// CodeWhispererClient handles CodeWhisperer API calls.
|
||||
type CodeWhispererClient struct {
|
||||
httpClient *http.Client
|
||||
machineID string
|
||||
}
|
||||
|
||||
// UsageLimitsResponse represents the getUsageLimits API response.
|
||||
type UsageLimitsResponse struct {
|
||||
DaysUntilReset *int `json:"daysUntilReset,omitempty"`
|
||||
NextDateReset *float64 `json:"nextDateReset,omitempty"`
|
||||
UserInfo *UserInfo `json:"userInfo,omitempty"`
|
||||
SubscriptionInfo *SubscriptionInfo `json:"subscriptionInfo,omitempty"`
|
||||
UsageBreakdownList []UsageBreakdown `json:"usageBreakdownList,omitempty"`
|
||||
}
|
||||
|
||||
// UserInfo contains user information from the API.
|
||||
type UserInfo struct {
|
||||
Email string `json:"email,omitempty"`
|
||||
UserID string `json:"userId,omitempty"`
|
||||
}
|
||||
|
||||
// SubscriptionInfo contains subscription details.
|
||||
type SubscriptionInfo struct {
|
||||
SubscriptionTitle string `json:"subscriptionTitle,omitempty"`
|
||||
Type string `json:"type,omitempty"`
|
||||
}
|
||||
|
||||
// UsageBreakdown contains usage details.
|
||||
type UsageBreakdown struct {
|
||||
UsageLimit *int `json:"usageLimit,omitempty"`
|
||||
CurrentUsage *int `json:"currentUsage,omitempty"`
|
||||
UsageLimitWithPrecision *float64 `json:"usageLimitWithPrecision,omitempty"`
|
||||
CurrentUsageWithPrecision *float64 `json:"currentUsageWithPrecision,omitempty"`
|
||||
NextDateReset *float64 `json:"nextDateReset,omitempty"`
|
||||
DisplayName string `json:"displayName,omitempty"`
|
||||
ResourceType string `json:"resourceType,omitempty"`
|
||||
}
|
||||
|
||||
// NewCodeWhispererClient creates a new CodeWhisperer client.
|
||||
func NewCodeWhispererClient(cfg *config.Config, machineID string) *CodeWhispererClient {
|
||||
client := &http.Client{Timeout: 30 * time.Second}
|
||||
if cfg != nil {
|
||||
client = util.SetProxy(&cfg.SDKConfig, client)
|
||||
}
|
||||
if machineID == "" {
|
||||
machineID = uuid.New().String()
|
||||
}
|
||||
return &CodeWhispererClient{
|
||||
httpClient: client,
|
||||
machineID: machineID,
|
||||
}
|
||||
}
|
||||
|
||||
// generateInvocationID generates a unique invocation ID.
|
||||
func generateInvocationID() string {
|
||||
return uuid.New().String()
|
||||
}
|
||||
|
||||
// GetUsageLimits fetches usage limits and user info from CodeWhisperer API.
|
||||
// This is the recommended way to get user email after login.
|
||||
func (c *CodeWhispererClient) GetUsageLimits(ctx context.Context, accessToken string) (*UsageLimitsResponse, error) {
|
||||
url := fmt.Sprintf("%s/getUsageLimits?isEmailRequired=true&origin=AI_EDITOR&resourceType=AGENTIC_REQUEST", codeWhispererAPI)
|
||||
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodGet, url, nil)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create request: %w", err)
|
||||
}
|
||||
|
||||
// Set headers to match Kiro IDE
|
||||
xAmzUserAgent := fmt.Sprintf("aws-sdk-js/1.0.0 KiroIDE-%s-%s", kiroVersion, c.machineID)
|
||||
userAgent := fmt.Sprintf("aws-sdk-js/1.0.0 ua/2.1 os/windows lang/js md/nodejs#20.16.0 api/codewhispererruntime#1.0.0 m/E KiroIDE-%s-%s", kiroVersion, c.machineID)
|
||||
|
||||
req.Header.Set("Authorization", "Bearer "+accessToken)
|
||||
req.Header.Set("x-amz-user-agent", xAmzUserAgent)
|
||||
req.Header.Set("User-Agent", userAgent)
|
||||
req.Header.Set("amz-sdk-invocation-id", generateInvocationID())
|
||||
req.Header.Set("amz-sdk-request", "attempt=1; max=1")
|
||||
req.Header.Set("Connection", "close")
|
||||
|
||||
log.Debugf("codewhisperer: GET %s", url)
|
||||
|
||||
resp, err := c.httpClient.Do(req)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("request failed: %w", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
body, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to read response: %w", err)
|
||||
}
|
||||
|
||||
log.Debugf("codewhisperer: status=%d, body=%s", resp.StatusCode, string(body))
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
return nil, fmt.Errorf("API returned status %d: %s", resp.StatusCode, string(body))
|
||||
}
|
||||
|
||||
var result UsageLimitsResponse
|
||||
if err := json.Unmarshal(body, &result); err != nil {
|
||||
return nil, fmt.Errorf("failed to parse response: %w", err)
|
||||
}
|
||||
|
||||
return &result, nil
|
||||
}
|
||||
|
||||
// FetchUserEmailFromAPI fetches user email using CodeWhisperer getUsageLimits API.
|
||||
// This is more reliable than JWT parsing as it uses the official API.
|
||||
func (c *CodeWhispererClient) FetchUserEmailFromAPI(ctx context.Context, accessToken string) string {
|
||||
resp, err := c.GetUsageLimits(ctx, accessToken)
|
||||
if err != nil {
|
||||
log.Debugf("codewhisperer: failed to get usage limits: %v", err)
|
||||
return ""
|
||||
}
|
||||
|
||||
if resp.UserInfo != nil && resp.UserInfo.Email != "" {
|
||||
log.Debugf("codewhisperer: got email from API: %s", resp.UserInfo.Email)
|
||||
return resp.UserInfo.Email
|
||||
}
|
||||
|
||||
log.Debugf("codewhisperer: no email in response")
|
||||
return ""
|
||||
}
|
||||
|
||||
// FetchUserEmailWithFallback fetches user email with multiple fallback methods.
|
||||
// Priority: 1. CodeWhisperer API 2. userinfo endpoint 3. JWT parsing
|
||||
func FetchUserEmailWithFallback(ctx context.Context, cfg *config.Config, accessToken string) string {
|
||||
// Method 1: Try CodeWhisperer API (most reliable)
|
||||
cwClient := NewCodeWhispererClient(cfg, "")
|
||||
email := cwClient.FetchUserEmailFromAPI(ctx, accessToken)
|
||||
if email != "" {
|
||||
return email
|
||||
}
|
||||
|
||||
// Method 2: Try SSO OIDC userinfo endpoint
|
||||
ssoClient := NewSSOOIDCClient(cfg)
|
||||
email = ssoClient.FetchUserEmail(ctx, accessToken)
|
||||
if email != "" {
|
||||
return email
|
||||
}
|
||||
|
||||
// Method 3: Fallback to JWT parsing
|
||||
return ExtractEmailFromJWT(accessToken)
|
||||
}
|
||||
@@ -163,6 +163,13 @@ func (o *KiroOAuth) LoginWithBuilderID(ctx context.Context) (*KiroTokenData, err
|
||||
return ssoClient.LoginWithBuilderID(ctx)
|
||||
}
|
||||
|
||||
// LoginWithBuilderIDAuthCode performs OAuth login with AWS Builder ID using authorization code flow.
|
||||
// This provides a better UX than device code flow as it uses automatic browser callback.
|
||||
func (o *KiroOAuth) LoginWithBuilderIDAuthCode(ctx context.Context) (*KiroTokenData, error) {
|
||||
ssoClient := NewSSOOIDCClient(o.cfg)
|
||||
return ssoClient.LoginWithBuilderIDAuthCode(ctx)
|
||||
}
|
||||
|
||||
// exchangeCodeForToken exchanges the authorization code for tokens.
|
||||
func (o *KiroOAuth) exchangeCodeForToken(ctx context.Context, code, codeVerifier, redirectURI string) (*KiroTokenData, error) {
|
||||
payload := map[string]string{
|
||||
|
||||
@@ -3,9 +3,14 @@ package kiro
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/rand"
|
||||
"crypto/sha256"
|
||||
"encoding/base64"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"html"
|
||||
"io"
|
||||
"net"
|
||||
"net/http"
|
||||
"strings"
|
||||
"time"
|
||||
@@ -25,6 +30,13 @@ const (
|
||||
|
||||
// Polling interval
|
||||
pollInterval = 5 * time.Second
|
||||
|
||||
// Authorization code flow callback
|
||||
authCodeCallbackPath = "/oauth/callback"
|
||||
authCodeCallbackPort = 19877
|
||||
|
||||
// User-Agent to match official Kiro IDE
|
||||
kiroUserAgent = "KiroIDE"
|
||||
)
|
||||
|
||||
// SSOOIDCClient handles AWS SSO OIDC authentication.
|
||||
@@ -73,13 +85,11 @@ type CreateTokenResponse struct {
|
||||
|
||||
// RegisterClient registers a new OIDC client with AWS.
|
||||
func (c *SSOOIDCClient) RegisterClient(ctx context.Context) (*RegisterClientResponse, error) {
|
||||
// Generate unique client name for each registration to support multiple accounts
|
||||
clientName := fmt.Sprintf("CLI-Proxy-API-%d", time.Now().UnixNano())
|
||||
|
||||
payload := map[string]interface{}{
|
||||
"clientName": clientName,
|
||||
"clientName": "Kiro IDE",
|
||||
"clientType": "public",
|
||||
"scopes": []string{"codewhisperer:completions", "codewhisperer:analysis", "codewhisperer:conversations"},
|
||||
"scopes": []string{"codewhisperer:completions", "codewhisperer:analysis", "codewhisperer:conversations", "codewhisperer:transformations", "codewhisperer:taskassist"},
|
||||
"grantTypes": []string{"urn:ietf:params:oauth:grant-type:device_code", "refresh_token"},
|
||||
}
|
||||
|
||||
body, err := json.Marshal(payload)
|
||||
@@ -92,6 +102,7 @@ func (c *SSOOIDCClient) RegisterClient(ctx context.Context) (*RegisterClientResp
|
||||
return nil, err
|
||||
}
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
req.Header.Set("User-Agent", kiroUserAgent)
|
||||
|
||||
resp, err := c.httpClient.Do(req)
|
||||
if err != nil {
|
||||
@@ -135,6 +146,7 @@ func (c *SSOOIDCClient) StartDeviceAuthorization(ctx context.Context, clientID,
|
||||
return nil, err
|
||||
}
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
req.Header.Set("User-Agent", kiroUserAgent)
|
||||
|
||||
resp, err := c.httpClient.Do(req)
|
||||
if err != nil {
|
||||
@@ -179,6 +191,7 @@ func (c *SSOOIDCClient) CreateToken(ctx context.Context, clientID, clientSecret,
|
||||
return nil, err
|
||||
}
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
req.Header.Set("User-Agent", kiroUserAgent)
|
||||
|
||||
resp, err := c.httpClient.Do(req)
|
||||
if err != nil {
|
||||
@@ -240,6 +253,7 @@ func (c *SSOOIDCClient) RefreshToken(ctx context.Context, clientID, clientSecret
|
||||
return nil, err
|
||||
}
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
req.Header.Set("User-Agent", kiroUserAgent)
|
||||
|
||||
resp, err := c.httpClient.Do(req)
|
||||
if err != nil {
|
||||
@@ -370,8 +384,8 @@ func (c *SSOOIDCClient) LoginWithBuilderID(ctx context.Context) (*KiroTokenData,
|
||||
fmt.Println("Fetching profile information...")
|
||||
profileArn := c.fetchProfileArn(ctx, tokenResp.AccessToken)
|
||||
|
||||
// Extract email from JWT access token
|
||||
email := ExtractEmailFromJWT(tokenResp.AccessToken)
|
||||
// Fetch user email (tries CodeWhisperer API first, then userinfo endpoint, then JWT parsing)
|
||||
email := FetchUserEmailWithFallback(ctx, c.cfg, tokenResp.AccessToken)
|
||||
if email != "" {
|
||||
fmt.Printf(" Logged in as: %s\n", email)
|
||||
}
|
||||
@@ -399,6 +413,68 @@ func (c *SSOOIDCClient) LoginWithBuilderID(ctx context.Context) (*KiroTokenData,
|
||||
return nil, fmt.Errorf("authorization timed out")
|
||||
}
|
||||
|
||||
// FetchUserEmail retrieves the user's email from AWS SSO OIDC userinfo endpoint.
|
||||
// Falls back to JWT parsing if userinfo fails.
|
||||
func (c *SSOOIDCClient) FetchUserEmail(ctx context.Context, accessToken string) string {
|
||||
// Method 1: Try userinfo endpoint (standard OIDC)
|
||||
email := c.tryUserInfoEndpoint(ctx, accessToken)
|
||||
if email != "" {
|
||||
return email
|
||||
}
|
||||
|
||||
// Method 2: Fallback to JWT parsing
|
||||
return ExtractEmailFromJWT(accessToken)
|
||||
}
|
||||
|
||||
// tryUserInfoEndpoint attempts to get user info from AWS SSO OIDC userinfo endpoint.
|
||||
func (c *SSOOIDCClient) tryUserInfoEndpoint(ctx context.Context, accessToken string) string {
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodGet, ssoOIDCEndpoint+"/userinfo", nil)
|
||||
if err != nil {
|
||||
return ""
|
||||
}
|
||||
req.Header.Set("Authorization", "Bearer "+accessToken)
|
||||
req.Header.Set("Accept", "application/json")
|
||||
|
||||
resp, err := c.httpClient.Do(req)
|
||||
if err != nil {
|
||||
log.Debugf("userinfo request failed: %v", err)
|
||||
return ""
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
respBody, _ := io.ReadAll(resp.Body)
|
||||
log.Debugf("userinfo endpoint returned status %d: %s", resp.StatusCode, string(respBody))
|
||||
return ""
|
||||
}
|
||||
|
||||
respBody, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return ""
|
||||
}
|
||||
|
||||
log.Debugf("userinfo response: %s", string(respBody))
|
||||
|
||||
var userInfo struct {
|
||||
Email string `json:"email"`
|
||||
Sub string `json:"sub"`
|
||||
PreferredUsername string `json:"preferred_username"`
|
||||
Name string `json:"name"`
|
||||
}
|
||||
|
||||
if err := json.Unmarshal(respBody, &userInfo); err != nil {
|
||||
return ""
|
||||
}
|
||||
|
||||
if userInfo.Email != "" {
|
||||
return userInfo.Email
|
||||
}
|
||||
if userInfo.PreferredUsername != "" && strings.Contains(userInfo.PreferredUsername, "@") {
|
||||
return userInfo.PreferredUsername
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
// fetchProfileArn retrieves the profile ARN from CodeWhisperer API.
|
||||
// This is needed for file naming since AWS SSO OIDC doesn't return profile info.
|
||||
func (c *SSOOIDCClient) fetchProfileArn(ctx context.Context, accessToken string) string {
|
||||
@@ -525,3 +601,323 @@ func (c *SSOOIDCClient) tryListCustomizations(ctx context.Context, accessToken s
|
||||
|
||||
return ""
|
||||
}
|
||||
|
||||
// RegisterClientForAuthCode registers a new OIDC client for authorization code flow.
|
||||
func (c *SSOOIDCClient) RegisterClientForAuthCode(ctx context.Context, redirectURI string) (*RegisterClientResponse, error) {
|
||||
payload := map[string]interface{}{
|
||||
"clientName": "Kiro IDE",
|
||||
"clientType": "public",
|
||||
"scopes": []string{"codewhisperer:completions", "codewhisperer:analysis", "codewhisperer:conversations", "codewhisperer:transformations", "codewhisperer:taskassist"},
|
||||
"grantTypes": []string{"authorization_code", "refresh_token"},
|
||||
"redirectUris": []string{redirectURI},
|
||||
"issuerUrl": builderIDStartURL,
|
||||
}
|
||||
|
||||
body, err := json.Marshal(payload)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodPost, ssoOIDCEndpoint+"/client/register", strings.NewReader(string(body)))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
req.Header.Set("User-Agent", kiroUserAgent)
|
||||
|
||||
resp, err := c.httpClient.Do(req)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
respBody, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
log.Debugf("register client for auth code failed (status %d): %s", resp.StatusCode, string(respBody))
|
||||
return nil, fmt.Errorf("register client failed (status %d)", resp.StatusCode)
|
||||
}
|
||||
|
||||
var result RegisterClientResponse
|
||||
if err := json.Unmarshal(respBody, &result); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &result, nil
|
||||
}
|
||||
|
||||
// AuthCodeCallbackResult contains the result from authorization code callback.
|
||||
type AuthCodeCallbackResult struct {
|
||||
Code string
|
||||
State string
|
||||
Error string
|
||||
}
|
||||
|
||||
// startAuthCodeCallbackServer starts a local HTTP server to receive the authorization code callback.
|
||||
func (c *SSOOIDCClient) startAuthCodeCallbackServer(ctx context.Context, expectedState string) (string, <-chan AuthCodeCallbackResult, error) {
|
||||
// Try to find an available port
|
||||
listener, err := net.Listen("tcp", fmt.Sprintf("127.0.0.1:%d", authCodeCallbackPort))
|
||||
if err != nil {
|
||||
// Try with dynamic port
|
||||
log.Warnf("sso oidc: default port %d is busy, falling back to dynamic port", authCodeCallbackPort)
|
||||
listener, err = net.Listen("tcp", "127.0.0.1:0")
|
||||
if err != nil {
|
||||
return "", nil, fmt.Errorf("failed to start callback server: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
port := listener.Addr().(*net.TCPAddr).Port
|
||||
redirectURI := fmt.Sprintf("http://127.0.0.1:%d%s", port, authCodeCallbackPath)
|
||||
resultChan := make(chan AuthCodeCallbackResult, 1)
|
||||
|
||||
server := &http.Server{
|
||||
ReadHeaderTimeout: 10 * time.Second,
|
||||
}
|
||||
|
||||
mux := http.NewServeMux()
|
||||
mux.HandleFunc(authCodeCallbackPath, func(w http.ResponseWriter, r *http.Request) {
|
||||
code := r.URL.Query().Get("code")
|
||||
state := r.URL.Query().Get("state")
|
||||
errParam := r.URL.Query().Get("error")
|
||||
|
||||
// Send response to browser
|
||||
w.Header().Set("Content-Type", "text/html; charset=utf-8")
|
||||
if errParam != "" {
|
||||
w.WriteHeader(http.StatusBadRequest)
|
||||
fmt.Fprintf(w, `<!DOCTYPE html>
|
||||
<html><head><title>Login Failed</title></head>
|
||||
<body><h1>Login Failed</h1><p>Error: %s</p><p>You can close this window.</p></body></html>`, html.EscapeString(errParam))
|
||||
resultChan <- AuthCodeCallbackResult{Error: errParam}
|
||||
return
|
||||
}
|
||||
|
||||
if state != expectedState {
|
||||
w.WriteHeader(http.StatusBadRequest)
|
||||
fmt.Fprint(w, `<!DOCTYPE html>
|
||||
<html><head><title>Login Failed</title></head>
|
||||
<body><h1>Login Failed</h1><p>Invalid state parameter</p><p>You can close this window.</p></body></html>`)
|
||||
resultChan <- AuthCodeCallbackResult{Error: "state mismatch"}
|
||||
return
|
||||
}
|
||||
|
||||
fmt.Fprint(w, `<!DOCTYPE html>
|
||||
<html><head><title>Login Successful</title></head>
|
||||
<body><h1>Login Successful!</h1><p>You can close this window and return to the terminal.</p>
|
||||
<script>window.close();</script></body></html>`)
|
||||
resultChan <- AuthCodeCallbackResult{Code: code, State: state}
|
||||
})
|
||||
|
||||
server.Handler = mux
|
||||
|
||||
go func() {
|
||||
if err := server.Serve(listener); err != nil && err != http.ErrServerClosed {
|
||||
log.Debugf("auth code callback server error: %v", err)
|
||||
}
|
||||
}()
|
||||
|
||||
go func() {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
case <-time.After(10 * time.Minute):
|
||||
case <-resultChan:
|
||||
}
|
||||
_ = server.Shutdown(context.Background())
|
||||
}()
|
||||
|
||||
return redirectURI, resultChan, nil
|
||||
}
|
||||
|
||||
// generatePKCEForAuthCode generates PKCE code verifier and challenge for authorization code flow.
|
||||
func generatePKCEForAuthCode() (verifier, challenge string, err error) {
|
||||
b := make([]byte, 32)
|
||||
if _, err := rand.Read(b); err != nil {
|
||||
return "", "", fmt.Errorf("failed to generate random bytes: %w", err)
|
||||
}
|
||||
verifier = base64.RawURLEncoding.EncodeToString(b)
|
||||
h := sha256.Sum256([]byte(verifier))
|
||||
challenge = base64.RawURLEncoding.EncodeToString(h[:])
|
||||
return verifier, challenge, nil
|
||||
}
|
||||
|
||||
// generateStateForAuthCode generates a random state parameter.
|
||||
func generateStateForAuthCode() (string, error) {
|
||||
b := make([]byte, 16)
|
||||
if _, err := rand.Read(b); err != nil {
|
||||
return "", err
|
||||
}
|
||||
return base64.RawURLEncoding.EncodeToString(b), nil
|
||||
}
|
||||
|
||||
// CreateTokenWithAuthCode exchanges authorization code for tokens.
|
||||
func (c *SSOOIDCClient) CreateTokenWithAuthCode(ctx context.Context, clientID, clientSecret, code, codeVerifier, redirectURI string) (*CreateTokenResponse, error) {
|
||||
payload := map[string]string{
|
||||
"clientId": clientID,
|
||||
"clientSecret": clientSecret,
|
||||
"code": code,
|
||||
"codeVerifier": codeVerifier,
|
||||
"redirectUri": redirectURI,
|
||||
"grantType": "authorization_code",
|
||||
}
|
||||
|
||||
body, err := json.Marshal(payload)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodPost, ssoOIDCEndpoint+"/token", strings.NewReader(string(body)))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
req.Header.Set("User-Agent", kiroUserAgent)
|
||||
|
||||
resp, err := c.httpClient.Do(req)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
respBody, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
log.Debugf("create token with auth code failed (status %d): %s", resp.StatusCode, string(respBody))
|
||||
return nil, fmt.Errorf("create token failed (status %d)", resp.StatusCode)
|
||||
}
|
||||
|
||||
var result CreateTokenResponse
|
||||
if err := json.Unmarshal(respBody, &result); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &result, nil
|
||||
}
|
||||
|
||||
// LoginWithBuilderIDAuthCode performs the authorization code flow for AWS Builder ID.
|
||||
// This provides a better UX than device code flow as it uses automatic browser callback.
|
||||
func (c *SSOOIDCClient) LoginWithBuilderIDAuthCode(ctx context.Context) (*KiroTokenData, error) {
|
||||
fmt.Println("\n╔══════════════════════════════════════════════════════════╗")
|
||||
fmt.Println("║ Kiro Authentication (AWS Builder ID - Auth Code) ║")
|
||||
fmt.Println("╚══════════════════════════════════════════════════════════╝")
|
||||
|
||||
// Step 1: Generate PKCE and state
|
||||
codeVerifier, codeChallenge, err := generatePKCEForAuthCode()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to generate PKCE: %w", err)
|
||||
}
|
||||
|
||||
state, err := generateStateForAuthCode()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to generate state: %w", err)
|
||||
}
|
||||
|
||||
// Step 2: Start callback server
|
||||
fmt.Println("\nStarting callback server...")
|
||||
redirectURI, resultChan, err := c.startAuthCodeCallbackServer(ctx, state)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to start callback server: %w", err)
|
||||
}
|
||||
log.Debugf("Callback server started, redirect URI: %s", redirectURI)
|
||||
|
||||
// Step 3: Register client with auth code grant type
|
||||
fmt.Println("Registering client...")
|
||||
regResp, err := c.RegisterClientForAuthCode(ctx, redirectURI)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to register client: %w", err)
|
||||
}
|
||||
log.Debugf("Client registered: %s", regResp.ClientID)
|
||||
|
||||
// Step 4: Build authorization URL
|
||||
scopes := "codewhisperer:completions,codewhisperer:analysis,codewhisperer:conversations"
|
||||
authURL := fmt.Sprintf("%s/authorize?response_type=code&client_id=%s&redirect_uri=%s&scopes=%s&state=%s&code_challenge=%s&code_challenge_method=S256",
|
||||
ssoOIDCEndpoint,
|
||||
regResp.ClientID,
|
||||
redirectURI,
|
||||
scopes,
|
||||
state,
|
||||
codeChallenge,
|
||||
)
|
||||
|
||||
// Step 5: Open browser
|
||||
fmt.Println("\n════════════════════════════════════════════════════════════")
|
||||
fmt.Println(" Opening browser for authentication...")
|
||||
fmt.Println("════════════════════════════════════════════════════════════")
|
||||
fmt.Printf("\n URL: %s\n\n", authURL)
|
||||
|
||||
// Set incognito mode
|
||||
if c.cfg != nil {
|
||||
browser.SetIncognitoMode(c.cfg.IncognitoBrowser)
|
||||
} else {
|
||||
browser.SetIncognitoMode(true)
|
||||
}
|
||||
|
||||
if err := browser.OpenURL(authURL); err != nil {
|
||||
log.Warnf("Could not open browser automatically: %v", err)
|
||||
fmt.Println(" ⚠ Could not open browser automatically.")
|
||||
fmt.Println(" Please open the URL above in your browser manually.")
|
||||
} else {
|
||||
fmt.Println(" (Browser opened automatically)")
|
||||
}
|
||||
|
||||
fmt.Println("\n Waiting for authorization callback...")
|
||||
|
||||
// Step 6: Wait for callback
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
browser.CloseBrowser()
|
||||
return nil, ctx.Err()
|
||||
case <-time.After(10 * time.Minute):
|
||||
browser.CloseBrowser()
|
||||
return nil, fmt.Errorf("authorization timed out")
|
||||
case result := <-resultChan:
|
||||
if result.Error != "" {
|
||||
browser.CloseBrowser()
|
||||
return nil, fmt.Errorf("authorization failed: %s", result.Error)
|
||||
}
|
||||
|
||||
fmt.Println("\n✓ Authorization received!")
|
||||
|
||||
// Close browser
|
||||
if err := browser.CloseBrowser(); err != nil {
|
||||
log.Debugf("Failed to close browser: %v", err)
|
||||
}
|
||||
|
||||
// Step 7: Exchange code for tokens
|
||||
fmt.Println("Exchanging code for tokens...")
|
||||
tokenResp, err := c.CreateTokenWithAuthCode(ctx, regResp.ClientID, regResp.ClientSecret, result.Code, codeVerifier, redirectURI)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to exchange code for tokens: %w", err)
|
||||
}
|
||||
|
||||
fmt.Println("\n✓ Authentication successful!")
|
||||
|
||||
// Step 8: Get profile ARN
|
||||
fmt.Println("Fetching profile information...")
|
||||
profileArn := c.fetchProfileArn(ctx, tokenResp.AccessToken)
|
||||
|
||||
// Fetch user email (tries CodeWhisperer API first, then userinfo endpoint, then JWT parsing)
|
||||
email := FetchUserEmailWithFallback(ctx, c.cfg, tokenResp.AccessToken)
|
||||
if email != "" {
|
||||
fmt.Printf(" Logged in as: %s\n", email)
|
||||
}
|
||||
|
||||
expiresAt := time.Now().Add(time.Duration(tokenResp.ExpiresIn) * time.Second)
|
||||
|
||||
return &KiroTokenData{
|
||||
AccessToken: tokenResp.AccessToken,
|
||||
RefreshToken: tokenResp.RefreshToken,
|
||||
ProfileArn: profileArn,
|
||||
ExpiresAt: expiresAt.Format(time.RFC3339),
|
||||
AuthMethod: "builder-id",
|
||||
Provider: "AWS",
|
||||
ClientID: regResp.ClientID,
|
||||
ClientSecret: regResp.ClientSecret,
|
||||
Email: email,
|
||||
}, nil
|
||||
}
|
||||
}
|
||||
|
||||
@@ -116,6 +116,54 @@ func DoKiroAWSLogin(cfg *config.Config, options *LoginOptions) {
|
||||
fmt.Println("Kiro AWS authentication successful!")
|
||||
}
|
||||
|
||||
// DoKiroAWSAuthCodeLogin triggers Kiro authentication with AWS Builder ID using authorization code flow.
|
||||
// This provides a better UX than device code flow as it uses automatic browser callback.
|
||||
//
|
||||
// Parameters:
|
||||
// - cfg: The application configuration
|
||||
// - options: Login options including prompts
|
||||
func DoKiroAWSAuthCodeLogin(cfg *config.Config, options *LoginOptions) {
|
||||
if options == nil {
|
||||
options = &LoginOptions{}
|
||||
}
|
||||
|
||||
// Note: Kiro defaults to incognito mode for multi-account support.
|
||||
// Users can override with --no-incognito if they want to use existing browser sessions.
|
||||
|
||||
manager := newAuthManager()
|
||||
|
||||
// Use KiroAuthenticator with AWS Builder ID login (authorization code flow)
|
||||
authenticator := sdkAuth.NewKiroAuthenticator()
|
||||
record, err := authenticator.LoginWithAuthCode(context.Background(), cfg, &sdkAuth.LoginOptions{
|
||||
NoBrowser: options.NoBrowser,
|
||||
Metadata: map[string]string{},
|
||||
Prompt: options.Prompt,
|
||||
})
|
||||
if err != nil {
|
||||
log.Errorf("Kiro AWS authentication (auth code) failed: %v", err)
|
||||
fmt.Println("\nTroubleshooting:")
|
||||
fmt.Println("1. Make sure you have an AWS Builder ID")
|
||||
fmt.Println("2. Complete the authorization in the browser")
|
||||
fmt.Println("3. If callback fails, try: --kiro-aws-login (device code flow)")
|
||||
return
|
||||
}
|
||||
|
||||
// Save the auth record
|
||||
savedPath, err := manager.SaveAuth(record, cfg)
|
||||
if err != nil {
|
||||
log.Errorf("Failed to save auth: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
if savedPath != "" {
|
||||
fmt.Printf("Authentication saved to %s\n", savedPath)
|
||||
}
|
||||
if record != nil && record.Label != "" {
|
||||
fmt.Printf("Authenticated as %s\n", record.Label)
|
||||
}
|
||||
fmt.Println("Kiro AWS authentication successful!")
|
||||
}
|
||||
|
||||
// DoKiroImport imports Kiro token from Kiro IDE's token file.
|
||||
// This is useful for users who have already logged in via Kiro IDE
|
||||
// and want to use the same credentials in CLI Proxy API.
|
||||
|
||||
@@ -160,7 +160,7 @@ func GetGeminiModels() []*ModelInfo {
|
||||
InputTokenLimit: 1048576,
|
||||
OutputTokenLimit: 65536,
|
||||
SupportedGenerationMethods: []string{"generateContent", "countTokens", "createCachedContent", "batchGenerateContent"},
|
||||
Thinking: &ThinkingSupport{Min: 128, Max: 32768, ZeroAllowed: false, DynamicAllowed: true},
|
||||
Thinking: &ThinkingSupport{Min: 128, Max: 32768, ZeroAllowed: false, DynamicAllowed: true, Levels: []string{"low", "high"}},
|
||||
},
|
||||
{
|
||||
ID: "gemini-3-pro-image-preview",
|
||||
@@ -175,7 +175,7 @@ func GetGeminiModels() []*ModelInfo {
|
||||
InputTokenLimit: 1048576,
|
||||
OutputTokenLimit: 65536,
|
||||
SupportedGenerationMethods: []string{"generateContent", "countTokens", "createCachedContent", "batchGenerateContent"},
|
||||
Thinking: &ThinkingSupport{Min: 128, Max: 32768, ZeroAllowed: false, DynamicAllowed: true},
|
||||
Thinking: &ThinkingSupport{Min: 128, Max: 32768, ZeroAllowed: false, DynamicAllowed: true, Levels: []string{"low", "high"}},
|
||||
},
|
||||
}
|
||||
}
|
||||
@@ -240,7 +240,22 @@ func GetGeminiVertexModels() []*ModelInfo {
|
||||
InputTokenLimit: 1048576,
|
||||
OutputTokenLimit: 65536,
|
||||
SupportedGenerationMethods: []string{"generateContent", "countTokens", "createCachedContent", "batchGenerateContent"},
|
||||
Thinking: &ThinkingSupport{Min: 128, Max: 32768, ZeroAllowed: false, DynamicAllowed: true},
|
||||
Thinking: &ThinkingSupport{Min: 128, Max: 32768, ZeroAllowed: false, DynamicAllowed: true, Levels: []string{"low", "high"}},
|
||||
},
|
||||
{
|
||||
ID: "gemini-3-flash-preview",
|
||||
Object: "model",
|
||||
Created: 1765929600,
|
||||
OwnedBy: "google",
|
||||
Type: "gemini",
|
||||
Name: "models/gemini-3-flash-preview",
|
||||
Version: "3.0",
|
||||
DisplayName: "Gemini 3 Flash Preview",
|
||||
Description: "Our most intelligent model built for speed, combining frontier intelligence with superior search and grounding.",
|
||||
InputTokenLimit: 1048576,
|
||||
OutputTokenLimit: 65536,
|
||||
SupportedGenerationMethods: []string{"generateContent", "countTokens", "createCachedContent", "batchGenerateContent"},
|
||||
Thinking: &ThinkingSupport{Min: 128, Max: 32768, ZeroAllowed: false, DynamicAllowed: true, Levels: []string{"minimal", "low", "medium", "high"}},
|
||||
},
|
||||
{
|
||||
ID: "gemini-3-pro-image-preview",
|
||||
@@ -255,7 +270,7 @@ func GetGeminiVertexModels() []*ModelInfo {
|
||||
InputTokenLimit: 1048576,
|
||||
OutputTokenLimit: 65536,
|
||||
SupportedGenerationMethods: []string{"generateContent", "countTokens", "createCachedContent", "batchGenerateContent"},
|
||||
Thinking: &ThinkingSupport{Min: 128, Max: 32768, ZeroAllowed: false, DynamicAllowed: true},
|
||||
Thinking: &ThinkingSupport{Min: 128, Max: 32768, ZeroAllowed: false, DynamicAllowed: true, Levels: []string{"low", "high"}},
|
||||
},
|
||||
}
|
||||
}
|
||||
@@ -317,11 +332,26 @@ func GetGeminiCLIModels() []*ModelInfo {
|
||||
Name: "models/gemini-3-pro-preview",
|
||||
Version: "3.0",
|
||||
DisplayName: "Gemini 3 Pro Preview",
|
||||
Description: "Gemini 3 Pro Preview",
|
||||
Description: "Our most intelligent model with SOTA reasoning and multimodal understanding, and powerful agentic and vibe coding capabilities",
|
||||
InputTokenLimit: 1048576,
|
||||
OutputTokenLimit: 65536,
|
||||
SupportedGenerationMethods: []string{"generateContent", "countTokens", "createCachedContent", "batchGenerateContent"},
|
||||
Thinking: &ThinkingSupport{Min: 128, Max: 32768, ZeroAllowed: false, DynamicAllowed: true},
|
||||
Thinking: &ThinkingSupport{Min: 128, Max: 32768, ZeroAllowed: false, DynamicAllowed: true, Levels: []string{"low", "high"}},
|
||||
},
|
||||
{
|
||||
ID: "gemini-3-flash-preview",
|
||||
Object: "model",
|
||||
Created: 1765929600,
|
||||
OwnedBy: "google",
|
||||
Type: "gemini",
|
||||
Name: "models/gemini-3-flash-preview",
|
||||
Version: "3.0",
|
||||
DisplayName: "Gemini 3 Flash Preview",
|
||||
Description: "Our most intelligent model built for speed, combining frontier intelligence with superior search and grounding.",
|
||||
InputTokenLimit: 1048576,
|
||||
OutputTokenLimit: 65536,
|
||||
SupportedGenerationMethods: []string{"generateContent", "countTokens", "createCachedContent", "batchGenerateContent"},
|
||||
Thinking: &ThinkingSupport{Min: 128, Max: 32768, ZeroAllowed: false, DynamicAllowed: true, Levels: []string{"minimal", "low", "medium", "high"}},
|
||||
},
|
||||
}
|
||||
}
|
||||
@@ -387,7 +417,22 @@ func GetAIStudioModels() []*ModelInfo {
|
||||
InputTokenLimit: 1048576,
|
||||
OutputTokenLimit: 65536,
|
||||
SupportedGenerationMethods: []string{"generateContent", "countTokens", "createCachedContent", "batchGenerateContent"},
|
||||
Thinking: &ThinkingSupport{Min: 128, Max: 32768, ZeroAllowed: false, DynamicAllowed: true},
|
||||
Thinking: &ThinkingSupport{Min: 128, Max: 32768, ZeroAllowed: false, DynamicAllowed: true, Levels: []string{"low", "high"}},
|
||||
},
|
||||
{
|
||||
ID: "gemini-3-flash-preview",
|
||||
Object: "model",
|
||||
Created: 1765929600,
|
||||
OwnedBy: "google",
|
||||
Type: "gemini",
|
||||
Name: "models/gemini-3-flash-preview",
|
||||
Version: "3.0",
|
||||
DisplayName: "Gemini 3 Flash Preview",
|
||||
Description: "Our most intelligent model built for speed, combining frontier intelligence with superior search and grounding.",
|
||||
InputTokenLimit: 1048576,
|
||||
OutputTokenLimit: 65536,
|
||||
SupportedGenerationMethods: []string{"generateContent", "countTokens", "createCachedContent", "batchGenerateContent"},
|
||||
Thinking: &ThinkingSupport{Min: 128, Max: 32768, ZeroAllowed: false, DynamicAllowed: true, Levels: []string{"minimal", "low", "medium", "high"}},
|
||||
},
|
||||
{
|
||||
ID: "gemini-pro-latest",
|
||||
@@ -698,8 +743,9 @@ func GetAntigravityModelConfig() map[string]*AntigravityModelConfig {
|
||||
"gemini-2.5-flash": {Thinking: &ThinkingSupport{Min: 0, Max: 24576, ZeroAllowed: true, DynamicAllowed: true}, Name: "models/gemini-2.5-flash"},
|
||||
"gemini-2.5-flash-lite": {Thinking: &ThinkingSupport{Min: 0, Max: 24576, ZeroAllowed: true, DynamicAllowed: true}, Name: "models/gemini-2.5-flash-lite"},
|
||||
"gemini-2.5-computer-use-preview-10-2025": {Name: "models/gemini-2.5-computer-use-preview-10-2025"},
|
||||
"gemini-3-pro-preview": {Thinking: &ThinkingSupport{Min: 128, Max: 32768, ZeroAllowed: false, DynamicAllowed: true}, Name: "models/gemini-3-pro-preview"},
|
||||
"gemini-3-pro-image-preview": {Thinking: &ThinkingSupport{Min: 128, Max: 32768, ZeroAllowed: false, DynamicAllowed: true}, Name: "models/gemini-3-pro-image-preview"},
|
||||
"gemini-3-pro-preview": {Thinking: &ThinkingSupport{Min: 128, Max: 32768, ZeroAllowed: false, DynamicAllowed: true, Levels: []string{"low", "high"}}, Name: "models/gemini-3-pro-preview"},
|
||||
"gemini-3-pro-image-preview": {Thinking: &ThinkingSupport{Min: 128, Max: 32768, ZeroAllowed: false, DynamicAllowed: true, Levels: []string{"low", "high"}}, Name: "models/gemini-3-pro-image-preview"},
|
||||
"gemini-3-flash-preview": {Thinking: &ThinkingSupport{Min: 128, Max: 32768, ZeroAllowed: false, DynamicAllowed: true, Levels: []string{"minimal", "low", "medium", "high"}}, Name: "models/gemini-3-flash-preview"},
|
||||
"gemini-claude-sonnet-4-5-thinking": {Thinking: &ThinkingSupport{Min: 1024, Max: 200000, ZeroAllowed: false, DynamicAllowed: true}, MaxCompletionTokens: 64000},
|
||||
"gemini-claude-opus-4-5-thinking": {Thinking: &ThinkingSupport{Min: 1024, Max: 200000, ZeroAllowed: false, DynamicAllowed: true}, MaxCompletionTokens: 64000},
|
||||
}
|
||||
@@ -787,6 +833,17 @@ func GetGitHubCopilotModels() []*ModelInfo {
|
||||
ContextLength: 128000,
|
||||
MaxCompletionTokens: 16384,
|
||||
},
|
||||
{
|
||||
ID: "gpt-5.2",
|
||||
Object: "model",
|
||||
Created: now,
|
||||
OwnedBy: "github-copilot",
|
||||
Type: "github-copilot",
|
||||
DisplayName: "GPT-5.2",
|
||||
Description: "OpenAI GPT-5.2 via GitHub Copilot",
|
||||
ContextLength: 200000,
|
||||
MaxCompletionTokens: 32768,
|
||||
},
|
||||
{
|
||||
ID: "claude-haiku-4.5",
|
||||
Object: "model",
|
||||
|
||||
@@ -323,8 +323,9 @@ func (e *AIStudioExecutor) translateRequest(req cliproxyexecutor.Request, opts c
|
||||
to := sdktranslator.FromString("gemini")
|
||||
payload := sdktranslator.TranslateRequest(from, to, req.Model, bytes.Clone(req.Payload), stream)
|
||||
payload = ApplyThinkingMetadata(payload, req.Metadata, req.Model)
|
||||
payload = util.ApplyGemini3ThinkingLevelFromMetadata(req.Model, req.Metadata, payload)
|
||||
payload = util.ApplyDefaultThinkingIfNeeded(req.Model, payload)
|
||||
payload = util.ConvertThinkingLevelToBudget(payload)
|
||||
payload = util.ConvertThinkingLevelToBudget(payload, req.Model)
|
||||
payload = util.NormalizeGeminiThinkingBudget(req.Model, payload)
|
||||
payload = util.StripThinkingConfigIfUnsupported(req.Model, payload)
|
||||
payload = fixGeminiImageAspectRatio(req.Model, payload)
|
||||
|
||||
@@ -33,15 +33,16 @@ import (
|
||||
const (
|
||||
antigravityBaseURLDaily = "https://daily-cloudcode-pa.sandbox.googleapis.com"
|
||||
// antigravityBaseURLAutopush = "https://autopush-cloudcode-pa.sandbox.googleapis.com"
|
||||
antigravityBaseURLProd = "https://cloudcode-pa.googleapis.com"
|
||||
antigravityStreamPath = "/v1internal:streamGenerateContent"
|
||||
antigravityGeneratePath = "/v1internal:generateContent"
|
||||
antigravityModelsPath = "/v1internal:fetchAvailableModels"
|
||||
antigravityClientID = "1071006060591-tmhssin2h21lcre235vtolojh4g403ep.apps.googleusercontent.com"
|
||||
antigravityClientSecret = "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"
|
||||
defaultAntigravityAgent = "antigravity/1.11.5 windows/amd64"
|
||||
antigravityAuthType = "antigravity"
|
||||
refreshSkew = 3000 * time.Second
|
||||
antigravityBaseURLProd = "https://cloudcode-pa.googleapis.com"
|
||||
antigravityCountTokensPath = "/v1internal:countTokens"
|
||||
antigravityStreamPath = "/v1internal:streamGenerateContent"
|
||||
antigravityGeneratePath = "/v1internal:generateContent"
|
||||
antigravityModelsPath = "/v1internal:fetchAvailableModels"
|
||||
antigravityClientID = "1071006060591-tmhssin2h21lcre235vtolojh4g403ep.apps.googleusercontent.com"
|
||||
antigravityClientSecret = "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"
|
||||
defaultAntigravityAgent = "antigravity/1.11.5 windows/amd64"
|
||||
antigravityAuthType = "antigravity"
|
||||
refreshSkew = 3000 * time.Second
|
||||
)
|
||||
|
||||
var (
|
||||
@@ -93,6 +94,7 @@ func (e *AntigravityExecutor) Execute(ctx context.Context, auth *cliproxyauth.Au
|
||||
translated := sdktranslator.TranslateRequest(from, to, req.Model, bytes.Clone(req.Payload), false)
|
||||
|
||||
translated = applyThinkingMetadataCLI(translated, req.Metadata, req.Model)
|
||||
translated = util.ApplyGemini3ThinkingLevelFromMetadataCLI(req.Model, req.Metadata, translated)
|
||||
translated = util.ApplyDefaultThinkingIfNeededCLI(req.Model, translated)
|
||||
translated = normalizeAntigravityThinking(req.Model, translated)
|
||||
|
||||
@@ -186,6 +188,7 @@ func (e *AntigravityExecutor) executeClaudeNonStream(ctx context.Context, auth *
|
||||
translated := sdktranslator.TranslateRequest(from, to, req.Model, bytes.Clone(req.Payload), true)
|
||||
|
||||
translated = applyThinkingMetadataCLI(translated, req.Metadata, req.Model)
|
||||
translated = util.ApplyGemini3ThinkingLevelFromMetadataCLI(req.Model, req.Metadata, translated)
|
||||
translated = util.ApplyDefaultThinkingIfNeededCLI(req.Model, translated)
|
||||
translated = normalizeAntigravityThinking(req.Model, translated)
|
||||
|
||||
@@ -518,6 +521,7 @@ func (e *AntigravityExecutor) ExecuteStream(ctx context.Context, auth *cliproxya
|
||||
translated := sdktranslator.TranslateRequest(from, to, req.Model, bytes.Clone(req.Payload), true)
|
||||
|
||||
translated = applyThinkingMetadataCLI(translated, req.Metadata, req.Model)
|
||||
translated = util.ApplyGemini3ThinkingLevelFromMetadataCLI(req.Model, req.Metadata, translated)
|
||||
translated = util.ApplyDefaultThinkingIfNeededCLI(req.Model, translated)
|
||||
translated = normalizeAntigravityThinking(req.Model, translated)
|
||||
|
||||
@@ -650,9 +654,131 @@ func (e *AntigravityExecutor) Refresh(ctx context.Context, auth *cliproxyauth.Au
|
||||
return updated, nil
|
||||
}
|
||||
|
||||
// CountTokens counts tokens for the given request (not supported for Antigravity).
|
||||
func (e *AntigravityExecutor) CountTokens(context.Context, *cliproxyauth.Auth, cliproxyexecutor.Request, cliproxyexecutor.Options) (cliproxyexecutor.Response, error) {
|
||||
return cliproxyexecutor.Response{}, statusErr{code: http.StatusNotImplemented, msg: "count tokens not supported"}
|
||||
// CountTokens counts tokens for the given request using the Antigravity API.
|
||||
func (e *AntigravityExecutor) CountTokens(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (cliproxyexecutor.Response, error) {
|
||||
token, updatedAuth, errToken := e.ensureAccessToken(ctx, auth)
|
||||
if errToken != nil {
|
||||
return cliproxyexecutor.Response{}, errToken
|
||||
}
|
||||
if updatedAuth != nil {
|
||||
auth = updatedAuth
|
||||
}
|
||||
if strings.TrimSpace(token) == "" {
|
||||
return cliproxyexecutor.Response{}, statusErr{code: http.StatusUnauthorized, msg: "missing access token"}
|
||||
}
|
||||
|
||||
from := opts.SourceFormat
|
||||
to := sdktranslator.FromString("antigravity")
|
||||
respCtx := context.WithValue(ctx, "alt", opts.Alt)
|
||||
|
||||
baseURLs := antigravityBaseURLFallbackOrder(auth)
|
||||
httpClient := newProxyAwareHTTPClient(ctx, e.cfg, auth, 0)
|
||||
|
||||
var authID, authLabel, authType, authValue string
|
||||
if auth != nil {
|
||||
authID = auth.ID
|
||||
authLabel = auth.Label
|
||||
authType, authValue = auth.AccountInfo()
|
||||
}
|
||||
|
||||
var lastStatus int
|
||||
var lastBody []byte
|
||||
var lastErr error
|
||||
|
||||
for idx, baseURL := range baseURLs {
|
||||
payload := sdktranslator.TranslateRequest(from, to, req.Model, bytes.Clone(req.Payload), false)
|
||||
payload = applyThinkingMetadataCLI(payload, req.Metadata, req.Model)
|
||||
payload = util.ApplyDefaultThinkingIfNeededCLI(req.Model, payload)
|
||||
payload = normalizeAntigravityThinking(req.Model, payload)
|
||||
payload = deleteJSONField(payload, "project")
|
||||
payload = deleteJSONField(payload, "model")
|
||||
payload = deleteJSONField(payload, "request.safetySettings")
|
||||
|
||||
base := strings.TrimSuffix(baseURL, "/")
|
||||
if base == "" {
|
||||
base = buildBaseURL(auth)
|
||||
}
|
||||
|
||||
var requestURL strings.Builder
|
||||
requestURL.WriteString(base)
|
||||
requestURL.WriteString(antigravityCountTokensPath)
|
||||
if opts.Alt != "" {
|
||||
requestURL.WriteString("?$alt=")
|
||||
requestURL.WriteString(url.QueryEscape(opts.Alt))
|
||||
}
|
||||
|
||||
httpReq, errReq := http.NewRequestWithContext(ctx, http.MethodPost, requestURL.String(), bytes.NewReader(payload))
|
||||
if errReq != nil {
|
||||
return cliproxyexecutor.Response{}, errReq
|
||||
}
|
||||
httpReq.Header.Set("Content-Type", "application/json")
|
||||
httpReq.Header.Set("Authorization", "Bearer "+token)
|
||||
httpReq.Header.Set("User-Agent", resolveUserAgent(auth))
|
||||
httpReq.Header.Set("Accept", "application/json")
|
||||
if host := resolveHost(base); host != "" {
|
||||
httpReq.Host = host
|
||||
}
|
||||
|
||||
recordAPIRequest(ctx, e.cfg, upstreamRequestLog{
|
||||
URL: requestURL.String(),
|
||||
Method: http.MethodPost,
|
||||
Headers: httpReq.Header.Clone(),
|
||||
Body: payload,
|
||||
Provider: e.Identifier(),
|
||||
AuthID: authID,
|
||||
AuthLabel: authLabel,
|
||||
AuthType: authType,
|
||||
AuthValue: authValue,
|
||||
})
|
||||
|
||||
httpResp, errDo := httpClient.Do(httpReq)
|
||||
if errDo != nil {
|
||||
recordAPIResponseError(ctx, e.cfg, errDo)
|
||||
lastStatus = 0
|
||||
lastBody = nil
|
||||
lastErr = errDo
|
||||
if idx+1 < len(baseURLs) {
|
||||
log.Debugf("antigravity executor: request error on base url %s, retrying with fallback base url: %s", baseURL, baseURLs[idx+1])
|
||||
continue
|
||||
}
|
||||
return cliproxyexecutor.Response{}, errDo
|
||||
}
|
||||
|
||||
recordAPIResponseMetadata(ctx, e.cfg, httpResp.StatusCode, httpResp.Header.Clone())
|
||||
bodyBytes, errRead := io.ReadAll(httpResp.Body)
|
||||
if errClose := httpResp.Body.Close(); errClose != nil {
|
||||
log.Errorf("antigravity executor: close response body error: %v", errClose)
|
||||
}
|
||||
if errRead != nil {
|
||||
recordAPIResponseError(ctx, e.cfg, errRead)
|
||||
return cliproxyexecutor.Response{}, errRead
|
||||
}
|
||||
appendAPIResponseChunk(ctx, e.cfg, bodyBytes)
|
||||
|
||||
if httpResp.StatusCode >= http.StatusOK && httpResp.StatusCode < http.StatusMultipleChoices {
|
||||
count := gjson.GetBytes(bodyBytes, "totalTokens").Int()
|
||||
translated := sdktranslator.TranslateTokenCount(respCtx, to, from, count, bodyBytes)
|
||||
return cliproxyexecutor.Response{Payload: []byte(translated)}, nil
|
||||
}
|
||||
|
||||
lastStatus = httpResp.StatusCode
|
||||
lastBody = append([]byte(nil), bodyBytes...)
|
||||
lastErr = nil
|
||||
if httpResp.StatusCode == http.StatusTooManyRequests && idx+1 < len(baseURLs) {
|
||||
log.Debugf("antigravity executor: rate limited on base url %s, retrying with fallback base url: %s", baseURL, baseURLs[idx+1])
|
||||
continue
|
||||
}
|
||||
return cliproxyexecutor.Response{}, statusErr{code: httpResp.StatusCode, msg: string(bodyBytes)}
|
||||
}
|
||||
|
||||
switch {
|
||||
case lastStatus != 0:
|
||||
return cliproxyexecutor.Response{}, statusErr{code: lastStatus, msg: string(lastBody)}
|
||||
case lastErr != nil:
|
||||
return cliproxyexecutor.Response{}, lastErr
|
||||
default:
|
||||
return cliproxyexecutor.Response{}, statusErr{code: http.StatusServiceUnavailable, msg: "antigravity executor: no base url available"}
|
||||
}
|
||||
}
|
||||
|
||||
// FetchAntigravityModels retrieves available models using the supplied auth.
|
||||
@@ -1122,6 +1248,8 @@ func modelName2Alias(modelName string) string {
|
||||
return "gemini-3-pro-image-preview"
|
||||
case "gemini-3-pro-high":
|
||||
return "gemini-3-pro-preview"
|
||||
case "gemini-3-flash":
|
||||
return "gemini-3-flash-preview"
|
||||
case "claude-sonnet-4-5":
|
||||
return "gemini-claude-sonnet-4-5"
|
||||
case "claude-sonnet-4-5-thinking":
|
||||
@@ -1143,6 +1271,8 @@ func alias2ModelName(modelName string) string {
|
||||
return "gemini-3-pro-image"
|
||||
case "gemini-3-pro-preview":
|
||||
return "gemini-3-pro-high"
|
||||
case "gemini-3-flash-preview":
|
||||
return "gemini-3-flash"
|
||||
case "gemini-claude-sonnet-4-5":
|
||||
return "claude-sonnet-4-5"
|
||||
case "gemini-claude-sonnet-4-5-thinking":
|
||||
|
||||
@@ -79,6 +79,7 @@ func (e *GeminiCLIExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth
|
||||
to := sdktranslator.FromString("gemini-cli")
|
||||
basePayload := sdktranslator.TranslateRequest(from, to, req.Model, bytes.Clone(req.Payload), false)
|
||||
basePayload = applyThinkingMetadataCLI(basePayload, req.Metadata, req.Model)
|
||||
basePayload = util.ApplyGemini3ThinkingLevelFromMetadataCLI(req.Model, req.Metadata, basePayload)
|
||||
basePayload = util.ApplyDefaultThinkingIfNeededCLI(req.Model, basePayload)
|
||||
basePayload = util.NormalizeGeminiCLIThinkingBudget(req.Model, basePayload)
|
||||
basePayload = util.StripThinkingConfigIfUnsupported(req.Model, basePayload)
|
||||
@@ -217,6 +218,7 @@ func (e *GeminiCLIExecutor) ExecuteStream(ctx context.Context, auth *cliproxyaut
|
||||
to := sdktranslator.FromString("gemini-cli")
|
||||
basePayload := sdktranslator.TranslateRequest(from, to, req.Model, bytes.Clone(req.Payload), true)
|
||||
basePayload = applyThinkingMetadataCLI(basePayload, req.Metadata, req.Model)
|
||||
basePayload = util.ApplyGemini3ThinkingLevelFromMetadataCLI(req.Model, req.Metadata, basePayload)
|
||||
basePayload = util.ApplyDefaultThinkingIfNeededCLI(req.Model, basePayload)
|
||||
basePayload = util.NormalizeGeminiCLIThinkingBudget(req.Model, basePayload)
|
||||
basePayload = util.StripThinkingConfigIfUnsupported(req.Model, basePayload)
|
||||
@@ -418,6 +420,7 @@ func (e *GeminiCLIExecutor) CountTokens(ctx context.Context, auth *cliproxyauth.
|
||||
for _, attemptModel := range models {
|
||||
payload := sdktranslator.TranslateRequest(from, to, attemptModel, bytes.Clone(req.Payload), false)
|
||||
payload = applyThinkingMetadataCLI(payload, req.Metadata, req.Model)
|
||||
payload = util.ApplyGemini3ThinkingLevelFromMetadataCLI(req.Model, req.Metadata, payload)
|
||||
payload = deleteJSONField(payload, "project")
|
||||
payload = deleteJSONField(payload, "model")
|
||||
payload = deleteJSONField(payload, "request.safetySettings")
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -73,10 +73,12 @@ func tokenizerForModel(model string) (*TokenizerWrapper, error) {
|
||||
switch {
|
||||
case sanitized == "":
|
||||
enc, err = tokenizer.Get(tokenizer.Cl100kBase)
|
||||
case strings.HasPrefix(sanitized, "gpt-5"):
|
||||
case strings.HasPrefix(sanitized, "gpt-5.2"):
|
||||
enc, err = tokenizer.ForModel(tokenizer.GPT5)
|
||||
case strings.HasPrefix(sanitized, "gpt-5.1"):
|
||||
enc, err = tokenizer.ForModel(tokenizer.GPT5)
|
||||
case strings.HasPrefix(sanitized, "gpt-5"):
|
||||
enc, err = tokenizer.ForModel(tokenizer.GPT5)
|
||||
case strings.HasPrefix(sanitized, "gpt-4.1"):
|
||||
enc, err = tokenizer.ForModel(tokenizer.GPT41)
|
||||
case strings.HasPrefix(sanitized, "gpt-4o"):
|
||||
@@ -154,10 +156,10 @@ func countClaudeChatTokens(enc *TokenizerWrapper, payload []byte) (int64, error)
|
||||
|
||||
// Collect system prompt (can be string or array of content blocks)
|
||||
collectClaudeSystem(root.Get("system"), &segments)
|
||||
|
||||
|
||||
// Collect messages
|
||||
collectClaudeMessages(root.Get("messages"), &segments)
|
||||
|
||||
|
||||
// Collect tools
|
||||
collectClaudeTools(root.Get("tools"), &segments)
|
||||
|
||||
|
||||
@@ -222,20 +222,19 @@ func BuildKiroPayload(claudeBody []byte, modelID, profileArn, origin string, isA
|
||||
kiroTools := convertClaudeToolsToKiro(tools)
|
||||
|
||||
// Thinking mode implementation:
|
||||
// Kiro API doesn't accept max_tokens for thinking. Instead, thinking mode is enabled
|
||||
// by injecting <thinking_mode> and <max_thinking_length> tags into the system prompt.
|
||||
// We use a fixed max_thinking_length value since Kiro handles the actual budget internally.
|
||||
// Kiro API supports official thinking/reasoning mode via <thinking_mode> tag.
|
||||
// When set to "enabled", Kiro returns reasoning content as official reasoningContentEvent
|
||||
// rather than inline <thinking> tags in assistantResponseEvent.
|
||||
// We use a high max_thinking_length to allow extensive reasoning.
|
||||
if thinkingEnabled {
|
||||
thinkingHint := `<thinking_mode>interleaved</thinking_mode>
|
||||
<max_thinking_length>200000</max_thinking_length>
|
||||
|
||||
IMPORTANT: You MUST use <thinking>...</thinking> tags to show your reasoning process before providing your final response. Think step by step inside the thinking tags.`
|
||||
thinkingHint := `<thinking_mode>enabled</thinking_mode>
|
||||
<max_thinking_length>200000</max_thinking_length>`
|
||||
if systemPrompt != "" {
|
||||
systemPrompt = thinkingHint + "\n\n" + systemPrompt
|
||||
} else {
|
||||
systemPrompt = thinkingHint
|
||||
}
|
||||
log.Infof("kiro: injected thinking prompt, has_tools: %v", len(kiroTools) > 0)
|
||||
log.Infof("kiro: injected thinking prompt (official mode), has_tools: %v", len(kiroTools) > 0)
|
||||
}
|
||||
|
||||
// Process messages and build history
|
||||
|
||||
@@ -231,20 +231,19 @@ func BuildKiroPayloadFromOpenAI(openaiBody []byte, modelID, profileArn, origin s
|
||||
kiroTools := convertOpenAIToolsToKiro(tools)
|
||||
|
||||
// Thinking mode implementation:
|
||||
// Kiro API doesn't accept max_tokens for thinking. Instead, thinking mode is enabled
|
||||
// by injecting <thinking_mode> and <max_thinking_length> tags into the system prompt.
|
||||
// We use a fixed max_thinking_length value since Kiro handles the actual budget internally.
|
||||
// Kiro API supports official thinking/reasoning mode via <thinking_mode> tag.
|
||||
// When set to "enabled", Kiro returns reasoning content as official reasoningContentEvent
|
||||
// rather than inline <thinking> tags in assistantResponseEvent.
|
||||
// We use a high max_thinking_length to allow extensive reasoning.
|
||||
if thinkingEnabled {
|
||||
thinkingHint := `<thinking_mode>interleaved</thinking_mode>
|
||||
<max_thinking_length>200000</max_thinking_length>
|
||||
|
||||
IMPORTANT: You MUST use <thinking>...</thinking> tags to show your reasoning process before providing your final response. Think step by step inside the thinking tags.`
|
||||
thinkingHint := `<thinking_mode>enabled</thinking_mode>
|
||||
<max_thinking_length>200000</max_thinking_length>`
|
||||
if systemPrompt != "" {
|
||||
systemPrompt = thinkingHint + "\n\n" + systemPrompt
|
||||
} else {
|
||||
systemPrompt = thinkingHint
|
||||
}
|
||||
log.Debugf("kiro-openai: injected thinking prompt")
|
||||
log.Debugf("kiro-openai: injected thinking prompt (official mode)")
|
||||
}
|
||||
|
||||
// Process messages and build history
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
package util
|
||||
|
||||
import (
|
||||
"regexp"
|
||||
"strings"
|
||||
|
||||
"github.com/tidwall/gjson"
|
||||
@@ -13,6 +14,44 @@ const (
|
||||
GeminiOriginalModelMetadataKey = "gemini_original_model"
|
||||
)
|
||||
|
||||
// Gemini model family detection patterns
|
||||
var (
|
||||
gemini3Pattern = regexp.MustCompile(`(?i)^gemini[_-]?3[_-]`)
|
||||
gemini3ProPattern = regexp.MustCompile(`(?i)^gemini[_-]?3[_-]pro`)
|
||||
gemini3FlashPattern = regexp.MustCompile(`(?i)^gemini[_-]?3[_-]flash`)
|
||||
gemini25Pattern = regexp.MustCompile(`(?i)^gemini[_-]?2\.5[_-]`)
|
||||
)
|
||||
|
||||
// IsGemini3Model returns true if the model is a Gemini 3 family model.
|
||||
// Gemini 3 models should use thinkingLevel (string) instead of thinkingBudget (number).
|
||||
func IsGemini3Model(model string) bool {
|
||||
return gemini3Pattern.MatchString(model)
|
||||
}
|
||||
|
||||
// IsGemini3ProModel returns true if the model is a Gemini 3 Pro variant.
|
||||
// Gemini 3 Pro supports thinkingLevel: "low", "high" (default: "high")
|
||||
func IsGemini3ProModel(model string) bool {
|
||||
return gemini3ProPattern.MatchString(model)
|
||||
}
|
||||
|
||||
// IsGemini3FlashModel returns true if the model is a Gemini 3 Flash variant.
|
||||
// Gemini 3 Flash supports thinkingLevel: "minimal", "low", "medium", "high" (default: "high")
|
||||
func IsGemini3FlashModel(model string) bool {
|
||||
return gemini3FlashPattern.MatchString(model)
|
||||
}
|
||||
|
||||
// IsGemini25Model returns true if the model is a Gemini 2.5 family model.
|
||||
// Gemini 2.5 models should use thinkingBudget (number).
|
||||
func IsGemini25Model(model string) bool {
|
||||
return gemini25Pattern.MatchString(model)
|
||||
}
|
||||
|
||||
// Gemini3ProThinkingLevels are the valid thinkingLevel values for Gemini 3 Pro models.
|
||||
var Gemini3ProThinkingLevels = []string{"low", "high"}
|
||||
|
||||
// Gemini3FlashThinkingLevels are the valid thinkingLevel values for Gemini 3 Flash models.
|
||||
var Gemini3FlashThinkingLevels = []string{"minimal", "low", "medium", "high"}
|
||||
|
||||
func ApplyGeminiThinkingConfig(body []byte, budget *int, includeThoughts *bool) []byte {
|
||||
if budget == nil && includeThoughts == nil {
|
||||
return body
|
||||
@@ -69,10 +108,141 @@ func ApplyGeminiCLIThinkingConfig(body []byte, budget *int, includeThoughts *boo
|
||||
return updated
|
||||
}
|
||||
|
||||
// ApplyGeminiThinkingLevel applies thinkingLevel config for Gemini 3 models.
|
||||
// For standard Gemini API format (generationConfig.thinkingConfig path).
|
||||
// Per Google's documentation, Gemini 3 models should use thinkingLevel instead of thinkingBudget.
|
||||
func ApplyGeminiThinkingLevel(body []byte, level string, includeThoughts *bool) []byte {
|
||||
if level == "" && includeThoughts == nil {
|
||||
return body
|
||||
}
|
||||
updated := body
|
||||
if level != "" {
|
||||
valuePath := "generationConfig.thinkingConfig.thinkingLevel"
|
||||
rewritten, err := sjson.SetBytes(updated, valuePath, level)
|
||||
if err == nil {
|
||||
updated = rewritten
|
||||
}
|
||||
}
|
||||
// Default to including thoughts when a level is set but no explicit include flag is provided.
|
||||
incl := includeThoughts
|
||||
if incl == nil && level != "" {
|
||||
defaultInclude := true
|
||||
incl = &defaultInclude
|
||||
}
|
||||
if incl != nil {
|
||||
valuePath := "generationConfig.thinkingConfig.includeThoughts"
|
||||
rewritten, err := sjson.SetBytes(updated, valuePath, *incl)
|
||||
if err == nil {
|
||||
updated = rewritten
|
||||
}
|
||||
}
|
||||
return updated
|
||||
}
|
||||
|
||||
// ApplyGeminiCLIThinkingLevel applies thinkingLevel config for Gemini 3 models.
|
||||
// For Gemini CLI API format (request.generationConfig.thinkingConfig path).
|
||||
// Per Google's documentation, Gemini 3 models should use thinkingLevel instead of thinkingBudget.
|
||||
func ApplyGeminiCLIThinkingLevel(body []byte, level string, includeThoughts *bool) []byte {
|
||||
if level == "" && includeThoughts == nil {
|
||||
return body
|
||||
}
|
||||
updated := body
|
||||
if level != "" {
|
||||
valuePath := "request.generationConfig.thinkingConfig.thinkingLevel"
|
||||
rewritten, err := sjson.SetBytes(updated, valuePath, level)
|
||||
if err == nil {
|
||||
updated = rewritten
|
||||
}
|
||||
}
|
||||
// Default to including thoughts when a level is set but no explicit include flag is provided.
|
||||
incl := includeThoughts
|
||||
if incl == nil && level != "" {
|
||||
defaultInclude := true
|
||||
incl = &defaultInclude
|
||||
}
|
||||
if incl != nil {
|
||||
valuePath := "request.generationConfig.thinkingConfig.includeThoughts"
|
||||
rewritten, err := sjson.SetBytes(updated, valuePath, *incl)
|
||||
if err == nil {
|
||||
updated = rewritten
|
||||
}
|
||||
}
|
||||
return updated
|
||||
}
|
||||
|
||||
// ValidateGemini3ThinkingLevel validates that the thinkingLevel is valid for the Gemini 3 model variant.
|
||||
// Returns the validated level (normalized to lowercase) and true if valid, or empty string and false if invalid.
|
||||
func ValidateGemini3ThinkingLevel(model, level string) (string, bool) {
|
||||
if level == "" {
|
||||
return "", false
|
||||
}
|
||||
normalized := strings.ToLower(strings.TrimSpace(level))
|
||||
|
||||
var validLevels []string
|
||||
if IsGemini3ProModel(model) {
|
||||
validLevels = Gemini3ProThinkingLevels
|
||||
} else if IsGemini3FlashModel(model) {
|
||||
validLevels = Gemini3FlashThinkingLevels
|
||||
} else if IsGemini3Model(model) {
|
||||
// Unknown Gemini 3 variant - allow all levels as fallback
|
||||
validLevels = Gemini3FlashThinkingLevels
|
||||
} else {
|
||||
return "", false
|
||||
}
|
||||
|
||||
for _, valid := range validLevels {
|
||||
if normalized == valid {
|
||||
return normalized, true
|
||||
}
|
||||
}
|
||||
return "", false
|
||||
}
|
||||
|
||||
// ThinkingBudgetToGemini3Level converts a thinkingBudget to a thinkingLevel for Gemini 3 models.
|
||||
// This provides backward compatibility when thinkingBudget is provided for Gemini 3 models.
|
||||
// Returns the appropriate thinkingLevel and true if conversion is possible.
|
||||
func ThinkingBudgetToGemini3Level(model string, budget int) (string, bool) {
|
||||
if !IsGemini3Model(model) {
|
||||
return "", false
|
||||
}
|
||||
|
||||
// Map budget to level based on Google's documentation
|
||||
// Gemini 3 Pro: "low", "high" (default: "high")
|
||||
// Gemini 3 Flash: "minimal", "low", "medium", "high" (default: "high")
|
||||
switch {
|
||||
case budget == -1:
|
||||
// Dynamic budget maps to "high" (API default)
|
||||
return "high", true
|
||||
case budget == 0:
|
||||
// Zero budget - Gemini 3 doesn't support disabling thinking
|
||||
// Map to lowest available level
|
||||
if IsGemini3FlashModel(model) {
|
||||
return "minimal", true
|
||||
}
|
||||
return "low", true
|
||||
case budget > 0 && budget <= 512:
|
||||
if IsGemini3FlashModel(model) {
|
||||
return "minimal", true
|
||||
}
|
||||
return "low", true
|
||||
case budget <= 1024:
|
||||
return "low", true
|
||||
case budget <= 8192:
|
||||
if IsGemini3FlashModel(model) {
|
||||
return "medium", true
|
||||
}
|
||||
return "low", true // Pro doesn't have medium, use low
|
||||
default:
|
||||
return "high", true
|
||||
}
|
||||
}
|
||||
|
||||
// modelsWithDefaultThinking lists models that should have thinking enabled by default
|
||||
// when no explicit thinkingConfig is provided.
|
||||
var modelsWithDefaultThinking = map[string]bool{
|
||||
"gemini-3-pro-preview": true,
|
||||
"gemini-3-pro-preview": true,
|
||||
"gemini-3-pro-image-preview": true,
|
||||
"gemini-3-flash-preview": true,
|
||||
}
|
||||
|
||||
// ModelHasDefaultThinking returns true if the model should have thinking enabled by default.
|
||||
@@ -83,6 +253,7 @@ func ModelHasDefaultThinking(model string) bool {
|
||||
// ApplyDefaultThinkingIfNeeded injects default thinkingConfig for models that require it.
|
||||
// For standard Gemini API format (generationConfig.thinkingConfig path).
|
||||
// Returns the modified body if thinkingConfig was added, otherwise returns the original.
|
||||
// For Gemini 3 models, uses thinkingLevel instead of thinkingBudget per Google's documentation.
|
||||
func ApplyDefaultThinkingIfNeeded(model string, body []byte) []byte {
|
||||
if !ModelHasDefaultThinking(model) {
|
||||
return body
|
||||
@@ -90,14 +261,59 @@ func ApplyDefaultThinkingIfNeeded(model string, body []byte) []byte {
|
||||
if gjson.GetBytes(body, "generationConfig.thinkingConfig").Exists() {
|
||||
return body
|
||||
}
|
||||
// Gemini 3 models use thinkingLevel instead of thinkingBudget
|
||||
if IsGemini3Model(model) {
|
||||
// Don't set a default - let the API use its dynamic default ("high")
|
||||
// Only set includeThoughts
|
||||
updated, _ := sjson.SetBytes(body, "generationConfig.thinkingConfig.includeThoughts", true)
|
||||
return updated
|
||||
}
|
||||
// Gemini 2.5 and other models use thinkingBudget
|
||||
updated, _ := sjson.SetBytes(body, "generationConfig.thinkingConfig.thinkingBudget", -1)
|
||||
updated, _ = sjson.SetBytes(updated, "generationConfig.thinkingConfig.include_thoughts", true)
|
||||
return updated
|
||||
}
|
||||
|
||||
// ApplyGemini3ThinkingLevelFromMetadata applies thinkingLevel from metadata for Gemini 3 models.
|
||||
// For standard Gemini API format (generationConfig.thinkingConfig path).
|
||||
// This handles the case where reasoning_effort is specified via model name suffix (e.g., model(minimal)).
|
||||
func ApplyGemini3ThinkingLevelFromMetadata(model string, metadata map[string]any, body []byte) []byte {
|
||||
if !IsGemini3Model(model) {
|
||||
return body
|
||||
}
|
||||
effort, ok := ReasoningEffortFromMetadata(metadata)
|
||||
if !ok || effort == "" {
|
||||
return body
|
||||
}
|
||||
// Validate and apply the thinkingLevel
|
||||
if level, valid := ValidateGemini3ThinkingLevel(model, effort); valid {
|
||||
return ApplyGeminiThinkingLevel(body, level, nil)
|
||||
}
|
||||
return body
|
||||
}
|
||||
|
||||
// ApplyGemini3ThinkingLevelFromMetadataCLI applies thinkingLevel from metadata for Gemini 3 models.
|
||||
// For Gemini CLI API format (request.generationConfig.thinkingConfig path).
|
||||
// This handles the case where reasoning_effort is specified via model name suffix (e.g., model(minimal)).
|
||||
func ApplyGemini3ThinkingLevelFromMetadataCLI(model string, metadata map[string]any, body []byte) []byte {
|
||||
if !IsGemini3Model(model) {
|
||||
return body
|
||||
}
|
||||
effort, ok := ReasoningEffortFromMetadata(metadata)
|
||||
if !ok || effort == "" {
|
||||
return body
|
||||
}
|
||||
// Validate and apply the thinkingLevel
|
||||
if level, valid := ValidateGemini3ThinkingLevel(model, effort); valid {
|
||||
return ApplyGeminiCLIThinkingLevel(body, level, nil)
|
||||
}
|
||||
return body
|
||||
}
|
||||
|
||||
// ApplyDefaultThinkingIfNeededCLI injects default thinkingConfig for models that require it.
|
||||
// For Gemini CLI API format (request.generationConfig.thinkingConfig path).
|
||||
// Returns the modified body if thinkingConfig was added, otherwise returns the original.
|
||||
// For Gemini 3 models, uses thinkingLevel instead of thinkingBudget per Google's documentation.
|
||||
func ApplyDefaultThinkingIfNeededCLI(model string, body []byte) []byte {
|
||||
if !ModelHasDefaultThinking(model) {
|
||||
return body
|
||||
@@ -105,6 +321,14 @@ func ApplyDefaultThinkingIfNeededCLI(model string, body []byte) []byte {
|
||||
if gjson.GetBytes(body, "request.generationConfig.thinkingConfig").Exists() {
|
||||
return body
|
||||
}
|
||||
// Gemini 3 models use thinkingLevel instead of thinkingBudget
|
||||
if IsGemini3Model(model) {
|
||||
// Don't set a default - let the API use its dynamic default ("high")
|
||||
// Only set includeThoughts
|
||||
updated, _ := sjson.SetBytes(body, "request.generationConfig.thinkingConfig.includeThoughts", true)
|
||||
return updated
|
||||
}
|
||||
// Gemini 2.5 and other models use thinkingBudget
|
||||
updated, _ := sjson.SetBytes(body, "request.generationConfig.thinkingConfig.thinkingBudget", -1)
|
||||
updated, _ = sjson.SetBytes(updated, "request.generationConfig.thinkingConfig.include_thoughts", true)
|
||||
return updated
|
||||
@@ -128,12 +352,29 @@ func StripThinkingConfigIfUnsupported(model string, body []byte) []byte {
|
||||
|
||||
// NormalizeGeminiThinkingBudget normalizes the thinkingBudget value in a standard Gemini
|
||||
// request body (generationConfig.thinkingConfig.thinkingBudget path).
|
||||
// For Gemini 3 models, converts thinkingBudget to thinkingLevel per Google's documentation.
|
||||
func NormalizeGeminiThinkingBudget(model string, body []byte) []byte {
|
||||
const budgetPath = "generationConfig.thinkingConfig.thinkingBudget"
|
||||
const levelPath = "generationConfig.thinkingConfig.thinkingLevel"
|
||||
|
||||
budget := gjson.GetBytes(body, budgetPath)
|
||||
if !budget.Exists() {
|
||||
return body
|
||||
}
|
||||
|
||||
// For Gemini 3 models, convert thinkingBudget to thinkingLevel
|
||||
if IsGemini3Model(model) {
|
||||
if level, ok := ThinkingBudgetToGemini3Level(model, int(budget.Int())); ok {
|
||||
updated, _ := sjson.SetBytes(body, levelPath, level)
|
||||
updated, _ = sjson.DeleteBytes(updated, budgetPath)
|
||||
return updated
|
||||
}
|
||||
// If conversion fails, just remove the budget (let API use default)
|
||||
updated, _ := sjson.DeleteBytes(body, budgetPath)
|
||||
return updated
|
||||
}
|
||||
|
||||
// For Gemini 2.5 and other models, normalize the budget value
|
||||
normalized := NormalizeThinkingBudget(model, int(budget.Int()))
|
||||
updated, _ := sjson.SetBytes(body, budgetPath, normalized)
|
||||
return updated
|
||||
@@ -141,12 +382,29 @@ func NormalizeGeminiThinkingBudget(model string, body []byte) []byte {
|
||||
|
||||
// NormalizeGeminiCLIThinkingBudget normalizes the thinkingBudget value in a Gemini CLI
|
||||
// request body (request.generationConfig.thinkingConfig.thinkingBudget path).
|
||||
// For Gemini 3 models, converts thinkingBudget to thinkingLevel per Google's documentation.
|
||||
func NormalizeGeminiCLIThinkingBudget(model string, body []byte) []byte {
|
||||
const budgetPath = "request.generationConfig.thinkingConfig.thinkingBudget"
|
||||
const levelPath = "request.generationConfig.thinkingConfig.thinkingLevel"
|
||||
|
||||
budget := gjson.GetBytes(body, budgetPath)
|
||||
if !budget.Exists() {
|
||||
return body
|
||||
}
|
||||
|
||||
// For Gemini 3 models, convert thinkingBudget to thinkingLevel
|
||||
if IsGemini3Model(model) {
|
||||
if level, ok := ThinkingBudgetToGemini3Level(model, int(budget.Int())); ok {
|
||||
updated, _ := sjson.SetBytes(body, levelPath, level)
|
||||
updated, _ = sjson.DeleteBytes(updated, budgetPath)
|
||||
return updated
|
||||
}
|
||||
// If conversion fails, just remove the budget (let API use default)
|
||||
updated, _ := sjson.DeleteBytes(body, budgetPath)
|
||||
return updated
|
||||
}
|
||||
|
||||
// For Gemini 2.5 and other models, normalize the budget value
|
||||
normalized := NormalizeThinkingBudget(model, int(budget.Int()))
|
||||
updated, _ := sjson.SetBytes(body, budgetPath, normalized)
|
||||
return updated
|
||||
@@ -218,34 +476,42 @@ func ApplyReasoningEffortToGeminiCLI(body []byte, effort string) []byte {
|
||||
}
|
||||
|
||||
// ConvertThinkingLevelToBudget checks for "generationConfig.thinkingConfig.thinkingLevel"
|
||||
// and converts it to "thinkingBudget".
|
||||
// "high" -> 32768
|
||||
// "low" -> 128
|
||||
// It removes "thinkingLevel" after conversion.
|
||||
func ConvertThinkingLevelToBudget(body []byte) []byte {
|
||||
// and converts it to "thinkingBudget" for Gemini 2.5 models.
|
||||
// For Gemini 3 models, preserves thinkingLevel as-is (does not convert).
|
||||
// Mappings for Gemini 2.5:
|
||||
// - "high" -> 32768
|
||||
// - "medium" -> 8192
|
||||
// - "low" -> 1024
|
||||
// - "minimal" -> 512
|
||||
//
|
||||
// It removes "thinkingLevel" after conversion (for Gemini 2.5 only).
|
||||
func ConvertThinkingLevelToBudget(body []byte, model string) []byte {
|
||||
levelPath := "generationConfig.thinkingConfig.thinkingLevel"
|
||||
res := gjson.GetBytes(body, levelPath)
|
||||
if !res.Exists() {
|
||||
return body
|
||||
}
|
||||
|
||||
// For Gemini 3 models, preserve thinkingLevel - don't convert to budget
|
||||
if IsGemini3Model(model) {
|
||||
return body
|
||||
}
|
||||
|
||||
level := strings.ToLower(res.String())
|
||||
var budget int
|
||||
switch level {
|
||||
case "high":
|
||||
budget = 32768
|
||||
case "medium":
|
||||
budget = 8192
|
||||
case "low":
|
||||
budget = 128
|
||||
budget = 1024
|
||||
case "minimal":
|
||||
budget = 512
|
||||
default:
|
||||
// If unknown level, we might just leave it or default.
|
||||
// User only specified high and low. We'll assume we shouldn't touch it if it's something else,
|
||||
// or maybe we should just remove the invalid level?
|
||||
// For safety adhering to strict instructions: "If high... if low...".
|
||||
// If it's something else, the upstream might fail anyway if we leave it,
|
||||
// but let's just delete the level if we processed it.
|
||||
// Actually, let's check if we need to do anything for other values.
|
||||
// For now, only handle high/low.
|
||||
return body
|
||||
// Unknown level - remove it and let the API use defaults
|
||||
updated, _ := sjson.DeleteBytes(body, levelPath)
|
||||
return updated
|
||||
}
|
||||
|
||||
// Set budget
|
||||
@@ -262,3 +528,50 @@ func ConvertThinkingLevelToBudget(body []byte) []byte {
|
||||
}
|
||||
return updated
|
||||
}
|
||||
|
||||
// ConvertThinkingLevelToBudgetCLI checks for "request.generationConfig.thinkingConfig.thinkingLevel"
|
||||
// and converts it to "thinkingBudget" for Gemini 2.5 models.
|
||||
// For Gemini 3 models, preserves thinkingLevel as-is (does not convert).
|
||||
func ConvertThinkingLevelToBudgetCLI(body []byte, model string) []byte {
|
||||
levelPath := "request.generationConfig.thinkingConfig.thinkingLevel"
|
||||
res := gjson.GetBytes(body, levelPath)
|
||||
if !res.Exists() {
|
||||
return body
|
||||
}
|
||||
|
||||
// For Gemini 3 models, preserve thinkingLevel - don't convert to budget
|
||||
if IsGemini3Model(model) {
|
||||
return body
|
||||
}
|
||||
|
||||
level := strings.ToLower(res.String())
|
||||
var budget int
|
||||
switch level {
|
||||
case "high":
|
||||
budget = 32768
|
||||
case "medium":
|
||||
budget = 8192
|
||||
case "low":
|
||||
budget = 1024
|
||||
case "minimal":
|
||||
budget = 512
|
||||
default:
|
||||
// Unknown level - remove it and let the API use defaults
|
||||
updated, _ := sjson.DeleteBytes(body, levelPath)
|
||||
return updated
|
||||
}
|
||||
|
||||
// Set budget
|
||||
budgetPath := "request.generationConfig.thinkingConfig.thinkingBudget"
|
||||
updated, err := sjson.SetBytes(body, budgetPath, budget)
|
||||
if err != nil {
|
||||
return body
|
||||
}
|
||||
|
||||
// Remove level
|
||||
updated, err = sjson.DeleteBytes(updated, levelPath)
|
||||
if err != nil {
|
||||
return body
|
||||
}
|
||||
return updated
|
||||
}
|
||||
|
||||
@@ -117,6 +117,71 @@ func (a *KiroAuthenticator) Login(ctx context.Context, cfg *config.Config, opts
|
||||
return record, nil
|
||||
}
|
||||
|
||||
// LoginWithAuthCode performs OAuth login for Kiro with AWS Builder ID using authorization code flow.
|
||||
// This provides a better UX than device code flow as it uses automatic browser callback.
|
||||
func (a *KiroAuthenticator) LoginWithAuthCode(ctx context.Context, cfg *config.Config, opts *LoginOptions) (*coreauth.Auth, error) {
|
||||
if cfg == nil {
|
||||
return nil, fmt.Errorf("kiro auth: configuration is required")
|
||||
}
|
||||
|
||||
oauth := kiroauth.NewKiroOAuth(cfg)
|
||||
|
||||
// Use AWS Builder ID authorization code flow
|
||||
tokenData, err := oauth.LoginWithBuilderIDAuthCode(ctx)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("login failed: %w", err)
|
||||
}
|
||||
|
||||
// Parse expires_at
|
||||
expiresAt, err := time.Parse(time.RFC3339, tokenData.ExpiresAt)
|
||||
if err != nil {
|
||||
expiresAt = time.Now().Add(1 * time.Hour)
|
||||
}
|
||||
|
||||
// Extract identifier for file naming
|
||||
idPart := extractKiroIdentifier(tokenData.Email, tokenData.ProfileArn)
|
||||
|
||||
now := time.Now()
|
||||
fileName := fmt.Sprintf("kiro-aws-%s.json", idPart)
|
||||
|
||||
record := &coreauth.Auth{
|
||||
ID: fileName,
|
||||
Provider: "kiro",
|
||||
FileName: fileName,
|
||||
Label: "kiro-aws",
|
||||
Status: coreauth.StatusActive,
|
||||
CreatedAt: now,
|
||||
UpdatedAt: now,
|
||||
Metadata: map[string]any{
|
||||
"type": "kiro",
|
||||
"access_token": tokenData.AccessToken,
|
||||
"refresh_token": tokenData.RefreshToken,
|
||||
"profile_arn": tokenData.ProfileArn,
|
||||
"expires_at": tokenData.ExpiresAt,
|
||||
"auth_method": tokenData.AuthMethod,
|
||||
"provider": tokenData.Provider,
|
||||
"client_id": tokenData.ClientID,
|
||||
"client_secret": tokenData.ClientSecret,
|
||||
"email": tokenData.Email,
|
||||
},
|
||||
Attributes: map[string]string{
|
||||
"profile_arn": tokenData.ProfileArn,
|
||||
"source": "aws-builder-id-authcode",
|
||||
"email": tokenData.Email,
|
||||
},
|
||||
// NextRefreshAfter is aligned with RefreshLead (5min)
|
||||
NextRefreshAfter: expiresAt.Add(-5 * time.Minute),
|
||||
}
|
||||
|
||||
if tokenData.Email != "" {
|
||||
fmt.Printf("\n✓ Kiro authentication completed successfully! (Account: %s)\n", tokenData.Email)
|
||||
} else {
|
||||
fmt.Println("\n✓ Kiro authentication completed successfully!")
|
||||
}
|
||||
|
||||
return record, nil
|
||||
}
|
||||
|
||||
// LoginWithGoogle performs OAuth login for Kiro with Google.
|
||||
// This uses a custom protocol handler (kiro://) to receive the callback.
|
||||
func (a *KiroAuthenticator) LoginWithGoogle(ctx context.Context, cfg *config.Config, opts *LoginOptions) (*coreauth.Auth, error) {
|
||||
|
||||
423
test/gemini3_thinking_level_test.go
Normal file
423
test/gemini3_thinking_level_test.go
Normal file
@@ -0,0 +1,423 @@
|
||||
package test
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/registry"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/util"
|
||||
"github.com/tidwall/gjson"
|
||||
)
|
||||
|
||||
// registerGemini3Models loads Gemini 3 models into the registry for testing.
|
||||
func registerGemini3Models(t *testing.T) func() {
|
||||
t.Helper()
|
||||
reg := registry.GetGlobalRegistry()
|
||||
uid := fmt.Sprintf("gemini3-test-%d", time.Now().UnixNano())
|
||||
reg.RegisterClient(uid+"-gemini", "gemini", registry.GetGeminiModels())
|
||||
reg.RegisterClient(uid+"-aistudio", "aistudio", registry.GetAIStudioModels())
|
||||
return func() {
|
||||
reg.UnregisterClient(uid + "-gemini")
|
||||
reg.UnregisterClient(uid + "-aistudio")
|
||||
}
|
||||
}
|
||||
|
||||
func TestIsGemini3Model(t *testing.T) {
|
||||
cases := []struct {
|
||||
model string
|
||||
expected bool
|
||||
}{
|
||||
{"gemini-3-pro-preview", true},
|
||||
{"gemini-3-flash-preview", true},
|
||||
{"gemini_3_pro_preview", true},
|
||||
{"gemini-3-pro", true},
|
||||
{"gemini-3-flash", true},
|
||||
{"GEMINI-3-PRO-PREVIEW", true},
|
||||
{"gemini-2.5-pro", false},
|
||||
{"gemini-2.5-flash", false},
|
||||
{"gpt-5", false},
|
||||
{"claude-sonnet-4-5", false},
|
||||
{"", false},
|
||||
}
|
||||
|
||||
for _, cs := range cases {
|
||||
t.Run(cs.model, func(t *testing.T) {
|
||||
got := util.IsGemini3Model(cs.model)
|
||||
if got != cs.expected {
|
||||
t.Fatalf("IsGemini3Model(%q) = %v, want %v", cs.model, got, cs.expected)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestIsGemini3ProModel(t *testing.T) {
|
||||
cases := []struct {
|
||||
model string
|
||||
expected bool
|
||||
}{
|
||||
{"gemini-3-pro-preview", true},
|
||||
{"gemini_3_pro_preview", true},
|
||||
{"gemini-3-pro", true},
|
||||
{"GEMINI-3-PRO-PREVIEW", true},
|
||||
{"gemini-3-flash-preview", false},
|
||||
{"gemini-3-flash", false},
|
||||
{"gemini-2.5-pro", false},
|
||||
{"", false},
|
||||
}
|
||||
|
||||
for _, cs := range cases {
|
||||
t.Run(cs.model, func(t *testing.T) {
|
||||
got := util.IsGemini3ProModel(cs.model)
|
||||
if got != cs.expected {
|
||||
t.Fatalf("IsGemini3ProModel(%q) = %v, want %v", cs.model, got, cs.expected)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestIsGemini3FlashModel(t *testing.T) {
|
||||
cases := []struct {
|
||||
model string
|
||||
expected bool
|
||||
}{
|
||||
{"gemini-3-flash-preview", true},
|
||||
{"gemini_3_flash_preview", true},
|
||||
{"gemini-3-flash", true},
|
||||
{"GEMINI-3-FLASH-PREVIEW", true},
|
||||
{"gemini-3-pro-preview", false},
|
||||
{"gemini-3-pro", false},
|
||||
{"gemini-2.5-flash", false},
|
||||
{"", false},
|
||||
}
|
||||
|
||||
for _, cs := range cases {
|
||||
t.Run(cs.model, func(t *testing.T) {
|
||||
got := util.IsGemini3FlashModel(cs.model)
|
||||
if got != cs.expected {
|
||||
t.Fatalf("IsGemini3FlashModel(%q) = %v, want %v", cs.model, got, cs.expected)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestValidateGemini3ThinkingLevel(t *testing.T) {
|
||||
cases := []struct {
|
||||
name string
|
||||
model string
|
||||
level string
|
||||
wantOK bool
|
||||
wantVal string
|
||||
}{
|
||||
// Gemini 3 Pro: supports "low", "high"
|
||||
{"pro-low", "gemini-3-pro-preview", "low", true, "low"},
|
||||
{"pro-high", "gemini-3-pro-preview", "high", true, "high"},
|
||||
{"pro-minimal-invalid", "gemini-3-pro-preview", "minimal", false, ""},
|
||||
{"pro-medium-invalid", "gemini-3-pro-preview", "medium", false, ""},
|
||||
|
||||
// Gemini 3 Flash: supports "minimal", "low", "medium", "high"
|
||||
{"flash-minimal", "gemini-3-flash-preview", "minimal", true, "minimal"},
|
||||
{"flash-low", "gemini-3-flash-preview", "low", true, "low"},
|
||||
{"flash-medium", "gemini-3-flash-preview", "medium", true, "medium"},
|
||||
{"flash-high", "gemini-3-flash-preview", "high", true, "high"},
|
||||
|
||||
// Case insensitivity
|
||||
{"flash-LOW-case", "gemini-3-flash-preview", "LOW", true, "low"},
|
||||
{"flash-High-case", "gemini-3-flash-preview", "High", true, "high"},
|
||||
{"pro-HIGH-case", "gemini-3-pro-preview", "HIGH", true, "high"},
|
||||
|
||||
// Invalid levels
|
||||
{"flash-invalid", "gemini-3-flash-preview", "xhigh", false, ""},
|
||||
{"flash-invalid-auto", "gemini-3-flash-preview", "auto", false, ""},
|
||||
{"flash-empty", "gemini-3-flash-preview", "", false, ""},
|
||||
|
||||
// Non-Gemini 3 models
|
||||
{"non-gemini3", "gemini-2.5-pro", "high", false, ""},
|
||||
{"gpt5", "gpt-5", "high", false, ""},
|
||||
}
|
||||
|
||||
for _, cs := range cases {
|
||||
t.Run(cs.name, func(t *testing.T) {
|
||||
got, ok := util.ValidateGemini3ThinkingLevel(cs.model, cs.level)
|
||||
if ok != cs.wantOK {
|
||||
t.Fatalf("ValidateGemini3ThinkingLevel(%q, %q) ok = %v, want %v", cs.model, cs.level, ok, cs.wantOK)
|
||||
}
|
||||
if got != cs.wantVal {
|
||||
t.Fatalf("ValidateGemini3ThinkingLevel(%q, %q) = %q, want %q", cs.model, cs.level, got, cs.wantVal)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestThinkingBudgetToGemini3Level(t *testing.T) {
|
||||
cases := []struct {
|
||||
name string
|
||||
model string
|
||||
budget int
|
||||
wantOK bool
|
||||
wantVal string
|
||||
}{
|
||||
// Gemini 3 Pro: maps to "low" or "high"
|
||||
{"pro-dynamic", "gemini-3-pro-preview", -1, true, "high"},
|
||||
{"pro-zero", "gemini-3-pro-preview", 0, true, "low"},
|
||||
{"pro-small", "gemini-3-pro-preview", 1000, true, "low"},
|
||||
{"pro-medium", "gemini-3-pro-preview", 8000, true, "low"},
|
||||
{"pro-large", "gemini-3-pro-preview", 20000, true, "high"},
|
||||
{"pro-huge", "gemini-3-pro-preview", 50000, true, "high"},
|
||||
|
||||
// Gemini 3 Flash: maps to "minimal", "low", "medium", "high"
|
||||
{"flash-dynamic", "gemini-3-flash-preview", -1, true, "high"},
|
||||
{"flash-zero", "gemini-3-flash-preview", 0, true, "minimal"},
|
||||
{"flash-tiny", "gemini-3-flash-preview", 500, true, "minimal"},
|
||||
{"flash-small", "gemini-3-flash-preview", 1000, true, "low"},
|
||||
{"flash-medium-val", "gemini-3-flash-preview", 8000, true, "medium"},
|
||||
{"flash-large", "gemini-3-flash-preview", 20000, true, "high"},
|
||||
{"flash-huge", "gemini-3-flash-preview", 50000, true, "high"},
|
||||
|
||||
// Non-Gemini 3 models should return false
|
||||
{"gemini25-budget", "gemini-2.5-pro", 8000, false, ""},
|
||||
{"gpt5-budget", "gpt-5", 8000, false, ""},
|
||||
}
|
||||
|
||||
for _, cs := range cases {
|
||||
t.Run(cs.name, func(t *testing.T) {
|
||||
got, ok := util.ThinkingBudgetToGemini3Level(cs.model, cs.budget)
|
||||
if ok != cs.wantOK {
|
||||
t.Fatalf("ThinkingBudgetToGemini3Level(%q, %d) ok = %v, want %v", cs.model, cs.budget, ok, cs.wantOK)
|
||||
}
|
||||
if got != cs.wantVal {
|
||||
t.Fatalf("ThinkingBudgetToGemini3Level(%q, %d) = %q, want %q", cs.model, cs.budget, got, cs.wantVal)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestApplyGemini3ThinkingLevelFromMetadata(t *testing.T) {
|
||||
cleanup := registerGemini3Models(t)
|
||||
defer cleanup()
|
||||
|
||||
cases := []struct {
|
||||
name string
|
||||
model string
|
||||
metadata map[string]any
|
||||
inputBody string
|
||||
wantLevel string
|
||||
wantInclude bool
|
||||
wantNoChange bool
|
||||
}{
|
||||
{
|
||||
name: "flash-minimal-from-suffix",
|
||||
model: "gemini-3-flash-preview",
|
||||
metadata: map[string]any{"reasoning_effort": "minimal"},
|
||||
inputBody: `{"generationConfig":{"thinkingConfig":{"includeThoughts":true}}}`,
|
||||
wantLevel: "minimal",
|
||||
wantInclude: true,
|
||||
},
|
||||
{
|
||||
name: "flash-medium-from-suffix",
|
||||
model: "gemini-3-flash-preview",
|
||||
metadata: map[string]any{"reasoning_effort": "medium"},
|
||||
inputBody: `{"generationConfig":{"thinkingConfig":{"includeThoughts":true}}}`,
|
||||
wantLevel: "medium",
|
||||
wantInclude: true,
|
||||
},
|
||||
{
|
||||
name: "pro-high-from-suffix",
|
||||
model: "gemini-3-pro-preview",
|
||||
metadata: map[string]any{"reasoning_effort": "high"},
|
||||
inputBody: `{"generationConfig":{"thinkingConfig":{"includeThoughts":true}}}`,
|
||||
wantLevel: "high",
|
||||
wantInclude: true,
|
||||
},
|
||||
{
|
||||
name: "no-metadata-no-change",
|
||||
model: "gemini-3-flash-preview",
|
||||
metadata: nil,
|
||||
inputBody: `{"generationConfig":{"thinkingConfig":{"includeThoughts":true}}}`,
|
||||
wantNoChange: true,
|
||||
},
|
||||
{
|
||||
name: "non-gemini3-no-change",
|
||||
model: "gemini-2.5-pro",
|
||||
metadata: map[string]any{"reasoning_effort": "high"},
|
||||
inputBody: `{"generationConfig":{"thinkingConfig":{"thinkingBudget":-1}}}`,
|
||||
wantNoChange: true,
|
||||
},
|
||||
{
|
||||
name: "invalid-level-no-change",
|
||||
model: "gemini-3-flash-preview",
|
||||
metadata: map[string]any{"reasoning_effort": "xhigh"},
|
||||
inputBody: `{"generationConfig":{"thinkingConfig":{"includeThoughts":true}}}`,
|
||||
wantNoChange: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, cs := range cases {
|
||||
t.Run(cs.name, func(t *testing.T) {
|
||||
input := []byte(cs.inputBody)
|
||||
result := util.ApplyGemini3ThinkingLevelFromMetadata(cs.model, cs.metadata, input)
|
||||
|
||||
if cs.wantNoChange {
|
||||
if string(result) != cs.inputBody {
|
||||
t.Fatalf("expected no change, but got: %s", string(result))
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
level := gjson.GetBytes(result, "generationConfig.thinkingConfig.thinkingLevel")
|
||||
if !level.Exists() {
|
||||
t.Fatalf("thinkingLevel not set in result: %s", string(result))
|
||||
}
|
||||
if level.String() != cs.wantLevel {
|
||||
t.Fatalf("thinkingLevel = %q, want %q", level.String(), cs.wantLevel)
|
||||
}
|
||||
|
||||
include := gjson.GetBytes(result, "generationConfig.thinkingConfig.includeThoughts")
|
||||
if cs.wantInclude && (!include.Exists() || !include.Bool()) {
|
||||
t.Fatalf("includeThoughts should be true, got: %s", string(result))
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestApplyGemini3ThinkingLevelFromMetadataCLI(t *testing.T) {
|
||||
cleanup := registerGemini3Models(t)
|
||||
defer cleanup()
|
||||
|
||||
cases := []struct {
|
||||
name string
|
||||
model string
|
||||
metadata map[string]any
|
||||
inputBody string
|
||||
wantLevel string
|
||||
wantInclude bool
|
||||
wantNoChange bool
|
||||
}{
|
||||
{
|
||||
name: "flash-minimal-from-suffix-cli",
|
||||
model: "gemini-3-flash-preview",
|
||||
metadata: map[string]any{"reasoning_effort": "minimal"},
|
||||
inputBody: `{"request":{"generationConfig":{"thinkingConfig":{"includeThoughts":true}}}}`,
|
||||
wantLevel: "minimal",
|
||||
wantInclude: true,
|
||||
},
|
||||
{
|
||||
name: "flash-low-from-suffix-cli",
|
||||
model: "gemini-3-flash-preview",
|
||||
metadata: map[string]any{"reasoning_effort": "low"},
|
||||
inputBody: `{"request":{"generationConfig":{"thinkingConfig":{"includeThoughts":true}}}}`,
|
||||
wantLevel: "low",
|
||||
wantInclude: true,
|
||||
},
|
||||
{
|
||||
name: "pro-low-from-suffix-cli",
|
||||
model: "gemini-3-pro-preview",
|
||||
metadata: map[string]any{"reasoning_effort": "low"},
|
||||
inputBody: `{"request":{"generationConfig":{"thinkingConfig":{"includeThoughts":true}}}}`,
|
||||
wantLevel: "low",
|
||||
wantInclude: true,
|
||||
},
|
||||
{
|
||||
name: "no-metadata-no-change-cli",
|
||||
model: "gemini-3-flash-preview",
|
||||
metadata: nil,
|
||||
inputBody: `{"request":{"generationConfig":{"thinkingConfig":{"includeThoughts":true}}}}`,
|
||||
wantNoChange: true,
|
||||
},
|
||||
{
|
||||
name: "non-gemini3-no-change-cli",
|
||||
model: "gemini-2.5-pro",
|
||||
metadata: map[string]any{"reasoning_effort": "high"},
|
||||
inputBody: `{"request":{"generationConfig":{"thinkingConfig":{"thinkingBudget":-1}}}}`,
|
||||
wantNoChange: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, cs := range cases {
|
||||
t.Run(cs.name, func(t *testing.T) {
|
||||
input := []byte(cs.inputBody)
|
||||
result := util.ApplyGemini3ThinkingLevelFromMetadataCLI(cs.model, cs.metadata, input)
|
||||
|
||||
if cs.wantNoChange {
|
||||
if string(result) != cs.inputBody {
|
||||
t.Fatalf("expected no change, but got: %s", string(result))
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
level := gjson.GetBytes(result, "request.generationConfig.thinkingConfig.thinkingLevel")
|
||||
if !level.Exists() {
|
||||
t.Fatalf("thinkingLevel not set in result: %s", string(result))
|
||||
}
|
||||
if level.String() != cs.wantLevel {
|
||||
t.Fatalf("thinkingLevel = %q, want %q", level.String(), cs.wantLevel)
|
||||
}
|
||||
|
||||
include := gjson.GetBytes(result, "request.generationConfig.thinkingConfig.includeThoughts")
|
||||
if cs.wantInclude && (!include.Exists() || !include.Bool()) {
|
||||
t.Fatalf("includeThoughts should be true, got: %s", string(result))
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestNormalizeGeminiThinkingBudget_Gemini3Conversion(t *testing.T) {
|
||||
cleanup := registerGemini3Models(t)
|
||||
defer cleanup()
|
||||
|
||||
cases := []struct {
|
||||
name string
|
||||
model string
|
||||
inputBody string
|
||||
wantLevel string
|
||||
wantBudget bool // if true, expect thinkingBudget instead of thinkingLevel
|
||||
}{
|
||||
{
|
||||
name: "gemini3-flash-budget-to-level",
|
||||
model: "gemini-3-flash-preview",
|
||||
inputBody: `{"generationConfig":{"thinkingConfig":{"thinkingBudget":8000}}}`,
|
||||
wantLevel: "medium",
|
||||
},
|
||||
{
|
||||
name: "gemini3-pro-budget-to-level",
|
||||
model: "gemini-3-pro-preview",
|
||||
inputBody: `{"generationConfig":{"thinkingConfig":{"thinkingBudget":20000}}}`,
|
||||
wantLevel: "high",
|
||||
},
|
||||
{
|
||||
name: "gemini25-keeps-budget",
|
||||
model: "gemini-2.5-pro",
|
||||
inputBody: `{"generationConfig":{"thinkingConfig":{"thinkingBudget":8000}}}`,
|
||||
wantBudget: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, cs := range cases {
|
||||
t.Run(cs.name, func(t *testing.T) {
|
||||
result := util.NormalizeGeminiThinkingBudget(cs.model, []byte(cs.inputBody))
|
||||
|
||||
if cs.wantBudget {
|
||||
budget := gjson.GetBytes(result, "generationConfig.thinkingConfig.thinkingBudget")
|
||||
if !budget.Exists() {
|
||||
t.Fatalf("thinkingBudget should exist for non-Gemini3 model: %s", string(result))
|
||||
}
|
||||
level := gjson.GetBytes(result, "generationConfig.thinkingConfig.thinkingLevel")
|
||||
if level.Exists() {
|
||||
t.Fatalf("thinkingLevel should not exist for non-Gemini3 model: %s", string(result))
|
||||
}
|
||||
} else {
|
||||
level := gjson.GetBytes(result, "generationConfig.thinkingConfig.thinkingLevel")
|
||||
if !level.Exists() {
|
||||
t.Fatalf("thinkingLevel should exist for Gemini3 model: %s", string(result))
|
||||
}
|
||||
if level.String() != cs.wantLevel {
|
||||
t.Fatalf("thinkingLevel = %q, want %q", level.String(), cs.wantLevel)
|
||||
}
|
||||
budget := gjson.GetBytes(result, "generationConfig.thinkingConfig.thinkingBudget")
|
||||
if budget.Exists() {
|
||||
t.Fatalf("thinkingBudget should be removed for Gemini3 model: %s", string(result))
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user