From 3a9ac7ef331da59da78d8504cab00a639f837cb7 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Ernesto=20Mart=C3=ADnez?= Date: Thu, 27 Nov 2025 20:14:30 +0100 Subject: [PATCH 01/16] feat(auth): add GitHub Copilot authentication and API integration Add complete GitHub Copilot support including: - Device flow OAuth authentication via GitHub's official client ID - Token management with automatic caching (25 min TTL) - OpenAI-compatible API executor for api.githubcopilot.com - 16 model definitions (GPT-5 variants, Claude variants, Gemini, Grok, Raptor) - CLI login command via -github-copilot-login flag - SDK authenticator and refresh registry integration Enables users to authenticate with their GitHub Copilot subscription and use it as a backend provider alongside existing providers. --- cmd/server/main.go | 5 + internal/auth/copilot/copilot_auth.go | 301 +++++++++++++++ internal/auth/copilot/errors.go | 183 +++++++++ internal/auth/copilot/oauth.go | 259 +++++++++++++ internal/auth/copilot/token.go | 93 +++++ internal/cmd/auth_manager.go | 3 +- internal/cmd/github_copilot_login.go | 44 +++ internal/registry/model_definitions.go | 184 +++++++++ .../executor/github_copilot_executor.go | 354 ++++++++++++++++++ sdk/auth/github_copilot.go | 133 +++++++ sdk/auth/refresh_registry.go | 1 + sdk/cliproxy/service.go | 4 + 12 files changed, 1563 insertions(+), 1 deletion(-) create mode 100644 internal/auth/copilot/copilot_auth.go create mode 100644 internal/auth/copilot/errors.go create mode 100644 internal/auth/copilot/oauth.go create mode 100644 internal/auth/copilot/token.go create mode 100644 internal/cmd/github_copilot_login.go create mode 100644 internal/runtime/executor/github_copilot_executor.go create mode 100644 sdk/auth/github_copilot.go diff --git a/cmd/server/main.go b/cmd/server/main.go index bbf500e7..b8ee66ba 100644 --- a/cmd/server/main.go +++ b/cmd/server/main.go @@ -62,6 +62,7 @@ func main() { var iflowCookie bool var noBrowser bool var antigravityLogin bool + var githubCopilotLogin bool var projectID string var vertexImport string var configPath string @@ -76,6 +77,7 @@ func main() { flag.BoolVar(&iflowCookie, "iflow-cookie", false, "Login to iFlow using Cookie") flag.BoolVar(&noBrowser, "no-browser", false, "Don't open browser automatically for OAuth") flag.BoolVar(&antigravityLogin, "antigravity-login", false, "Login to Antigravity using OAuth") + flag.BoolVar(&githubCopilotLogin, "github-copilot-login", false, "Login to GitHub Copilot using device flow") flag.StringVar(&projectID, "project_id", "", "Project ID (Gemini only, not required)") flag.StringVar(&configPath, "config", DefaultConfigPath, "Configure File Path") flag.StringVar(&vertexImport, "vertex-import", "", "Import Vertex service account key JSON file") @@ -436,6 +438,9 @@ func main() { } else if antigravityLogin { // Handle Antigravity login cmd.DoAntigravityLogin(cfg, options) + } else if githubCopilotLogin { + // Handle GitHub Copilot login + cmd.DoGitHubCopilotLogin(cfg, options) } else if codexLogin { // Handle Codex login cmd.DoCodexLogin(cfg, options) diff --git a/internal/auth/copilot/copilot_auth.go b/internal/auth/copilot/copilot_auth.go new file mode 100644 index 00000000..fbfb1762 --- /dev/null +++ b/internal/auth/copilot/copilot_auth.go @@ -0,0 +1,301 @@ +// Package copilot provides authentication and token management for GitHub Copilot API. +// It handles the OAuth2 device flow for secure authentication with the Copilot API. +package copilot + +import ( + "context" + "encoding/json" + "fmt" + "io" + "net/http" + "strings" + "time" + + "github.com/router-for-me/CLIProxyAPI/v6/internal/config" + "github.com/router-for-me/CLIProxyAPI/v6/internal/util" + log "github.com/sirupsen/logrus" +) + +const ( + // copilotAPITokenURL is the endpoint for getting Copilot API tokens from GitHub token. + copilotAPITokenURL = "https://api.github.com/copilot_internal/v2/token" + // copilotAPIEndpoint is the base URL for making API requests. + copilotAPIEndpoint = "https://api.githubcopilot.com" +) + +// CopilotAPIToken represents the Copilot API token response. +type CopilotAPIToken struct { + // Token is the JWT token for authenticating with the Copilot API. + Token string `json:"token"` + // ExpiresAt is the Unix timestamp when the token expires. + ExpiresAt int64 `json:"expires_at"` + // Endpoints contains the available API endpoints. + Endpoints struct { + API string `json:"api"` + Proxy string `json:"proxy"` + OriginTracker string `json:"origin-tracker"` + Telemetry string `json:"telemetry"` + } `json:"endpoints,omitempty"` + // ErrorDetails contains error information if the request failed. + ErrorDetails *struct { + URL string `json:"url"` + Message string `json:"message"` + DocumentationURL string `json:"documentation_url"` + } `json:"error_details,omitempty"` +} + +// CopilotAuth handles GitHub Copilot authentication flow. +// It provides methods for device flow authentication and token management. +type CopilotAuth struct { + httpClient *http.Client + deviceClient *DeviceFlowClient + cfg *config.Config +} + +// NewCopilotAuth creates a new CopilotAuth service instance. +// It initializes an HTTP client with proxy settings from the provided configuration. +func NewCopilotAuth(cfg *config.Config) *CopilotAuth { + return &CopilotAuth{ + httpClient: util.SetProxy(&cfg.SDKConfig, &http.Client{Timeout: 30 * time.Second}), + deviceClient: NewDeviceFlowClient(cfg), + cfg: cfg, + } +} + +// StartDeviceFlow initiates the device flow authentication. +// Returns the device code response containing the user code and verification URI. +func (c *CopilotAuth) StartDeviceFlow(ctx context.Context) (*DeviceCodeResponse, error) { + return c.deviceClient.RequestDeviceCode(ctx) +} + +// WaitForAuthorization polls for user authorization and returns the auth bundle. +func (c *CopilotAuth) WaitForAuthorization(ctx context.Context, deviceCode *DeviceCodeResponse) (*CopilotAuthBundle, error) { + tokenData, err := c.deviceClient.PollForToken(ctx, deviceCode) + if err != nil { + return nil, err + } + + // Fetch the GitHub username + username, err := c.deviceClient.FetchUserInfo(ctx, tokenData.AccessToken) + if err != nil { + log.Warnf("copilot: failed to fetch user info: %v", err) + username = "unknown" + } + + return &CopilotAuthBundle{ + TokenData: tokenData, + Username: username, + }, nil +} + +// GetCopilotAPIToken exchanges a GitHub access token for a Copilot API token. +// This token is used to make authenticated requests to the Copilot API. +func (c *CopilotAuth) GetCopilotAPIToken(ctx context.Context, githubAccessToken string) (*CopilotAPIToken, error) { + if githubAccessToken == "" { + return nil, NewAuthenticationError(ErrTokenExchangeFailed, fmt.Errorf("github access token is empty")) + } + + req, err := http.NewRequestWithContext(ctx, http.MethodGet, copilotAPITokenURL, nil) + if err != nil { + return nil, NewAuthenticationError(ErrTokenExchangeFailed, err) + } + + req.Header.Set("Authorization", "token "+githubAccessToken) + req.Header.Set("Accept", "application/json") + req.Header.Set("User-Agent", "GithubCopilot/1.0") + req.Header.Set("Editor-Version", "vscode/1.100.0") + req.Header.Set("Editor-Plugin-Version", "copilot/1.300.0") + + resp, err := c.httpClient.Do(req) + if err != nil { + return nil, NewAuthenticationError(ErrTokenExchangeFailed, err) + } + defer func() { + if errClose := resp.Body.Close(); errClose != nil { + log.Errorf("copilot api token: close body error: %v", errClose) + } + }() + + bodyBytes, err := io.ReadAll(resp.Body) + if err != nil { + return nil, NewAuthenticationError(ErrTokenExchangeFailed, err) + } + + if resp.StatusCode < http.StatusOK || resp.StatusCode >= http.StatusMultipleChoices { + return nil, NewAuthenticationError(ErrTokenExchangeFailed, + fmt.Errorf("status %d: %s", resp.StatusCode, string(bodyBytes))) + } + + var apiToken CopilotAPIToken + if err = json.Unmarshal(bodyBytes, &apiToken); err != nil { + return nil, NewAuthenticationError(ErrTokenExchangeFailed, err) + } + + if apiToken.Token == "" { + return nil, NewAuthenticationError(ErrTokenExchangeFailed, fmt.Errorf("empty copilot api token")) + } + + return &apiToken, nil +} + +// ValidateToken checks if a GitHub access token is valid by attempting to fetch user info. +func (c *CopilotAuth) ValidateToken(ctx context.Context, accessToken string) (bool, string, error) { + if accessToken == "" { + return false, "", nil + } + + username, err := c.deviceClient.FetchUserInfo(ctx, accessToken) + if err != nil { + return false, "", err + } + + return true, username, nil +} + +// CreateTokenStorage creates a new CopilotTokenStorage from auth bundle. +func (c *CopilotAuth) CreateTokenStorage(bundle *CopilotAuthBundle) *CopilotTokenStorage { + return &CopilotTokenStorage{ + AccessToken: bundle.TokenData.AccessToken, + TokenType: bundle.TokenData.TokenType, + Scope: bundle.TokenData.Scope, + Username: bundle.Username, + Type: "github-copilot", + } +} + +// LoadAndValidateToken loads a token from storage and validates it. +// Returns the storage if valid, or an error if the token is invalid or expired. +func (c *CopilotAuth) LoadAndValidateToken(ctx context.Context, storage *CopilotTokenStorage) (bool, error) { + if storage == nil || storage.AccessToken == "" { + return false, fmt.Errorf("no token available") + } + + // Check if we can still use the GitHub token to get a Copilot API token + apiToken, err := c.GetCopilotAPIToken(ctx, storage.AccessToken) + if err != nil { + return false, err + } + + // Check if the API token is expired + if apiToken.ExpiresAt > 0 && time.Now().Unix() >= apiToken.ExpiresAt { + return false, fmt.Errorf("copilot api token expired") + } + + return true, nil +} + +// GetAPIEndpoint returns the Copilot API endpoint URL. +func (c *CopilotAuth) GetAPIEndpoint() string { + return copilotAPIEndpoint +} + +// MakeAuthenticatedRequest creates an authenticated HTTP request to the Copilot API. +func (c *CopilotAuth) MakeAuthenticatedRequest(ctx context.Context, method, url string, body io.Reader, apiToken *CopilotAPIToken) (*http.Request, error) { + req, err := http.NewRequestWithContext(ctx, method, url, body) + if err != nil { + return nil, fmt.Errorf("failed to create request: %w", err) + } + + req.Header.Set("Authorization", "Bearer "+apiToken.Token) + req.Header.Set("Content-Type", "application/json") + req.Header.Set("Accept", "application/json") + req.Header.Set("User-Agent", "GithubCopilot/1.0") + req.Header.Set("Editor-Version", "vscode/1.100.0") + req.Header.Set("Editor-Plugin-Version", "copilot/1.300.0") + req.Header.Set("Openai-Intent", "conversation-panel") + req.Header.Set("Copilot-Integration-Id", "vscode-chat") + + return req, nil +} + +// CachedAPIToken manages caching of Copilot API tokens. +type CachedAPIToken struct { + Token *CopilotAPIToken + ExpiresAt time.Time +} + +// IsExpired checks if the cached token has expired. +func (c *CachedAPIToken) IsExpired() bool { + if c.Token == nil { + return true + } + // Add a 5-minute buffer before expiration + return time.Now().Add(5 * time.Minute).After(c.ExpiresAt) +} + +// TokenManager handles caching and refreshing of Copilot API tokens. +type TokenManager struct { + auth *CopilotAuth + githubToken string + cachedToken *CachedAPIToken +} + +// NewTokenManager creates a new token manager for handling Copilot API tokens. +func NewTokenManager(auth *CopilotAuth, githubToken string) *TokenManager { + return &TokenManager{ + auth: auth, + githubToken: githubToken, + } +} + +// GetToken returns a valid Copilot API token, refreshing if necessary. +func (tm *TokenManager) GetToken(ctx context.Context) (*CopilotAPIToken, error) { + if tm.cachedToken != nil && !tm.cachedToken.IsExpired() { + return tm.cachedToken.Token, nil + } + + // Fetch a new API token + apiToken, err := tm.auth.GetCopilotAPIToken(ctx, tm.githubToken) + if err != nil { + return nil, err + } + + // Cache the token + expiresAt := time.Now().Add(30 * time.Minute) // Default 30 min cache + if apiToken.ExpiresAt > 0 { + expiresAt = time.Unix(apiToken.ExpiresAt, 0) + } + + tm.cachedToken = &CachedAPIToken{ + Token: apiToken, + ExpiresAt: expiresAt, + } + + return apiToken, nil +} + +// GetAuthorizationHeader returns the authorization header value for API requests. +func (tm *TokenManager) GetAuthorizationHeader(ctx context.Context) (string, error) { + token, err := tm.GetToken(ctx) + if err != nil { + return "", err + } + return "Bearer " + token.Token, nil +} + +// UpdateGitHubToken updates the GitHub access token used for getting API tokens. +func (tm *TokenManager) UpdateGitHubToken(githubToken string) { + tm.githubToken = githubToken + tm.cachedToken = nil // Invalidate cache +} + +// BuildChatCompletionURL builds the URL for chat completions API. +func BuildChatCompletionURL() string { + return copilotAPIEndpoint + "/chat/completions" +} + +// BuildModelsURL builds the URL for listing available models. +func BuildModelsURL() string { + return copilotAPIEndpoint + "/models" +} + +// ExtractBearerToken extracts the bearer token from an Authorization header. +func ExtractBearerToken(authHeader string) string { + if strings.HasPrefix(authHeader, "Bearer ") { + return strings.TrimPrefix(authHeader, "Bearer ") + } + if strings.HasPrefix(authHeader, "token ") { + return strings.TrimPrefix(authHeader, "token ") + } + return authHeader +} diff --git a/internal/auth/copilot/errors.go b/internal/auth/copilot/errors.go new file mode 100644 index 00000000..01f8e754 --- /dev/null +++ b/internal/auth/copilot/errors.go @@ -0,0 +1,183 @@ +package copilot + +import ( + "errors" + "fmt" + "net/http" +) + +// OAuthError represents an OAuth-specific error. +type OAuthError struct { + // Code is the OAuth error code. + Code string `json:"error"` + // Description is a human-readable description of the error. + Description string `json:"error_description,omitempty"` + // URI is a URI identifying a human-readable web page with information about the error. + URI string `json:"error_uri,omitempty"` + // StatusCode is the HTTP status code associated with the error. + StatusCode int `json:"-"` +} + +// Error returns a string representation of the OAuth error. +func (e *OAuthError) Error() string { + if e.Description != "" { + return fmt.Sprintf("OAuth error %s: %s", e.Code, e.Description) + } + return fmt.Sprintf("OAuth error: %s", e.Code) +} + +// NewOAuthError creates a new OAuth error with the specified code, description, and status code. +func NewOAuthError(code, description string, statusCode int) *OAuthError { + return &OAuthError{ + Code: code, + Description: description, + StatusCode: statusCode, + } +} + +// AuthenticationError represents authentication-related errors. +type AuthenticationError struct { + // Type is the type of authentication error. + Type string `json:"type"` + // Message is a human-readable message describing the error. + Message string `json:"message"` + // Code is the HTTP status code associated with the error. + Code int `json:"code"` + // Cause is the underlying error that caused this authentication error. + Cause error `json:"-"` +} + +// Error returns a string representation of the authentication error. +func (e *AuthenticationError) Error() string { + if e.Cause != nil { + return fmt.Sprintf("%s: %s (caused by: %v)", e.Type, e.Message, e.Cause) + } + return fmt.Sprintf("%s: %s", e.Type, e.Message) +} + +// Common authentication error types for GitHub Copilot device flow. +var ( + // ErrDeviceCodeFailed represents an error when requesting the device code fails. + ErrDeviceCodeFailed = &AuthenticationError{ + Type: "device_code_failed", + Message: "Failed to request device code from GitHub", + Code: http.StatusBadRequest, + } + + // ErrDeviceCodeExpired represents an error when the device code has expired. + ErrDeviceCodeExpired = &AuthenticationError{ + Type: "device_code_expired", + Message: "Device code has expired. Please try again.", + Code: http.StatusGone, + } + + // ErrAuthorizationPending represents a pending authorization state (not an error, used for polling). + ErrAuthorizationPending = &AuthenticationError{ + Type: "authorization_pending", + Message: "Authorization is pending. Waiting for user to authorize.", + Code: http.StatusAccepted, + } + + // ErrSlowDown represents a request to slow down polling. + ErrSlowDown = &AuthenticationError{ + Type: "slow_down", + Message: "Polling too frequently. Slowing down.", + Code: http.StatusTooManyRequests, + } + + // ErrAccessDenied represents an error when the user denies authorization. + ErrAccessDenied = &AuthenticationError{ + Type: "access_denied", + Message: "User denied authorization", + Code: http.StatusForbidden, + } + + // ErrTokenExchangeFailed represents an error when token exchange fails. + ErrTokenExchangeFailed = &AuthenticationError{ + Type: "token_exchange_failed", + Message: "Failed to exchange device code for access token", + Code: http.StatusBadRequest, + } + + // ErrPollingTimeout represents an error when polling times out. + ErrPollingTimeout = &AuthenticationError{ + Type: "polling_timeout", + Message: "Timeout waiting for user authorization", + Code: http.StatusRequestTimeout, + } + + // ErrUserInfoFailed represents an error when fetching user info fails. + ErrUserInfoFailed = &AuthenticationError{ + Type: "user_info_failed", + Message: "Failed to fetch GitHub user information", + Code: http.StatusBadRequest, + } +) + +// NewAuthenticationError creates a new authentication error with a cause based on a base error. +func NewAuthenticationError(baseErr *AuthenticationError, cause error) *AuthenticationError { + return &AuthenticationError{ + Type: baseErr.Type, + Message: baseErr.Message, + Code: baseErr.Code, + Cause: cause, + } +} + +// IsAuthenticationError checks if an error is an authentication error. +func IsAuthenticationError(err error) bool { + var authenticationError *AuthenticationError + ok := errors.As(err, &authenticationError) + return ok +} + +// IsOAuthError checks if an error is an OAuth error. +func IsOAuthError(err error) bool { + var oAuthError *OAuthError + ok := errors.As(err, &oAuthError) + return ok +} + +// GetUserFriendlyMessage returns a user-friendly error message based on the error type. +func GetUserFriendlyMessage(err error) string { + switch { + case IsAuthenticationError(err): + var authErr *AuthenticationError + errors.As(err, &authErr) + switch authErr.Type { + case "device_code_failed": + return "Failed to start GitHub authentication. Please check your network connection and try again." + case "device_code_expired": + return "The authentication code has expired. Please try again." + case "authorization_pending": + return "Waiting for you to authorize the application on GitHub." + case "slow_down": + return "Please wait a moment before trying again." + case "access_denied": + return "Authentication was cancelled or denied." + case "token_exchange_failed": + return "Failed to complete authentication. Please try again." + case "polling_timeout": + return "Authentication timed out. Please try again." + case "user_info_failed": + return "Failed to get your GitHub account information. Please try again." + default: + return "Authentication failed. Please try again." + } + case IsOAuthError(err): + var oauthErr *OAuthError + errors.As(err, &oauthErr) + switch oauthErr.Code { + case "access_denied": + return "Authentication was cancelled or denied." + case "invalid_request": + return "Invalid authentication request. Please try again." + case "server_error": + return "GitHub server error. Please try again later." + default: + return fmt.Sprintf("Authentication failed: %s", oauthErr.Description) + } + default: + return "An unexpected error occurred. Please try again." + } +} diff --git a/internal/auth/copilot/oauth.go b/internal/auth/copilot/oauth.go new file mode 100644 index 00000000..3f7877b6 --- /dev/null +++ b/internal/auth/copilot/oauth.go @@ -0,0 +1,259 @@ +package copilot + +import ( + "context" + "encoding/json" + "fmt" + "io" + "net/http" + "net/url" + "strings" + "time" + + "github.com/router-for-me/CLIProxyAPI/v6/internal/config" + "github.com/router-for-me/CLIProxyAPI/v6/internal/util" + log "github.com/sirupsen/logrus" +) + +const ( + // copilotClientID is GitHub's Copilot CLI OAuth client ID. + copilotClientID = "Iv1.b507a08c87ecfe98" + // copilotDeviceCodeURL is the endpoint for requesting device codes. + copilotDeviceCodeURL = "https://github.com/login/device/code" + // copilotTokenURL is the endpoint for exchanging device codes for tokens. + copilotTokenURL = "https://github.com/login/oauth/access_token" + // copilotUserInfoURL is the endpoint for fetching GitHub user information. + copilotUserInfoURL = "https://api.github.com/user" + // defaultPollInterval is the default interval for polling token endpoint. + defaultPollInterval = 5 * time.Second + // maxPollDuration is the maximum time to wait for user authorization. + maxPollDuration = 15 * time.Minute +) + +// DeviceFlowClient handles the OAuth2 device flow for GitHub Copilot. +type DeviceFlowClient struct { + httpClient *http.Client + cfg *config.Config +} + +// NewDeviceFlowClient creates a new device flow client. +func NewDeviceFlowClient(cfg *config.Config) *DeviceFlowClient { + client := &http.Client{Timeout: 30 * time.Second} + if cfg != nil { + client = util.SetProxy(&cfg.SDKConfig, client) + } + return &DeviceFlowClient{ + httpClient: client, + cfg: cfg, + } +} + +// RequestDeviceCode initiates the device flow by requesting a device code from GitHub. +func (c *DeviceFlowClient) RequestDeviceCode(ctx context.Context) (*DeviceCodeResponse, error) { + data := url.Values{} + data.Set("client_id", copilotClientID) + data.Set("scope", "user:email") + + req, err := http.NewRequestWithContext(ctx, http.MethodPost, copilotDeviceCodeURL, strings.NewReader(data.Encode())) + if err != nil { + return nil, NewAuthenticationError(ErrDeviceCodeFailed, err) + } + req.Header.Set("Content-Type", "application/x-www-form-urlencoded") + req.Header.Set("Accept", "application/json") + + resp, err := c.httpClient.Do(req) + if err != nil { + return nil, NewAuthenticationError(ErrDeviceCodeFailed, err) + } + defer func() { + if errClose := resp.Body.Close(); errClose != nil { + log.Errorf("copilot device code: close body error: %v", errClose) + } + }() + + if resp.StatusCode < http.StatusOK || resp.StatusCode >= http.StatusMultipleChoices { + bodyBytes, _ := io.ReadAll(resp.Body) + return nil, NewAuthenticationError(ErrDeviceCodeFailed, fmt.Errorf("status %d: %s", resp.StatusCode, string(bodyBytes))) + } + + var deviceCode DeviceCodeResponse + if err = json.NewDecoder(resp.Body).Decode(&deviceCode); err != nil { + return nil, NewAuthenticationError(ErrDeviceCodeFailed, err) + } + + return &deviceCode, nil +} + +// PollForToken polls the token endpoint until the user authorizes or the device code expires. +func (c *DeviceFlowClient) PollForToken(ctx context.Context, deviceCode *DeviceCodeResponse) (*CopilotTokenData, error) { + if deviceCode == nil { + return nil, NewAuthenticationError(ErrTokenExchangeFailed, fmt.Errorf("device code is nil")) + } + + interval := time.Duration(deviceCode.Interval) * time.Second + if interval < defaultPollInterval { + interval = defaultPollInterval + } + + deadline := time.Now().Add(maxPollDuration) + if deviceCode.ExpiresIn > 0 { + codeDeadline := time.Now().Add(time.Duration(deviceCode.ExpiresIn) * time.Second) + if codeDeadline.Before(deadline) { + deadline = codeDeadline + } + } + + ticker := time.NewTicker(interval) + defer ticker.Stop() + + for { + select { + case <-ctx.Done(): + return nil, NewAuthenticationError(ErrPollingTimeout, ctx.Err()) + case <-ticker.C: + if time.Now().After(deadline) { + return nil, ErrPollingTimeout + } + + token, err := c.exchangeDeviceCode(ctx, deviceCode.DeviceCode) + if err != nil { + var authErr *AuthenticationError + if IsAuthenticationError(err) { + if ok := (err.(*AuthenticationError)); ok != nil { + authErr = ok + } + } + if authErr != nil { + switch authErr.Type { + case ErrAuthorizationPending.Type: + // Continue polling + continue + case ErrSlowDown.Type: + // Increase interval and continue + interval += 5 * time.Second + ticker.Reset(interval) + continue + case ErrDeviceCodeExpired.Type: + return nil, err + case ErrAccessDenied.Type: + return nil, err + } + } + return nil, err + } + return token, nil + } + } +} + +// exchangeDeviceCode attempts to exchange the device code for an access token. +func (c *DeviceFlowClient) exchangeDeviceCode(ctx context.Context, deviceCode string) (*CopilotTokenData, error) { + data := url.Values{} + data.Set("client_id", copilotClientID) + data.Set("device_code", deviceCode) + data.Set("grant_type", "urn:ietf:params:oauth:grant-type:device_code") + + req, err := http.NewRequestWithContext(ctx, http.MethodPost, copilotTokenURL, strings.NewReader(data.Encode())) + if err != nil { + return nil, NewAuthenticationError(ErrTokenExchangeFailed, err) + } + req.Header.Set("Content-Type", "application/x-www-form-urlencoded") + req.Header.Set("Accept", "application/json") + + resp, err := c.httpClient.Do(req) + if err != nil { + return nil, NewAuthenticationError(ErrTokenExchangeFailed, err) + } + defer func() { + if errClose := resp.Body.Close(); errClose != nil { + log.Errorf("copilot token exchange: close body error: %v", errClose) + } + }() + + bodyBytes, err := io.ReadAll(resp.Body) + if err != nil { + return nil, NewAuthenticationError(ErrTokenExchangeFailed, err) + } + + // GitHub returns 200 for both success and error cases in device flow + // Check for OAuth error response first + var oauthResp struct { + Error string `json:"error"` + ErrorDescription string `json:"error_description"` + AccessToken string `json:"access_token"` + TokenType string `json:"token_type"` + Scope string `json:"scope"` + } + + if err = json.Unmarshal(bodyBytes, &oauthResp); err != nil { + return nil, NewAuthenticationError(ErrTokenExchangeFailed, err) + } + + if oauthResp.Error != "" { + switch oauthResp.Error { + case "authorization_pending": + return nil, ErrAuthorizationPending + case "slow_down": + return nil, ErrSlowDown + case "expired_token": + return nil, ErrDeviceCodeExpired + case "access_denied": + return nil, ErrAccessDenied + default: + return nil, NewOAuthError(oauthResp.Error, oauthResp.ErrorDescription, resp.StatusCode) + } + } + + if oauthResp.AccessToken == "" { + return nil, NewAuthenticationError(ErrTokenExchangeFailed, fmt.Errorf("empty access token")) + } + + return &CopilotTokenData{ + AccessToken: oauthResp.AccessToken, + TokenType: oauthResp.TokenType, + Scope: oauthResp.Scope, + }, nil +} + +// FetchUserInfo retrieves the GitHub username for the authenticated user. +func (c *DeviceFlowClient) FetchUserInfo(ctx context.Context, accessToken string) (string, error) { + if accessToken == "" { + return "", NewAuthenticationError(ErrUserInfoFailed, fmt.Errorf("access token is empty")) + } + + req, err := http.NewRequestWithContext(ctx, http.MethodGet, copilotUserInfoURL, nil) + if err != nil { + return "", NewAuthenticationError(ErrUserInfoFailed, err) + } + req.Header.Set("Authorization", "Bearer "+accessToken) + req.Header.Set("Accept", "application/json") + req.Header.Set("User-Agent", "CLIProxyAPI") + + resp, err := c.httpClient.Do(req) + if err != nil { + return "", NewAuthenticationError(ErrUserInfoFailed, err) + } + defer func() { + if errClose := resp.Body.Close(); errClose != nil { + log.Errorf("copilot user info: close body error: %v", errClose) + } + }() + + if resp.StatusCode < http.StatusOK || resp.StatusCode >= http.StatusMultipleChoices { + bodyBytes, _ := io.ReadAll(resp.Body) + return "", NewAuthenticationError(ErrUserInfoFailed, fmt.Errorf("status %d: %s", resp.StatusCode, string(bodyBytes))) + } + + var userInfo struct { + Login string `json:"login"` + } + if err = json.NewDecoder(resp.Body).Decode(&userInfo); err != nil { + return "", NewAuthenticationError(ErrUserInfoFailed, err) + } + + if userInfo.Login == "" { + return "", NewAuthenticationError(ErrUserInfoFailed, fmt.Errorf("empty username")) + } + + return userInfo.Login, nil +} diff --git a/internal/auth/copilot/token.go b/internal/auth/copilot/token.go new file mode 100644 index 00000000..4e5eed6c --- /dev/null +++ b/internal/auth/copilot/token.go @@ -0,0 +1,93 @@ +// Package copilot provides authentication and token management functionality +// for GitHub Copilot AI services. It handles OAuth2 device flow token storage, +// serialization, and retrieval for maintaining authenticated sessions with the Copilot API. +package copilot + +import ( + "encoding/json" + "fmt" + "os" + "path/filepath" + + "github.com/router-for-me/CLIProxyAPI/v6/internal/misc" +) + +// CopilotTokenStorage stores OAuth2 token information for GitHub Copilot API authentication. +// It maintains compatibility with the existing auth system while adding Copilot-specific fields +// for managing access tokens and user account information. +type CopilotTokenStorage struct { + // AccessToken is the OAuth2 access token used for authenticating API requests. + AccessToken string `json:"access_token"` + // TokenType is the type of token, typically "bearer". + TokenType string `json:"token_type"` + // Scope is the OAuth2 scope granted to the token. + Scope string `json:"scope"` + // ExpiresAt is the timestamp when the access token expires (if provided). + ExpiresAt string `json:"expires_at,omitempty"` + // Username is the GitHub username associated with this token. + Username string `json:"username"` + // Type indicates the authentication provider type, always "github-copilot" for this storage. + Type string `json:"type"` +} + +// CopilotTokenData holds the raw OAuth token response from GitHub. +type CopilotTokenData struct { + // AccessToken is the OAuth2 access token. + AccessToken string `json:"access_token"` + // TokenType is the type of token, typically "bearer". + TokenType string `json:"token_type"` + // Scope is the OAuth2 scope granted to the token. + Scope string `json:"scope"` +} + +// CopilotAuthBundle bundles authentication data for storage. +type CopilotAuthBundle struct { + // TokenData contains the OAuth token information. + TokenData *CopilotTokenData + // Username is the GitHub username. + Username string +} + +// DeviceCodeResponse represents GitHub's device code response. +type DeviceCodeResponse struct { + // DeviceCode is the device verification code. + DeviceCode string `json:"device_code"` + // UserCode is the code the user must enter at the verification URI. + UserCode string `json:"user_code"` + // VerificationURI is the URL where the user should enter the code. + VerificationURI string `json:"verification_uri"` + // ExpiresIn is the number of seconds until the device code expires. + ExpiresIn int `json:"expires_in"` + // Interval is the minimum number of seconds to wait between polling requests. + Interval int `json:"interval"` +} + +// SaveTokenToFile serializes the Copilot token storage to a JSON file. +// This method creates the necessary directory structure and writes the token +// data in JSON format to the specified file path for persistent storage. +// +// Parameters: +// - authFilePath: The full path where the token file should be saved +// +// Returns: +// - error: An error if the operation fails, nil otherwise +func (ts *CopilotTokenStorage) SaveTokenToFile(authFilePath string) error { + misc.LogSavingCredentials(authFilePath) + ts.Type = "github-copilot" + if err := os.MkdirAll(filepath.Dir(authFilePath), 0700); err != nil { + return fmt.Errorf("failed to create directory: %v", err) + } + + f, err := os.Create(authFilePath) + if err != nil { + return fmt.Errorf("failed to create token file: %w", err) + } + defer func() { + _ = f.Close() + }() + + if err = json.NewEncoder(f).Encode(ts); err != nil { + return fmt.Errorf("failed to write token to file: %w", err) + } + return nil +} diff --git a/internal/cmd/auth_manager.go b/internal/cmd/auth_manager.go index e6caa954..baf43bae 100644 --- a/internal/cmd/auth_manager.go +++ b/internal/cmd/auth_manager.go @@ -6,7 +6,7 @@ import ( // newAuthManager creates a new authentication manager instance with all supported // authenticators and a file-based token store. It initializes authenticators for -// Gemini, Codex, Claude, and Qwen providers. +// Gemini, Codex, Claude, Qwen, IFlow, Antigravity, and GitHub Copilot providers. // // Returns: // - *sdkAuth.Manager: A configured authentication manager instance @@ -19,6 +19,7 @@ func newAuthManager() *sdkAuth.Manager { sdkAuth.NewQwenAuthenticator(), sdkAuth.NewIFlowAuthenticator(), sdkAuth.NewAntigravityAuthenticator(), + sdkAuth.NewGitHubCopilotAuthenticator(), ) return manager } diff --git a/internal/cmd/github_copilot_login.go b/internal/cmd/github_copilot_login.go new file mode 100644 index 00000000..056e811f --- /dev/null +++ b/internal/cmd/github_copilot_login.go @@ -0,0 +1,44 @@ +package cmd + +import ( + "context" + "fmt" + + "github.com/router-for-me/CLIProxyAPI/v6/internal/config" + sdkAuth "github.com/router-for-me/CLIProxyAPI/v6/sdk/auth" + log "github.com/sirupsen/logrus" +) + +// DoGitHubCopilotLogin triggers the OAuth device flow for GitHub Copilot and saves tokens. +// It initiates the device flow authentication, displays the user code for the user to enter +// at GitHub's verification URL, and waits for authorization before saving the tokens. +// +// Parameters: +// - cfg: The application configuration containing proxy and auth directory settings +// - options: Login options including browser behavior settings +func DoGitHubCopilotLogin(cfg *config.Config, options *LoginOptions) { + if options == nil { + options = &LoginOptions{} + } + + manager := newAuthManager() + authOpts := &sdkAuth.LoginOptions{ + NoBrowser: options.NoBrowser, + Metadata: map[string]string{}, + Prompt: options.Prompt, + } + + record, savedPath, err := manager.Login(context.Background(), "github-copilot", cfg, authOpts) + if err != nil { + log.Errorf("GitHub Copilot authentication failed: %v", err) + return + } + + if savedPath != "" { + fmt.Printf("Authentication saved to %s\n", savedPath) + } + if record != nil && record.Label != "" { + fmt.Printf("Authenticated as %s\n", record.Label) + } + fmt.Println("GitHub Copilot authentication successful!") +} diff --git a/internal/registry/model_definitions.go b/internal/registry/model_definitions.go index 51f984f2..42e68239 100644 --- a/internal/registry/model_definitions.go +++ b/internal/registry/model_definitions.go @@ -984,3 +984,187 @@ func GetIFlowModels() []*ModelInfo { } return models } + +// GetGitHubCopilotModels returns the available models for GitHub Copilot. +// These models are available through the GitHub Copilot API at api.githubcopilot.com. +func GetGitHubCopilotModels() []*ModelInfo { + now := int64(1732752000) // 2024-11-27 + return []*ModelInfo{ + { + ID: "gpt-4.1", + Object: "model", + Created: now, + OwnedBy: "github-copilot", + Type: "github-copilot", + DisplayName: "GPT-4.1", + Description: "OpenAI GPT-4.1 via GitHub Copilot", + ContextLength: 128000, + MaxCompletionTokens: 16384, + }, + { + ID: "gpt-5", + Object: "model", + Created: now, + OwnedBy: "github-copilot", + Type: "github-copilot", + DisplayName: "GPT-5", + Description: "OpenAI GPT-5 via GitHub Copilot", + ContextLength: 200000, + MaxCompletionTokens: 32768, + }, + { + ID: "gpt-5-mini", + Object: "model", + Created: now, + OwnedBy: "github-copilot", + Type: "github-copilot", + DisplayName: "GPT-5 Mini", + Description: "OpenAI GPT-5 Mini via GitHub Copilot", + ContextLength: 128000, + MaxCompletionTokens: 16384, + }, + { + ID: "gpt-5-codex", + Object: "model", + Created: now, + OwnedBy: "github-copilot", + Type: "github-copilot", + DisplayName: "GPT-5 Codex", + Description: "OpenAI GPT-5 Codex via GitHub Copilot", + ContextLength: 200000, + MaxCompletionTokens: 32768, + }, + { + ID: "gpt-5.1", + Object: "model", + Created: now, + OwnedBy: "github-copilot", + Type: "github-copilot", + DisplayName: "GPT-5.1", + Description: "OpenAI GPT-5.1 via GitHub Copilot", + ContextLength: 200000, + MaxCompletionTokens: 32768, + }, + { + ID: "gpt-5.1-codex", + Object: "model", + Created: now, + OwnedBy: "github-copilot", + Type: "github-copilot", + DisplayName: "GPT-5.1 Codex", + Description: "OpenAI GPT-5.1 Codex via GitHub Copilot", + ContextLength: 200000, + MaxCompletionTokens: 32768, + }, + { + ID: "gpt-5.1-codex-mini", + Object: "model", + Created: now, + OwnedBy: "github-copilot", + Type: "github-copilot", + DisplayName: "GPT-5.1 Codex Mini", + Description: "OpenAI GPT-5.1 Codex Mini via GitHub Copilot", + ContextLength: 128000, + MaxCompletionTokens: 16384, + }, + { + ID: "claude-haiku-4.5", + Object: "model", + Created: now, + OwnedBy: "github-copilot", + Type: "github-copilot", + DisplayName: "Claude Haiku 4.5", + Description: "Anthropic Claude Haiku 4.5 via GitHub Copilot", + ContextLength: 200000, + MaxCompletionTokens: 64000, + }, + { + ID: "claude-opus-4.1", + Object: "model", + Created: now, + OwnedBy: "github-copilot", + Type: "github-copilot", + DisplayName: "Claude Opus 4.1", + Description: "Anthropic Claude Opus 4.1 via GitHub Copilot", + ContextLength: 200000, + MaxCompletionTokens: 32000, + }, + { + ID: "claude-opus-4.5", + Object: "model", + Created: now, + OwnedBy: "github-copilot", + Type: "github-copilot", + DisplayName: "Claude Opus 4.5", + Description: "Anthropic Claude Opus 4.5 via GitHub Copilot", + ContextLength: 200000, + MaxCompletionTokens: 64000, + }, + { + ID: "claude-sonnet-4", + Object: "model", + Created: now, + OwnedBy: "github-copilot", + Type: "github-copilot", + DisplayName: "Claude Sonnet 4", + Description: "Anthropic Claude Sonnet 4 via GitHub Copilot", + ContextLength: 200000, + MaxCompletionTokens: 64000, + }, + { + ID: "claude-sonnet-4.5", + Object: "model", + Created: now, + OwnedBy: "github-copilot", + Type: "github-copilot", + DisplayName: "Claude Sonnet 4.5", + Description: "Anthropic Claude Sonnet 4.5 via GitHub Copilot", + ContextLength: 200000, + MaxCompletionTokens: 64000, + }, + { + ID: "gemini-2.5-pro", + Object: "model", + Created: now, + OwnedBy: "github-copilot", + Type: "github-copilot", + DisplayName: "Gemini 2.5 Pro", + Description: "Google Gemini 2.5 Pro via GitHub Copilot", + ContextLength: 1048576, + MaxCompletionTokens: 65536, + }, + { + ID: "gemini-3-pro", + Object: "model", + Created: now, + OwnedBy: "github-copilot", + Type: "github-copilot", + DisplayName: "Gemini 3 Pro", + Description: "Google Gemini 3 Pro via GitHub Copilot", + ContextLength: 1048576, + MaxCompletionTokens: 65536, + }, + { + ID: "grok-code-fast-1", + Object: "model", + Created: now, + OwnedBy: "github-copilot", + Type: "github-copilot", + DisplayName: "Grok Code Fast 1", + Description: "xAI Grok Code Fast 1 via GitHub Copilot", + ContextLength: 128000, + MaxCompletionTokens: 16384, + }, + { + ID: "raptor-mini", + Object: "model", + Created: now, + OwnedBy: "github-copilot", + Type: "github-copilot", + DisplayName: "Raptor Mini", + Description: "Raptor Mini via GitHub Copilot", + ContextLength: 128000, + MaxCompletionTokens: 16384, + }, + } +} diff --git a/internal/runtime/executor/github_copilot_executor.go b/internal/runtime/executor/github_copilot_executor.go new file mode 100644 index 00000000..0d1240f7 --- /dev/null +++ b/internal/runtime/executor/github_copilot_executor.go @@ -0,0 +1,354 @@ +package executor + +import ( + "bufio" + "bytes" + "context" + "fmt" + "io" + "net/http" + "time" + + "github.com/google/uuid" + copilotauth "github.com/router-for-me/CLIProxyAPI/v6/internal/auth/copilot" + "github.com/router-for-me/CLIProxyAPI/v6/internal/config" + cliproxyauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth" + cliproxyexecutor "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/executor" + sdktranslator "github.com/router-for-me/CLIProxyAPI/v6/sdk/translator" + log "github.com/sirupsen/logrus" + "github.com/tidwall/sjson" +) + +const ( + githubCopilotBaseURL = "https://api.githubcopilot.com" + githubCopilotChatPath = "/chat/completions" + githubCopilotAuthType = "github-copilot" + githubCopilotTokenCacheTTL = 25 * time.Minute +) + +// GitHubCopilotExecutor handles requests to the GitHub Copilot API. +type GitHubCopilotExecutor struct { + cfg *config.Config +} + +// NewGitHubCopilotExecutor constructs a new executor instance. +func NewGitHubCopilotExecutor(cfg *config.Config) *GitHubCopilotExecutor { + return &GitHubCopilotExecutor{cfg: cfg} +} + +// Identifier implements ProviderExecutor. +func (e *GitHubCopilotExecutor) Identifier() string { return githubCopilotAuthType } + +// PrepareRequest implements ProviderExecutor. +func (e *GitHubCopilotExecutor) PrepareRequest(_ *http.Request, _ *cliproxyauth.Auth) error { + return nil +} + +// Execute handles non-streaming requests to GitHub Copilot. +func (e *GitHubCopilotExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (resp cliproxyexecutor.Response, err error) { + apiToken, errToken := e.ensureAPIToken(ctx, auth) + if errToken != nil { + return resp, errToken + } + + reporter := newUsageReporter(ctx, e.Identifier(), req.Model, auth) + defer reporter.trackFailure(ctx, &err) + + from := opts.SourceFormat + to := sdktranslator.FromString("openai") + body := sdktranslator.TranslateRequest(from, to, req.Model, bytes.Clone(req.Payload), false) + body = e.normalizeModel(req.Model, body) + body = applyPayloadConfig(e.cfg, req.Model, body) + body, _ = sjson.SetBytes(body, "stream", false) + + url := githubCopilotBaseURL + githubCopilotChatPath + httpReq, err := http.NewRequestWithContext(ctx, http.MethodPost, url, bytes.NewReader(body)) + if err != nil { + return resp, err + } + e.applyHeaders(httpReq, apiToken) + + var authID, authLabel, authType, authValue string + if auth != nil { + authID = auth.ID + authLabel = auth.Label + authType, authValue = auth.AccountInfo() + } + recordAPIRequest(ctx, e.cfg, upstreamRequestLog{ + URL: url, + Method: http.MethodPost, + Headers: httpReq.Header.Clone(), + Body: body, + Provider: e.Identifier(), + AuthID: authID, + AuthLabel: authLabel, + AuthType: authType, + AuthValue: authValue, + }) + + httpClient := newProxyAwareHTTPClient(ctx, e.cfg, auth, 0) + httpResp, err := httpClient.Do(httpReq) + if err != nil { + recordAPIResponseError(ctx, e.cfg, err) + return resp, err + } + defer func() { + if errClose := httpResp.Body.Close(); errClose != nil { + log.Errorf("github-copilot executor: close response body error: %v", errClose) + } + }() + + recordAPIResponseMetadata(ctx, e.cfg, httpResp.StatusCode, httpResp.Header.Clone()) + + if httpResp.StatusCode < http.StatusOK || httpResp.StatusCode >= http.StatusMultipleChoices { + data, _ := io.ReadAll(httpResp.Body) + appendAPIResponseChunk(ctx, e.cfg, data) + log.Debugf("github-copilot executor: upstream error status: %d, body: %s", httpResp.StatusCode, summarizeErrorBody(httpResp.Header.Get("Content-Type"), data)) + err = statusErr{code: httpResp.StatusCode, msg: string(data)} + return resp, err + } + + data, err := io.ReadAll(httpResp.Body) + if err != nil { + recordAPIResponseError(ctx, e.cfg, err) + return resp, err + } + appendAPIResponseChunk(ctx, e.cfg, data) + + detail := parseOpenAIUsage(data) + if detail.TotalTokens > 0 { + reporter.publish(ctx, detail) + } + + var param any + converted := sdktranslator.TranslateNonStream(ctx, to, from, req.Model, bytes.Clone(opts.OriginalRequest), body, data, ¶m) + resp = cliproxyexecutor.Response{Payload: []byte(converted)} + reporter.ensurePublished(ctx) + return resp, nil +} + +// ExecuteStream handles streaming requests to GitHub Copilot. +func (e *GitHubCopilotExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (stream <-chan cliproxyexecutor.StreamChunk, err error) { + apiToken, errToken := e.ensureAPIToken(ctx, auth) + if errToken != nil { + return nil, errToken + } + + reporter := newUsageReporter(ctx, e.Identifier(), req.Model, auth) + defer reporter.trackFailure(ctx, &err) + + from := opts.SourceFormat + to := sdktranslator.FromString("openai") + body := sdktranslator.TranslateRequest(from, to, req.Model, bytes.Clone(req.Payload), true) + body = e.normalizeModel(req.Model, body) + body = applyPayloadConfig(e.cfg, req.Model, body) + body, _ = sjson.SetBytes(body, "stream", true) + // Enable stream options for usage stats in stream + body, _ = sjson.SetBytes(body, "stream_options.include_usage", true) + + url := githubCopilotBaseURL + githubCopilotChatPath + httpReq, err := http.NewRequestWithContext(ctx, http.MethodPost, url, bytes.NewReader(body)) + if err != nil { + return nil, err + } + e.applyHeaders(httpReq, apiToken) + + var authID, authLabel, authType, authValue string + if auth != nil { + authID = auth.ID + authLabel = auth.Label + authType, authValue = auth.AccountInfo() + } + recordAPIRequest(ctx, e.cfg, upstreamRequestLog{ + URL: url, + Method: http.MethodPost, + Headers: httpReq.Header.Clone(), + Body: body, + Provider: e.Identifier(), + AuthID: authID, + AuthLabel: authLabel, + AuthType: authType, + AuthValue: authValue, + }) + + httpClient := newProxyAwareHTTPClient(ctx, e.cfg, auth, 0) + httpResp, err := httpClient.Do(httpReq) + if err != nil { + recordAPIResponseError(ctx, e.cfg, err) + return nil, err + } + + recordAPIResponseMetadata(ctx, e.cfg, httpResp.StatusCode, httpResp.Header.Clone()) + + if httpResp.StatusCode < http.StatusOK || httpResp.StatusCode >= http.StatusMultipleChoices { + data, readErr := io.ReadAll(httpResp.Body) + if errClose := httpResp.Body.Close(); errClose != nil { + log.Errorf("github-copilot executor: close response body error: %v", errClose) + } + if readErr != nil { + recordAPIResponseError(ctx, e.cfg, readErr) + return nil, readErr + } + appendAPIResponseChunk(ctx, e.cfg, data) + log.Debugf("github-copilot executor: upstream error status: %d, body: %s", httpResp.StatusCode, summarizeErrorBody(httpResp.Header.Get("Content-Type"), data)) + err = statusErr{code: httpResp.StatusCode, msg: string(data)} + return nil, err + } + + out := make(chan cliproxyexecutor.StreamChunk) + stream = out + + go func() { + defer close(out) + defer func() { + if errClose := httpResp.Body.Close(); errClose != nil { + log.Errorf("github-copilot executor: close response body error: %v", errClose) + } + }() + + scanner := bufio.NewScanner(httpResp.Body) + scanner.Buffer(nil, 20_971_520) + var param any + + for scanner.Scan() { + line := scanner.Bytes() + appendAPIResponseChunk(ctx, e.cfg, line) + + // Parse SSE data + if bytes.HasPrefix(line, dataTag) { + data := bytes.TrimSpace(line[5:]) + if bytes.Equal(data, []byte("[DONE]")) { + continue + } + if detail, ok := parseOpenAIStreamUsage(line); ok { + reporter.publish(ctx, detail) + } + } + + chunks := sdktranslator.TranslateStream(ctx, to, from, req.Model, bytes.Clone(opts.OriginalRequest), body, bytes.Clone(line), ¶m) + for i := range chunks { + out <- cliproxyexecutor.StreamChunk{Payload: []byte(chunks[i])} + } + } + + if errScan := scanner.Err(); errScan != nil { + recordAPIResponseError(ctx, e.cfg, errScan) + reporter.publishFailure(ctx) + out <- cliproxyexecutor.StreamChunk{Err: errScan} + } else { + reporter.ensurePublished(ctx) + } + }() + + return stream, nil +} + +// CountTokens is not supported for GitHub Copilot. +func (e *GitHubCopilotExecutor) CountTokens(_ context.Context, _ *cliproxyauth.Auth, _ cliproxyexecutor.Request, _ cliproxyexecutor.Options) (cliproxyexecutor.Response, error) { + return cliproxyexecutor.Response{}, statusErr{code: http.StatusNotImplemented, msg: "count tokens not supported for github-copilot"} +} + +// Refresh validates the GitHub token is still working. +// GitHub OAuth tokens don't expire traditionally, so we just validate. +func (e *GitHubCopilotExecutor) Refresh(ctx context.Context, auth *cliproxyauth.Auth) (*cliproxyauth.Auth, error) { + if auth == nil { + return nil, statusErr{code: http.StatusUnauthorized, msg: "missing auth"} + } + + // Get the GitHub access token + accessToken := metaStringValue(auth.Metadata, "access_token") + if accessToken == "" { + return auth, nil + } + + // Validate the token can still get a Copilot API token + copilotAuth := copilotauth.NewCopilotAuth(e.cfg) + _, err := copilotAuth.GetCopilotAPIToken(ctx, accessToken) + if err != nil { + return nil, statusErr{code: http.StatusUnauthorized, msg: fmt.Sprintf("github-copilot token validation failed: %v", err)} + } + + return auth, nil +} + +// ensureAPIToken gets or refreshes the Copilot API token. +func (e *GitHubCopilotExecutor) ensureAPIToken(ctx context.Context, auth *cliproxyauth.Auth) (string, error) { + if auth == nil { + return "", statusErr{code: http.StatusUnauthorized, msg: "missing auth"} + } + + // Check for cached API token + if cachedToken := metaStringValue(auth.Metadata, "copilot_api_token"); cachedToken != "" { + if expiresAt := tokenExpiry(auth.Metadata); expiresAt.After(time.Now().Add(5 * time.Minute)) { + return cachedToken, nil + } + } + + // Get the GitHub access token + accessToken := metaStringValue(auth.Metadata, "access_token") + if accessToken == "" { + return "", statusErr{code: http.StatusUnauthorized, msg: "missing github access token"} + } + + // Get a new Copilot API token + copilotAuth := copilotauth.NewCopilotAuth(e.cfg) + apiToken, err := copilotAuth.GetCopilotAPIToken(ctx, accessToken) + if err != nil { + return "", statusErr{code: http.StatusUnauthorized, msg: fmt.Sprintf("failed to get copilot api token: %v", err)} + } + + // Cache the token in metadata (will be persisted on next save) + if auth.Metadata == nil { + auth.Metadata = make(map[string]any) + } + auth.Metadata["copilot_api_token"] = apiToken.Token + if apiToken.ExpiresAt > 0 { + auth.Metadata["expired"] = time.Unix(apiToken.ExpiresAt, 0).Format(time.RFC3339) + } else { + auth.Metadata["expired"] = time.Now().Add(githubCopilotTokenCacheTTL).Format(time.RFC3339) + } + + return apiToken.Token, nil +} + +// applyHeaders sets the required headers for GitHub Copilot API requests. +func (e *GitHubCopilotExecutor) applyHeaders(r *http.Request, apiToken string) { + r.Header.Set("Content-Type", "application/json") + r.Header.Set("Authorization", "Bearer "+apiToken) + r.Header.Set("Accept", "application/json") + r.Header.Set("User-Agent", "GithubCopilot/1.0") + r.Header.Set("Editor-Version", "vscode/1.100.0") + r.Header.Set("Editor-Plugin-Version", "copilot/1.300.0") + r.Header.Set("Openai-Intent", "conversation-panel") + r.Header.Set("Copilot-Integration-Id", "vscode-chat") + r.Header.Set("X-Request-Id", uuid.NewString()) +} + +// normalizeModel ensures the model name is correct for the API. +func (e *GitHubCopilotExecutor) normalizeModel(requestedModel string, body []byte) []byte { + // Map friendly names to API model names + modelMap := map[string]string{ + "gpt-4.1": "gpt-4.1", + "gpt-5": "gpt-5", + "gpt-5-mini": "gpt-5-mini", + "gpt-5-codex": "gpt-5-codex", + "gpt-5.1": "gpt-5.1", + "gpt-5.1-codex": "gpt-5.1-codex", + "gpt-5.1-codex-mini": "gpt-5.1-codex-mini", + "claude-haiku-4.5": "claude-haiku-4.5", + "claude-opus-4.1": "claude-opus-4.1", + "claude-opus-4.5": "claude-opus-4.5", + "claude-sonnet-4": "claude-sonnet-4", + "claude-sonnet-4.5": "claude-sonnet-4.5", + "gemini-2.5-pro": "gemini-2.5-pro", + "gemini-3-pro": "gemini-3-pro", + "grok-code-fast-1": "grok-code-fast-1", + "raptor-mini": "raptor-mini", + } + + if mapped, ok := modelMap[requestedModel]; ok { + body, _ = sjson.SetBytes(body, "model", mapped) + } + + return body +} diff --git a/sdk/auth/github_copilot.go b/sdk/auth/github_copilot.go new file mode 100644 index 00000000..3ffcf133 --- /dev/null +++ b/sdk/auth/github_copilot.go @@ -0,0 +1,133 @@ +package auth + +import ( + "context" + "fmt" + "time" + + "github.com/router-for-me/CLIProxyAPI/v6/internal/auth/copilot" + "github.com/router-for-me/CLIProxyAPI/v6/internal/browser" + "github.com/router-for-me/CLIProxyAPI/v6/internal/config" + coreauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth" + log "github.com/sirupsen/logrus" +) + +// GitHubCopilotAuthenticator implements the OAuth device flow login for GitHub Copilot. +type GitHubCopilotAuthenticator struct{} + +// NewGitHubCopilotAuthenticator constructs a new GitHub Copilot authenticator. +func NewGitHubCopilotAuthenticator() Authenticator { + return &GitHubCopilotAuthenticator{} +} + +// Provider returns the provider key for github-copilot. +func (GitHubCopilotAuthenticator) Provider() string { + return "github-copilot" +} + +// RefreshLead returns nil since GitHub OAuth tokens don't expire in the traditional sense. +// The token remains valid until the user revokes it or the Copilot subscription expires. +func (GitHubCopilotAuthenticator) RefreshLead() *time.Duration { + return nil +} + +// Login initiates the GitHub device flow authentication for Copilot access. +func (a GitHubCopilotAuthenticator) Login(ctx context.Context, cfg *config.Config, opts *LoginOptions) (*coreauth.Auth, error) { + if cfg == nil { + return nil, fmt.Errorf("cliproxy auth: configuration is required") + } + if ctx == nil { + ctx = context.Background() + } + if opts == nil { + opts = &LoginOptions{} + } + + authSvc := copilot.NewCopilotAuth(cfg) + + // Start the device flow + fmt.Println("Starting GitHub Copilot authentication...") + deviceCode, err := authSvc.StartDeviceFlow(ctx) + if err != nil { + return nil, fmt.Errorf("github-copilot: failed to start device flow: %w", err) + } + + // Display the user code and verification URL + fmt.Printf("\nTo authenticate, please visit: %s\n", deviceCode.VerificationURI) + fmt.Printf("And enter the code: %s\n\n", deviceCode.UserCode) + + // Try to open the browser automatically + if !opts.NoBrowser { + if browser.IsAvailable() { + if errOpen := browser.OpenURL(deviceCode.VerificationURI); errOpen != nil { + log.Warnf("Failed to open browser automatically: %v", errOpen) + } + } + } + + fmt.Println("Waiting for GitHub authorization...") + fmt.Printf("(This will timeout in %d seconds if not authorized)\n", deviceCode.ExpiresIn) + + // Wait for user authorization + authBundle, err := authSvc.WaitForAuthorization(ctx, deviceCode) + if err != nil { + errMsg := copilot.GetUserFriendlyMessage(err) + return nil, fmt.Errorf("github-copilot: %s", errMsg) + } + + // Verify the token can get a Copilot API token + fmt.Println("Verifying Copilot access...") + apiToken, err := authSvc.GetCopilotAPIToken(ctx, authBundle.TokenData.AccessToken) + if err != nil { + return nil, fmt.Errorf("github-copilot: failed to verify Copilot access - you may not have an active Copilot subscription: %w", err) + } + + // Create the token storage + tokenStorage := authSvc.CreateTokenStorage(authBundle) + + // Build metadata + metadata := map[string]any{ + "type": "github-copilot", + "username": authBundle.Username, + "access_token": authBundle.TokenData.AccessToken, + "token_type": authBundle.TokenData.TokenType, + "scope": authBundle.TokenData.Scope, + "timestamp": time.Now().UnixMilli(), + "api_endpoint": copilot.BuildChatCompletionURL(), + } + + if apiToken.ExpiresAt > 0 { + metadata["api_token_expires_at"] = apiToken.ExpiresAt + } + + fileName := fmt.Sprintf("github-copilot-%s.json", authBundle.Username) + + fmt.Printf("\nGitHub Copilot authentication successful for user: %s\n", authBundle.Username) + + return &coreauth.Auth{ + ID: fileName, + Provider: a.Provider(), + FileName: fileName, + Label: authBundle.Username, + Storage: tokenStorage, + Metadata: metadata, + }, nil +} + +// RefreshGitHubCopilotToken validates and returns the current token status. +// GitHub OAuth tokens don't need traditional refresh - we just validate they still work. +func RefreshGitHubCopilotToken(ctx context.Context, cfg *config.Config, storage *copilot.CopilotTokenStorage) error { + if storage == nil || storage.AccessToken == "" { + return fmt.Errorf("no token available") + } + + authSvc := copilot.NewCopilotAuth(cfg) + + // Validate the token can still get a Copilot API token + _, err := authSvc.GetCopilotAPIToken(ctx, storage.AccessToken) + if err != nil { + return fmt.Errorf("token validation failed: %w", err) + } + + return nil +} diff --git a/sdk/auth/refresh_registry.go b/sdk/auth/refresh_registry.go index e82ac684..e8997f71 100644 --- a/sdk/auth/refresh_registry.go +++ b/sdk/auth/refresh_registry.go @@ -14,6 +14,7 @@ func init() { registerRefreshLead("gemini", func() Authenticator { return NewGeminiAuthenticator() }) registerRefreshLead("gemini-cli", func() Authenticator { return NewGeminiAuthenticator() }) registerRefreshLead("antigravity", func() Authenticator { return NewAntigravityAuthenticator() }) + registerRefreshLead("github-copilot", func() Authenticator { return NewGitHubCopilotAuthenticator() }) } func registerRefreshLead(provider string, factory func() Authenticator) { diff --git a/sdk/cliproxy/service.go b/sdk/cliproxy/service.go index 6e303ed2..8b66a9a9 100644 --- a/sdk/cliproxy/service.go +++ b/sdk/cliproxy/service.go @@ -351,6 +351,8 @@ func (s *Service) ensureExecutorsForAuth(a *coreauth.Auth) { s.coreManager.RegisterExecutor(executor.NewQwenExecutor(s.cfg)) case "iflow": s.coreManager.RegisterExecutor(executor.NewIFlowExecutor(s.cfg)) + case "github-copilot": + s.coreManager.RegisterExecutor(executor.NewGitHubCopilotExecutor(s.cfg)) default: providerKey := strings.ToLower(strings.TrimSpace(a.Provider)) if providerKey == "" { @@ -662,6 +664,8 @@ func (s *Service) registerModelsForAuth(a *coreauth.Auth) { models = registry.GetQwenModels() case "iflow": models = registry.GetIFlowModels() + case "github-copilot": + models = registry.GetGitHubCopilotModels() default: // Handle OpenAI-compatibility providers by name using config if s.cfg != nil { From efd28bf981a33f15323f655411ef5405a26b80e4 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Ernesto=20Mart=C3=ADnez?= Date: Thu, 27 Nov 2025 22:47:51 +0100 Subject: [PATCH 02/16] refactor(copilot): address PR review feedback - Simplify error type checking in oauth.go using errors.As directly - Remove redundant errors.As call in GetUserFriendlyMessage - Remove unused CachedAPIToken and TokenManager types (dead code) --- internal/auth/copilot/copilot_auth.go | 71 --------------------------- internal/auth/copilot/errors.go | 17 +++---- internal/auth/copilot/oauth.go | 8 +-- 3 files changed, 10 insertions(+), 86 deletions(-) diff --git a/internal/auth/copilot/copilot_auth.go b/internal/auth/copilot/copilot_auth.go index fbfb1762..e4e5876b 100644 --- a/internal/auth/copilot/copilot_auth.go +++ b/internal/auth/copilot/copilot_auth.go @@ -208,77 +208,6 @@ func (c *CopilotAuth) MakeAuthenticatedRequest(ctx context.Context, method, url return req, nil } -// CachedAPIToken manages caching of Copilot API tokens. -type CachedAPIToken struct { - Token *CopilotAPIToken - ExpiresAt time.Time -} - -// IsExpired checks if the cached token has expired. -func (c *CachedAPIToken) IsExpired() bool { - if c.Token == nil { - return true - } - // Add a 5-minute buffer before expiration - return time.Now().Add(5 * time.Minute).After(c.ExpiresAt) -} - -// TokenManager handles caching and refreshing of Copilot API tokens. -type TokenManager struct { - auth *CopilotAuth - githubToken string - cachedToken *CachedAPIToken -} - -// NewTokenManager creates a new token manager for handling Copilot API tokens. -func NewTokenManager(auth *CopilotAuth, githubToken string) *TokenManager { - return &TokenManager{ - auth: auth, - githubToken: githubToken, - } -} - -// GetToken returns a valid Copilot API token, refreshing if necessary. -func (tm *TokenManager) GetToken(ctx context.Context) (*CopilotAPIToken, error) { - if tm.cachedToken != nil && !tm.cachedToken.IsExpired() { - return tm.cachedToken.Token, nil - } - - // Fetch a new API token - apiToken, err := tm.auth.GetCopilotAPIToken(ctx, tm.githubToken) - if err != nil { - return nil, err - } - - // Cache the token - expiresAt := time.Now().Add(30 * time.Minute) // Default 30 min cache - if apiToken.ExpiresAt > 0 { - expiresAt = time.Unix(apiToken.ExpiresAt, 0) - } - - tm.cachedToken = &CachedAPIToken{ - Token: apiToken, - ExpiresAt: expiresAt, - } - - return apiToken, nil -} - -// GetAuthorizationHeader returns the authorization header value for API requests. -func (tm *TokenManager) GetAuthorizationHeader(ctx context.Context) (string, error) { - token, err := tm.GetToken(ctx) - if err != nil { - return "", err - } - return "Bearer " + token.Token, nil -} - -// UpdateGitHubToken updates the GitHub access token used for getting API tokens. -func (tm *TokenManager) UpdateGitHubToken(githubToken string) { - tm.githubToken = githubToken - tm.cachedToken = nil // Invalidate cache -} - // BuildChatCompletionURL builds the URL for chat completions API. func BuildChatCompletionURL() string { return copilotAPIEndpoint + "/chat/completions" diff --git a/internal/auth/copilot/errors.go b/internal/auth/copilot/errors.go index 01f8e754..dac6ecfa 100644 --- a/internal/auth/copilot/errors.go +++ b/internal/auth/copilot/errors.go @@ -140,10 +140,8 @@ func IsOAuthError(err error) bool { // GetUserFriendlyMessage returns a user-friendly error message based on the error type. func GetUserFriendlyMessage(err error) string { - switch { - case IsAuthenticationError(err): - var authErr *AuthenticationError - errors.As(err, &authErr) + var authErr *AuthenticationError + if errors.As(err, &authErr) { switch authErr.Type { case "device_code_failed": return "Failed to start GitHub authentication. Please check your network connection and try again." @@ -164,9 +162,10 @@ func GetUserFriendlyMessage(err error) string { default: return "Authentication failed. Please try again." } - case IsOAuthError(err): - var oauthErr *OAuthError - errors.As(err, &oauthErr) + } + + var oauthErr *OAuthError + if errors.As(err, &oauthErr) { switch oauthErr.Code { case "access_denied": return "Authentication was cancelled or denied." @@ -177,7 +176,7 @@ func GetUserFriendlyMessage(err error) string { default: return fmt.Sprintf("Authentication failed: %s", oauthErr.Description) } - default: - return "An unexpected error occurred. Please try again." } + + return "An unexpected error occurred. Please try again." } diff --git a/internal/auth/copilot/oauth.go b/internal/auth/copilot/oauth.go index 3f7877b6..1aecf596 100644 --- a/internal/auth/copilot/oauth.go +++ b/internal/auth/copilot/oauth.go @@ -3,6 +3,7 @@ package copilot import ( "context" "encoding/json" + "errors" "fmt" "io" "net/http" @@ -118,12 +119,7 @@ func (c *DeviceFlowClient) PollForToken(ctx context.Context, deviceCode *DeviceC token, err := c.exchangeDeviceCode(ctx, deviceCode.DeviceCode) if err != nil { var authErr *AuthenticationError - if IsAuthenticationError(err) { - if ok := (err.(*AuthenticationError)); ok != nil { - authErr = ok - } - } - if authErr != nil { + if errors.As(err, &authErr) { switch authErr.Type { case ErrAuthorizationPending.Type: // Continue polling From 2c296e9cb19ef48bf2b8f51911ab4fd7bf115b50 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Ernesto=20Mart=C3=ADnez?= Date: Fri, 28 Nov 2025 08:32:36 +0100 Subject: [PATCH 03/16] refactor(copilot): improve code quality in authentication module - Add Unwrap() to AuthenticationError for proper error chain handling with errors.Is/As - Extract hardcoded header values to constants for maintainability - Replace verbose status code checks with isHTTPSuccess() helper - Remove unused ExtractBearerToken() and BuildModelsURL() functions - Make buildChatCompletionURL() private (only used internally) - Remove unused 'strings' import --- internal/auth/copilot/copilot_auth.go | 47 ++++++++++++--------------- internal/auth/copilot/errors.go | 5 +++ internal/auth/copilot/oauth.go | 4 +-- 3 files changed, 28 insertions(+), 28 deletions(-) diff --git a/internal/auth/copilot/copilot_auth.go b/internal/auth/copilot/copilot_auth.go index e4e5876b..c40e7082 100644 --- a/internal/auth/copilot/copilot_auth.go +++ b/internal/auth/copilot/copilot_auth.go @@ -8,7 +8,6 @@ import ( "fmt" "io" "net/http" - "strings" "time" "github.com/router-for-me/CLIProxyAPI/v6/internal/config" @@ -21,6 +20,13 @@ const ( copilotAPITokenURL = "https://api.github.com/copilot_internal/v2/token" // copilotAPIEndpoint is the base URL for making API requests. copilotAPIEndpoint = "https://api.githubcopilot.com" + + // Common HTTP header values for Copilot API requests. + copilotUserAgent = "GithubCopilot/1.0" + copilotEditorVersion = "vscode/1.100.0" + copilotPluginVersion = "copilot/1.300.0" + copilotIntegrationID = "vscode-chat" + copilotOpenAIIntent = "conversation-panel" ) // CopilotAPIToken represents the Copilot API token response. @@ -102,9 +108,9 @@ func (c *CopilotAuth) GetCopilotAPIToken(ctx context.Context, githubAccessToken req.Header.Set("Authorization", "token "+githubAccessToken) req.Header.Set("Accept", "application/json") - req.Header.Set("User-Agent", "GithubCopilot/1.0") - req.Header.Set("Editor-Version", "vscode/1.100.0") - req.Header.Set("Editor-Plugin-Version", "copilot/1.300.0") + req.Header.Set("User-Agent", copilotUserAgent) + req.Header.Set("Editor-Version", copilotEditorVersion) + req.Header.Set("Editor-Plugin-Version", copilotPluginVersion) resp, err := c.httpClient.Do(req) if err != nil { @@ -121,7 +127,7 @@ func (c *CopilotAuth) GetCopilotAPIToken(ctx context.Context, githubAccessToken return nil, NewAuthenticationError(ErrTokenExchangeFailed, err) } - if resp.StatusCode < http.StatusOK || resp.StatusCode >= http.StatusMultipleChoices { + if !isHTTPSuccess(resp.StatusCode) { return nil, NewAuthenticationError(ErrTokenExchangeFailed, fmt.Errorf("status %d: %s", resp.StatusCode, string(bodyBytes))) } @@ -199,32 +205,21 @@ func (c *CopilotAuth) MakeAuthenticatedRequest(ctx context.Context, method, url req.Header.Set("Authorization", "Bearer "+apiToken.Token) req.Header.Set("Content-Type", "application/json") req.Header.Set("Accept", "application/json") - req.Header.Set("User-Agent", "GithubCopilot/1.0") - req.Header.Set("Editor-Version", "vscode/1.100.0") - req.Header.Set("Editor-Plugin-Version", "copilot/1.300.0") - req.Header.Set("Openai-Intent", "conversation-panel") - req.Header.Set("Copilot-Integration-Id", "vscode-chat") + req.Header.Set("User-Agent", copilotUserAgent) + req.Header.Set("Editor-Version", copilotEditorVersion) + req.Header.Set("Editor-Plugin-Version", copilotPluginVersion) + req.Header.Set("Openai-Intent", copilotOpenAIIntent) + req.Header.Set("Copilot-Integration-Id", copilotIntegrationID) return req, nil } -// BuildChatCompletionURL builds the URL for chat completions API. -func BuildChatCompletionURL() string { +// buildChatCompletionURL builds the URL for chat completions API. +func buildChatCompletionURL() string { return copilotAPIEndpoint + "/chat/completions" } -// BuildModelsURL builds the URL for listing available models. -func BuildModelsURL() string { - return copilotAPIEndpoint + "/models" -} - -// ExtractBearerToken extracts the bearer token from an Authorization header. -func ExtractBearerToken(authHeader string) string { - if strings.HasPrefix(authHeader, "Bearer ") { - return strings.TrimPrefix(authHeader, "Bearer ") - } - if strings.HasPrefix(authHeader, "token ") { - return strings.TrimPrefix(authHeader, "token ") - } - return authHeader +// isHTTPSuccess checks if the status code indicates success (2xx). +func isHTTPSuccess(statusCode int) bool { + return statusCode >= 200 && statusCode < 300 } diff --git a/internal/auth/copilot/errors.go b/internal/auth/copilot/errors.go index dac6ecfa..a82dd8ec 100644 --- a/internal/auth/copilot/errors.go +++ b/internal/auth/copilot/errors.go @@ -55,6 +55,11 @@ func (e *AuthenticationError) Error() string { return fmt.Sprintf("%s: %s", e.Type, e.Message) } +// Unwrap returns the underlying cause of the error. +func (e *AuthenticationError) Unwrap() error { + return e.Cause +} + // Common authentication error types for GitHub Copilot device flow. var ( // ErrDeviceCodeFailed represents an error when requesting the device code fails. diff --git a/internal/auth/copilot/oauth.go b/internal/auth/copilot/oauth.go index 1aecf596..d3f46aaa 100644 --- a/internal/auth/copilot/oauth.go +++ b/internal/auth/copilot/oauth.go @@ -72,7 +72,7 @@ func (c *DeviceFlowClient) RequestDeviceCode(ctx context.Context) (*DeviceCodeRe } }() - if resp.StatusCode < http.StatusOK || resp.StatusCode >= http.StatusMultipleChoices { + if !isHTTPSuccess(resp.StatusCode) { bodyBytes, _ := io.ReadAll(resp.Body) return nil, NewAuthenticationError(ErrDeviceCodeFailed, fmt.Errorf("status %d: %s", resp.StatusCode, string(bodyBytes))) } @@ -235,7 +235,7 @@ func (c *DeviceFlowClient) FetchUserInfo(ctx context.Context, accessToken string } }() - if resp.StatusCode < http.StatusOK || resp.StatusCode >= http.StatusMultipleChoices { + if !isHTTPSuccess(resp.StatusCode) { bodyBytes, _ := io.ReadAll(resp.Body) return "", NewAuthenticationError(ErrUserInfoFailed, fmt.Errorf("status %d: %s", resp.StatusCode, string(bodyBytes))) } From 7515090cb686f0a502af24f59dbe53aaecf135ee Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Ernesto=20Mart=C3=ADnez?= Date: Fri, 28 Nov 2025 08:33:51 +0100 Subject: [PATCH 04/16] refactor(executor): improve concurrency and code quality in GitHub Copilot executor - Replace concurrent-unsafe metadata caching with thread-safe sync.RWMutex-protected map - Extract magic numbers and hardcoded header values to named constants - Replace verbose status code checks with isHTTPSuccess() helper - Simplify normalizeModel() to no-op with explanatory comment (models already canonical) - Remove redundant metadata manipulation in token caching - Improve code clarity and performance with proper cache management --- .../executor/github_copilot_executor.go | 109 ++++++++++-------- sdk/auth/github_copilot.go | 6 +- 2 files changed, 59 insertions(+), 56 deletions(-) diff --git a/internal/runtime/executor/github_copilot_executor.go b/internal/runtime/executor/github_copilot_executor.go index 0d1240f7..b2bc73df 100644 --- a/internal/runtime/executor/github_copilot_executor.go +++ b/internal/runtime/executor/github_copilot_executor.go @@ -7,6 +7,7 @@ import ( "fmt" "io" "net/http" + "sync" "time" "github.com/google/uuid" @@ -24,16 +25,38 @@ const ( githubCopilotChatPath = "/chat/completions" githubCopilotAuthType = "github-copilot" githubCopilotTokenCacheTTL = 25 * time.Minute + // tokenExpiryBuffer is the time before expiry when we should refresh the token. + tokenExpiryBuffer = 5 * time.Minute + // maxScannerBufferSize is the maximum buffer size for SSE scanning (20MB). + maxScannerBufferSize = 20_971_520 + + // Copilot API header values. + copilotUserAgent = "GithubCopilot/1.0" + copilotEditorVersion = "vscode/1.100.0" + copilotPluginVersion = "copilot/1.300.0" + copilotIntegrationID = "vscode-chat" + copilotOpenAIIntent = "conversation-panel" ) // GitHubCopilotExecutor handles requests to the GitHub Copilot API. type GitHubCopilotExecutor struct { - cfg *config.Config + cfg *config.Config + mu sync.RWMutex + cache map[string]*cachedAPIToken +} + +// cachedAPIToken stores a cached Copilot API token with its expiry. +type cachedAPIToken struct { + token string + expiresAt time.Time } // NewGitHubCopilotExecutor constructs a new executor instance. func NewGitHubCopilotExecutor(cfg *config.Config) *GitHubCopilotExecutor { - return &GitHubCopilotExecutor{cfg: cfg} + return &GitHubCopilotExecutor{ + cfg: cfg, + cache: make(map[string]*cachedAPIToken), + } } // Identifier implements ProviderExecutor. @@ -100,7 +123,7 @@ func (e *GitHubCopilotExecutor) Execute(ctx context.Context, auth *cliproxyauth. recordAPIResponseMetadata(ctx, e.cfg, httpResp.StatusCode, httpResp.Header.Clone()) - if httpResp.StatusCode < http.StatusOK || httpResp.StatusCode >= http.StatusMultipleChoices { + if !isHTTPSuccess(httpResp.StatusCode) { data, _ := io.ReadAll(httpResp.Body) appendAPIResponseChunk(ctx, e.cfg, data) log.Debugf("github-copilot executor: upstream error status: %d, body: %s", httpResp.StatusCode, summarizeErrorBody(httpResp.Header.Get("Content-Type"), data)) @@ -180,7 +203,7 @@ func (e *GitHubCopilotExecutor) ExecuteStream(ctx context.Context, auth *cliprox recordAPIResponseMetadata(ctx, e.cfg, httpResp.StatusCode, httpResp.Header.Clone()) - if httpResp.StatusCode < http.StatusOK || httpResp.StatusCode >= http.StatusMultipleChoices { + if !isHTTPSuccess(httpResp.StatusCode) { data, readErr := io.ReadAll(httpResp.Body) if errClose := httpResp.Body.Close(); errClose != nil { log.Errorf("github-copilot executor: close response body error: %v", errClose) @@ -207,7 +230,7 @@ func (e *GitHubCopilotExecutor) ExecuteStream(ctx context.Context, auth *cliprox }() scanner := bufio.NewScanner(httpResp.Body) - scanner.Buffer(nil, 20_971_520) + scanner.Buffer(nil, maxScannerBufferSize) var param any for scanner.Scan() { @@ -277,19 +300,20 @@ func (e *GitHubCopilotExecutor) ensureAPIToken(ctx context.Context, auth *clipro return "", statusErr{code: http.StatusUnauthorized, msg: "missing auth"} } - // Check for cached API token - if cachedToken := metaStringValue(auth.Metadata, "copilot_api_token"); cachedToken != "" { - if expiresAt := tokenExpiry(auth.Metadata); expiresAt.After(time.Now().Add(5 * time.Minute)) { - return cachedToken, nil - } - } - // Get the GitHub access token accessToken := metaStringValue(auth.Metadata, "access_token") if accessToken == "" { return "", statusErr{code: http.StatusUnauthorized, msg: "missing github access token"} } + // Check for cached API token using thread-safe access + e.mu.RLock() + if cached, ok := e.cache[accessToken]; ok && cached.expiresAt.After(time.Now().Add(tokenExpiryBuffer)) { + e.mu.RUnlock() + return cached.token, nil + } + e.mu.RUnlock() + // Get a new Copilot API token copilotAuth := copilotauth.NewCopilotAuth(e.cfg) apiToken, err := copilotAuth.GetCopilotAPIToken(ctx, accessToken) @@ -297,16 +321,17 @@ func (e *GitHubCopilotExecutor) ensureAPIToken(ctx context.Context, auth *clipro return "", statusErr{code: http.StatusUnauthorized, msg: fmt.Sprintf("failed to get copilot api token: %v", err)} } - // Cache the token in metadata (will be persisted on next save) - if auth.Metadata == nil { - auth.Metadata = make(map[string]any) - } - auth.Metadata["copilot_api_token"] = apiToken.Token + // Cache the token with thread-safe access + expiresAt := time.Now().Add(githubCopilotTokenCacheTTL) if apiToken.ExpiresAt > 0 { - auth.Metadata["expired"] = time.Unix(apiToken.ExpiresAt, 0).Format(time.RFC3339) - } else { - auth.Metadata["expired"] = time.Now().Add(githubCopilotTokenCacheTTL).Format(time.RFC3339) + expiresAt = time.Unix(apiToken.ExpiresAt, 0) } + e.mu.Lock() + e.cache[accessToken] = &cachedAPIToken{ + token: apiToken.Token, + expiresAt: expiresAt, + } + e.mu.Unlock() return apiToken.Token, nil } @@ -316,39 +341,21 @@ func (e *GitHubCopilotExecutor) applyHeaders(r *http.Request, apiToken string) { r.Header.Set("Content-Type", "application/json") r.Header.Set("Authorization", "Bearer "+apiToken) r.Header.Set("Accept", "application/json") - r.Header.Set("User-Agent", "GithubCopilot/1.0") - r.Header.Set("Editor-Version", "vscode/1.100.0") - r.Header.Set("Editor-Plugin-Version", "copilot/1.300.0") - r.Header.Set("Openai-Intent", "conversation-panel") - r.Header.Set("Copilot-Integration-Id", "vscode-chat") + r.Header.Set("User-Agent", copilotUserAgent) + r.Header.Set("Editor-Version", copilotEditorVersion) + r.Header.Set("Editor-Plugin-Version", copilotPluginVersion) + r.Header.Set("Openai-Intent", copilotOpenAIIntent) + r.Header.Set("Copilot-Integration-Id", copilotIntegrationID) r.Header.Set("X-Request-Id", uuid.NewString()) } -// normalizeModel ensures the model name is correct for the API. -func (e *GitHubCopilotExecutor) normalizeModel(requestedModel string, body []byte) []byte { - // Map friendly names to API model names - modelMap := map[string]string{ - "gpt-4.1": "gpt-4.1", - "gpt-5": "gpt-5", - "gpt-5-mini": "gpt-5-mini", - "gpt-5-codex": "gpt-5-codex", - "gpt-5.1": "gpt-5.1", - "gpt-5.1-codex": "gpt-5.1-codex", - "gpt-5.1-codex-mini": "gpt-5.1-codex-mini", - "claude-haiku-4.5": "claude-haiku-4.5", - "claude-opus-4.1": "claude-opus-4.1", - "claude-opus-4.5": "claude-opus-4.5", - "claude-sonnet-4": "claude-sonnet-4", - "claude-sonnet-4.5": "claude-sonnet-4.5", - "gemini-2.5-pro": "gemini-2.5-pro", - "gemini-3-pro": "gemini-3-pro", - "grok-code-fast-1": "grok-code-fast-1", - "raptor-mini": "raptor-mini", - } - - if mapped, ok := modelMap[requestedModel]; ok { - body, _ = sjson.SetBytes(body, "model", mapped) - } - +// normalizeModel is a no-op as GitHub Copilot accepts model names directly. +// Model mapping should be done at the registry level if needed. +func (e *GitHubCopilotExecutor) normalizeModel(_ string, body []byte) []byte { return body } + +// isHTTPSuccess checks if the status code indicates success (2xx). +func isHTTPSuccess(statusCode int) bool { + return statusCode >= 200 && statusCode < 300 +} diff --git a/sdk/auth/github_copilot.go b/sdk/auth/github_copilot.go index 3ffcf133..1d14ac47 100644 --- a/sdk/auth/github_copilot.go +++ b/sdk/auth/github_copilot.go @@ -36,9 +36,6 @@ func (a GitHubCopilotAuthenticator) Login(ctx context.Context, cfg *config.Confi if cfg == nil { return nil, fmt.Errorf("cliproxy auth: configuration is required") } - if ctx == nil { - ctx = context.Background() - } if opts == nil { opts = &LoginOptions{} } @@ -85,7 +82,7 @@ func (a GitHubCopilotAuthenticator) Login(ctx context.Context, cfg *config.Confi // Create the token storage tokenStorage := authSvc.CreateTokenStorage(authBundle) - // Build metadata + // Build metadata with token information for the executor metadata := map[string]any{ "type": "github-copilot", "username": authBundle.Username, @@ -93,7 +90,6 @@ func (a GitHubCopilotAuthenticator) Login(ctx context.Context, cfg *config.Confi "token_type": authBundle.TokenData.TokenType, "scope": authBundle.TokenData.Scope, "timestamp": time.Now().UnixMilli(), - "api_endpoint": copilot.BuildChatCompletionURL(), } if apiToken.ExpiresAt > 0 { From 0087eecad8516e4021bdeeff5765acd5a0ac6837 Mon Sep 17 00:00:00 2001 From: Luis Pater Date: Sun, 30 Nov 2025 17:23:47 +0800 Subject: [PATCH 05/16] **fix(build): append '-plus' suffix to version metadata in workflows and Goreleaser** - Updated release and Docker workflows to ensure the `-plus` suffix is added to build versions when missing. - Adjusted Goreleaser configuration to include `-plus` suffix in `main.Version` during build process. --- .github/workflows/docker-image.yml | 7 +++++-- .github/workflows/release.yaml | 6 +++++- .goreleaser.yml | 2 +- 3 files changed, 11 insertions(+), 4 deletions(-) diff --git a/.github/workflows/docker-image.yml b/.github/workflows/docker-image.yml index 3aacf4f5..63e1c4b7 100644 --- a/.github/workflows/docker-image.yml +++ b/.github/workflows/docker-image.yml @@ -26,7 +26,11 @@ jobs: password: ${{ secrets.DOCKERHUB_TOKEN }} - name: Generate Build Metadata run: | - echo VERSION=`git describe --tags --always --dirty` >> $GITHUB_ENV + VERSION=$(git describe --tags --always --dirty) + if [[ "$VERSION" != *"-plus" ]]; then + VERSION="${VERSION}-plus" + fi + echo "VERSION=${VERSION}" >> $GITHUB_ENV echo COMMIT=`git rev-parse --short HEAD` >> $GITHUB_ENV echo BUILD_DATE=`date -u +%Y-%m-%dT%H:%M:%SZ` >> $GITHUB_ENV - name: Build and push @@ -42,5 +46,4 @@ jobs: COMMIT=${{ env.COMMIT }} BUILD_DATE=${{ env.BUILD_DATE }} tags: | - ${{ env.DOCKERHUB_REPO }}:latest ${{ env.DOCKERHUB_REPO }}:${{ env.VERSION }} diff --git a/.github/workflows/release.yaml b/.github/workflows/release.yaml index 4bb5e63b..53c1cf32 100644 --- a/.github/workflows/release.yaml +++ b/.github/workflows/release.yaml @@ -23,7 +23,11 @@ jobs: cache: true - name: Generate Build Metadata run: | - echo VERSION=`git describe --tags --always --dirty` >> $GITHUB_ENV + VERSION=$(git describe --tags --always --dirty) + if [[ "$VERSION" != *"-plus" ]]; then + VERSION="${VERSION}-plus" + fi + echo "VERSION=${VERSION}" >> $GITHUB_ENV echo COMMIT=`git rev-parse --short HEAD` >> $GITHUB_ENV echo BUILD_DATE=`date -u +%Y-%m-%dT%H:%M:%SZ` >> $GITHUB_ENV - uses: goreleaser/goreleaser-action@v4 diff --git a/.goreleaser.yml b/.goreleaser.yml index 31d05e6d..8fbec015 100644 --- a/.goreleaser.yml +++ b/.goreleaser.yml @@ -12,7 +12,7 @@ builds: main: ./cmd/server/ binary: cli-proxy-api ldflags: - - -s -w -X 'main.Version={{.Version}}' -X 'main.Commit={{.ShortCommit}}' -X 'main.BuildDate={{.Date}}' + - -s -w -X 'main.Version={{.Version}}-plus' -X 'main.Commit={{.ShortCommit}}' -X 'main.BuildDate={{.Date}}' archives: - id: "cli-proxy-api" format: tar.gz From 52e5551d8f2298274338c467c6629ad4978023de Mon Sep 17 00:00:00 2001 From: Luis Pater Date: Sun, 30 Nov 2025 17:54:37 +0800 Subject: [PATCH 06/16] **chore(build): rename artifacts and adjust workflows for 'plus' variant** - Updated Docker and release workflows to use `cli-proxy-api-plus` as the Docker repository name and adjusted tag generation logic. - Renamed artifacts in Goreleaser configuration to align with the new 'plus' variant naming convention. --- .github/workflows/docker-image.yml | 9 +++------ .github/workflows/release.yaml | 3 --- .goreleaser.yml | 6 +++--- 3 files changed, 6 insertions(+), 12 deletions(-) diff --git a/.github/workflows/docker-image.yml b/.github/workflows/docker-image.yml index 63e1c4b7..de924672 100644 --- a/.github/workflows/docker-image.yml +++ b/.github/workflows/docker-image.yml @@ -7,7 +7,7 @@ on: env: APP_NAME: CLIProxyAPI - DOCKERHUB_REPO: eceasy/cli-proxy-api + DOCKERHUB_REPO: eceasy/cli-proxy-api-plus jobs: docker: @@ -26,11 +26,7 @@ jobs: password: ${{ secrets.DOCKERHUB_TOKEN }} - name: Generate Build Metadata run: | - VERSION=$(git describe --tags --always --dirty) - if [[ "$VERSION" != *"-plus" ]]; then - VERSION="${VERSION}-plus" - fi - echo "VERSION=${VERSION}" >> $GITHUB_ENV + echo VERSION=`git describe --tags --always --dirty` >> $GITHUB_ENV echo COMMIT=`git rev-parse --short HEAD` >> $GITHUB_ENV echo BUILD_DATE=`date -u +%Y-%m-%dT%H:%M:%SZ` >> $GITHUB_ENV - name: Build and push @@ -46,4 +42,5 @@ jobs: COMMIT=${{ env.COMMIT }} BUILD_DATE=${{ env.BUILD_DATE }} tags: | + ${{ env.DOCKERHUB_REPO }}:latest ${{ env.DOCKERHUB_REPO }}:${{ env.VERSION }} diff --git a/.github/workflows/release.yaml b/.github/workflows/release.yaml index 53c1cf32..4c4aafe7 100644 --- a/.github/workflows/release.yaml +++ b/.github/workflows/release.yaml @@ -24,9 +24,6 @@ jobs: - name: Generate Build Metadata run: | VERSION=$(git describe --tags --always --dirty) - if [[ "$VERSION" != *"-plus" ]]; then - VERSION="${VERSION}-plus" - fi echo "VERSION=${VERSION}" >> $GITHUB_ENV echo COMMIT=`git rev-parse --short HEAD` >> $GITHUB_ENV echo BUILD_DATE=`date -u +%Y-%m-%dT%H:%M:%SZ` >> $GITHUB_ENV diff --git a/.goreleaser.yml b/.goreleaser.yml index 8fbec015..6e1829ed 100644 --- a/.goreleaser.yml +++ b/.goreleaser.yml @@ -1,5 +1,5 @@ builds: - - id: "cli-proxy-api" + - id: "cli-proxy-api-plus" env: - CGO_ENABLED=0 goos: @@ -10,11 +10,11 @@ builds: - amd64 - arm64 main: ./cmd/server/ - binary: cli-proxy-api + binary: cli-proxy-api-plus ldflags: - -s -w -X 'main.Version={{.Version}}-plus' -X 'main.Commit={{.ShortCommit}}' -X 'main.BuildDate={{.Date}}' archives: - - id: "cli-proxy-api" + - id: "cli-proxy-api-plus" format: tar.gz format_overrides: - goos: windows From 8203bf64ec554e1fc93ceeca870e51d45dde08b1 Mon Sep 17 00:00:00 2001 From: Luis Pater Date: Sun, 30 Nov 2025 20:16:51 +0800 Subject: [PATCH 07/16] **docs: update README for CLIProxyAPI Plus and rename artifacts** - Updated README files to reflect the new 'Plus' variant with detailed third-party provider support information. - Adjusted Dockerfile, `docker-compose.yml`, and build configurations to align with the `CLIProxyAPIPlus` naming convention. - Added information about GitHub Copilot OAuth integration contributed by the community. --- Dockerfile | 6 +-- README.md | 93 ++++------------------------------------- README_CN.md | 100 ++++----------------------------------------- docker-compose.yml | 4 +- 4 files changed, 23 insertions(+), 180 deletions(-) diff --git a/Dockerfile b/Dockerfile index 8623dc5e..98509423 100644 --- a/Dockerfile +++ b/Dockerfile @@ -12,7 +12,7 @@ ARG VERSION=dev ARG COMMIT=none ARG BUILD_DATE=unknown -RUN CGO_ENABLED=0 GOOS=linux go build -ldflags="-s -w -X 'main.Version=${VERSION}' -X 'main.Commit=${COMMIT}' -X 'main.BuildDate=${BUILD_DATE}'" -o ./CLIProxyAPI ./cmd/server/ +RUN CGO_ENABLED=0 GOOS=linux go build -ldflags="-s -w -X 'main.Version=${VERSION}-plus' -X 'main.Commit=${COMMIT}' -X 'main.BuildDate=${BUILD_DATE}'" -o ./CLIProxyAPIPlus ./cmd/server/ FROM alpine:3.22.0 @@ -20,7 +20,7 @@ RUN apk add --no-cache tzdata RUN mkdir /CLIProxyAPI -COPY --from=builder ./app/CLIProxyAPI /CLIProxyAPI/CLIProxyAPI +COPY --from=builder ./app/CLIProxyAPIPlus /CLIProxyAPI/CLIProxyAPIPlus COPY config.example.yaml /CLIProxyAPI/config.example.yaml @@ -32,4 +32,4 @@ ENV TZ=Asia/Shanghai RUN cp /usr/share/zoneinfo/${TZ} /etc/localtime && echo "${TZ}" > /etc/timezone -CMD ["./CLIProxyAPI"] \ No newline at end of file +CMD ["./CLIProxyAPIPlus"] \ No newline at end of file diff --git a/README.md b/README.md index 90d5d465..8e27ab05 100644 --- a/README.md +++ b/README.md @@ -1,97 +1,22 @@ -# CLI Proxy API +# CLIProxyAPI Plus -English | [中文](README_CN.md) +English | [Chinese](README_CN.md) -A proxy server that provides OpenAI/Gemini/Claude/Codex compatible API interfaces for CLI. +This is the Plus version of [CLIProxyAPI](https://github.com/router-for-me/CLIProxyAPI), adding support for third-party providers on top of the mainline project. -It now also supports OpenAI Codex (GPT models) and Claude Code via OAuth. +All third-party provider support is maintained by community contributors; CLIProxyAPI does not provide technical support. Please contact the corresponding community maintainer if you need assistance. -So you can use local or multi-account CLI access with OpenAI(include Responses)/Gemini/Claude-compatible clients and SDKs. +The Plus release stays in lockstep with the mainline features. -## Sponsor +## Differences from the Mainline -[![z.ai](https://assets.router-for.me/english.png)](https://z.ai/subscribe?ic=8JVLJQFSKB) - -This project is sponsored by Z.ai, supporting us with their GLM CODING PLAN. - -GLM CODING PLAN is a subscription service designed for AI coding, starting at just $3/month. It provides access to their flagship GLM-4.6 model across 10+ popular AI coding tools (Claude Code, Cline, Roo Code, etc.), offering developers top-tier, fast, and stable coding experiences. - -Get 10% OFF GLM CODING PLAN:https://z.ai/subscribe?ic=8JVLJQFSKB - -## Overview - -- OpenAI/Gemini/Claude compatible API endpoints for CLI models -- OpenAI Codex support (GPT models) via OAuth login -- Claude Code support via OAuth login -- Qwen Code support via OAuth login -- iFlow support via OAuth login -- Amp CLI and IDE extensions support with provider routing -- Streaming and non-streaming responses -- Function calling/tools support -- Multimodal input support (text and images) -- Multiple accounts with round-robin load balancing (Gemini, OpenAI, Claude, Qwen and iFlow) -- Simple CLI authentication flows (Gemini, OpenAI, Claude, Qwen and iFlow) -- Generative Language API Key support -- AI Studio Build multi-account load balancing -- Gemini CLI multi-account load balancing -- Claude Code multi-account load balancing -- Qwen Code multi-account load balancing -- iFlow multi-account load balancing -- OpenAI Codex multi-account load balancing -- OpenAI-compatible upstream providers via config (e.g., OpenRouter) -- Reusable Go SDK for embedding the proxy (see `docs/sdk-usage.md`) - -## Getting Started - -CLIProxyAPI Guides: [https://help.router-for.me/](https://help.router-for.me/) - -## Management API - -see [MANAGEMENT_API.md](https://help.router-for.me/management/api) - -## Amp CLI Support - -CLIProxyAPI includes integrated support for [Amp CLI](https://ampcode.com) and Amp IDE extensions, enabling you to use your Google/ChatGPT/Claude OAuth subscriptions with Amp's coding tools: - -- Provider route aliases for Amp's API patterns (`/api/provider/{provider}/v1...`) -- Management proxy for OAuth authentication and account features -- Smart model fallback with automatic routing -- Security-first design with localhost-only management endpoints - -**→ [Complete Amp CLI Integration Guide](docs/amp-cli-integration.md)** - -## SDK Docs - -- Usage: [docs/sdk-usage.md](docs/sdk-usage.md) -- Advanced (executors & translators): [docs/sdk-advanced.md](docs/sdk-advanced.md) -- Access: [docs/sdk-access.md](docs/sdk-access.md) -- Watcher: [docs/sdk-watcher.md](docs/sdk-watcher.md) -- Custom Provider Example: `examples/custom-provider` +- Added GitHub Copilot support (OAuth login), provided by [em4gp](https://github.com/em4go/CLIProxyAPI/tree/feature/github-copilot-auth) ## Contributing -Contributions are welcome! Please feel free to submit a Pull Request. +This project only accepts pull requests that relate to third-party provider support. Any pull requests unrelated to third-party provider support will be rejected. -1. Fork the repository -2. Create your feature branch (`git checkout -b feature/amazing-feature`) -3. Commit your changes (`git commit -m 'Add some amazing feature'`) -4. Push to the branch (`git push origin feature/amazing-feature`) -5. Open a Pull Request - -## Who is with us? - -Those projects are based on CLIProxyAPI: - -### [vibeproxy](https://github.com/automazeio/vibeproxy) - -Native macOS menu bar app to use your Claude Code & ChatGPT subscriptions with AI coding tools - no API keys needed - -### [Subtitle Translator](https://github.com/VjayC/SRT-Subtitle-Translator-Validator) - -Browser-based tool to translate SRT subtitles using your Gemini subscription via CLIProxyAPI with automatic validation/error correction - no API keys needed - -> [!NOTE] -> If you developed a project based on CLIProxyAPI, please open a PR to add it to this list. +If you need to submit any non-third-party provider changes, please open them against the mainline repository. ## License diff --git a/README_CN.md b/README_CN.md index be0aa234..84e90d40 100644 --- a/README_CN.md +++ b/README_CN.md @@ -1,105 +1,23 @@ -# CLI 代理 API +# CLIProxyAPI Plus [English](README.md) | 中文 -一个为 CLI 提供 OpenAI/Gemini/Claude/Codex 兼容 API 接口的代理服务器。 +这是 [CLIProxyAPI](https://github.com/router-for-me/CLIProxyAPI) 的 Plus 版本,在原有基础上增加了第三方供应商的支持。 -现已支持通过 OAuth 登录接入 OpenAI Codex(GPT 系列)和 Claude Code。 +所有的第三方供应商支持都由第三方社区维护者提供,CLIProxyAPI 不提供技术支持。如需取得支持,请与对应的社区维护者联系。 -您可以使用本地或多账户的CLI方式,通过任何与 OpenAI(包括Responses)/Gemini/Claude 兼容的客户端和SDK进行访问。 +该 Plus 版本的主线功能与主线功能强制同步。 -## 赞助商 +## 与主线版本版本差异 -[![bigmodel.cn](https://assets.router-for.me/chinese.png)](https://www.bigmodel.cn/claude-code?ic=RRVJPB5SII) - -本项目由 Z智谱 提供赞助, 他们通过 GLM CODING PLAN 对本项目提供技术支持。 - -GLM CODING PLAN 是专为AI编码打造的订阅套餐,每月最低仅需20元,即可在十余款主流AI编码工具如 Claude Code、Cline、Roo Code 中畅享智谱旗舰模型GLM-4.6,为开发者提供顶尖的编码体验。 - -智谱AI为本软件提供了特别优惠,使用以下链接购买可以享受九折优惠:https://www.bigmodel.cn/claude-code?ic=RRVJPB5SII - -## 功能特性 - -- 为 CLI 模型提供 OpenAI/Gemini/Claude/Codex 兼容的 API 端点 -- 新增 OpenAI Codex(GPT 系列)支持(OAuth 登录) -- 新增 Claude Code 支持(OAuth 登录) -- 新增 Qwen Code 支持(OAuth 登录) -- 新增 iFlow 支持(OAuth 登录) -- 支持流式与非流式响应 -- 函数调用/工具支持 -- 多模态输入(文本、图片) -- 多账户支持与轮询负载均衡(Gemini、OpenAI、Claude、Qwen 与 iFlow) -- 简单的 CLI 身份验证流程(Gemini、OpenAI、Claude、Qwen 与 iFlow) -- 支持 Gemini AIStudio API 密钥 -- 支持 AI Studio Build 多账户轮询 -- 支持 Gemini CLI 多账户轮询 -- 支持 Claude Code 多账户轮询 -- 支持 Qwen Code 多账户轮询 -- 支持 iFlow 多账户轮询 -- 支持 OpenAI Codex 多账户轮询 -- 通过配置接入上游 OpenAI 兼容提供商(例如 OpenRouter) -- 可复用的 Go SDK(见 `docs/sdk-usage_CN.md`) - -## 新手入门 - -CLIProxyAPI 用户手册: [https://help.router-for.me/](https://help.router-for.me/cn/) - -## 管理 API 文档 - -请参见 [MANAGEMENT_API_CN.md](https://help.router-for.me/cn/management/api) - -## Amp CLI 支持 - -CLIProxyAPI 已内置对 [Amp CLI](https://ampcode.com) 和 Amp IDE 扩展的支持,可让你使用自己的 Google/ChatGPT/Claude OAuth 订阅来配合 Amp 编码工具: - -- 提供商路由别名,兼容 Amp 的 API 路径模式(`/api/provider/{provider}/v1...`) -- 管理代理,处理 OAuth 认证和账号功能 -- 智能模型回退与自动路由 -- 以安全为先的设计,管理端点仅限 localhost - -**→ [Amp CLI 完整集成指南](docs/amp-cli-integration_CN.md)** - -## SDK 文档 - -- 使用文档:[docs/sdk-usage_CN.md](docs/sdk-usage_CN.md) -- 高级(执行器与翻译器):[docs/sdk-advanced_CN.md](docs/sdk-advanced_CN.md) -- 认证: [docs/sdk-access_CN.md](docs/sdk-access_CN.md) -- 凭据加载/更新: [docs/sdk-watcher_CN.md](docs/sdk-watcher_CN.md) -- 自定义 Provider 示例:`examples/custom-provider` +- 新增 GitHub Copilot 支持(OAuth 登录),由[em4gp](https://github.com/em4go/CLIProxyAPI/tree/feature/github-copilot-auth)提供 ## 贡献 -欢迎贡献!请随时提交 Pull Request。 +该项目仅接受第三方供应商支持的 Pull Request。任何非第三方供应商支持的 Pull Request 都将被拒绝。 -1. Fork 仓库 -2. 创建您的功能分支(`git checkout -b feature/amazing-feature`) -3. 提交您的更改(`git commit -m 'Add some amazing feature'`) -4. 推送到分支(`git push origin feature/amazing-feature`) -5. 打开 Pull Request - -## 谁与我们在一起? - -这些项目基于 CLIProxyAPI: - -### [vibeproxy](https://github.com/automazeio/vibeproxy) - -一个原生 macOS 菜单栏应用,让您可以使用 Claude Code & ChatGPT 订阅服务和 AI 编程工具,无需 API 密钥。 - -### [Subtitle Translator](https://github.com/VjayC/SRT-Subtitle-Translator-Validator) - -一款基于浏览器的 SRT 字幕翻译工具,可通过 CLI 代理 API 使用您的 Gemini 订阅。内置自动验证与错误修正功能,无需 API 密钥。 - -> [!NOTE] -> 如果你开发了基于 CLIProxyAPI 的项目,请提交一个 PR(拉取请求)将其添加到此列表中。 +如果需要提交任何非第三方供应商支持的 Pull Request,请提交到主线版本。 ## 许可证 -此项目根据 MIT 许可证授权 - 有关详细信息,请参阅 [LICENSE](LICENSE) 文件。 - -## 写给所有中国网友的 - -QQ 群:188637136 - -或 - -Telegram 群:https://t.me/CLIProxyAPI +此项目根据 MIT 许可证授权 - 有关详细信息,请参阅 [LICENSE](LICENSE) 文件。 \ No newline at end of file diff --git a/docker-compose.yml b/docker-compose.yml index 29712419..80693fbe 100644 --- a/docker-compose.yml +++ b/docker-compose.yml @@ -1,6 +1,6 @@ services: cli-proxy-api: - image: ${CLI_PROXY_IMAGE:-eceasy/cli-proxy-api:latest} + image: ${CLI_PROXY_IMAGE:-eceasy/cli-proxy-api-plus:latest} pull_policy: always build: context: . @@ -9,7 +9,7 @@ services: VERSION: ${VERSION:-dev} COMMIT: ${COMMIT:-none} BUILD_DATE: ${BUILD_DATE:-unknown} - container_name: cli-proxy-api + container_name: cli-proxy-api-plus # env_file: # - .env environment: From 691cdb6bdfae92c9dd1d4207bcbb36360ffbc55d Mon Sep 17 00:00:00 2001 From: Luis Pater Date: Fri, 5 Dec 2025 10:32:28 +0800 Subject: [PATCH 08/16] **fix(api): update GitHub release URL and user agent for CLIProxyAPIPlus** --- internal/api/handlers/management/config_basic.go | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/internal/api/handlers/management/config_basic.go b/internal/api/handlers/management/config_basic.go index ae292982..f9069198 100644 --- a/internal/api/handlers/management/config_basic.go +++ b/internal/api/handlers/management/config_basic.go @@ -19,8 +19,8 @@ import ( ) const ( - latestReleaseURL = "https://api.github.com/repos/router-for-me/CLIProxyAPI/releases/latest" - latestReleaseUserAgent = "CLIProxyAPI" + latestReleaseURL = "https://api.github.com/repos/router-for-me/CLIProxyAPIPlus/releases/latest" + latestReleaseUserAgent = "CLIProxyAPIPlus" ) func (h *Handler) GetConfig(c *gin.Context) { From 02d8a1cfecb2bf4ee1c5105507f4022f09e0bcde Mon Sep 17 00:00:00 2001 From: Mansi Date: Fri, 5 Dec 2025 22:46:24 +0300 Subject: [PATCH 09/16] feat(kiro): add AWS Builder ID authentication support - Add --kiro-aws-login flag for AWS Builder ID device code flow - Add DoKiroAWSLogin function for AWS SSO OIDC authentication - Complete Kiro integration with AWS, Google OAuth, and social auth - Add kiro executor, translator, and SDK components - Update browser support for Kiro authentication flows --- .gitignore | 2 + cmd/server/main.go | 50 + config.example.yaml | 15 + go.mod | 3 +- go.sum | 5 +- .../api/handlers/management/auth_files.go | 155 +- internal/api/server.go | 2 +- internal/auth/claude/oauth_server.go | 11 + internal/auth/codex/oauth_server.go | 11 + internal/auth/iflow/iflow_auth.go | 22 +- internal/auth/kiro/aws.go | 301 +++ internal/auth/kiro/aws_auth.go | 314 +++ internal/auth/kiro/aws_test.go | 161 ++ internal/auth/kiro/oauth.go | 296 +++ internal/auth/kiro/protocol_handler.go | 725 +++++ internal/auth/kiro/social_auth.go | 403 +++ internal/auth/kiro/sso_oidc.go | 527 ++++ internal/auth/kiro/token.go | 72 + internal/browser/browser.go | 444 +++- internal/cmd/auth_manager.go | 1 + internal/cmd/kiro_login.go | 160 ++ internal/config/config.go | 53 + internal/constant/constant.go | 3 + internal/logging/global_logger.go | 14 +- internal/registry/model_definitions.go | 158 ++ .../runtime/executor/antigravity_executor.go | 10 +- internal/runtime/executor/cache_helpers.go | 32 +- internal/runtime/executor/codex_executor.go | 4 +- internal/runtime/executor/kiro_executor.go | 2353 +++++++++++++++++ internal/runtime/executor/proxy_helpers.go | 49 +- .../claude/gemini/claude_gemini_response.go | 5 +- .../claude_openai_response.go | 4 + .../claude/gemini-cli_claude_response.go | 97 +- internal/translator/init.go | 3 + internal/translator/kiro/claude/init.go | 19 + .../translator/kiro/claude/kiro_claude.go | 24 + .../kiro/openai/chat-completions/init.go | 19 + .../chat-completions/kiro_openai_request.go | 258 ++ .../chat-completions/kiro_openai_response.go | 316 +++ internal/watcher/watcher.go | 181 +- sdk/api/handlers/handlers.go | 32 +- sdk/auth/kiro.go | 357 +++ sdk/auth/manager.go | 13 + sdk/auth/refresh_registry.go | 1 + sdk/cliproxy/service.go | 5 + 45 files changed, 7519 insertions(+), 171 deletions(-) create mode 100644 internal/auth/kiro/aws.go create mode 100644 internal/auth/kiro/aws_auth.go create mode 100644 internal/auth/kiro/aws_test.go create mode 100644 internal/auth/kiro/oauth.go create mode 100644 internal/auth/kiro/protocol_handler.go create mode 100644 internal/auth/kiro/social_auth.go create mode 100644 internal/auth/kiro/sso_oidc.go create mode 100644 internal/auth/kiro/token.go create mode 100644 internal/cmd/kiro_login.go create mode 100644 internal/runtime/executor/kiro_executor.go create mode 100644 internal/translator/kiro/claude/init.go create mode 100644 internal/translator/kiro/claude/kiro_claude.go create mode 100644 internal/translator/kiro/openai/chat-completions/init.go create mode 100644 internal/translator/kiro/openai/chat-completions/kiro_openai_request.go create mode 100644 internal/translator/kiro/openai/chat-completions/kiro_openai_response.go create mode 100644 sdk/auth/kiro.go diff --git a/.gitignore b/.gitignore index 9e730c98..9014b273 100644 --- a/.gitignore +++ b/.gitignore @@ -1,5 +1,6 @@ # Binaries cli-proxy-api +cliproxy *.exe # Configuration @@ -31,6 +32,7 @@ GEMINI.md .vscode/* .claude/* .serena/* +.mcp/cache/ # macOS .DS_Store diff --git a/cmd/server/main.go b/cmd/server/main.go index bbf500e7..28c3e514 100644 --- a/cmd/server/main.go +++ b/cmd/server/main.go @@ -62,10 +62,16 @@ func main() { var iflowCookie bool var noBrowser bool var antigravityLogin bool + var kiroLogin bool + var kiroGoogleLogin bool + var kiroAWSLogin bool + var kiroImport bool var projectID string var vertexImport string var configPath string var password string + var noIncognito bool + var useIncognito bool // Define command-line flags for different operation modes. flag.BoolVar(&login, "login", false, "Login Google Account") @@ -75,7 +81,13 @@ func main() { flag.BoolVar(&iflowLogin, "iflow-login", false, "Login to iFlow using OAuth") flag.BoolVar(&iflowCookie, "iflow-cookie", false, "Login to iFlow using Cookie") flag.BoolVar(&noBrowser, "no-browser", false, "Don't open browser automatically for OAuth") + flag.BoolVar(&useIncognito, "incognito", false, "Open browser in incognito/private mode for OAuth (useful for multiple accounts)") + flag.BoolVar(&noIncognito, "no-incognito", false, "Force disable incognito mode (uses existing browser session)") flag.BoolVar(&antigravityLogin, "antigravity-login", false, "Login to Antigravity using OAuth") + flag.BoolVar(&kiroLogin, "kiro-login", false, "Login to Kiro using Google OAuth") + flag.BoolVar(&kiroGoogleLogin, "kiro-google-login", false, "Login to Kiro using Google OAuth (same as --kiro-login)") + flag.BoolVar(&kiroAWSLogin, "kiro-aws-login", false, "Login to Kiro using AWS Builder ID (device code flow)") + flag.BoolVar(&kiroImport, "kiro-import", false, "Import Kiro token from Kiro IDE (~/.aws/sso/cache/kiro-auth-token.json)") flag.StringVar(&projectID, "project_id", "", "Project ID (Gemini only, not required)") flag.StringVar(&configPath, "config", DefaultConfigPath, "Configure File Path") flag.StringVar(&vertexImport, "vertex-import", "", "Import Vertex service account key JSON file") @@ -448,6 +460,44 @@ func main() { cmd.DoIFlowLogin(cfg, options) } else if iflowCookie { cmd.DoIFlowCookieAuth(cfg, options) + } else if kiroLogin { + // For Kiro auth, default to incognito mode for multi-account support + // Users can explicitly override with --no-incognito + // Note: This config mutation is safe - auth commands exit after completion + // and don't share config with StartService (which is in the else branch) + if useIncognito { + cfg.IncognitoBrowser = true + } else if noIncognito { + cfg.IncognitoBrowser = false + } else { + cfg.IncognitoBrowser = true // Kiro default + } + cmd.DoKiroLogin(cfg, options) + } else if kiroGoogleLogin { + // For Kiro auth, default to incognito mode for multi-account support + // Users can explicitly override with --no-incognito + // Note: This config mutation is safe - auth commands exit after completion + if useIncognito { + cfg.IncognitoBrowser = true + } else if noIncognito { + cfg.IncognitoBrowser = false + } else { + cfg.IncognitoBrowser = true // Kiro default + } + cmd.DoKiroGoogleLogin(cfg, options) + } else if kiroAWSLogin { + // For Kiro auth, default to incognito mode for multi-account support + // Users can explicitly override with --no-incognito + if useIncognito { + cfg.IncognitoBrowser = true + } else if noIncognito { + cfg.IncognitoBrowser = false + } else { + cfg.IncognitoBrowser = true // Kiro default + } + cmd.DoKiroAWSLogin(cfg, options) + } else if kiroImport { + cmd.DoKiroImport(cfg, options) } else { // In cloud deploy mode without config file, just wait for shutdown signals if isCloudDeploy && !configFileExists { diff --git a/config.example.yaml b/config.example.yaml index 61f51d47..8fb01c14 100644 --- a/config.example.yaml +++ b/config.example.yaml @@ -32,6 +32,11 @@ api-keys: # Enable debug logging debug: false +# Open OAuth URLs in incognito/private browser mode. +# Useful when you want to login with a different account without logging out from your current session. +# Default: false (but Kiro auth defaults to true for multi-account support) +incognito-browser: true + # When true, write application logs to rotating files instead of stdout logging-to-file: false @@ -99,6 +104,16 @@ ws-auth: false # - "*-think" # wildcard matching suffix (e.g. claude-opus-4-5-thinking) # - "*haiku*" # wildcard matching substring (e.g. claude-3-5-haiku-20241022) +# Kiro (AWS CodeWhisperer) configuration +# Note: Kiro API currently only operates in us-east-1 region +#kiro: +# - token-file: "~/.aws/sso/cache/kiro-auth-token.json" # path to Kiro token file +# agent-task-type: "" # optional: "vibe" or empty (API default) +# - access-token: "aoaAAAAA..." # or provide tokens directly +# refresh-token: "aorAAAAA..." +# profile-arn: "arn:aws:codewhisperer:us-east-1:..." +# proxy-url: "socks5://proxy.example.com:1080" # optional: proxy override + # OpenAI compatibility providers # openai-compatibility: # - name: "openrouter" # The name of the provider; it will be used in the user agent and other places. diff --git a/go.mod b/go.mod index c7660c96..85d816c9 100644 --- a/go.mod +++ b/go.mod @@ -13,14 +13,15 @@ require ( github.com/joho/godotenv v1.5.1 github.com/klauspost/compress v1.17.4 github.com/minio/minio-go/v7 v7.0.66 + github.com/pkg/browser v0.0.0-20240102092130-5ac0b6a4141c github.com/sirupsen/logrus v1.9.3 - github.com/skratchdot/open-golang v0.0.0-20200116055534-eef842397966 github.com/tidwall/gjson v1.18.0 github.com/tidwall/sjson v1.2.5 github.com/tiktoken-go/tokenizer v0.7.0 golang.org/x/crypto v0.43.0 golang.org/x/net v0.46.0 golang.org/x/oauth2 v0.30.0 + golang.org/x/term v0.36.0 gopkg.in/natefinch/lumberjack.v2 v2.2.1 gopkg.in/yaml.v3 v3.0.1 ) diff --git a/go.sum b/go.sum index 3acf8562..833c45b9 100644 --- a/go.sum +++ b/go.sum @@ -116,6 +116,8 @@ github.com/pelletier/go-toml/v2 v2.2.2 h1:aYUidT7k73Pcl9nb2gScu7NSrKCSHIDE89b3+6 github.com/pelletier/go-toml/v2 v2.2.2/go.mod h1:1t835xjRzz80PqgE6HHgN2JOsmgYu/h4qDAS4n929Rs= github.com/pjbgf/sha1cd v0.5.0 h1:a+UkboSi1znleCDUNT3M5YxjOnN1fz2FhN48FlwCxs0= github.com/pjbgf/sha1cd v0.5.0/go.mod h1:lhpGlyHLpQZoxMv8HcgXvZEhcGs0PG/vsZnEJ7H0iCM= +github.com/pkg/browser v0.0.0-20240102092130-5ac0b6a4141c h1:+mdjkGKdHQG3305AYmdv1U2eRNDiU2ErMBj1gwrq8eQ= +github.com/pkg/browser v0.0.0-20240102092130-5ac0b6a4141c/go.mod h1:7rwL4CYBLnjLxUqIJNnCWiEdr3bn6IUYi15bNlnbCCU= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= github.com/rogpeppe/go-internal v1.14.1 h1:UQB4HGPB6osV0SQTLymcB4TgvyWu6ZyliaW0tI/otEQ= @@ -126,8 +128,6 @@ github.com/sergi/go-diff v1.4.0 h1:n/SP9D5ad1fORl+llWyN+D6qoUETXNZARKjyY2/KVCw= github.com/sergi/go-diff v1.4.0/go.mod h1:A0bzQcvG0E7Rwjx0REVgAGH58e96+X0MeOfepqsbeW4= github.com/sirupsen/logrus v1.9.3 h1:dueUQJ1C2q9oE3F7wvmSGAaVtTmUizReu6fjN8uqzbQ= github.com/sirupsen/logrus v1.9.3/go.mod h1:naHLuLoDiP4jHNo9R0sCBMtWGeIprob74mVsIT4qYEQ= -github.com/skratchdot/open-golang v0.0.0-20200116055534-eef842397966 h1:JIAuq3EEf9cgbU6AtGPK4CTG3Zf6CKMNqf0MHTggAUA= -github.com/skratchdot/open-golang v0.0.0-20200116055534-eef842397966/go.mod h1:sUM3LWHvSMaG192sy56D9F7CNvL7jUJVXoqM1QKLnog= github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= github.com/stretchr/objx v0.4.0/go.mod h1:YvHI0jy2hoMjB+UWwv71VJQ9isScKT/TqJzVSSt89Yw= github.com/stretchr/objx v0.5.0/go.mod h1:Yh+to48EsGEfYuaHDzXPcE3xhTkx73EhmCGUpEOglKo= @@ -169,6 +169,7 @@ golang.org/x/oauth2 v0.30.0/go.mod h1:B++QgG3ZKulg6sRPGD/mqlHQs5rB3Ml9erfeDY7xKl golang.org/x/sync v0.17.0 h1:l60nONMj9l5drqw6jlhIELNv9I0A4OFgRsG9k2oT9Ug= golang.org/x/sync v0.17.0/go.mod h1:9KTHXmSnoGruLpwFjVSX0lNNA75CykiMECbovNTZqGI= golang.org/x/sys v0.0.0-20220715151400-c0bba94af5f8/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.1.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.37.0 h1:fdNQudmxPjkdUTPnLn5mdQv7Zwvbvpaxqs831goi9kQ= golang.org/x/sys v0.37.0/go.mod h1:OgkHotnGiDImocRcuBABYBEXf8A9a87e/uXjp9XT3ks= diff --git a/internal/api/handlers/management/auth_files.go b/internal/api/handlers/management/auth_files.go index 6f77fda9..161f7a93 100644 --- a/internal/api/handlers/management/auth_files.go +++ b/internal/api/handlers/management/auth_files.go @@ -36,9 +36,32 @@ import ( ) var ( - oauthStatus = make(map[string]string) + oauthStatus = make(map[string]string) + oauthStatusMutex sync.RWMutex ) +// getOAuthStatus safely retrieves an OAuth status +func getOAuthStatus(key string) (string, bool) { + oauthStatusMutex.RLock() + defer oauthStatusMutex.RUnlock() + status, ok := oauthStatus[key] + return status, ok +} + +// setOAuthStatus safely sets an OAuth status +func setOAuthStatus(key string, status string) { + oauthStatusMutex.Lock() + defer oauthStatusMutex.Unlock() + oauthStatus[key] = status +} + +// deleteOAuthStatus safely deletes an OAuth status +func deleteOAuthStatus(key string) { + oauthStatusMutex.Lock() + defer oauthStatusMutex.Unlock() + delete(oauthStatus, key) +} + var lastRefreshKeys = []string{"last_refresh", "lastRefresh", "last_refreshed_at", "lastRefreshedAt"} const ( @@ -760,7 +783,7 @@ func (h *Handler) RequestAnthropicToken(c *gin.Context) { deadline := time.Now().Add(timeout) for { if time.Now().After(deadline) { - oauthStatus[state] = "Timeout waiting for OAuth callback" + setOAuthStatus(state, "Timeout waiting for OAuth callback") return nil, fmt.Errorf("timeout waiting for OAuth callback") } data, errRead := os.ReadFile(path) @@ -785,13 +808,13 @@ func (h *Handler) RequestAnthropicToken(c *gin.Context) { if errStr := resultMap["error"]; errStr != "" { oauthErr := claude.NewOAuthError(errStr, "", http.StatusBadRequest) log.Error(claude.GetUserFriendlyMessage(oauthErr)) - oauthStatus[state] = "Bad request" + setOAuthStatus(state, "Bad request") return } if resultMap["state"] != state { authErr := claude.NewAuthenticationError(claude.ErrInvalidState, fmt.Errorf("expected %s, got %s", state, resultMap["state"])) log.Error(claude.GetUserFriendlyMessage(authErr)) - oauthStatus[state] = "State code error" + setOAuthStatus(state, "State code error") return } @@ -824,7 +847,7 @@ func (h *Handler) RequestAnthropicToken(c *gin.Context) { if errDo != nil { authErr := claude.NewAuthenticationError(claude.ErrCodeExchangeFailed, errDo) log.Errorf("Failed to exchange authorization code for tokens: %v", authErr) - oauthStatus[state] = "Failed to exchange authorization code for tokens" + setOAuthStatus(state, "Failed to exchange authorization code for tokens") return } defer func() { @@ -835,7 +858,7 @@ func (h *Handler) RequestAnthropicToken(c *gin.Context) { respBody, _ := io.ReadAll(resp.Body) if resp.StatusCode != http.StatusOK { log.Errorf("token exchange failed with status %d: %s", resp.StatusCode, string(respBody)) - oauthStatus[state] = fmt.Sprintf("token exchange failed with status %d", resp.StatusCode) + setOAuthStatus(state, fmt.Sprintf("token exchange failed with status %d", resp.StatusCode)) return } var tResp struct { @@ -848,7 +871,7 @@ func (h *Handler) RequestAnthropicToken(c *gin.Context) { } if errU := json.Unmarshal(respBody, &tResp); errU != nil { log.Errorf("failed to parse token response: %v", errU) - oauthStatus[state] = "Failed to parse token response" + setOAuthStatus(state, "Failed to parse token response") return } bundle := &claude.ClaudeAuthBundle{ @@ -873,7 +896,7 @@ func (h *Handler) RequestAnthropicToken(c *gin.Context) { savedPath, errSave := h.saveTokenRecord(ctx, record) if errSave != nil { log.Fatalf("Failed to save authentication tokens: %v", errSave) - oauthStatus[state] = "Failed to save authentication tokens" + setOAuthStatus(state, "Failed to save authentication tokens") return } @@ -882,10 +905,10 @@ func (h *Handler) RequestAnthropicToken(c *gin.Context) { fmt.Println("API key obtained and saved") } fmt.Println("You can now use Claude services through this CLI") - delete(oauthStatus, state) + deleteOAuthStatus(state) }() - oauthStatus[state] = "" + setOAuthStatus(state, "") c.JSON(200, gin.H{"status": "ok", "url": authURL, "state": state}) } @@ -944,7 +967,7 @@ func (h *Handler) RequestGeminiCLIToken(c *gin.Context) { for { if time.Now().After(deadline) { log.Error("oauth flow timed out") - oauthStatus[state] = "OAuth flow timed out" + setOAuthStatus(state, "OAuth flow timed out") return } if data, errR := os.ReadFile(waitFile); errR == nil { @@ -953,13 +976,13 @@ func (h *Handler) RequestGeminiCLIToken(c *gin.Context) { _ = os.Remove(waitFile) if errStr := m["error"]; errStr != "" { log.Errorf("Authentication failed: %s", errStr) - oauthStatus[state] = "Authentication failed" + setOAuthStatus(state, "Authentication failed") return } authCode = m["code"] if authCode == "" { log.Errorf("Authentication failed: code not found") - oauthStatus[state] = "Authentication failed: code not found" + setOAuthStatus(state, "Authentication failed: code not found") return } break @@ -971,7 +994,7 @@ func (h *Handler) RequestGeminiCLIToken(c *gin.Context) { token, err := conf.Exchange(ctx, authCode) if err != nil { log.Errorf("Failed to exchange token: %v", err) - oauthStatus[state] = "Failed to exchange token" + setOAuthStatus(state, "Failed to exchange token") return } @@ -982,7 +1005,7 @@ func (h *Handler) RequestGeminiCLIToken(c *gin.Context) { req, errNewRequest := http.NewRequestWithContext(ctx, "GET", "https://www.googleapis.com/oauth2/v1/userinfo?alt=json", nil) if errNewRequest != nil { log.Errorf("Could not get user info: %v", errNewRequest) - oauthStatus[state] = "Could not get user info" + setOAuthStatus(state, "Could not get user info") return } req.Header.Set("Content-Type", "application/json") @@ -991,7 +1014,7 @@ func (h *Handler) RequestGeminiCLIToken(c *gin.Context) { resp, errDo := authHTTPClient.Do(req) if errDo != nil { log.Errorf("Failed to execute request: %v", errDo) - oauthStatus[state] = "Failed to execute request" + setOAuthStatus(state, "Failed to execute request") return } defer func() { @@ -1003,7 +1026,7 @@ func (h *Handler) RequestGeminiCLIToken(c *gin.Context) { bodyBytes, _ := io.ReadAll(resp.Body) if resp.StatusCode < 200 || resp.StatusCode >= 300 { log.Errorf("Get user info request failed with status %d: %s", resp.StatusCode, string(bodyBytes)) - oauthStatus[state] = fmt.Sprintf("Get user info request failed with status %d", resp.StatusCode) + setOAuthStatus(state, fmt.Sprintf("Get user info request failed with status %d", resp.StatusCode)) return } @@ -1012,7 +1035,7 @@ func (h *Handler) RequestGeminiCLIToken(c *gin.Context) { fmt.Printf("Authenticated user email: %s\n", email) } else { fmt.Println("Failed to get user email from token") - oauthStatus[state] = "Failed to get user email from token" + setOAuthStatus(state, "Failed to get user email from token") } // Marshal/unmarshal oauth2.Token to generic map and enrich fields @@ -1020,7 +1043,7 @@ func (h *Handler) RequestGeminiCLIToken(c *gin.Context) { jsonData, _ := json.Marshal(token) if errUnmarshal := json.Unmarshal(jsonData, &ifToken); errUnmarshal != nil { log.Errorf("Failed to unmarshal token: %v", errUnmarshal) - oauthStatus[state] = "Failed to unmarshal token" + setOAuthStatus(state, "Failed to unmarshal token") return } @@ -1046,7 +1069,7 @@ func (h *Handler) RequestGeminiCLIToken(c *gin.Context) { gemClient, errGetClient := gemAuth.GetAuthenticatedClient(ctx, &ts, h.cfg, true) if errGetClient != nil { log.Fatalf("failed to get authenticated client: %v", errGetClient) - oauthStatus[state] = "Failed to get authenticated client" + setOAuthStatus(state, "Failed to get authenticated client") return } fmt.Println("Authentication successful.") @@ -1056,12 +1079,12 @@ func (h *Handler) RequestGeminiCLIToken(c *gin.Context) { projects, errAll := onboardAllGeminiProjects(ctx, gemClient, &ts) if errAll != nil { log.Errorf("Failed to complete Gemini CLI onboarding: %v", errAll) - oauthStatus[state] = "Failed to complete Gemini CLI onboarding" + setOAuthStatus(state, "Failed to complete Gemini CLI onboarding") return } if errVerify := ensureGeminiProjectsEnabled(ctx, gemClient, projects); errVerify != nil { log.Errorf("Failed to verify Cloud AI API status: %v", errVerify) - oauthStatus[state] = "Failed to verify Cloud AI API status" + setOAuthStatus(state, "Failed to verify Cloud AI API status") return } ts.ProjectID = strings.Join(projects, ",") @@ -1069,26 +1092,26 @@ func (h *Handler) RequestGeminiCLIToken(c *gin.Context) { } else { if errEnsure := ensureGeminiProjectAndOnboard(ctx, gemClient, &ts, requestedProjectID); errEnsure != nil { log.Errorf("Failed to complete Gemini CLI onboarding: %v", errEnsure) - oauthStatus[state] = "Failed to complete Gemini CLI onboarding" + setOAuthStatus(state, "Failed to complete Gemini CLI onboarding") return } if strings.TrimSpace(ts.ProjectID) == "" { log.Error("Onboarding did not return a project ID") - oauthStatus[state] = "Failed to resolve project ID" + setOAuthStatus(state, "Failed to resolve project ID") return } isChecked, errCheck := checkCloudAPIIsEnabled(ctx, gemClient, ts.ProjectID) if errCheck != nil { log.Errorf("Failed to verify Cloud AI API status: %v", errCheck) - oauthStatus[state] = "Failed to verify Cloud AI API status" + setOAuthStatus(state, "Failed to verify Cloud AI API status") return } ts.Checked = isChecked if !isChecked { log.Error("Cloud AI API is not enabled for the selected project") - oauthStatus[state] = "Cloud AI API not enabled" + setOAuthStatus(state, "Cloud AI API not enabled") return } } @@ -1111,15 +1134,15 @@ func (h *Handler) RequestGeminiCLIToken(c *gin.Context) { savedPath, errSave := h.saveTokenRecord(ctx, record) if errSave != nil { log.Fatalf("Failed to save token to file: %v", errSave) - oauthStatus[state] = "Failed to save token to file" + setOAuthStatus(state, "Failed to save token to file") return } - delete(oauthStatus, state) + deleteOAuthStatus(state) fmt.Printf("You can now use Gemini CLI services through this CLI; token saved to %s\n", savedPath) }() - oauthStatus[state] = "" + setOAuthStatus(state, "") c.JSON(200, gin.H{"status": "ok", "url": authURL, "state": state}) } @@ -1180,7 +1203,7 @@ func (h *Handler) RequestCodexToken(c *gin.Context) { if time.Now().After(deadline) { authErr := codex.NewAuthenticationError(codex.ErrCallbackTimeout, fmt.Errorf("timeout waiting for OAuth callback")) log.Error(codex.GetUserFriendlyMessage(authErr)) - oauthStatus[state] = "Timeout waiting for OAuth callback" + setOAuthStatus(state, "Timeout waiting for OAuth callback") return } if data, errR := os.ReadFile(waitFile); errR == nil { @@ -1190,12 +1213,12 @@ func (h *Handler) RequestCodexToken(c *gin.Context) { if errStr := m["error"]; errStr != "" { oauthErr := codex.NewOAuthError(errStr, "", http.StatusBadRequest) log.Error(codex.GetUserFriendlyMessage(oauthErr)) - oauthStatus[state] = "Bad Request" + setOAuthStatus(state, "Bad Request") return } if m["state"] != state { authErr := codex.NewAuthenticationError(codex.ErrInvalidState, fmt.Errorf("expected %s, got %s", state, m["state"])) - oauthStatus[state] = "State code error" + setOAuthStatus(state, "State code error") log.Error(codex.GetUserFriendlyMessage(authErr)) return } @@ -1226,14 +1249,14 @@ func (h *Handler) RequestCodexToken(c *gin.Context) { resp, errDo := httpClient.Do(req) if errDo != nil { authErr := codex.NewAuthenticationError(codex.ErrCodeExchangeFailed, errDo) - oauthStatus[state] = "Failed to exchange authorization code for tokens" + setOAuthStatus(state, "Failed to exchange authorization code for tokens") log.Errorf("Failed to exchange authorization code for tokens: %v", authErr) return } defer func() { _ = resp.Body.Close() }() respBody, _ := io.ReadAll(resp.Body) if resp.StatusCode != http.StatusOK { - oauthStatus[state] = fmt.Sprintf("Token exchange failed with status %d", resp.StatusCode) + setOAuthStatus(state, fmt.Sprintf("Token exchange failed with status %d", resp.StatusCode)) log.Errorf("token exchange failed with status %d: %s", resp.StatusCode, string(respBody)) return } @@ -1244,7 +1267,7 @@ func (h *Handler) RequestCodexToken(c *gin.Context) { ExpiresIn int `json:"expires_in"` } if errU := json.Unmarshal(respBody, &tokenResp); errU != nil { - oauthStatus[state] = "Failed to parse token response" + setOAuthStatus(state, "Failed to parse token response") log.Errorf("failed to parse token response: %v", errU) return } @@ -1282,7 +1305,7 @@ func (h *Handler) RequestCodexToken(c *gin.Context) { } savedPath, errSave := h.saveTokenRecord(ctx, record) if errSave != nil { - oauthStatus[state] = "Failed to save authentication tokens" + setOAuthStatus(state, "Failed to save authentication tokens") log.Fatalf("Failed to save authentication tokens: %v", errSave) return } @@ -1291,10 +1314,10 @@ func (h *Handler) RequestCodexToken(c *gin.Context) { fmt.Println("API key obtained and saved") } fmt.Println("You can now use Codex services through this CLI") - delete(oauthStatus, state) + deleteOAuthStatus(state) }() - oauthStatus[state] = "" + setOAuthStatus(state, "") c.JSON(200, gin.H{"status": "ok", "url": authURL, "state": state}) } @@ -1360,7 +1383,7 @@ func (h *Handler) RequestAntigravityToken(c *gin.Context) { for { if time.Now().After(deadline) { log.Error("oauth flow timed out") - oauthStatus[state] = "OAuth flow timed out" + setOAuthStatus(state, "OAuth flow timed out") return } if data, errReadFile := os.ReadFile(waitFile); errReadFile == nil { @@ -1369,18 +1392,18 @@ func (h *Handler) RequestAntigravityToken(c *gin.Context) { _ = os.Remove(waitFile) if errStr := strings.TrimSpace(payload["error"]); errStr != "" { log.Errorf("Authentication failed: %s", errStr) - oauthStatus[state] = "Authentication failed" + setOAuthStatus(state, "Authentication failed") return } if payloadState := strings.TrimSpace(payload["state"]); payloadState != "" && payloadState != state { log.Errorf("Authentication failed: state mismatch") - oauthStatus[state] = "Authentication failed: state mismatch" + setOAuthStatus(state, "Authentication failed: state mismatch") return } authCode = strings.TrimSpace(payload["code"]) if authCode == "" { log.Error("Authentication failed: code not found") - oauthStatus[state] = "Authentication failed: code not found" + setOAuthStatus(state, "Authentication failed: code not found") return } break @@ -1399,7 +1422,7 @@ func (h *Handler) RequestAntigravityToken(c *gin.Context) { req, errNewRequest := http.NewRequestWithContext(ctx, http.MethodPost, "https://oauth2.googleapis.com/token", strings.NewReader(form.Encode())) if errNewRequest != nil { log.Errorf("Failed to build token request: %v", errNewRequest) - oauthStatus[state] = "Failed to build token request" + setOAuthStatus(state, "Failed to build token request") return } req.Header.Set("Content-Type", "application/x-www-form-urlencoded") @@ -1407,7 +1430,7 @@ func (h *Handler) RequestAntigravityToken(c *gin.Context) { resp, errDo := httpClient.Do(req) if errDo != nil { log.Errorf("Failed to execute token request: %v", errDo) - oauthStatus[state] = "Failed to exchange token" + setOAuthStatus(state, "Failed to exchange token") return } defer func() { @@ -1419,7 +1442,7 @@ func (h *Handler) RequestAntigravityToken(c *gin.Context) { if resp.StatusCode < http.StatusOK || resp.StatusCode >= http.StatusMultipleChoices { bodyBytes, _ := io.ReadAll(resp.Body) log.Errorf("Antigravity token exchange failed with status %d: %s", resp.StatusCode, string(bodyBytes)) - oauthStatus[state] = fmt.Sprintf("Token exchange failed: %d", resp.StatusCode) + setOAuthStatus(state, fmt.Sprintf("Token exchange failed: %d", resp.StatusCode)) return } @@ -1431,7 +1454,7 @@ func (h *Handler) RequestAntigravityToken(c *gin.Context) { } if errDecode := json.NewDecoder(resp.Body).Decode(&tokenResp); errDecode != nil { log.Errorf("Failed to parse token response: %v", errDecode) - oauthStatus[state] = "Failed to parse token response" + setOAuthStatus(state, "Failed to parse token response") return } @@ -1440,7 +1463,7 @@ func (h *Handler) RequestAntigravityToken(c *gin.Context) { infoReq, errInfoReq := http.NewRequestWithContext(ctx, http.MethodGet, "https://www.googleapis.com/oauth2/v1/userinfo?alt=json", nil) if errInfoReq != nil { log.Errorf("Failed to build user info request: %v", errInfoReq) - oauthStatus[state] = "Failed to build user info request" + setOAuthStatus(state, "Failed to build user info request") return } infoReq.Header.Set("Authorization", "Bearer "+tokenResp.AccessToken) @@ -1448,7 +1471,7 @@ func (h *Handler) RequestAntigravityToken(c *gin.Context) { infoResp, errInfo := httpClient.Do(infoReq) if errInfo != nil { log.Errorf("Failed to execute user info request: %v", errInfo) - oauthStatus[state] = "Failed to execute user info request" + setOAuthStatus(state, "Failed to execute user info request") return } defer func() { @@ -1467,7 +1490,7 @@ func (h *Handler) RequestAntigravityToken(c *gin.Context) { } else { bodyBytes, _ := io.ReadAll(infoResp.Body) log.Errorf("User info request failed with status %d: %s", infoResp.StatusCode, string(bodyBytes)) - oauthStatus[state] = fmt.Sprintf("User info request failed: %d", infoResp.StatusCode) + setOAuthStatus(state, fmt.Sprintf("User info request failed: %d", infoResp.StatusCode)) return } } @@ -1515,11 +1538,11 @@ func (h *Handler) RequestAntigravityToken(c *gin.Context) { savedPath, errSave := h.saveTokenRecord(ctx, record) if errSave != nil { log.Fatalf("Failed to save token to file: %v", errSave) - oauthStatus[state] = "Failed to save token to file" + setOAuthStatus(state, "Failed to save token to file") return } - delete(oauthStatus, state) + deleteOAuthStatus(state) fmt.Printf("Authentication successful! Token saved to %s\n", savedPath) if projectID != "" { fmt.Printf("Using GCP project: %s\n", projectID) @@ -1527,7 +1550,7 @@ func (h *Handler) RequestAntigravityToken(c *gin.Context) { fmt.Println("You can now use Antigravity services through this CLI") }() - oauthStatus[state] = "" + setOAuthStatus(state, "") c.JSON(200, gin.H{"status": "ok", "url": authURL, "state": state}) } @@ -1552,7 +1575,7 @@ func (h *Handler) RequestQwenToken(c *gin.Context) { fmt.Println("Waiting for authentication...") tokenData, errPollForToken := qwenAuth.PollForToken(deviceFlow.DeviceCode, deviceFlow.CodeVerifier) if errPollForToken != nil { - oauthStatus[state] = "Authentication failed" + setOAuthStatus(state, "Authentication failed") fmt.Printf("Authentication failed: %v\n", errPollForToken) return } @@ -1571,16 +1594,16 @@ func (h *Handler) RequestQwenToken(c *gin.Context) { savedPath, errSave := h.saveTokenRecord(ctx, record) if errSave != nil { log.Fatalf("Failed to save authentication tokens: %v", errSave) - oauthStatus[state] = "Failed to save authentication tokens" + setOAuthStatus(state, "Failed to save authentication tokens") return } fmt.Printf("Authentication successful! Token saved to %s\n", savedPath) fmt.Println("You can now use Qwen services through this CLI") - delete(oauthStatus, state) + deleteOAuthStatus(state) }() - oauthStatus[state] = "" + setOAuthStatus(state, "") c.JSON(200, gin.H{"status": "ok", "url": authURL, "state": state}) } @@ -1619,7 +1642,7 @@ func (h *Handler) RequestIFlowToken(c *gin.Context) { var resultMap map[string]string for { if time.Now().After(deadline) { - oauthStatus[state] = "Authentication failed" + setOAuthStatus(state, "Authentication failed") fmt.Println("Authentication failed: timeout waiting for callback") return } @@ -1632,26 +1655,26 @@ func (h *Handler) RequestIFlowToken(c *gin.Context) { } if errStr := strings.TrimSpace(resultMap["error"]); errStr != "" { - oauthStatus[state] = "Authentication failed" + setOAuthStatus(state, "Authentication failed") fmt.Printf("Authentication failed: %s\n", errStr) return } if resultState := strings.TrimSpace(resultMap["state"]); resultState != state { - oauthStatus[state] = "Authentication failed" + setOAuthStatus(state, "Authentication failed") fmt.Println("Authentication failed: state mismatch") return } code := strings.TrimSpace(resultMap["code"]) if code == "" { - oauthStatus[state] = "Authentication failed" + setOAuthStatus(state, "Authentication failed") fmt.Println("Authentication failed: code missing") return } tokenData, errExchange := authSvc.ExchangeCodeForTokens(ctx, code, redirectURI) if errExchange != nil { - oauthStatus[state] = "Authentication failed" + setOAuthStatus(state, "Authentication failed") fmt.Printf("Authentication failed: %v\n", errExchange) return } @@ -1673,7 +1696,7 @@ func (h *Handler) RequestIFlowToken(c *gin.Context) { savedPath, errSave := h.saveTokenRecord(ctx, record) if errSave != nil { - oauthStatus[state] = "Failed to save authentication tokens" + setOAuthStatus(state, "Failed to save authentication tokens") log.Fatalf("Failed to save authentication tokens: %v", errSave) return } @@ -1683,10 +1706,10 @@ func (h *Handler) RequestIFlowToken(c *gin.Context) { fmt.Println("API key obtained and saved") } fmt.Println("You can now use iFlow services through this CLI") - delete(oauthStatus, state) + deleteOAuthStatus(state) }() - oauthStatus[state] = "" + setOAuthStatus(state, "") c.JSON(http.StatusOK, gin.H{"status": "ok", "url": authURL, "state": state}) } @@ -2110,7 +2133,7 @@ func checkCloudAPIIsEnabled(ctx context.Context, httpClient *http.Client, projec func (h *Handler) GetAuthStatus(c *gin.Context) { state := c.Query("state") - if err, ok := oauthStatus[state]; ok { + if err, ok := getOAuthStatus(state); ok { if err != "" { c.JSON(200, gin.H{"status": "error", "error": err}) } else { @@ -2120,5 +2143,5 @@ func (h *Handler) GetAuthStatus(c *gin.Context) { } else { c.JSON(200, gin.H{"status": "ok"}) } - delete(oauthStatus, state) + deleteOAuthStatus(state) } diff --git a/internal/api/server.go b/internal/api/server.go index 9e1c5848..72cb0313 100644 --- a/internal/api/server.go +++ b/internal/api/server.go @@ -902,7 +902,7 @@ func (s *Server) UpdateClients(cfg *config.Config) { for _, p := range cfg.OpenAICompatibility { providerNames = append(providerNames, p.Name) } - s.handlers.OpenAICompatProviders = providerNames + s.handlers.SetOpenAICompatProviders(providerNames) s.handlers.UpdateClients(&cfg.SDKConfig) diff --git a/internal/auth/claude/oauth_server.go b/internal/auth/claude/oauth_server.go index a6ebe2f7..49b04794 100644 --- a/internal/auth/claude/oauth_server.go +++ b/internal/auth/claude/oauth_server.go @@ -242,6 +242,11 @@ func (s *OAuthServer) handleSuccess(w http.ResponseWriter, r *http.Request) { platformURL = "https://console.anthropic.com/" } + // Validate platformURL to prevent XSS - only allow http/https URLs + if !isValidURL(platformURL) { + platformURL = "https://console.anthropic.com/" + } + // Generate success page HTML with dynamic content successHTML := s.generateSuccessHTML(setupRequired, platformURL) @@ -251,6 +256,12 @@ func (s *OAuthServer) handleSuccess(w http.ResponseWriter, r *http.Request) { } } +// isValidURL checks if the URL is a valid http/https URL to prevent XSS +func isValidURL(urlStr string) bool { + urlStr = strings.TrimSpace(urlStr) + return strings.HasPrefix(urlStr, "https://") || strings.HasPrefix(urlStr, "http://") +} + // generateSuccessHTML creates the HTML content for the success page. // It customizes the page based on whether additional setup is required // and includes a link to the platform. diff --git a/internal/auth/codex/oauth_server.go b/internal/auth/codex/oauth_server.go index 9c6a6c5b..58b5394e 100644 --- a/internal/auth/codex/oauth_server.go +++ b/internal/auth/codex/oauth_server.go @@ -239,6 +239,11 @@ func (s *OAuthServer) handleSuccess(w http.ResponseWriter, r *http.Request) { platformURL = "https://platform.openai.com" } + // Validate platformURL to prevent XSS - only allow http/https URLs + if !isValidURL(platformURL) { + platformURL = "https://platform.openai.com" + } + // Generate success page HTML with dynamic content successHTML := s.generateSuccessHTML(setupRequired, platformURL) @@ -248,6 +253,12 @@ func (s *OAuthServer) handleSuccess(w http.ResponseWriter, r *http.Request) { } } +// isValidURL checks if the URL is a valid http/https URL to prevent XSS +func isValidURL(urlStr string) bool { + urlStr = strings.TrimSpace(urlStr) + return strings.HasPrefix(urlStr, "https://") || strings.HasPrefix(urlStr, "http://") +} + // generateSuccessHTML creates the HTML content for the success page. // It customizes the page based on whether additional setup is required // and includes a link to the platform. diff --git a/internal/auth/iflow/iflow_auth.go b/internal/auth/iflow/iflow_auth.go index 4957f519..b3431f84 100644 --- a/internal/auth/iflow/iflow_auth.go +++ b/internal/auth/iflow/iflow_auth.go @@ -9,6 +9,7 @@ import ( "io" "net/http" "net/url" + "os" "strings" "time" @@ -28,10 +29,21 @@ const ( iFlowAPIKeyEndpoint = "https://platform.iflow.cn/api/openapi/apikey" // Client credentials provided by iFlow for the Code Assist integration. - iFlowOAuthClientID = "10009311001" - iFlowOAuthClientSecret = "4Z3YjXycVsQvyGF1etiNlIBB4RsqSDtW" + iFlowOAuthClientID = "10009311001" + // Default client secret (can be overridden via IFLOW_CLIENT_SECRET env var) + defaultIFlowClientSecret = "4Z3YjXycVsQvyGF1etiNlIBB4RsqSDtW" ) +// getIFlowClientSecret returns the iFlow OAuth client secret. +// It first checks the IFLOW_CLIENT_SECRET environment variable, +// falling back to the default value if not set. +func getIFlowClientSecret() string { + if secret := os.Getenv("IFLOW_CLIENT_SECRET"); secret != "" { + return secret + } + return defaultIFlowClientSecret +} + // DefaultAPIBaseURL is the canonical chat completions endpoint. const DefaultAPIBaseURL = "https://apis.iflow.cn/v1" @@ -72,7 +84,7 @@ func (ia *IFlowAuth) ExchangeCodeForTokens(ctx context.Context, code, redirectUR form.Set("code", code) form.Set("redirect_uri", redirectURI) form.Set("client_id", iFlowOAuthClientID) - form.Set("client_secret", iFlowOAuthClientSecret) + form.Set("client_secret", getIFlowClientSecret()) req, err := ia.newTokenRequest(ctx, form) if err != nil { @@ -88,7 +100,7 @@ func (ia *IFlowAuth) RefreshTokens(ctx context.Context, refreshToken string) (*I form.Set("grant_type", "refresh_token") form.Set("refresh_token", refreshToken) form.Set("client_id", iFlowOAuthClientID) - form.Set("client_secret", iFlowOAuthClientSecret) + form.Set("client_secret", getIFlowClientSecret()) req, err := ia.newTokenRequest(ctx, form) if err != nil { @@ -104,7 +116,7 @@ func (ia *IFlowAuth) newTokenRequest(ctx context.Context, form url.Values) (*htt return nil, fmt.Errorf("iflow token: create request failed: %w", err) } - basic := base64.StdEncoding.EncodeToString([]byte(iFlowOAuthClientID + ":" + iFlowOAuthClientSecret)) + basic := base64.StdEncoding.EncodeToString([]byte(iFlowOAuthClientID + ":" + getIFlowClientSecret())) req.Header.Set("Content-Type", "application/x-www-form-urlencoded") req.Header.Set("Accept", "application/json") req.Header.Set("Authorization", "Basic "+basic) diff --git a/internal/auth/kiro/aws.go b/internal/auth/kiro/aws.go new file mode 100644 index 00000000..9be025c2 --- /dev/null +++ b/internal/auth/kiro/aws.go @@ -0,0 +1,301 @@ +// Package kiro provides authentication functionality for AWS CodeWhisperer (Kiro) API. +// It includes interfaces and implementations for token storage and authentication methods. +package kiro + +import ( + "encoding/base64" + "encoding/json" + "fmt" + "os" + "path/filepath" + "strings" +) + +// PKCECodes holds PKCE verification codes for OAuth2 PKCE flow +type PKCECodes struct { + // CodeVerifier is the cryptographically random string used to correlate + // the authorization request to the token request + CodeVerifier string `json:"code_verifier"` + // CodeChallenge is the SHA256 hash of the code verifier, base64url-encoded + CodeChallenge string `json:"code_challenge"` +} + +// KiroTokenData holds OAuth token information from AWS CodeWhisperer (Kiro) +type KiroTokenData struct { + // AccessToken is the OAuth2 access token for API access + AccessToken string `json:"accessToken"` + // RefreshToken is used to obtain new access tokens + RefreshToken string `json:"refreshToken"` + // ProfileArn is the AWS CodeWhisperer profile ARN + ProfileArn string `json:"profileArn"` + // ExpiresAt is the timestamp when the token expires + ExpiresAt string `json:"expiresAt"` + // AuthMethod indicates the authentication method used (e.g., "builder-id", "social") + AuthMethod string `json:"authMethod"` + // Provider indicates the OAuth provider (e.g., "AWS", "Google") + Provider string `json:"provider"` + // ClientID is the OIDC client ID (needed for token refresh) + ClientID string `json:"clientId,omitempty"` + // ClientSecret is the OIDC client secret (needed for token refresh) + ClientSecret string `json:"clientSecret,omitempty"` + // Email is the user's email address (used for file naming) + Email string `json:"email,omitempty"` +} + +// KiroAuthBundle aggregates authentication data after OAuth flow completion +type KiroAuthBundle struct { + // TokenData contains the OAuth tokens from the authentication flow + TokenData KiroTokenData `json:"token_data"` + // LastRefresh is the timestamp of the last token refresh + LastRefresh string `json:"last_refresh"` +} + +// KiroUsageInfo represents usage information from CodeWhisperer API +type KiroUsageInfo struct { + // SubscriptionTitle is the subscription plan name (e.g., "KIRO FREE") + SubscriptionTitle string `json:"subscription_title"` + // CurrentUsage is the current credit usage + CurrentUsage float64 `json:"current_usage"` + // UsageLimit is the maximum credit limit + UsageLimit float64 `json:"usage_limit"` + // NextReset is the timestamp of the next usage reset + NextReset string `json:"next_reset"` +} + +// KiroModel represents a model available through the CodeWhisperer API +type KiroModel struct { + // ModelID is the unique identifier for the model + ModelID string `json:"modelId"` + // ModelName is the human-readable name + ModelName string `json:"modelName"` + // Description is the model description + Description string `json:"description"` + // RateMultiplier is the credit multiplier for this model + RateMultiplier float64 `json:"rateMultiplier"` + // RateUnit is the unit for rate calculation (e.g., "credit") + RateUnit string `json:"rateUnit"` + // MaxInputTokens is the maximum input token limit + MaxInputTokens int `json:"maxInputTokens,omitempty"` +} + +// KiroIDETokenFile is the default path to Kiro IDE's token file +const KiroIDETokenFile = ".aws/sso/cache/kiro-auth-token.json" + +// LoadKiroIDEToken loads token data from Kiro IDE's token file. +func LoadKiroIDEToken() (*KiroTokenData, error) { + homeDir, err := os.UserHomeDir() + if err != nil { + return nil, fmt.Errorf("failed to get home directory: %w", err) + } + + tokenPath := filepath.Join(homeDir, KiroIDETokenFile) + data, err := os.ReadFile(tokenPath) + if err != nil { + return nil, fmt.Errorf("failed to read Kiro IDE token file (%s): %w", tokenPath, err) + } + + var token KiroTokenData + if err := json.Unmarshal(data, &token); err != nil { + return nil, fmt.Errorf("failed to parse Kiro IDE token: %w", err) + } + + if token.AccessToken == "" { + return nil, fmt.Errorf("access token is empty in Kiro IDE token file") + } + + return &token, nil +} + +// LoadKiroTokenFromPath loads token data from a custom path. +// This supports multiple accounts by allowing different token files. +func LoadKiroTokenFromPath(tokenPath string) (*KiroTokenData, error) { + // Expand ~ to home directory + if len(tokenPath) > 0 && tokenPath[0] == '~' { + homeDir, err := os.UserHomeDir() + if err != nil { + return nil, fmt.Errorf("failed to get home directory: %w", err) + } + tokenPath = filepath.Join(homeDir, tokenPath[1:]) + } + + data, err := os.ReadFile(tokenPath) + if err != nil { + return nil, fmt.Errorf("failed to read token file (%s): %w", tokenPath, err) + } + + var token KiroTokenData + if err := json.Unmarshal(data, &token); err != nil { + return nil, fmt.Errorf("failed to parse token file: %w", err) + } + + if token.AccessToken == "" { + return nil, fmt.Errorf("access token is empty in token file") + } + + return &token, nil +} + +// ListKiroTokenFiles lists all Kiro token files in the cache directory. +// This supports multiple accounts by finding all token files. +func ListKiroTokenFiles() ([]string, error) { + homeDir, err := os.UserHomeDir() + if err != nil { + return nil, fmt.Errorf("failed to get home directory: %w", err) + } + + cacheDir := filepath.Join(homeDir, ".aws", "sso", "cache") + + // Check if directory exists + if _, err := os.Stat(cacheDir); os.IsNotExist(err) { + return nil, nil // No token files + } + + entries, err := os.ReadDir(cacheDir) + if err != nil { + return nil, fmt.Errorf("failed to read cache directory: %w", err) + } + + var tokenFiles []string + for _, entry := range entries { + if entry.IsDir() { + continue + } + name := entry.Name() + // Look for kiro token files only (avoid matching unrelated AWS SSO cache files) + if strings.HasSuffix(name, ".json") && strings.HasPrefix(name, "kiro") { + tokenFiles = append(tokenFiles, filepath.Join(cacheDir, name)) + } + } + + return tokenFiles, nil +} + +// LoadAllKiroTokens loads all Kiro tokens from the cache directory. +// This supports multiple accounts. +func LoadAllKiroTokens() ([]*KiroTokenData, error) { + files, err := ListKiroTokenFiles() + if err != nil { + return nil, err + } + + var tokens []*KiroTokenData + for _, file := range files { + token, err := LoadKiroTokenFromPath(file) + if err != nil { + // Skip invalid token files + continue + } + tokens = append(tokens, token) + } + + return tokens, nil +} + +// JWTClaims represents the claims we care about from a JWT token. +// JWT tokens from Kiro/AWS contain user information in the payload. +type JWTClaims struct { + Email string `json:"email,omitempty"` + Sub string `json:"sub,omitempty"` + PreferredUser string `json:"preferred_username,omitempty"` + Name string `json:"name,omitempty"` + Iss string `json:"iss,omitempty"` +} + +// ExtractEmailFromJWT extracts the user's email from a JWT access token. +// JWT tokens typically have format: header.payload.signature +// The payload is base64url-encoded JSON containing user claims. +func ExtractEmailFromJWT(accessToken string) string { + if accessToken == "" { + return "" + } + + // JWT format: header.payload.signature + parts := strings.Split(accessToken, ".") + if len(parts) != 3 { + return "" + } + + // Decode the payload (second part) + payload := parts[1] + + // Add padding if needed (base64url requires padding) + switch len(payload) % 4 { + case 2: + payload += "==" + case 3: + payload += "=" + } + + decoded, err := base64.URLEncoding.DecodeString(payload) + if err != nil { + // Try RawURLEncoding (no padding) + decoded, err = base64.RawURLEncoding.DecodeString(parts[1]) + if err != nil { + return "" + } + } + + var claims JWTClaims + if err := json.Unmarshal(decoded, &claims); err != nil { + return "" + } + + // Return email if available + if claims.Email != "" { + return claims.Email + } + + // Fallback to preferred_username (some providers use this) + if claims.PreferredUser != "" && strings.Contains(claims.PreferredUser, "@") { + return claims.PreferredUser + } + + // Fallback to sub if it looks like an email + if claims.Sub != "" && strings.Contains(claims.Sub, "@") { + return claims.Sub + } + + return "" +} + +// SanitizeEmailForFilename sanitizes an email address for use in a filename. +// Replaces special characters with underscores and prevents path traversal attacks. +// Also handles URL-encoded characters to prevent encoded path traversal attempts. +func SanitizeEmailForFilename(email string) string { + if email == "" { + return "" + } + + result := email + + // First, handle URL-encoded path traversal attempts (%2F, %2E, %5C, etc.) + // This prevents encoded characters from bypassing the sanitization. + // Note: We replace % last to catch any remaining encodings including double-encoding (%252F) + result = strings.ReplaceAll(result, "%2F", "_") // / + result = strings.ReplaceAll(result, "%2f", "_") + result = strings.ReplaceAll(result, "%5C", "_") // \ + result = strings.ReplaceAll(result, "%5c", "_") + result = strings.ReplaceAll(result, "%2E", "_") // . + result = strings.ReplaceAll(result, "%2e", "_") + result = strings.ReplaceAll(result, "%00", "_") // null byte + result = strings.ReplaceAll(result, "%", "_") // Catch remaining % to prevent double-encoding attacks + + // Replace characters that are problematic in filenames + // Keep @ and . in middle but replace other special characters + for _, char := range []string{"/", "\\", ":", "*", "?", "\"", "<", ">", "|", " ", "\x00"} { + result = strings.ReplaceAll(result, char, "_") + } + + // Prevent path traversal: replace leading dots in each path component + // This handles cases like "../../../etc/passwd" → "_.._.._.._etc_passwd" + parts := strings.Split(result, "_") + for i, part := range parts { + for strings.HasPrefix(part, ".") { + part = "_" + part[1:] + } + parts[i] = part + } + result = strings.Join(parts, "_") + + return result +} diff --git a/internal/auth/kiro/aws_auth.go b/internal/auth/kiro/aws_auth.go new file mode 100644 index 00000000..53c77a8b --- /dev/null +++ b/internal/auth/kiro/aws_auth.go @@ -0,0 +1,314 @@ +// Package kiro provides OAuth2 authentication functionality for AWS CodeWhisperer (Kiro) API. +// This package implements token loading, refresh, and API communication with CodeWhisperer. +package kiro + +import ( + "context" + "encoding/json" + "fmt" + "io" + "net/http" + "os" + "path/filepath" + "strings" + "time" + + "github.com/router-for-me/CLIProxyAPI/v6/internal/config" + "github.com/router-for-me/CLIProxyAPI/v6/internal/util" + log "github.com/sirupsen/logrus" +) + +const ( + // awsKiroEndpoint is used for CodeWhisperer management APIs (GetUsageLimits, ListProfiles, etc.) + // Note: This is different from the Amazon Q streaming endpoint (q.us-east-1.amazonaws.com) + // used in kiro_executor.go for GenerateAssistantResponse. Both endpoints are correct + // for their respective API operations. + awsKiroEndpoint = "https://codewhisperer.us-east-1.amazonaws.com" + defaultTokenFile = "~/.aws/sso/cache/kiro-auth-token.json" + targetGetUsage = "AmazonCodeWhispererService.GetUsageLimits" + targetListModels = "AmazonCodeWhispererService.ListAvailableModels" + targetGenerateChat = "AmazonCodeWhispererStreamingService.GenerateAssistantResponse" +) + +// KiroAuth handles AWS CodeWhisperer authentication and API communication. +// It provides methods for loading tokens, refreshing expired tokens, +// and communicating with the CodeWhisperer API. +type KiroAuth struct { + httpClient *http.Client + endpoint string +} + +// NewKiroAuth creates a new Kiro authentication service. +// It initializes the HTTP client with proxy settings from the configuration. +// +// Parameters: +// - cfg: The application configuration containing proxy settings +// +// Returns: +// - *KiroAuth: A new Kiro authentication service instance +func NewKiroAuth(cfg *config.Config) *KiroAuth { + return &KiroAuth{ + httpClient: util.SetProxy(&cfg.SDKConfig, &http.Client{Timeout: 120 * time.Second}), + endpoint: awsKiroEndpoint, + } +} + +// LoadTokenFromFile loads token data from a file path. +// This method reads and parses the token file, expanding ~ to the home directory. +// +// Parameters: +// - tokenFile: Path to the token file (supports ~ expansion) +// +// Returns: +// - *KiroTokenData: The parsed token data +// - error: An error if file reading or parsing fails +func (k *KiroAuth) LoadTokenFromFile(tokenFile string) (*KiroTokenData, error) { + // Expand ~ to home directory + if strings.HasPrefix(tokenFile, "~") { + home, err := os.UserHomeDir() + if err != nil { + return nil, fmt.Errorf("failed to get home directory: %w", err) + } + tokenFile = filepath.Join(home, tokenFile[1:]) + } + + data, err := os.ReadFile(tokenFile) + if err != nil { + return nil, fmt.Errorf("failed to read token file: %w", err) + } + + var tokenData KiroTokenData + if err := json.Unmarshal(data, &tokenData); err != nil { + return nil, fmt.Errorf("failed to parse token file: %w", err) + } + + return &tokenData, nil +} + +// IsTokenExpired checks if the token has expired. +// This method parses the expiration timestamp and compares it with the current time. +// +// Parameters: +// - tokenData: The token data to check +// +// Returns: +// - bool: True if the token has expired, false otherwise +func (k *KiroAuth) IsTokenExpired(tokenData *KiroTokenData) bool { + if tokenData.ExpiresAt == "" { + return true + } + + expiresAt, err := time.Parse(time.RFC3339, tokenData.ExpiresAt) + if err != nil { + // Try alternate format + expiresAt, err = time.Parse("2006-01-02T15:04:05.000Z", tokenData.ExpiresAt) + if err != nil { + return true + } + } + + return time.Now().After(expiresAt) +} + +// makeRequest sends a request to the CodeWhisperer API. +// This is an internal method for making authenticated API calls. +// +// Parameters: +// - ctx: The context for the request +// - target: The API target (e.g., "AmazonCodeWhispererService.GetUsageLimits") +// - accessToken: The OAuth access token +// - payload: The request payload +// +// Returns: +// - []byte: The response body +// - error: An error if the request fails +func (k *KiroAuth) makeRequest(ctx context.Context, target string, accessToken string, payload interface{}) ([]byte, error) { + jsonBody, err := json.Marshal(payload) + if err != nil { + return nil, fmt.Errorf("failed to marshal request: %w", err) + } + + req, err := http.NewRequestWithContext(ctx, http.MethodPost, k.endpoint, strings.NewReader(string(jsonBody))) + if err != nil { + return nil, fmt.Errorf("failed to create request: %w", err) + } + + req.Header.Set("Content-Type", "application/x-amz-json-1.0") + req.Header.Set("x-amz-target", target) + req.Header.Set("Authorization", "Bearer "+accessToken) + req.Header.Set("Accept", "application/json") + + resp, err := k.httpClient.Do(req) + if err != nil { + return nil, fmt.Errorf("request failed: %w", err) + } + defer func() { + if errClose := resp.Body.Close(); errClose != nil { + log.Errorf("failed to close response body: %v", errClose) + } + }() + + body, err := io.ReadAll(resp.Body) + if err != nil { + return nil, fmt.Errorf("failed to read response: %w", err) + } + + if resp.StatusCode != http.StatusOK { + return nil, fmt.Errorf("API error (status %d): %s", resp.StatusCode, string(body)) + } + + return body, nil +} + +// GetUsageLimits retrieves usage information from the CodeWhisperer API. +// This method fetches the current usage statistics and subscription information. +// +// Parameters: +// - ctx: The context for the request +// - tokenData: The token data containing access token and profile ARN +// +// Returns: +// - *KiroUsageInfo: The usage information +// - error: An error if the request fails +func (k *KiroAuth) GetUsageLimits(ctx context.Context, tokenData *KiroTokenData) (*KiroUsageInfo, error) { + payload := map[string]interface{}{ + "origin": "AI_EDITOR", + "profileArn": tokenData.ProfileArn, + "resourceType": "AGENTIC_REQUEST", + } + + body, err := k.makeRequest(ctx, targetGetUsage, tokenData.AccessToken, payload) + if err != nil { + return nil, err + } + + var result struct { + SubscriptionInfo struct { + SubscriptionTitle string `json:"subscriptionTitle"` + } `json:"subscriptionInfo"` + UsageBreakdownList []struct { + CurrentUsageWithPrecision float64 `json:"currentUsageWithPrecision"` + UsageLimitWithPrecision float64 `json:"usageLimitWithPrecision"` + } `json:"usageBreakdownList"` + NextDateReset float64 `json:"nextDateReset"` + } + + if err := json.Unmarshal(body, &result); err != nil { + return nil, fmt.Errorf("failed to parse usage response: %w", err) + } + + usage := &KiroUsageInfo{ + SubscriptionTitle: result.SubscriptionInfo.SubscriptionTitle, + NextReset: fmt.Sprintf("%v", result.NextDateReset), + } + + if len(result.UsageBreakdownList) > 0 { + usage.CurrentUsage = result.UsageBreakdownList[0].CurrentUsageWithPrecision + usage.UsageLimit = result.UsageBreakdownList[0].UsageLimitWithPrecision + } + + return usage, nil +} + +// ListAvailableModels retrieves available models from the CodeWhisperer API. +// This method fetches the list of AI models available for the authenticated user. +// +// Parameters: +// - ctx: The context for the request +// - tokenData: The token data containing access token and profile ARN +// +// Returns: +// - []*KiroModel: The list of available models +// - error: An error if the request fails +func (k *KiroAuth) ListAvailableModels(ctx context.Context, tokenData *KiroTokenData) ([]*KiroModel, error) { + payload := map[string]interface{}{ + "origin": "AI_EDITOR", + "profileArn": tokenData.ProfileArn, + } + + body, err := k.makeRequest(ctx, targetListModels, tokenData.AccessToken, payload) + if err != nil { + return nil, err + } + + var result struct { + Models []struct { + ModelID string `json:"modelId"` + ModelName string `json:"modelName"` + Description string `json:"description"` + RateMultiplier float64 `json:"rateMultiplier"` + RateUnit string `json:"rateUnit"` + TokenLimits struct { + MaxInputTokens int `json:"maxInputTokens"` + } `json:"tokenLimits"` + } `json:"models"` + } + + if err := json.Unmarshal(body, &result); err != nil { + return nil, fmt.Errorf("failed to parse models response: %w", err) + } + + models := make([]*KiroModel, 0, len(result.Models)) + for _, m := range result.Models { + models = append(models, &KiroModel{ + ModelID: m.ModelID, + ModelName: m.ModelName, + Description: m.Description, + RateMultiplier: m.RateMultiplier, + RateUnit: m.RateUnit, + MaxInputTokens: m.TokenLimits.MaxInputTokens, + }) + } + + return models, nil +} + +// CreateTokenStorage creates a new KiroTokenStorage from token data. +// This method converts the token data into a storage structure suitable for persistence. +// +// Parameters: +// - tokenData: The token data to convert +// +// Returns: +// - *KiroTokenStorage: A new token storage instance +func (k *KiroAuth) CreateTokenStorage(tokenData *KiroTokenData) *KiroTokenStorage { + return &KiroTokenStorage{ + AccessToken: tokenData.AccessToken, + RefreshToken: tokenData.RefreshToken, + ProfileArn: tokenData.ProfileArn, + ExpiresAt: tokenData.ExpiresAt, + AuthMethod: tokenData.AuthMethod, + Provider: tokenData.Provider, + LastRefresh: time.Now().Format(time.RFC3339), + } +} + +// ValidateToken checks if the token is valid by making a test API call. +// This method verifies the token by attempting to fetch usage limits. +// +// Parameters: +// - ctx: The context for the request +// - tokenData: The token data to validate +// +// Returns: +// - error: An error if the token is invalid +func (k *KiroAuth) ValidateToken(ctx context.Context, tokenData *KiroTokenData) error { + _, err := k.GetUsageLimits(ctx, tokenData) + return err +} + +// UpdateTokenStorage updates an existing token storage with new token data. +// This method refreshes the token storage with newly obtained access and refresh tokens. +// +// Parameters: +// - storage: The existing token storage to update +// - tokenData: The new token data to apply +func (k *KiroAuth) UpdateTokenStorage(storage *KiroTokenStorage, tokenData *KiroTokenData) { + storage.AccessToken = tokenData.AccessToken + storage.RefreshToken = tokenData.RefreshToken + storage.ProfileArn = tokenData.ProfileArn + storage.ExpiresAt = tokenData.ExpiresAt + storage.AuthMethod = tokenData.AuthMethod + storage.Provider = tokenData.Provider + storage.LastRefresh = time.Now().Format(time.RFC3339) +} diff --git a/internal/auth/kiro/aws_test.go b/internal/auth/kiro/aws_test.go new file mode 100644 index 00000000..5f60294c --- /dev/null +++ b/internal/auth/kiro/aws_test.go @@ -0,0 +1,161 @@ +package kiro + +import ( + "encoding/base64" + "encoding/json" + "testing" +) + +func TestExtractEmailFromJWT(t *testing.T) { + tests := []struct { + name string + token string + expected string + }{ + { + name: "Empty token", + token: "", + expected: "", + }, + { + name: "Invalid token format", + token: "not.a.valid.jwt", + expected: "", + }, + { + name: "Invalid token - not base64", + token: "xxx.yyy.zzz", + expected: "", + }, + { + name: "Valid JWT with email", + token: createTestJWT(map[string]any{"email": "test@example.com", "sub": "user123"}), + expected: "test@example.com", + }, + { + name: "JWT without email but with preferred_username", + token: createTestJWT(map[string]any{"preferred_username": "user@domain.com", "sub": "user123"}), + expected: "user@domain.com", + }, + { + name: "JWT with email-like sub", + token: createTestJWT(map[string]any{"sub": "another@test.com"}), + expected: "another@test.com", + }, + { + name: "JWT without any email fields", + token: createTestJWT(map[string]any{"sub": "user123", "name": "Test User"}), + expected: "", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := ExtractEmailFromJWT(tt.token) + if result != tt.expected { + t.Errorf("ExtractEmailFromJWT() = %q, want %q", result, tt.expected) + } + }) + } +} + +func TestSanitizeEmailForFilename(t *testing.T) { + tests := []struct { + name string + email string + expected string + }{ + { + name: "Empty email", + email: "", + expected: "", + }, + { + name: "Simple email", + email: "user@example.com", + expected: "user@example.com", + }, + { + name: "Email with space", + email: "user name@example.com", + expected: "user_name@example.com", + }, + { + name: "Email with special chars", + email: "user:name@example.com", + expected: "user_name@example.com", + }, + { + name: "Email with multiple special chars", + email: "user/name:test@example.com", + expected: "user_name_test@example.com", + }, + { + name: "Path traversal attempt", + email: "../../../etc/passwd", + expected: "_.__.__._etc_passwd", + }, + { + name: "Path traversal with backslash", + email: `..\..\..\..\windows\system32`, + expected: "_.__.__.__._windows_system32", + }, + { + name: "Null byte injection attempt", + email: "user\x00@evil.com", + expected: "user_@evil.com", + }, + // URL-encoded path traversal tests + { + name: "URL-encoded slash", + email: "user%2Fpath@example.com", + expected: "user_path@example.com", + }, + { + name: "URL-encoded backslash", + email: "user%5Cpath@example.com", + expected: "user_path@example.com", + }, + { + name: "URL-encoded dot", + email: "%2E%2E%2Fetc%2Fpasswd", + expected: "___etc_passwd", + }, + { + name: "URL-encoded null", + email: "user%00@evil.com", + expected: "user_@evil.com", + }, + { + name: "Double URL-encoding attack", + email: "%252F%252E%252E", + expected: "_252F_252E_252E", // % replaced with _, remaining chars preserved (safe) + }, + { + name: "Mixed case URL-encoding", + email: "%2f%2F%5c%5C", + expected: "____", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := SanitizeEmailForFilename(tt.email) + if result != tt.expected { + t.Errorf("SanitizeEmailForFilename() = %q, want %q", result, tt.expected) + } + }) + } +} + +// createTestJWT creates a test JWT token with the given claims +func createTestJWT(claims map[string]any) string { + header := base64.RawURLEncoding.EncodeToString([]byte(`{"alg":"RS256","typ":"JWT"}`)) + + payloadBytes, _ := json.Marshal(claims) + payload := base64.RawURLEncoding.EncodeToString(payloadBytes) + + signature := base64.RawURLEncoding.EncodeToString([]byte("fake-signature")) + + return header + "." + payload + "." + signature +} diff --git a/internal/auth/kiro/oauth.go b/internal/auth/kiro/oauth.go new file mode 100644 index 00000000..e828da14 --- /dev/null +++ b/internal/auth/kiro/oauth.go @@ -0,0 +1,296 @@ +// Package kiro provides OAuth2 authentication for Kiro using native Google login. +package kiro + +import ( + "context" + "crypto/rand" + "crypto/sha256" + "encoding/base64" + "encoding/json" + "fmt" + "html" + "io" + "net" + "net/http" + "strings" + "time" + + "github.com/router-for-me/CLIProxyAPI/v6/internal/config" + "github.com/router-for-me/CLIProxyAPI/v6/internal/util" + log "github.com/sirupsen/logrus" +) + +const ( + // Kiro auth endpoint + kiroAuthEndpoint = "https://prod.us-east-1.auth.desktop.kiro.dev" + + // Default callback port + defaultCallbackPort = 9876 + + // Auth timeout + authTimeout = 10 * time.Minute +) + +// KiroTokenResponse represents the response from Kiro token endpoint. +type KiroTokenResponse struct { + AccessToken string `json:"accessToken"` + RefreshToken string `json:"refreshToken"` + ProfileArn string `json:"profileArn"` + ExpiresIn int `json:"expiresIn"` +} + +// KiroOAuth handles the OAuth flow for Kiro authentication. +type KiroOAuth struct { + httpClient *http.Client + cfg *config.Config +} + +// NewKiroOAuth creates a new Kiro OAuth handler. +func NewKiroOAuth(cfg *config.Config) *KiroOAuth { + client := &http.Client{Timeout: 30 * time.Second} + if cfg != nil { + client = util.SetProxy(&cfg.SDKConfig, client) + } + return &KiroOAuth{ + httpClient: client, + cfg: cfg, + } +} + +// generateCodeVerifier generates a random code verifier for PKCE. +func generateCodeVerifier() (string, error) { + b := make([]byte, 32) + if _, err := rand.Read(b); err != nil { + return "", err + } + return base64.RawURLEncoding.EncodeToString(b), nil +} + +// generateCodeChallenge generates the code challenge from verifier. +func generateCodeChallenge(verifier string) string { + h := sha256.Sum256([]byte(verifier)) + return base64.RawURLEncoding.EncodeToString(h[:]) +} + +// generateState generates a random state parameter. +func generateState() (string, error) { + b := make([]byte, 16) + if _, err := rand.Read(b); err != nil { + return "", err + } + return base64.RawURLEncoding.EncodeToString(b), nil +} + +// AuthResult contains the authorization code and state from callback. +type AuthResult struct { + Code string + State string + Error string +} + +// startCallbackServer starts a local HTTP server to receive the OAuth callback. +func (o *KiroOAuth) startCallbackServer(ctx context.Context, expectedState string) (string, <-chan AuthResult, error) { + // Try to find an available port - use localhost like Kiro does + listener, err := net.Listen("tcp", fmt.Sprintf("localhost:%d", defaultCallbackPort)) + if err != nil { + // Try with dynamic port (RFC 8252 allows dynamic ports for native apps) + log.Warnf("kiro oauth: default port %d is busy, falling back to dynamic port", defaultCallbackPort) + listener, err = net.Listen("tcp", "localhost:0") + if err != nil { + return "", nil, fmt.Errorf("failed to start callback server: %w", err) + } + } + + port := listener.Addr().(*net.TCPAddr).Port + // Use http scheme for local callback server + redirectURI := fmt.Sprintf("http://localhost:%d/oauth/callback", port) + resultChan := make(chan AuthResult, 1) + + server := &http.Server{ + ReadHeaderTimeout: 10 * time.Second, + } + + mux := http.NewServeMux() + mux.HandleFunc("/oauth/callback", func(w http.ResponseWriter, r *http.Request) { + code := r.URL.Query().Get("code") + state := r.URL.Query().Get("state") + errParam := r.URL.Query().Get("error") + + if errParam != "" { + w.Header().Set("Content-Type", "text/html") + w.WriteHeader(http.StatusBadRequest) + fmt.Fprintf(w, `

Login Failed

%s

You can close this window.

`, html.EscapeString(errParam)) + resultChan <- AuthResult{Error: errParam} + return + } + + if state != expectedState { + w.Header().Set("Content-Type", "text/html") + w.WriteHeader(http.StatusBadRequest) + fmt.Fprint(w, `

Login Failed

Invalid state parameter

You can close this window.

`) + resultChan <- AuthResult{Error: "state mismatch"} + return + } + + w.Header().Set("Content-Type", "text/html") + fmt.Fprint(w, `

Login Successful!

You can close this window and return to the terminal.

`) + resultChan <- AuthResult{Code: code, State: state} + }) + + server.Handler = mux + + go func() { + if err := server.Serve(listener); err != nil && err != http.ErrServerClosed { + log.Debugf("callback server error: %v", err) + } + }() + + go func() { + select { + case <-ctx.Done(): + case <-time.After(authTimeout): + case <-resultChan: + } + _ = server.Shutdown(context.Background()) + }() + + return redirectURI, resultChan, nil +} + +// LoginWithBuilderID performs OAuth login with AWS Builder ID using device code flow. +func (o *KiroOAuth) LoginWithBuilderID(ctx context.Context) (*KiroTokenData, error) { + ssoClient := NewSSOOIDCClient(o.cfg) + return ssoClient.LoginWithBuilderID(ctx) +} + +// exchangeCodeForToken exchanges the authorization code for tokens. +func (o *KiroOAuth) exchangeCodeForToken(ctx context.Context, code, codeVerifier, redirectURI string) (*KiroTokenData, error) { + payload := map[string]string{ + "code": code, + "code_verifier": codeVerifier, + "redirect_uri": redirectURI, + } + + body, err := json.Marshal(payload) + if err != nil { + return nil, fmt.Errorf("failed to marshal request: %w", err) + } + + tokenURL := kiroAuthEndpoint + "/oauth/token" + req, err := http.NewRequestWithContext(ctx, http.MethodPost, tokenURL, strings.NewReader(string(body))) + if err != nil { + return nil, fmt.Errorf("failed to create request: %w", err) + } + + req.Header.Set("Content-Type", "application/json") + req.Header.Set("User-Agent", "cli-proxy-api/1.0.0") + + resp, err := o.httpClient.Do(req) + if err != nil { + return nil, fmt.Errorf("token request failed: %w", err) + } + defer resp.Body.Close() + + respBody, err := io.ReadAll(resp.Body) + if err != nil { + return nil, fmt.Errorf("failed to read response: %w", err) + } + + if resp.StatusCode != http.StatusOK { + log.Debugf("token exchange failed (status %d): %s", resp.StatusCode, string(respBody)) + return nil, fmt.Errorf("token exchange failed (status %d)", resp.StatusCode) + } + + var tokenResp KiroTokenResponse + if err := json.Unmarshal(respBody, &tokenResp); err != nil { + return nil, fmt.Errorf("failed to parse token response: %w", err) + } + + // Validate ExpiresIn - use default 1 hour if invalid + expiresIn := tokenResp.ExpiresIn + if expiresIn <= 0 { + expiresIn = 3600 + } + expiresAt := time.Now().Add(time.Duration(expiresIn) * time.Second) + + return &KiroTokenData{ + AccessToken: tokenResp.AccessToken, + RefreshToken: tokenResp.RefreshToken, + ProfileArn: tokenResp.ProfileArn, + ExpiresAt: expiresAt.Format(time.RFC3339), + AuthMethod: "social", + Provider: "", // Caller should preserve original provider + }, nil +} + +// RefreshToken refreshes an expired access token. +func (o *KiroOAuth) RefreshToken(ctx context.Context, refreshToken string) (*KiroTokenData, error) { + payload := map[string]string{ + "refreshToken": refreshToken, + } + + body, err := json.Marshal(payload) + if err != nil { + return nil, fmt.Errorf("failed to marshal request: %w", err) + } + + refreshURL := kiroAuthEndpoint + "/refreshToken" + req, err := http.NewRequestWithContext(ctx, http.MethodPost, refreshURL, strings.NewReader(string(body))) + if err != nil { + return nil, fmt.Errorf("failed to create request: %w", err) + } + + req.Header.Set("Content-Type", "application/json") + req.Header.Set("User-Agent", "cli-proxy-api/1.0.0") + + resp, err := o.httpClient.Do(req) + if err != nil { + return nil, fmt.Errorf("refresh request failed: %w", err) + } + defer resp.Body.Close() + + respBody, err := io.ReadAll(resp.Body) + if err != nil { + return nil, fmt.Errorf("failed to read response: %w", err) + } + + if resp.StatusCode != http.StatusOK { + log.Debugf("token refresh failed (status %d): %s", resp.StatusCode, string(respBody)) + return nil, fmt.Errorf("token refresh failed (status %d)", resp.StatusCode) + } + + var tokenResp KiroTokenResponse + if err := json.Unmarshal(respBody, &tokenResp); err != nil { + return nil, fmt.Errorf("failed to parse token response: %w", err) + } + + // Validate ExpiresIn - use default 1 hour if invalid + expiresIn := tokenResp.ExpiresIn + if expiresIn <= 0 { + expiresIn = 3600 + } + expiresAt := time.Now().Add(time.Duration(expiresIn) * time.Second) + + return &KiroTokenData{ + AccessToken: tokenResp.AccessToken, + RefreshToken: tokenResp.RefreshToken, + ProfileArn: tokenResp.ProfileArn, + ExpiresAt: expiresAt.Format(time.RFC3339), + AuthMethod: "social", + Provider: "", // Caller should preserve original provider + }, nil +} + +// LoginWithGoogle performs OAuth login with Google using Kiro's social auth. +// This uses a custom protocol handler (kiro://) to receive the callback. +func (o *KiroOAuth) LoginWithGoogle(ctx context.Context) (*KiroTokenData, error) { + socialClient := NewSocialAuthClient(o.cfg) + return socialClient.LoginWithGoogle(ctx) +} + +// LoginWithGitHub performs OAuth login with GitHub using Kiro's social auth. +// This uses a custom protocol handler (kiro://) to receive the callback. +func (o *KiroOAuth) LoginWithGitHub(ctx context.Context) (*KiroTokenData, error) { + socialClient := NewSocialAuthClient(o.cfg) + return socialClient.LoginWithGitHub(ctx) +} diff --git a/internal/auth/kiro/protocol_handler.go b/internal/auth/kiro/protocol_handler.go new file mode 100644 index 00000000..e07cc24d --- /dev/null +++ b/internal/auth/kiro/protocol_handler.go @@ -0,0 +1,725 @@ +// Package kiro provides custom protocol handler registration for Kiro OAuth. +// This enables the CLI to intercept kiro:// URIs for social authentication (Google/GitHub). +package kiro + +import ( + "context" + "fmt" + "html" + "net" + "net/http" + "net/url" + "os" + "os/exec" + "path/filepath" + "runtime" + "strings" + "sync" + "time" + + log "github.com/sirupsen/logrus" +) + +const ( + // KiroProtocol is the custom URI scheme used by Kiro + KiroProtocol = "kiro" + + // KiroAuthority is the URI authority for authentication callbacks + KiroAuthority = "kiro.kiroAgent" + + // KiroAuthPath is the path for successful authentication + KiroAuthPath = "/authenticate-success" + + // KiroRedirectURI is the full redirect URI for social auth + KiroRedirectURI = "kiro://kiro.kiroAgent/authenticate-success" + + // DefaultHandlerPort is the default port for the local callback server + DefaultHandlerPort = 19876 + + // HandlerTimeout is how long to wait for the OAuth callback + HandlerTimeout = 10 * time.Minute +) + +// ProtocolHandler manages the custom kiro:// protocol handler for OAuth callbacks. +type ProtocolHandler struct { + port int + server *http.Server + listener net.Listener + resultChan chan *AuthCallback + stopChan chan struct{} + mu sync.Mutex + running bool +} + +// AuthCallback contains the OAuth callback parameters. +type AuthCallback struct { + Code string + State string + Error string +} + +// NewProtocolHandler creates a new protocol handler. +func NewProtocolHandler() *ProtocolHandler { + return &ProtocolHandler{ + port: DefaultHandlerPort, + resultChan: make(chan *AuthCallback, 1), + stopChan: make(chan struct{}), + } +} + +// Start starts the local callback server that receives redirects from the protocol handler. +func (h *ProtocolHandler) Start(ctx context.Context) (int, error) { + h.mu.Lock() + defer h.mu.Unlock() + + if h.running { + return h.port, nil + } + + // Drain any stale results from previous runs + select { + case <-h.resultChan: + default: + } + + // Reset stopChan for reuse - close old channel first to unblock any waiting goroutines + if h.stopChan != nil { + select { + case <-h.stopChan: + // Already closed + default: + close(h.stopChan) + } + } + h.stopChan = make(chan struct{}) + + // Try ports in known range (must match handler script port range) + var listener net.Listener + var err error + portRange := []int{DefaultHandlerPort, DefaultHandlerPort + 1, DefaultHandlerPort + 2, DefaultHandlerPort + 3, DefaultHandlerPort + 4} + + for _, port := range portRange { + listener, err = net.Listen("tcp", fmt.Sprintf("127.0.0.1:%d", port)) + if err == nil { + break + } + log.Debugf("kiro protocol handler: port %d busy, trying next", port) + } + + if listener == nil { + return 0, fmt.Errorf("failed to start callback server: all ports %d-%d are busy", DefaultHandlerPort, DefaultHandlerPort+4) + } + + h.listener = listener + h.port = listener.Addr().(*net.TCPAddr).Port + + mux := http.NewServeMux() + mux.HandleFunc("/oauth/callback", h.handleCallback) + + h.server = &http.Server{ + Handler: mux, + ReadHeaderTimeout: 10 * time.Second, + } + + go func() { + if err := h.server.Serve(listener); err != nil && err != http.ErrServerClosed { + log.Debugf("kiro protocol handler server error: %v", err) + } + }() + + h.running = true + log.Debugf("kiro protocol handler started on port %d", h.port) + + // Auto-shutdown after context done, timeout, or explicit stop + // Capture references to prevent race with new Start() calls + currentStopChan := h.stopChan + currentServer := h.server + currentListener := h.listener + go func() { + select { + case <-ctx.Done(): + case <-time.After(HandlerTimeout): + case <-currentStopChan: + return // Already stopped, exit goroutine + } + // Only stop if this is still the current server/listener instance + h.mu.Lock() + if h.server == currentServer && h.listener == currentListener { + h.mu.Unlock() + h.Stop() + } else { + h.mu.Unlock() + } + }() + + return h.port, nil +} + +// Stop stops the callback server. +func (h *ProtocolHandler) Stop() { + h.mu.Lock() + defer h.mu.Unlock() + + if !h.running { + return + } + + // Signal the auto-shutdown goroutine to exit. + // This select pattern is safe because stopChan is only modified while holding h.mu, + // and we hold the lock here. The select prevents panic from double-close. + select { + case <-h.stopChan: + // Already closed + default: + close(h.stopChan) + } + + if h.server != nil { + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + _ = h.server.Shutdown(ctx) + } + + h.running = false + log.Debug("kiro protocol handler stopped") +} + +// WaitForCallback waits for the OAuth callback and returns the result. +func (h *ProtocolHandler) WaitForCallback(ctx context.Context) (*AuthCallback, error) { + select { + case <-ctx.Done(): + return nil, ctx.Err() + case <-time.After(HandlerTimeout): + return nil, fmt.Errorf("timeout waiting for OAuth callback") + case result := <-h.resultChan: + return result, nil + } +} + +// GetPort returns the port the handler is listening on. +func (h *ProtocolHandler) GetPort() int { + return h.port +} + +// handleCallback processes the OAuth callback from the protocol handler script. +func (h *ProtocolHandler) handleCallback(w http.ResponseWriter, r *http.Request) { + code := r.URL.Query().Get("code") + state := r.URL.Query().Get("state") + errParam := r.URL.Query().Get("error") + + result := &AuthCallback{ + Code: code, + State: state, + Error: errParam, + } + + // Send result + select { + case h.resultChan <- result: + default: + // Channel full, ignore duplicate callbacks + } + + // Send success response + w.Header().Set("Content-Type", "text/html; charset=utf-8") + if errParam != "" { + w.WriteHeader(http.StatusBadRequest) + fmt.Fprintf(w, ` + +Login Failed + +

Login Failed

+

Error: %s

+

You can close this window.

+ +`, html.EscapeString(errParam)) + } else { + fmt.Fprint(w, ` + +Login Successful + +

Login Successful!

+

You can close this window and return to the terminal.

+ + +`) + } +} + +// IsProtocolHandlerInstalled checks if the kiro:// protocol handler is installed. +func IsProtocolHandlerInstalled() bool { + switch runtime.GOOS { + case "linux": + return isLinuxHandlerInstalled() + case "windows": + return isWindowsHandlerInstalled() + case "darwin": + return isDarwinHandlerInstalled() + default: + return false + } +} + +// InstallProtocolHandler installs the kiro:// protocol handler for the current platform. +func InstallProtocolHandler(handlerPort int) error { + switch runtime.GOOS { + case "linux": + return installLinuxHandler(handlerPort) + case "windows": + return installWindowsHandler(handlerPort) + case "darwin": + return installDarwinHandler(handlerPort) + default: + return fmt.Errorf("unsupported platform: %s", runtime.GOOS) + } +} + +// UninstallProtocolHandler removes the kiro:// protocol handler. +func UninstallProtocolHandler() error { + switch runtime.GOOS { + case "linux": + return uninstallLinuxHandler() + case "windows": + return uninstallWindowsHandler() + case "darwin": + return uninstallDarwinHandler() + default: + return fmt.Errorf("unsupported platform: %s", runtime.GOOS) + } +} + +// --- Linux Implementation --- + +func getLinuxDesktopPath() string { + homeDir, _ := os.UserHomeDir() + return filepath.Join(homeDir, ".local", "share", "applications", "kiro-oauth-handler.desktop") +} + +func getLinuxHandlerScriptPath() string { + homeDir, _ := os.UserHomeDir() + return filepath.Join(homeDir, ".local", "bin", "kiro-oauth-handler") +} + +func isLinuxHandlerInstalled() bool { + desktopPath := getLinuxDesktopPath() + _, err := os.Stat(desktopPath) + return err == nil +} + +func installLinuxHandler(handlerPort int) error { + // Create directories + homeDir, err := os.UserHomeDir() + if err != nil { + return err + } + + binDir := filepath.Join(homeDir, ".local", "bin") + appDir := filepath.Join(homeDir, ".local", "share", "applications") + + if err := os.MkdirAll(binDir, 0755); err != nil { + return fmt.Errorf("failed to create bin directory: %w", err) + } + if err := os.MkdirAll(appDir, 0755); err != nil { + return fmt.Errorf("failed to create applications directory: %w", err) + } + + // Create handler script - tries multiple ports to handle dynamic port allocation + scriptPath := getLinuxHandlerScriptPath() + scriptContent := fmt.Sprintf(`#!/bin/bash +# Kiro OAuth Protocol Handler +# Handles kiro:// URIs - tries CLI first, then forwards to Kiro IDE + +URL="$1" + +# Check curl availability +if ! command -v curl &> /dev/null; then + echo "Error: curl is required for Kiro OAuth handler" >&2 + exit 1 +fi + +# Extract code and state from URL +[[ "$URL" =~ code=([^&]+) ]] && CODE="${BASH_REMATCH[1]}" +[[ "$URL" =~ state=([^&]+) ]] && STATE="${BASH_REMATCH[1]}" +[[ "$URL" =~ error=([^&]+) ]] && ERROR="${BASH_REMATCH[1]}" + +# Try CLI proxy on multiple possible ports (default + dynamic range) +CLI_OK=0 +for PORT in %d %d %d %d %d; do + if [ -n "$ERROR" ]; then + curl -sf --connect-timeout 1 "http://127.0.0.1:$PORT/oauth/callback?error=$ERROR" && CLI_OK=1 && break + elif [ -n "$CODE" ] && [ -n "$STATE" ]; then + curl -sf --connect-timeout 1 "http://127.0.0.1:$PORT/oauth/callback?code=$CODE&state=$STATE" && CLI_OK=1 && break + fi +done + +# If CLI not available, forward to Kiro IDE +if [ $CLI_OK -eq 0 ] && [ -x "/usr/share/kiro/kiro" ]; then + /usr/share/kiro/kiro --open-url "$URL" & +fi +`, handlerPort, handlerPort+1, handlerPort+2, handlerPort+3, handlerPort+4) + + if err := os.WriteFile(scriptPath, []byte(scriptContent), 0755); err != nil { + return fmt.Errorf("failed to write handler script: %w", err) + } + + // Create .desktop file + desktopPath := getLinuxDesktopPath() + desktopContent := fmt.Sprintf(`[Desktop Entry] +Name=Kiro OAuth Handler +Comment=Handle kiro:// protocol for CLI Proxy API authentication +Exec=%s %%u +Type=Application +Terminal=false +NoDisplay=true +MimeType=x-scheme-handler/kiro; +Categories=Utility; +`, scriptPath) + + if err := os.WriteFile(desktopPath, []byte(desktopContent), 0644); err != nil { + return fmt.Errorf("failed to write desktop file: %w", err) + } + + // Register handler with xdg-mime + cmd := exec.Command("xdg-mime", "default", "kiro-oauth-handler.desktop", "x-scheme-handler/kiro") + if err := cmd.Run(); err != nil { + log.Warnf("xdg-mime registration failed (may need manual setup): %v", err) + } + + // Update desktop database + cmd = exec.Command("update-desktop-database", appDir) + _ = cmd.Run() // Ignore errors, not critical + + log.Info("Kiro protocol handler installed for Linux") + return nil +} + +func uninstallLinuxHandler() error { + desktopPath := getLinuxDesktopPath() + scriptPath := getLinuxHandlerScriptPath() + + if err := os.Remove(desktopPath); err != nil && !os.IsNotExist(err) { + return fmt.Errorf("failed to remove desktop file: %w", err) + } + if err := os.Remove(scriptPath); err != nil && !os.IsNotExist(err) { + return fmt.Errorf("failed to remove handler script: %w", err) + } + + log.Info("Kiro protocol handler uninstalled") + return nil +} + +// --- Windows Implementation --- + +func isWindowsHandlerInstalled() bool { + // Check registry key existence + cmd := exec.Command("reg", "query", `HKCU\Software\Classes\kiro`, "/ve") + return cmd.Run() == nil +} + +func installWindowsHandler(handlerPort int) error { + homeDir, err := os.UserHomeDir() + if err != nil { + return err + } + + // Create handler script (PowerShell) + scriptDir := filepath.Join(homeDir, ".cliproxyapi") + if err := os.MkdirAll(scriptDir, 0755); err != nil { + return fmt.Errorf("failed to create script directory: %w", err) + } + + scriptPath := filepath.Join(scriptDir, "kiro-oauth-handler.ps1") + scriptContent := fmt.Sprintf(`# Kiro OAuth Protocol Handler for Windows +param([string]$url) + +# Load required assembly for HttpUtility +Add-Type -AssemblyName System.Web + +# Parse URL parameters +$uri = [System.Uri]$url +$query = [System.Web.HttpUtility]::ParseQueryString($uri.Query) +$code = $query["code"] +$state = $query["state"] +$errorParam = $query["error"] + +# Try multiple ports (default + dynamic range) +$ports = @(%d, %d, %d, %d, %d) +$success = $false + +foreach ($port in $ports) { + if ($success) { break } + $callbackUrl = "http://127.0.0.1:$port/oauth/callback" + try { + if ($errorParam) { + $fullUrl = $callbackUrl + "?error=" + $errorParam + Invoke-WebRequest -Uri $fullUrl -UseBasicParsing -TimeoutSec 1 -ErrorAction Stop | Out-Null + $success = $true + } elseif ($code -and $state) { + $fullUrl = $callbackUrl + "?code=" + $code + "&state=" + $state + Invoke-WebRequest -Uri $fullUrl -UseBasicParsing -TimeoutSec 1 -ErrorAction Stop | Out-Null + $success = $true + } + } catch { + # Try next port + } +} +`, handlerPort, handlerPort+1, handlerPort+2, handlerPort+3, handlerPort+4) + + if err := os.WriteFile(scriptPath, []byte(scriptContent), 0644); err != nil { + return fmt.Errorf("failed to write handler script: %w", err) + } + + // Create batch wrapper + batchPath := filepath.Join(scriptDir, "kiro-oauth-handler.bat") + batchContent := fmt.Sprintf("@echo off\npowershell -ExecutionPolicy Bypass -File \"%s\" \"%%1\"\n", scriptPath) + + if err := os.WriteFile(batchPath, []byte(batchContent), 0644); err != nil { + return fmt.Errorf("failed to write batch wrapper: %w", err) + } + + // Register in Windows registry + commands := [][]string{ + {"reg", "add", `HKCU\Software\Classes\kiro`, "/ve", "/d", "URL:Kiro Protocol", "/f"}, + {"reg", "add", `HKCU\Software\Classes\kiro`, "/v", "URL Protocol", "/d", "", "/f"}, + {"reg", "add", `HKCU\Software\Classes\kiro\shell`, "/f"}, + {"reg", "add", `HKCU\Software\Classes\kiro\shell\open`, "/f"}, + {"reg", "add", `HKCU\Software\Classes\kiro\shell\open\command`, "/ve", "/d", fmt.Sprintf("\"%s\" \"%%1\"", batchPath), "/f"}, + } + + for _, args := range commands { + cmd := exec.Command(args[0], args[1:]...) + if err := cmd.Run(); err != nil { + return fmt.Errorf("failed to run registry command: %w", err) + } + } + + log.Info("Kiro protocol handler installed for Windows") + return nil +} + +func uninstallWindowsHandler() error { + // Remove registry keys + cmd := exec.Command("reg", "delete", `HKCU\Software\Classes\kiro`, "/f") + if err := cmd.Run(); err != nil { + log.Warnf("failed to remove registry key: %v", err) + } + + // Remove scripts + homeDir, _ := os.UserHomeDir() + scriptDir := filepath.Join(homeDir, ".cliproxyapi") + _ = os.Remove(filepath.Join(scriptDir, "kiro-oauth-handler.ps1")) + _ = os.Remove(filepath.Join(scriptDir, "kiro-oauth-handler.bat")) + + log.Info("Kiro protocol handler uninstalled") + return nil +} + +// --- macOS Implementation --- + +func getDarwinAppPath() string { + homeDir, _ := os.UserHomeDir() + return filepath.Join(homeDir, "Applications", "KiroOAuthHandler.app") +} + +func isDarwinHandlerInstalled() bool { + appPath := getDarwinAppPath() + _, err := os.Stat(appPath) + return err == nil +} + +func installDarwinHandler(handlerPort int) error { + // Create app bundle structure + appPath := getDarwinAppPath() + contentsPath := filepath.Join(appPath, "Contents") + macOSPath := filepath.Join(contentsPath, "MacOS") + + if err := os.MkdirAll(macOSPath, 0755); err != nil { + return fmt.Errorf("failed to create app bundle: %w", err) + } + + // Create Info.plist + plistPath := filepath.Join(contentsPath, "Info.plist") + plistContent := ` + + + + CFBundleIdentifier + com.cliproxyapi.kiro-oauth-handler + CFBundleName + KiroOAuthHandler + CFBundleExecutable + kiro-oauth-handler + CFBundleVersion + 1.0 + CFBundleURLTypes + + + CFBundleURLName + Kiro Protocol + CFBundleURLSchemes + + kiro + + + + LSBackgroundOnly + + +` + + if err := os.WriteFile(plistPath, []byte(plistContent), 0644); err != nil { + return fmt.Errorf("failed to write Info.plist: %w", err) + } + + // Create executable script - tries multiple ports to handle dynamic port allocation + execPath := filepath.Join(macOSPath, "kiro-oauth-handler") + execContent := fmt.Sprintf(`#!/bin/bash +# Kiro OAuth Protocol Handler for macOS + +URL="$1" + +# Check curl availability (should always exist on macOS) +if [ ! -x /usr/bin/curl ]; then + echo "Error: curl is required for Kiro OAuth handler" >&2 + exit 1 +fi + +# Extract code and state from URL +[[ "$URL" =~ code=([^&]+) ]] && CODE="${BASH_REMATCH[1]}" +[[ "$URL" =~ state=([^&]+) ]] && STATE="${BASH_REMATCH[1]}" +[[ "$URL" =~ error=([^&]+) ]] && ERROR="${BASH_REMATCH[1]}" + +# Try multiple ports (default + dynamic range) +for PORT in %d %d %d %d %d; do + if [ -n "$ERROR" ]; then + /usr/bin/curl -sf --connect-timeout 1 "http://127.0.0.1:$PORT/oauth/callback?error=$ERROR" && exit 0 + elif [ -n "$CODE" ] && [ -n "$STATE" ]; then + /usr/bin/curl -sf --connect-timeout 1 "http://127.0.0.1:$PORT/oauth/callback?code=$CODE&state=$STATE" && exit 0 + fi +done +`, handlerPort, handlerPort+1, handlerPort+2, handlerPort+3, handlerPort+4) + + if err := os.WriteFile(execPath, []byte(execContent), 0755); err != nil { + return fmt.Errorf("failed to write executable: %w", err) + } + + // Register the app with Launch Services + cmd := exec.Command("/System/Library/Frameworks/CoreServices.framework/Frameworks/LaunchServices.framework/Support/lsregister", + "-f", appPath) + if err := cmd.Run(); err != nil { + log.Warnf("lsregister failed (handler may still work): %v", err) + } + + log.Info("Kiro protocol handler installed for macOS") + return nil +} + +func uninstallDarwinHandler() error { + appPath := getDarwinAppPath() + + // Unregister from Launch Services + cmd := exec.Command("/System/Library/Frameworks/CoreServices.framework/Frameworks/LaunchServices.framework/Support/lsregister", + "-u", appPath) + _ = cmd.Run() + + // Remove app bundle + if err := os.RemoveAll(appPath); err != nil && !os.IsNotExist(err) { + return fmt.Errorf("failed to remove app bundle: %w", err) + } + + log.Info("Kiro protocol handler uninstalled") + return nil +} + +// ParseKiroURI parses a kiro:// URI and extracts the callback parameters. +func ParseKiroURI(rawURI string) (*AuthCallback, error) { + u, err := url.Parse(rawURI) + if err != nil { + return nil, fmt.Errorf("invalid URI: %w", err) + } + + if u.Scheme != KiroProtocol { + return nil, fmt.Errorf("invalid scheme: expected %s, got %s", KiroProtocol, u.Scheme) + } + + if u.Host != KiroAuthority { + return nil, fmt.Errorf("invalid authority: expected %s, got %s", KiroAuthority, u.Host) + } + + query := u.Query() + return &AuthCallback{ + Code: query.Get("code"), + State: query.Get("state"), + Error: query.Get("error"), + }, nil +} + +// GetHandlerInstructions returns platform-specific instructions for manual handler setup. +func GetHandlerInstructions() string { + switch runtime.GOOS { + case "linux": + return `To manually set up the Kiro protocol handler on Linux: + +1. Create ~/.local/share/applications/kiro-oauth-handler.desktop: + [Desktop Entry] + Name=Kiro OAuth Handler + Exec=~/.local/bin/kiro-oauth-handler %u + Type=Application + Terminal=false + MimeType=x-scheme-handler/kiro; + +2. Create ~/.local/bin/kiro-oauth-handler (make it executable): + #!/bin/bash + URL="$1" + # ... (see generated script for full content) + +3. Run: xdg-mime default kiro-oauth-handler.desktop x-scheme-handler/kiro` + + case "windows": + return `To manually set up the Kiro protocol handler on Windows: + +1. Open Registry Editor (regedit.exe) +2. Create key: HKEY_CURRENT_USER\Software\Classes\kiro +3. Set default value to: URL:Kiro Protocol +4. Create string value "URL Protocol" with empty data +5. Create subkey: shell\open\command +6. Set default value to: "C:\path\to\handler.bat" "%1"` + + case "darwin": + return `To manually set up the Kiro protocol handler on macOS: + +1. Create ~/Applications/KiroOAuthHandler.app bundle +2. Add Info.plist with CFBundleURLTypes containing "kiro" scheme +3. Create executable in Contents/MacOS/ +4. Run: /System/Library/.../lsregister -f ~/Applications/KiroOAuthHandler.app` + + default: + return "Protocol handler setup is not supported on this platform." + } +} + +// SetupProtocolHandlerIfNeeded checks and installs the protocol handler if needed. +func SetupProtocolHandlerIfNeeded(handlerPort int) error { + if IsProtocolHandlerInstalled() { + log.Debug("Kiro protocol handler already installed") + return nil + } + + fmt.Println("\n╔══════════════════════════════════════════════════════════╗") + fmt.Println("║ Kiro Protocol Handler Setup Required ║") + fmt.Println("╚══════════════════════════════════════════════════════════╝") + fmt.Println("\nTo enable Google/GitHub login, we need to install a protocol handler.") + fmt.Println("This allows your browser to redirect back to the CLI after authentication.") + fmt.Println("\nInstalling protocol handler...") + + if err := InstallProtocolHandler(handlerPort); err != nil { + fmt.Printf("\n⚠ Automatic installation failed: %v\n", err) + fmt.Println("\nManual setup instructions:") + fmt.Println(strings.Repeat("-", 60)) + fmt.Println(GetHandlerInstructions()) + return err + } + + fmt.Println("\n✓ Protocol handler installed successfully!") + return nil +} diff --git a/internal/auth/kiro/social_auth.go b/internal/auth/kiro/social_auth.go new file mode 100644 index 00000000..61c67886 --- /dev/null +++ b/internal/auth/kiro/social_auth.go @@ -0,0 +1,403 @@ +// Package kiro provides social authentication (Google/GitHub) for Kiro via AuthServiceClient. +package kiro + +import ( + "bufio" + "context" + "crypto/rand" + "crypto/sha256" + "encoding/base64" + "encoding/json" + "fmt" + "io" + "net/http" + "net/url" + "os" + "os/exec" + "runtime" + "strings" + "time" + + "github.com/router-for-me/CLIProxyAPI/v6/internal/browser" + "github.com/router-for-me/CLIProxyAPI/v6/internal/config" + "github.com/router-for-me/CLIProxyAPI/v6/internal/util" + log "github.com/sirupsen/logrus" + "golang.org/x/term" +) + +const ( + // Kiro AuthService endpoint + kiroAuthServiceEndpoint = "https://prod.us-east-1.auth.desktop.kiro.dev" + + // OAuth timeout + socialAuthTimeout = 10 * time.Minute +) + +// SocialProvider represents the social login provider. +type SocialProvider string + +const ( + // ProviderGoogle is Google OAuth provider + ProviderGoogle SocialProvider = "Google" + // ProviderGitHub is GitHub OAuth provider + ProviderGitHub SocialProvider = "Github" + // Note: AWS Builder ID is NOT supported by Kiro's auth service. + // It only supports: Google, Github, Cognito + // AWS Builder ID must use device code flow via SSO OIDC. +) + +// CreateTokenRequest is sent to Kiro's /oauth/token endpoint. +type CreateTokenRequest struct { + Code string `json:"code"` + CodeVerifier string `json:"code_verifier"` + RedirectURI string `json:"redirect_uri"` + InvitationCode string `json:"invitation_code,omitempty"` +} + +// SocialTokenResponse from Kiro's /oauth/token endpoint for social auth. +type SocialTokenResponse struct { + AccessToken string `json:"accessToken"` + RefreshToken string `json:"refreshToken"` + ProfileArn string `json:"profileArn"` + ExpiresIn int `json:"expiresIn"` +} + +// RefreshTokenRequest is sent to Kiro's /refreshToken endpoint. +type RefreshTokenRequest struct { + RefreshToken string `json:"refreshToken"` +} + +// SocialAuthClient handles social authentication with Kiro. +type SocialAuthClient struct { + httpClient *http.Client + cfg *config.Config + protocolHandler *ProtocolHandler +} + +// NewSocialAuthClient creates a new social auth client. +func NewSocialAuthClient(cfg *config.Config) *SocialAuthClient { + client := &http.Client{Timeout: 30 * time.Second} + if cfg != nil { + client = util.SetProxy(&cfg.SDKConfig, client) + } + return &SocialAuthClient{ + httpClient: client, + cfg: cfg, + protocolHandler: NewProtocolHandler(), + } +} + +// generatePKCE generates PKCE code verifier and challenge. +func generatePKCE() (verifier, challenge string, err error) { + // Generate 32 bytes of random data for verifier + b := make([]byte, 32) + if _, err := rand.Read(b); err != nil { + return "", "", fmt.Errorf("failed to generate random bytes: %w", err) + } + verifier = base64.RawURLEncoding.EncodeToString(b) + + // Generate SHA256 hash of verifier for challenge + h := sha256.Sum256([]byte(verifier)) + challenge = base64.RawURLEncoding.EncodeToString(h[:]) + + return verifier, challenge, nil +} + +// generateState generates a random state parameter. +func generateStateParam() (string, error) { + b := make([]byte, 16) + if _, err := rand.Read(b); err != nil { + return "", err + } + return base64.RawURLEncoding.EncodeToString(b), nil +} + +// buildLoginURL constructs the Kiro OAuth login URL. +// The login endpoint expects a GET request with query parameters. +// Format: /login?idp=Google&redirect_uri=...&code_challenge=...&code_challenge_method=S256&state=...&prompt=select_account +// The prompt=select_account parameter forces the account selection screen even if already logged in. +func (c *SocialAuthClient) buildLoginURL(provider, redirectURI, codeChallenge, state string) string { + return fmt.Sprintf("%s/login?idp=%s&redirect_uri=%s&code_challenge=%s&code_challenge_method=S256&state=%s&prompt=select_account", + kiroAuthServiceEndpoint, + provider, + url.QueryEscape(redirectURI), + codeChallenge, + state, + ) +} + +// createToken exchanges the authorization code for tokens. +func (c *SocialAuthClient) createToken(ctx context.Context, req *CreateTokenRequest) (*SocialTokenResponse, error) { + body, err := json.Marshal(req) + if err != nil { + return nil, fmt.Errorf("failed to marshal token request: %w", err) + } + + tokenURL := kiroAuthServiceEndpoint + "/oauth/token" + httpReq, err := http.NewRequestWithContext(ctx, http.MethodPost, tokenURL, strings.NewReader(string(body))) + if err != nil { + return nil, fmt.Errorf("failed to create token request: %w", err) + } + + httpReq.Header.Set("Content-Type", "application/json") + httpReq.Header.Set("User-Agent", "cli-proxy-api/1.0.0") + + resp, err := c.httpClient.Do(httpReq) + if err != nil { + return nil, fmt.Errorf("token request failed: %w", err) + } + defer resp.Body.Close() + + respBody, err := io.ReadAll(resp.Body) + if err != nil { + return nil, fmt.Errorf("failed to read token response: %w", err) + } + + if resp.StatusCode != http.StatusOK { + log.Debugf("token exchange failed (status %d): %s", resp.StatusCode, string(respBody)) + return nil, fmt.Errorf("token exchange failed (status %d)", resp.StatusCode) + } + + var tokenResp SocialTokenResponse + if err := json.Unmarshal(respBody, &tokenResp); err != nil { + return nil, fmt.Errorf("failed to parse token response: %w", err) + } + + return &tokenResp, nil +} + +// RefreshSocialToken refreshes an expired social auth token. +func (c *SocialAuthClient) RefreshSocialToken(ctx context.Context, refreshToken string) (*KiroTokenData, error) { + body, err := json.Marshal(&RefreshTokenRequest{RefreshToken: refreshToken}) + if err != nil { + return nil, fmt.Errorf("failed to marshal refresh request: %w", err) + } + + refreshURL := kiroAuthServiceEndpoint + "/refreshToken" + httpReq, err := http.NewRequestWithContext(ctx, http.MethodPost, refreshURL, strings.NewReader(string(body))) + if err != nil { + return nil, fmt.Errorf("failed to create refresh request: %w", err) + } + + httpReq.Header.Set("Content-Type", "application/json") + httpReq.Header.Set("User-Agent", "cli-proxy-api/1.0.0") + + resp, err := c.httpClient.Do(httpReq) + if err != nil { + return nil, fmt.Errorf("refresh request failed: %w", err) + } + defer resp.Body.Close() + + respBody, err := io.ReadAll(resp.Body) + if err != nil { + return nil, fmt.Errorf("failed to read refresh response: %w", err) + } + + if resp.StatusCode != http.StatusOK { + log.Debugf("token refresh failed (status %d): %s", resp.StatusCode, string(respBody)) + return nil, fmt.Errorf("token refresh failed (status %d)", resp.StatusCode) + } + + var tokenResp SocialTokenResponse + if err := json.Unmarshal(respBody, &tokenResp); err != nil { + return nil, fmt.Errorf("failed to parse refresh response: %w", err) + } + + // Validate ExpiresIn - use default 1 hour if invalid + expiresIn := tokenResp.ExpiresIn + if expiresIn <= 0 { + expiresIn = 3600 // Default 1 hour + } + expiresAt := time.Now().Add(time.Duration(expiresIn) * time.Second) + + return &KiroTokenData{ + AccessToken: tokenResp.AccessToken, + RefreshToken: tokenResp.RefreshToken, + ProfileArn: tokenResp.ProfileArn, + ExpiresAt: expiresAt.Format(time.RFC3339), + AuthMethod: "social", + Provider: "", // Caller should preserve original provider + }, nil +} + +// LoginWithSocial performs OAuth login with Google. +func (c *SocialAuthClient) LoginWithSocial(ctx context.Context, provider SocialProvider) (*KiroTokenData, error) { + providerName := string(provider) + + fmt.Println("\n╔══════════════════════════════════════════════════════════╗") + fmt.Printf("║ Kiro Authentication (%s) ║\n", providerName) + fmt.Println("╚══════════════════════════════════════════════════════════╝") + + // Step 1: Setup protocol handler + fmt.Println("\nSetting up authentication...") + + // Start the local callback server + handlerPort, err := c.protocolHandler.Start(ctx) + if err != nil { + return nil, fmt.Errorf("failed to start callback server: %w", err) + } + defer c.protocolHandler.Stop() + + // Ensure protocol handler is installed and set as default + if err := SetupProtocolHandlerIfNeeded(handlerPort); err != nil { + fmt.Println("\n⚠ Protocol handler setup failed. Trying alternative method...") + fmt.Println(" If you see a browser 'Open with' dialog, select your default browser.") + fmt.Println(" For manual setup instructions, run: cliproxy kiro --help-protocol") + log.Debugf("kiro: protocol handler setup error: %v", err) + // Continue anyway - user might have set it up manually or select browser manually + } else { + // Force set our handler as default (prevents "Open with" dialog) + forceDefaultProtocolHandler() + } + + // Step 2: Generate PKCE codes + codeVerifier, codeChallenge, err := generatePKCE() + if err != nil { + return nil, fmt.Errorf("failed to generate PKCE: %w", err) + } + + // Step 3: Generate state + state, err := generateStateParam() + if err != nil { + return nil, fmt.Errorf("failed to generate state: %w", err) + } + + // Step 4: Build the login URL (Kiro uses GET request with query params) + authURL := c.buildLoginURL(providerName, KiroRedirectURI, codeChallenge, state) + + // Set incognito mode based on config (defaults to true for Kiro, can be overridden with --no-incognito) + // Incognito mode enables multi-account support by bypassing cached sessions + if c.cfg != nil { + browser.SetIncognitoMode(c.cfg.IncognitoBrowser) + if !c.cfg.IncognitoBrowser { + log.Info("kiro: using normal browser mode (--no-incognito). Note: You may not be able to select a different account.") + } else { + log.Debug("kiro: using incognito mode for multi-account support") + } + } else { + browser.SetIncognitoMode(true) // Default to incognito if no config + log.Debug("kiro: using incognito mode for multi-account support (default)") + } + + // Step 5: Open browser for user authentication + fmt.Println("\n════════════════════════════════════════════════════════════") + fmt.Printf(" Opening browser for %s authentication...\n", providerName) + fmt.Println("════════════════════════════════════════════════════════════") + fmt.Printf("\n URL: %s\n\n", authURL) + + if err := browser.OpenURL(authURL); err != nil { + log.Warnf("Could not open browser automatically: %v", err) + fmt.Println(" ⚠ Could not open browser automatically.") + fmt.Println(" Please open the URL above in your browser manually.") + } else { + fmt.Println(" (Browser opened automatically)") + } + + fmt.Println("\n Waiting for authentication callback...") + + // Step 6: Wait for callback + callback, err := c.protocolHandler.WaitForCallback(ctx) + if err != nil { + return nil, fmt.Errorf("failed to receive callback: %w", err) + } + + if callback.Error != "" { + return nil, fmt.Errorf("authentication error: %s", callback.Error) + } + + if callback.State != state { + // Log state values for debugging, but don't expose in user-facing error + log.Debugf("kiro: OAuth state mismatch - expected %s, got %s", state, callback.State) + return nil, fmt.Errorf("OAuth state validation failed - please try again") + } + + if callback.Code == "" { + return nil, fmt.Errorf("no authorization code received") + } + + fmt.Println("\n✓ Authorization received!") + + // Step 7: Exchange code for tokens + fmt.Println("Exchanging code for tokens...") + + tokenReq := &CreateTokenRequest{ + Code: callback.Code, + CodeVerifier: codeVerifier, + RedirectURI: KiroRedirectURI, + } + + tokenResp, err := c.createToken(ctx, tokenReq) + if err != nil { + return nil, fmt.Errorf("failed to exchange code for tokens: %w", err) + } + + fmt.Println("\n✓ Authentication successful!") + + // Close the browser window + if err := browser.CloseBrowser(); err != nil { + log.Debugf("Failed to close browser: %v", err) + } + + // Validate ExpiresIn - use default 1 hour if invalid + expiresIn := tokenResp.ExpiresIn + if expiresIn <= 0 { + expiresIn = 3600 + } + expiresAt := time.Now().Add(time.Duration(expiresIn) * time.Second) + + // Try to extract email from JWT access token first + email := ExtractEmailFromJWT(tokenResp.AccessToken) + + // If no email in JWT, ask user for account label (only in interactive mode) + if email == "" && isInteractiveTerminal() { + fmt.Print("\n Enter account label for file naming (optional, press Enter to skip): ") + reader := bufio.NewReader(os.Stdin) + var err error + email, err = reader.ReadString('\n') + if err != nil { + log.Debugf("Failed to read account label: %v", err) + } + email = strings.TrimSpace(email) + } + + return &KiroTokenData{ + AccessToken: tokenResp.AccessToken, + RefreshToken: tokenResp.RefreshToken, + ProfileArn: tokenResp.ProfileArn, + ExpiresAt: expiresAt.Format(time.RFC3339), + AuthMethod: "social", + Provider: providerName, + Email: email, // JWT email or user-provided label + }, nil +} + +// LoginWithGoogle performs OAuth login with Google. +func (c *SocialAuthClient) LoginWithGoogle(ctx context.Context) (*KiroTokenData, error) { + return c.LoginWithSocial(ctx, ProviderGoogle) +} + +// LoginWithGitHub performs OAuth login with GitHub. +func (c *SocialAuthClient) LoginWithGitHub(ctx context.Context) (*KiroTokenData, error) { + return c.LoginWithSocial(ctx, ProviderGitHub) +} + +// forceDefaultProtocolHandler sets our protocol handler as the default for kiro:// URLs. +// This prevents the "Open with" dialog from appearing on Linux. +// On non-Linux platforms, this is a no-op as they use different mechanisms. +func forceDefaultProtocolHandler() { + if runtime.GOOS != "linux" { + return // Non-Linux platforms use different handler mechanisms + } + + // Set our handler as default using xdg-mime + cmd := exec.Command("xdg-mime", "default", "kiro-oauth-handler.desktop", "x-scheme-handler/kiro") + if err := cmd.Run(); err != nil { + log.Warnf("Failed to set default protocol handler: %v. You may see a handler selection dialog.", err) + } +} + +// isInteractiveTerminal checks if stdin is connected to an interactive terminal. +// Returns false in CI/automated environments or when stdin is piped. +func isInteractiveTerminal() bool { + return term.IsTerminal(int(os.Stdin.Fd())) +} diff --git a/internal/auth/kiro/sso_oidc.go b/internal/auth/kiro/sso_oidc.go new file mode 100644 index 00000000..d3c27d16 --- /dev/null +++ b/internal/auth/kiro/sso_oidc.go @@ -0,0 +1,527 @@ +// Package kiro provides AWS SSO OIDC authentication for Kiro. +package kiro + +import ( + "context" + "encoding/json" + "fmt" + "io" + "net/http" + "strings" + "time" + + "github.com/router-for-me/CLIProxyAPI/v6/internal/browser" + "github.com/router-for-me/CLIProxyAPI/v6/internal/config" + "github.com/router-for-me/CLIProxyAPI/v6/internal/util" + log "github.com/sirupsen/logrus" +) + +const ( + // AWS SSO OIDC endpoints + ssoOIDCEndpoint = "https://oidc.us-east-1.amazonaws.com" + + // Kiro's start URL for Builder ID + builderIDStartURL = "https://view.awsapps.com/start" + + // Polling interval + pollInterval = 5 * time.Second +) + +// SSOOIDCClient handles AWS SSO OIDC authentication. +type SSOOIDCClient struct { + httpClient *http.Client + cfg *config.Config +} + +// NewSSOOIDCClient creates a new SSO OIDC client. +func NewSSOOIDCClient(cfg *config.Config) *SSOOIDCClient { + client := &http.Client{Timeout: 30 * time.Second} + if cfg != nil { + client = util.SetProxy(&cfg.SDKConfig, client) + } + return &SSOOIDCClient{ + httpClient: client, + cfg: cfg, + } +} + +// RegisterClientResponse from AWS SSO OIDC. +type RegisterClientResponse struct { + ClientID string `json:"clientId"` + ClientSecret string `json:"clientSecret"` + ClientIDIssuedAt int64 `json:"clientIdIssuedAt"` + ClientSecretExpiresAt int64 `json:"clientSecretExpiresAt"` +} + +// StartDeviceAuthResponse from AWS SSO OIDC. +type StartDeviceAuthResponse struct { + DeviceCode string `json:"deviceCode"` + UserCode string `json:"userCode"` + VerificationURI string `json:"verificationUri"` + VerificationURIComplete string `json:"verificationUriComplete"` + ExpiresIn int `json:"expiresIn"` + Interval int `json:"interval"` +} + +// CreateTokenResponse from AWS SSO OIDC. +type CreateTokenResponse struct { + AccessToken string `json:"accessToken"` + TokenType string `json:"tokenType"` + ExpiresIn int `json:"expiresIn"` + RefreshToken string `json:"refreshToken"` +} + +// RegisterClient registers a new OIDC client with AWS. +func (c *SSOOIDCClient) RegisterClient(ctx context.Context) (*RegisterClientResponse, error) { + // Generate unique client name for each registration to support multiple accounts + clientName := fmt.Sprintf("CLI-Proxy-API-%d", time.Now().UnixNano()) + + payload := map[string]interface{}{ + "clientName": clientName, + "clientType": "public", + "scopes": []string{"codewhisperer:completions", "codewhisperer:analysis", "codewhisperer:conversations"}, + } + + body, err := json.Marshal(payload) + if err != nil { + return nil, err + } + + req, err := http.NewRequestWithContext(ctx, http.MethodPost, ssoOIDCEndpoint+"/client/register", strings.NewReader(string(body))) + if err != nil { + return nil, err + } + req.Header.Set("Content-Type", "application/json") + + resp, err := c.httpClient.Do(req) + if err != nil { + return nil, err + } + defer resp.Body.Close() + + respBody, err := io.ReadAll(resp.Body) + if err != nil { + return nil, err + } + + if resp.StatusCode != http.StatusOK { + log.Debugf("register client failed (status %d): %s", resp.StatusCode, string(respBody)) + return nil, fmt.Errorf("register client failed (status %d)", resp.StatusCode) + } + + var result RegisterClientResponse + if err := json.Unmarshal(respBody, &result); err != nil { + return nil, err + } + + return &result, nil +} + +// StartDeviceAuthorization starts the device authorization flow. +func (c *SSOOIDCClient) StartDeviceAuthorization(ctx context.Context, clientID, clientSecret string) (*StartDeviceAuthResponse, error) { + payload := map[string]string{ + "clientId": clientID, + "clientSecret": clientSecret, + "startUrl": builderIDStartURL, + } + + body, err := json.Marshal(payload) + if err != nil { + return nil, err + } + + req, err := http.NewRequestWithContext(ctx, http.MethodPost, ssoOIDCEndpoint+"/device_authorization", strings.NewReader(string(body))) + if err != nil { + return nil, err + } + req.Header.Set("Content-Type", "application/json") + + resp, err := c.httpClient.Do(req) + if err != nil { + return nil, err + } + defer resp.Body.Close() + + respBody, err := io.ReadAll(resp.Body) + if err != nil { + return nil, err + } + + if resp.StatusCode != http.StatusOK { + log.Debugf("start device auth failed (status %d): %s", resp.StatusCode, string(respBody)) + return nil, fmt.Errorf("start device auth failed (status %d)", resp.StatusCode) + } + + var result StartDeviceAuthResponse + if err := json.Unmarshal(respBody, &result); err != nil { + return nil, err + } + + return &result, nil +} + +// CreateToken polls for the access token after user authorization. +func (c *SSOOIDCClient) CreateToken(ctx context.Context, clientID, clientSecret, deviceCode string) (*CreateTokenResponse, error) { + payload := map[string]string{ + "clientId": clientID, + "clientSecret": clientSecret, + "deviceCode": deviceCode, + "grantType": "urn:ietf:params:oauth:grant-type:device_code", + } + + body, err := json.Marshal(payload) + if err != nil { + return nil, err + } + + req, err := http.NewRequestWithContext(ctx, http.MethodPost, ssoOIDCEndpoint+"/token", strings.NewReader(string(body))) + if err != nil { + return nil, err + } + req.Header.Set("Content-Type", "application/json") + + resp, err := c.httpClient.Do(req) + if err != nil { + return nil, err + } + defer resp.Body.Close() + + respBody, err := io.ReadAll(resp.Body) + if err != nil { + return nil, err + } + + // Check for pending authorization + if resp.StatusCode == http.StatusBadRequest { + var errResp struct { + Error string `json:"error"` + } + if json.Unmarshal(respBody, &errResp) == nil { + if errResp.Error == "authorization_pending" { + return nil, fmt.Errorf("authorization_pending") + } + if errResp.Error == "slow_down" { + return nil, fmt.Errorf("slow_down") + } + } + log.Debugf("create token failed: %s", string(respBody)) + return nil, fmt.Errorf("create token failed") + } + + if resp.StatusCode != http.StatusOK { + log.Debugf("create token failed (status %d): %s", resp.StatusCode, string(respBody)) + return nil, fmt.Errorf("create token failed (status %d)", resp.StatusCode) + } + + var result CreateTokenResponse + if err := json.Unmarshal(respBody, &result); err != nil { + return nil, err + } + + return &result, nil +} + +// RefreshToken refreshes an access token using the refresh token. +func (c *SSOOIDCClient) RefreshToken(ctx context.Context, clientID, clientSecret, refreshToken string) (*KiroTokenData, error) { + payload := map[string]string{ + "clientId": clientID, + "clientSecret": clientSecret, + "refreshToken": refreshToken, + "grantType": "refresh_token", + } + + body, err := json.Marshal(payload) + if err != nil { + return nil, err + } + + req, err := http.NewRequestWithContext(ctx, http.MethodPost, ssoOIDCEndpoint+"/token", strings.NewReader(string(body))) + if err != nil { + return nil, err + } + req.Header.Set("Content-Type", "application/json") + + resp, err := c.httpClient.Do(req) + if err != nil { + return nil, err + } + defer resp.Body.Close() + + respBody, err := io.ReadAll(resp.Body) + if err != nil { + return nil, err + } + + if resp.StatusCode != http.StatusOK { + log.Debugf("token refresh failed (status %d): %s", resp.StatusCode, string(respBody)) + return nil, fmt.Errorf("token refresh failed (status %d)", resp.StatusCode) + } + + var result CreateTokenResponse + if err := json.Unmarshal(respBody, &result); err != nil { + return nil, err + } + + expiresAt := time.Now().Add(time.Duration(result.ExpiresIn) * time.Second) + + return &KiroTokenData{ + AccessToken: result.AccessToken, + RefreshToken: result.RefreshToken, + ExpiresAt: expiresAt.Format(time.RFC3339), + AuthMethod: "builder-id", + Provider: "AWS", + ClientID: clientID, + ClientSecret: clientSecret, + }, nil +} + +// LoginWithBuilderID performs the full device code flow for AWS Builder ID. +func (c *SSOOIDCClient) LoginWithBuilderID(ctx context.Context) (*KiroTokenData, error) { + fmt.Println("\n╔══════════════════════════════════════════════════════════╗") + fmt.Println("║ Kiro Authentication (AWS Builder ID) ║") + fmt.Println("╚══════════════════════════════════════════════════════════╝") + + // Step 1: Register client + fmt.Println("\nRegistering client...") + regResp, err := c.RegisterClient(ctx) + if err != nil { + return nil, fmt.Errorf("failed to register client: %w", err) + } + log.Debugf("Client registered: %s", regResp.ClientID) + + // Step 2: Start device authorization + fmt.Println("Starting device authorization...") + authResp, err := c.StartDeviceAuthorization(ctx, regResp.ClientID, regResp.ClientSecret) + if err != nil { + return nil, fmt.Errorf("failed to start device auth: %w", err) + } + + // Step 3: Show user the verification URL + fmt.Printf("\n") + fmt.Println("════════════════════════════════════════════════════════════") + fmt.Printf(" Open this URL in your browser:\n") + fmt.Printf(" %s\n", authResp.VerificationURIComplete) + fmt.Println("════════════════════════════════════════════════════════════") + fmt.Printf("\n Or go to: %s\n", authResp.VerificationURI) + fmt.Printf(" And enter code: %s\n\n", authResp.UserCode) + + // Set incognito mode based on config (defaults to true for Kiro, can be overridden with --no-incognito) + // Incognito mode enables multi-account support by bypassing cached sessions + if c.cfg != nil { + browser.SetIncognitoMode(c.cfg.IncognitoBrowser) + if !c.cfg.IncognitoBrowser { + log.Info("kiro: using normal browser mode (--no-incognito). Note: You may not be able to select a different account.") + } else { + log.Debug("kiro: using incognito mode for multi-account support") + } + } else { + browser.SetIncognitoMode(true) // Default to incognito if no config + log.Debug("kiro: using incognito mode for multi-account support (default)") + } + + // Open browser using cross-platform browser package + if err := browser.OpenURL(authResp.VerificationURIComplete); err != nil { + log.Warnf("Could not open browser automatically: %v", err) + fmt.Println(" Please open the URL manually in your browser.") + } else { + fmt.Println(" (Browser opened automatically)") + } + + // Step 4: Poll for token + fmt.Println("Waiting for authorization...") + + interval := pollInterval + if authResp.Interval > 0 { + interval = time.Duration(authResp.Interval) * time.Second + } + + deadline := time.Now().Add(time.Duration(authResp.ExpiresIn) * time.Second) + + for time.Now().Before(deadline) { + select { + case <-ctx.Done(): + browser.CloseBrowser() // Cleanup on cancel + return nil, ctx.Err() + case <-time.After(interval): + tokenResp, err := c.CreateToken(ctx, regResp.ClientID, regResp.ClientSecret, authResp.DeviceCode) + if err != nil { + errStr := err.Error() + if strings.Contains(errStr, "authorization_pending") { + fmt.Print(".") + continue + } + if strings.Contains(errStr, "slow_down") { + interval += 5 * time.Second + continue + } + // Close browser on error before returning + browser.CloseBrowser() + return nil, fmt.Errorf("token creation failed: %w", err) + } + + fmt.Println("\n\n✓ Authorization successful!") + + // Close the browser window + if err := browser.CloseBrowser(); err != nil { + log.Debugf("Failed to close browser: %v", err) + } + + // Step 5: Get profile ARN from CodeWhisperer API + fmt.Println("Fetching profile information...") + profileArn := c.fetchProfileArn(ctx, tokenResp.AccessToken) + + // Extract email from JWT access token + email := ExtractEmailFromJWT(tokenResp.AccessToken) + if email != "" { + fmt.Printf(" Logged in as: %s\n", email) + } + + expiresAt := time.Now().Add(time.Duration(tokenResp.ExpiresIn) * time.Second) + + return &KiroTokenData{ + AccessToken: tokenResp.AccessToken, + RefreshToken: tokenResp.RefreshToken, + ProfileArn: profileArn, + ExpiresAt: expiresAt.Format(time.RFC3339), + AuthMethod: "builder-id", + Provider: "AWS", + ClientID: regResp.ClientID, + ClientSecret: regResp.ClientSecret, + Email: email, + }, nil + } + } + + // Close browser on timeout for better UX + if err := browser.CloseBrowser(); err != nil { + log.Debugf("Failed to close browser on timeout: %v", err) + } + return nil, fmt.Errorf("authorization timed out") +} + +// fetchProfileArn retrieves the profile ARN from CodeWhisperer API. +// This is needed for file naming since AWS SSO OIDC doesn't return profile info. +func (c *SSOOIDCClient) fetchProfileArn(ctx context.Context, accessToken string) string { + // Try ListProfiles API first + profileArn := c.tryListProfiles(ctx, accessToken) + if profileArn != "" { + return profileArn + } + + // Fallback: Try ListAvailableCustomizations + return c.tryListCustomizations(ctx, accessToken) +} + +func (c *SSOOIDCClient) tryListProfiles(ctx context.Context, accessToken string) string { + payload := map[string]interface{}{ + "origin": "AI_EDITOR", + } + + body, err := json.Marshal(payload) + if err != nil { + return "" + } + + req, err := http.NewRequestWithContext(ctx, http.MethodPost, "https://codewhisperer.us-east-1.amazonaws.com", strings.NewReader(string(body))) + if err != nil { + return "" + } + + req.Header.Set("Content-Type", "application/x-amz-json-1.0") + req.Header.Set("x-amz-target", "AmazonCodeWhispererService.ListProfiles") + req.Header.Set("Authorization", "Bearer "+accessToken) + req.Header.Set("Accept", "application/json") + + resp, err := c.httpClient.Do(req) + if err != nil { + return "" + } + defer resp.Body.Close() + + respBody, _ := io.ReadAll(resp.Body) + + if resp.StatusCode != http.StatusOK { + log.Debugf("ListProfiles failed (status %d): %s", resp.StatusCode, string(respBody)) + return "" + } + + log.Debugf("ListProfiles response: %s", string(respBody)) + + var result struct { + Profiles []struct { + Arn string `json:"arn"` + } `json:"profiles"` + ProfileArn string `json:"profileArn"` + } + + if err := json.Unmarshal(respBody, &result); err != nil { + return "" + } + + if result.ProfileArn != "" { + return result.ProfileArn + } + + if len(result.Profiles) > 0 { + return result.Profiles[0].Arn + } + + return "" +} + +func (c *SSOOIDCClient) tryListCustomizations(ctx context.Context, accessToken string) string { + payload := map[string]interface{}{ + "origin": "AI_EDITOR", + } + + body, err := json.Marshal(payload) + if err != nil { + return "" + } + + req, err := http.NewRequestWithContext(ctx, http.MethodPost, "https://codewhisperer.us-east-1.amazonaws.com", strings.NewReader(string(body))) + if err != nil { + return "" + } + + req.Header.Set("Content-Type", "application/x-amz-json-1.0") + req.Header.Set("x-amz-target", "AmazonCodeWhispererService.ListAvailableCustomizations") + req.Header.Set("Authorization", "Bearer "+accessToken) + req.Header.Set("Accept", "application/json") + + resp, err := c.httpClient.Do(req) + if err != nil { + return "" + } + defer resp.Body.Close() + + respBody, _ := io.ReadAll(resp.Body) + + if resp.StatusCode != http.StatusOK { + log.Debugf("ListAvailableCustomizations failed (status %d): %s", resp.StatusCode, string(respBody)) + return "" + } + + log.Debugf("ListAvailableCustomizations response: %s", string(respBody)) + + var result struct { + Customizations []struct { + Arn string `json:"arn"` + } `json:"customizations"` + ProfileArn string `json:"profileArn"` + } + + if err := json.Unmarshal(respBody, &result); err != nil { + return "" + } + + if result.ProfileArn != "" { + return result.ProfileArn + } + + if len(result.Customizations) > 0 { + return result.Customizations[0].Arn + } + + return "" +} diff --git a/internal/auth/kiro/token.go b/internal/auth/kiro/token.go new file mode 100644 index 00000000..e83b1728 --- /dev/null +++ b/internal/auth/kiro/token.go @@ -0,0 +1,72 @@ +package kiro + +import ( + "encoding/json" + "fmt" + "os" + "path/filepath" +) + +// KiroTokenStorage holds the persistent token data for Kiro authentication. +type KiroTokenStorage struct { + // AccessToken is the OAuth2 access token for API access + AccessToken string `json:"access_token"` + // RefreshToken is used to obtain new access tokens + RefreshToken string `json:"refresh_token"` + // ProfileArn is the AWS CodeWhisperer profile ARN + ProfileArn string `json:"profile_arn"` + // ExpiresAt is the timestamp when the token expires + ExpiresAt string `json:"expires_at"` + // AuthMethod indicates the authentication method used + AuthMethod string `json:"auth_method"` + // Provider indicates the OAuth provider + Provider string `json:"provider"` + // LastRefresh is the timestamp of the last token refresh + LastRefresh string `json:"last_refresh"` +} + +// SaveTokenToFile persists the token storage to the specified file path. +func (s *KiroTokenStorage) SaveTokenToFile(authFilePath string) error { + dir := filepath.Dir(authFilePath) + if err := os.MkdirAll(dir, 0700); err != nil { + return fmt.Errorf("failed to create directory: %w", err) + } + + data, err := json.MarshalIndent(s, "", " ") + if err != nil { + return fmt.Errorf("failed to marshal token storage: %w", err) + } + + if err := os.WriteFile(authFilePath, data, 0600); err != nil { + return fmt.Errorf("failed to write token file: %w", err) + } + + return nil +} + +// LoadFromFile loads token storage from the specified file path. +func LoadFromFile(authFilePath string) (*KiroTokenStorage, error) { + data, err := os.ReadFile(authFilePath) + if err != nil { + return nil, fmt.Errorf("failed to read token file: %w", err) + } + + var storage KiroTokenStorage + if err := json.Unmarshal(data, &storage); err != nil { + return nil, fmt.Errorf("failed to parse token file: %w", err) + } + + return &storage, nil +} + +// ToTokenData converts storage to KiroTokenData for API use. +func (s *KiroTokenStorage) ToTokenData() *KiroTokenData { + return &KiroTokenData{ + AccessToken: s.AccessToken, + RefreshToken: s.RefreshToken, + ProfileArn: s.ProfileArn, + ExpiresAt: s.ExpiresAt, + AuthMethod: s.AuthMethod, + Provider: s.Provider, + } +} diff --git a/internal/browser/browser.go b/internal/browser/browser.go index b24dc5e1..1ff0b469 100644 --- a/internal/browser/browser.go +++ b/internal/browser/browser.go @@ -6,14 +6,48 @@ import ( "fmt" "os/exec" "runtime" + "sync" + pkgbrowser "github.com/pkg/browser" log "github.com/sirupsen/logrus" - "github.com/skratchdot/open-golang/open" ) +// incognitoMode controls whether to open URLs in incognito/private mode. +// This is useful for OAuth flows where you want to use a different account. +var incognitoMode bool + +// lastBrowserProcess stores the last opened browser process for cleanup +var lastBrowserProcess *exec.Cmd +var browserMutex sync.Mutex + +// SetIncognitoMode enables or disables incognito/private browsing mode. +func SetIncognitoMode(enabled bool) { + incognitoMode = enabled +} + +// IsIncognitoMode returns whether incognito mode is enabled. +func IsIncognitoMode() bool { + return incognitoMode +} + +// CloseBrowser closes the last opened browser process. +func CloseBrowser() error { + browserMutex.Lock() + defer browserMutex.Unlock() + + if lastBrowserProcess == nil || lastBrowserProcess.Process == nil { + return nil + } + + err := lastBrowserProcess.Process.Kill() + lastBrowserProcess = nil + return err +} + // OpenURL opens the specified URL in the default web browser. -// It first attempts to use a platform-agnostic library and falls back to -// platform-specific commands if that fails. +// It uses the pkg/browser library which provides robust cross-platform support +// for Windows, macOS, and Linux. +// If incognito mode is enabled, it will open in a private/incognito window. // // Parameters: // - url: The URL to open. @@ -21,16 +55,22 @@ import ( // Returns: // - An error if the URL cannot be opened, otherwise nil. func OpenURL(url string) error { - fmt.Printf("Attempting to open URL in browser: %s\n", url) + log.Debugf("Opening URL in browser: %s (incognito=%v)", url, incognitoMode) - // Try using the open-golang library first - err := open.Run(url) + // If incognito mode is enabled, use platform-specific incognito commands + if incognitoMode { + log.Debug("Using incognito mode") + return openURLIncognito(url) + } + + // Use pkg/browser for cross-platform support + err := pkgbrowser.OpenURL(url) if err == nil { - log.Debug("Successfully opened URL using open-golang library") + log.Debug("Successfully opened URL using pkg/browser library") return nil } - log.Debugf("open-golang failed: %v, trying platform-specific commands", err) + log.Debugf("pkg/browser failed: %v, trying platform-specific commands", err) // Fallback to platform-specific commands return openURLPlatformSpecific(url) @@ -78,18 +118,394 @@ func openURLPlatformSpecific(url string) error { return nil } +// openURLIncognito opens a URL in incognito/private browsing mode. +// It first tries to detect the default browser and use its incognito flag. +// Falls back to a chain of known browsers if detection fails. +// +// Parameters: +// - url: The URL to open. +// +// Returns: +// - An error if the URL cannot be opened, otherwise nil. +func openURLIncognito(url string) error { + // First, try to detect and use the default browser + if cmd := tryDefaultBrowserIncognito(url); cmd != nil { + log.Debugf("Using detected default browser: %s %v", cmd.Path, cmd.Args[1:]) + if err := cmd.Start(); err == nil { + storeBrowserProcess(cmd) + log.Debug("Successfully opened URL in default browser's incognito mode") + return nil + } + log.Debugf("Failed to start default browser, trying fallback chain") + } + + // Fallback to known browser chain + cmd := tryFallbackBrowsersIncognito(url) + if cmd == nil { + log.Warn("No browser with incognito support found, falling back to normal mode") + return openURLPlatformSpecific(url) + } + + log.Debugf("Running incognito command: %s %v", cmd.Path, cmd.Args[1:]) + err := cmd.Start() + if err != nil { + log.Warnf("Failed to open incognito browser: %v, falling back to normal mode", err) + return openURLPlatformSpecific(url) + } + + storeBrowserProcess(cmd) + log.Debug("Successfully opened URL in incognito/private mode") + return nil +} + +// storeBrowserProcess safely stores the browser process for later cleanup. +func storeBrowserProcess(cmd *exec.Cmd) { + browserMutex.Lock() + lastBrowserProcess = cmd + browserMutex.Unlock() +} + +// tryDefaultBrowserIncognito attempts to detect the default browser and return +// an exec.Cmd configured with the appropriate incognito flag. +func tryDefaultBrowserIncognito(url string) *exec.Cmd { + switch runtime.GOOS { + case "darwin": + return tryDefaultBrowserMacOS(url) + case "windows": + return tryDefaultBrowserWindows(url) + case "linux": + return tryDefaultBrowserLinux(url) + } + return nil +} + +// tryDefaultBrowserMacOS detects the default browser on macOS. +func tryDefaultBrowserMacOS(url string) *exec.Cmd { + // Try to get default browser from Launch Services + out, err := exec.Command("defaults", "read", "com.apple.LaunchServices/com.apple.launchservices.secure", "LSHandlers").Output() + if err != nil { + return nil + } + + output := string(out) + var browserName string + + // Parse the output to find the http/https handler + if containsBrowserID(output, "com.google.chrome") { + browserName = "chrome" + } else if containsBrowserID(output, "org.mozilla.firefox") { + browserName = "firefox" + } else if containsBrowserID(output, "com.apple.safari") { + browserName = "safari" + } else if containsBrowserID(output, "com.brave.browser") { + browserName = "brave" + } else if containsBrowserID(output, "com.microsoft.edgemac") { + browserName = "edge" + } + + return createMacOSIncognitoCmd(browserName, url) +} + +// containsBrowserID checks if the LaunchServices output contains a browser ID. +func containsBrowserID(output, bundleID string) bool { + return stringContains(output, bundleID) +} + +// stringContains is a simple string contains check. +func stringContains(s, substr string) bool { + return len(s) >= len(substr) && (s == substr || len(substr) == 0 || + (len(s) > 0 && len(substr) > 0 && findSubstring(s, substr))) +} + +func findSubstring(s, substr string) bool { + for i := 0; i <= len(s)-len(substr); i++ { + if s[i:i+len(substr)] == substr { + return true + } + } + return false +} + +// createMacOSIncognitoCmd creates the appropriate incognito command for macOS browsers. +func createMacOSIncognitoCmd(browserName, url string) *exec.Cmd { + switch browserName { + case "chrome": + // Try direct path first + chromePath := "/Applications/Google Chrome.app/Contents/MacOS/Google Chrome" + if _, err := exec.LookPath(chromePath); err == nil { + return exec.Command(chromePath, "--incognito", url) + } + return exec.Command("open", "-na", "Google Chrome", "--args", "--incognito", url) + case "firefox": + return exec.Command("open", "-na", "Firefox", "--args", "--private-window", url) + case "safari": + // Safari doesn't have CLI incognito, try AppleScript + return tryAppleScriptSafariPrivate(url) + case "brave": + return exec.Command("open", "-na", "Brave Browser", "--args", "--incognito", url) + case "edge": + return exec.Command("open", "-na", "Microsoft Edge", "--args", "--inprivate", url) + } + return nil +} + +// tryAppleScriptSafariPrivate attempts to open Safari in private browsing mode using AppleScript. +func tryAppleScriptSafariPrivate(url string) *exec.Cmd { + // AppleScript to open a new private window in Safari + script := fmt.Sprintf(` + tell application "Safari" + activate + tell application "System Events" + keystroke "n" using {command down, shift down} + delay 0.5 + end tell + set URL of document 1 to "%s" + end tell + `, url) + + cmd := exec.Command("osascript", "-e", script) + // Test if this approach works by checking if Safari is available + if _, err := exec.LookPath("/Applications/Safari.app/Contents/MacOS/Safari"); err != nil { + log.Debug("Safari not found, AppleScript private window not available") + return nil + } + log.Debug("Attempting Safari private window via AppleScript") + return cmd +} + +// tryDefaultBrowserWindows detects the default browser on Windows via registry. +func tryDefaultBrowserWindows(url string) *exec.Cmd { + // Query registry for default browser + out, err := exec.Command("reg", "query", + `HKEY_CURRENT_USER\Software\Microsoft\Windows\Shell\Associations\UrlAssociations\http\UserChoice`, + "/v", "ProgId").Output() + if err != nil { + return nil + } + + output := string(out) + var browserName string + + // Map ProgId to browser name + if stringContains(output, "ChromeHTML") { + browserName = "chrome" + } else if stringContains(output, "FirefoxURL") { + browserName = "firefox" + } else if stringContains(output, "MSEdgeHTM") { + browserName = "edge" + } else if stringContains(output, "BraveHTML") { + browserName = "brave" + } + + return createWindowsIncognitoCmd(browserName, url) +} + +// createWindowsIncognitoCmd creates the appropriate incognito command for Windows browsers. +func createWindowsIncognitoCmd(browserName, url string) *exec.Cmd { + switch browserName { + case "chrome": + paths := []string{ + "chrome", + `C:\Program Files\Google\Chrome\Application\chrome.exe`, + `C:\Program Files (x86)\Google\Chrome\Application\chrome.exe`, + } + for _, p := range paths { + if _, err := exec.LookPath(p); err == nil { + return exec.Command(p, "--incognito", url) + } + } + case "firefox": + if path, err := exec.LookPath("firefox"); err == nil { + return exec.Command(path, "--private-window", url) + } + case "edge": + paths := []string{ + "msedge", + `C:\Program Files (x86)\Microsoft\Edge\Application\msedge.exe`, + `C:\Program Files\Microsoft\Edge\Application\msedge.exe`, + } + for _, p := range paths { + if _, err := exec.LookPath(p); err == nil { + return exec.Command(p, "--inprivate", url) + } + } + case "brave": + paths := []string{ + `C:\Program Files\BraveSoftware\Brave-Browser\Application\brave.exe`, + `C:\Program Files (x86)\BraveSoftware\Brave-Browser\Application\brave.exe`, + } + for _, p := range paths { + if _, err := exec.LookPath(p); err == nil { + return exec.Command(p, "--incognito", url) + } + } + } + return nil +} + +// tryDefaultBrowserLinux detects the default browser on Linux using xdg-settings. +func tryDefaultBrowserLinux(url string) *exec.Cmd { + out, err := exec.Command("xdg-settings", "get", "default-web-browser").Output() + if err != nil { + return nil + } + + desktop := string(out) + var browserName string + + // Map .desktop file to browser name + if stringContains(desktop, "google-chrome") || stringContains(desktop, "chrome") { + browserName = "chrome" + } else if stringContains(desktop, "firefox") { + browserName = "firefox" + } else if stringContains(desktop, "chromium") { + browserName = "chromium" + } else if stringContains(desktop, "brave") { + browserName = "brave" + } else if stringContains(desktop, "microsoft-edge") || stringContains(desktop, "msedge") { + browserName = "edge" + } + + return createLinuxIncognitoCmd(browserName, url) +} + +// createLinuxIncognitoCmd creates the appropriate incognito command for Linux browsers. +func createLinuxIncognitoCmd(browserName, url string) *exec.Cmd { + switch browserName { + case "chrome": + paths := []string{"google-chrome", "google-chrome-stable"} + for _, p := range paths { + if path, err := exec.LookPath(p); err == nil { + return exec.Command(path, "--incognito", url) + } + } + case "firefox": + paths := []string{"firefox", "firefox-esr"} + for _, p := range paths { + if path, err := exec.LookPath(p); err == nil { + return exec.Command(path, "--private-window", url) + } + } + case "chromium": + paths := []string{"chromium", "chromium-browser"} + for _, p := range paths { + if path, err := exec.LookPath(p); err == nil { + return exec.Command(path, "--incognito", url) + } + } + case "brave": + if path, err := exec.LookPath("brave-browser"); err == nil { + return exec.Command(path, "--incognito", url) + } + case "edge": + if path, err := exec.LookPath("microsoft-edge"); err == nil { + return exec.Command(path, "--inprivate", url) + } + } + return nil +} + +// tryFallbackBrowsersIncognito tries a chain of known browsers as fallback. +func tryFallbackBrowsersIncognito(url string) *exec.Cmd { + switch runtime.GOOS { + case "darwin": + return tryFallbackBrowsersMacOS(url) + case "windows": + return tryFallbackBrowsersWindows(url) + case "linux": + return tryFallbackBrowsersLinuxChain(url) + } + return nil +} + +// tryFallbackBrowsersMacOS tries known browsers on macOS. +func tryFallbackBrowsersMacOS(url string) *exec.Cmd { + // Try Chrome + chromePath := "/Applications/Google Chrome.app/Contents/MacOS/Google Chrome" + if _, err := exec.LookPath(chromePath); err == nil { + return exec.Command(chromePath, "--incognito", url) + } + // Try Firefox + if _, err := exec.LookPath("/Applications/Firefox.app/Contents/MacOS/firefox"); err == nil { + return exec.Command("open", "-na", "Firefox", "--args", "--private-window", url) + } + // Try Brave + if _, err := exec.LookPath("/Applications/Brave Browser.app/Contents/MacOS/Brave Browser"); err == nil { + return exec.Command("open", "-na", "Brave Browser", "--args", "--incognito", url) + } + // Try Edge + if _, err := exec.LookPath("/Applications/Microsoft Edge.app/Contents/MacOS/Microsoft Edge"); err == nil { + return exec.Command("open", "-na", "Microsoft Edge", "--args", "--inprivate", url) + } + // Last resort: try Safari with AppleScript + if cmd := tryAppleScriptSafariPrivate(url); cmd != nil { + log.Info("Using Safari with AppleScript for private browsing (may require accessibility permissions)") + return cmd + } + return nil +} + +// tryFallbackBrowsersWindows tries known browsers on Windows. +func tryFallbackBrowsersWindows(url string) *exec.Cmd { + // Chrome + chromePaths := []string{ + "chrome", + `C:\Program Files\Google\Chrome\Application\chrome.exe`, + `C:\Program Files (x86)\Google\Chrome\Application\chrome.exe`, + } + for _, p := range chromePaths { + if _, err := exec.LookPath(p); err == nil { + return exec.Command(p, "--incognito", url) + } + } + // Firefox + if path, err := exec.LookPath("firefox"); err == nil { + return exec.Command(path, "--private-window", url) + } + // Edge (usually available on Windows 10+) + edgePaths := []string{ + "msedge", + `C:\Program Files (x86)\Microsoft\Edge\Application\msedge.exe`, + `C:\Program Files\Microsoft\Edge\Application\msedge.exe`, + } + for _, p := range edgePaths { + if _, err := exec.LookPath(p); err == nil { + return exec.Command(p, "--inprivate", url) + } + } + return nil +} + +// tryFallbackBrowsersLinuxChain tries known browsers on Linux. +func tryFallbackBrowsersLinuxChain(url string) *exec.Cmd { + type browserConfig struct { + name string + flag string + } + browsers := []browserConfig{ + {"google-chrome", "--incognito"}, + {"google-chrome-stable", "--incognito"}, + {"chromium", "--incognito"}, + {"chromium-browser", "--incognito"}, + {"firefox", "--private-window"}, + {"firefox-esr", "--private-window"}, + {"brave-browser", "--incognito"}, + {"microsoft-edge", "--inprivate"}, + } + for _, b := range browsers { + if path, err := exec.LookPath(b.name); err == nil { + return exec.Command(path, b.flag, url) + } + } + return nil +} + // IsAvailable checks if the system has a command available to open a web browser. // It verifies the presence of necessary commands for the current operating system. // // Returns: // - true if a browser can be opened, false otherwise. func IsAvailable() bool { - // First check if open-golang can work - testErr := open.Run("about:blank") - if testErr == nil { - return true - } - // Check platform-specific commands switch runtime.GOOS { case "darwin": diff --git a/internal/cmd/auth_manager.go b/internal/cmd/auth_manager.go index e6caa954..692ea34d 100644 --- a/internal/cmd/auth_manager.go +++ b/internal/cmd/auth_manager.go @@ -19,6 +19,7 @@ func newAuthManager() *sdkAuth.Manager { sdkAuth.NewQwenAuthenticator(), sdkAuth.NewIFlowAuthenticator(), sdkAuth.NewAntigravityAuthenticator(), + sdkAuth.NewKiroAuthenticator(), ) return manager } diff --git a/internal/cmd/kiro_login.go b/internal/cmd/kiro_login.go new file mode 100644 index 00000000..5fc3b9eb --- /dev/null +++ b/internal/cmd/kiro_login.go @@ -0,0 +1,160 @@ +package cmd + +import ( + "context" + "fmt" + + "github.com/router-for-me/CLIProxyAPI/v6/internal/config" + sdkAuth "github.com/router-for-me/CLIProxyAPI/v6/sdk/auth" + log "github.com/sirupsen/logrus" +) + +// DoKiroLogin triggers the Kiro authentication flow with Google OAuth. +// This is the default login method (same as --kiro-google-login). +// +// Parameters: +// - cfg: The application configuration +// - options: Login options including Prompt field +func DoKiroLogin(cfg *config.Config, options *LoginOptions) { + // Use Google login as default + DoKiroGoogleLogin(cfg, options) +} + +// DoKiroGoogleLogin triggers Kiro authentication with Google OAuth. +// This uses a custom protocol handler (kiro://) to receive the callback. +// +// Parameters: +// - cfg: The application configuration +// - options: Login options including prompts +func DoKiroGoogleLogin(cfg *config.Config, options *LoginOptions) { + if options == nil { + options = &LoginOptions{} + } + + // Note: Kiro defaults to incognito mode for multi-account support. + // Users can override with --no-incognito if they want to use existing browser sessions. + + manager := newAuthManager() + + // Use KiroAuthenticator with Google login + authenticator := sdkAuth.NewKiroAuthenticator() + record, err := authenticator.LoginWithGoogle(context.Background(), cfg, &sdkAuth.LoginOptions{ + NoBrowser: options.NoBrowser, + Metadata: map[string]string{}, + Prompt: options.Prompt, + }) + if err != nil { + log.Errorf("Kiro Google authentication failed: %v", err) + fmt.Println("\nTroubleshooting:") + fmt.Println("1. Make sure the protocol handler is installed") + fmt.Println("2. Complete the Google login in the browser") + fmt.Println("3. If callback fails, try: --kiro-import (after logging in via Kiro IDE)") + return + } + + // Save the auth record + savedPath, err := manager.SaveAuth(record, cfg) + if err != nil { + log.Errorf("Failed to save auth: %v", err) + return + } + + if savedPath != "" { + fmt.Printf("Authentication saved to %s\n", savedPath) + } + if record != nil && record.Label != "" { + fmt.Printf("Authenticated as %s\n", record.Label) + } + fmt.Println("Kiro Google authentication successful!") +} + +// DoKiroAWSLogin triggers Kiro authentication with AWS Builder ID. +// This uses the device code flow for AWS SSO OIDC authentication. +// +// Parameters: +// - cfg: The application configuration +// - options: Login options including prompts +func DoKiroAWSLogin(cfg *config.Config, options *LoginOptions) { + if options == nil { + options = &LoginOptions{} + } + + // Note: Kiro defaults to incognito mode for multi-account support. + // Users can override with --no-incognito if they want to use existing browser sessions. + + manager := newAuthManager() + + // Use KiroAuthenticator with AWS Builder ID login (device code flow) + authenticator := sdkAuth.NewKiroAuthenticator() + record, err := authenticator.Login(context.Background(), cfg, &sdkAuth.LoginOptions{ + NoBrowser: options.NoBrowser, + Metadata: map[string]string{}, + Prompt: options.Prompt, + }) + if err != nil { + log.Errorf("Kiro AWS authentication failed: %v", err) + fmt.Println("\nTroubleshooting:") + fmt.Println("1. Make sure you have an AWS Builder ID") + fmt.Println("2. Complete the authorization in the browser") + fmt.Println("3. If callback fails, try: --kiro-import (after logging in via Kiro IDE)") + return + } + + // Save the auth record + savedPath, err := manager.SaveAuth(record, cfg) + if err != nil { + log.Errorf("Failed to save auth: %v", err) + return + } + + if savedPath != "" { + fmt.Printf("Authentication saved to %s\n", savedPath) + } + if record != nil && record.Label != "" { + fmt.Printf("Authenticated as %s\n", record.Label) + } + fmt.Println("Kiro AWS authentication successful!") +} + +// DoKiroImport imports Kiro token from Kiro IDE's token file. +// This is useful for users who have already logged in via Kiro IDE +// and want to use the same credentials in CLI Proxy API. +// +// Parameters: +// - cfg: The application configuration +// - options: Login options (currently unused for import) +func DoKiroImport(cfg *config.Config, options *LoginOptions) { + if options == nil { + options = &LoginOptions{} + } + + manager := newAuthManager() + + // Use ImportFromKiroIDE instead of Login + authenticator := sdkAuth.NewKiroAuthenticator() + record, err := authenticator.ImportFromKiroIDE(context.Background(), cfg) + if err != nil { + log.Errorf("Kiro token import failed: %v", err) + fmt.Println("\nMake sure you have logged in to Kiro IDE first:") + fmt.Println("1. Open Kiro IDE") + fmt.Println("2. Click 'Sign in with Google' (or GitHub)") + fmt.Println("3. Complete the login process") + fmt.Println("4. Run this command again") + return + } + + // Save the imported auth record + savedPath, err := manager.SaveAuth(record, cfg) + if err != nil { + log.Errorf("Failed to save auth: %v", err) + return + } + + if savedPath != "" { + fmt.Printf("Authentication saved to %s\n", savedPath) + } + if record != nil && record.Label != "" { + fmt.Printf("Imported as %s\n", record.Label) + } + fmt.Println("Kiro token import successful!") +} diff --git a/internal/config/config.go b/internal/config/config.go index 2681d049..1c72ece4 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -58,6 +58,9 @@ type Config struct { // GeminiKey defines Gemini API key configurations with optional routing overrides. GeminiKey []GeminiKey `yaml:"gemini-api-key" json:"gemini-api-key"` + // KiroKey defines a list of Kiro (AWS CodeWhisperer) configurations. + KiroKey []KiroKey `yaml:"kiro" json:"kiro"` + // Codex defines a list of Codex API key configurations as specified in the YAML configuration file. CodexKey []CodexKey `yaml:"codex-api-key" json:"codex-api-key"` @@ -80,6 +83,11 @@ type Config struct { // Payload defines default and override rules for provider payload parameters. Payload PayloadConfig `yaml:"payload" json:"payload"` + // IncognitoBrowser enables opening OAuth URLs in incognito/private browsing mode. + // This is useful when you want to login with a different account without logging out + // from your current session. Default: false. + IncognitoBrowser bool `yaml:"incognito-browser" json:"incognito-browser"` + legacyMigrationPending bool `yaml:"-" json:"-"` } @@ -240,6 +248,31 @@ type GeminiKey struct { ExcludedModels []string `yaml:"excluded-models,omitempty" json:"excluded-models,omitempty"` } +// KiroKey represents the configuration for Kiro (AWS CodeWhisperer) authentication. +type KiroKey struct { + // TokenFile is the path to the Kiro token file (default: ~/.aws/sso/cache/kiro-auth-token.json) + TokenFile string `yaml:"token-file,omitempty" json:"token-file,omitempty"` + + // AccessToken is the OAuth access token for direct configuration. + AccessToken string `yaml:"access-token,omitempty" json:"access-token,omitempty"` + + // RefreshToken is the OAuth refresh token for token renewal. + RefreshToken string `yaml:"refresh-token,omitempty" json:"refresh-token,omitempty"` + + // ProfileArn is the AWS CodeWhisperer profile ARN. + ProfileArn string `yaml:"profile-arn,omitempty" json:"profile-arn,omitempty"` + + // Region is the AWS region (default: us-east-1). + Region string `yaml:"region,omitempty" json:"region,omitempty"` + + // ProxyURL optionally overrides the global proxy for this configuration. + ProxyURL string `yaml:"proxy-url,omitempty" json:"proxy-url,omitempty"` + + // AgentTaskType sets the Kiro API task type. Known values: "vibe", "dev", "chat". + // Leave empty to let API use defaults. Different values may inject different system prompts. + AgentTaskType string `yaml:"agent-task-type,omitempty" json:"agent-task-type,omitempty"` +} + // OpenAICompatibility represents the configuration for OpenAI API compatibility // with external providers, allowing model aliases to be routed through OpenAI API format. type OpenAICompatibility struct { @@ -320,6 +353,7 @@ func LoadConfigOptional(configFile string, optional bool) (*Config, error) { cfg.UsageStatisticsEnabled = false cfg.DisableCooling = false cfg.AmpCode.RestrictManagementToLocalhost = true // Default to secure: only localhost access + cfg.IncognitoBrowser = false // Default to normal browser (AWS uses incognito by force) if err = yaml.Unmarshal(data, &cfg); err != nil { if optional { // In cloud deploy mode, if YAML parsing fails, return empty config instead of error. @@ -370,6 +404,9 @@ func LoadConfigOptional(configFile string, optional bool) (*Config, error) { // Sanitize Claude key headers cfg.SanitizeClaudeKeys() + // Sanitize Kiro keys: trim whitespace from credential fields + cfg.SanitizeKiroKeys() + // Sanitize OpenAI compatibility providers: drop entries without base-url cfg.SanitizeOpenAICompatibility() @@ -446,6 +483,22 @@ func (cfg *Config) SanitizeClaudeKeys() { } } +// SanitizeKiroKeys trims whitespace from Kiro credential fields. +func (cfg *Config) SanitizeKiroKeys() { + if cfg == nil || len(cfg.KiroKey) == 0 { + return + } + for i := range cfg.KiroKey { + entry := &cfg.KiroKey[i] + entry.TokenFile = strings.TrimSpace(entry.TokenFile) + entry.AccessToken = strings.TrimSpace(entry.AccessToken) + entry.RefreshToken = strings.TrimSpace(entry.RefreshToken) + entry.ProfileArn = strings.TrimSpace(entry.ProfileArn) + entry.Region = strings.TrimSpace(entry.Region) + entry.ProxyURL = strings.TrimSpace(entry.ProxyURL) + } +} + // SanitizeGeminiKeys deduplicates and normalizes Gemini credentials. func (cfg *Config) SanitizeGeminiKeys() { if cfg == nil { diff --git a/internal/constant/constant.go b/internal/constant/constant.go index 58b388a1..1dbeecde 100644 --- a/internal/constant/constant.go +++ b/internal/constant/constant.go @@ -24,4 +24,7 @@ const ( // Antigravity represents the Antigravity response format identifier. Antigravity = "antigravity" + + // Kiro represents the AWS CodeWhisperer (Kiro) provider identifier. + Kiro = "kiro" ) diff --git a/internal/logging/global_logger.go b/internal/logging/global_logger.go index 28fde213..74a79efc 100644 --- a/internal/logging/global_logger.go +++ b/internal/logging/global_logger.go @@ -38,13 +38,16 @@ func (m *LogFormatter) Format(entry *log.Entry) ([]byte, error) { timestamp := entry.Time.Format("2006-01-02 15:04:05") message := strings.TrimRight(entry.Message, "\r\n") - - var formatted string + + // Handle nil Caller (can happen with some log entries) + callerFile := "unknown" + callerLine := 0 if entry.Caller != nil { - formatted = fmt.Sprintf("[%s] [%s] [%s:%d] %s\n", timestamp, entry.Level, filepath.Base(entry.Caller.File), entry.Caller.Line, message) - } else { - formatted = fmt.Sprintf("[%s] [%s] %s\n", timestamp, entry.Level, message) + callerFile = filepath.Base(entry.Caller.File) + callerLine = entry.Caller.Line } + + formatted := fmt.Sprintf("[%s] [%s] [%s:%d] %s\n", timestamp, entry.Level, callerFile, callerLine, message) buffer.WriteString(formatted) return buffer.Bytes(), nil @@ -55,6 +58,7 @@ func (m *LogFormatter) Format(entry *log.Entry) ([]byte, error) { func SetupBaseLogger() { setupOnce.Do(func() { log.SetOutput(os.Stdout) + log.SetLevel(log.InfoLevel) log.SetReportCaller(true) log.SetFormatter(&LogFormatter{}) diff --git a/internal/registry/model_definitions.go b/internal/registry/model_definitions.go index 36aa83bb..9f10eeac 100644 --- a/internal/registry/model_definitions.go +++ b/internal/registry/model_definitions.go @@ -986,3 +986,161 @@ func GetIFlowModels() []*ModelInfo { } return models } + +// GetKiroModels returns the Kiro (AWS CodeWhisperer) model definitions +func GetKiroModels() []*ModelInfo { + return []*ModelInfo{ + { + ID: "kiro-auto", + Object: "model", + Created: 1732752000, // 2024-11-28 + OwnedBy: "aws", + Type: "kiro", + DisplayName: "Kiro Auto", + Description: "Automatic model selection by AWS CodeWhisperer", + ContextLength: 200000, + MaxCompletionTokens: 64000, + }, + { + ID: "kiro-claude-opus-4.5", + Object: "model", + Created: 1732752000, + OwnedBy: "aws", + Type: "kiro", + DisplayName: "Kiro Claude Opus 4.5", + Description: "Claude Opus 4.5 via Kiro (2.2x credit)", + ContextLength: 200000, + MaxCompletionTokens: 64000, + }, + { + ID: "kiro-claude-sonnet-4.5", + Object: "model", + Created: 1732752000, + OwnedBy: "aws", + Type: "kiro", + DisplayName: "Kiro Claude Sonnet 4.5", + Description: "Claude Sonnet 4.5 via Kiro (1.3x credit)", + ContextLength: 200000, + MaxCompletionTokens: 64000, + }, + { + ID: "kiro-claude-sonnet-4", + Object: "model", + Created: 1732752000, + OwnedBy: "aws", + Type: "kiro", + DisplayName: "Kiro Claude Sonnet 4", + Description: "Claude Sonnet 4 via Kiro (1.3x credit)", + ContextLength: 200000, + MaxCompletionTokens: 64000, + }, + { + ID: "kiro-claude-haiku-4.5", + Object: "model", + Created: 1732752000, + OwnedBy: "aws", + Type: "kiro", + DisplayName: "Kiro Claude Haiku 4.5", + Description: "Claude Haiku 4.5 via Kiro (0.4x credit)", + ContextLength: 200000, + MaxCompletionTokens: 64000, + }, + // --- Chat Variant (No tool calling, for pure conversation) --- + { + ID: "kiro-claude-opus-4.5-chat", + Object: "model", + Created: 1732752000, + OwnedBy: "aws", + Type: "kiro", + DisplayName: "Kiro Claude Opus 4.5 (Chat)", + Description: "Claude Opus 4.5 for chat only (no tool calling)", + ContextLength: 200000, + MaxCompletionTokens: 64000, + }, + // --- Agentic Variants (Optimized for coding agents with chunked writes) --- + { + ID: "kiro-claude-opus-4.5-agentic", + Object: "model", + Created: 1732752000, + OwnedBy: "aws", + Type: "kiro", + DisplayName: "Kiro Claude Opus 4.5 (Agentic)", + Description: "Claude Opus 4.5 optimized for coding agents (chunked writes)", + ContextLength: 200000, + MaxCompletionTokens: 64000, + }, + { + ID: "kiro-claude-sonnet-4.5-agentic", + Object: "model", + Created: 1732752000, + OwnedBy: "aws", + Type: "kiro", + DisplayName: "Kiro Claude Sonnet 4.5 (Agentic)", + Description: "Claude Sonnet 4.5 optimized for coding agents (chunked writes)", + ContextLength: 200000, + MaxCompletionTokens: 64000, + }, + } +} + +// GetAmazonQModels returns the Amazon Q (AWS CodeWhisperer) model definitions. +// These models use the same API as Kiro and share the same executor. +func GetAmazonQModels() []*ModelInfo { + return []*ModelInfo{ + { + ID: "amazonq-auto", + Object: "model", + Created: 1732752000, + OwnedBy: "aws", + Type: "kiro", // Uses Kiro executor - same API + DisplayName: "Amazon Q Auto", + Description: "Automatic model selection by Amazon Q", + ContextLength: 200000, + MaxCompletionTokens: 64000, + }, + { + ID: "amazonq-claude-opus-4.5", + Object: "model", + Created: 1732752000, + OwnedBy: "aws", + Type: "kiro", + DisplayName: "Amazon Q Claude Opus 4.5", + Description: "Claude Opus 4.5 via Amazon Q (2.2x credit)", + ContextLength: 200000, + MaxCompletionTokens: 64000, + }, + { + ID: "amazonq-claude-sonnet-4.5", + Object: "model", + Created: 1732752000, + OwnedBy: "aws", + Type: "kiro", + DisplayName: "Amazon Q Claude Sonnet 4.5", + Description: "Claude Sonnet 4.5 via Amazon Q (1.3x credit)", + ContextLength: 200000, + MaxCompletionTokens: 64000, + }, + { + ID: "amazonq-claude-sonnet-4", + Object: "model", + Created: 1732752000, + OwnedBy: "aws", + Type: "kiro", + DisplayName: "Amazon Q Claude Sonnet 4", + Description: "Claude Sonnet 4 via Amazon Q (1.3x credit)", + ContextLength: 200000, + MaxCompletionTokens: 64000, + }, + { + ID: "amazonq-claude-haiku-4.5", + Object: "model", + Created: 1732752000, + OwnedBy: "aws", + Type: "kiro", + DisplayName: "Amazon Q Claude Haiku 4.5", + Description: "Claude Haiku 4.5 via Amazon Q (0.4x credit)", + ContextLength: 200000, + MaxCompletionTokens: 64000, + }, + } +} diff --git a/internal/runtime/executor/antigravity_executor.go b/internal/runtime/executor/antigravity_executor.go index 73914750..c124c829 100644 --- a/internal/runtime/executor/antigravity_executor.go +++ b/internal/runtime/executor/antigravity_executor.go @@ -12,6 +12,7 @@ import ( "net/url" "strconv" "strings" + "sync" "time" "github.com/google/uuid" @@ -41,7 +42,10 @@ const ( streamScannerBuffer int = 20_971_520 ) -var randSource = rand.New(rand.NewSource(time.Now().UnixNano())) +var ( + randSource = rand.New(rand.NewSource(time.Now().UnixNano())) + randSourceMutex sync.Mutex +) // AntigravityExecutor proxies requests to the antigravity upstream. type AntigravityExecutor struct { @@ -754,15 +758,19 @@ func generateRequestID() string { } func generateSessionID() string { + randSourceMutex.Lock() n := randSource.Int63n(9_000_000_000_000_000_000) + randSourceMutex.Unlock() return "-" + strconv.FormatInt(n, 10) } func generateProjectID() string { adjectives := []string{"useful", "bright", "swift", "calm", "bold"} nouns := []string{"fuze", "wave", "spark", "flow", "core"} + randSourceMutex.Lock() adj := adjectives[randSource.Intn(len(adjectives))] noun := nouns[randSource.Intn(len(nouns))] + randSourceMutex.Unlock() randomPart := strings.ToLower(uuid.NewString())[:5] return adj + "-" + noun + "-" + randomPart } diff --git a/internal/runtime/executor/cache_helpers.go b/internal/runtime/executor/cache_helpers.go index 5272686b..4b553662 100644 --- a/internal/runtime/executor/cache_helpers.go +++ b/internal/runtime/executor/cache_helpers.go @@ -1,10 +1,38 @@ package executor -import "time" +import ( + "sync" + "time" +) type codexCache struct { ID string Expire time.Time } -var codexCacheMap = map[string]codexCache{} +var ( + codexCacheMap = map[string]codexCache{} + codexCacheMutex sync.RWMutex +) + +// getCodexCache safely retrieves a cache entry +func getCodexCache(key string) (codexCache, bool) { + codexCacheMutex.RLock() + defer codexCacheMutex.RUnlock() + cache, ok := codexCacheMap[key] + return cache, ok +} + +// setCodexCache safely sets a cache entry +func setCodexCache(key string, cache codexCache) { + codexCacheMutex.Lock() + defer codexCacheMutex.Unlock() + codexCacheMap[key] = cache +} + +// deleteCodexCache safely deletes a cache entry +func deleteCodexCache(key string) { + codexCacheMutex.Lock() + defer codexCacheMutex.Unlock() + delete(codexCacheMap, key) +} diff --git a/internal/runtime/executor/codex_executor.go b/internal/runtime/executor/codex_executor.go index 1c4291f6..c14b7be8 100644 --- a/internal/runtime/executor/codex_executor.go +++ b/internal/runtime/executor/codex_executor.go @@ -506,12 +506,12 @@ func (e *CodexExecutor) cacheHelper(ctx context.Context, from sdktranslator.Form if userIDResult.Exists() { var hasKey bool key := fmt.Sprintf("%s-%s", req.Model, userIDResult.String()) - if cache, hasKey = codexCacheMap[key]; !hasKey || cache.Expire.Before(time.Now()) { + if cache, hasKey = getCodexCache(key); !hasKey || cache.Expire.Before(time.Now()) { cache = codexCache{ ID: uuid.New().String(), Expire: time.Now().Add(1 * time.Hour), } - codexCacheMap[key] = cache + setCodexCache(key, cache) } } } else if from == "openai-response" { diff --git a/internal/runtime/executor/kiro_executor.go b/internal/runtime/executor/kiro_executor.go new file mode 100644 index 00000000..eb068c14 --- /dev/null +++ b/internal/runtime/executor/kiro_executor.go @@ -0,0 +1,2353 @@ +package executor + +import ( + "bufio" + "bytes" + "context" + "encoding/base64" + "encoding/binary" + "encoding/json" + "fmt" + "io" + "net/http" + "regexp" + "strings" + "sync" + "time" + "unicode/utf8" + + "github.com/google/uuid" + kiroauth "github.com/router-for-me/CLIProxyAPI/v6/internal/auth/kiro" + "github.com/router-for-me/CLIProxyAPI/v6/internal/config" + "github.com/router-for-me/CLIProxyAPI/v6/internal/util" + cliproxyauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth" + cliproxyexecutor "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/executor" + "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/usage" + sdktranslator "github.com/router-for-me/CLIProxyAPI/v6/sdk/translator" + log "github.com/sirupsen/logrus" + "github.com/tidwall/gjson" + + "github.com/gin-gonic/gin" +) + +const ( + // kiroEndpoint is the Amazon Q streaming endpoint for chat API (GenerateAssistantResponse). + // Note: This is different from the CodeWhisperer management endpoint (codewhisperer.us-east-1.amazonaws.com) + // used in aws_auth.go for GetUsageLimits, ListProfiles, etc. Both endpoints are correct + // for their respective API operations. + kiroEndpoint = "https://q.us-east-1.amazonaws.com" + kiroTargetChat = "AmazonCodeWhispererStreamingService.GenerateAssistantResponse" + kiroContentType = "application/x-amz-json-1.0" + kiroAcceptStream = "application/vnd.amazon.eventstream" + kiroMaxMessageSize = 10 * 1024 * 1024 // 10MB max message size for event stream + kiroMaxToolDescLen = 10237 // Kiro API limit is 10240 bytes, leave room for "..." + + // kiroAgenticSystemPrompt is injected only for -agentic models to prevent timeouts on large writes. + // AWS Kiro API has a 2-3 minute timeout for large file write operations. + kiroAgenticSystemPrompt = ` +# CRITICAL: CHUNKED WRITE PROTOCOL (MANDATORY) + +You MUST follow these rules for ALL file operations. Violation causes server timeouts and task failure. + +## ABSOLUTE LIMITS +- **MAXIMUM 350 LINES** per single write/edit operation - NO EXCEPTIONS +- **RECOMMENDED 300 LINES** or less for optimal performance +- **NEVER** write entire files in one operation if >300 lines + +## MANDATORY CHUNKED WRITE STRATEGY + +### For NEW FILES (>300 lines total): +1. FIRST: Write initial chunk (first 250-300 lines) using write_to_file/fsWrite +2. THEN: Append remaining content in 250-300 line chunks using file append operations +3. REPEAT: Continue appending until complete + +### For EDITING EXISTING FILES: +1. Use surgical edits (apply_diff/targeted edits) - change ONLY what's needed +2. NEVER rewrite entire files - use incremental modifications +3. Split large refactors into multiple small, focused edits + +### For LARGE CODE GENERATION: +1. Generate in logical sections (imports, types, functions separately) +2. Write each section as a separate operation +3. Use append operations for subsequent sections + +## EXAMPLES OF CORRECT BEHAVIOR + +✅ CORRECT: Writing a 600-line file +- Operation 1: Write lines 1-300 (initial file creation) +- Operation 2: Append lines 301-600 + +✅ CORRECT: Editing multiple functions +- Operation 1: Edit function A +- Operation 2: Edit function B +- Operation 3: Edit function C + +❌ WRONG: Writing 500 lines in single operation → TIMEOUT +❌ WRONG: Rewriting entire file to change 5 lines → TIMEOUT +❌ WRONG: Generating massive code blocks without chunking → TIMEOUT + +## WHY THIS MATTERS +- Server has 2-3 minute timeout for operations +- Large writes exceed timeout and FAIL completely +- Chunked writes are FASTER and more RELIABLE +- Failed writes waste time and require retry + +REMEMBER: When in doubt, write LESS per operation. Multiple small operations > one large operation.` +) + +// KiroExecutor handles requests to AWS CodeWhisperer (Kiro) API. +type KiroExecutor struct { + cfg *config.Config + refreshMu sync.Mutex // Serializes token refresh operations to prevent race conditions +} + +// NewKiroExecutor creates a new Kiro executor instance. +func NewKiroExecutor(cfg *config.Config) *KiroExecutor { + return &KiroExecutor{cfg: cfg} +} + +// Identifier returns the unique identifier for this executor. +func (e *KiroExecutor) Identifier() string { return "kiro" } + +// PrepareRequest prepares the HTTP request before execution. +func (e *KiroExecutor) PrepareRequest(_ *http.Request, _ *cliproxyauth.Auth) error { return nil } + + +// Execute sends the request to Kiro API and returns the response. +// Supports automatic token refresh on 401/403 errors. +func (e *KiroExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (resp cliproxyexecutor.Response, err error) { + accessToken, profileArn := kiroCredentials(auth) + if accessToken == "" { + return resp, fmt.Errorf("kiro: access token not found in auth") + } + if profileArn == "" { + log.Warnf("kiro: profile ARN not found in auth, API calls may fail") + } + + reporter := newUsageReporter(ctx, e.Identifier(), req.Model, auth) + defer reporter.trackFailure(ctx, &err) + + // Check if token is expired before making request + if e.isTokenExpired(accessToken) { + log.Infof("kiro: access token expired, attempting refresh before request") + refreshedAuth, refreshErr := e.Refresh(ctx, auth) + if refreshErr != nil { + log.Warnf("kiro: pre-request token refresh failed: %v", refreshErr) + } else if refreshedAuth != nil { + auth = refreshedAuth + accessToken, profileArn = kiroCredentials(auth) + log.Infof("kiro: token refreshed successfully before request") + } + } + + from := opts.SourceFormat + to := sdktranslator.FromString("kiro") + body := sdktranslator.TranslateRequest(from, to, req.Model, bytes.Clone(req.Payload), true) + + kiroModelID := e.mapModelToKiro(req.Model) + + // Check if this is an agentic model variant + isAgentic := strings.HasSuffix(req.Model, "-agentic") + + // Check if this is a chat-only model variant (no tool calling) + isChatOnly := strings.HasSuffix(req.Model, "-chat") + + // Determine initial origin based on model type + // Opus models use AI_EDITOR (Kiro IDE quota), others start with CLI (Amazon Q quota) + var currentOrigin string + if strings.Contains(strings.ToLower(req.Model), "opus") { + currentOrigin = "AI_EDITOR" + } else { + currentOrigin = "CLI" + } + + kiroPayload := e.buildKiroPayload(body, kiroModelID, profileArn, currentOrigin, isAgentic, isChatOnly) + + // Execute with retry on 401/403 and 429 (quota exhausted) + resp, err = e.executeWithRetry(ctx, auth, req, opts, accessToken, profileArn, kiroPayload, body, from, to, reporter, currentOrigin, kiroModelID, isAgentic, isChatOnly) + return resp, err +} + +// executeWithRetry performs the actual HTTP request with automatic retry on auth errors. +// Supports automatic fallback from CLI (Amazon Q) quota to AI_EDITOR (Kiro IDE) quota on 429. +func (e *KiroExecutor) executeWithRetry(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options, accessToken, profileArn string, kiroPayload, body []byte, from, to sdktranslator.Format, reporter *usageReporter, currentOrigin, kiroModelID string, isAgentic, isChatOnly bool) (cliproxyexecutor.Response, error) { + var resp cliproxyexecutor.Response + maxRetries := 2 // Allow retries for token refresh + origin fallback + + for attempt := 0; attempt <= maxRetries; attempt++ { + url := kiroEndpoint + httpReq, err := http.NewRequestWithContext(ctx, http.MethodPost, url, bytes.NewReader(kiroPayload)) + if err != nil { + return resp, err + } + + httpReq.Header.Set("Content-Type", kiroContentType) + httpReq.Header.Set("x-amz-target", kiroTargetChat) + httpReq.Header.Set("Authorization", "Bearer "+accessToken) + httpReq.Header.Set("Accept", kiroAcceptStream) + + var attrs map[string]string + if auth != nil { + attrs = auth.Attributes + } + util.ApplyCustomHeadersFromAttrs(httpReq, attrs) + + var authID, authLabel, authType, authValue string + if auth != nil { + authID = auth.ID + authLabel = auth.Label + authType, authValue = auth.AccountInfo() + } + recordAPIRequest(ctx, e.cfg, upstreamRequestLog{ + URL: url, + Method: http.MethodPost, + Headers: httpReq.Header.Clone(), + Body: kiroPayload, + Provider: e.Identifier(), + AuthID: authID, + AuthLabel: authLabel, + AuthType: authType, + AuthValue: authValue, + }) + + httpClient := newProxyAwareHTTPClient(ctx, e.cfg, auth, 120*time.Second) + httpResp, err := httpClient.Do(httpReq) + if err != nil { + recordAPIResponseError(ctx, e.cfg, err) + return resp, err + } + recordAPIResponseMetadata(ctx, e.cfg, httpResp.StatusCode, httpResp.Header.Clone()) + + // Handle 429 errors (quota exhausted) with origin fallback + if httpResp.StatusCode == 429 { + respBody, _ := io.ReadAll(httpResp.Body) + _ = httpResp.Body.Close() + appendAPIResponseChunk(ctx, e.cfg, respBody) + + // If currently using CLI quota and it's exhausted, switch to AI_EDITOR (Kiro IDE) quota + if currentOrigin == "CLI" { + log.Warnf("kiro: Amazon Q (CLI) quota exhausted (429), switching to Kiro (AI_EDITOR) fallback") + currentOrigin = "AI_EDITOR" + + // Rebuild payload with new origin + kiroPayload = e.buildKiroPayload(body, kiroModelID, profileArn, currentOrigin, isAgentic, isChatOnly) + + // Retry with new origin + continue + } + + // Already on AI_EDITOR or other origin, return the error + log.Debugf("kiro request error, status: %d, body: %s", httpResp.StatusCode, summarizeErrorBody(httpResp.Header.Get("Content-Type"), respBody)) + return resp, statusErr{code: httpResp.StatusCode, msg: string(respBody)} + } + + // Handle 401/403 errors with token refresh and retry + if httpResp.StatusCode == 401 || httpResp.StatusCode == 403 { + respBody, _ := io.ReadAll(httpResp.Body) + _ = httpResp.Body.Close() + appendAPIResponseChunk(ctx, e.cfg, respBody) + + if attempt < maxRetries { + log.Warnf("kiro: received %d error, attempting token refresh and retry (attempt %d/%d)", httpResp.StatusCode, attempt+1, maxRetries+1) + + refreshedAuth, refreshErr := e.Refresh(ctx, auth) + if refreshErr != nil { + log.Errorf("kiro: token refresh failed: %v", refreshErr) + return resp, statusErr{code: httpResp.StatusCode, msg: string(respBody)} + } + + if refreshedAuth != nil { + auth = refreshedAuth + accessToken, profileArn = kiroCredentials(auth) + // Rebuild payload with new profile ARN if changed + kiroPayload = e.buildKiroPayload(body, kiroModelID, profileArn, currentOrigin, isAgentic, isChatOnly) + log.Infof("kiro: token refreshed successfully, retrying request") + continue + } + } + + log.Debugf("kiro request error, status: %d, body: %s", httpResp.StatusCode, summarizeErrorBody(httpResp.Header.Get("Content-Type"), respBody)) + return resp, statusErr{code: httpResp.StatusCode, msg: string(respBody)} + } + + if httpResp.StatusCode < 200 || httpResp.StatusCode >= 300 { + b, _ := io.ReadAll(httpResp.Body) + appendAPIResponseChunk(ctx, e.cfg, b) + log.Debugf("kiro request error, status: %d, body: %s", httpResp.StatusCode, summarizeErrorBody(httpResp.Header.Get("Content-Type"), b)) + err = statusErr{code: httpResp.StatusCode, msg: string(b)} + if errClose := httpResp.Body.Close(); errClose != nil { + log.Errorf("response body close error: %v", errClose) + } + return resp, err + } + + defer func() { + if errClose := httpResp.Body.Close(); errClose != nil { + log.Errorf("response body close error: %v", errClose) + } + }() + + content, toolUses, usageInfo, err := e.parseEventStream(httpResp.Body) + if err != nil { + recordAPIResponseError(ctx, e.cfg, err) + return resp, err + } + + // Fallback for usage if missing from upstream + if usageInfo.TotalTokens == 0 { + if enc, encErr := tokenizerForModel(req.Model); encErr == nil { + if inp, countErr := countOpenAIChatTokens(enc, opts.OriginalRequest); countErr == nil { + usageInfo.InputTokens = inp + } + } + if len(content) > 0 { + usageInfo.OutputTokens = int64(len(content) / 4) + if usageInfo.OutputTokens == 0 { + usageInfo.OutputTokens = 1 + } + } + usageInfo.TotalTokens = usageInfo.InputTokens + usageInfo.OutputTokens + } + + appendAPIResponseChunk(ctx, e.cfg, []byte(content)) + reporter.publish(ctx, usageInfo) + + // Build response in Claude format for Kiro translator + kiroResponse := e.buildClaudeResponse(content, toolUses, req.Model, usageInfo) + out := sdktranslator.TranslateNonStream(ctx, to, from, req.Model, bytes.Clone(opts.OriginalRequest), body, kiroResponse, nil) + resp = cliproxyexecutor.Response{Payload: []byte(out)} + return resp, nil + } + + return resp, fmt.Errorf("kiro: max retries exceeded") +} + +// ExecuteStream handles streaming requests to Kiro API. +// Supports automatic token refresh on 401/403 errors and quota fallback on 429. +func (e *KiroExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (stream <-chan cliproxyexecutor.StreamChunk, err error) { + accessToken, profileArn := kiroCredentials(auth) + if accessToken == "" { + return nil, fmt.Errorf("kiro: access token not found in auth") + } + if profileArn == "" { + log.Warnf("kiro: profile ARN not found in auth, API calls may fail") + } + + reporter := newUsageReporter(ctx, e.Identifier(), req.Model, auth) + defer reporter.trackFailure(ctx, &err) + + // Check if token is expired before making request + if e.isTokenExpired(accessToken) { + log.Infof("kiro: access token expired, attempting refresh before stream request") + refreshedAuth, refreshErr := e.Refresh(ctx, auth) + if refreshErr != nil { + log.Warnf("kiro: pre-request token refresh failed: %v", refreshErr) + } else if refreshedAuth != nil { + auth = refreshedAuth + accessToken, profileArn = kiroCredentials(auth) + log.Infof("kiro: token refreshed successfully before stream request") + } + } + + from := opts.SourceFormat + to := sdktranslator.FromString("kiro") + body := sdktranslator.TranslateRequest(from, to, req.Model, bytes.Clone(req.Payload), true) + + kiroModelID := e.mapModelToKiro(req.Model) + + // Check if this is an agentic model variant + isAgentic := strings.HasSuffix(req.Model, "-agentic") + + // Check if this is a chat-only model variant (no tool calling) + isChatOnly := strings.HasSuffix(req.Model, "-chat") + + // Determine initial origin based on model type + // Opus models use AI_EDITOR (Kiro IDE quota), others start with CLI (Amazon Q quota) + var currentOrigin string + if strings.Contains(strings.ToLower(req.Model), "opus") { + currentOrigin = "AI_EDITOR" + } else { + currentOrigin = "CLI" + } + + kiroPayload := e.buildKiroPayload(body, kiroModelID, profileArn, currentOrigin, isAgentic, isChatOnly) + + // Execute stream with retry on 401/403 and 429 (quota exhausted) + return e.executeStreamWithRetry(ctx, auth, req, opts, accessToken, profileArn, kiroPayload, body, from, reporter, currentOrigin, kiroModelID, isAgentic, isChatOnly) +} + +// executeStreamWithRetry performs the streaming HTTP request with automatic retry on auth errors. +// Supports automatic fallback from CLI (Amazon Q) quota to AI_EDITOR (Kiro IDE) quota on 429. +func (e *KiroExecutor) executeStreamWithRetry(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options, accessToken, profileArn string, kiroPayload, body []byte, from sdktranslator.Format, reporter *usageReporter, currentOrigin, kiroModelID string, isAgentic, isChatOnly bool) (<-chan cliproxyexecutor.StreamChunk, error) { + maxRetries := 2 // Allow retries for token refresh + origin fallback + + for attempt := 0; attempt <= maxRetries; attempt++ { + url := kiroEndpoint + httpReq, err := http.NewRequestWithContext(ctx, http.MethodPost, url, bytes.NewReader(kiroPayload)) + if err != nil { + return nil, err + } + + httpReq.Header.Set("Content-Type", kiroContentType) + httpReq.Header.Set("x-amz-target", kiroTargetChat) + httpReq.Header.Set("Authorization", "Bearer "+accessToken) + httpReq.Header.Set("Accept", kiroAcceptStream) + + var attrs map[string]string + if auth != nil { + attrs = auth.Attributes + } + util.ApplyCustomHeadersFromAttrs(httpReq, attrs) + + var authID, authLabel, authType, authValue string + if auth != nil { + authID = auth.ID + authLabel = auth.Label + authType, authValue = auth.AccountInfo() + } + recordAPIRequest(ctx, e.cfg, upstreamRequestLog{ + URL: url, + Method: http.MethodPost, + Headers: httpReq.Header.Clone(), + Body: kiroPayload, + Provider: e.Identifier(), + AuthID: authID, + AuthLabel: authLabel, + AuthType: authType, + AuthValue: authValue, + }) + + httpClient := newProxyAwareHTTPClient(ctx, e.cfg, auth, 0) + httpResp, err := httpClient.Do(httpReq) + if err != nil { + recordAPIResponseError(ctx, e.cfg, err) + return nil, err + } + recordAPIResponseMetadata(ctx, e.cfg, httpResp.StatusCode, httpResp.Header.Clone()) + + // Handle 429 errors (quota exhausted) with origin fallback + if httpResp.StatusCode == 429 { + respBody, _ := io.ReadAll(httpResp.Body) + _ = httpResp.Body.Close() + appendAPIResponseChunk(ctx, e.cfg, respBody) + + // If currently using CLI quota and it's exhausted, switch to AI_EDITOR (Kiro IDE) quota + if currentOrigin == "CLI" { + log.Warnf("kiro: stream Amazon Q (CLI) quota exhausted (429), switching to Kiro (AI_EDITOR) fallback") + currentOrigin = "AI_EDITOR" + + // Rebuild payload with new origin + kiroPayload = e.buildKiroPayload(body, kiroModelID, profileArn, currentOrigin, isAgentic, isChatOnly) + + // Retry with new origin + continue + } + + // Already on AI_EDITOR or other origin, return the error + log.Debugf("kiro stream error, status: %d, body: %s", httpResp.StatusCode, string(respBody)) + return nil, statusErr{code: httpResp.StatusCode, msg: string(respBody)} + } + + // Handle 401/403 errors with token refresh and retry + if httpResp.StatusCode == 401 || httpResp.StatusCode == 403 { + respBody, _ := io.ReadAll(httpResp.Body) + _ = httpResp.Body.Close() + appendAPIResponseChunk(ctx, e.cfg, respBody) + + if attempt < maxRetries { + log.Warnf("kiro: stream received %d error, attempting token refresh and retry (attempt %d/%d)", httpResp.StatusCode, attempt+1, maxRetries+1) + + refreshedAuth, refreshErr := e.Refresh(ctx, auth) + if refreshErr != nil { + log.Errorf("kiro: token refresh failed: %v", refreshErr) + return nil, statusErr{code: httpResp.StatusCode, msg: string(respBody)} + } + + if refreshedAuth != nil { + auth = refreshedAuth + accessToken, profileArn = kiroCredentials(auth) + // Rebuild payload with new profile ARN if changed + kiroPayload = e.buildKiroPayload(body, kiroModelID, profileArn, currentOrigin, isAgentic, isChatOnly) + log.Infof("kiro: token refreshed successfully, retrying stream request") + continue + } + } + + log.Debugf("kiro stream error, status: %d, body: %s", httpResp.StatusCode, string(respBody)) + return nil, statusErr{code: httpResp.StatusCode, msg: string(respBody)} + } + + if httpResp.StatusCode < 200 || httpResp.StatusCode >= 300 { + b, _ := io.ReadAll(httpResp.Body) + appendAPIResponseChunk(ctx, e.cfg, b) + log.Debugf("kiro stream error, status: %d, body: %s", httpResp.StatusCode, string(b)) + if errClose := httpResp.Body.Close(); errClose != nil { + log.Errorf("response body close error: %v", errClose) + } + return nil, statusErr{code: httpResp.StatusCode, msg: string(b)} + } + + out := make(chan cliproxyexecutor.StreamChunk) + + go func(resp *http.Response) { + defer close(out) + defer func() { + if errClose := resp.Body.Close(); errClose != nil { + log.Errorf("response body close error: %v", errClose) + } + }() + + e.streamToChannel(ctx, resp.Body, out, from, req.Model, opts.OriginalRequest, body, reporter) + }(httpResp) + + return out, nil + } + + return nil, fmt.Errorf("kiro: max retries exceeded for stream") +} + + +// kiroCredentials extracts access token and profile ARN from auth. +func kiroCredentials(auth *cliproxyauth.Auth) (accessToken, profileArn string) { + if auth == nil { + return "", "" + } + + // Try Metadata first (wrapper format) + if auth.Metadata != nil { + if token, ok := auth.Metadata["access_token"].(string); ok { + accessToken = token + } + if arn, ok := auth.Metadata["profile_arn"].(string); ok { + profileArn = arn + } + } + + // Try Attributes + if accessToken == "" && auth.Attributes != nil { + accessToken = auth.Attributes["access_token"] + profileArn = auth.Attributes["profile_arn"] + } + + // Try direct fields from flat JSON format (new AWS Builder ID format) + if accessToken == "" && auth.Metadata != nil { + if token, ok := auth.Metadata["accessToken"].(string); ok { + accessToken = token + } + if arn, ok := auth.Metadata["profileArn"].(string); ok { + profileArn = arn + } + } + + return accessToken, profileArn +} + +// mapModelToKiro maps external model names to Kiro model IDs. +// Supports both Kiro and Amazon Q prefixes since they use the same API. +// Agentic variants (-agentic suffix) map to the same backend model IDs. +func (e *KiroExecutor) mapModelToKiro(model string) string { + modelMap := map[string]string{ + // Proxy format (kiro- prefix) + "kiro-auto": "auto", + "kiro-claude-opus-4.5": "claude-opus-4.5", + "kiro-claude-sonnet-4.5": "claude-sonnet-4.5", + "kiro-claude-sonnet-4": "claude-sonnet-4", + "kiro-claude-haiku-4.5": "claude-haiku-4.5", + // Amazon Q format (amazonq- prefix) - same API as Kiro + "amazonq-auto": "auto", + "amazonq-claude-opus-4.5": "claude-opus-4.5", + "amazonq-claude-sonnet-4.5": "claude-sonnet-4.5", + "amazonq-claude-sonnet-4": "claude-sonnet-4", + "amazonq-claude-haiku-4.5": "claude-haiku-4.5", + // Native Kiro format (no prefix) - used by Kiro IDE directly + "claude-opus-4.5": "claude-opus-4.5", + "claude-sonnet-4.5": "claude-sonnet-4.5", + "claude-sonnet-4": "claude-sonnet-4", + "claude-haiku-4.5": "claude-haiku-4.5", + "auto": "auto", + // Chat variant (no tool calling support) + "kiro-claude-opus-4.5-chat": "claude-opus-4.5", + // Agentic variants (same backend model IDs, but with special system prompt) + "kiro-claude-opus-4.5-agentic": "claude-opus-4.5", + "kiro-claude-sonnet-4.5-agentic": "claude-sonnet-4.5", + "kiro-claude-sonnet-4-agentic": "claude-sonnet-4", + "kiro-claude-haiku-4.5-agentic": "claude-haiku-4.5", + "amazonq-claude-sonnet-4.5-agentic": "claude-sonnet-4.5", + } + if kiroID, ok := modelMap[model]; ok { + return kiroID + } + log.Debugf("kiro: unknown model '%s', falling back to 'auto'", model) + return "auto" +} + +// Kiro API request structs - field order determines JSON key order + +type kiroPayload struct { + ConversationState kiroConversationState `json:"conversationState"` + ProfileArn string `json:"profileArn,omitempty"` +} + +type kiroConversationState struct { + ConversationID string `json:"conversationId"` + History []kiroHistoryMessage `json:"history"` + CurrentMessage kiroCurrentMessage `json:"currentMessage"` + ChatTriggerType string `json:"chatTriggerType"` // Required: "MANUAL" +} + +type kiroCurrentMessage struct { + UserInputMessage kiroUserInputMessage `json:"userInputMessage"` +} + +type kiroHistoryMessage struct { + UserInputMessage *kiroUserInputMessage `json:"userInputMessage,omitempty"` + AssistantResponseMessage *kiroAssistantResponseMessage `json:"assistantResponseMessage,omitempty"` +} + +type kiroUserInputMessage struct { + Content string `json:"content"` + ModelID string `json:"modelId"` + Origin string `json:"origin"` + UserInputMessageContext *kiroUserInputMessageContext `json:"userInputMessageContext,omitempty"` +} + +type kiroUserInputMessageContext struct { + ToolResults []kiroToolResult `json:"toolResults,omitempty"` + Tools []kiroToolWrapper `json:"tools,omitempty"` +} + +type kiroToolResult struct { + ToolUseID string `json:"toolUseId"` + Content []kiroTextContent `json:"content"` + Status string `json:"status"` +} + +type kiroTextContent struct { + Text string `json:"text"` +} + +type kiroToolWrapper struct { + ToolSpecification kiroToolSpecification `json:"toolSpecification"` +} + +type kiroToolSpecification struct { + Name string `json:"name"` + Description string `json:"description"` + InputSchema kiroInputSchema `json:"inputSchema"` +} + +type kiroInputSchema struct { + JSON interface{} `json:"json"` +} + +type kiroAssistantResponseMessage struct { + Content string `json:"content"` + ToolUses []kiroToolUse `json:"toolUses,omitempty"` +} + +type kiroToolUse struct { + ToolUseID string `json:"toolUseId"` + Name string `json:"name"` + Input map[string]interface{} `json:"input"` +} + +// buildKiroPayload constructs the Kiro API request payload. +// Supports tool calling - tools are passed via userInputMessageContext. +// origin parameter determines which quota to use: "CLI" for Amazon Q, "AI_EDITOR" for Kiro IDE. +// isAgentic parameter enables chunked write optimization prompt for -agentic model variants. +// isChatOnly parameter disables tool calling for -chat model variants (pure conversation mode). +func (e *KiroExecutor) buildKiroPayload(claudeBody []byte, modelID, profileArn, origin string, isAgentic, isChatOnly bool) []byte { + messages := gjson.GetBytes(claudeBody, "messages") + + // For chat-only mode, don't include tools + var tools gjson.Result + if !isChatOnly { + tools = gjson.GetBytes(claudeBody, "tools") + } + + // Extract system prompt - can be string or array of content blocks + systemField := gjson.GetBytes(claudeBody, "system") + var systemPrompt string + if systemField.IsArray() { + // System is array of content blocks, extract text + var sb strings.Builder + for _, block := range systemField.Array() { + if block.Get("type").String() == "text" { + sb.WriteString(block.Get("text").String()) + } else if block.Type == gjson.String { + sb.WriteString(block.String()) + } + } + systemPrompt = sb.String() + } else { + systemPrompt = systemField.String() + } + + // Inject agentic optimization prompt for -agentic model variants + // This prevents AWS Kiro API timeouts during large file write operations + if isAgentic { + if systemPrompt != "" { + systemPrompt += "\n" + } + systemPrompt += kiroAgenticSystemPrompt + } + + // Convert Claude tools to Kiro format + var kiroTools []kiroToolWrapper + if tools.IsArray() { + for _, tool := range tools.Array() { + name := tool.Get("name").String() + description := tool.Get("description").String() + inputSchema := tool.Get("input_schema").Value() + + // Truncate long descriptions (Kiro API limit is in bytes) + // Truncate at valid UTF-8 boundary to avoid breaking multi-byte chars + if len(description) > kiroMaxToolDescLen { + // Find a valid UTF-8 boundary before the limit + truncLen := kiroMaxToolDescLen + for truncLen > 0 && !utf8.RuneStart(description[truncLen]) { + truncLen-- + } + description = description[:truncLen] + "..." + } + + kiroTools = append(kiroTools, kiroToolWrapper{ + ToolSpecification: kiroToolSpecification{ + Name: name, + Description: description, + InputSchema: kiroInputSchema{JSON: inputSchema}, + }, + }) + } + } + + var history []kiroHistoryMessage + var currentUserMsg *kiroUserInputMessage + var currentToolResults []kiroToolResult + + messagesArray := messages.Array() + for i, msg := range messagesArray { + role := msg.Get("role").String() + isLastMessage := i == len(messagesArray)-1 + + if role == "user" { + userMsg, toolResults := e.buildUserMessageStruct(msg, modelID, origin) + if isLastMessage { + currentUserMsg = &userMsg + currentToolResults = toolResults + } else { + // For history messages, embed tool results in context + if len(toolResults) > 0 { + userMsg.UserInputMessageContext = &kiroUserInputMessageContext{ + ToolResults: toolResults, + } + } + history = append(history, kiroHistoryMessage{ + UserInputMessage: &userMsg, + }) + } + } else if role == "assistant" { + assistantMsg := e.buildAssistantMessageStruct(msg) + history = append(history, kiroHistoryMessage{ + AssistantResponseMessage: &assistantMsg, + }) + } + } + + // Build content with system prompt + if currentUserMsg != nil { + var contentBuilder strings.Builder + + // Add system prompt if present + if systemPrompt != "" { + contentBuilder.WriteString("--- SYSTEM PROMPT ---\n") + contentBuilder.WriteString(systemPrompt) + contentBuilder.WriteString("\n--- END SYSTEM PROMPT ---\n\n") + } + + // Add the actual user message + contentBuilder.WriteString(currentUserMsg.Content) + currentUserMsg.Content = contentBuilder.String() + + // Build userInputMessageContext with tools and tool results + if len(kiroTools) > 0 || len(currentToolResults) > 0 { + currentUserMsg.UserInputMessageContext = &kiroUserInputMessageContext{ + Tools: kiroTools, + ToolResults: currentToolResults, + } + } + } + + // Build payload using structs (preserves key order) + var currentMessage kiroCurrentMessage + if currentUserMsg != nil { + currentMessage = kiroCurrentMessage{UserInputMessage: *currentUserMsg} + } else { + // Fallback when no user messages - still include system prompt if present + fallbackContent := "" + if systemPrompt != "" { + fallbackContent = "--- SYSTEM PROMPT ---\n" + systemPrompt + "\n--- END SYSTEM PROMPT ---\n" + } + currentMessage = kiroCurrentMessage{UserInputMessage: kiroUserInputMessage{ + Content: fallbackContent, + ModelID: modelID, + Origin: origin, + }} + } + + payload := kiroPayload{ + ConversationState: kiroConversationState{ + ConversationID: uuid.New().String(), + History: history, + CurrentMessage: currentMessage, + ChatTriggerType: "MANUAL", // Required by Kiro API + }, + ProfileArn: profileArn, + } + + // Ensure history is not nil (empty array) + if payload.ConversationState.History == nil { + payload.ConversationState.History = []kiroHistoryMessage{} + } + + result, err := json.Marshal(payload) + if err != nil { + log.Debugf("kiro: failed to marshal payload: %v", err) + return nil + } + return result +} + +// buildUserMessageStruct builds a user message and extracts tool results +// origin parameter determines which quota to use: "CLI" for Amazon Q, "AI_EDITOR" for Kiro IDE. +func (e *KiroExecutor) buildUserMessageStruct(msg gjson.Result, modelID, origin string) (kiroUserInputMessage, []kiroToolResult) { + content := msg.Get("content") + var contentBuilder strings.Builder + var toolResults []kiroToolResult + + if content.IsArray() { + for _, part := range content.Array() { + partType := part.Get("type").String() + switch partType { + case "text": + contentBuilder.WriteString(part.Get("text").String()) + case "tool_result": + // Extract tool result for API + toolUseID := part.Get("tool_use_id").String() + isError := part.Get("is_error").Bool() + resultContent := part.Get("content") + + // Convert content to Kiro format: [{text: "..."}] + var textContents []kiroTextContent + if resultContent.IsArray() { + for _, item := range resultContent.Array() { + if item.Get("type").String() == "text" { + textContents = append(textContents, kiroTextContent{Text: item.Get("text").String()}) + } else if item.Type == gjson.String { + textContents = append(textContents, kiroTextContent{Text: item.String()}) + } + } + } else if resultContent.Type == gjson.String { + textContents = append(textContents, kiroTextContent{Text: resultContent.String()}) + } + + // If no content, add default message + if len(textContents) == 0 { + textContents = append(textContents, kiroTextContent{Text: "Tool use was cancelled by the user"}) + } + + status := "success" + if isError { + status = "error" + } + + toolResults = append(toolResults, kiroToolResult{ + ToolUseID: toolUseID, + Content: textContents, + Status: status, + }) + } + } + } else { + contentBuilder.WriteString(content.String()) + } + + return kiroUserInputMessage{ + Content: contentBuilder.String(), + ModelID: modelID, + Origin: origin, + }, toolResults +} + +// buildAssistantMessageStruct builds an assistant message with tool uses +func (e *KiroExecutor) buildAssistantMessageStruct(msg gjson.Result) kiroAssistantResponseMessage { + content := msg.Get("content") + var contentBuilder strings.Builder + var toolUses []kiroToolUse + + if content.IsArray() { + for _, part := range content.Array() { + partType := part.Get("type").String() + switch partType { + case "text": + contentBuilder.WriteString(part.Get("text").String()) + case "tool_use": + // Extract tool use for API + toolUseID := part.Get("id").String() + toolName := part.Get("name").String() + toolInput := part.Get("input") + + // Convert input to map + var inputMap map[string]interface{} + if toolInput.IsObject() { + inputMap = make(map[string]interface{}) + toolInput.ForEach(func(key, value gjson.Result) bool { + inputMap[key.String()] = value.Value() + return true + }) + } + + toolUses = append(toolUses, kiroToolUse{ + ToolUseID: toolUseID, + Name: toolName, + Input: inputMap, + }) + } + } + } else { + contentBuilder.WriteString(content.String()) + } + + return kiroAssistantResponseMessage{ + Content: contentBuilder.String(), + ToolUses: toolUses, + } +} + +// NOTE: Tool calling is now supported via userInputMessageContext.tools and toolResults + +// parseEventStream parses AWS Event Stream binary format. +// Extracts text content and tool uses from the response. +// Supports embedded [Called ...] tool calls and input buffering for toolUseEvent. +func (e *KiroExecutor) parseEventStream(body io.Reader) (string, []kiroToolUse, usage.Detail, error) { + var content strings.Builder + var toolUses []kiroToolUse + var usageInfo usage.Detail + reader := bufio.NewReader(body) + + // Tool use state tracking for input buffering and deduplication + processedIDs := make(map[string]bool) + var currentToolUse *toolUseState + + for { + prelude := make([]byte, 8) + _, err := io.ReadFull(reader, prelude) + if err == io.EOF { + break + } + if err != nil { + return content.String(), toolUses, usageInfo, fmt.Errorf("failed to read prelude: %w", err) + } + + totalLen := binary.BigEndian.Uint32(prelude[0:4]) + if totalLen < 8 { + return content.String(), toolUses, usageInfo, fmt.Errorf("invalid message length: %d", totalLen) + } + if totalLen > kiroMaxMessageSize { + return content.String(), toolUses, usageInfo, fmt.Errorf("message too large: %d bytes", totalLen) + } + headersLen := binary.BigEndian.Uint32(prelude[4:8]) + + remaining := make([]byte, totalLen-8) + _, err = io.ReadFull(reader, remaining) + if err != nil { + return content.String(), toolUses, usageInfo, fmt.Errorf("failed to read message: %w", err) + } + + // Extract event type from headers + eventType := e.extractEventType(remaining[:headersLen+4]) + + payloadStart := 4 + headersLen + payloadEnd := uint32(len(remaining)) - 4 + if payloadStart >= payloadEnd { + continue + } + + payload := remaining[payloadStart:payloadEnd] + + var event map[string]interface{} + if err := json.Unmarshal(payload, &event); err != nil { + log.Debugf("kiro: skipping malformed event: %v", err) + continue + } + + // Handle different event types + switch eventType { + case "assistantResponseEvent": + if assistantResp, ok := event["assistantResponseEvent"].(map[string]interface{}); ok { + if contentText, ok := assistantResp["content"].(string); ok { + content.WriteString(contentText) + } + // Extract tool uses from response + if toolUsesRaw, ok := assistantResp["toolUses"].([]interface{}); ok { + for _, tuRaw := range toolUsesRaw { + if tu, ok := tuRaw.(map[string]interface{}); ok { + toolUseID := getString(tu, "toolUseId") + // Check for duplicate + if processedIDs[toolUseID] { + log.Debugf("kiro: skipping duplicate tool use from assistantResponse: %s", toolUseID) + continue + } + processedIDs[toolUseID] = true + + toolUse := kiroToolUse{ + ToolUseID: toolUseID, + Name: getString(tu, "name"), + } + if input, ok := tu["input"].(map[string]interface{}); ok { + toolUse.Input = input + } + toolUses = append(toolUses, toolUse) + } + } + } + } + // Also try direct format + if contentText, ok := event["content"].(string); ok { + content.WriteString(contentText) + } + // Direct tool uses + if toolUsesRaw, ok := event["toolUses"].([]interface{}); ok { + for _, tuRaw := range toolUsesRaw { + if tu, ok := tuRaw.(map[string]interface{}); ok { + toolUseID := getString(tu, "toolUseId") + // Check for duplicate + if processedIDs[toolUseID] { + log.Debugf("kiro: skipping duplicate direct tool use: %s", toolUseID) + continue + } + processedIDs[toolUseID] = true + + toolUse := kiroToolUse{ + ToolUseID: toolUseID, + Name: getString(tu, "name"), + } + if input, ok := tu["input"].(map[string]interface{}); ok { + toolUse.Input = input + } + toolUses = append(toolUses, toolUse) + } + } + } + + case "toolUseEvent": + // Handle dedicated tool use events with input buffering + completedToolUses, newState := e.processToolUseEvent(event, currentToolUse, processedIDs) + currentToolUse = newState + toolUses = append(toolUses, completedToolUses...) + + case "supplementaryWebLinksEvent": + if inputTokens, ok := event["inputTokens"].(float64); ok { + usageInfo.InputTokens = int64(inputTokens) + } + if outputTokens, ok := event["outputTokens"].(float64); ok { + usageInfo.OutputTokens = int64(outputTokens) + } + } + + // Also check nested supplementaryWebLinksEvent + if usageEvent, ok := event["supplementaryWebLinksEvent"].(map[string]interface{}); ok { + if inputTokens, ok := usageEvent["inputTokens"].(float64); ok { + usageInfo.InputTokens = int64(inputTokens) + } + if outputTokens, ok := usageEvent["outputTokens"].(float64); ok { + usageInfo.OutputTokens = int64(outputTokens) + } + } + } + + // Parse embedded tool calls from content (e.g., [Called tool_name with args: {...}]) + contentStr := content.String() + cleanedContent, embeddedToolUses := e.parseEmbeddedToolCalls(contentStr, processedIDs) + toolUses = append(toolUses, embeddedToolUses...) + + // Deduplicate all tool uses + toolUses = deduplicateToolUses(toolUses) + + return cleanedContent, toolUses, usageInfo, nil +} + +// extractEventType extracts the event type from AWS Event Stream headers +func (e *KiroExecutor) extractEventType(headerBytes []byte) string { + // Skip prelude CRC (4 bytes) + if len(headerBytes) < 4 { + return "" + } + headers := headerBytes[4:] + + offset := 0 + for offset < len(headers) { + if offset >= len(headers) { + break + } + nameLen := int(headers[offset]) + offset++ + if offset+nameLen > len(headers) { + break + } + name := string(headers[offset : offset+nameLen]) + offset += nameLen + + if offset >= len(headers) { + break + } + valueType := headers[offset] + offset++ + + if valueType == 7 { // String type + if offset+2 > len(headers) { + break + } + valueLen := int(binary.BigEndian.Uint16(headers[offset : offset+2])) + offset += 2 + if offset+valueLen > len(headers) { + break + } + value := string(headers[offset : offset+valueLen]) + offset += valueLen + + if name == ":event-type" { + return value + } + } else { + // Skip other types + break + } + } + return "" +} + +// getString safely extracts a string from a map +func getString(m map[string]interface{}, key string) string { + if v, ok := m[key].(string); ok { + return v + } + return "" +} + +// buildClaudeResponse constructs a Claude-compatible response. +// Supports tool_use blocks when tools are present in the response. +func (e *KiroExecutor) buildClaudeResponse(content string, toolUses []kiroToolUse, model string, usageInfo usage.Detail) []byte { + var contentBlocks []map[string]interface{} + + // Add text content if present + if content != "" { + contentBlocks = append(contentBlocks, map[string]interface{}{ + "type": "text", + "text": content, + }) + } + + // Add tool_use blocks + for _, toolUse := range toolUses { + contentBlocks = append(contentBlocks, map[string]interface{}{ + "type": "tool_use", + "id": toolUse.ToolUseID, + "name": toolUse.Name, + "input": toolUse.Input, + }) + } + + // Ensure at least one content block (Claude API requires non-empty content) + if len(contentBlocks) == 0 { + contentBlocks = append(contentBlocks, map[string]interface{}{ + "type": "text", + "text": "", + }) + } + + // Determine stop reason + stopReason := "end_turn" + if len(toolUses) > 0 { + stopReason = "tool_use" + } + + response := map[string]interface{}{ + "id": "msg_" + uuid.New().String()[:24], + "type": "message", + "role": "assistant", + "model": model, + "content": contentBlocks, + "stop_reason": stopReason, + "usage": map[string]interface{}{ + "input_tokens": usageInfo.InputTokens, + "output_tokens": usageInfo.OutputTokens, + }, + } + result, _ := json.Marshal(response) + return result +} + +// NOTE: Tool uses are now extracted from API response, not parsed from text + + +// streamToChannel converts AWS Event Stream to channel-based streaming. +// Supports tool calling - emits tool_use content blocks when tools are used. +// Includes embedded [Called ...] tool call parsing and input buffering for toolUseEvent. +func (e *KiroExecutor) streamToChannel(ctx context.Context, body io.Reader, out chan<- cliproxyexecutor.StreamChunk, targetFormat sdktranslator.Format, model string, originalReq, claudeBody []byte, reporter *usageReporter) { + reader := bufio.NewReader(body) + var totalUsage usage.Detail + var hasToolUses bool // Track if any tool uses were emitted + + // Tool use state tracking for input buffering and deduplication + processedIDs := make(map[string]bool) + var currentToolUse *toolUseState + + // Translator param for maintaining tool call state across streaming events + // IMPORTANT: This must persist across all TranslateStream calls + var translatorParam any + + // Pre-calculate input tokens from request if possible + if enc, err := tokenizerForModel(model); err == nil { + // Try OpenAI format first, then fall back to raw byte count estimation + if inp, err := countOpenAIChatTokens(enc, originalReq); err == nil && inp > 0 { + totalUsage.InputTokens = inp + } else { + // Fallback: estimate from raw request size (roughly 4 chars per token) + totalUsage.InputTokens = int64(len(originalReq) / 4) + if totalUsage.InputTokens == 0 && len(originalReq) > 0 { + totalUsage.InputTokens = 1 + } + } + log.Debugf("kiro: streamToChannel pre-calculated input tokens: %d (request size: %d bytes)", totalUsage.InputTokens, len(originalReq)) + } + + contentBlockIndex := -1 + messageStartSent := false + isTextBlockOpen := false + var outputLen int + + // Ensure usage is published even on early return + defer func() { + reporter.publish(ctx, totalUsage) + }() + + for { + select { + case <-ctx.Done(): + return + default: + } + + prelude := make([]byte, 8) + _, err := io.ReadFull(reader, prelude) + if err == io.EOF { + break + } + if err != nil { + out <- cliproxyexecutor.StreamChunk{Err: fmt.Errorf("failed to read prelude: %w", err)} + return + } + + totalLen := binary.BigEndian.Uint32(prelude[0:4]) + if totalLen < 8 { + out <- cliproxyexecutor.StreamChunk{Err: fmt.Errorf("invalid message length: %d", totalLen)} + return + } + if totalLen > kiroMaxMessageSize { + out <- cliproxyexecutor.StreamChunk{Err: fmt.Errorf("message too large: %d bytes", totalLen)} + return + } + headersLen := binary.BigEndian.Uint32(prelude[4:8]) + + remaining := make([]byte, totalLen-8) + _, err = io.ReadFull(reader, remaining) + if err != nil { + out <- cliproxyexecutor.StreamChunk{Err: fmt.Errorf("failed to read message: %w", err)} + return + } + + eventType := e.extractEventType(remaining[:headersLen+4]) + + payloadStart := 4 + headersLen + payloadEnd := uint32(len(remaining)) - 4 + if payloadStart >= payloadEnd { + continue + } + + payload := remaining[payloadStart:payloadEnd] + appendAPIResponseChunk(ctx, e.cfg, payload) + + var event map[string]interface{} + if err := json.Unmarshal(payload, &event); err != nil { + continue + } + + // Send message_start on first event + if !messageStartSent { + msgStart := e.buildClaudeMessageStartEvent(model, totalUsage.InputTokens) + sseData := sdktranslator.TranslateStream(ctx, sdktranslator.FromString("claude"), targetFormat, model, originalReq, claudeBody, msgStart, &translatorParam) + for _, chunk := range sseData { + if chunk != "" { + out <- cliproxyexecutor.StreamChunk{Payload: []byte(chunk + "\n\n")} + } + } + messageStartSent = true + } + + switch eventType { + case "assistantResponseEvent": + var contentDelta string + var toolUses []map[string]interface{} + + if assistantResp, ok := event["assistantResponseEvent"].(map[string]interface{}); ok { + if c, ok := assistantResp["content"].(string); ok { + contentDelta = c + } + // Extract tool uses from response + if tus, ok := assistantResp["toolUses"].([]interface{}); ok { + for _, tuRaw := range tus { + if tu, ok := tuRaw.(map[string]interface{}); ok { + toolUses = append(toolUses, tu) + } + } + } + } + if contentDelta == "" { + if c, ok := event["content"].(string); ok { + contentDelta = c + } + } + // Direct tool uses + if tus, ok := event["toolUses"].([]interface{}); ok { + for _, tuRaw := range tus { + if tu, ok := tuRaw.(map[string]interface{}); ok { + toolUses = append(toolUses, tu) + } + } + } + + // Handle text content + if contentDelta != "" { + outputLen += len(contentDelta) + // Start text content block if needed + if !isTextBlockOpen { + contentBlockIndex++ + isTextBlockOpen = true + blockStart := e.buildClaudeContentBlockStartEvent(contentBlockIndex, "text", "", "") + sseData := sdktranslator.TranslateStream(ctx, sdktranslator.FromString("claude"), targetFormat, model, originalReq, claudeBody, blockStart, &translatorParam) + for _, chunk := range sseData { + if chunk != "" { + out <- cliproxyexecutor.StreamChunk{Payload: []byte(chunk + "\n\n")} + } + } + } + + claudeEvent := e.buildClaudeStreamEvent(contentDelta, contentBlockIndex) + sseData := sdktranslator.TranslateStream(ctx, sdktranslator.FromString("claude"), targetFormat, model, originalReq, claudeBody, claudeEvent, &translatorParam) + for _, chunk := range sseData { + if chunk != "" { + out <- cliproxyexecutor.StreamChunk{Payload: []byte(chunk + "\n\n")} + } + } + } + + // Handle tool uses in response (with deduplication) + for _, tu := range toolUses { + toolUseID := getString(tu, "toolUseId") + + // Check for duplicate + if processedIDs[toolUseID] { + log.Debugf("kiro: skipping duplicate tool use in stream: %s", toolUseID) + continue + } + processedIDs[toolUseID] = true + + hasToolUses = true + // Close text block if open before starting tool_use block + if isTextBlockOpen && contentBlockIndex >= 0 { + blockStop := e.buildClaudeContentBlockStopEvent(contentBlockIndex) + sseData := sdktranslator.TranslateStream(ctx, sdktranslator.FromString("claude"), targetFormat, model, originalReq, claudeBody, blockStop, &translatorParam) + for _, chunk := range sseData { + if chunk != "" { + out <- cliproxyexecutor.StreamChunk{Payload: []byte(chunk + "\n\n")} + } + } + isTextBlockOpen = false + } + + // Emit tool_use content block + contentBlockIndex++ + toolName := getString(tu, "name") + + blockStart := e.buildClaudeContentBlockStartEvent(contentBlockIndex, "tool_use", toolUseID, toolName) + sseData := sdktranslator.TranslateStream(ctx, sdktranslator.FromString("claude"), targetFormat, model, originalReq, claudeBody, blockStart, &translatorParam) + for _, chunk := range sseData { + if chunk != "" { + out <- cliproxyexecutor.StreamChunk{Payload: []byte(chunk + "\n\n")} + } + } + + // Send input_json_delta with the tool input + if input, ok := tu["input"].(map[string]interface{}); ok { + inputJSON, err := json.Marshal(input) + if err != nil { + log.Debugf("kiro: failed to marshal tool input: %v", err) + // Don't continue - still need to close the block + } else { + inputDelta := e.buildClaudeInputJsonDeltaEvent(string(inputJSON), contentBlockIndex) + sseData = sdktranslator.TranslateStream(ctx, sdktranslator.FromString("claude"), targetFormat, model, originalReq, claudeBody, inputDelta, &translatorParam) + for _, chunk := range sseData { + if chunk != "" { + out <- cliproxyexecutor.StreamChunk{Payload: []byte(chunk + "\n\n")} + } + } + } + } + + // Close tool_use block (always close even if input marshal failed) + blockStop := e.buildClaudeContentBlockStopEvent(contentBlockIndex) + sseData = sdktranslator.TranslateStream(ctx, sdktranslator.FromString("claude"), targetFormat, model, originalReq, claudeBody, blockStop, &translatorParam) + for _, chunk := range sseData { + if chunk != "" { + out <- cliproxyexecutor.StreamChunk{Payload: []byte(chunk + "\n\n")} + } + } + } + + case "toolUseEvent": + // Handle dedicated tool use events with input buffering + completedToolUses, newState := e.processToolUseEvent(event, currentToolUse, processedIDs) + currentToolUse = newState + + // Emit completed tool uses + for _, tu := range completedToolUses { + hasToolUses = true + + // Close text block if open + if isTextBlockOpen && contentBlockIndex >= 0 { + blockStop := e.buildClaudeContentBlockStopEvent(contentBlockIndex) + sseData := sdktranslator.TranslateStream(ctx, sdktranslator.FromString("claude"), targetFormat, model, originalReq, claudeBody, blockStop, &translatorParam) + for _, chunk := range sseData { + if chunk != "" { + out <- cliproxyexecutor.StreamChunk{Payload: []byte(chunk + "\n\n")} + } + } + isTextBlockOpen = false + } + + contentBlockIndex++ + + blockStart := e.buildClaudeContentBlockStartEvent(contentBlockIndex, "tool_use", tu.ToolUseID, tu.Name) + sseData := sdktranslator.TranslateStream(ctx, sdktranslator.FromString("claude"), targetFormat, model, originalReq, claudeBody, blockStart, &translatorParam) + for _, chunk := range sseData { + if chunk != "" { + out <- cliproxyexecutor.StreamChunk{Payload: []byte(chunk + "\n\n")} + } + } + + if tu.Input != nil { + inputJSON, err := json.Marshal(tu.Input) + if err != nil { + log.Debugf("kiro: failed to marshal tool input in toolUseEvent: %v", err) + } else { + inputDelta := e.buildClaudeInputJsonDeltaEvent(string(inputJSON), contentBlockIndex) + sseData = sdktranslator.TranslateStream(ctx, sdktranslator.FromString("claude"), targetFormat, model, originalReq, claudeBody, inputDelta, &translatorParam) + for _, chunk := range sseData { + if chunk != "" { + out <- cliproxyexecutor.StreamChunk{Payload: []byte(chunk + "\n\n")} + } + } + } + } + + blockStop := e.buildClaudeContentBlockStopEvent(contentBlockIndex) + sseData = sdktranslator.TranslateStream(ctx, sdktranslator.FromString("claude"), targetFormat, model, originalReq, claudeBody, blockStop, &translatorParam) + for _, chunk := range sseData { + if chunk != "" { + out <- cliproxyexecutor.StreamChunk{Payload: []byte(chunk + "\n\n")} + } + } + } + + case "supplementaryWebLinksEvent": + if inputTokens, ok := event["inputTokens"].(float64); ok { + totalUsage.InputTokens = int64(inputTokens) + } + if outputTokens, ok := event["outputTokens"].(float64); ok { + totalUsage.OutputTokens = int64(outputTokens) + } + } + + // Check nested usage event + if usageEvent, ok := event["supplementaryWebLinksEvent"].(map[string]interface{}); ok { + if inputTokens, ok := usageEvent["inputTokens"].(float64); ok { + totalUsage.InputTokens = int64(inputTokens) + } + if outputTokens, ok := usageEvent["outputTokens"].(float64); ok { + totalUsage.OutputTokens = int64(outputTokens) + } + } + } + + // Close content block if open + if isTextBlockOpen && contentBlockIndex >= 0 { + blockStop := e.buildClaudeContentBlockStopEvent(contentBlockIndex) + sseData := sdktranslator.TranslateStream(ctx, sdktranslator.FromString("claude"), targetFormat, model, originalReq, claudeBody, blockStop, &translatorParam) + for _, chunk := range sseData { + if chunk != "" { + out <- cliproxyexecutor.StreamChunk{Payload: []byte(chunk + "\n\n")} + } + } + } + + // Fallback for output tokens if not received from upstream + if totalUsage.OutputTokens == 0 && outputLen > 0 { + totalUsage.OutputTokens = int64(outputLen / 4) + if totalUsage.OutputTokens == 0 { + totalUsage.OutputTokens = 1 + } + } + totalUsage.TotalTokens = totalUsage.InputTokens + totalUsage.OutputTokens + + // Determine stop reason based on whether tool uses were emitted + stopReason := "end_turn" + if hasToolUses { + stopReason = "tool_use" + } + + // Send message_delta and message_stop + msgStop := e.buildClaudeMessageStopEvent(stopReason, totalUsage) + sseData := sdktranslator.TranslateStream(ctx, sdktranslator.FromString("claude"), targetFormat, model, originalReq, claudeBody, msgStop, &translatorParam) + for _, chunk := range sseData { + if chunk != "" { + out <- cliproxyexecutor.StreamChunk{Payload: []byte(chunk + "\n\n")} + } + } + // reporter.publish is called via defer +} + + +// Claude SSE event builders +func (e *KiroExecutor) buildClaudeMessageStartEvent(model string, inputTokens int64) []byte { + event := map[string]interface{}{ + "type": "message_start", + "message": map[string]interface{}{ + "id": "msg_" + uuid.New().String()[:24], + "type": "message", + "role": "assistant", + "content": []interface{}{}, + "model": model, + "stop_reason": nil, + "stop_sequence": nil, + "usage": map[string]interface{}{"input_tokens": inputTokens, "output_tokens": 0}, + }, + } + result, _ := json.Marshal(event) + return []byte("data: " + string(result)) +} + +func (e *KiroExecutor) buildClaudeContentBlockStartEvent(index int, blockType, toolUseID, toolName string) []byte { + var contentBlock map[string]interface{} + if blockType == "tool_use" { + contentBlock = map[string]interface{}{ + "type": "tool_use", + "id": toolUseID, + "name": toolName, + "input": map[string]interface{}{}, + } + } else { + contentBlock = map[string]interface{}{ + "type": "text", + "text": "", + } + } + + event := map[string]interface{}{ + "type": "content_block_start", + "index": index, + "content_block": contentBlock, + } + result, _ := json.Marshal(event) + return []byte("data: " + string(result)) +} + +func (e *KiroExecutor) buildClaudeStreamEvent(contentDelta string, index int) []byte { + event := map[string]interface{}{ + "type": "content_block_delta", + "index": index, + "delta": map[string]interface{}{ + "type": "text_delta", + "text": contentDelta, + }, + } + result, _ := json.Marshal(event) + return []byte("data: " + string(result)) +} + +// buildClaudeInputJsonDeltaEvent creates an input_json_delta event for tool use streaming +func (e *KiroExecutor) buildClaudeInputJsonDeltaEvent(partialJSON string, index int) []byte { + event := map[string]interface{}{ + "type": "content_block_delta", + "index": index, + "delta": map[string]interface{}{ + "type": "input_json_delta", + "partial_json": partialJSON, + }, + } + result, _ := json.Marshal(event) + return []byte("data: " + string(result)) +} + +func (e *KiroExecutor) buildClaudeContentBlockStopEvent(index int) []byte { + event := map[string]interface{}{ + "type": "content_block_stop", + "index": index, + } + result, _ := json.Marshal(event) + return []byte("data: " + string(result)) +} + +func (e *KiroExecutor) buildClaudeMessageStopEvent(stopReason string, usageInfo usage.Detail) []byte { + // First message_delta + deltaEvent := map[string]interface{}{ + "type": "message_delta", + "delta": map[string]interface{}{ + "stop_reason": stopReason, + "stop_sequence": nil, + }, + "usage": map[string]interface{}{ + "input_tokens": usageInfo.InputTokens, + "output_tokens": usageInfo.OutputTokens, + }, + } + deltaResult, _ := json.Marshal(deltaEvent) + + // Then message_stop + stopEvent := map[string]interface{}{ + "type": "message_stop", + } + stopResult, _ := json.Marshal(stopEvent) + + return []byte("data: " + string(deltaResult) + "\n\ndata: " + string(stopResult)) +} + +// buildClaudeFinalEvent constructs the final Claude-style event. +func (e *KiroExecutor) buildClaudeFinalEvent() []byte { + event := map[string]interface{}{ + "type": "message_stop", + } + result, _ := json.Marshal(event) + return []byte("data: " + string(result)) +} + +// CountTokens is not supported for the Kiro provider. +func (e *KiroExecutor) CountTokens(context.Context, *cliproxyauth.Auth, cliproxyexecutor.Request, cliproxyexecutor.Options) (cliproxyexecutor.Response, error) { + return cliproxyexecutor.Response{}, statusErr{code: http.StatusNotImplemented, msg: "count tokens not supported for kiro"} +} + +// Refresh refreshes the Kiro OAuth token. +// Supports both AWS Builder ID (SSO OIDC) and Google OAuth (social login). +// Uses mutex to prevent race conditions when multiple concurrent requests try to refresh. +func (e *KiroExecutor) Refresh(ctx context.Context, auth *cliproxyauth.Auth) (*cliproxyauth.Auth, error) { + // Serialize token refresh operations to prevent race conditions + e.refreshMu.Lock() + defer e.refreshMu.Unlock() + + log.Debugf("kiro executor: refresh called for auth %s", auth.ID) + if auth == nil { + return nil, fmt.Errorf("kiro executor: auth is nil") + } + + // Double-check: After acquiring lock, verify token still needs refresh + // Another goroutine may have already refreshed while we were waiting + // NOTE: This check has a design limitation - it reads from the auth object passed in, + // not from persistent storage. If another goroutine returns a new Auth object (via Clone), + // this check won't see those updates. The mutex still prevents truly concurrent refreshes, + // but queued goroutines may still attempt redundant refreshes. This is acceptable as + // the refresh operation is idempotent and the extra API calls are infrequent. + if auth.Metadata != nil { + if lastRefresh, ok := auth.Metadata["last_refresh"].(string); ok { + if refreshTime, err := time.Parse(time.RFC3339, lastRefresh); err == nil { + // If token was refreshed within the last 30 seconds, skip refresh + if time.Since(refreshTime) < 30*time.Second { + log.Debugf("kiro executor: token was recently refreshed by another goroutine, skipping") + return auth, nil + } + } + } + // Also check if expires_at is now in the future with sufficient buffer + if expiresAt, ok := auth.Metadata["expires_at"].(string); ok { + if expTime, err := time.Parse(time.RFC3339, expiresAt); err == nil { + // If token expires more than 2 minutes from now, it's still valid + if time.Until(expTime) > 2*time.Minute { + log.Debugf("kiro executor: token is still valid (expires in %v), skipping refresh", time.Until(expTime)) + return auth, nil + } + } + } + } + + var refreshToken string + var clientID, clientSecret string + var authMethod string + + if auth.Metadata != nil { + if rt, ok := auth.Metadata["refresh_token"].(string); ok { + refreshToken = rt + } + if cid, ok := auth.Metadata["client_id"].(string); ok { + clientID = cid + } + if cs, ok := auth.Metadata["client_secret"].(string); ok { + clientSecret = cs + } + if am, ok := auth.Metadata["auth_method"].(string); ok { + authMethod = am + } + } + + if refreshToken == "" { + return nil, fmt.Errorf("kiro executor: refresh token not found") + } + + var tokenData *kiroauth.KiroTokenData + var err error + + // Use SSO OIDC refresh for AWS Builder ID, otherwise use Kiro's OAuth refresh endpoint + if clientID != "" && clientSecret != "" && authMethod == "builder-id" { + log.Debugf("kiro executor: using SSO OIDC refresh for AWS Builder ID") + ssoClient := kiroauth.NewSSOOIDCClient(e.cfg) + tokenData, err = ssoClient.RefreshToken(ctx, clientID, clientSecret, refreshToken) + } else { + log.Debugf("kiro executor: using Kiro OAuth refresh endpoint") + oauth := kiroauth.NewKiroOAuth(e.cfg) + tokenData, err = oauth.RefreshToken(ctx, refreshToken) + } + + if err != nil { + return nil, fmt.Errorf("kiro executor: token refresh failed: %w", err) + } + + updated := auth.Clone() + now := time.Now() + updated.UpdatedAt = now + updated.LastRefreshedAt = now + + if updated.Metadata == nil { + updated.Metadata = make(map[string]any) + } + updated.Metadata["access_token"] = tokenData.AccessToken + updated.Metadata["refresh_token"] = tokenData.RefreshToken + updated.Metadata["expires_at"] = tokenData.ExpiresAt + updated.Metadata["last_refresh"] = now.Format(time.RFC3339) + if tokenData.ProfileArn != "" { + updated.Metadata["profile_arn"] = tokenData.ProfileArn + } + if tokenData.AuthMethod != "" { + updated.Metadata["auth_method"] = tokenData.AuthMethod + } + if tokenData.Provider != "" { + updated.Metadata["provider"] = tokenData.Provider + } + // Preserve client credentials for future refreshes (AWS Builder ID) + if tokenData.ClientID != "" { + updated.Metadata["client_id"] = tokenData.ClientID + } + if tokenData.ClientSecret != "" { + updated.Metadata["client_secret"] = tokenData.ClientSecret + } + + if updated.Attributes == nil { + updated.Attributes = make(map[string]string) + } + updated.Attributes["access_token"] = tokenData.AccessToken + if tokenData.ProfileArn != "" { + updated.Attributes["profile_arn"] = tokenData.ProfileArn + } + + // Set next refresh time to 30 minutes before expiry + if expiresAt, parseErr := time.Parse(time.RFC3339, tokenData.ExpiresAt); parseErr == nil { + updated.NextRefreshAfter = expiresAt.Add(-30 * time.Minute) + } + + log.Infof("kiro executor: token refreshed successfully, expires at %s", tokenData.ExpiresAt) + return updated, nil +} + +// streamEventStream converts AWS Event Stream to SSE (legacy method for gin.Context). +// Note: For full tool calling support, use streamToChannel instead. +func (e *KiroExecutor) streamEventStream(ctx context.Context, body io.Reader, c *gin.Context, targetFormat sdktranslator.Format, model string, originalReq, claudeBody []byte, reporter *usageReporter) error { + reader := bufio.NewReader(body) + var totalUsage usage.Detail + + // Translator param for maintaining tool call state across streaming events + var translatorParam any + + // Pre-calculate input tokens from request if possible + if enc, err := tokenizerForModel(model); err == nil { + // Try OpenAI format first, then fall back to raw byte count estimation + if inp, err := countOpenAIChatTokens(enc, originalReq); err == nil && inp > 0 { + totalUsage.InputTokens = inp + } else { + // Fallback: estimate from raw request size (roughly 4 chars per token) + totalUsage.InputTokens = int64(len(originalReq) / 4) + if totalUsage.InputTokens == 0 && len(originalReq) > 0 { + totalUsage.InputTokens = 1 + } + } + log.Debugf("kiro: streamEventStream pre-calculated input tokens: %d (request size: %d bytes)", totalUsage.InputTokens, len(originalReq)) + } + + contentBlockIndex := -1 + messageStartSent := false + isBlockOpen := false + var outputLen int + + for { + select { + case <-ctx.Done(): + return ctx.Err() + default: + } + + prelude := make([]byte, 8) + _, err := io.ReadFull(reader, prelude) + if err == io.EOF { + break + } + if err != nil { + return fmt.Errorf("failed to read prelude: %w", err) + } + + totalLen := binary.BigEndian.Uint32(prelude[0:4]) + if totalLen < 8 { + return fmt.Errorf("invalid message length: %d", totalLen) + } + if totalLen > kiroMaxMessageSize { + return fmt.Errorf("message too large: %d bytes", totalLen) + } + headersLen := binary.BigEndian.Uint32(prelude[4:8]) + + remaining := make([]byte, totalLen-8) + _, err = io.ReadFull(reader, remaining) + if err != nil { + return fmt.Errorf("failed to read message: %w", err) + } + + eventType := e.extractEventType(remaining[:headersLen+4]) + + payloadStart := 4 + headersLen + payloadEnd := uint32(len(remaining)) - 4 + if payloadStart >= payloadEnd { + continue + } + + payload := remaining[payloadStart:payloadEnd] + appendAPIResponseChunk(ctx, e.cfg, payload) + + var event map[string]interface{} + if err := json.Unmarshal(payload, &event); err != nil { + continue + } + + if !messageStartSent { + msgStart := e.buildClaudeMessageStartEvent(model, totalUsage.InputTokens) + sseData := sdktranslator.TranslateStream(ctx, sdktranslator.FromString("claude"), targetFormat, model, originalReq, claudeBody, msgStart, &translatorParam) + for _, chunk := range sseData { + if chunk != "" { + c.Writer.Write([]byte(chunk + "\n\n")) + } + } + c.Writer.Flush() + messageStartSent = true + } + + switch eventType { + case "assistantResponseEvent": + var contentDelta string + if assistantResp, ok := event["assistantResponseEvent"].(map[string]interface{}); ok { + if ct, ok := assistantResp["content"].(string); ok { + contentDelta = ct + } + } + if contentDelta == "" { + if ct, ok := event["content"].(string); ok { + contentDelta = ct + } + } + + if contentDelta != "" { + outputLen += len(contentDelta) + // Start text content block if needed + if !isBlockOpen { + contentBlockIndex++ + isBlockOpen = true + blockStart := e.buildClaudeContentBlockStartEvent(contentBlockIndex, "text", "", "") + sseData := sdktranslator.TranslateStream(ctx, sdktranslator.FromString("claude"), targetFormat, model, originalReq, claudeBody, blockStart, &translatorParam) + for _, chunk := range sseData { + if chunk != "" { + c.Writer.Write([]byte(chunk + "\n\n")) + } + } + c.Writer.Flush() + } + + claudeEvent := e.buildClaudeStreamEvent(contentDelta, contentBlockIndex) + sseData := sdktranslator.TranslateStream(ctx, sdktranslator.FromString("claude"), targetFormat, model, originalReq, claudeBody, claudeEvent, &translatorParam) + for _, chunk := range sseData { + if chunk != "" { + c.Writer.Write([]byte(chunk + "\n\n")) + } + } + c.Writer.Flush() + } + + // Note: For full toolUseEvent support, use streamToChannel + + case "supplementaryWebLinksEvent": + if inputTokens, ok := event["inputTokens"].(float64); ok { + totalUsage.InputTokens = int64(inputTokens) + } + if outputTokens, ok := event["outputTokens"].(float64); ok { + totalUsage.OutputTokens = int64(outputTokens) + } + } + + if usageEvent, ok := event["supplementaryWebLinksEvent"].(map[string]interface{}); ok { + if inputTokens, ok := usageEvent["inputTokens"].(float64); ok { + totalUsage.InputTokens = int64(inputTokens) + } + if outputTokens, ok := usageEvent["outputTokens"].(float64); ok { + totalUsage.OutputTokens = int64(outputTokens) + } + } + } + + // Close content block if open + if isBlockOpen && contentBlockIndex >= 0 { + blockStop := e.buildClaudeContentBlockStopEvent(contentBlockIndex) + sseData := sdktranslator.TranslateStream(ctx, sdktranslator.FromString("claude"), targetFormat, model, originalReq, claudeBody, blockStop, &translatorParam) + for _, chunk := range sseData { + if chunk != "" { + c.Writer.Write([]byte(chunk + "\n\n")) + } + } + c.Writer.Flush() + } + + // Fallback for output tokens if not received from upstream + if totalUsage.OutputTokens == 0 && outputLen > 0 { + totalUsage.OutputTokens = int64(outputLen / 4) + if totalUsage.OutputTokens == 0 { + totalUsage.OutputTokens = 1 + } + } + totalUsage.TotalTokens = totalUsage.InputTokens + totalUsage.OutputTokens + + // Always use end_turn (no tool_use support) + msgStop := e.buildClaudeMessageStopEvent("end_turn", totalUsage) + sseData := sdktranslator.TranslateStream(ctx, sdktranslator.FromString("claude"), targetFormat, model, originalReq, claudeBody, msgStop, &translatorParam) + for _, chunk := range sseData { + if chunk != "" { + c.Writer.Write([]byte(chunk + "\n\n")) + } + } + + c.Writer.Write([]byte("data: [DONE]\n\n")) + c.Writer.Flush() + + reporter.publish(ctx, totalUsage) + return nil +} + +// isTokenExpired checks if a JWT access token has expired. +// Returns true if the token is expired or cannot be parsed. +func (e *KiroExecutor) isTokenExpired(accessToken string) bool { + if accessToken == "" { + return true + } + + // JWT tokens have 3 parts separated by dots + parts := strings.Split(accessToken, ".") + if len(parts) != 3 { + // Not a JWT token, assume not expired + return false + } + + // Decode the payload (second part) + // JWT uses base64url encoding without padding (RawURLEncoding) + payload := parts[1] + decoded, err := base64.RawURLEncoding.DecodeString(payload) + if err != nil { + // Try with padding added as fallback + switch len(payload) % 4 { + case 2: + payload += "==" + case 3: + payload += "=" + } + decoded, err = base64.URLEncoding.DecodeString(payload) + if err != nil { + log.Debugf("kiro: failed to decode JWT payload: %v", err) + return false + } + } + + var claims struct { + Exp int64 `json:"exp"` + } + if err := json.Unmarshal(decoded, &claims); err != nil { + log.Debugf("kiro: failed to parse JWT claims: %v", err) + return false + } + + if claims.Exp == 0 { + // No expiration claim, assume not expired + return false + } + + expTime := time.Unix(claims.Exp, 0) + now := time.Now() + + // Consider token expired if it expires within 1 minute (buffer for clock skew) + isExpired := now.After(expTime) || expTime.Sub(now) < time.Minute + if isExpired { + log.Debugf("kiro: token expired at %s (now: %s)", expTime.Format(time.RFC3339), now.Format(time.RFC3339)) + } + + return isExpired +} + +// ============================================================================ +// Tool Calling Support - Embedded tool call parsing and input buffering +// Based on amq2api and AIClient-2-API implementations +// ============================================================================ + +// toolUseState tracks the state of an in-progress tool use during streaming. +type toolUseState struct { + toolUseID string + name string + inputBuffer strings.Builder + isComplete bool +} + +// Pre-compiled regex patterns for performance (avoid recompilation on each call) +var ( + // embeddedToolCallPattern matches [Called tool_name with args: {...}] format + // This pattern is used by Kiro when it embeds tool calls in text content + embeddedToolCallPattern = regexp.MustCompile(`\[Called\s+(\w+)\s+with\s+args:\s*`) + // whitespaceCollapsePattern collapses multiple whitespace characters into single space + whitespaceCollapsePattern = regexp.MustCompile(`\s+`) + // trailingCommaPattern matches trailing commas before closing braces/brackets + trailingCommaPattern = regexp.MustCompile(`,\s*([}\]])`) + // unquotedKeyPattern matches unquoted JSON keys that need quoting + unquotedKeyPattern = regexp.MustCompile(`([{,]\s*)([a-zA-Z_][a-zA-Z0-9_]*)\s*:`) +) + +// parseEmbeddedToolCalls extracts [Called tool_name with args: {...}] format from text. +// Kiro sometimes embeds tool calls in text content instead of using toolUseEvent. +// Returns the cleaned text (with tool calls removed) and extracted tool uses. +func (e *KiroExecutor) parseEmbeddedToolCalls(text string, processedIDs map[string]bool) (string, []kiroToolUse) { + if !strings.Contains(text, "[Called") { + return text, nil + } + + var toolUses []kiroToolUse + cleanText := text + + // Find all [Called markers + matches := embeddedToolCallPattern.FindAllStringSubmatchIndex(text, -1) + if len(matches) == 0 { + return text, nil + } + + // Process matches in reverse order to maintain correct indices + for i := len(matches) - 1; i >= 0; i-- { + matchStart := matches[i][0] + toolNameStart := matches[i][2] + toolNameEnd := matches[i][3] + + if toolNameStart < 0 || toolNameEnd < 0 { + continue + } + + toolName := text[toolNameStart:toolNameEnd] + + // Find the JSON object start (after "with args:") + jsonStart := matches[i][1] + if jsonStart >= len(text) { + continue + } + + // Skip whitespace to find the opening brace + for jsonStart < len(text) && (text[jsonStart] == ' ' || text[jsonStart] == '\t') { + jsonStart++ + } + + if jsonStart >= len(text) || text[jsonStart] != '{' { + continue + } + + // Find matching closing bracket + jsonEnd := findMatchingBracket(text, jsonStart) + if jsonEnd < 0 { + continue + } + + // Extract JSON and find the closing bracket of [Called ...] + jsonStr := text[jsonStart : jsonEnd+1] + + // Find the closing ] after the JSON + closingBracket := jsonEnd + 1 + for closingBracket < len(text) && text[closingBracket] != ']' { + closingBracket++ + } + if closingBracket >= len(text) { + continue + } + + // Extract and repair the full tool call text + fullMatch := text[matchStart : closingBracket+1] + + // Repair and parse JSON + repairedJSON := repairJSON(jsonStr) + var inputMap map[string]interface{} + if err := json.Unmarshal([]byte(repairedJSON), &inputMap); err != nil { + log.Debugf("kiro: failed to parse embedded tool call JSON: %v, raw: %s", err, jsonStr) + continue + } + + // Generate unique tool ID + toolUseID := "toolu_" + uuid.New().String()[:12] + + // Check for duplicates using name+input as key + dedupeKey := toolName + ":" + repairedJSON + if processedIDs != nil { + if processedIDs[dedupeKey] { + log.Debugf("kiro: skipping duplicate embedded tool call: %s", toolName) + // Still remove from text even if duplicate + cleanText = strings.Replace(cleanText, fullMatch, "", 1) + continue + } + processedIDs[dedupeKey] = true + } + + toolUses = append(toolUses, kiroToolUse{ + ToolUseID: toolUseID, + Name: toolName, + Input: inputMap, + }) + + log.Infof("kiro: extracted embedded tool call: %s (ID: %s)", toolName, toolUseID) + + // Remove from clean text + cleanText = strings.Replace(cleanText, fullMatch, "", 1) + } + + // Clean up extra whitespace + cleanText = strings.TrimSpace(cleanText) + cleanText = whitespaceCollapsePattern.ReplaceAllString(cleanText, " ") + + return cleanText, toolUses +} + +// findMatchingBracket finds the index of the closing brace/bracket that matches +// the opening one at startPos. Handles nested objects and strings correctly. +func findMatchingBracket(text string, startPos int) int { + if startPos >= len(text) { + return -1 + } + + openChar := text[startPos] + var closeChar byte + switch openChar { + case '{': + closeChar = '}' + case '[': + closeChar = ']' + default: + return -1 + } + + depth := 1 + inString := false + escapeNext := false + + for i := startPos + 1; i < len(text); i++ { + char := text[i] + + if escapeNext { + escapeNext = false + continue + } + + if char == '\\' && inString { + escapeNext = true + continue + } + + if char == '"' { + inString = !inString + continue + } + + if !inString { + if char == openChar { + depth++ + } else if char == closeChar { + depth-- + if depth == 0 { + return i + } + } + } + } + + return -1 +} + +// repairJSON attempts to fix common JSON issues that may occur in tool call arguments. +// Based on AIClient-2-API's JSON repair implementation. +// Uses pre-compiled regex patterns for performance. +func repairJSON(raw string) string { + // Remove trailing commas before closing braces/brackets + repaired := trailingCommaPattern.ReplaceAllString(raw, "$1") + // Fix unquoted keys (basic attempt - handles simple cases) + repaired = unquotedKeyPattern.ReplaceAllString(repaired, `$1"$2":`) + return repaired +} + +// processToolUseEvent handles a toolUseEvent from the Kiro stream. +// It accumulates input fragments and emits tool_use blocks when complete. +// Returns events to emit and updated state. +func (e *KiroExecutor) processToolUseEvent(event map[string]interface{}, currentToolUse *toolUseState, processedIDs map[string]bool) ([]kiroToolUse, *toolUseState) { + var toolUses []kiroToolUse + + // Extract from nested toolUseEvent or direct format + tu := event + if nested, ok := event["toolUseEvent"].(map[string]interface{}); ok { + tu = nested + } + + toolUseID := getString(tu, "toolUseId") + toolName := getString(tu, "name") + isStop := false + if stop, ok := tu["stop"].(bool); ok { + isStop = stop + } + + // Get input - can be string (fragment) or object (complete) + var inputFragment string + var inputMap map[string]interface{} + + if inputRaw, ok := tu["input"]; ok { + switch v := inputRaw.(type) { + case string: + inputFragment = v + case map[string]interface{}: + inputMap = v + } + } + + // New tool use starting + if toolUseID != "" && toolName != "" { + if currentToolUse != nil && currentToolUse.toolUseID != toolUseID { + // New tool use arrived while another is in progress (interleaved events) + // This is unusual - log warning and complete the previous one + log.Warnf("kiro: interleaved tool use detected - new ID %s arrived while %s in progress, completing previous", + toolUseID, currentToolUse.toolUseID) + // Emit incomplete previous tool use + if !processedIDs[currentToolUse.toolUseID] { + incomplete := kiroToolUse{ + ToolUseID: currentToolUse.toolUseID, + Name: currentToolUse.name, + } + if currentToolUse.inputBuffer.Len() > 0 { + var input map[string]interface{} + if err := json.Unmarshal([]byte(currentToolUse.inputBuffer.String()), &input); err == nil { + incomplete.Input = input + } + } + toolUses = append(toolUses, incomplete) + processedIDs[currentToolUse.toolUseID] = true + } + currentToolUse = nil + } + + if currentToolUse == nil { + // Check for duplicate + if processedIDs != nil && processedIDs[toolUseID] { + log.Debugf("kiro: skipping duplicate toolUseEvent: %s", toolUseID) + return nil, nil + } + + currentToolUse = &toolUseState{ + toolUseID: toolUseID, + name: toolName, + } + log.Infof("kiro: starting new tool use: %s (ID: %s)", toolName, toolUseID) + } + } + + // Accumulate input fragments + if currentToolUse != nil && inputFragment != "" { + currentToolUse.inputBuffer.WriteString(inputFragment) + log.Debugf("kiro: accumulated input fragment, total length: %d", currentToolUse.inputBuffer.Len()) + } + + // If complete input object provided directly + if currentToolUse != nil && inputMap != nil { + inputBytes, _ := json.Marshal(inputMap) + currentToolUse.inputBuffer.Reset() + currentToolUse.inputBuffer.Write(inputBytes) + } + + // Tool use complete + if isStop && currentToolUse != nil { + fullInput := currentToolUse.inputBuffer.String() + + // Repair and parse the accumulated JSON + repairedJSON := repairJSON(fullInput) + var finalInput map[string]interface{} + if err := json.Unmarshal([]byte(repairedJSON), &finalInput); err != nil { + log.Warnf("kiro: failed to parse accumulated tool input: %v, raw: %s", err, fullInput) + // Use empty input as fallback + finalInput = make(map[string]interface{}) + } + + toolUse := kiroToolUse{ + ToolUseID: currentToolUse.toolUseID, + Name: currentToolUse.name, + Input: finalInput, + } + toolUses = append(toolUses, toolUse) + + // Mark as processed + if processedIDs != nil { + processedIDs[currentToolUse.toolUseID] = true + } + + log.Infof("kiro: completed tool use: %s (ID: %s)", currentToolUse.name, currentToolUse.toolUseID) + return toolUses, nil // Reset state + } + + return toolUses, currentToolUse +} + +// deduplicateToolUses removes duplicate tool uses based on toolUseId. +func deduplicateToolUses(toolUses []kiroToolUse) []kiroToolUse { + seen := make(map[string]bool) + var unique []kiroToolUse + + for _, tu := range toolUses { + if !seen[tu.ToolUseID] { + seen[tu.ToolUseID] = true + unique = append(unique, tu) + } else { + log.Debugf("kiro: removing duplicate tool use: %s", tu.ToolUseID) + } + } + + return unique +} diff --git a/internal/runtime/executor/proxy_helpers.go b/internal/runtime/executor/proxy_helpers.go index ab0f626a..8998eb23 100644 --- a/internal/runtime/executor/proxy_helpers.go +++ b/internal/runtime/executor/proxy_helpers.go @@ -6,6 +6,7 @@ import ( "net/http" "net/url" "strings" + "sync" "time" "github.com/router-for-me/CLIProxyAPI/v6/internal/config" @@ -14,11 +15,19 @@ import ( "golang.org/x/net/proxy" ) +// httpClientCache caches HTTP clients by proxy URL to enable connection reuse +var ( + httpClientCache = make(map[string]*http.Client) + httpClientCacheMutex sync.RWMutex +) + // newProxyAwareHTTPClient creates an HTTP client with proper proxy configuration priority: // 1. Use auth.ProxyURL if configured (highest priority) // 2. Use cfg.ProxyURL if auth proxy is not configured // 3. Use RoundTripper from context if neither are configured // +// This function caches HTTP clients by proxy URL to enable TCP/TLS connection reuse. +// // Parameters: // - ctx: The context containing optional RoundTripper // - cfg: The application configuration @@ -28,11 +37,6 @@ import ( // Returns: // - *http.Client: An HTTP client with configured proxy or transport func newProxyAwareHTTPClient(ctx context.Context, cfg *config.Config, auth *cliproxyauth.Auth, timeout time.Duration) *http.Client { - httpClient := &http.Client{} - if timeout > 0 { - httpClient.Timeout = timeout - } - // Priority 1: Use auth.ProxyURL if configured var proxyURL string if auth != nil { @@ -44,11 +48,39 @@ func newProxyAwareHTTPClient(ctx context.Context, cfg *config.Config, auth *clip proxyURL = strings.TrimSpace(cfg.ProxyURL) } + // Build cache key from proxy URL (empty string for no proxy) + cacheKey := proxyURL + + // Check cache first + httpClientCacheMutex.RLock() + if cachedClient, ok := httpClientCache[cacheKey]; ok { + httpClientCacheMutex.RUnlock() + // Return a wrapper with the requested timeout but shared transport + if timeout > 0 { + return &http.Client{ + Transport: cachedClient.Transport, + Timeout: timeout, + } + } + return cachedClient + } + httpClientCacheMutex.RUnlock() + + // Create new client + httpClient := &http.Client{} + if timeout > 0 { + httpClient.Timeout = timeout + } + // If we have a proxy URL configured, set up the transport if proxyURL != "" { transport := buildProxyTransport(proxyURL) if transport != nil { httpClient.Transport = transport + // Cache the client + httpClientCacheMutex.Lock() + httpClientCache[cacheKey] = httpClient + httpClientCacheMutex.Unlock() return httpClient } // If proxy setup failed, log and fall through to context RoundTripper @@ -60,6 +92,13 @@ func newProxyAwareHTTPClient(ctx context.Context, cfg *config.Config, auth *clip httpClient.Transport = rt } + // Cache the client for no-proxy case + if proxyURL == "" { + httpClientCacheMutex.Lock() + httpClientCache[cacheKey] = httpClient + httpClientCacheMutex.Unlock() + } + return httpClient } diff --git a/internal/translator/claude/gemini/claude_gemini_response.go b/internal/translator/claude/gemini/claude_gemini_response.go index 0c90398e..72e1820c 100644 --- a/internal/translator/claude/gemini/claude_gemini_response.go +++ b/internal/translator/claude/gemini/claude_gemini_response.go @@ -331,8 +331,9 @@ func ConvertClaudeResponseToGeminiNonStream(_ context.Context, modelName string, streamingEvents := make([][]byte, 0) scanner := bufio.NewScanner(bytes.NewReader(rawJSON)) - buffer := make([]byte, 20_971_520) - scanner.Buffer(buffer, 20_971_520) + // Use a smaller initial buffer (64KB) that can grow up to 20MB if needed + // This prevents allocating 20MB for every request regardless of size + scanner.Buffer(make([]byte, 64*1024), 20_971_520) for scanner.Scan() { line := scanner.Bytes() // log.Debug(string(line)) diff --git a/internal/translator/claude/openai/chat-completions/claude_openai_response.go b/internal/translator/claude/openai/chat-completions/claude_openai_response.go index f8fd4018..5b0238bf 100644 --- a/internal/translator/claude/openai/chat-completions/claude_openai_response.go +++ b/internal/translator/claude/openai/chat-completions/claude_openai_response.go @@ -50,6 +50,10 @@ type ToolCallAccumulator struct { // Returns: // - []string: A slice of strings, each containing an OpenAI-compatible JSON response func ConvertClaudeResponseToOpenAI(_ context.Context, modelName string, originalRequestRawJSON, requestRawJSON, rawJSON []byte, param *any) []string { + var localParam any + if param == nil { + param = &localParam + } if *param == nil { *param = &ConvertAnthropicResponseToOpenAIParams{ CreatedAt: 0, diff --git a/internal/translator/gemini-cli/claude/gemini-cli_claude_response.go b/internal/translator/gemini-cli/claude/gemini-cli_claude_response.go index 733668f3..e7f6275f 100644 --- a/internal/translator/gemini-cli/claude/gemini-cli_claude_response.go +++ b/internal/translator/gemini-cli/claude/gemini-cli_claude_response.go @@ -60,12 +60,12 @@ func ConvertGeminiCLIResponseToClaude(_ context.Context, _ string, originalReque // Track whether tools are being used in this response chunk usedTool := false - output := "" + var sb strings.Builder // Initialize the streaming session with a message_start event // This is only sent for the very first response chunk to establish the streaming session if !(*param).(*Params).HasFirstResponse { - output = "event: message_start\n" + sb.WriteString("event: message_start\n") // Create the initial message structure with default values according to Claude Code API specification // This follows the Claude Code API specification for streaming message initialization @@ -78,7 +78,7 @@ func ConvertGeminiCLIResponseToClaude(_ context.Context, _ string, originalReque if responseIDResult := gjson.GetBytes(rawJSON, "response.responseId"); responseIDResult.Exists() { messageStartTemplate, _ = sjson.Set(messageStartTemplate, "message.id", responseIDResult.String()) } - output = output + fmt.Sprintf("data: %s\n\n\n", messageStartTemplate) + sb.WriteString(fmt.Sprintf("data: %s\n\n\n", messageStartTemplate)) (*param).(*Params).HasFirstResponse = true } @@ -101,62 +101,52 @@ func ConvertGeminiCLIResponseToClaude(_ context.Context, _ string, originalReque if partResult.Get("thought").Bool() { // Continue existing thinking block if already in thinking state if (*param).(*Params).ResponseType == 2 { - output = output + "event: content_block_delta\n" + sb.WriteString("event: content_block_delta\n") data, _ := sjson.Set(fmt.Sprintf(`{"type":"content_block_delta","index":%d,"delta":{"type":"thinking_delta","thinking":""}}`, (*param).(*Params).ResponseIndex), "delta.thinking", partTextResult.String()) - output = output + fmt.Sprintf("data: %s\n\n\n", data) + sb.WriteString(fmt.Sprintf("data: %s\n\n\n", data)) } else { // Transition from another state to thinking // First, close any existing content block if (*param).(*Params).ResponseType != 0 { - if (*param).(*Params).ResponseType == 2 { - // output = output + "event: content_block_delta\n" - // output = output + fmt.Sprintf(`data: {"type":"content_block_delta","index":%d,"delta":{"type":"signature_delta","signature":null}}`, (*param).(*Params).ResponseIndex) - // output = output + "\n\n\n" - } - output = output + "event: content_block_stop\n" - output = output + fmt.Sprintf(`data: {"type":"content_block_stop","index":%d}`, (*param).(*Params).ResponseIndex) - output = output + "\n\n\n" + sb.WriteString("event: content_block_stop\n") + sb.WriteString(fmt.Sprintf(`data: {"type":"content_block_stop","index":%d}`, (*param).(*Params).ResponseIndex)) + sb.WriteString("\n\n\n") (*param).(*Params).ResponseIndex++ } // Start a new thinking content block - output = output + "event: content_block_start\n" - output = output + fmt.Sprintf(`data: {"type":"content_block_start","index":%d,"content_block":{"type":"thinking","thinking":""}}`, (*param).(*Params).ResponseIndex) - output = output + "\n\n\n" - output = output + "event: content_block_delta\n" + sb.WriteString("event: content_block_start\n") + sb.WriteString(fmt.Sprintf(`data: {"type":"content_block_start","index":%d,"content_block":{"type":"thinking","thinking":""}}`, (*param).(*Params).ResponseIndex)) + sb.WriteString("\n\n\n") + sb.WriteString("event: content_block_delta\n") data, _ := sjson.Set(fmt.Sprintf(`{"type":"content_block_delta","index":%d,"delta":{"type":"thinking_delta","thinking":""}}`, (*param).(*Params).ResponseIndex), "delta.thinking", partTextResult.String()) - output = output + fmt.Sprintf("data: %s\n\n\n", data) + sb.WriteString(fmt.Sprintf("data: %s\n\n\n", data)) (*param).(*Params).ResponseType = 2 // Set state to thinking } } else { // Process regular text content (user-visible output) // Continue existing text block if already in content state if (*param).(*Params).ResponseType == 1 { - output = output + "event: content_block_delta\n" + sb.WriteString("event: content_block_delta\n") data, _ := sjson.Set(fmt.Sprintf(`{"type":"content_block_delta","index":%d,"delta":{"type":"text_delta","text":""}}`, (*param).(*Params).ResponseIndex), "delta.text", partTextResult.String()) - output = output + fmt.Sprintf("data: %s\n\n\n", data) + sb.WriteString(fmt.Sprintf("data: %s\n\n\n", data)) } else { // Transition from another state to text content // First, close any existing content block if (*param).(*Params).ResponseType != 0 { - if (*param).(*Params).ResponseType == 2 { - // output = output + "event: content_block_delta\n" - // output = output + fmt.Sprintf(`data: {"type":"content_block_delta","index":%d,"delta":{"type":"signature_delta","signature":null}}`, (*param).(*Params).ResponseIndex) - // output = output + "\n\n\n" - } - output = output + "event: content_block_stop\n" - output = output + fmt.Sprintf(`data: {"type":"content_block_stop","index":%d}`, (*param).(*Params).ResponseIndex) - output = output + "\n\n\n" + sb.WriteString("event: content_block_stop\n") + sb.WriteString(fmt.Sprintf(`data: {"type":"content_block_stop","index":%d}`, (*param).(*Params).ResponseIndex)) + sb.WriteString("\n\n\n") (*param).(*Params).ResponseIndex++ } // Start a new text content block - output = output + "event: content_block_start\n" - output = output + fmt.Sprintf(`data: {"type":"content_block_start","index":%d,"content_block":{"type":"text","text":""}}`, (*param).(*Params).ResponseIndex) - output = output + "\n\n\n" - output = output + "event: content_block_delta\n" + sb.WriteString("event: content_block_start\n") + sb.WriteString(fmt.Sprintf(`data: {"type":"content_block_start","index":%d,"content_block":{"type":"text","text":""}}`, (*param).(*Params).ResponseIndex)) + sb.WriteString("\n\n\n") + sb.WriteString("event: content_block_delta\n") data, _ := sjson.Set(fmt.Sprintf(`{"type":"content_block_delta","index":%d,"delta":{"type":"text_delta","text":""}}`, (*param).(*Params).ResponseIndex), "delta.text", partTextResult.String()) - output = output + fmt.Sprintf("data: %s\n\n\n", data) + sb.WriteString(fmt.Sprintf("data: %s\n\n\n", data)) (*param).(*Params).ResponseType = 1 // Set state to content } } @@ -169,42 +159,35 @@ func ConvertGeminiCLIResponseToClaude(_ context.Context, _ string, originalReque // Handle state transitions when switching to function calls // Close any existing function call block first if (*param).(*Params).ResponseType == 3 { - output = output + "event: content_block_stop\n" - output = output + fmt.Sprintf(`data: {"type":"content_block_stop","index":%d}`, (*param).(*Params).ResponseIndex) - output = output + "\n\n\n" + sb.WriteString("event: content_block_stop\n") + sb.WriteString(fmt.Sprintf(`data: {"type":"content_block_stop","index":%d}`, (*param).(*Params).ResponseIndex)) + sb.WriteString("\n\n\n") (*param).(*Params).ResponseIndex++ (*param).(*Params).ResponseType = 0 } - // Special handling for thinking state transition - if (*param).(*Params).ResponseType == 2 { - // output = output + "event: content_block_delta\n" - // output = output + fmt.Sprintf(`data: {"type":"content_block_delta","index":%d,"delta":{"type":"signature_delta","signature":null}}`, (*param).(*Params).ResponseIndex) - // output = output + "\n\n\n" - } - // Close any other existing content block if (*param).(*Params).ResponseType != 0 { - output = output + "event: content_block_stop\n" - output = output + fmt.Sprintf(`data: {"type":"content_block_stop","index":%d}`, (*param).(*Params).ResponseIndex) - output = output + "\n\n\n" + sb.WriteString("event: content_block_stop\n") + sb.WriteString(fmt.Sprintf(`data: {"type":"content_block_stop","index":%d}`, (*param).(*Params).ResponseIndex)) + sb.WriteString("\n\n\n") (*param).(*Params).ResponseIndex++ } // Start a new tool use content block // This creates the structure for a function call in Claude Code format - output = output + "event: content_block_start\n" + sb.WriteString("event: content_block_start\n") // Create the tool use block with unique ID and function details data := fmt.Sprintf(`{"type":"content_block_start","index":%d,"content_block":{"type":"tool_use","id":"","name":"","input":{}}}`, (*param).(*Params).ResponseIndex) data, _ = sjson.Set(data, "content_block.id", fmt.Sprintf("%s-%d", fcName, time.Now().UnixNano())) data, _ = sjson.Set(data, "content_block.name", fcName) - output = output + fmt.Sprintf("data: %s\n\n\n", data) + sb.WriteString(fmt.Sprintf("data: %s\n\n\n", data)) if fcArgsResult := functionCallResult.Get("args"); fcArgsResult.Exists() { - output = output + "event: content_block_delta\n" + sb.WriteString("event: content_block_delta\n") data, _ = sjson.Set(fmt.Sprintf(`{"type":"content_block_delta","index":%d,"delta":{"type":"input_json_delta","partial_json":""}}`, (*param).(*Params).ResponseIndex), "delta.partial_json", fcArgsResult.Raw) - output = output + fmt.Sprintf("data: %s\n\n\n", data) + sb.WriteString(fmt.Sprintf("data: %s\n\n\n", data)) } (*param).(*Params).ResponseType = 3 } @@ -216,13 +199,13 @@ func ConvertGeminiCLIResponseToClaude(_ context.Context, _ string, originalReque if usageResult.Exists() && bytes.Contains(rawJSON, []byte(`"finishReason"`)) { if candidatesTokenCountResult := usageResult.Get("candidatesTokenCount"); candidatesTokenCountResult.Exists() { // Close the final content block - output = output + "event: content_block_stop\n" - output = output + fmt.Sprintf(`data: {"type":"content_block_stop","index":%d}`, (*param).(*Params).ResponseIndex) - output = output + "\n\n\n" + sb.WriteString("event: content_block_stop\n") + sb.WriteString(fmt.Sprintf(`data: {"type":"content_block_stop","index":%d}`, (*param).(*Params).ResponseIndex)) + sb.WriteString("\n\n\n") // Send the final message delta with usage information and stop reason - output = output + "event: message_delta\n" - output = output + `data: ` + sb.WriteString("event: message_delta\n") + sb.WriteString(`data: `) // Create the message delta template with appropriate stop reason template := `{"type":"message_delta","delta":{"stop_reason":"end_turn","stop_sequence":null},"usage":{"input_tokens":0,"output_tokens":0}}` @@ -236,11 +219,11 @@ func ConvertGeminiCLIResponseToClaude(_ context.Context, _ string, originalReque template, _ = sjson.Set(template, "usage.output_tokens", candidatesTokenCountResult.Int()+thoughtsTokenCount) template, _ = sjson.Set(template, "usage.input_tokens", usageResult.Get("promptTokenCount").Int()) - output = output + template + "\n\n\n" + sb.WriteString(template + "\n\n\n") } } - return []string{output} + return []string{sb.String()} } // ConvertGeminiCLIResponseToClaudeNonStream converts a non-streaming Gemini CLI response to a non-streaming Claude response. diff --git a/internal/translator/init.go b/internal/translator/init.go index 084ea7ac..d19d9b34 100644 --- a/internal/translator/init.go +++ b/internal/translator/init.go @@ -33,4 +33,7 @@ import ( _ "github.com/router-for-me/CLIProxyAPI/v6/internal/translator/antigravity/gemini" _ "github.com/router-for-me/CLIProxyAPI/v6/internal/translator/antigravity/openai/chat-completions" _ "github.com/router-for-me/CLIProxyAPI/v6/internal/translator/antigravity/openai/responses" + + _ "github.com/router-for-me/CLIProxyAPI/v6/internal/translator/kiro/claude" + _ "github.com/router-for-me/CLIProxyAPI/v6/internal/translator/kiro/openai/chat-completions" ) diff --git a/internal/translator/kiro/claude/init.go b/internal/translator/kiro/claude/init.go new file mode 100644 index 00000000..9e3a2ba3 --- /dev/null +++ b/internal/translator/kiro/claude/init.go @@ -0,0 +1,19 @@ +package claude + +import ( + . "github.com/router-for-me/CLIProxyAPI/v6/internal/constant" + "github.com/router-for-me/CLIProxyAPI/v6/internal/interfaces" + "github.com/router-for-me/CLIProxyAPI/v6/internal/translator/translator" +) + +func init() { + translator.Register( + Claude, + Kiro, + ConvertClaudeRequestToKiro, + interfaces.TranslateResponse{ + Stream: ConvertKiroResponseToClaude, + NonStream: ConvertKiroResponseToClaudeNonStream, + }, + ) +} diff --git a/internal/translator/kiro/claude/kiro_claude.go b/internal/translator/kiro/claude/kiro_claude.go new file mode 100644 index 00000000..9922860e --- /dev/null +++ b/internal/translator/kiro/claude/kiro_claude.go @@ -0,0 +1,24 @@ +// Package claude provides translation between Kiro and Claude formats. +// Since Kiro uses Claude-compatible format internally, translations are mostly pass-through. +package claude + +import ( + "bytes" + "context" +) + +// ConvertClaudeRequestToKiro converts Claude request to Kiro format. +// Since Kiro uses Claude format internally, this is mostly a pass-through. +func ConvertClaudeRequestToKiro(modelName string, inputRawJSON []byte, stream bool) []byte { + return bytes.Clone(inputRawJSON) +} + +// ConvertKiroResponseToClaude converts Kiro streaming response to Claude format. +func ConvertKiroResponseToClaude(ctx context.Context, model string, originalRequest, request, rawResponse []byte, param *any) []string { + return []string{string(rawResponse)} +} + +// ConvertKiroResponseToClaudeNonStream converts Kiro non-streaming response to Claude format. +func ConvertKiroResponseToClaudeNonStream(ctx context.Context, model string, originalRequest, request, rawResponse []byte, param *any) string { + return string(rawResponse) +} diff --git a/internal/translator/kiro/openai/chat-completions/init.go b/internal/translator/kiro/openai/chat-completions/init.go new file mode 100644 index 00000000..2a99d0e0 --- /dev/null +++ b/internal/translator/kiro/openai/chat-completions/init.go @@ -0,0 +1,19 @@ +package chat_completions + +import ( + . "github.com/router-for-me/CLIProxyAPI/v6/internal/constant" + "github.com/router-for-me/CLIProxyAPI/v6/internal/interfaces" + "github.com/router-for-me/CLIProxyAPI/v6/internal/translator/translator" +) + +func init() { + translator.Register( + OpenAI, + Kiro, + ConvertOpenAIRequestToKiro, + interfaces.TranslateResponse{ + Stream: ConvertKiroResponseToOpenAI, + NonStream: ConvertKiroResponseToOpenAINonStream, + }, + ) +} diff --git a/internal/translator/kiro/openai/chat-completions/kiro_openai_request.go b/internal/translator/kiro/openai/chat-completions/kiro_openai_request.go new file mode 100644 index 00000000..fc850d96 --- /dev/null +++ b/internal/translator/kiro/openai/chat-completions/kiro_openai_request.go @@ -0,0 +1,258 @@ +// Package chat_completions provides request translation from OpenAI to Kiro format. +package chat_completions + +import ( + "bytes" + "encoding/json" + + "github.com/tidwall/gjson" + "github.com/tidwall/sjson" +) + +// ConvertOpenAIRequestToKiro transforms an OpenAI Chat Completions API request into Kiro (Claude) format. +// Kiro uses Claude-compatible format internally, so we primarily pass through to Claude format. +// Supports tool calling: OpenAI tools -> Claude tools, tool_calls -> tool_use, tool messages -> tool_result. +func ConvertOpenAIRequestToKiro(modelName string, inputRawJSON []byte, stream bool) []byte { + rawJSON := bytes.Clone(inputRawJSON) + root := gjson.ParseBytes(rawJSON) + + // Build Claude-compatible request + out := `{"model":"","max_tokens":32000,"messages":[]}` + + // Set model + out, _ = sjson.Set(out, "model", modelName) + + // Copy max_tokens if present + if v := root.Get("max_tokens"); v.Exists() { + out, _ = sjson.Set(out, "max_tokens", v.Int()) + } + + // Copy temperature if present + if v := root.Get("temperature"); v.Exists() { + out, _ = sjson.Set(out, "temperature", v.Float()) + } + + // Copy top_p if present + if v := root.Get("top_p"); v.Exists() { + out, _ = sjson.Set(out, "top_p", v.Float()) + } + + // Convert OpenAI tools to Claude tools format + if tools := root.Get("tools"); tools.Exists() && tools.IsArray() { + claudeTools := make([]interface{}, 0) + for _, tool := range tools.Array() { + if tool.Get("type").String() == "function" { + fn := tool.Get("function") + claudeTool := map[string]interface{}{ + "name": fn.Get("name").String(), + "description": fn.Get("description").String(), + } + // Convert parameters to input_schema + if params := fn.Get("parameters"); params.Exists() { + claudeTool["input_schema"] = params.Value() + } else { + claudeTool["input_schema"] = map[string]interface{}{ + "type": "object", + "properties": map[string]interface{}{}, + } + } + claudeTools = append(claudeTools, claudeTool) + } + } + if len(claudeTools) > 0 { + out, _ = sjson.Set(out, "tools", claudeTools) + } + } + + // Process messages + messages := root.Get("messages") + if messages.Exists() && messages.IsArray() { + claudeMessages := make([]interface{}, 0) + var systemPrompt string + + // Track pending tool results to merge with next user message + var pendingToolResults []map[string]interface{} + + for _, msg := range messages.Array() { + role := msg.Get("role").String() + content := msg.Get("content") + + if role == "system" { + // Extract system message + if content.IsArray() { + for _, part := range content.Array() { + if part.Get("type").String() == "text" { + systemPrompt += part.Get("text").String() + "\n" + } + } + } else { + systemPrompt = content.String() + } + continue + } + + if role == "tool" { + // OpenAI tool message -> Claude tool_result content block + toolCallID := msg.Get("tool_call_id").String() + toolContent := content.String() + + toolResult := map[string]interface{}{ + "type": "tool_result", + "tool_use_id": toolCallID, + } + + // Handle content - can be string or structured + if content.IsArray() { + contentParts := make([]interface{}, 0) + for _, part := range content.Array() { + if part.Get("type").String() == "text" { + contentParts = append(contentParts, map[string]interface{}{ + "type": "text", + "text": part.Get("text").String(), + }) + } + } + toolResult["content"] = contentParts + } else { + toolResult["content"] = toolContent + } + + pendingToolResults = append(pendingToolResults, toolResult) + continue + } + + claudeMsg := map[string]interface{}{ + "role": role, + } + + // Handle assistant messages with tool_calls + if role == "assistant" && msg.Get("tool_calls").Exists() { + contentParts := make([]interface{}, 0) + + // Add text content if present + if content.Exists() && content.String() != "" { + contentParts = append(contentParts, map[string]interface{}{ + "type": "text", + "text": content.String(), + }) + } + + // Convert tool_calls to tool_use blocks + for _, toolCall := range msg.Get("tool_calls").Array() { + toolUseID := toolCall.Get("id").String() + fnName := toolCall.Get("function.name").String() + fnArgs := toolCall.Get("function.arguments").String() + + // Parse arguments JSON + var argsMap map[string]interface{} + if err := json.Unmarshal([]byte(fnArgs), &argsMap); err != nil { + argsMap = map[string]interface{}{"raw": fnArgs} + } + + contentParts = append(contentParts, map[string]interface{}{ + "type": "tool_use", + "id": toolUseID, + "name": fnName, + "input": argsMap, + }) + } + + claudeMsg["content"] = contentParts + claudeMessages = append(claudeMessages, claudeMsg) + continue + } + + // Handle user messages - may need to include pending tool results + if role == "user" && len(pendingToolResults) > 0 { + contentParts := make([]interface{}, 0) + + // Add pending tool results first + for _, tr := range pendingToolResults { + contentParts = append(contentParts, tr) + } + pendingToolResults = nil + + // Add user content + if content.IsArray() { + for _, part := range content.Array() { + partType := part.Get("type").String() + if partType == "text" { + contentParts = append(contentParts, map[string]interface{}{ + "type": "text", + "text": part.Get("text").String(), + }) + } else if partType == "image_url" { + contentParts = append(contentParts, map[string]interface{}{ + "type": "image", + "source": map[string]interface{}{ + "type": "url", + "url": part.Get("image_url.url").String(), + }, + }) + } + } + } else if content.String() != "" { + contentParts = append(contentParts, map[string]interface{}{ + "type": "text", + "text": content.String(), + }) + } + + claudeMsg["content"] = contentParts + claudeMessages = append(claudeMessages, claudeMsg) + continue + } + + // Handle regular content + if content.IsArray() { + contentParts := make([]interface{}, 0) + for _, part := range content.Array() { + partType := part.Get("type").String() + if partType == "text" { + contentParts = append(contentParts, map[string]interface{}{ + "type": "text", + "text": part.Get("text").String(), + }) + } else if partType == "image_url" { + contentParts = append(contentParts, map[string]interface{}{ + "type": "image", + "source": map[string]interface{}{ + "type": "url", + "url": part.Get("image_url.url").String(), + }, + }) + } + } + claudeMsg["content"] = contentParts + } else { + claudeMsg["content"] = content.String() + } + + claudeMessages = append(claudeMessages, claudeMsg) + } + + // If there are pending tool results without a following user message, + // create a user message with just the tool results + if len(pendingToolResults) > 0 { + contentParts := make([]interface{}, 0) + for _, tr := range pendingToolResults { + contentParts = append(contentParts, tr) + } + claudeMessages = append(claudeMessages, map[string]interface{}{ + "role": "user", + "content": contentParts, + }) + } + + out, _ = sjson.Set(out, "messages", claudeMessages) + + if systemPrompt != "" { + out, _ = sjson.Set(out, "system", systemPrompt) + } + } + + // Set stream + out, _ = sjson.Set(out, "stream", stream) + + return []byte(out) +} diff --git a/internal/translator/kiro/openai/chat-completions/kiro_openai_response.go b/internal/translator/kiro/openai/chat-completions/kiro_openai_response.go new file mode 100644 index 00000000..6a0ad250 --- /dev/null +++ b/internal/translator/kiro/openai/chat-completions/kiro_openai_response.go @@ -0,0 +1,316 @@ +// Package chat_completions provides response translation from Kiro to OpenAI format. +package chat_completions + +import ( + "context" + "encoding/json" + "time" + + "github.com/google/uuid" + "github.com/tidwall/gjson" +) + +// ConvertKiroResponseToOpenAI converts Kiro streaming response to OpenAI SSE format. +// Handles Claude SSE events: content_block_start, content_block_delta, input_json_delta, +// content_block_stop, message_delta, and message_stop. +func ConvertKiroResponseToOpenAI(ctx context.Context, model string, originalRequest, request, rawResponse []byte, param *any) []string { + root := gjson.ParseBytes(rawResponse) + var results []string + + eventType := root.Get("type").String() + + switch eventType { + case "message_start": + // Initial message event - could emit initial chunk if needed + return results + + case "content_block_start": + // Start of a content block (text or tool_use) + blockType := root.Get("content_block.type").String() + index := int(root.Get("index").Int()) + + if blockType == "tool_use" { + // Start of tool_use block + toolUseID := root.Get("content_block.id").String() + toolName := root.Get("content_block.name").String() + + toolCall := map[string]interface{}{ + "index": index, + "id": toolUseID, + "type": "function", + "function": map[string]interface{}{ + "name": toolName, + "arguments": "", + }, + } + + response := map[string]interface{}{ + "id": "chatcmpl-" + uuid.New().String()[:24], + "object": "chat.completion.chunk", + "created": time.Now().Unix(), + "model": model, + "choices": []map[string]interface{}{ + { + "index": 0, + "delta": map[string]interface{}{ + "tool_calls": []map[string]interface{}{toolCall}, + }, + "finish_reason": nil, + }, + }, + } + result, _ := json.Marshal(response) + results = append(results, string(result)) + } + return results + + case "content_block_delta": + index := int(root.Get("index").Int()) + deltaType := root.Get("delta.type").String() + + if deltaType == "text_delta" { + // Text content delta + contentDelta := root.Get("delta.text").String() + if contentDelta != "" { + response := map[string]interface{}{ + "id": "chatcmpl-" + uuid.New().String()[:24], + "object": "chat.completion.chunk", + "created": time.Now().Unix(), + "model": model, + "choices": []map[string]interface{}{ + { + "index": 0, + "delta": map[string]interface{}{ + "content": contentDelta, + }, + "finish_reason": nil, + }, + }, + } + result, _ := json.Marshal(response) + results = append(results, string(result)) + } + } else if deltaType == "input_json_delta" { + // Tool input delta (streaming arguments) + partialJSON := root.Get("delta.partial_json").String() + if partialJSON != "" { + toolCall := map[string]interface{}{ + "index": index, + "function": map[string]interface{}{ + "arguments": partialJSON, + }, + } + + response := map[string]interface{}{ + "id": "chatcmpl-" + uuid.New().String()[:24], + "object": "chat.completion.chunk", + "created": time.Now().Unix(), + "model": model, + "choices": []map[string]interface{}{ + { + "index": 0, + "delta": map[string]interface{}{ + "tool_calls": []map[string]interface{}{toolCall}, + }, + "finish_reason": nil, + }, + }, + } + result, _ := json.Marshal(response) + results = append(results, string(result)) + } + } + return results + + case "content_block_stop": + // End of content block - no output needed for OpenAI format + return results + + case "message_delta": + // Final message delta with stop_reason + stopReason := root.Get("delta.stop_reason").String() + if stopReason != "" { + finishReason := "stop" + if stopReason == "tool_use" { + finishReason = "tool_calls" + } else if stopReason == "end_turn" { + finishReason = "stop" + } else if stopReason == "max_tokens" { + finishReason = "length" + } + + response := map[string]interface{}{ + "id": "chatcmpl-" + uuid.New().String()[:24], + "object": "chat.completion.chunk", + "created": time.Now().Unix(), + "model": model, + "choices": []map[string]interface{}{ + { + "index": 0, + "delta": map[string]interface{}{}, + "finish_reason": finishReason, + }, + }, + } + result, _ := json.Marshal(response) + results = append(results, string(result)) + } + return results + + case "message_stop": + // End of message - could emit [DONE] marker + return results + } + + // Fallback: handle raw content for backward compatibility + var contentDelta string + if delta := root.Get("delta.text"); delta.Exists() { + contentDelta = delta.String() + } else if content := root.Get("content"); content.Exists() && root.Get("type").String() == "" { + contentDelta = content.String() + } + + if contentDelta != "" { + response := map[string]interface{}{ + "id": "chatcmpl-" + uuid.New().String()[:24], + "object": "chat.completion.chunk", + "created": time.Now().Unix(), + "model": model, + "choices": []map[string]interface{}{ + { + "index": 0, + "delta": map[string]interface{}{ + "content": contentDelta, + }, + "finish_reason": nil, + }, + }, + } + result, _ := json.Marshal(response) + results = append(results, string(result)) + } + + // Handle tool_use content blocks (Claude format) - fallback + toolUses := root.Get("delta.tool_use") + if !toolUses.Exists() { + toolUses = root.Get("tool_use") + } + if toolUses.Exists() && toolUses.IsObject() { + inputJSON := toolUses.Get("input").String() + if inputJSON == "" { + if inputObj := toolUses.Get("input"); inputObj.Exists() { + inputBytes, _ := json.Marshal(inputObj.Value()) + inputJSON = string(inputBytes) + } + } + + toolCall := map[string]interface{}{ + "index": 0, + "id": toolUses.Get("id").String(), + "type": "function", + "function": map[string]interface{}{ + "name": toolUses.Get("name").String(), + "arguments": inputJSON, + }, + } + + response := map[string]interface{}{ + "id": "chatcmpl-" + uuid.New().String()[:24], + "object": "chat.completion.chunk", + "created": time.Now().Unix(), + "model": model, + "choices": []map[string]interface{}{ + { + "index": 0, + "delta": map[string]interface{}{ + "tool_calls": []map[string]interface{}{toolCall}, + }, + "finish_reason": nil, + }, + }, + } + result, _ := json.Marshal(response) + results = append(results, string(result)) + } + + return results +} + +// ConvertKiroResponseToOpenAINonStream converts Kiro non-streaming response to OpenAI format. +func ConvertKiroResponseToOpenAINonStream(ctx context.Context, model string, originalRequest, request, rawResponse []byte, param *any) string { + root := gjson.ParseBytes(rawResponse) + + var content string + var toolCalls []map[string]interface{} + + contentArray := root.Get("content") + if contentArray.IsArray() { + for _, item := range contentArray.Array() { + itemType := item.Get("type").String() + if itemType == "text" { + content += item.Get("text").String() + } else if itemType == "tool_use" { + // Convert Claude tool_use to OpenAI tool_calls format + inputJSON := item.Get("input").String() + if inputJSON == "" { + // If input is an object, marshal it + if inputObj := item.Get("input"); inputObj.Exists() { + inputBytes, _ := json.Marshal(inputObj.Value()) + inputJSON = string(inputBytes) + } + } + toolCall := map[string]interface{}{ + "id": item.Get("id").String(), + "type": "function", + "function": map[string]interface{}{ + "name": item.Get("name").String(), + "arguments": inputJSON, + }, + } + toolCalls = append(toolCalls, toolCall) + } + } + } else { + content = root.Get("content").String() + } + + inputTokens := root.Get("usage.input_tokens").Int() + outputTokens := root.Get("usage.output_tokens").Int() + + message := map[string]interface{}{ + "role": "assistant", + "content": content, + } + + // Add tool_calls if present + if len(toolCalls) > 0 { + message["tool_calls"] = toolCalls + } + + finishReason := "stop" + if len(toolCalls) > 0 { + finishReason = "tool_calls" + } + + response := map[string]interface{}{ + "id": "chatcmpl-" + uuid.New().String()[:24], + "object": "chat.completion", + "created": time.Now().Unix(), + "model": model, + "choices": []map[string]interface{}{ + { + "index": 0, + "message": message, + "finish_reason": finishReason, + }, + }, + "usage": map[string]interface{}{ + "prompt_tokens": inputTokens, + "completion_tokens": outputTokens, + "total_tokens": inputTokens + outputTokens, + }, + } + + result, _ := json.Marshal(response) + return string(result) +} diff --git a/internal/watcher/watcher.go b/internal/watcher/watcher.go index 1f4f9043..da152141 100644 --- a/internal/watcher/watcher.go +++ b/internal/watcher/watcher.go @@ -21,6 +21,7 @@ import ( "github.com/fsnotify/fsnotify" "github.com/router-for-me/CLIProxyAPI/v6/internal/config" + kiroauth "github.com/router-for-me/CLIProxyAPI/v6/internal/auth/kiro" "github.com/router-for-me/CLIProxyAPI/v6/internal/runtime/geminicli" "gopkg.in/yaml.v3" @@ -176,6 +177,9 @@ func (w *Watcher) Start(ctx context.Context) error { } log.Debugf("watching auth directory: %s", w.authDir) + // Watch Kiro IDE token file directory for automatic token updates + w.watchKiroIDETokenFile() + // Start the event processing goroutine go w.processEvents(ctx) @@ -184,6 +188,31 @@ func (w *Watcher) Start(ctx context.Context) error { return nil } +// watchKiroIDETokenFile adds the Kiro IDE token file directory to the watcher. +// This enables automatic detection of token updates from Kiro IDE. +func (w *Watcher) watchKiroIDETokenFile() { + homeDir, err := os.UserHomeDir() + if err != nil { + log.Debugf("failed to get home directory for Kiro IDE token watch: %v", err) + return + } + + // Kiro IDE stores tokens in ~/.aws/sso/cache/ + kiroTokenDir := filepath.Join(homeDir, ".aws", "sso", "cache") + + // Check if directory exists + if _, statErr := os.Stat(kiroTokenDir); os.IsNotExist(statErr) { + log.Debugf("Kiro IDE token directory does not exist: %s", kiroTokenDir) + return + } + + if errAdd := w.watcher.Add(kiroTokenDir); errAdd != nil { + log.Debugf("failed to watch Kiro IDE token directory %s: %v", kiroTokenDir, errAdd) + return + } + log.Debugf("watching Kiro IDE token directory: %s", kiroTokenDir) +} + // Stop stops the file watcher func (w *Watcher) Stop() error { w.stopDispatch() @@ -744,10 +773,20 @@ func (w *Watcher) handleEvent(event fsnotify.Event) { isConfigEvent := event.Name == w.configPath && event.Op&configOps != 0 authOps := fsnotify.Create | fsnotify.Write | fsnotify.Remove | fsnotify.Rename isAuthJSON := strings.HasPrefix(event.Name, w.authDir) && strings.HasSuffix(event.Name, ".json") && event.Op&authOps != 0 - if !isConfigEvent && !isAuthJSON { + + // Check for Kiro IDE token file changes + isKiroIDEToken := w.isKiroIDETokenFile(event.Name) && event.Op&authOps != 0 + + if !isConfigEvent && !isAuthJSON && !isKiroIDEToken { // Ignore unrelated files (e.g., cookie snapshots *.cookie) and other noise. return } + + // Handle Kiro IDE token file changes + if isKiroIDEToken { + w.handleKiroIDETokenChange(event) + return + } now := time.Now() log.Debugf("file system event detected: %s %s", event.Op.String(), event.Name) @@ -805,6 +844,51 @@ func (w *Watcher) scheduleConfigReload() { }) } +// isKiroIDETokenFile checks if the given path is the Kiro IDE token file. +func (w *Watcher) isKiroIDETokenFile(path string) bool { + // Check if it's the kiro-auth-token.json file in ~/.aws/sso/cache/ + // Use filepath.ToSlash to ensure consistent separators across platforms (Windows uses backslashes) + normalized := filepath.ToSlash(path) + return strings.HasSuffix(normalized, "kiro-auth-token.json") && strings.Contains(normalized, ".aws/sso/cache") +} + +// handleKiroIDETokenChange processes changes to the Kiro IDE token file. +// When the token file is updated by Kiro IDE, this triggers a reload of Kiro auth. +func (w *Watcher) handleKiroIDETokenChange(event fsnotify.Event) { + log.Debugf("Kiro IDE token file event detected: %s %s", event.Op.String(), event.Name) + + if event.Op&(fsnotify.Remove|fsnotify.Rename) != 0 { + // Token file removed - wait briefly for potential atomic replace + time.Sleep(replaceCheckDelay) + if _, statErr := os.Stat(event.Name); statErr != nil { + log.Debugf("Kiro IDE token file removed: %s", event.Name) + return + } + } + + // Try to load the updated token + tokenData, err := kiroauth.LoadKiroIDEToken() + if err != nil { + log.Debugf("failed to load Kiro IDE token after change: %v", err) + return + } + + log.Infof("Kiro IDE token file updated, access token refreshed (provider: %s)", tokenData.Provider) + + // Trigger auth state refresh to pick up the new token + w.refreshAuthState() + + // Notify callback if set + w.clientsMutex.RLock() + cfg := w.config + w.clientsMutex.RUnlock() + + if w.reloadCallback != nil && cfg != nil { + log.Debugf("triggering server update callback after Kiro IDE token change") + w.reloadCallback(cfg) + } +} + func (w *Watcher) reloadConfigIfChanged() { data, err := os.ReadFile(w.configPath) if err != nil { @@ -1181,6 +1265,67 @@ func (w *Watcher) SnapshotCoreAuths() []*coreauth.Auth { applyAuthExcludedModelsMeta(a, cfg, ck.ExcludedModels, "apikey") out = append(out, a) } + // Kiro (AWS CodeWhisperer) -> synthesize auths + var kAuth *kiroauth.KiroAuth + if len(cfg.KiroKey) > 0 { + kAuth = kiroauth.NewKiroAuth(cfg) + } + for i := range cfg.KiroKey { + kk := cfg.KiroKey[i] + var accessToken, profileArn string + + // Try to load from token file first + if kk.TokenFile != "" && kAuth != nil { + tokenData, err := kAuth.LoadTokenFromFile(kk.TokenFile) + if err != nil { + log.Warnf("failed to load kiro token file %s: %v", kk.TokenFile, err) + } else { + accessToken = tokenData.AccessToken + profileArn = tokenData.ProfileArn + } + } + + // Override with direct config values if provided + if kk.AccessToken != "" { + accessToken = kk.AccessToken + } + if kk.ProfileArn != "" { + profileArn = kk.ProfileArn + } + + if accessToken == "" { + log.Warnf("kiro config[%d] missing access_token, skipping", i) + continue + } + + // profileArn is optional for AWS Builder ID users + id, token := idGen.next("kiro:token", accessToken, profileArn) + attrs := map[string]string{ + "source": fmt.Sprintf("config:kiro[%s]", token), + "access_token": accessToken, + } + if profileArn != "" { + attrs["profile_arn"] = profileArn + } + if kk.Region != "" { + attrs["region"] = kk.Region + } + if kk.AgentTaskType != "" { + attrs["agent_task_type"] = kk.AgentTaskType + } + proxyURL := strings.TrimSpace(kk.ProxyURL) + a := &coreauth.Auth{ + ID: id, + Provider: "kiro", + Label: "kiro-token", + Status: coreauth.StatusActive, + ProxyURL: proxyURL, + Attributes: attrs, + CreatedAt: now, + UpdatedAt: now, + } + out = append(out, a) + } for i := range cfg.OpenAICompatibility { compat := &cfg.OpenAICompatibility[i] providerName := strings.ToLower(strings.TrimSpace(compat.Name)) @@ -1287,7 +1432,12 @@ func (w *Watcher) SnapshotCoreAuths() []*coreauth.Auth { } // Also synthesize auth entries directly from auth files (for OAuth/file-backed providers) - entries, _ := os.ReadDir(w.authDir) + log.Debugf("SnapshotCoreAuths: scanning auth directory: %s", w.authDir) + entries, readErr := os.ReadDir(w.authDir) + if readErr != nil { + log.Errorf("SnapshotCoreAuths: failed to read auth directory %s: %v", w.authDir, readErr) + } + log.Debugf("SnapshotCoreAuths: found %d entries in auth directory", len(entries)) for _, e := range entries { if e.IsDir() { continue @@ -1306,9 +1456,20 @@ func (w *Watcher) SnapshotCoreAuths() []*coreauth.Auth { continue } t, _ := metadata["type"].(string) + + // Detect Kiro auth files by auth_method field (they don't have "type" field) if t == "" { + if authMethod, _ := metadata["auth_method"].(string); authMethod == "builder-id" || authMethod == "social" { + t = "kiro" + log.Debugf("SnapshotCoreAuths: detected Kiro auth by auth_method: %s", name) + } + } + + if t == "" { + log.Debugf("SnapshotCoreAuths: skipping file without type: %s", name) continue } + log.Debugf("SnapshotCoreAuths: processing auth file: %s (type=%s)", name, t) provider := strings.ToLower(t) if provider == "gemini" { provider = "gemini-cli" @@ -1317,6 +1478,12 @@ func (w *Watcher) SnapshotCoreAuths() []*coreauth.Auth { if email, _ := metadata["email"].(string); email != "" { label = email } + // For Kiro, use provider field as label if available + if provider == "kiro" { + if kiroProvider, _ := metadata["provider"].(string); kiroProvider != "" { + label = fmt.Sprintf("kiro-%s", strings.ToLower(kiroProvider)) + } + } // Use relative path under authDir as ID to stay consistent with the file-based token store id := full if rel, errRel := filepath.Rel(w.authDir, full); errRel == nil && rel != "" { @@ -1342,6 +1509,16 @@ func (w *Watcher) SnapshotCoreAuths() []*coreauth.Auth { CreatedAt: now, UpdatedAt: now, } + // Set NextRefreshAfter for Kiro auth based on expires_at + if provider == "kiro" { + if expiresAtStr, ok := metadata["expires_at"].(string); ok && expiresAtStr != "" { + if expiresAt, parseErr := time.Parse(time.RFC3339, expiresAtStr); parseErr == nil { + // Refresh 30 minutes before expiry + a.NextRefreshAfter = expiresAt.Add(-30 * time.Minute) + } + } + } + applyAuthExcludedModelsMeta(a, cfg, nil, "oauth") if provider == "gemini-cli" { if virtuals := synthesizeGeminiVirtualAuths(a, metadata, now); len(virtuals) > 0 { diff --git a/sdk/api/handlers/handlers.go b/sdk/api/handlers/handlers.go index 76280b3a..3bbf6f61 100644 --- a/sdk/api/handlers/handlers.go +++ b/sdk/api/handlers/handlers.go @@ -8,6 +8,7 @@ import ( "fmt" "net/http" "strings" + "sync" "github.com/gin-gonic/gin" "github.com/router-for-me/CLIProxyAPI/v6/internal/interfaces" @@ -50,7 +51,25 @@ type BaseAPIHandler struct { Cfg *config.SDKConfig // OpenAICompatProviders is a list of provider names for OpenAI compatibility. - OpenAICompatProviders []string + openAICompatProviders []string + openAICompatMutex sync.RWMutex +} + +// GetOpenAICompatProviders safely returns a copy of the provider names +func (h *BaseAPIHandler) GetOpenAICompatProviders() []string { + h.openAICompatMutex.RLock() + defer h.openAICompatMutex.RUnlock() + result := make([]string, len(h.openAICompatProviders)) + copy(result, h.openAICompatProviders) + return result +} + +// SetOpenAICompatProviders safely sets the provider names +func (h *BaseAPIHandler) SetOpenAICompatProviders(providers []string) { + h.openAICompatMutex.Lock() + defer h.openAICompatMutex.Unlock() + h.openAICompatProviders = make([]string, len(providers)) + copy(h.openAICompatProviders, providers) } // NewBaseAPIHandlers creates a new API handlers instance. @@ -63,11 +82,12 @@ type BaseAPIHandler struct { // Returns: // - *BaseAPIHandler: A new API handlers instance func NewBaseAPIHandlers(cfg *config.SDKConfig, authManager *coreauth.Manager, openAICompatProviders []string) *BaseAPIHandler { - return &BaseAPIHandler{ - Cfg: cfg, - AuthManager: authManager, - OpenAICompatProviders: openAICompatProviders, + h := &BaseAPIHandler{ + Cfg: cfg, + AuthManager: authManager, } + h.SetOpenAICompatProviders(openAICompatProviders) + return h } // UpdateClients updates the handlers' client list and configuration. @@ -363,7 +383,7 @@ func (h *BaseAPIHandler) parseDynamicModel(modelName string) (providerName, mode } // Check if the provider is a configured openai-compatibility provider - for _, pName := range h.OpenAICompatProviders { + for _, pName := range h.GetOpenAICompatProviders() { if pName == providerPart { return providerPart, modelPart, true } diff --git a/sdk/auth/kiro.go b/sdk/auth/kiro.go new file mode 100644 index 00000000..b95d103b --- /dev/null +++ b/sdk/auth/kiro.go @@ -0,0 +1,357 @@ +package auth + +import ( + "context" + "fmt" + "strings" + "time" + + kiroauth "github.com/router-for-me/CLIProxyAPI/v6/internal/auth/kiro" + "github.com/router-for-me/CLIProxyAPI/v6/internal/config" + coreauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth" +) + +// extractKiroIdentifier extracts a meaningful identifier for file naming. +// Returns account name if provided, otherwise profile ARN ID. +// All extracted values are sanitized to prevent path injection attacks. +func extractKiroIdentifier(accountName, profileArn string) string { + // Priority 1: Use account name if provided + if accountName != "" { + return kiroauth.SanitizeEmailForFilename(accountName) + } + + // Priority 2: Use profile ARN ID part (sanitized to prevent path injection) + if profileArn != "" { + parts := strings.Split(profileArn, "/") + if len(parts) >= 2 { + // Sanitize the ARN component to prevent path traversal + return kiroauth.SanitizeEmailForFilename(parts[len(parts)-1]) + } + } + + // Fallback: timestamp + return fmt.Sprintf("%d", time.Now().UnixNano()%100000) +} + +// KiroAuthenticator implements OAuth authentication for Kiro with Google login. +type KiroAuthenticator struct{} + +// NewKiroAuthenticator constructs a Kiro authenticator. +func NewKiroAuthenticator() *KiroAuthenticator { + return &KiroAuthenticator{} +} + +// Provider returns the provider key for the authenticator. +func (a *KiroAuthenticator) Provider() string { + return "kiro" +} + +// RefreshLead indicates how soon before expiry a refresh should be attempted. +func (a *KiroAuthenticator) RefreshLead() *time.Duration { + d := 30 * time.Minute + return &d +} + +// Login performs OAuth login for Kiro with AWS Builder ID. +func (a *KiroAuthenticator) Login(ctx context.Context, cfg *config.Config, opts *LoginOptions) (*coreauth.Auth, error) { + if cfg == nil { + return nil, fmt.Errorf("kiro auth: configuration is required") + } + + oauth := kiroauth.NewKiroOAuth(cfg) + + // Use AWS Builder ID device code flow + tokenData, err := oauth.LoginWithBuilderID(ctx) + if err != nil { + return nil, fmt.Errorf("login failed: %w", err) + } + + // Parse expires_at + expiresAt, err := time.Parse(time.RFC3339, tokenData.ExpiresAt) + if err != nil { + expiresAt = time.Now().Add(1 * time.Hour) + } + + // Extract identifier for file naming + idPart := extractKiroIdentifier(tokenData.Email, tokenData.ProfileArn) + + now := time.Now() + fileName := fmt.Sprintf("kiro-aws-%s.json", idPart) + + record := &coreauth.Auth{ + ID: fileName, + Provider: "kiro", + FileName: fileName, + Label: "kiro-aws", + Status: coreauth.StatusActive, + CreatedAt: now, + UpdatedAt: now, + Metadata: map[string]any{ + "type": "kiro", + "access_token": tokenData.AccessToken, + "refresh_token": tokenData.RefreshToken, + "profile_arn": tokenData.ProfileArn, + "expires_at": tokenData.ExpiresAt, + "auth_method": tokenData.AuthMethod, + "provider": tokenData.Provider, + "client_id": tokenData.ClientID, + "client_secret": tokenData.ClientSecret, + "email": tokenData.Email, + }, + Attributes: map[string]string{ + "profile_arn": tokenData.ProfileArn, + "source": "aws-builder-id", + "email": tokenData.Email, + }, + NextRefreshAfter: expiresAt.Add(-30 * time.Minute), + } + + if tokenData.Email != "" { + fmt.Printf("\n✓ Kiro authentication completed successfully! (Account: %s)\n", tokenData.Email) + } else { + fmt.Println("\n✓ Kiro authentication completed successfully!") + } + + return record, nil +} + +// LoginWithGoogle performs OAuth login for Kiro with Google. +// This uses a custom protocol handler (kiro://) to receive the callback. +func (a *KiroAuthenticator) LoginWithGoogle(ctx context.Context, cfg *config.Config, opts *LoginOptions) (*coreauth.Auth, error) { + if cfg == nil { + return nil, fmt.Errorf("kiro auth: configuration is required") + } + + oauth := kiroauth.NewKiroOAuth(cfg) + + // Use Google OAuth flow with protocol handler + tokenData, err := oauth.LoginWithGoogle(ctx) + if err != nil { + return nil, fmt.Errorf("google login failed: %w", err) + } + + // Parse expires_at + expiresAt, err := time.Parse(time.RFC3339, tokenData.ExpiresAt) + if err != nil { + expiresAt = time.Now().Add(1 * time.Hour) + } + + // Extract identifier for file naming + idPart := extractKiroIdentifier(tokenData.Email, tokenData.ProfileArn) + + now := time.Now() + fileName := fmt.Sprintf("kiro-google-%s.json", idPart) + + record := &coreauth.Auth{ + ID: fileName, + Provider: "kiro", + FileName: fileName, + Label: "kiro-google", + Status: coreauth.StatusActive, + CreatedAt: now, + UpdatedAt: now, + Metadata: map[string]any{ + "type": "kiro", + "access_token": tokenData.AccessToken, + "refresh_token": tokenData.RefreshToken, + "profile_arn": tokenData.ProfileArn, + "expires_at": tokenData.ExpiresAt, + "auth_method": tokenData.AuthMethod, + "provider": tokenData.Provider, + "email": tokenData.Email, + }, + Attributes: map[string]string{ + "profile_arn": tokenData.ProfileArn, + "source": "google-oauth", + "email": tokenData.Email, + }, + NextRefreshAfter: expiresAt.Add(-30 * time.Minute), + } + + if tokenData.Email != "" { + fmt.Printf("\n✓ Kiro Google authentication completed successfully! (Account: %s)\n", tokenData.Email) + } else { + fmt.Println("\n✓ Kiro Google authentication completed successfully!") + } + + return record, nil +} + +// LoginWithGitHub performs OAuth login for Kiro with GitHub. +// This uses a custom protocol handler (kiro://) to receive the callback. +func (a *KiroAuthenticator) LoginWithGitHub(ctx context.Context, cfg *config.Config, opts *LoginOptions) (*coreauth.Auth, error) { + if cfg == nil { + return nil, fmt.Errorf("kiro auth: configuration is required") + } + + oauth := kiroauth.NewKiroOAuth(cfg) + + // Use GitHub OAuth flow with protocol handler + tokenData, err := oauth.LoginWithGitHub(ctx) + if err != nil { + return nil, fmt.Errorf("github login failed: %w", err) + } + + // Parse expires_at + expiresAt, err := time.Parse(time.RFC3339, tokenData.ExpiresAt) + if err != nil { + expiresAt = time.Now().Add(1 * time.Hour) + } + + // Extract identifier for file naming + idPart := extractKiroIdentifier(tokenData.Email, tokenData.ProfileArn) + + now := time.Now() + fileName := fmt.Sprintf("kiro-github-%s.json", idPart) + + record := &coreauth.Auth{ + ID: fileName, + Provider: "kiro", + FileName: fileName, + Label: "kiro-github", + Status: coreauth.StatusActive, + CreatedAt: now, + UpdatedAt: now, + Metadata: map[string]any{ + "type": "kiro", + "access_token": tokenData.AccessToken, + "refresh_token": tokenData.RefreshToken, + "profile_arn": tokenData.ProfileArn, + "expires_at": tokenData.ExpiresAt, + "auth_method": tokenData.AuthMethod, + "provider": tokenData.Provider, + "email": tokenData.Email, + }, + Attributes: map[string]string{ + "profile_arn": tokenData.ProfileArn, + "source": "github-oauth", + "email": tokenData.Email, + }, + NextRefreshAfter: expiresAt.Add(-30 * time.Minute), + } + + if tokenData.Email != "" { + fmt.Printf("\n✓ Kiro GitHub authentication completed successfully! (Account: %s)\n", tokenData.Email) + } else { + fmt.Println("\n✓ Kiro GitHub authentication completed successfully!") + } + + return record, nil +} + +// ImportFromKiroIDE imports token from Kiro IDE's token file. +func (a *KiroAuthenticator) ImportFromKiroIDE(ctx context.Context, cfg *config.Config) (*coreauth.Auth, error) { + tokenData, err := kiroauth.LoadKiroIDEToken() + if err != nil { + return nil, fmt.Errorf("failed to load Kiro IDE token: %w", err) + } + + // Parse expires_at + expiresAt, err := time.Parse(time.RFC3339, tokenData.ExpiresAt) + if err != nil { + expiresAt = time.Now().Add(1 * time.Hour) + } + + // Extract email from JWT if not already set (for imported tokens) + if tokenData.Email == "" { + tokenData.Email = kiroauth.ExtractEmailFromJWT(tokenData.AccessToken) + } + + // Extract identifier for file naming + idPart := extractKiroIdentifier(tokenData.Email, tokenData.ProfileArn) + // Sanitize provider to prevent path traversal (defense-in-depth) + provider := kiroauth.SanitizeEmailForFilename(strings.ToLower(strings.TrimSpace(tokenData.Provider))) + if provider == "" { + provider = "imported" // Fallback for legacy tokens without provider + } + + now := time.Now() + fileName := fmt.Sprintf("kiro-%s-%s.json", provider, idPart) + + record := &coreauth.Auth{ + ID: fileName, + Provider: "kiro", + FileName: fileName, + Label: fmt.Sprintf("kiro-%s", provider), + Status: coreauth.StatusActive, + CreatedAt: now, + UpdatedAt: now, + Metadata: map[string]any{ + "type": "kiro", + "access_token": tokenData.AccessToken, + "refresh_token": tokenData.RefreshToken, + "profile_arn": tokenData.ProfileArn, + "expires_at": tokenData.ExpiresAt, + "auth_method": tokenData.AuthMethod, + "provider": tokenData.Provider, + "email": tokenData.Email, + }, + Attributes: map[string]string{ + "profile_arn": tokenData.ProfileArn, + "source": "kiro-ide-import", + "email": tokenData.Email, + }, + NextRefreshAfter: expiresAt.Add(-30 * time.Minute), + } + + // Display the email if extracted + if tokenData.Email != "" { + fmt.Printf("\n✓ Imported Kiro token from IDE (Provider: %s, Account: %s)\n", tokenData.Provider, tokenData.Email) + } else { + fmt.Printf("\n✓ Imported Kiro token from IDE (Provider: %s)\n", tokenData.Provider) + } + + return record, nil +} + +// Refresh refreshes an expired Kiro token using AWS SSO OIDC. +func (a *KiroAuthenticator) Refresh(ctx context.Context, cfg *config.Config, auth *coreauth.Auth) (*coreauth.Auth, error) { + if auth == nil || auth.Metadata == nil { + return nil, fmt.Errorf("invalid auth record") + } + + refreshToken, ok := auth.Metadata["refresh_token"].(string) + if !ok || refreshToken == "" { + return nil, fmt.Errorf("refresh token not found") + } + + clientID, _ := auth.Metadata["client_id"].(string) + clientSecret, _ := auth.Metadata["client_secret"].(string) + authMethod, _ := auth.Metadata["auth_method"].(string) + + var tokenData *kiroauth.KiroTokenData + var err error + + // Use SSO OIDC refresh for AWS Builder ID, otherwise use Kiro's OAuth refresh endpoint + if clientID != "" && clientSecret != "" && authMethod == "builder-id" { + ssoClient := kiroauth.NewSSOOIDCClient(cfg) + tokenData, err = ssoClient.RefreshToken(ctx, clientID, clientSecret, refreshToken) + } else { + // Fallback to Kiro's refresh endpoint (for social auth: Google/GitHub) + oauth := kiroauth.NewKiroOAuth(cfg) + tokenData, err = oauth.RefreshToken(ctx, refreshToken) + } + + if err != nil { + return nil, fmt.Errorf("token refresh failed: %w", err) + } + + // Parse expires_at + expiresAt, err := time.Parse(time.RFC3339, tokenData.ExpiresAt) + if err != nil { + expiresAt = time.Now().Add(1 * time.Hour) + } + + // Clone auth to avoid mutating the input parameter + updated := auth.Clone() + now := time.Now() + updated.UpdatedAt = now + updated.LastRefreshedAt = now + updated.Metadata["access_token"] = tokenData.AccessToken + updated.Metadata["refresh_token"] = tokenData.RefreshToken + updated.Metadata["expires_at"] = tokenData.ExpiresAt + updated.Metadata["last_refresh"] = now.Format(time.RFC3339) // For double-check optimization + updated.NextRefreshAfter = expiresAt.Add(-30 * time.Minute) + + return updated, nil +} diff --git a/sdk/auth/manager.go b/sdk/auth/manager.go index c6469a7d..d630f128 100644 --- a/sdk/auth/manager.go +++ b/sdk/auth/manager.go @@ -74,3 +74,16 @@ func (m *Manager) Login(ctx context.Context, provider string, cfg *config.Config } return record, savedPath, nil } + +// SaveAuth persists an auth record directly without going through the login flow. +func (m *Manager) SaveAuth(record *coreauth.Auth, cfg *config.Config) (string, error) { + if m.store == nil { + return "", fmt.Errorf("no store configured") + } + if cfg != nil { + if dirSetter, ok := m.store.(interface{ SetBaseDir(string) }); ok { + dirSetter.SetBaseDir(cfg.AuthDir) + } + } + return m.store.Save(context.Background(), record) +} diff --git a/sdk/auth/refresh_registry.go b/sdk/auth/refresh_registry.go index e82ac684..09406de5 100644 --- a/sdk/auth/refresh_registry.go +++ b/sdk/auth/refresh_registry.go @@ -14,6 +14,7 @@ func init() { registerRefreshLead("gemini", func() Authenticator { return NewGeminiAuthenticator() }) registerRefreshLead("gemini-cli", func() Authenticator { return NewGeminiAuthenticator() }) registerRefreshLead("antigravity", func() Authenticator { return NewAntigravityAuthenticator() }) + registerRefreshLead("kiro", func() Authenticator { return NewKiroAuthenticator() }) } func registerRefreshLead(provider string, factory func() Authenticator) { diff --git a/sdk/cliproxy/service.go b/sdk/cliproxy/service.go index 8b9a6639..7e1acb83 100644 --- a/sdk/cliproxy/service.go +++ b/sdk/cliproxy/service.go @@ -379,6 +379,8 @@ func (s *Service) ensureExecutorsForAuth(a *coreauth.Auth) { s.coreManager.RegisterExecutor(executor.NewQwenExecutor(s.cfg)) case "iflow": s.coreManager.RegisterExecutor(executor.NewIFlowExecutor(s.cfg)) + case "kiro": + s.coreManager.RegisterExecutor(executor.NewKiroExecutor(s.cfg)) default: providerKey := strings.ToLower(strings.TrimSpace(a.Provider)) if providerKey == "" { @@ -721,6 +723,9 @@ func (s *Service) registerModelsForAuth(a *coreauth.Auth) { case "iflow": models = registry.GetIFlowModels() models = applyExcludedModels(models, excluded) + case "kiro": + models = registry.GetKiroModels() + models = applyExcludedModels(models, excluded) default: // Handle OpenAI-compatibility providers by name using config if s.cfg != nil { From 9583f6b1c5124d1ac86f93361144eec5b81838ba Mon Sep 17 00:00:00 2001 From: Mansi Date: Fri, 5 Dec 2025 23:06:07 +0300 Subject: [PATCH 10/16] refactor: extract setKiroIncognitoMode helper to reduce code duplication - Added setKiroIncognitoMode() helper function to handle Kiro auth incognito mode setting - Replaced 3 duplicate code blocks (21 lines) with single function calls (3 lines) - Kiro auth defaults to incognito mode for multi-account support - Users can override with --incognito or --no-incognito flags This addresses the code duplication noted in PR #1 review. --- cmd/server/main.go | 37 ++++++++++++++++--------------------- 1 file changed, 16 insertions(+), 21 deletions(-) diff --git a/cmd/server/main.go b/cmd/server/main.go index 28c3e514..ef949187 100644 --- a/cmd/server/main.go +++ b/cmd/server/main.go @@ -47,6 +47,19 @@ func init() { buildinfo.BuildDate = BuildDate } +// setKiroIncognitoMode sets the incognito browser mode for Kiro authentication. +// Kiro defaults to incognito mode for multi-account support. +// Users can explicitly override with --incognito or --no-incognito flags. +func setKiroIncognitoMode(cfg *config.Config, useIncognito, noIncognito bool) { + if useIncognito { + cfg.IncognitoBrowser = true + } else if noIncognito { + cfg.IncognitoBrowser = false + } else { + cfg.IncognitoBrowser = true // Kiro default + } +} + // main is the entry point of the application. // It parses command-line flags, loads configuration, and starts the appropriate // service based on the provided flags (login, codex-login, or server mode). @@ -465,36 +478,18 @@ func main() { // Users can explicitly override with --no-incognito // Note: This config mutation is safe - auth commands exit after completion // and don't share config with StartService (which is in the else branch) - if useIncognito { - cfg.IncognitoBrowser = true - } else if noIncognito { - cfg.IncognitoBrowser = false - } else { - cfg.IncognitoBrowser = true // Kiro default - } + setKiroIncognitoMode(cfg, useIncognito, noIncognito) cmd.DoKiroLogin(cfg, options) } else if kiroGoogleLogin { // For Kiro auth, default to incognito mode for multi-account support // Users can explicitly override with --no-incognito // Note: This config mutation is safe - auth commands exit after completion - if useIncognito { - cfg.IncognitoBrowser = true - } else if noIncognito { - cfg.IncognitoBrowser = false - } else { - cfg.IncognitoBrowser = true // Kiro default - } + setKiroIncognitoMode(cfg, useIncognito, noIncognito) cmd.DoKiroGoogleLogin(cfg, options) } else if kiroAWSLogin { // For Kiro auth, default to incognito mode for multi-account support // Users can explicitly override with --no-incognito - if useIncognito { - cfg.IncognitoBrowser = true - } else if noIncognito { - cfg.IncognitoBrowser = false - } else { - cfg.IncognitoBrowser = true // Kiro default - } + setKiroIncognitoMode(cfg, useIncognito, noIncognito) cmd.DoKiroAWSLogin(cfg, options) } else if kiroImport { cmd.DoKiroImport(cfg, options) From df83ba877f40e5528ae8348e935d54829217071b Mon Sep 17 00:00:00 2001 From: Mansi Date: Fri, 5 Dec 2025 23:16:48 +0300 Subject: [PATCH 11/16] refactor: replace custom stringContains with strings.Contains - Remove custom stringContains and findSubstring helper functions - Use standard library strings.Contains for better maintainability - No functional change, just cleaner code Addresses Gemini Code Assist review feedback --- internal/browser/browser.go | 36 +++++++++++------------------------- 1 file changed, 11 insertions(+), 25 deletions(-) diff --git a/internal/browser/browser.go b/internal/browser/browser.go index 1ff0b469..3a5aeea7 100644 --- a/internal/browser/browser.go +++ b/internal/browser/browser.go @@ -6,6 +6,7 @@ import ( "fmt" "os/exec" "runtime" + "strings" "sync" pkgbrowser "github.com/pkg/browser" @@ -208,22 +209,7 @@ func tryDefaultBrowserMacOS(url string) *exec.Cmd { // containsBrowserID checks if the LaunchServices output contains a browser ID. func containsBrowserID(output, bundleID string) bool { - return stringContains(output, bundleID) -} - -// stringContains is a simple string contains check. -func stringContains(s, substr string) bool { - return len(s) >= len(substr) && (s == substr || len(substr) == 0 || - (len(s) > 0 && len(substr) > 0 && findSubstring(s, substr))) -} - -func findSubstring(s, substr string) bool { - for i := 0; i <= len(s)-len(substr); i++ { - if s[i:i+len(substr)] == substr { - return true - } - } - return false + return strings.Contains(output, bundleID) } // createMacOSIncognitoCmd creates the appropriate incognito command for macOS browsers. @@ -287,13 +273,13 @@ func tryDefaultBrowserWindows(url string) *exec.Cmd { var browserName string // Map ProgId to browser name - if stringContains(output, "ChromeHTML") { + if strings.Contains(output, "ChromeHTML") { browserName = "chrome" - } else if stringContains(output, "FirefoxURL") { + } else if strings.Contains(output, "FirefoxURL") { browserName = "firefox" - } else if stringContains(output, "MSEdgeHTM") { + } else if strings.Contains(output, "MSEdgeHTM") { browserName = "edge" - } else if stringContains(output, "BraveHTML") { + } else if strings.Contains(output, "BraveHTML") { browserName = "brave" } @@ -354,15 +340,15 @@ func tryDefaultBrowserLinux(url string) *exec.Cmd { var browserName string // Map .desktop file to browser name - if stringContains(desktop, "google-chrome") || stringContains(desktop, "chrome") { + if strings.Contains(desktop, "google-chrome") || strings.Contains(desktop, "chrome") { browserName = "chrome" - } else if stringContains(desktop, "firefox") { + } else if strings.Contains(desktop, "firefox") { browserName = "firefox" - } else if stringContains(desktop, "chromium") { + } else if strings.Contains(desktop, "chromium") { browserName = "chromium" - } else if stringContains(desktop, "brave") { + } else if strings.Contains(desktop, "brave") { browserName = "brave" - } else if stringContains(desktop, "microsoft-edge") || stringContains(desktop, "msedge") { + } else if strings.Contains(desktop, "microsoft-edge") || strings.Contains(desktop, "msedge") { browserName = "edge" } From b06463c6d92450bedc696350fc832b870594b1b2 Mon Sep 17 00:00:00 2001 From: Luis Pater Date: Sat, 6 Dec 2025 12:04:36 +0800 Subject: [PATCH 12/16] docs: update README to include Kiro (AWS CodeWhisperer) support integration details --- README.md | 1 + README_CN.md | 1 + 2 files changed, 2 insertions(+) diff --git a/README.md b/README.md index 8e27ab05..b4037180 100644 --- a/README.md +++ b/README.md @@ -11,6 +11,7 @@ The Plus release stays in lockstep with the mainline features. ## Differences from the Mainline - Added GitHub Copilot support (OAuth login), provided by [em4gp](https://github.com/em4go/CLIProxyAPI/tree/feature/github-copilot-auth) +- Added Kiro (AWS CodeWhisperer) support (OAuth login), provided by [fuko2935](https://github.com/fuko2935/CLIProxyAPI/tree/feature/kiro-integration) ## Contributing diff --git a/README_CN.md b/README_CN.md index 84e90d40..f3260e9b 100644 --- a/README_CN.md +++ b/README_CN.md @@ -11,6 +11,7 @@ ## 与主线版本版本差异 - 新增 GitHub Copilot 支持(OAuth 登录),由[em4gp](https://github.com/em4go/CLIProxyAPI/tree/feature/github-copilot-auth)提供 +- 新增 Kiro (AWS CodeWhisperer) 支持 (OAuth 登录), 由[fuko2935](https://github.com/fuko2935/CLIProxyAPI/tree/feature/kiro-integration)提供 ## 贡献 From b73e53d6c41eedb64d445c2fccba831c73b6533a Mon Sep 17 00:00:00 2001 From: Luis Pater Date: Sat, 6 Dec 2025 12:14:23 +0800 Subject: [PATCH 13/16] docs: update README to fix formatting for GitHub Copilot and Kiro support sections --- README.md | 2 +- README_CN.md | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/README.md b/README.md index b4037180..7a552135 100644 --- a/README.md +++ b/README.md @@ -10,7 +10,7 @@ The Plus release stays in lockstep with the mainline features. ## Differences from the Mainline -- Added GitHub Copilot support (OAuth login), provided by [em4gp](https://github.com/em4go/CLIProxyAPI/tree/feature/github-copilot-auth) +- Added GitHub Copilot support (OAuth login), provided by [em4go](https://github.com/em4go/CLIProxyAPI/tree/feature/github-copilot-auth) - Added Kiro (AWS CodeWhisperer) support (OAuth login), provided by [fuko2935](https://github.com/fuko2935/CLIProxyAPI/tree/feature/kiro-integration) ## Contributing diff --git a/README_CN.md b/README_CN.md index f3260e9b..163bec07 100644 --- a/README_CN.md +++ b/README_CN.md @@ -10,7 +10,7 @@ ## 与主线版本版本差异 -- 新增 GitHub Copilot 支持(OAuth 登录),由[em4gp](https://github.com/em4go/CLIProxyAPI/tree/feature/github-copilot-auth)提供 +- 新增 GitHub Copilot 支持(OAuth 登录),由[em4go](https://github.com/em4go/CLIProxyAPI/tree/feature/github-copilot-auth)提供 - 新增 Kiro (AWS CodeWhisperer) 支持 (OAuth 登录), 由[fuko2935](https://github.com/fuko2935/CLIProxyAPI/tree/feature/kiro-integration)提供 ## 贡献 From 68cbe2066453e24dfd65ca249f83a0a66aac98c8 Mon Sep 17 00:00:00 2001 From: Your Name Date: Sun, 7 Dec 2025 17:56:06 +0800 Subject: [PATCH 14/16] =?UTF-8?q?feat:=20=E6=B7=BB=E5=8A=A0Kiro=E6=B8=A0?= =?UTF-8?q?=E9=81=93=E5=9B=BE=E7=89=87=E6=94=AF=E6=8C=81=E5=8A=9F=E8=83=BD?= =?UTF-8?q?=EF=BC=8C=E5=80=9F=E9=89=B4justlovemaki/AIClient-2-API=E5=AE=9E?= =?UTF-8?q?=E7=8E=B0?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- internal/runtime/executor/kiro_executor.go | 43 +++++++++++++++++++++- 1 file changed, 41 insertions(+), 2 deletions(-) diff --git a/internal/runtime/executor/kiro_executor.go b/internal/runtime/executor/kiro_executor.go index eb068c14..0157d68c 100644 --- a/internal/runtime/executor/kiro_executor.go +++ b/internal/runtime/executor/kiro_executor.go @@ -604,10 +604,22 @@ type kiroHistoryMessage struct { AssistantResponseMessage *kiroAssistantResponseMessage `json:"assistantResponseMessage,omitempty"` } +// kiroImage represents an image in Kiro API format +type kiroImage struct { + Format string `json:"format"` + Source kiroImageSource `json:"source"` +} + +// kiroImageSource contains the image data +type kiroImageSource struct { + Bytes string `json:"bytes"` // base64 encoded image data +} + type kiroUserInputMessage struct { Content string `json:"content"` ModelID string `json:"modelId"` Origin string `json:"origin"` + Images []kiroImage `json:"images,omitempty"` UserInputMessageContext *kiroUserInputMessageContext `json:"userInputMessageContext,omitempty"` } @@ -824,6 +836,7 @@ func (e *KiroExecutor) buildUserMessageStruct(msg gjson.Result, modelID, origin content := msg.Get("content") var contentBuilder strings.Builder var toolResults []kiroToolResult + var images []kiroImage if content.IsArray() { for _, part := range content.Array() { @@ -831,6 +844,25 @@ func (e *KiroExecutor) buildUserMessageStruct(msg gjson.Result, modelID, origin switch partType { case "text": contentBuilder.WriteString(part.Get("text").String()) + case "image": + // Extract image data from Claude API format + mediaType := part.Get("source.media_type").String() + data := part.Get("source.data").String() + + // Extract format from media_type (e.g., "image/png" -> "png") + format := "" + if idx := strings.LastIndex(mediaType, "/"); idx != -1 { + format = mediaType[idx+1:] + } + + if format != "" && data != "" { + images = append(images, kiroImage{ + Format: format, + Source: kiroImageSource{ + Bytes: data, + }, + }) + } case "tool_result": // Extract tool result for API toolUseID := part.Get("tool_use_id").String() @@ -872,11 +904,18 @@ func (e *KiroExecutor) buildUserMessageStruct(msg gjson.Result, modelID, origin contentBuilder.WriteString(content.String()) } - return kiroUserInputMessage{ + userMsg := kiroUserInputMessage{ Content: contentBuilder.String(), ModelID: modelID, Origin: origin, - }, toolResults + } + + // Add images to message if present + if len(images) > 0 { + userMsg.Images = images + } + + return userMsg, toolResults } // buildAssistantMessageStruct builds an assistant message with tool uses From a0c6cffb0da0d53d94a6c6bfb8cdf528773a09ee Mon Sep 17 00:00:00 2001 From: Ravens2121 Date: Sun, 7 Dec 2025 21:55:13 +0800 Subject: [PATCH 15/16] =?UTF-8?q?fix(kiro):=E4=BF=AE=E5=A4=8D=20base64=20?= =?UTF-8?q?=E5=9B=BE=E7=89=87=E6=A0=BC=E5=BC=8F=E8=BD=AC=E6=8D=A2=E9=97=AE?= =?UTF-8?q?=E9=A2=98=20(#10)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../chat-completions/kiro_openai_request.go | 89 ++++++++++++++++--- 1 file changed, 75 insertions(+), 14 deletions(-) diff --git a/internal/translator/kiro/openai/chat-completions/kiro_openai_request.go b/internal/translator/kiro/openai/chat-completions/kiro_openai_request.go index fc850d96..3d339505 100644 --- a/internal/translator/kiro/openai/chat-completions/kiro_openai_request.go +++ b/internal/translator/kiro/openai/chat-completions/kiro_openai_request.go @@ -4,6 +4,7 @@ package chat_completions import ( "bytes" "encoding/json" + "strings" "github.com/tidwall/gjson" "github.com/tidwall/sjson" @@ -182,13 +183,43 @@ func ConvertOpenAIRequestToKiro(modelName string, inputRawJSON []byte, stream bo "text": part.Get("text").String(), }) } else if partType == "image_url" { - contentParts = append(contentParts, map[string]interface{}{ - "type": "image", - "source": map[string]interface{}{ - "type": "url", - "url": part.Get("image_url.url").String(), - }, - }) + imageURL := part.Get("image_url.url").String() + + // Check if it's base64 format (data:image/png;base64,xxxxx) + if strings.HasPrefix(imageURL, "data:") { + // Parse data URL format + // Format: data:image/png;base64,xxxxx + commaIdx := strings.Index(imageURL, ",") + if commaIdx != -1 { + // Extract media_type (e.g., "image/png") + header := imageURL[5:commaIdx] // Remove "data:" prefix + mediaType := header + if semiIdx := strings.Index(header, ";"); semiIdx != -1 { + mediaType = header[:semiIdx] + } + + // Extract base64 data + base64Data := imageURL[commaIdx+1:] + + contentParts = append(contentParts, map[string]interface{}{ + "type": "image", + "source": map[string]interface{}{ + "type": "base64", + "media_type": mediaType, + "data": base64Data, + }, + }) + } + } else { + // Regular URL format - keep original logic + contentParts = append(contentParts, map[string]interface{}{ + "type": "image", + "source": map[string]interface{}{ + "type": "url", + "url": imageURL, + }, + }) + } } } } else if content.String() != "" { @@ -214,13 +245,43 @@ func ConvertOpenAIRequestToKiro(modelName string, inputRawJSON []byte, stream bo "text": part.Get("text").String(), }) } else if partType == "image_url" { - contentParts = append(contentParts, map[string]interface{}{ - "type": "image", - "source": map[string]interface{}{ - "type": "url", - "url": part.Get("image_url.url").String(), - }, - }) + imageURL := part.Get("image_url.url").String() + + // Check if it's base64 format (data:image/png;base64,xxxxx) + if strings.HasPrefix(imageURL, "data:") { + // Parse data URL format + // Format: data:image/png;base64,xxxxx + commaIdx := strings.Index(imageURL, ",") + if commaIdx != -1 { + // Extract media_type (e.g., "image/png") + header := imageURL[5:commaIdx] // Remove "data:" prefix + mediaType := header + if semiIdx := strings.Index(header, ";"); semiIdx != -1 { + mediaType = header[:semiIdx] + } + + // Extract base64 data + base64Data := imageURL[commaIdx+1:] + + contentParts = append(contentParts, map[string]interface{}{ + "type": "image", + "source": map[string]interface{}{ + "type": "base64", + "media_type": mediaType, + "data": base64Data, + }, + }) + } + } else { + // Regular URL format - keep original logic + contentParts = append(contentParts, map[string]interface{}{ + "type": "image", + "source": map[string]interface{}{ + "type": "url", + "url": imageURL, + }, + }) + } } } claudeMsg["content"] = contentParts From ab9e9442ec0a63a29ebf6f9468d330fd76280bb9 Mon Sep 17 00:00:00 2001 From: Luis Pater Date: Mon, 8 Dec 2025 22:32:29 +0800 Subject: [PATCH 16/16] v6.5.56 (#12) * feat(api): add comprehensive ampcode management endpoints Add new REST API endpoints under /v0/management/ampcode for managing ampcode configuration including upstream URL, API key, localhost restriction, model mappings, and force model mappings settings. - Move force-model-mappings from config_basic to config_lists - Add GET/PUT/PATCH/DELETE endpoints for all ampcode settings - Support model mapping CRUD with upsert (PATCH) capability - Add comprehensive test coverage for all ampcode endpoints * refactor(api): simplify request body parsing in ampcode handlers * feat(logging): add upstream API request/response capture to streaming logs * style(logging): remove redundant separator line from response section * feat(antigravity): enforce thinking budget limits for Claude models * refactor(logging): remove unused variable in `ensureAttempt` and redundant function call --------- Co-authored-by: hkfires <10558748+hkfires@users.noreply.github.com> --- .../api/handlers/management/config_basic.go | 8 - .../api/handlers/management/config_lists.go | 152 ++++ internal/api/handlers/management/handler.go | 10 - internal/api/middleware/response_writer.go | 9 + internal/api/server.go | 22 +- internal/logging/request_logger.go | 202 ++++- internal/registry/model_definitions.go | 31 +- .../runtime/executor/antigravity_executor.go | 63 +- internal/runtime/executor/logging_helpers.go | 4 +- .../antigravity_openai_request.go | 5 +- test/amp_management_test.go | 827 ++++++++++++++++++ 11 files changed, 1263 insertions(+), 70 deletions(-) create mode 100644 test/amp_management_test.go diff --git a/internal/api/handlers/management/config_basic.go b/internal/api/handlers/management/config_basic.go index c788aca4..f9069198 100644 --- a/internal/api/handlers/management/config_basic.go +++ b/internal/api/handlers/management/config_basic.go @@ -241,11 +241,3 @@ func (h *Handler) DeleteProxyURL(c *gin.Context) { h.cfg.ProxyURL = "" h.persist(c) } - -// Force Model Mappings (for Amp CLI) -func (h *Handler) GetForceModelMappings(c *gin.Context) { - c.JSON(200, gin.H{"force-model-mappings": h.cfg.AmpCode.ForceModelMappings}) -} -func (h *Handler) PutForceModelMappings(c *gin.Context) { - h.updateBoolField(c, func(v bool) { h.cfg.AmpCode.ForceModelMappings = v }) -} diff --git a/internal/api/handlers/management/config_lists.go b/internal/api/handlers/management/config_lists.go index 8f4c4037..a0d0b169 100644 --- a/internal/api/handlers/management/config_lists.go +++ b/internal/api/handlers/management/config_lists.go @@ -706,3 +706,155 @@ func normalizeClaudeKey(entry *config.ClaudeKey) { } entry.Models = normalized } + +// GetAmpCode returns the complete ampcode configuration. +func (h *Handler) GetAmpCode(c *gin.Context) { + if h == nil || h.cfg == nil { + c.JSON(200, gin.H{"ampcode": config.AmpCode{}}) + return + } + c.JSON(200, gin.H{"ampcode": h.cfg.AmpCode}) +} + +// GetAmpUpstreamURL returns the ampcode upstream URL. +func (h *Handler) GetAmpUpstreamURL(c *gin.Context) { + if h == nil || h.cfg == nil { + c.JSON(200, gin.H{"upstream-url": ""}) + return + } + c.JSON(200, gin.H{"upstream-url": h.cfg.AmpCode.UpstreamURL}) +} + +// PutAmpUpstreamURL updates the ampcode upstream URL. +func (h *Handler) PutAmpUpstreamURL(c *gin.Context) { + h.updateStringField(c, func(v string) { h.cfg.AmpCode.UpstreamURL = strings.TrimSpace(v) }) +} + +// DeleteAmpUpstreamURL clears the ampcode upstream URL. +func (h *Handler) DeleteAmpUpstreamURL(c *gin.Context) { + h.cfg.AmpCode.UpstreamURL = "" + h.persist(c) +} + +// GetAmpUpstreamAPIKey returns the ampcode upstream API key. +func (h *Handler) GetAmpUpstreamAPIKey(c *gin.Context) { + if h == nil || h.cfg == nil { + c.JSON(200, gin.H{"upstream-api-key": ""}) + return + } + c.JSON(200, gin.H{"upstream-api-key": h.cfg.AmpCode.UpstreamAPIKey}) +} + +// PutAmpUpstreamAPIKey updates the ampcode upstream API key. +func (h *Handler) PutAmpUpstreamAPIKey(c *gin.Context) { + h.updateStringField(c, func(v string) { h.cfg.AmpCode.UpstreamAPIKey = strings.TrimSpace(v) }) +} + +// DeleteAmpUpstreamAPIKey clears the ampcode upstream API key. +func (h *Handler) DeleteAmpUpstreamAPIKey(c *gin.Context) { + h.cfg.AmpCode.UpstreamAPIKey = "" + h.persist(c) +} + +// GetAmpRestrictManagementToLocalhost returns the localhost restriction setting. +func (h *Handler) GetAmpRestrictManagementToLocalhost(c *gin.Context) { + if h == nil || h.cfg == nil { + c.JSON(200, gin.H{"restrict-management-to-localhost": true}) + return + } + c.JSON(200, gin.H{"restrict-management-to-localhost": h.cfg.AmpCode.RestrictManagementToLocalhost}) +} + +// PutAmpRestrictManagementToLocalhost updates the localhost restriction setting. +func (h *Handler) PutAmpRestrictManagementToLocalhost(c *gin.Context) { + h.updateBoolField(c, func(v bool) { h.cfg.AmpCode.RestrictManagementToLocalhost = v }) +} + +// GetAmpModelMappings returns the ampcode model mappings. +func (h *Handler) GetAmpModelMappings(c *gin.Context) { + if h == nil || h.cfg == nil { + c.JSON(200, gin.H{"model-mappings": []config.AmpModelMapping{}}) + return + } + c.JSON(200, gin.H{"model-mappings": h.cfg.AmpCode.ModelMappings}) +} + +// PutAmpModelMappings replaces all ampcode model mappings. +func (h *Handler) PutAmpModelMappings(c *gin.Context) { + var body struct { + Value []config.AmpModelMapping `json:"value"` + } + if err := c.ShouldBindJSON(&body); err != nil { + c.JSON(400, gin.H{"error": "invalid body"}) + return + } + h.cfg.AmpCode.ModelMappings = body.Value + h.persist(c) +} + +// PatchAmpModelMappings adds or updates model mappings. +func (h *Handler) PatchAmpModelMappings(c *gin.Context) { + var body struct { + Value []config.AmpModelMapping `json:"value"` + } + if err := c.ShouldBindJSON(&body); err != nil { + c.JSON(400, gin.H{"error": "invalid body"}) + return + } + + existing := make(map[string]int) + for i, m := range h.cfg.AmpCode.ModelMappings { + existing[strings.TrimSpace(m.From)] = i + } + + for _, newMapping := range body.Value { + from := strings.TrimSpace(newMapping.From) + if idx, ok := existing[from]; ok { + h.cfg.AmpCode.ModelMappings[idx] = newMapping + } else { + h.cfg.AmpCode.ModelMappings = append(h.cfg.AmpCode.ModelMappings, newMapping) + existing[from] = len(h.cfg.AmpCode.ModelMappings) - 1 + } + } + h.persist(c) +} + +// DeleteAmpModelMappings removes specified model mappings by "from" field. +func (h *Handler) DeleteAmpModelMappings(c *gin.Context) { + var body struct { + Value []string `json:"value"` + } + if err := c.ShouldBindJSON(&body); err != nil || len(body.Value) == 0 { + h.cfg.AmpCode.ModelMappings = nil + h.persist(c) + return + } + + toRemove := make(map[string]bool) + for _, from := range body.Value { + toRemove[strings.TrimSpace(from)] = true + } + + newMappings := make([]config.AmpModelMapping, 0, len(h.cfg.AmpCode.ModelMappings)) + for _, m := range h.cfg.AmpCode.ModelMappings { + if !toRemove[strings.TrimSpace(m.From)] { + newMappings = append(newMappings, m) + } + } + h.cfg.AmpCode.ModelMappings = newMappings + h.persist(c) +} + +// GetAmpForceModelMappings returns whether model mappings are forced. +func (h *Handler) GetAmpForceModelMappings(c *gin.Context) { + if h == nil || h.cfg == nil { + c.JSON(200, gin.H{"force-model-mappings": false}) + return + } + c.JSON(200, gin.H{"force-model-mappings": h.cfg.AmpCode.ForceModelMappings}) +} + +// PutAmpForceModelMappings updates the force model mappings setting. +func (h *Handler) PutAmpForceModelMappings(c *gin.Context) { + h.updateBoolField(c, func(v bool) { h.cfg.AmpCode.ForceModelMappings = v }) +} diff --git a/internal/api/handlers/management/handler.go b/internal/api/handlers/management/handler.go index ef6f400a..39e6b7fd 100644 --- a/internal/api/handlers/management/handler.go +++ b/internal/api/handlers/management/handler.go @@ -240,16 +240,6 @@ func (h *Handler) updateBoolField(c *gin.Context, set func(bool)) { Value *bool `json:"value"` } if err := c.ShouldBindJSON(&body); err != nil || body.Value == nil { - var m map[string]any - if err2 := c.ShouldBindJSON(&m); err2 == nil { - for _, v := range m { - if b, ok := v.(bool); ok { - set(b) - h.persist(c) - return - } - } - } c.JSON(http.StatusBadRequest, gin.H{"error": "invalid body"}) return } diff --git a/internal/api/middleware/response_writer.go b/internal/api/middleware/response_writer.go index f0d1ad26..b7259bc6 100644 --- a/internal/api/middleware/response_writer.go +++ b/internal/api/middleware/response_writer.go @@ -232,7 +232,16 @@ func (w *ResponseWriterWrapper) Finalize(c *gin.Context) error { w.streamDone = nil } + // Write API Request and Response to the streaming log before closing if w.streamWriter != nil { + apiRequest := w.extractAPIRequest(c) + if len(apiRequest) > 0 { + _ = w.streamWriter.WriteAPIRequest(apiRequest) + } + apiResponse := w.extractAPIResponse(c) + if len(apiResponse) > 0 { + _ = w.streamWriter.WriteAPIResponse(apiResponse) + } if err := w.streamWriter.Close(); err != nil { w.streamWriter = nil return err diff --git a/internal/api/server.go b/internal/api/server.go index 1f35429e..2e463de4 100644 --- a/internal/api/server.go +++ b/internal/api/server.go @@ -520,9 +520,25 @@ func (s *Server) registerManagementRoutes() { mgmt.PUT("/ws-auth", s.mgmt.PutWebsocketAuth) mgmt.PATCH("/ws-auth", s.mgmt.PutWebsocketAuth) - mgmt.GET("/force-model-mappings", s.mgmt.GetForceModelMappings) - mgmt.PUT("/force-model-mappings", s.mgmt.PutForceModelMappings) - mgmt.PATCH("/force-model-mappings", s.mgmt.PutForceModelMappings) + mgmt.GET("/ampcode", s.mgmt.GetAmpCode) + mgmt.GET("/ampcode/upstream-url", s.mgmt.GetAmpUpstreamURL) + mgmt.PUT("/ampcode/upstream-url", s.mgmt.PutAmpUpstreamURL) + mgmt.PATCH("/ampcode/upstream-url", s.mgmt.PutAmpUpstreamURL) + mgmt.DELETE("/ampcode/upstream-url", s.mgmt.DeleteAmpUpstreamURL) + mgmt.GET("/ampcode/upstream-api-key", s.mgmt.GetAmpUpstreamAPIKey) + mgmt.PUT("/ampcode/upstream-api-key", s.mgmt.PutAmpUpstreamAPIKey) + mgmt.PATCH("/ampcode/upstream-api-key", s.mgmt.PutAmpUpstreamAPIKey) + mgmt.DELETE("/ampcode/upstream-api-key", s.mgmt.DeleteAmpUpstreamAPIKey) + mgmt.GET("/ampcode/restrict-management-to-localhost", s.mgmt.GetAmpRestrictManagementToLocalhost) + mgmt.PUT("/ampcode/restrict-management-to-localhost", s.mgmt.PutAmpRestrictManagementToLocalhost) + mgmt.PATCH("/ampcode/restrict-management-to-localhost", s.mgmt.PutAmpRestrictManagementToLocalhost) + mgmt.GET("/ampcode/model-mappings", s.mgmt.GetAmpModelMappings) + mgmt.PUT("/ampcode/model-mappings", s.mgmt.PutAmpModelMappings) + mgmt.PATCH("/ampcode/model-mappings", s.mgmt.PatchAmpModelMappings) + mgmt.DELETE("/ampcode/model-mappings", s.mgmt.DeleteAmpModelMappings) + mgmt.GET("/ampcode/force-model-mappings", s.mgmt.GetAmpForceModelMappings) + mgmt.PUT("/ampcode/force-model-mappings", s.mgmt.PutAmpForceModelMappings) + mgmt.PATCH("/ampcode/force-model-mappings", s.mgmt.PutAmpForceModelMappings) mgmt.GET("/request-retry", s.mgmt.GetRequestRetry) mgmt.PUT("/request-retry", s.mgmt.PutRequestRetry) diff --git a/internal/logging/request_logger.go b/internal/logging/request_logger.go index c574febb..f8c068c5 100644 --- a/internal/logging/request_logger.go +++ b/internal/logging/request_logger.go @@ -84,6 +84,26 @@ type StreamingLogWriter interface { // - error: An error if writing fails, nil otherwise WriteStatus(status int, headers map[string][]string) error + // WriteAPIRequest writes the upstream API request details to the log. + // This should be called before WriteStatus to maintain proper log ordering. + // + // Parameters: + // - apiRequest: The API request data (typically includes URL, headers, body sent upstream) + // + // Returns: + // - error: An error if writing fails, nil otherwise + WriteAPIRequest(apiRequest []byte) error + + // WriteAPIResponse writes the upstream API response details to the log. + // This should be called after the streaming response is complete. + // + // Parameters: + // - apiResponse: The API response data + // + // Returns: + // - error: An error if writing fails, nil otherwise + WriteAPIResponse(apiResponse []byte) error + // Close finalizes the log file and cleans up resources. // // Returns: @@ -248,10 +268,11 @@ func (l *FileRequestLogger) LogStreamingRequest(url, method string, headers map[ // Create streaming writer writer := &FileStreamingLogWriter{ - file: file, - chunkChan: make(chan []byte, 100), // Buffered channel for async writes - closeChan: make(chan struct{}), - errorChan: make(chan error, 1), + file: file, + chunkChan: make(chan []byte, 100), // Buffered channel for async writes + closeChan: make(chan struct{}), + errorChan: make(chan error, 1), + bufferedChunks: &bytes.Buffer{}, } // Start async writer goroutine @@ -628,11 +649,12 @@ func (l *FileRequestLogger) formatRequestInfo(url, method string, headers map[st // FileStreamingLogWriter implements StreamingLogWriter for file-based streaming logs. // It handles asynchronous writing of streaming response chunks to a file. +// All data is buffered and written in the correct order when Close is called. type FileStreamingLogWriter struct { // file is the file where log data is written. file *os.File - // chunkChan is a channel for receiving response chunks to write. + // chunkChan is a channel for receiving response chunks to buffer. chunkChan chan []byte // closeChan is a channel for signaling when the writer is closed. @@ -641,8 +663,23 @@ type FileStreamingLogWriter struct { // errorChan is a channel for reporting errors during writing. errorChan chan error - // statusWritten indicates whether the response status has been written. + // bufferedChunks stores the response chunks in order. + bufferedChunks *bytes.Buffer + + // responseStatus stores the HTTP status code. + responseStatus int + + // statusWritten indicates whether a non-zero status was recorded. statusWritten bool + + // responseHeaders stores the response headers. + responseHeaders map[string][]string + + // apiRequest stores the upstream API request data. + apiRequest []byte + + // apiResponse stores the upstream API response data. + apiResponse []byte } // WriteChunkAsync writes a response chunk asynchronously (non-blocking). @@ -666,39 +703,65 @@ func (w *FileStreamingLogWriter) WriteChunkAsync(chunk []byte) { } } -// WriteStatus writes the response status and headers to the log. +// WriteStatus buffers the response status and headers for later writing. // // Parameters: // - status: The response status code // - headers: The response headers // // Returns: -// - error: An error if writing fails, nil otherwise +// - error: Always returns nil (buffering cannot fail) func (w *FileStreamingLogWriter) WriteStatus(status int, headers map[string][]string) error { - if w.file == nil || w.statusWritten { + if status == 0 { return nil } - var content strings.Builder - content.WriteString("========================================\n") - content.WriteString("=== RESPONSE ===\n") - content.WriteString(fmt.Sprintf("Status: %d\n", status)) - - for key, values := range headers { - for _, value := range values { - content.WriteString(fmt.Sprintf("%s: %s\n", key, value)) + w.responseStatus = status + if headers != nil { + w.responseHeaders = make(map[string][]string, len(headers)) + for key, values := range headers { + headerValues := make([]string, len(values)) + copy(headerValues, values) + w.responseHeaders[key] = headerValues } } - content.WriteString("\n") + w.statusWritten = true + return nil +} - _, err := w.file.WriteString(content.String()) - if err == nil { - w.statusWritten = true +// WriteAPIRequest buffers the upstream API request details for later writing. +// +// Parameters: +// - apiRequest: The API request data (typically includes URL, headers, body sent upstream) +// +// Returns: +// - error: Always returns nil (buffering cannot fail) +func (w *FileStreamingLogWriter) WriteAPIRequest(apiRequest []byte) error { + if len(apiRequest) == 0 { + return nil } - return err + w.apiRequest = bytes.Clone(apiRequest) + return nil +} + +// WriteAPIResponse buffers the upstream API response details for later writing. +// +// Parameters: +// - apiResponse: The API response data +// +// Returns: +// - error: Always returns nil (buffering cannot fail) +func (w *FileStreamingLogWriter) WriteAPIResponse(apiResponse []byte) error { + if len(apiResponse) == 0 { + return nil + } + w.apiResponse = bytes.Clone(apiResponse) + return nil } // Close finalizes the log file and cleans up resources. +// It writes all buffered data to the file in the correct order: +// API REQUEST -> API RESPONSE -> RESPONSE (status, headers, body chunks) // // Returns: // - error: An error if closing fails, nil otherwise @@ -707,27 +770,84 @@ func (w *FileStreamingLogWriter) Close() error { close(w.chunkChan) } - // Wait for async writer to finish + // Wait for async writer to finish buffering chunks if w.closeChan != nil { <-w.closeChan w.chunkChan = nil } - if w.file != nil { - return w.file.Close() + if w.file == nil { + return nil } - return nil + // Write all content in the correct order + var content strings.Builder + + // 1. Write API REQUEST section + if len(w.apiRequest) > 0 { + if bytes.HasPrefix(w.apiRequest, []byte("=== API REQUEST")) { + content.Write(w.apiRequest) + if !bytes.HasSuffix(w.apiRequest, []byte("\n")) { + content.WriteString("\n") + } + } else { + content.WriteString("=== API REQUEST ===\n") + content.Write(w.apiRequest) + content.WriteString("\n") + } + content.WriteString("\n") + } + + // 2. Write API RESPONSE section + if len(w.apiResponse) > 0 { + if bytes.HasPrefix(w.apiResponse, []byte("=== API RESPONSE")) { + content.Write(w.apiResponse) + if !bytes.HasSuffix(w.apiResponse, []byte("\n")) { + content.WriteString("\n") + } + } else { + content.WriteString("=== API RESPONSE ===\n") + content.Write(w.apiResponse) + content.WriteString("\n") + } + content.WriteString("\n") + } + + // 3. Write RESPONSE section (status, headers, buffered chunks) + content.WriteString("=== RESPONSE ===\n") + if w.statusWritten { + content.WriteString(fmt.Sprintf("Status: %d\n", w.responseStatus)) + } + + for key, values := range w.responseHeaders { + for _, value := range values { + content.WriteString(fmt.Sprintf("%s: %s\n", key, value)) + } + } + content.WriteString("\n") + + // Write buffered response body chunks + if w.bufferedChunks != nil && w.bufferedChunks.Len() > 0 { + content.Write(w.bufferedChunks.Bytes()) + } + + // Write the complete content to file + if _, err := w.file.WriteString(content.String()); err != nil { + _ = w.file.Close() + return err + } + + return w.file.Close() } -// asyncWriter runs in a goroutine to handle async chunk writing. -// It continuously reads chunks from the channel and writes them to the file. +// asyncWriter runs in a goroutine to buffer chunks from the channel. +// It continuously reads chunks from the channel and buffers them for later writing. func (w *FileStreamingLogWriter) asyncWriter() { defer close(w.closeChan) for chunk := range w.chunkChan { - if w.file != nil { - _, _ = w.file.Write(chunk) + if w.bufferedChunks != nil { + w.bufferedChunks.Write(chunk) } } } @@ -754,6 +874,28 @@ func (w *NoOpStreamingLogWriter) WriteStatus(_ int, _ map[string][]string) error return nil } +// WriteAPIRequest is a no-op implementation that does nothing and always returns nil. +// +// Parameters: +// - apiRequest: The API request data (ignored) +// +// Returns: +// - error: Always returns nil +func (w *NoOpStreamingLogWriter) WriteAPIRequest(_ []byte) error { + return nil +} + +// WriteAPIResponse is a no-op implementation that does nothing and always returns nil. +// +// Parameters: +// - apiResponse: The API response data (ignored) +// +// Returns: +// - error: Always returns nil +func (w *NoOpStreamingLogWriter) WriteAPIResponse(_ []byte) error { + return nil +} + // Close is a no-op implementation that does nothing and always returns nil. // // Returns: diff --git a/internal/registry/model_definitions.go b/internal/registry/model_definitions.go index b25d91c2..fc7e75a1 100644 --- a/internal/registry/model_definitions.go +++ b/internal/registry/model_definitions.go @@ -943,18 +943,6 @@ func GetQwenModels() []*ModelInfo { } } -// GetAntigravityThinkingConfig returns the Thinking configuration for antigravity models. -// Keys use the ALIASED model names (after modelName2Alias conversion) for direct lookup. -func GetAntigravityThinkingConfig() map[string]*ThinkingSupport { - return map[string]*ThinkingSupport{ - "gemini-2.5-flash": {Min: 0, Max: 24576, ZeroAllowed: true, DynamicAllowed: true}, - "gemini-2.5-flash-lite": {Min: 0, Max: 24576, ZeroAllowed: true, DynamicAllowed: true}, - "gemini-3-pro-preview": {Min: 128, Max: 32768, ZeroAllowed: false, DynamicAllowed: true}, - "gemini-claude-sonnet-4-5-thinking": {Min: 1024, Max: 200000, ZeroAllowed: false, DynamicAllowed: true}, - "gemini-claude-opus-4-5-thinking": {Min: 1024, Max: 200000, ZeroAllowed: false, DynamicAllowed: true}, - } -} - // GetIFlowModels returns supported models for iFlow OAuth accounts. func GetIFlowModels() []*ModelInfo { entries := []struct { @@ -998,6 +986,25 @@ func GetIFlowModels() []*ModelInfo { return models } +// AntigravityModelConfig captures static antigravity model overrides, including +// Thinking budget limits and provider max completion tokens. +type AntigravityModelConfig struct { + Thinking *ThinkingSupport + MaxCompletionTokens int +} + +// GetAntigravityModelConfig returns static configuration for antigravity models. +// Keys use the ALIASED model names (after modelName2Alias conversion) for direct lookup. +func GetAntigravityModelConfig() map[string]*AntigravityModelConfig { + return map[string]*AntigravityModelConfig{ + "gemini-2.5-flash": {Thinking: &ThinkingSupport{Min: 0, Max: 24576, ZeroAllowed: true, DynamicAllowed: true}}, + "gemini-2.5-flash-lite": {Thinking: &ThinkingSupport{Min: 0, Max: 24576, ZeroAllowed: true, DynamicAllowed: true}}, + "gemini-3-pro-preview": {Thinking: &ThinkingSupport{Min: 128, Max: 32768, ZeroAllowed: false, DynamicAllowed: true}}, + "gemini-claude-sonnet-4-5-thinking": {Thinking: &ThinkingSupport{Min: 1024, Max: 200000, ZeroAllowed: false, DynamicAllowed: true}, MaxCompletionTokens: 64000}, + "gemini-claude-opus-4-5-thinking": {Thinking: &ThinkingSupport{Min: 1024, Max: 200000, ZeroAllowed: false, DynamicAllowed: true}, MaxCompletionTokens: 64000}, + } +} + // GetGitHubCopilotModels returns the available models for GitHub Copilot. // These models are available through the GitHub Copilot API at api.githubcopilot.com. func GetGitHubCopilotModels() []*ModelInfo { diff --git a/internal/runtime/executor/antigravity_executor.go b/internal/runtime/executor/antigravity_executor.go index ce836a77..d83559ab 100644 --- a/internal/runtime/executor/antigravity_executor.go +++ b/internal/runtime/executor/antigravity_executor.go @@ -81,6 +81,7 @@ func (e *AntigravityExecutor) Execute(ctx context.Context, auth *cliproxyauth.Au translated := sdktranslator.TranslateRequest(from, to, req.Model, bytes.Clone(req.Payload), false) translated = applyThinkingMetadataCLI(translated, req.Metadata, req.Model) + translated = normalizeAntigravityThinking(req.Model, translated) baseURLs := antigravityBaseURLFallbackOrder(auth) httpClient := newProxyAwareHTTPClient(ctx, e.cfg, auth, 0) @@ -174,6 +175,7 @@ func (e *AntigravityExecutor) ExecuteStream(ctx context.Context, auth *cliproxya translated := sdktranslator.TranslateRequest(from, to, req.Model, bytes.Clone(req.Payload), true) translated = applyThinkingMetadataCLI(translated, req.Metadata, req.Model) + translated = normalizeAntigravityThinking(req.Model, translated) baseURLs := antigravityBaseURLFallbackOrder(auth) httpClient := newProxyAwareHTTPClient(ctx, e.cfg, auth, 0) @@ -370,7 +372,7 @@ func FetchAntigravityModels(ctx context.Context, auth *cliproxyauth.Auth, cfg *c } now := time.Now().Unix() - thinkingConfig := registry.GetAntigravityThinkingConfig() + modelConfig := registry.GetAntigravityModelConfig() models := make([]*registry.ModelInfo, 0, len(result.Map())) for originalName := range result.Map() { aliasName := modelName2Alias(originalName) @@ -387,8 +389,13 @@ func FetchAntigravityModels(ctx context.Context, auth *cliproxyauth.Auth, cfg *c Type: antigravityAuthType, } // Look up Thinking support from static config using alias name - if thinking, ok := thinkingConfig[aliasName]; ok { - modelInfo.Thinking = thinking + if cfg, ok := modelConfig[aliasName]; ok { + if cfg.Thinking != nil { + modelInfo.Thinking = cfg.Thinking + } + if cfg.MaxCompletionTokens > 0 { + modelInfo.MaxCompletionTokens = cfg.MaxCompletionTokens + } } models = append(models, modelInfo) } @@ -812,3 +819,53 @@ func alias2ModelName(modelName string) string { return modelName } } + +// normalizeAntigravityThinking clamps or removes thinking config based on model support. +// For Claude models, it additionally ensures thinking budget < max_tokens. +func normalizeAntigravityThinking(model string, payload []byte) []byte { + payload = util.StripThinkingConfigIfUnsupported(model, payload) + if !util.ModelSupportsThinking(model) { + return payload + } + budget := gjson.GetBytes(payload, "request.generationConfig.thinkingConfig.thinkingBudget") + if !budget.Exists() { + return payload + } + raw := int(budget.Int()) + normalized := util.NormalizeThinkingBudget(model, raw) + + isClaude := strings.Contains(strings.ToLower(model), "claude") + if isClaude { + effectiveMax, setDefaultMax := antigravityEffectiveMaxTokens(model, payload) + if effectiveMax > 0 && normalized >= effectiveMax { + normalized = effectiveMax - 1 + if normalized < 1 { + normalized = 1 + } + } + if setDefaultMax { + if res, errSet := sjson.SetBytes(payload, "request.generationConfig.maxOutputTokens", effectiveMax); errSet == nil { + payload = res + } + } + } + + updated, err := sjson.SetBytes(payload, "request.generationConfig.thinkingConfig.thinkingBudget", normalized) + if err != nil { + return payload + } + return updated +} + +// antigravityEffectiveMaxTokens returns the max tokens to cap thinking: +// prefer request-provided maxOutputTokens; otherwise fall back to model default. +// The boolean indicates whether the value came from the model default (and thus should be written back). +func antigravityEffectiveMaxTokens(model string, payload []byte) (max int, fromModel bool) { + if maxTok := gjson.GetBytes(payload, "request.generationConfig.maxOutputTokens"); maxTok.Exists() && maxTok.Int() > 0 { + return int(maxTok.Int()), false + } + if modelInfo := registry.GetGlobalRegistry().GetModelInfo(model); modelInfo != nil && modelInfo.MaxCompletionTokens > 0 { + return modelInfo.MaxCompletionTokens, true + } + return 0, false +} diff --git a/internal/runtime/executor/logging_helpers.go b/internal/runtime/executor/logging_helpers.go index 26931f53..7798b96b 100644 --- a/internal/runtime/executor/logging_helpers.go +++ b/internal/runtime/executor/logging_helpers.go @@ -157,7 +157,7 @@ func appendAPIResponseChunk(ctx context.Context, cfg *config.Config, chunk []byt if ginCtx == nil { return } - attempts, attempt := ensureAttempt(ginCtx) + _, attempt := ensureAttempt(ginCtx) ensureResponseIntro(attempt) if !attempt.headersWritten { @@ -175,8 +175,6 @@ func appendAPIResponseChunk(ctx context.Context, cfg *config.Config, chunk []byt } attempt.response.WriteString(string(data)) attempt.bodyHasContent = true - - updateAggregatedResponse(ginCtx, attempts) } func ginContextFrom(ctx context.Context) *gin.Context { diff --git a/internal/translator/antigravity/openai/chat-completions/antigravity_openai_request.go b/internal/translator/antigravity/openai/chat-completions/antigravity_openai_request.go index 82e71758..1c90a803 100644 --- a/internal/translator/antigravity/openai/chat-completions/antigravity_openai_request.go +++ b/internal/translator/antigravity/openai/chat-completions/antigravity_openai_request.go @@ -111,7 +111,7 @@ func ConvertOpenAIRequestToAntigravity(modelName string, inputRawJSON []byte, _ out, _ = sjson.SetBytes(out, "request.generationConfig.thinkingConfig.include_thoughts", true) } - // Temperature/top_p/top_k + // Temperature/top_p/top_k/max_tokens if tr := gjson.GetBytes(rawJSON, "temperature"); tr.Exists() && tr.Type == gjson.Number { out, _ = sjson.SetBytes(out, "request.generationConfig.temperature", tr.Num) } @@ -121,6 +121,9 @@ func ConvertOpenAIRequestToAntigravity(modelName string, inputRawJSON []byte, _ if tkr := gjson.GetBytes(rawJSON, "top_k"); tkr.Exists() && tkr.Type == gjson.Number { out, _ = sjson.SetBytes(out, "request.generationConfig.topK", tkr.Num) } + if maxTok := gjson.GetBytes(rawJSON, "max_tokens"); maxTok.Exists() && maxTok.Type == gjson.Number { + out, _ = sjson.SetBytes(out, "request.generationConfig.maxOutputTokens", maxTok.Num) + } // Map OpenAI modalities -> Gemini CLI request.generationConfig.responseModalities // e.g. "modalities": ["image", "text"] -> ["IMAGE", "TEXT"] diff --git a/test/amp_management_test.go b/test/amp_management_test.go new file mode 100644 index 00000000..19450dbf --- /dev/null +++ b/test/amp_management_test.go @@ -0,0 +1,827 @@ +package test + +import ( + "bytes" + "encoding/json" + "net/http" + "net/http/httptest" + "os" + "path/filepath" + "testing" + + "github.com/gin-gonic/gin" + "github.com/router-for-me/CLIProxyAPI/v6/internal/api/handlers/management" + "github.com/router-for-me/CLIProxyAPI/v6/internal/config" +) + +func init() { + gin.SetMode(gin.TestMode) +} + +// newAmpTestHandler creates a test handler with default ampcode configuration. +func newAmpTestHandler(t *testing.T) (*management.Handler, string) { + t.Helper() + tmpDir := t.TempDir() + configPath := filepath.Join(tmpDir, "config.yaml") + + cfg := &config.Config{ + AmpCode: config.AmpCode{ + UpstreamURL: "https://example.com", + UpstreamAPIKey: "test-api-key-12345", + RestrictManagementToLocalhost: true, + ForceModelMappings: false, + ModelMappings: []config.AmpModelMapping{ + {From: "gpt-4", To: "gemini-pro"}, + }, + }, + } + + if err := os.WriteFile(configPath, []byte("port: 8080\n"), 0644); err != nil { + t.Fatalf("failed to write config file: %v", err) + } + + h := management.NewHandler(cfg, configPath, nil) + return h, configPath +} + +// setupAmpRouter creates a test router with all ampcode management endpoints. +func setupAmpRouter(h *management.Handler) *gin.Engine { + r := gin.New() + mgmt := r.Group("/v0/management") + { + mgmt.GET("/ampcode", h.GetAmpCode) + mgmt.GET("/ampcode/upstream-url", h.GetAmpUpstreamURL) + mgmt.PUT("/ampcode/upstream-url", h.PutAmpUpstreamURL) + mgmt.DELETE("/ampcode/upstream-url", h.DeleteAmpUpstreamURL) + mgmt.GET("/ampcode/upstream-api-key", h.GetAmpUpstreamAPIKey) + mgmt.PUT("/ampcode/upstream-api-key", h.PutAmpUpstreamAPIKey) + mgmt.DELETE("/ampcode/upstream-api-key", h.DeleteAmpUpstreamAPIKey) + mgmt.GET("/ampcode/restrict-management-to-localhost", h.GetAmpRestrictManagementToLocalhost) + mgmt.PUT("/ampcode/restrict-management-to-localhost", h.PutAmpRestrictManagementToLocalhost) + mgmt.GET("/ampcode/model-mappings", h.GetAmpModelMappings) + mgmt.PUT("/ampcode/model-mappings", h.PutAmpModelMappings) + mgmt.PATCH("/ampcode/model-mappings", h.PatchAmpModelMappings) + mgmt.DELETE("/ampcode/model-mappings", h.DeleteAmpModelMappings) + mgmt.GET("/ampcode/force-model-mappings", h.GetAmpForceModelMappings) + mgmt.PUT("/ampcode/force-model-mappings", h.PutAmpForceModelMappings) + } + return r +} + +// TestGetAmpCode verifies GET /v0/management/ampcode returns full ampcode config. +func TestGetAmpCode(t *testing.T) { + h, _ := newAmpTestHandler(t) + r := setupAmpRouter(h) + + req := httptest.NewRequest(http.MethodGet, "/v0/management/ampcode", nil) + w := httptest.NewRecorder() + r.ServeHTTP(w, req) + + if w.Code != http.StatusOK { + t.Fatalf("expected status %d, got %d", http.StatusOK, w.Code) + } + + var resp map[string]config.AmpCode + if err := json.Unmarshal(w.Body.Bytes(), &resp); err != nil { + t.Fatalf("failed to unmarshal response: %v", err) + } + + ampcode := resp["ampcode"] + if ampcode.UpstreamURL != "https://example.com" { + t.Errorf("expected upstream-url %q, got %q", "https://example.com", ampcode.UpstreamURL) + } + if len(ampcode.ModelMappings) != 1 { + t.Errorf("expected 1 model mapping, got %d", len(ampcode.ModelMappings)) + } +} + +// TestGetAmpUpstreamURL verifies GET /v0/management/ampcode/upstream-url returns the upstream URL. +func TestGetAmpUpstreamURL(t *testing.T) { + h, _ := newAmpTestHandler(t) + r := setupAmpRouter(h) + + req := httptest.NewRequest(http.MethodGet, "/v0/management/ampcode/upstream-url", nil) + w := httptest.NewRecorder() + r.ServeHTTP(w, req) + + if w.Code != http.StatusOK { + t.Fatalf("expected status %d, got %d", http.StatusOK, w.Code) + } + + var resp map[string]string + if err := json.Unmarshal(w.Body.Bytes(), &resp); err != nil { + t.Fatalf("failed to unmarshal response: %v", err) + } + + if resp["upstream-url"] != "https://example.com" { + t.Errorf("expected %q, got %q", "https://example.com", resp["upstream-url"]) + } +} + +// TestPutAmpUpstreamURL verifies PUT /v0/management/ampcode/upstream-url updates the upstream URL. +func TestPutAmpUpstreamURL(t *testing.T) { + h, _ := newAmpTestHandler(t) + r := setupAmpRouter(h) + + body := `{"value": "https://new-upstream.com"}` + req := httptest.NewRequest(http.MethodPut, "/v0/management/ampcode/upstream-url", bytes.NewBufferString(body)) + req.Header.Set("Content-Type", "application/json") + w := httptest.NewRecorder() + r.ServeHTTP(w, req) + + if w.Code != http.StatusOK { + t.Fatalf("expected status %d, got %d: %s", http.StatusOK, w.Code, w.Body.String()) + } +} + +// TestDeleteAmpUpstreamURL verifies DELETE /v0/management/ampcode/upstream-url clears the upstream URL. +func TestDeleteAmpUpstreamURL(t *testing.T) { + h, _ := newAmpTestHandler(t) + r := setupAmpRouter(h) + + req := httptest.NewRequest(http.MethodDelete, "/v0/management/ampcode/upstream-url", nil) + w := httptest.NewRecorder() + r.ServeHTTP(w, req) + + if w.Code != http.StatusOK { + t.Fatalf("expected status %d, got %d", http.StatusOK, w.Code) + } +} + +// TestGetAmpUpstreamAPIKey verifies GET /v0/management/ampcode/upstream-api-key returns the API key. +func TestGetAmpUpstreamAPIKey(t *testing.T) { + h, _ := newAmpTestHandler(t) + r := setupAmpRouter(h) + + req := httptest.NewRequest(http.MethodGet, "/v0/management/ampcode/upstream-api-key", nil) + w := httptest.NewRecorder() + r.ServeHTTP(w, req) + + if w.Code != http.StatusOK { + t.Fatalf("expected status %d, got %d", http.StatusOK, w.Code) + } + + var resp map[string]any + if err := json.Unmarshal(w.Body.Bytes(), &resp); err != nil { + t.Fatalf("failed to unmarshal response: %v", err) + } + + key := resp["upstream-api-key"].(string) + if key != "test-api-key-12345" { + t.Errorf("expected key %q, got %q", "test-api-key-12345", key) + } +} + +// TestPutAmpUpstreamAPIKey verifies PUT /v0/management/ampcode/upstream-api-key updates the API key. +func TestPutAmpUpstreamAPIKey(t *testing.T) { + h, _ := newAmpTestHandler(t) + r := setupAmpRouter(h) + + body := `{"value": "new-secret-key"}` + req := httptest.NewRequest(http.MethodPut, "/v0/management/ampcode/upstream-api-key", bytes.NewBufferString(body)) + req.Header.Set("Content-Type", "application/json") + w := httptest.NewRecorder() + r.ServeHTTP(w, req) + + if w.Code != http.StatusOK { + t.Fatalf("expected status %d, got %d", http.StatusOK, w.Code) + } +} + +// TestDeleteAmpUpstreamAPIKey verifies DELETE /v0/management/ampcode/upstream-api-key clears the API key. +func TestDeleteAmpUpstreamAPIKey(t *testing.T) { + h, _ := newAmpTestHandler(t) + r := setupAmpRouter(h) + + req := httptest.NewRequest(http.MethodDelete, "/v0/management/ampcode/upstream-api-key", nil) + w := httptest.NewRecorder() + r.ServeHTTP(w, req) + + if w.Code != http.StatusOK { + t.Fatalf("expected status %d, got %d", http.StatusOK, w.Code) + } +} + +// TestGetAmpRestrictManagementToLocalhost verifies GET returns the localhost restriction setting. +func TestGetAmpRestrictManagementToLocalhost(t *testing.T) { + h, _ := newAmpTestHandler(t) + r := setupAmpRouter(h) + + req := httptest.NewRequest(http.MethodGet, "/v0/management/ampcode/restrict-management-to-localhost", nil) + w := httptest.NewRecorder() + r.ServeHTTP(w, req) + + if w.Code != http.StatusOK { + t.Fatalf("expected status %d, got %d", http.StatusOK, w.Code) + } + + var resp map[string]bool + if err := json.Unmarshal(w.Body.Bytes(), &resp); err != nil { + t.Fatalf("failed to unmarshal response: %v", err) + } + + if resp["restrict-management-to-localhost"] != true { + t.Error("expected restrict-management-to-localhost to be true") + } +} + +// TestPutAmpRestrictManagementToLocalhost verifies PUT updates the localhost restriction setting. +func TestPutAmpRestrictManagementToLocalhost(t *testing.T) { + h, _ := newAmpTestHandler(t) + r := setupAmpRouter(h) + + body := `{"value": false}` + req := httptest.NewRequest(http.MethodPut, "/v0/management/ampcode/restrict-management-to-localhost", bytes.NewBufferString(body)) + req.Header.Set("Content-Type", "application/json") + w := httptest.NewRecorder() + r.ServeHTTP(w, req) + + if w.Code != http.StatusOK { + t.Fatalf("expected status %d, got %d", http.StatusOK, w.Code) + } +} + +// TestGetAmpModelMappings verifies GET /v0/management/ampcode/model-mappings returns all mappings. +func TestGetAmpModelMappings(t *testing.T) { + h, _ := newAmpTestHandler(t) + r := setupAmpRouter(h) + + req := httptest.NewRequest(http.MethodGet, "/v0/management/ampcode/model-mappings", nil) + w := httptest.NewRecorder() + r.ServeHTTP(w, req) + + if w.Code != http.StatusOK { + t.Fatalf("expected status %d, got %d", http.StatusOK, w.Code) + } + + var resp map[string][]config.AmpModelMapping + if err := json.Unmarshal(w.Body.Bytes(), &resp); err != nil { + t.Fatalf("failed to unmarshal response: %v", err) + } + + mappings := resp["model-mappings"] + if len(mappings) != 1 { + t.Fatalf("expected 1 mapping, got %d", len(mappings)) + } + if mappings[0].From != "gpt-4" || mappings[0].To != "gemini-pro" { + t.Errorf("unexpected mapping: %+v", mappings[0]) + } +} + +// TestPutAmpModelMappings verifies PUT /v0/management/ampcode/model-mappings replaces all mappings. +func TestPutAmpModelMappings(t *testing.T) { + h, _ := newAmpTestHandler(t) + r := setupAmpRouter(h) + + body := `{"value": [{"from": "claude-3", "to": "gpt-4o"}, {"from": "gemini", "to": "claude"}]}` + req := httptest.NewRequest(http.MethodPut, "/v0/management/ampcode/model-mappings", bytes.NewBufferString(body)) + req.Header.Set("Content-Type", "application/json") + w := httptest.NewRecorder() + r.ServeHTTP(w, req) + + if w.Code != http.StatusOK { + t.Fatalf("expected status %d, got %d: %s", http.StatusOK, w.Code, w.Body.String()) + } +} + +// TestPatchAmpModelMappings verifies PATCH updates existing mappings and adds new ones. +func TestPatchAmpModelMappings(t *testing.T) { + h, _ := newAmpTestHandler(t) + r := setupAmpRouter(h) + + body := `{"value": [{"from": "gpt-4", "to": "updated-model"}, {"from": "new-model", "to": "target"}]}` + req := httptest.NewRequest(http.MethodPatch, "/v0/management/ampcode/model-mappings", bytes.NewBufferString(body)) + req.Header.Set("Content-Type", "application/json") + w := httptest.NewRecorder() + r.ServeHTTP(w, req) + + if w.Code != http.StatusOK { + t.Fatalf("expected status %d, got %d: %s", http.StatusOK, w.Code, w.Body.String()) + } +} + +// TestDeleteAmpModelMappings_Specific verifies DELETE removes specified mappings by "from" field. +func TestDeleteAmpModelMappings_Specific(t *testing.T) { + h, _ := newAmpTestHandler(t) + r := setupAmpRouter(h) + + body := `{"value": ["gpt-4"]}` + req := httptest.NewRequest(http.MethodDelete, "/v0/management/ampcode/model-mappings", bytes.NewBufferString(body)) + req.Header.Set("Content-Type", "application/json") + w := httptest.NewRecorder() + r.ServeHTTP(w, req) + + if w.Code != http.StatusOK { + t.Fatalf("expected status %d, got %d", http.StatusOK, w.Code) + } +} + +// TestDeleteAmpModelMappings_All verifies DELETE with empty body removes all mappings. +func TestDeleteAmpModelMappings_All(t *testing.T) { + h, _ := newAmpTestHandler(t) + r := setupAmpRouter(h) + + req := httptest.NewRequest(http.MethodDelete, "/v0/management/ampcode/model-mappings", nil) + w := httptest.NewRecorder() + r.ServeHTTP(w, req) + + if w.Code != http.StatusOK { + t.Fatalf("expected status %d, got %d", http.StatusOK, w.Code) + } +} + +// TestGetAmpForceModelMappings verifies GET returns the force-model-mappings setting. +func TestGetAmpForceModelMappings(t *testing.T) { + h, _ := newAmpTestHandler(t) + r := setupAmpRouter(h) + + req := httptest.NewRequest(http.MethodGet, "/v0/management/ampcode/force-model-mappings", nil) + w := httptest.NewRecorder() + r.ServeHTTP(w, req) + + if w.Code != http.StatusOK { + t.Fatalf("expected status %d, got %d", http.StatusOK, w.Code) + } + + var resp map[string]bool + if err := json.Unmarshal(w.Body.Bytes(), &resp); err != nil { + t.Fatalf("failed to unmarshal response: %v", err) + } + + if resp["force-model-mappings"] != false { + t.Error("expected force-model-mappings to be false") + } +} + +// TestPutAmpForceModelMappings verifies PUT updates the force-model-mappings setting. +func TestPutAmpForceModelMappings(t *testing.T) { + h, _ := newAmpTestHandler(t) + r := setupAmpRouter(h) + + body := `{"value": true}` + req := httptest.NewRequest(http.MethodPut, "/v0/management/ampcode/force-model-mappings", bytes.NewBufferString(body)) + req.Header.Set("Content-Type", "application/json") + w := httptest.NewRecorder() + r.ServeHTTP(w, req) + + if w.Code != http.StatusOK { + t.Fatalf("expected status %d, got %d", http.StatusOK, w.Code) + } +} + +// TestPutAmpModelMappings_VerifyState verifies PUT replaces mappings and state is persisted. +func TestPutAmpModelMappings_VerifyState(t *testing.T) { + h, _ := newAmpTestHandler(t) + r := setupAmpRouter(h) + + body := `{"value": [{"from": "model-a", "to": "model-b"}, {"from": "model-c", "to": "model-d"}, {"from": "model-e", "to": "model-f"}]}` + req := httptest.NewRequest(http.MethodPut, "/v0/management/ampcode/model-mappings", bytes.NewBufferString(body)) + req.Header.Set("Content-Type", "application/json") + w := httptest.NewRecorder() + r.ServeHTTP(w, req) + + if w.Code != http.StatusOK { + t.Fatalf("PUT failed: status %d, body: %s", w.Code, w.Body.String()) + } + + req = httptest.NewRequest(http.MethodGet, "/v0/management/ampcode/model-mappings", nil) + w = httptest.NewRecorder() + r.ServeHTTP(w, req) + + var resp map[string][]config.AmpModelMapping + if err := json.Unmarshal(w.Body.Bytes(), &resp); err != nil { + t.Fatalf("failed to unmarshal: %v", err) + } + + mappings := resp["model-mappings"] + if len(mappings) != 3 { + t.Fatalf("expected 3 mappings, got %d", len(mappings)) + } + + expected := map[string]string{"model-a": "model-b", "model-c": "model-d", "model-e": "model-f"} + for _, m := range mappings { + if expected[m.From] != m.To { + t.Errorf("mapping %q -> expected %q, got %q", m.From, expected[m.From], m.To) + } + } +} + +// TestPatchAmpModelMappings_VerifyState verifies PATCH merges mappings correctly. +func TestPatchAmpModelMappings_VerifyState(t *testing.T) { + h, _ := newAmpTestHandler(t) + r := setupAmpRouter(h) + + body := `{"value": [{"from": "gpt-4", "to": "updated-target"}, {"from": "new-model", "to": "new-target"}]}` + req := httptest.NewRequest(http.MethodPatch, "/v0/management/ampcode/model-mappings", bytes.NewBufferString(body)) + req.Header.Set("Content-Type", "application/json") + w := httptest.NewRecorder() + r.ServeHTTP(w, req) + + if w.Code != http.StatusOK { + t.Fatalf("PATCH failed: status %d", w.Code) + } + + req = httptest.NewRequest(http.MethodGet, "/v0/management/ampcode/model-mappings", nil) + w = httptest.NewRecorder() + r.ServeHTTP(w, req) + + var resp map[string][]config.AmpModelMapping + if err := json.Unmarshal(w.Body.Bytes(), &resp); err != nil { + t.Fatalf("failed to unmarshal: %v", err) + } + + mappings := resp["model-mappings"] + if len(mappings) != 2 { + t.Fatalf("expected 2 mappings (1 updated + 1 new), got %d", len(mappings)) + } + + found := make(map[string]string) + for _, m := range mappings { + found[m.From] = m.To + } + + if found["gpt-4"] != "updated-target" { + t.Errorf("gpt-4 should map to updated-target, got %q", found["gpt-4"]) + } + if found["new-model"] != "new-target" { + t.Errorf("new-model should map to new-target, got %q", found["new-model"]) + } +} + +// TestDeleteAmpModelMappings_VerifyState verifies DELETE removes specific mappings and keeps others. +func TestDeleteAmpModelMappings_VerifyState(t *testing.T) { + h, _ := newAmpTestHandler(t) + r := setupAmpRouter(h) + + putBody := `{"value": [{"from": "a", "to": "1"}, {"from": "b", "to": "2"}, {"from": "c", "to": "3"}]}` + req := httptest.NewRequest(http.MethodPut, "/v0/management/ampcode/model-mappings", bytes.NewBufferString(putBody)) + req.Header.Set("Content-Type", "application/json") + w := httptest.NewRecorder() + r.ServeHTTP(w, req) + + delBody := `{"value": ["a", "c"]}` + req = httptest.NewRequest(http.MethodDelete, "/v0/management/ampcode/model-mappings", bytes.NewBufferString(delBody)) + req.Header.Set("Content-Type", "application/json") + w = httptest.NewRecorder() + r.ServeHTTP(w, req) + + if w.Code != http.StatusOK { + t.Fatalf("DELETE failed: status %d", w.Code) + } + + req = httptest.NewRequest(http.MethodGet, "/v0/management/ampcode/model-mappings", nil) + w = httptest.NewRecorder() + r.ServeHTTP(w, req) + + var resp map[string][]config.AmpModelMapping + if err := json.Unmarshal(w.Body.Bytes(), &resp); err != nil { + t.Fatalf("failed to unmarshal: %v", err) + } + + mappings := resp["model-mappings"] + if len(mappings) != 1 { + t.Fatalf("expected 1 mapping remaining, got %d", len(mappings)) + } + if mappings[0].From != "b" || mappings[0].To != "2" { + t.Errorf("expected b->2, got %s->%s", mappings[0].From, mappings[0].To) + } +} + +// TestDeleteAmpModelMappings_NonExistent verifies DELETE with non-existent mapping doesn't affect existing ones. +func TestDeleteAmpModelMappings_NonExistent(t *testing.T) { + h, _ := newAmpTestHandler(t) + r := setupAmpRouter(h) + + delBody := `{"value": ["non-existent-model"]}` + req := httptest.NewRequest(http.MethodDelete, "/v0/management/ampcode/model-mappings", bytes.NewBufferString(delBody)) + req.Header.Set("Content-Type", "application/json") + w := httptest.NewRecorder() + r.ServeHTTP(w, req) + + if w.Code != http.StatusOK { + t.Fatalf("expected status %d, got %d", http.StatusOK, w.Code) + } + + req = httptest.NewRequest(http.MethodGet, "/v0/management/ampcode/model-mappings", nil) + w = httptest.NewRecorder() + r.ServeHTTP(w, req) + + var resp map[string][]config.AmpModelMapping + if err := json.Unmarshal(w.Body.Bytes(), &resp); err != nil { + t.Fatalf("failed to unmarshal: %v", err) + } + + if len(resp["model-mappings"]) != 1 { + t.Errorf("original mapping should remain, got %d mappings", len(resp["model-mappings"])) + } +} + +// TestPutAmpModelMappings_Empty verifies PUT with empty array clears all mappings. +func TestPutAmpModelMappings_Empty(t *testing.T) { + h, _ := newAmpTestHandler(t) + r := setupAmpRouter(h) + + body := `{"value": []}` + req := httptest.NewRequest(http.MethodPut, "/v0/management/ampcode/model-mappings", bytes.NewBufferString(body)) + req.Header.Set("Content-Type", "application/json") + w := httptest.NewRecorder() + r.ServeHTTP(w, req) + + if w.Code != http.StatusOK { + t.Fatalf("expected status %d, got %d", http.StatusOK, w.Code) + } + + req = httptest.NewRequest(http.MethodGet, "/v0/management/ampcode/model-mappings", nil) + w = httptest.NewRecorder() + r.ServeHTTP(w, req) + + var resp map[string][]config.AmpModelMapping + if err := json.Unmarshal(w.Body.Bytes(), &resp); err != nil { + t.Fatalf("failed to unmarshal: %v", err) + } + + if len(resp["model-mappings"]) != 0 { + t.Errorf("expected 0 mappings, got %d", len(resp["model-mappings"])) + } +} + +// TestPutAmpUpstreamURL_VerifyState verifies PUT updates upstream URL and persists state. +func TestPutAmpUpstreamURL_VerifyState(t *testing.T) { + h, _ := newAmpTestHandler(t) + r := setupAmpRouter(h) + + body := `{"value": "https://new-api.example.com"}` + req := httptest.NewRequest(http.MethodPut, "/v0/management/ampcode/upstream-url", bytes.NewBufferString(body)) + req.Header.Set("Content-Type", "application/json") + w := httptest.NewRecorder() + r.ServeHTTP(w, req) + + if w.Code != http.StatusOK { + t.Fatalf("PUT failed: status %d", w.Code) + } + + req = httptest.NewRequest(http.MethodGet, "/v0/management/ampcode/upstream-url", nil) + w = httptest.NewRecorder() + r.ServeHTTP(w, req) + + var resp map[string]string + if err := json.Unmarshal(w.Body.Bytes(), &resp); err != nil { + t.Fatalf("failed to unmarshal: %v", err) + } + + if resp["upstream-url"] != "https://new-api.example.com" { + t.Errorf("expected %q, got %q", "https://new-api.example.com", resp["upstream-url"]) + } +} + +// TestDeleteAmpUpstreamURL_VerifyState verifies DELETE clears upstream URL. +func TestDeleteAmpUpstreamURL_VerifyState(t *testing.T) { + h, _ := newAmpTestHandler(t) + r := setupAmpRouter(h) + + req := httptest.NewRequest(http.MethodDelete, "/v0/management/ampcode/upstream-url", nil) + w := httptest.NewRecorder() + r.ServeHTTP(w, req) + + if w.Code != http.StatusOK { + t.Fatalf("DELETE failed: status %d", w.Code) + } + + req = httptest.NewRequest(http.MethodGet, "/v0/management/ampcode/upstream-url", nil) + w = httptest.NewRecorder() + r.ServeHTTP(w, req) + + var resp map[string]string + if err := json.Unmarshal(w.Body.Bytes(), &resp); err != nil { + t.Fatalf("failed to unmarshal: %v", err) + } + + if resp["upstream-url"] != "" { + t.Errorf("expected empty string, got %q", resp["upstream-url"]) + } +} + +// TestPutAmpUpstreamAPIKey_VerifyState verifies PUT updates API key and persists state. +func TestPutAmpUpstreamAPIKey_VerifyState(t *testing.T) { + h, _ := newAmpTestHandler(t) + r := setupAmpRouter(h) + + body := `{"value": "new-secret-api-key-xyz"}` + req := httptest.NewRequest(http.MethodPut, "/v0/management/ampcode/upstream-api-key", bytes.NewBufferString(body)) + req.Header.Set("Content-Type", "application/json") + w := httptest.NewRecorder() + r.ServeHTTP(w, req) + + if w.Code != http.StatusOK { + t.Fatalf("PUT failed: status %d", w.Code) + } + + req = httptest.NewRequest(http.MethodGet, "/v0/management/ampcode/upstream-api-key", nil) + w = httptest.NewRecorder() + r.ServeHTTP(w, req) + + var resp map[string]string + if err := json.Unmarshal(w.Body.Bytes(), &resp); err != nil { + t.Fatalf("failed to unmarshal: %v", err) + } + + if resp["upstream-api-key"] != "new-secret-api-key-xyz" { + t.Errorf("expected %q, got %q", "new-secret-api-key-xyz", resp["upstream-api-key"]) + } +} + +// TestDeleteAmpUpstreamAPIKey_VerifyState verifies DELETE clears API key. +func TestDeleteAmpUpstreamAPIKey_VerifyState(t *testing.T) { + h, _ := newAmpTestHandler(t) + r := setupAmpRouter(h) + + req := httptest.NewRequest(http.MethodDelete, "/v0/management/ampcode/upstream-api-key", nil) + w := httptest.NewRecorder() + r.ServeHTTP(w, req) + + if w.Code != http.StatusOK { + t.Fatalf("DELETE failed: status %d", w.Code) + } + + req = httptest.NewRequest(http.MethodGet, "/v0/management/ampcode/upstream-api-key", nil) + w = httptest.NewRecorder() + r.ServeHTTP(w, req) + + var resp map[string]string + if err := json.Unmarshal(w.Body.Bytes(), &resp); err != nil { + t.Fatalf("failed to unmarshal: %v", err) + } + + if resp["upstream-api-key"] != "" { + t.Errorf("expected empty string, got %q", resp["upstream-api-key"]) + } +} + +// TestPutAmpRestrictManagementToLocalhost_VerifyState verifies PUT updates localhost restriction. +func TestPutAmpRestrictManagementToLocalhost_VerifyState(t *testing.T) { + h, _ := newAmpTestHandler(t) + r := setupAmpRouter(h) + + body := `{"value": false}` + req := httptest.NewRequest(http.MethodPut, "/v0/management/ampcode/restrict-management-to-localhost", bytes.NewBufferString(body)) + req.Header.Set("Content-Type", "application/json") + w := httptest.NewRecorder() + r.ServeHTTP(w, req) + + if w.Code != http.StatusOK { + t.Fatalf("PUT failed: status %d", w.Code) + } + + req = httptest.NewRequest(http.MethodGet, "/v0/management/ampcode/restrict-management-to-localhost", nil) + w = httptest.NewRecorder() + r.ServeHTTP(w, req) + + var resp map[string]bool + if err := json.Unmarshal(w.Body.Bytes(), &resp); err != nil { + t.Fatalf("failed to unmarshal: %v", err) + } + + if resp["restrict-management-to-localhost"] != false { + t.Error("expected false after update") + } +} + +// TestPutAmpForceModelMappings_VerifyState verifies PUT updates force-model-mappings setting. +func TestPutAmpForceModelMappings_VerifyState(t *testing.T) { + h, _ := newAmpTestHandler(t) + r := setupAmpRouter(h) + + body := `{"value": true}` + req := httptest.NewRequest(http.MethodPut, "/v0/management/ampcode/force-model-mappings", bytes.NewBufferString(body)) + req.Header.Set("Content-Type", "application/json") + w := httptest.NewRecorder() + r.ServeHTTP(w, req) + + if w.Code != http.StatusOK { + t.Fatalf("PUT failed: status %d", w.Code) + } + + req = httptest.NewRequest(http.MethodGet, "/v0/management/ampcode/force-model-mappings", nil) + w = httptest.NewRecorder() + r.ServeHTTP(w, req) + + var resp map[string]bool + if err := json.Unmarshal(w.Body.Bytes(), &resp); err != nil { + t.Fatalf("failed to unmarshal: %v", err) + } + + if resp["force-model-mappings"] != true { + t.Error("expected true after update") + } +} + +// TestPutBoolField_EmptyObject verifies PUT with empty object returns 400. +func TestPutBoolField_EmptyObject(t *testing.T) { + h, _ := newAmpTestHandler(t) + r := setupAmpRouter(h) + + body := `{}` + req := httptest.NewRequest(http.MethodPut, "/v0/management/ampcode/force-model-mappings", bytes.NewBufferString(body)) + req.Header.Set("Content-Type", "application/json") + w := httptest.NewRecorder() + r.ServeHTTP(w, req) + + if w.Code != http.StatusBadRequest { + t.Fatalf("expected status %d for empty object, got %d", http.StatusBadRequest, w.Code) + } +} + +// TestComplexMappingsWorkflow tests a full workflow: PUT, PATCH, DELETE, and GET. +func TestComplexMappingsWorkflow(t *testing.T) { + h, _ := newAmpTestHandler(t) + r := setupAmpRouter(h) + + putBody := `{"value": [{"from": "m1", "to": "t1"}, {"from": "m2", "to": "t2"}, {"from": "m3", "to": "t3"}, {"from": "m4", "to": "t4"}]}` + req := httptest.NewRequest(http.MethodPut, "/v0/management/ampcode/model-mappings", bytes.NewBufferString(putBody)) + req.Header.Set("Content-Type", "application/json") + w := httptest.NewRecorder() + r.ServeHTTP(w, req) + + patchBody := `{"value": [{"from": "m2", "to": "t2-updated"}, {"from": "m5", "to": "t5"}]}` + req = httptest.NewRequest(http.MethodPatch, "/v0/management/ampcode/model-mappings", bytes.NewBufferString(patchBody)) + req.Header.Set("Content-Type", "application/json") + w = httptest.NewRecorder() + r.ServeHTTP(w, req) + + delBody := `{"value": ["m1", "m3"]}` + req = httptest.NewRequest(http.MethodDelete, "/v0/management/ampcode/model-mappings", bytes.NewBufferString(delBody)) + req.Header.Set("Content-Type", "application/json") + w = httptest.NewRecorder() + r.ServeHTTP(w, req) + + req = httptest.NewRequest(http.MethodGet, "/v0/management/ampcode/model-mappings", nil) + w = httptest.NewRecorder() + r.ServeHTTP(w, req) + + var resp map[string][]config.AmpModelMapping + if err := json.Unmarshal(w.Body.Bytes(), &resp); err != nil { + t.Fatalf("failed to unmarshal: %v", err) + } + + mappings := resp["model-mappings"] + if len(mappings) != 3 { + t.Fatalf("expected 3 mappings (m2, m4, m5), got %d", len(mappings)) + } + + expected := map[string]string{"m2": "t2-updated", "m4": "t4", "m5": "t5"} + found := make(map[string]string) + for _, m := range mappings { + found[m.From] = m.To + } + + for from, to := range expected { + if found[from] != to { + t.Errorf("mapping %s: expected %q, got %q", from, to, found[from]) + } + } +} + +// TestNilHandlerGetAmpCode verifies handler works with empty config. +func TestNilHandlerGetAmpCode(t *testing.T) { + cfg := &config.Config{} + h := management.NewHandler(cfg, "", nil) + r := setupAmpRouter(h) + + req := httptest.NewRequest(http.MethodGet, "/v0/management/ampcode", nil) + w := httptest.NewRecorder() + r.ServeHTTP(w, req) + + if w.Code != http.StatusOK { + t.Fatalf("expected status %d, got %d", http.StatusOK, w.Code) + } +} + +// TestEmptyConfigGetAmpModelMappings verifies GET returns empty array for fresh config. +func TestEmptyConfigGetAmpModelMappings(t *testing.T) { + cfg := &config.Config{} + tmpDir := t.TempDir() + configPath := filepath.Join(tmpDir, "config.yaml") + if err := os.WriteFile(configPath, []byte("port: 8080\n"), 0644); err != nil { + t.Fatalf("failed to write config: %v", err) + } + + h := management.NewHandler(cfg, configPath, nil) + r := setupAmpRouter(h) + + req := httptest.NewRequest(http.MethodGet, "/v0/management/ampcode/model-mappings", nil) + w := httptest.NewRecorder() + r.ServeHTTP(w, req) + + if w.Code != http.StatusOK { + t.Fatalf("expected status %d, got %d", http.StatusOK, w.Code) + } + + var resp map[string][]config.AmpModelMapping + if err := json.Unmarshal(w.Body.Bytes(), &resp); err != nil { + t.Fatalf("failed to unmarshal: %v", err) + } + + if len(resp["model-mappings"]) != 0 { + t.Errorf("expected 0 mappings, got %d", len(resp["model-mappings"])) + } +}