From 3a9ac7ef331da59da78d8504cab00a639f837cb7 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Ernesto=20Mart=C3=ADnez?= Date: Thu, 27 Nov 2025 20:14:30 +0100 Subject: [PATCH 01/38] feat(auth): add GitHub Copilot authentication and API integration Add complete GitHub Copilot support including: - Device flow OAuth authentication via GitHub's official client ID - Token management with automatic caching (25 min TTL) - OpenAI-compatible API executor for api.githubcopilot.com - 16 model definitions (GPT-5 variants, Claude variants, Gemini, Grok, Raptor) - CLI login command via -github-copilot-login flag - SDK authenticator and refresh registry integration Enables users to authenticate with their GitHub Copilot subscription and use it as a backend provider alongside existing providers. --- cmd/server/main.go | 5 + internal/auth/copilot/copilot_auth.go | 301 +++++++++++++++ internal/auth/copilot/errors.go | 183 +++++++++ internal/auth/copilot/oauth.go | 259 +++++++++++++ internal/auth/copilot/token.go | 93 +++++ internal/cmd/auth_manager.go | 3 +- internal/cmd/github_copilot_login.go | 44 +++ internal/registry/model_definitions.go | 184 +++++++++ .../executor/github_copilot_executor.go | 354 ++++++++++++++++++ sdk/auth/github_copilot.go | 133 +++++++ sdk/auth/refresh_registry.go | 1 + sdk/cliproxy/service.go | 4 + 12 files changed, 1563 insertions(+), 1 deletion(-) create mode 100644 internal/auth/copilot/copilot_auth.go create mode 100644 internal/auth/copilot/errors.go create mode 100644 internal/auth/copilot/oauth.go create mode 100644 internal/auth/copilot/token.go create mode 100644 internal/cmd/github_copilot_login.go create mode 100644 internal/runtime/executor/github_copilot_executor.go create mode 100644 sdk/auth/github_copilot.go diff --git a/cmd/server/main.go b/cmd/server/main.go index bbf500e7..b8ee66ba 100644 --- a/cmd/server/main.go +++ b/cmd/server/main.go @@ -62,6 +62,7 @@ func main() { var iflowCookie bool var noBrowser bool var antigravityLogin bool + var githubCopilotLogin bool var projectID string var vertexImport string var configPath string @@ -76,6 +77,7 @@ func main() { flag.BoolVar(&iflowCookie, "iflow-cookie", false, "Login to iFlow using Cookie") flag.BoolVar(&noBrowser, "no-browser", false, "Don't open browser automatically for OAuth") flag.BoolVar(&antigravityLogin, "antigravity-login", false, "Login to Antigravity using OAuth") + flag.BoolVar(&githubCopilotLogin, "github-copilot-login", false, "Login to GitHub Copilot using device flow") flag.StringVar(&projectID, "project_id", "", "Project ID (Gemini only, not required)") flag.StringVar(&configPath, "config", DefaultConfigPath, "Configure File Path") flag.StringVar(&vertexImport, "vertex-import", "", "Import Vertex service account key JSON file") @@ -436,6 +438,9 @@ func main() { } else if antigravityLogin { // Handle Antigravity login cmd.DoAntigravityLogin(cfg, options) + } else if githubCopilotLogin { + // Handle GitHub Copilot login + cmd.DoGitHubCopilotLogin(cfg, options) } else if codexLogin { // Handle Codex login cmd.DoCodexLogin(cfg, options) diff --git a/internal/auth/copilot/copilot_auth.go b/internal/auth/copilot/copilot_auth.go new file mode 100644 index 00000000..fbfb1762 --- /dev/null +++ b/internal/auth/copilot/copilot_auth.go @@ -0,0 +1,301 @@ +// Package copilot provides authentication and token management for GitHub Copilot API. +// It handles the OAuth2 device flow for secure authentication with the Copilot API. +package copilot + +import ( + "context" + "encoding/json" + "fmt" + "io" + "net/http" + "strings" + "time" + + "github.com/router-for-me/CLIProxyAPI/v6/internal/config" + "github.com/router-for-me/CLIProxyAPI/v6/internal/util" + log "github.com/sirupsen/logrus" +) + +const ( + // copilotAPITokenURL is the endpoint for getting Copilot API tokens from GitHub token. + copilotAPITokenURL = "https://api.github.com/copilot_internal/v2/token" + // copilotAPIEndpoint is the base URL for making API requests. + copilotAPIEndpoint = "https://api.githubcopilot.com" +) + +// CopilotAPIToken represents the Copilot API token response. +type CopilotAPIToken struct { + // Token is the JWT token for authenticating with the Copilot API. + Token string `json:"token"` + // ExpiresAt is the Unix timestamp when the token expires. + ExpiresAt int64 `json:"expires_at"` + // Endpoints contains the available API endpoints. + Endpoints struct { + API string `json:"api"` + Proxy string `json:"proxy"` + OriginTracker string `json:"origin-tracker"` + Telemetry string `json:"telemetry"` + } `json:"endpoints,omitempty"` + // ErrorDetails contains error information if the request failed. + ErrorDetails *struct { + URL string `json:"url"` + Message string `json:"message"` + DocumentationURL string `json:"documentation_url"` + } `json:"error_details,omitempty"` +} + +// CopilotAuth handles GitHub Copilot authentication flow. +// It provides methods for device flow authentication and token management. +type CopilotAuth struct { + httpClient *http.Client + deviceClient *DeviceFlowClient + cfg *config.Config +} + +// NewCopilotAuth creates a new CopilotAuth service instance. +// It initializes an HTTP client with proxy settings from the provided configuration. +func NewCopilotAuth(cfg *config.Config) *CopilotAuth { + return &CopilotAuth{ + httpClient: util.SetProxy(&cfg.SDKConfig, &http.Client{Timeout: 30 * time.Second}), + deviceClient: NewDeviceFlowClient(cfg), + cfg: cfg, + } +} + +// StartDeviceFlow initiates the device flow authentication. +// Returns the device code response containing the user code and verification URI. +func (c *CopilotAuth) StartDeviceFlow(ctx context.Context) (*DeviceCodeResponse, error) { + return c.deviceClient.RequestDeviceCode(ctx) +} + +// WaitForAuthorization polls for user authorization and returns the auth bundle. +func (c *CopilotAuth) WaitForAuthorization(ctx context.Context, deviceCode *DeviceCodeResponse) (*CopilotAuthBundle, error) { + tokenData, err := c.deviceClient.PollForToken(ctx, deviceCode) + if err != nil { + return nil, err + } + + // Fetch the GitHub username + username, err := c.deviceClient.FetchUserInfo(ctx, tokenData.AccessToken) + if err != nil { + log.Warnf("copilot: failed to fetch user info: %v", err) + username = "unknown" + } + + return &CopilotAuthBundle{ + TokenData: tokenData, + Username: username, + }, nil +} + +// GetCopilotAPIToken exchanges a GitHub access token for a Copilot API token. +// This token is used to make authenticated requests to the Copilot API. +func (c *CopilotAuth) GetCopilotAPIToken(ctx context.Context, githubAccessToken string) (*CopilotAPIToken, error) { + if githubAccessToken == "" { + return nil, NewAuthenticationError(ErrTokenExchangeFailed, fmt.Errorf("github access token is empty")) + } + + req, err := http.NewRequestWithContext(ctx, http.MethodGet, copilotAPITokenURL, nil) + if err != nil { + return nil, NewAuthenticationError(ErrTokenExchangeFailed, err) + } + + req.Header.Set("Authorization", "token "+githubAccessToken) + req.Header.Set("Accept", "application/json") + req.Header.Set("User-Agent", "GithubCopilot/1.0") + req.Header.Set("Editor-Version", "vscode/1.100.0") + req.Header.Set("Editor-Plugin-Version", "copilot/1.300.0") + + resp, err := c.httpClient.Do(req) + if err != nil { + return nil, NewAuthenticationError(ErrTokenExchangeFailed, err) + } + defer func() { + if errClose := resp.Body.Close(); errClose != nil { + log.Errorf("copilot api token: close body error: %v", errClose) + } + }() + + bodyBytes, err := io.ReadAll(resp.Body) + if err != nil { + return nil, NewAuthenticationError(ErrTokenExchangeFailed, err) + } + + if resp.StatusCode < http.StatusOK || resp.StatusCode >= http.StatusMultipleChoices { + return nil, NewAuthenticationError(ErrTokenExchangeFailed, + fmt.Errorf("status %d: %s", resp.StatusCode, string(bodyBytes))) + } + + var apiToken CopilotAPIToken + if err = json.Unmarshal(bodyBytes, &apiToken); err != nil { + return nil, NewAuthenticationError(ErrTokenExchangeFailed, err) + } + + if apiToken.Token == "" { + return nil, NewAuthenticationError(ErrTokenExchangeFailed, fmt.Errorf("empty copilot api token")) + } + + return &apiToken, nil +} + +// ValidateToken checks if a GitHub access token is valid by attempting to fetch user info. +func (c *CopilotAuth) ValidateToken(ctx context.Context, accessToken string) (bool, string, error) { + if accessToken == "" { + return false, "", nil + } + + username, err := c.deviceClient.FetchUserInfo(ctx, accessToken) + if err != nil { + return false, "", err + } + + return true, username, nil +} + +// CreateTokenStorage creates a new CopilotTokenStorage from auth bundle. +func (c *CopilotAuth) CreateTokenStorage(bundle *CopilotAuthBundle) *CopilotTokenStorage { + return &CopilotTokenStorage{ + AccessToken: bundle.TokenData.AccessToken, + TokenType: bundle.TokenData.TokenType, + Scope: bundle.TokenData.Scope, + Username: bundle.Username, + Type: "github-copilot", + } +} + +// LoadAndValidateToken loads a token from storage and validates it. +// Returns the storage if valid, or an error if the token is invalid or expired. +func (c *CopilotAuth) LoadAndValidateToken(ctx context.Context, storage *CopilotTokenStorage) (bool, error) { + if storage == nil || storage.AccessToken == "" { + return false, fmt.Errorf("no token available") + } + + // Check if we can still use the GitHub token to get a Copilot API token + apiToken, err := c.GetCopilotAPIToken(ctx, storage.AccessToken) + if err != nil { + return false, err + } + + // Check if the API token is expired + if apiToken.ExpiresAt > 0 && time.Now().Unix() >= apiToken.ExpiresAt { + return false, fmt.Errorf("copilot api token expired") + } + + return true, nil +} + +// GetAPIEndpoint returns the Copilot API endpoint URL. +func (c *CopilotAuth) GetAPIEndpoint() string { + return copilotAPIEndpoint +} + +// MakeAuthenticatedRequest creates an authenticated HTTP request to the Copilot API. +func (c *CopilotAuth) MakeAuthenticatedRequest(ctx context.Context, method, url string, body io.Reader, apiToken *CopilotAPIToken) (*http.Request, error) { + req, err := http.NewRequestWithContext(ctx, method, url, body) + if err != nil { + return nil, fmt.Errorf("failed to create request: %w", err) + } + + req.Header.Set("Authorization", "Bearer "+apiToken.Token) + req.Header.Set("Content-Type", "application/json") + req.Header.Set("Accept", "application/json") + req.Header.Set("User-Agent", "GithubCopilot/1.0") + req.Header.Set("Editor-Version", "vscode/1.100.0") + req.Header.Set("Editor-Plugin-Version", "copilot/1.300.0") + req.Header.Set("Openai-Intent", "conversation-panel") + req.Header.Set("Copilot-Integration-Id", "vscode-chat") + + return req, nil +} + +// CachedAPIToken manages caching of Copilot API tokens. +type CachedAPIToken struct { + Token *CopilotAPIToken + ExpiresAt time.Time +} + +// IsExpired checks if the cached token has expired. +func (c *CachedAPIToken) IsExpired() bool { + if c.Token == nil { + return true + } + // Add a 5-minute buffer before expiration + return time.Now().Add(5 * time.Minute).After(c.ExpiresAt) +} + +// TokenManager handles caching and refreshing of Copilot API tokens. +type TokenManager struct { + auth *CopilotAuth + githubToken string + cachedToken *CachedAPIToken +} + +// NewTokenManager creates a new token manager for handling Copilot API tokens. +func NewTokenManager(auth *CopilotAuth, githubToken string) *TokenManager { + return &TokenManager{ + auth: auth, + githubToken: githubToken, + } +} + +// GetToken returns a valid Copilot API token, refreshing if necessary. +func (tm *TokenManager) GetToken(ctx context.Context) (*CopilotAPIToken, error) { + if tm.cachedToken != nil && !tm.cachedToken.IsExpired() { + return tm.cachedToken.Token, nil + } + + // Fetch a new API token + apiToken, err := tm.auth.GetCopilotAPIToken(ctx, tm.githubToken) + if err != nil { + return nil, err + } + + // Cache the token + expiresAt := time.Now().Add(30 * time.Minute) // Default 30 min cache + if apiToken.ExpiresAt > 0 { + expiresAt = time.Unix(apiToken.ExpiresAt, 0) + } + + tm.cachedToken = &CachedAPIToken{ + Token: apiToken, + ExpiresAt: expiresAt, + } + + return apiToken, nil +} + +// GetAuthorizationHeader returns the authorization header value for API requests. +func (tm *TokenManager) GetAuthorizationHeader(ctx context.Context) (string, error) { + token, err := tm.GetToken(ctx) + if err != nil { + return "", err + } + return "Bearer " + token.Token, nil +} + +// UpdateGitHubToken updates the GitHub access token used for getting API tokens. +func (tm *TokenManager) UpdateGitHubToken(githubToken string) { + tm.githubToken = githubToken + tm.cachedToken = nil // Invalidate cache +} + +// BuildChatCompletionURL builds the URL for chat completions API. +func BuildChatCompletionURL() string { + return copilotAPIEndpoint + "/chat/completions" +} + +// BuildModelsURL builds the URL for listing available models. +func BuildModelsURL() string { + return copilotAPIEndpoint + "/models" +} + +// ExtractBearerToken extracts the bearer token from an Authorization header. +func ExtractBearerToken(authHeader string) string { + if strings.HasPrefix(authHeader, "Bearer ") { + return strings.TrimPrefix(authHeader, "Bearer ") + } + if strings.HasPrefix(authHeader, "token ") { + return strings.TrimPrefix(authHeader, "token ") + } + return authHeader +} diff --git a/internal/auth/copilot/errors.go b/internal/auth/copilot/errors.go new file mode 100644 index 00000000..01f8e754 --- /dev/null +++ b/internal/auth/copilot/errors.go @@ -0,0 +1,183 @@ +package copilot + +import ( + "errors" + "fmt" + "net/http" +) + +// OAuthError represents an OAuth-specific error. +type OAuthError struct { + // Code is the OAuth error code. + Code string `json:"error"` + // Description is a human-readable description of the error. + Description string `json:"error_description,omitempty"` + // URI is a URI identifying a human-readable web page with information about the error. + URI string `json:"error_uri,omitempty"` + // StatusCode is the HTTP status code associated with the error. + StatusCode int `json:"-"` +} + +// Error returns a string representation of the OAuth error. +func (e *OAuthError) Error() string { + if e.Description != "" { + return fmt.Sprintf("OAuth error %s: %s", e.Code, e.Description) + } + return fmt.Sprintf("OAuth error: %s", e.Code) +} + +// NewOAuthError creates a new OAuth error with the specified code, description, and status code. +func NewOAuthError(code, description string, statusCode int) *OAuthError { + return &OAuthError{ + Code: code, + Description: description, + StatusCode: statusCode, + } +} + +// AuthenticationError represents authentication-related errors. +type AuthenticationError struct { + // Type is the type of authentication error. + Type string `json:"type"` + // Message is a human-readable message describing the error. + Message string `json:"message"` + // Code is the HTTP status code associated with the error. + Code int `json:"code"` + // Cause is the underlying error that caused this authentication error. + Cause error `json:"-"` +} + +// Error returns a string representation of the authentication error. +func (e *AuthenticationError) Error() string { + if e.Cause != nil { + return fmt.Sprintf("%s: %s (caused by: %v)", e.Type, e.Message, e.Cause) + } + return fmt.Sprintf("%s: %s", e.Type, e.Message) +} + +// Common authentication error types for GitHub Copilot device flow. +var ( + // ErrDeviceCodeFailed represents an error when requesting the device code fails. + ErrDeviceCodeFailed = &AuthenticationError{ + Type: "device_code_failed", + Message: "Failed to request device code from GitHub", + Code: http.StatusBadRequest, + } + + // ErrDeviceCodeExpired represents an error when the device code has expired. + ErrDeviceCodeExpired = &AuthenticationError{ + Type: "device_code_expired", + Message: "Device code has expired. Please try again.", + Code: http.StatusGone, + } + + // ErrAuthorizationPending represents a pending authorization state (not an error, used for polling). + ErrAuthorizationPending = &AuthenticationError{ + Type: "authorization_pending", + Message: "Authorization is pending. Waiting for user to authorize.", + Code: http.StatusAccepted, + } + + // ErrSlowDown represents a request to slow down polling. + ErrSlowDown = &AuthenticationError{ + Type: "slow_down", + Message: "Polling too frequently. Slowing down.", + Code: http.StatusTooManyRequests, + } + + // ErrAccessDenied represents an error when the user denies authorization. + ErrAccessDenied = &AuthenticationError{ + Type: "access_denied", + Message: "User denied authorization", + Code: http.StatusForbidden, + } + + // ErrTokenExchangeFailed represents an error when token exchange fails. + ErrTokenExchangeFailed = &AuthenticationError{ + Type: "token_exchange_failed", + Message: "Failed to exchange device code for access token", + Code: http.StatusBadRequest, + } + + // ErrPollingTimeout represents an error when polling times out. + ErrPollingTimeout = &AuthenticationError{ + Type: "polling_timeout", + Message: "Timeout waiting for user authorization", + Code: http.StatusRequestTimeout, + } + + // ErrUserInfoFailed represents an error when fetching user info fails. + ErrUserInfoFailed = &AuthenticationError{ + Type: "user_info_failed", + Message: "Failed to fetch GitHub user information", + Code: http.StatusBadRequest, + } +) + +// NewAuthenticationError creates a new authentication error with a cause based on a base error. +func NewAuthenticationError(baseErr *AuthenticationError, cause error) *AuthenticationError { + return &AuthenticationError{ + Type: baseErr.Type, + Message: baseErr.Message, + Code: baseErr.Code, + Cause: cause, + } +} + +// IsAuthenticationError checks if an error is an authentication error. +func IsAuthenticationError(err error) bool { + var authenticationError *AuthenticationError + ok := errors.As(err, &authenticationError) + return ok +} + +// IsOAuthError checks if an error is an OAuth error. +func IsOAuthError(err error) bool { + var oAuthError *OAuthError + ok := errors.As(err, &oAuthError) + return ok +} + +// GetUserFriendlyMessage returns a user-friendly error message based on the error type. +func GetUserFriendlyMessage(err error) string { + switch { + case IsAuthenticationError(err): + var authErr *AuthenticationError + errors.As(err, &authErr) + switch authErr.Type { + case "device_code_failed": + return "Failed to start GitHub authentication. Please check your network connection and try again." + case "device_code_expired": + return "The authentication code has expired. Please try again." + case "authorization_pending": + return "Waiting for you to authorize the application on GitHub." + case "slow_down": + return "Please wait a moment before trying again." + case "access_denied": + return "Authentication was cancelled or denied." + case "token_exchange_failed": + return "Failed to complete authentication. Please try again." + case "polling_timeout": + return "Authentication timed out. Please try again." + case "user_info_failed": + return "Failed to get your GitHub account information. Please try again." + default: + return "Authentication failed. Please try again." + } + case IsOAuthError(err): + var oauthErr *OAuthError + errors.As(err, &oauthErr) + switch oauthErr.Code { + case "access_denied": + return "Authentication was cancelled or denied." + case "invalid_request": + return "Invalid authentication request. Please try again." + case "server_error": + return "GitHub server error. Please try again later." + default: + return fmt.Sprintf("Authentication failed: %s", oauthErr.Description) + } + default: + return "An unexpected error occurred. Please try again." + } +} diff --git a/internal/auth/copilot/oauth.go b/internal/auth/copilot/oauth.go new file mode 100644 index 00000000..3f7877b6 --- /dev/null +++ b/internal/auth/copilot/oauth.go @@ -0,0 +1,259 @@ +package copilot + +import ( + "context" + "encoding/json" + "fmt" + "io" + "net/http" + "net/url" + "strings" + "time" + + "github.com/router-for-me/CLIProxyAPI/v6/internal/config" + "github.com/router-for-me/CLIProxyAPI/v6/internal/util" + log "github.com/sirupsen/logrus" +) + +const ( + // copilotClientID is GitHub's Copilot CLI OAuth client ID. + copilotClientID = "Iv1.b507a08c87ecfe98" + // copilotDeviceCodeURL is the endpoint for requesting device codes. + copilotDeviceCodeURL = "https://github.com/login/device/code" + // copilotTokenURL is the endpoint for exchanging device codes for tokens. + copilotTokenURL = "https://github.com/login/oauth/access_token" + // copilotUserInfoURL is the endpoint for fetching GitHub user information. + copilotUserInfoURL = "https://api.github.com/user" + // defaultPollInterval is the default interval for polling token endpoint. + defaultPollInterval = 5 * time.Second + // maxPollDuration is the maximum time to wait for user authorization. + maxPollDuration = 15 * time.Minute +) + +// DeviceFlowClient handles the OAuth2 device flow for GitHub Copilot. +type DeviceFlowClient struct { + httpClient *http.Client + cfg *config.Config +} + +// NewDeviceFlowClient creates a new device flow client. +func NewDeviceFlowClient(cfg *config.Config) *DeviceFlowClient { + client := &http.Client{Timeout: 30 * time.Second} + if cfg != nil { + client = util.SetProxy(&cfg.SDKConfig, client) + } + return &DeviceFlowClient{ + httpClient: client, + cfg: cfg, + } +} + +// RequestDeviceCode initiates the device flow by requesting a device code from GitHub. +func (c *DeviceFlowClient) RequestDeviceCode(ctx context.Context) (*DeviceCodeResponse, error) { + data := url.Values{} + data.Set("client_id", copilotClientID) + data.Set("scope", "user:email") + + req, err := http.NewRequestWithContext(ctx, http.MethodPost, copilotDeviceCodeURL, strings.NewReader(data.Encode())) + if err != nil { + return nil, NewAuthenticationError(ErrDeviceCodeFailed, err) + } + req.Header.Set("Content-Type", "application/x-www-form-urlencoded") + req.Header.Set("Accept", "application/json") + + resp, err := c.httpClient.Do(req) + if err != nil { + return nil, NewAuthenticationError(ErrDeviceCodeFailed, err) + } + defer func() { + if errClose := resp.Body.Close(); errClose != nil { + log.Errorf("copilot device code: close body error: %v", errClose) + } + }() + + if resp.StatusCode < http.StatusOK || resp.StatusCode >= http.StatusMultipleChoices { + bodyBytes, _ := io.ReadAll(resp.Body) + return nil, NewAuthenticationError(ErrDeviceCodeFailed, fmt.Errorf("status %d: %s", resp.StatusCode, string(bodyBytes))) + } + + var deviceCode DeviceCodeResponse + if err = json.NewDecoder(resp.Body).Decode(&deviceCode); err != nil { + return nil, NewAuthenticationError(ErrDeviceCodeFailed, err) + } + + return &deviceCode, nil +} + +// PollForToken polls the token endpoint until the user authorizes or the device code expires. +func (c *DeviceFlowClient) PollForToken(ctx context.Context, deviceCode *DeviceCodeResponse) (*CopilotTokenData, error) { + if deviceCode == nil { + return nil, NewAuthenticationError(ErrTokenExchangeFailed, fmt.Errorf("device code is nil")) + } + + interval := time.Duration(deviceCode.Interval) * time.Second + if interval < defaultPollInterval { + interval = defaultPollInterval + } + + deadline := time.Now().Add(maxPollDuration) + if deviceCode.ExpiresIn > 0 { + codeDeadline := time.Now().Add(time.Duration(deviceCode.ExpiresIn) * time.Second) + if codeDeadline.Before(deadline) { + deadline = codeDeadline + } + } + + ticker := time.NewTicker(interval) + defer ticker.Stop() + + for { + select { + case <-ctx.Done(): + return nil, NewAuthenticationError(ErrPollingTimeout, ctx.Err()) + case <-ticker.C: + if time.Now().After(deadline) { + return nil, ErrPollingTimeout + } + + token, err := c.exchangeDeviceCode(ctx, deviceCode.DeviceCode) + if err != nil { + var authErr *AuthenticationError + if IsAuthenticationError(err) { + if ok := (err.(*AuthenticationError)); ok != nil { + authErr = ok + } + } + if authErr != nil { + switch authErr.Type { + case ErrAuthorizationPending.Type: + // Continue polling + continue + case ErrSlowDown.Type: + // Increase interval and continue + interval += 5 * time.Second + ticker.Reset(interval) + continue + case ErrDeviceCodeExpired.Type: + return nil, err + case ErrAccessDenied.Type: + return nil, err + } + } + return nil, err + } + return token, nil + } + } +} + +// exchangeDeviceCode attempts to exchange the device code for an access token. +func (c *DeviceFlowClient) exchangeDeviceCode(ctx context.Context, deviceCode string) (*CopilotTokenData, error) { + data := url.Values{} + data.Set("client_id", copilotClientID) + data.Set("device_code", deviceCode) + data.Set("grant_type", "urn:ietf:params:oauth:grant-type:device_code") + + req, err := http.NewRequestWithContext(ctx, http.MethodPost, copilotTokenURL, strings.NewReader(data.Encode())) + if err != nil { + return nil, NewAuthenticationError(ErrTokenExchangeFailed, err) + } + req.Header.Set("Content-Type", "application/x-www-form-urlencoded") + req.Header.Set("Accept", "application/json") + + resp, err := c.httpClient.Do(req) + if err != nil { + return nil, NewAuthenticationError(ErrTokenExchangeFailed, err) + } + defer func() { + if errClose := resp.Body.Close(); errClose != nil { + log.Errorf("copilot token exchange: close body error: %v", errClose) + } + }() + + bodyBytes, err := io.ReadAll(resp.Body) + if err != nil { + return nil, NewAuthenticationError(ErrTokenExchangeFailed, err) + } + + // GitHub returns 200 for both success and error cases in device flow + // Check for OAuth error response first + var oauthResp struct { + Error string `json:"error"` + ErrorDescription string `json:"error_description"` + AccessToken string `json:"access_token"` + TokenType string `json:"token_type"` + Scope string `json:"scope"` + } + + if err = json.Unmarshal(bodyBytes, &oauthResp); err != nil { + return nil, NewAuthenticationError(ErrTokenExchangeFailed, err) + } + + if oauthResp.Error != "" { + switch oauthResp.Error { + case "authorization_pending": + return nil, ErrAuthorizationPending + case "slow_down": + return nil, ErrSlowDown + case "expired_token": + return nil, ErrDeviceCodeExpired + case "access_denied": + return nil, ErrAccessDenied + default: + return nil, NewOAuthError(oauthResp.Error, oauthResp.ErrorDescription, resp.StatusCode) + } + } + + if oauthResp.AccessToken == "" { + return nil, NewAuthenticationError(ErrTokenExchangeFailed, fmt.Errorf("empty access token")) + } + + return &CopilotTokenData{ + AccessToken: oauthResp.AccessToken, + TokenType: oauthResp.TokenType, + Scope: oauthResp.Scope, + }, nil +} + +// FetchUserInfo retrieves the GitHub username for the authenticated user. +func (c *DeviceFlowClient) FetchUserInfo(ctx context.Context, accessToken string) (string, error) { + if accessToken == "" { + return "", NewAuthenticationError(ErrUserInfoFailed, fmt.Errorf("access token is empty")) + } + + req, err := http.NewRequestWithContext(ctx, http.MethodGet, copilotUserInfoURL, nil) + if err != nil { + return "", NewAuthenticationError(ErrUserInfoFailed, err) + } + req.Header.Set("Authorization", "Bearer "+accessToken) + req.Header.Set("Accept", "application/json") + req.Header.Set("User-Agent", "CLIProxyAPI") + + resp, err := c.httpClient.Do(req) + if err != nil { + return "", NewAuthenticationError(ErrUserInfoFailed, err) + } + defer func() { + if errClose := resp.Body.Close(); errClose != nil { + log.Errorf("copilot user info: close body error: %v", errClose) + } + }() + + if resp.StatusCode < http.StatusOK || resp.StatusCode >= http.StatusMultipleChoices { + bodyBytes, _ := io.ReadAll(resp.Body) + return "", NewAuthenticationError(ErrUserInfoFailed, fmt.Errorf("status %d: %s", resp.StatusCode, string(bodyBytes))) + } + + var userInfo struct { + Login string `json:"login"` + } + if err = json.NewDecoder(resp.Body).Decode(&userInfo); err != nil { + return "", NewAuthenticationError(ErrUserInfoFailed, err) + } + + if userInfo.Login == "" { + return "", NewAuthenticationError(ErrUserInfoFailed, fmt.Errorf("empty username")) + } + + return userInfo.Login, nil +} diff --git a/internal/auth/copilot/token.go b/internal/auth/copilot/token.go new file mode 100644 index 00000000..4e5eed6c --- /dev/null +++ b/internal/auth/copilot/token.go @@ -0,0 +1,93 @@ +// Package copilot provides authentication and token management functionality +// for GitHub Copilot AI services. It handles OAuth2 device flow token storage, +// serialization, and retrieval for maintaining authenticated sessions with the Copilot API. +package copilot + +import ( + "encoding/json" + "fmt" + "os" + "path/filepath" + + "github.com/router-for-me/CLIProxyAPI/v6/internal/misc" +) + +// CopilotTokenStorage stores OAuth2 token information for GitHub Copilot API authentication. +// It maintains compatibility with the existing auth system while adding Copilot-specific fields +// for managing access tokens and user account information. +type CopilotTokenStorage struct { + // AccessToken is the OAuth2 access token used for authenticating API requests. + AccessToken string `json:"access_token"` + // TokenType is the type of token, typically "bearer". + TokenType string `json:"token_type"` + // Scope is the OAuth2 scope granted to the token. + Scope string `json:"scope"` + // ExpiresAt is the timestamp when the access token expires (if provided). + ExpiresAt string `json:"expires_at,omitempty"` + // Username is the GitHub username associated with this token. + Username string `json:"username"` + // Type indicates the authentication provider type, always "github-copilot" for this storage. + Type string `json:"type"` +} + +// CopilotTokenData holds the raw OAuth token response from GitHub. +type CopilotTokenData struct { + // AccessToken is the OAuth2 access token. + AccessToken string `json:"access_token"` + // TokenType is the type of token, typically "bearer". + TokenType string `json:"token_type"` + // Scope is the OAuth2 scope granted to the token. + Scope string `json:"scope"` +} + +// CopilotAuthBundle bundles authentication data for storage. +type CopilotAuthBundle struct { + // TokenData contains the OAuth token information. + TokenData *CopilotTokenData + // Username is the GitHub username. + Username string +} + +// DeviceCodeResponse represents GitHub's device code response. +type DeviceCodeResponse struct { + // DeviceCode is the device verification code. + DeviceCode string `json:"device_code"` + // UserCode is the code the user must enter at the verification URI. + UserCode string `json:"user_code"` + // VerificationURI is the URL where the user should enter the code. + VerificationURI string `json:"verification_uri"` + // ExpiresIn is the number of seconds until the device code expires. + ExpiresIn int `json:"expires_in"` + // Interval is the minimum number of seconds to wait between polling requests. + Interval int `json:"interval"` +} + +// SaveTokenToFile serializes the Copilot token storage to a JSON file. +// This method creates the necessary directory structure and writes the token +// data in JSON format to the specified file path for persistent storage. +// +// Parameters: +// - authFilePath: The full path where the token file should be saved +// +// Returns: +// - error: An error if the operation fails, nil otherwise +func (ts *CopilotTokenStorage) SaveTokenToFile(authFilePath string) error { + misc.LogSavingCredentials(authFilePath) + ts.Type = "github-copilot" + if err := os.MkdirAll(filepath.Dir(authFilePath), 0700); err != nil { + return fmt.Errorf("failed to create directory: %v", err) + } + + f, err := os.Create(authFilePath) + if err != nil { + return fmt.Errorf("failed to create token file: %w", err) + } + defer func() { + _ = f.Close() + }() + + if err = json.NewEncoder(f).Encode(ts); err != nil { + return fmt.Errorf("failed to write token to file: %w", err) + } + return nil +} diff --git a/internal/cmd/auth_manager.go b/internal/cmd/auth_manager.go index e6caa954..baf43bae 100644 --- a/internal/cmd/auth_manager.go +++ b/internal/cmd/auth_manager.go @@ -6,7 +6,7 @@ import ( // newAuthManager creates a new authentication manager instance with all supported // authenticators and a file-based token store. It initializes authenticators for -// Gemini, Codex, Claude, and Qwen providers. +// Gemini, Codex, Claude, Qwen, IFlow, Antigravity, and GitHub Copilot providers. // // Returns: // - *sdkAuth.Manager: A configured authentication manager instance @@ -19,6 +19,7 @@ func newAuthManager() *sdkAuth.Manager { sdkAuth.NewQwenAuthenticator(), sdkAuth.NewIFlowAuthenticator(), sdkAuth.NewAntigravityAuthenticator(), + sdkAuth.NewGitHubCopilotAuthenticator(), ) return manager } diff --git a/internal/cmd/github_copilot_login.go b/internal/cmd/github_copilot_login.go new file mode 100644 index 00000000..056e811f --- /dev/null +++ b/internal/cmd/github_copilot_login.go @@ -0,0 +1,44 @@ +package cmd + +import ( + "context" + "fmt" + + "github.com/router-for-me/CLIProxyAPI/v6/internal/config" + sdkAuth "github.com/router-for-me/CLIProxyAPI/v6/sdk/auth" + log "github.com/sirupsen/logrus" +) + +// DoGitHubCopilotLogin triggers the OAuth device flow for GitHub Copilot and saves tokens. +// It initiates the device flow authentication, displays the user code for the user to enter +// at GitHub's verification URL, and waits for authorization before saving the tokens. +// +// Parameters: +// - cfg: The application configuration containing proxy and auth directory settings +// - options: Login options including browser behavior settings +func DoGitHubCopilotLogin(cfg *config.Config, options *LoginOptions) { + if options == nil { + options = &LoginOptions{} + } + + manager := newAuthManager() + authOpts := &sdkAuth.LoginOptions{ + NoBrowser: options.NoBrowser, + Metadata: map[string]string{}, + Prompt: options.Prompt, + } + + record, savedPath, err := manager.Login(context.Background(), "github-copilot", cfg, authOpts) + if err != nil { + log.Errorf("GitHub Copilot authentication failed: %v", err) + return + } + + if savedPath != "" { + fmt.Printf("Authentication saved to %s\n", savedPath) + } + if record != nil && record.Label != "" { + fmt.Printf("Authenticated as %s\n", record.Label) + } + fmt.Println("GitHub Copilot authentication successful!") +} diff --git a/internal/registry/model_definitions.go b/internal/registry/model_definitions.go index 51f984f2..42e68239 100644 --- a/internal/registry/model_definitions.go +++ b/internal/registry/model_definitions.go @@ -984,3 +984,187 @@ func GetIFlowModels() []*ModelInfo { } return models } + +// GetGitHubCopilotModels returns the available models for GitHub Copilot. +// These models are available through the GitHub Copilot API at api.githubcopilot.com. +func GetGitHubCopilotModels() []*ModelInfo { + now := int64(1732752000) // 2024-11-27 + return []*ModelInfo{ + { + ID: "gpt-4.1", + Object: "model", + Created: now, + OwnedBy: "github-copilot", + Type: "github-copilot", + DisplayName: "GPT-4.1", + Description: "OpenAI GPT-4.1 via GitHub Copilot", + ContextLength: 128000, + MaxCompletionTokens: 16384, + }, + { + ID: "gpt-5", + Object: "model", + Created: now, + OwnedBy: "github-copilot", + Type: "github-copilot", + DisplayName: "GPT-5", + Description: "OpenAI GPT-5 via GitHub Copilot", + ContextLength: 200000, + MaxCompletionTokens: 32768, + }, + { + ID: "gpt-5-mini", + Object: "model", + Created: now, + OwnedBy: "github-copilot", + Type: "github-copilot", + DisplayName: "GPT-5 Mini", + Description: "OpenAI GPT-5 Mini via GitHub Copilot", + ContextLength: 128000, + MaxCompletionTokens: 16384, + }, + { + ID: "gpt-5-codex", + Object: "model", + Created: now, + OwnedBy: "github-copilot", + Type: "github-copilot", + DisplayName: "GPT-5 Codex", + Description: "OpenAI GPT-5 Codex via GitHub Copilot", + ContextLength: 200000, + MaxCompletionTokens: 32768, + }, + { + ID: "gpt-5.1", + Object: "model", + Created: now, + OwnedBy: "github-copilot", + Type: "github-copilot", + DisplayName: "GPT-5.1", + Description: "OpenAI GPT-5.1 via GitHub Copilot", + ContextLength: 200000, + MaxCompletionTokens: 32768, + }, + { + ID: "gpt-5.1-codex", + Object: "model", + Created: now, + OwnedBy: "github-copilot", + Type: "github-copilot", + DisplayName: "GPT-5.1 Codex", + Description: "OpenAI GPT-5.1 Codex via GitHub Copilot", + ContextLength: 200000, + MaxCompletionTokens: 32768, + }, + { + ID: "gpt-5.1-codex-mini", + Object: "model", + Created: now, + OwnedBy: "github-copilot", + Type: "github-copilot", + DisplayName: "GPT-5.1 Codex Mini", + Description: "OpenAI GPT-5.1 Codex Mini via GitHub Copilot", + ContextLength: 128000, + MaxCompletionTokens: 16384, + }, + { + ID: "claude-haiku-4.5", + Object: "model", + Created: now, + OwnedBy: "github-copilot", + Type: "github-copilot", + DisplayName: "Claude Haiku 4.5", + Description: "Anthropic Claude Haiku 4.5 via GitHub Copilot", + ContextLength: 200000, + MaxCompletionTokens: 64000, + }, + { + ID: "claude-opus-4.1", + Object: "model", + Created: now, + OwnedBy: "github-copilot", + Type: "github-copilot", + DisplayName: "Claude Opus 4.1", + Description: "Anthropic Claude Opus 4.1 via GitHub Copilot", + ContextLength: 200000, + MaxCompletionTokens: 32000, + }, + { + ID: "claude-opus-4.5", + Object: "model", + Created: now, + OwnedBy: "github-copilot", + Type: "github-copilot", + DisplayName: "Claude Opus 4.5", + Description: "Anthropic Claude Opus 4.5 via GitHub Copilot", + ContextLength: 200000, + MaxCompletionTokens: 64000, + }, + { + ID: "claude-sonnet-4", + Object: "model", + Created: now, + OwnedBy: "github-copilot", + Type: "github-copilot", + DisplayName: "Claude Sonnet 4", + Description: "Anthropic Claude Sonnet 4 via GitHub Copilot", + ContextLength: 200000, + MaxCompletionTokens: 64000, + }, + { + ID: "claude-sonnet-4.5", + Object: "model", + Created: now, + OwnedBy: "github-copilot", + Type: "github-copilot", + DisplayName: "Claude Sonnet 4.5", + Description: "Anthropic Claude Sonnet 4.5 via GitHub Copilot", + ContextLength: 200000, + MaxCompletionTokens: 64000, + }, + { + ID: "gemini-2.5-pro", + Object: "model", + Created: now, + OwnedBy: "github-copilot", + Type: "github-copilot", + DisplayName: "Gemini 2.5 Pro", + Description: "Google Gemini 2.5 Pro via GitHub Copilot", + ContextLength: 1048576, + MaxCompletionTokens: 65536, + }, + { + ID: "gemini-3-pro", + Object: "model", + Created: now, + OwnedBy: "github-copilot", + Type: "github-copilot", + DisplayName: "Gemini 3 Pro", + Description: "Google Gemini 3 Pro via GitHub Copilot", + ContextLength: 1048576, + MaxCompletionTokens: 65536, + }, + { + ID: "grok-code-fast-1", + Object: "model", + Created: now, + OwnedBy: "github-copilot", + Type: "github-copilot", + DisplayName: "Grok Code Fast 1", + Description: "xAI Grok Code Fast 1 via GitHub Copilot", + ContextLength: 128000, + MaxCompletionTokens: 16384, + }, + { + ID: "raptor-mini", + Object: "model", + Created: now, + OwnedBy: "github-copilot", + Type: "github-copilot", + DisplayName: "Raptor Mini", + Description: "Raptor Mini via GitHub Copilot", + ContextLength: 128000, + MaxCompletionTokens: 16384, + }, + } +} diff --git a/internal/runtime/executor/github_copilot_executor.go b/internal/runtime/executor/github_copilot_executor.go new file mode 100644 index 00000000..0d1240f7 --- /dev/null +++ b/internal/runtime/executor/github_copilot_executor.go @@ -0,0 +1,354 @@ +package executor + +import ( + "bufio" + "bytes" + "context" + "fmt" + "io" + "net/http" + "time" + + "github.com/google/uuid" + copilotauth "github.com/router-for-me/CLIProxyAPI/v6/internal/auth/copilot" + "github.com/router-for-me/CLIProxyAPI/v6/internal/config" + cliproxyauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth" + cliproxyexecutor "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/executor" + sdktranslator "github.com/router-for-me/CLIProxyAPI/v6/sdk/translator" + log "github.com/sirupsen/logrus" + "github.com/tidwall/sjson" +) + +const ( + githubCopilotBaseURL = "https://api.githubcopilot.com" + githubCopilotChatPath = "/chat/completions" + githubCopilotAuthType = "github-copilot" + githubCopilotTokenCacheTTL = 25 * time.Minute +) + +// GitHubCopilotExecutor handles requests to the GitHub Copilot API. +type GitHubCopilotExecutor struct { + cfg *config.Config +} + +// NewGitHubCopilotExecutor constructs a new executor instance. +func NewGitHubCopilotExecutor(cfg *config.Config) *GitHubCopilotExecutor { + return &GitHubCopilotExecutor{cfg: cfg} +} + +// Identifier implements ProviderExecutor. +func (e *GitHubCopilotExecutor) Identifier() string { return githubCopilotAuthType } + +// PrepareRequest implements ProviderExecutor. +func (e *GitHubCopilotExecutor) PrepareRequest(_ *http.Request, _ *cliproxyauth.Auth) error { + return nil +} + +// Execute handles non-streaming requests to GitHub Copilot. +func (e *GitHubCopilotExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (resp cliproxyexecutor.Response, err error) { + apiToken, errToken := e.ensureAPIToken(ctx, auth) + if errToken != nil { + return resp, errToken + } + + reporter := newUsageReporter(ctx, e.Identifier(), req.Model, auth) + defer reporter.trackFailure(ctx, &err) + + from := opts.SourceFormat + to := sdktranslator.FromString("openai") + body := sdktranslator.TranslateRequest(from, to, req.Model, bytes.Clone(req.Payload), false) + body = e.normalizeModel(req.Model, body) + body = applyPayloadConfig(e.cfg, req.Model, body) + body, _ = sjson.SetBytes(body, "stream", false) + + url := githubCopilotBaseURL + githubCopilotChatPath + httpReq, err := http.NewRequestWithContext(ctx, http.MethodPost, url, bytes.NewReader(body)) + if err != nil { + return resp, err + } + e.applyHeaders(httpReq, apiToken) + + var authID, authLabel, authType, authValue string + if auth != nil { + authID = auth.ID + authLabel = auth.Label + authType, authValue = auth.AccountInfo() + } + recordAPIRequest(ctx, e.cfg, upstreamRequestLog{ + URL: url, + Method: http.MethodPost, + Headers: httpReq.Header.Clone(), + Body: body, + Provider: e.Identifier(), + AuthID: authID, + AuthLabel: authLabel, + AuthType: authType, + AuthValue: authValue, + }) + + httpClient := newProxyAwareHTTPClient(ctx, e.cfg, auth, 0) + httpResp, err := httpClient.Do(httpReq) + if err != nil { + recordAPIResponseError(ctx, e.cfg, err) + return resp, err + } + defer func() { + if errClose := httpResp.Body.Close(); errClose != nil { + log.Errorf("github-copilot executor: close response body error: %v", errClose) + } + }() + + recordAPIResponseMetadata(ctx, e.cfg, httpResp.StatusCode, httpResp.Header.Clone()) + + if httpResp.StatusCode < http.StatusOK || httpResp.StatusCode >= http.StatusMultipleChoices { + data, _ := io.ReadAll(httpResp.Body) + appendAPIResponseChunk(ctx, e.cfg, data) + log.Debugf("github-copilot executor: upstream error status: %d, body: %s", httpResp.StatusCode, summarizeErrorBody(httpResp.Header.Get("Content-Type"), data)) + err = statusErr{code: httpResp.StatusCode, msg: string(data)} + return resp, err + } + + data, err := io.ReadAll(httpResp.Body) + if err != nil { + recordAPIResponseError(ctx, e.cfg, err) + return resp, err + } + appendAPIResponseChunk(ctx, e.cfg, data) + + detail := parseOpenAIUsage(data) + if detail.TotalTokens > 0 { + reporter.publish(ctx, detail) + } + + var param any + converted := sdktranslator.TranslateNonStream(ctx, to, from, req.Model, bytes.Clone(opts.OriginalRequest), body, data, ¶m) + resp = cliproxyexecutor.Response{Payload: []byte(converted)} + reporter.ensurePublished(ctx) + return resp, nil +} + +// ExecuteStream handles streaming requests to GitHub Copilot. +func (e *GitHubCopilotExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (stream <-chan cliproxyexecutor.StreamChunk, err error) { + apiToken, errToken := e.ensureAPIToken(ctx, auth) + if errToken != nil { + return nil, errToken + } + + reporter := newUsageReporter(ctx, e.Identifier(), req.Model, auth) + defer reporter.trackFailure(ctx, &err) + + from := opts.SourceFormat + to := sdktranslator.FromString("openai") + body := sdktranslator.TranslateRequest(from, to, req.Model, bytes.Clone(req.Payload), true) + body = e.normalizeModel(req.Model, body) + body = applyPayloadConfig(e.cfg, req.Model, body) + body, _ = sjson.SetBytes(body, "stream", true) + // Enable stream options for usage stats in stream + body, _ = sjson.SetBytes(body, "stream_options.include_usage", true) + + url := githubCopilotBaseURL + githubCopilotChatPath + httpReq, err := http.NewRequestWithContext(ctx, http.MethodPost, url, bytes.NewReader(body)) + if err != nil { + return nil, err + } + e.applyHeaders(httpReq, apiToken) + + var authID, authLabel, authType, authValue string + if auth != nil { + authID = auth.ID + authLabel = auth.Label + authType, authValue = auth.AccountInfo() + } + recordAPIRequest(ctx, e.cfg, upstreamRequestLog{ + URL: url, + Method: http.MethodPost, + Headers: httpReq.Header.Clone(), + Body: body, + Provider: e.Identifier(), + AuthID: authID, + AuthLabel: authLabel, + AuthType: authType, + AuthValue: authValue, + }) + + httpClient := newProxyAwareHTTPClient(ctx, e.cfg, auth, 0) + httpResp, err := httpClient.Do(httpReq) + if err != nil { + recordAPIResponseError(ctx, e.cfg, err) + return nil, err + } + + recordAPIResponseMetadata(ctx, e.cfg, httpResp.StatusCode, httpResp.Header.Clone()) + + if httpResp.StatusCode < http.StatusOK || httpResp.StatusCode >= http.StatusMultipleChoices { + data, readErr := io.ReadAll(httpResp.Body) + if errClose := httpResp.Body.Close(); errClose != nil { + log.Errorf("github-copilot executor: close response body error: %v", errClose) + } + if readErr != nil { + recordAPIResponseError(ctx, e.cfg, readErr) + return nil, readErr + } + appendAPIResponseChunk(ctx, e.cfg, data) + log.Debugf("github-copilot executor: upstream error status: %d, body: %s", httpResp.StatusCode, summarizeErrorBody(httpResp.Header.Get("Content-Type"), data)) + err = statusErr{code: httpResp.StatusCode, msg: string(data)} + return nil, err + } + + out := make(chan cliproxyexecutor.StreamChunk) + stream = out + + go func() { + defer close(out) + defer func() { + if errClose := httpResp.Body.Close(); errClose != nil { + log.Errorf("github-copilot executor: close response body error: %v", errClose) + } + }() + + scanner := bufio.NewScanner(httpResp.Body) + scanner.Buffer(nil, 20_971_520) + var param any + + for scanner.Scan() { + line := scanner.Bytes() + appendAPIResponseChunk(ctx, e.cfg, line) + + // Parse SSE data + if bytes.HasPrefix(line, dataTag) { + data := bytes.TrimSpace(line[5:]) + if bytes.Equal(data, []byte("[DONE]")) { + continue + } + if detail, ok := parseOpenAIStreamUsage(line); ok { + reporter.publish(ctx, detail) + } + } + + chunks := sdktranslator.TranslateStream(ctx, to, from, req.Model, bytes.Clone(opts.OriginalRequest), body, bytes.Clone(line), ¶m) + for i := range chunks { + out <- cliproxyexecutor.StreamChunk{Payload: []byte(chunks[i])} + } + } + + if errScan := scanner.Err(); errScan != nil { + recordAPIResponseError(ctx, e.cfg, errScan) + reporter.publishFailure(ctx) + out <- cliproxyexecutor.StreamChunk{Err: errScan} + } else { + reporter.ensurePublished(ctx) + } + }() + + return stream, nil +} + +// CountTokens is not supported for GitHub Copilot. +func (e *GitHubCopilotExecutor) CountTokens(_ context.Context, _ *cliproxyauth.Auth, _ cliproxyexecutor.Request, _ cliproxyexecutor.Options) (cliproxyexecutor.Response, error) { + return cliproxyexecutor.Response{}, statusErr{code: http.StatusNotImplemented, msg: "count tokens not supported for github-copilot"} +} + +// Refresh validates the GitHub token is still working. +// GitHub OAuth tokens don't expire traditionally, so we just validate. +func (e *GitHubCopilotExecutor) Refresh(ctx context.Context, auth *cliproxyauth.Auth) (*cliproxyauth.Auth, error) { + if auth == nil { + return nil, statusErr{code: http.StatusUnauthorized, msg: "missing auth"} + } + + // Get the GitHub access token + accessToken := metaStringValue(auth.Metadata, "access_token") + if accessToken == "" { + return auth, nil + } + + // Validate the token can still get a Copilot API token + copilotAuth := copilotauth.NewCopilotAuth(e.cfg) + _, err := copilotAuth.GetCopilotAPIToken(ctx, accessToken) + if err != nil { + return nil, statusErr{code: http.StatusUnauthorized, msg: fmt.Sprintf("github-copilot token validation failed: %v", err)} + } + + return auth, nil +} + +// ensureAPIToken gets or refreshes the Copilot API token. +func (e *GitHubCopilotExecutor) ensureAPIToken(ctx context.Context, auth *cliproxyauth.Auth) (string, error) { + if auth == nil { + return "", statusErr{code: http.StatusUnauthorized, msg: "missing auth"} + } + + // Check for cached API token + if cachedToken := metaStringValue(auth.Metadata, "copilot_api_token"); cachedToken != "" { + if expiresAt := tokenExpiry(auth.Metadata); expiresAt.After(time.Now().Add(5 * time.Minute)) { + return cachedToken, nil + } + } + + // Get the GitHub access token + accessToken := metaStringValue(auth.Metadata, "access_token") + if accessToken == "" { + return "", statusErr{code: http.StatusUnauthorized, msg: "missing github access token"} + } + + // Get a new Copilot API token + copilotAuth := copilotauth.NewCopilotAuth(e.cfg) + apiToken, err := copilotAuth.GetCopilotAPIToken(ctx, accessToken) + if err != nil { + return "", statusErr{code: http.StatusUnauthorized, msg: fmt.Sprintf("failed to get copilot api token: %v", err)} + } + + // Cache the token in metadata (will be persisted on next save) + if auth.Metadata == nil { + auth.Metadata = make(map[string]any) + } + auth.Metadata["copilot_api_token"] = apiToken.Token + if apiToken.ExpiresAt > 0 { + auth.Metadata["expired"] = time.Unix(apiToken.ExpiresAt, 0).Format(time.RFC3339) + } else { + auth.Metadata["expired"] = time.Now().Add(githubCopilotTokenCacheTTL).Format(time.RFC3339) + } + + return apiToken.Token, nil +} + +// applyHeaders sets the required headers for GitHub Copilot API requests. +func (e *GitHubCopilotExecutor) applyHeaders(r *http.Request, apiToken string) { + r.Header.Set("Content-Type", "application/json") + r.Header.Set("Authorization", "Bearer "+apiToken) + r.Header.Set("Accept", "application/json") + r.Header.Set("User-Agent", "GithubCopilot/1.0") + r.Header.Set("Editor-Version", "vscode/1.100.0") + r.Header.Set("Editor-Plugin-Version", "copilot/1.300.0") + r.Header.Set("Openai-Intent", "conversation-panel") + r.Header.Set("Copilot-Integration-Id", "vscode-chat") + r.Header.Set("X-Request-Id", uuid.NewString()) +} + +// normalizeModel ensures the model name is correct for the API. +func (e *GitHubCopilotExecutor) normalizeModel(requestedModel string, body []byte) []byte { + // Map friendly names to API model names + modelMap := map[string]string{ + "gpt-4.1": "gpt-4.1", + "gpt-5": "gpt-5", + "gpt-5-mini": "gpt-5-mini", + "gpt-5-codex": "gpt-5-codex", + "gpt-5.1": "gpt-5.1", + "gpt-5.1-codex": "gpt-5.1-codex", + "gpt-5.1-codex-mini": "gpt-5.1-codex-mini", + "claude-haiku-4.5": "claude-haiku-4.5", + "claude-opus-4.1": "claude-opus-4.1", + "claude-opus-4.5": "claude-opus-4.5", + "claude-sonnet-4": "claude-sonnet-4", + "claude-sonnet-4.5": "claude-sonnet-4.5", + "gemini-2.5-pro": "gemini-2.5-pro", + "gemini-3-pro": "gemini-3-pro", + "grok-code-fast-1": "grok-code-fast-1", + "raptor-mini": "raptor-mini", + } + + if mapped, ok := modelMap[requestedModel]; ok { + body, _ = sjson.SetBytes(body, "model", mapped) + } + + return body +} diff --git a/sdk/auth/github_copilot.go b/sdk/auth/github_copilot.go new file mode 100644 index 00000000..3ffcf133 --- /dev/null +++ b/sdk/auth/github_copilot.go @@ -0,0 +1,133 @@ +package auth + +import ( + "context" + "fmt" + "time" + + "github.com/router-for-me/CLIProxyAPI/v6/internal/auth/copilot" + "github.com/router-for-me/CLIProxyAPI/v6/internal/browser" + "github.com/router-for-me/CLIProxyAPI/v6/internal/config" + coreauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth" + log "github.com/sirupsen/logrus" +) + +// GitHubCopilotAuthenticator implements the OAuth device flow login for GitHub Copilot. +type GitHubCopilotAuthenticator struct{} + +// NewGitHubCopilotAuthenticator constructs a new GitHub Copilot authenticator. +func NewGitHubCopilotAuthenticator() Authenticator { + return &GitHubCopilotAuthenticator{} +} + +// Provider returns the provider key for github-copilot. +func (GitHubCopilotAuthenticator) Provider() string { + return "github-copilot" +} + +// RefreshLead returns nil since GitHub OAuth tokens don't expire in the traditional sense. +// The token remains valid until the user revokes it or the Copilot subscription expires. +func (GitHubCopilotAuthenticator) RefreshLead() *time.Duration { + return nil +} + +// Login initiates the GitHub device flow authentication for Copilot access. +func (a GitHubCopilotAuthenticator) Login(ctx context.Context, cfg *config.Config, opts *LoginOptions) (*coreauth.Auth, error) { + if cfg == nil { + return nil, fmt.Errorf("cliproxy auth: configuration is required") + } + if ctx == nil { + ctx = context.Background() + } + if opts == nil { + opts = &LoginOptions{} + } + + authSvc := copilot.NewCopilotAuth(cfg) + + // Start the device flow + fmt.Println("Starting GitHub Copilot authentication...") + deviceCode, err := authSvc.StartDeviceFlow(ctx) + if err != nil { + return nil, fmt.Errorf("github-copilot: failed to start device flow: %w", err) + } + + // Display the user code and verification URL + fmt.Printf("\nTo authenticate, please visit: %s\n", deviceCode.VerificationURI) + fmt.Printf("And enter the code: %s\n\n", deviceCode.UserCode) + + // Try to open the browser automatically + if !opts.NoBrowser { + if browser.IsAvailable() { + if errOpen := browser.OpenURL(deviceCode.VerificationURI); errOpen != nil { + log.Warnf("Failed to open browser automatically: %v", errOpen) + } + } + } + + fmt.Println("Waiting for GitHub authorization...") + fmt.Printf("(This will timeout in %d seconds if not authorized)\n", deviceCode.ExpiresIn) + + // Wait for user authorization + authBundle, err := authSvc.WaitForAuthorization(ctx, deviceCode) + if err != nil { + errMsg := copilot.GetUserFriendlyMessage(err) + return nil, fmt.Errorf("github-copilot: %s", errMsg) + } + + // Verify the token can get a Copilot API token + fmt.Println("Verifying Copilot access...") + apiToken, err := authSvc.GetCopilotAPIToken(ctx, authBundle.TokenData.AccessToken) + if err != nil { + return nil, fmt.Errorf("github-copilot: failed to verify Copilot access - you may not have an active Copilot subscription: %w", err) + } + + // Create the token storage + tokenStorage := authSvc.CreateTokenStorage(authBundle) + + // Build metadata + metadata := map[string]any{ + "type": "github-copilot", + "username": authBundle.Username, + "access_token": authBundle.TokenData.AccessToken, + "token_type": authBundle.TokenData.TokenType, + "scope": authBundle.TokenData.Scope, + "timestamp": time.Now().UnixMilli(), + "api_endpoint": copilot.BuildChatCompletionURL(), + } + + if apiToken.ExpiresAt > 0 { + metadata["api_token_expires_at"] = apiToken.ExpiresAt + } + + fileName := fmt.Sprintf("github-copilot-%s.json", authBundle.Username) + + fmt.Printf("\nGitHub Copilot authentication successful for user: %s\n", authBundle.Username) + + return &coreauth.Auth{ + ID: fileName, + Provider: a.Provider(), + FileName: fileName, + Label: authBundle.Username, + Storage: tokenStorage, + Metadata: metadata, + }, nil +} + +// RefreshGitHubCopilotToken validates and returns the current token status. +// GitHub OAuth tokens don't need traditional refresh - we just validate they still work. +func RefreshGitHubCopilotToken(ctx context.Context, cfg *config.Config, storage *copilot.CopilotTokenStorage) error { + if storage == nil || storage.AccessToken == "" { + return fmt.Errorf("no token available") + } + + authSvc := copilot.NewCopilotAuth(cfg) + + // Validate the token can still get a Copilot API token + _, err := authSvc.GetCopilotAPIToken(ctx, storage.AccessToken) + if err != nil { + return fmt.Errorf("token validation failed: %w", err) + } + + return nil +} diff --git a/sdk/auth/refresh_registry.go b/sdk/auth/refresh_registry.go index e82ac684..e8997f71 100644 --- a/sdk/auth/refresh_registry.go +++ b/sdk/auth/refresh_registry.go @@ -14,6 +14,7 @@ func init() { registerRefreshLead("gemini", func() Authenticator { return NewGeminiAuthenticator() }) registerRefreshLead("gemini-cli", func() Authenticator { return NewGeminiAuthenticator() }) registerRefreshLead("antigravity", func() Authenticator { return NewAntigravityAuthenticator() }) + registerRefreshLead("github-copilot", func() Authenticator { return NewGitHubCopilotAuthenticator() }) } func registerRefreshLead(provider string, factory func() Authenticator) { diff --git a/sdk/cliproxy/service.go b/sdk/cliproxy/service.go index 6e303ed2..8b66a9a9 100644 --- a/sdk/cliproxy/service.go +++ b/sdk/cliproxy/service.go @@ -351,6 +351,8 @@ func (s *Service) ensureExecutorsForAuth(a *coreauth.Auth) { s.coreManager.RegisterExecutor(executor.NewQwenExecutor(s.cfg)) case "iflow": s.coreManager.RegisterExecutor(executor.NewIFlowExecutor(s.cfg)) + case "github-copilot": + s.coreManager.RegisterExecutor(executor.NewGitHubCopilotExecutor(s.cfg)) default: providerKey := strings.ToLower(strings.TrimSpace(a.Provider)) if providerKey == "" { @@ -662,6 +664,8 @@ func (s *Service) registerModelsForAuth(a *coreauth.Auth) { models = registry.GetQwenModels() case "iflow": models = registry.GetIFlowModels() + case "github-copilot": + models = registry.GetGitHubCopilotModels() default: // Handle OpenAI-compatibility providers by name using config if s.cfg != nil { From efd28bf981a33f15323f655411ef5405a26b80e4 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Ernesto=20Mart=C3=ADnez?= Date: Thu, 27 Nov 2025 22:47:51 +0100 Subject: [PATCH 02/38] refactor(copilot): address PR review feedback - Simplify error type checking in oauth.go using errors.As directly - Remove redundant errors.As call in GetUserFriendlyMessage - Remove unused CachedAPIToken and TokenManager types (dead code) --- internal/auth/copilot/copilot_auth.go | 71 --------------------------- internal/auth/copilot/errors.go | 17 +++---- internal/auth/copilot/oauth.go | 8 +-- 3 files changed, 10 insertions(+), 86 deletions(-) diff --git a/internal/auth/copilot/copilot_auth.go b/internal/auth/copilot/copilot_auth.go index fbfb1762..e4e5876b 100644 --- a/internal/auth/copilot/copilot_auth.go +++ b/internal/auth/copilot/copilot_auth.go @@ -208,77 +208,6 @@ func (c *CopilotAuth) MakeAuthenticatedRequest(ctx context.Context, method, url return req, nil } -// CachedAPIToken manages caching of Copilot API tokens. -type CachedAPIToken struct { - Token *CopilotAPIToken - ExpiresAt time.Time -} - -// IsExpired checks if the cached token has expired. -func (c *CachedAPIToken) IsExpired() bool { - if c.Token == nil { - return true - } - // Add a 5-minute buffer before expiration - return time.Now().Add(5 * time.Minute).After(c.ExpiresAt) -} - -// TokenManager handles caching and refreshing of Copilot API tokens. -type TokenManager struct { - auth *CopilotAuth - githubToken string - cachedToken *CachedAPIToken -} - -// NewTokenManager creates a new token manager for handling Copilot API tokens. -func NewTokenManager(auth *CopilotAuth, githubToken string) *TokenManager { - return &TokenManager{ - auth: auth, - githubToken: githubToken, - } -} - -// GetToken returns a valid Copilot API token, refreshing if necessary. -func (tm *TokenManager) GetToken(ctx context.Context) (*CopilotAPIToken, error) { - if tm.cachedToken != nil && !tm.cachedToken.IsExpired() { - return tm.cachedToken.Token, nil - } - - // Fetch a new API token - apiToken, err := tm.auth.GetCopilotAPIToken(ctx, tm.githubToken) - if err != nil { - return nil, err - } - - // Cache the token - expiresAt := time.Now().Add(30 * time.Minute) // Default 30 min cache - if apiToken.ExpiresAt > 0 { - expiresAt = time.Unix(apiToken.ExpiresAt, 0) - } - - tm.cachedToken = &CachedAPIToken{ - Token: apiToken, - ExpiresAt: expiresAt, - } - - return apiToken, nil -} - -// GetAuthorizationHeader returns the authorization header value for API requests. -func (tm *TokenManager) GetAuthorizationHeader(ctx context.Context) (string, error) { - token, err := tm.GetToken(ctx) - if err != nil { - return "", err - } - return "Bearer " + token.Token, nil -} - -// UpdateGitHubToken updates the GitHub access token used for getting API tokens. -func (tm *TokenManager) UpdateGitHubToken(githubToken string) { - tm.githubToken = githubToken - tm.cachedToken = nil // Invalidate cache -} - // BuildChatCompletionURL builds the URL for chat completions API. func BuildChatCompletionURL() string { return copilotAPIEndpoint + "/chat/completions" diff --git a/internal/auth/copilot/errors.go b/internal/auth/copilot/errors.go index 01f8e754..dac6ecfa 100644 --- a/internal/auth/copilot/errors.go +++ b/internal/auth/copilot/errors.go @@ -140,10 +140,8 @@ func IsOAuthError(err error) bool { // GetUserFriendlyMessage returns a user-friendly error message based on the error type. func GetUserFriendlyMessage(err error) string { - switch { - case IsAuthenticationError(err): - var authErr *AuthenticationError - errors.As(err, &authErr) + var authErr *AuthenticationError + if errors.As(err, &authErr) { switch authErr.Type { case "device_code_failed": return "Failed to start GitHub authentication. Please check your network connection and try again." @@ -164,9 +162,10 @@ func GetUserFriendlyMessage(err error) string { default: return "Authentication failed. Please try again." } - case IsOAuthError(err): - var oauthErr *OAuthError - errors.As(err, &oauthErr) + } + + var oauthErr *OAuthError + if errors.As(err, &oauthErr) { switch oauthErr.Code { case "access_denied": return "Authentication was cancelled or denied." @@ -177,7 +176,7 @@ func GetUserFriendlyMessage(err error) string { default: return fmt.Sprintf("Authentication failed: %s", oauthErr.Description) } - default: - return "An unexpected error occurred. Please try again." } + + return "An unexpected error occurred. Please try again." } diff --git a/internal/auth/copilot/oauth.go b/internal/auth/copilot/oauth.go index 3f7877b6..1aecf596 100644 --- a/internal/auth/copilot/oauth.go +++ b/internal/auth/copilot/oauth.go @@ -3,6 +3,7 @@ package copilot import ( "context" "encoding/json" + "errors" "fmt" "io" "net/http" @@ -118,12 +119,7 @@ func (c *DeviceFlowClient) PollForToken(ctx context.Context, deviceCode *DeviceC token, err := c.exchangeDeviceCode(ctx, deviceCode.DeviceCode) if err != nil { var authErr *AuthenticationError - if IsAuthenticationError(err) { - if ok := (err.(*AuthenticationError)); ok != nil { - authErr = ok - } - } - if authErr != nil { + if errors.As(err, &authErr) { switch authErr.Type { case ErrAuthorizationPending.Type: // Continue polling From 2c296e9cb19ef48bf2b8f51911ab4fd7bf115b50 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Ernesto=20Mart=C3=ADnez?= Date: Fri, 28 Nov 2025 08:32:36 +0100 Subject: [PATCH 03/38] refactor(copilot): improve code quality in authentication module - Add Unwrap() to AuthenticationError for proper error chain handling with errors.Is/As - Extract hardcoded header values to constants for maintainability - Replace verbose status code checks with isHTTPSuccess() helper - Remove unused ExtractBearerToken() and BuildModelsURL() functions - Make buildChatCompletionURL() private (only used internally) - Remove unused 'strings' import --- internal/auth/copilot/copilot_auth.go | 47 ++++++++++++--------------- internal/auth/copilot/errors.go | 5 +++ internal/auth/copilot/oauth.go | 4 +-- 3 files changed, 28 insertions(+), 28 deletions(-) diff --git a/internal/auth/copilot/copilot_auth.go b/internal/auth/copilot/copilot_auth.go index e4e5876b..c40e7082 100644 --- a/internal/auth/copilot/copilot_auth.go +++ b/internal/auth/copilot/copilot_auth.go @@ -8,7 +8,6 @@ import ( "fmt" "io" "net/http" - "strings" "time" "github.com/router-for-me/CLIProxyAPI/v6/internal/config" @@ -21,6 +20,13 @@ const ( copilotAPITokenURL = "https://api.github.com/copilot_internal/v2/token" // copilotAPIEndpoint is the base URL for making API requests. copilotAPIEndpoint = "https://api.githubcopilot.com" + + // Common HTTP header values for Copilot API requests. + copilotUserAgent = "GithubCopilot/1.0" + copilotEditorVersion = "vscode/1.100.0" + copilotPluginVersion = "copilot/1.300.0" + copilotIntegrationID = "vscode-chat" + copilotOpenAIIntent = "conversation-panel" ) // CopilotAPIToken represents the Copilot API token response. @@ -102,9 +108,9 @@ func (c *CopilotAuth) GetCopilotAPIToken(ctx context.Context, githubAccessToken req.Header.Set("Authorization", "token "+githubAccessToken) req.Header.Set("Accept", "application/json") - req.Header.Set("User-Agent", "GithubCopilot/1.0") - req.Header.Set("Editor-Version", "vscode/1.100.0") - req.Header.Set("Editor-Plugin-Version", "copilot/1.300.0") + req.Header.Set("User-Agent", copilotUserAgent) + req.Header.Set("Editor-Version", copilotEditorVersion) + req.Header.Set("Editor-Plugin-Version", copilotPluginVersion) resp, err := c.httpClient.Do(req) if err != nil { @@ -121,7 +127,7 @@ func (c *CopilotAuth) GetCopilotAPIToken(ctx context.Context, githubAccessToken return nil, NewAuthenticationError(ErrTokenExchangeFailed, err) } - if resp.StatusCode < http.StatusOK || resp.StatusCode >= http.StatusMultipleChoices { + if !isHTTPSuccess(resp.StatusCode) { return nil, NewAuthenticationError(ErrTokenExchangeFailed, fmt.Errorf("status %d: %s", resp.StatusCode, string(bodyBytes))) } @@ -199,32 +205,21 @@ func (c *CopilotAuth) MakeAuthenticatedRequest(ctx context.Context, method, url req.Header.Set("Authorization", "Bearer "+apiToken.Token) req.Header.Set("Content-Type", "application/json") req.Header.Set("Accept", "application/json") - req.Header.Set("User-Agent", "GithubCopilot/1.0") - req.Header.Set("Editor-Version", "vscode/1.100.0") - req.Header.Set("Editor-Plugin-Version", "copilot/1.300.0") - req.Header.Set("Openai-Intent", "conversation-panel") - req.Header.Set("Copilot-Integration-Id", "vscode-chat") + req.Header.Set("User-Agent", copilotUserAgent) + req.Header.Set("Editor-Version", copilotEditorVersion) + req.Header.Set("Editor-Plugin-Version", copilotPluginVersion) + req.Header.Set("Openai-Intent", copilotOpenAIIntent) + req.Header.Set("Copilot-Integration-Id", copilotIntegrationID) return req, nil } -// BuildChatCompletionURL builds the URL for chat completions API. -func BuildChatCompletionURL() string { +// buildChatCompletionURL builds the URL for chat completions API. +func buildChatCompletionURL() string { return copilotAPIEndpoint + "/chat/completions" } -// BuildModelsURL builds the URL for listing available models. -func BuildModelsURL() string { - return copilotAPIEndpoint + "/models" -} - -// ExtractBearerToken extracts the bearer token from an Authorization header. -func ExtractBearerToken(authHeader string) string { - if strings.HasPrefix(authHeader, "Bearer ") { - return strings.TrimPrefix(authHeader, "Bearer ") - } - if strings.HasPrefix(authHeader, "token ") { - return strings.TrimPrefix(authHeader, "token ") - } - return authHeader +// isHTTPSuccess checks if the status code indicates success (2xx). +func isHTTPSuccess(statusCode int) bool { + return statusCode >= 200 && statusCode < 300 } diff --git a/internal/auth/copilot/errors.go b/internal/auth/copilot/errors.go index dac6ecfa..a82dd8ec 100644 --- a/internal/auth/copilot/errors.go +++ b/internal/auth/copilot/errors.go @@ -55,6 +55,11 @@ func (e *AuthenticationError) Error() string { return fmt.Sprintf("%s: %s", e.Type, e.Message) } +// Unwrap returns the underlying cause of the error. +func (e *AuthenticationError) Unwrap() error { + return e.Cause +} + // Common authentication error types for GitHub Copilot device flow. var ( // ErrDeviceCodeFailed represents an error when requesting the device code fails. diff --git a/internal/auth/copilot/oauth.go b/internal/auth/copilot/oauth.go index 1aecf596..d3f46aaa 100644 --- a/internal/auth/copilot/oauth.go +++ b/internal/auth/copilot/oauth.go @@ -72,7 +72,7 @@ func (c *DeviceFlowClient) RequestDeviceCode(ctx context.Context) (*DeviceCodeRe } }() - if resp.StatusCode < http.StatusOK || resp.StatusCode >= http.StatusMultipleChoices { + if !isHTTPSuccess(resp.StatusCode) { bodyBytes, _ := io.ReadAll(resp.Body) return nil, NewAuthenticationError(ErrDeviceCodeFailed, fmt.Errorf("status %d: %s", resp.StatusCode, string(bodyBytes))) } @@ -235,7 +235,7 @@ func (c *DeviceFlowClient) FetchUserInfo(ctx context.Context, accessToken string } }() - if resp.StatusCode < http.StatusOK || resp.StatusCode >= http.StatusMultipleChoices { + if !isHTTPSuccess(resp.StatusCode) { bodyBytes, _ := io.ReadAll(resp.Body) return "", NewAuthenticationError(ErrUserInfoFailed, fmt.Errorf("status %d: %s", resp.StatusCode, string(bodyBytes))) } From 7515090cb686f0a502af24f59dbe53aaecf135ee Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Ernesto=20Mart=C3=ADnez?= Date: Fri, 28 Nov 2025 08:33:51 +0100 Subject: [PATCH 04/38] refactor(executor): improve concurrency and code quality in GitHub Copilot executor - Replace concurrent-unsafe metadata caching with thread-safe sync.RWMutex-protected map - Extract magic numbers and hardcoded header values to named constants - Replace verbose status code checks with isHTTPSuccess() helper - Simplify normalizeModel() to no-op with explanatory comment (models already canonical) - Remove redundant metadata manipulation in token caching - Improve code clarity and performance with proper cache management --- .../executor/github_copilot_executor.go | 109 ++++++++++-------- sdk/auth/github_copilot.go | 6 +- 2 files changed, 59 insertions(+), 56 deletions(-) diff --git a/internal/runtime/executor/github_copilot_executor.go b/internal/runtime/executor/github_copilot_executor.go index 0d1240f7..b2bc73df 100644 --- a/internal/runtime/executor/github_copilot_executor.go +++ b/internal/runtime/executor/github_copilot_executor.go @@ -7,6 +7,7 @@ import ( "fmt" "io" "net/http" + "sync" "time" "github.com/google/uuid" @@ -24,16 +25,38 @@ const ( githubCopilotChatPath = "/chat/completions" githubCopilotAuthType = "github-copilot" githubCopilotTokenCacheTTL = 25 * time.Minute + // tokenExpiryBuffer is the time before expiry when we should refresh the token. + tokenExpiryBuffer = 5 * time.Minute + // maxScannerBufferSize is the maximum buffer size for SSE scanning (20MB). + maxScannerBufferSize = 20_971_520 + + // Copilot API header values. + copilotUserAgent = "GithubCopilot/1.0" + copilotEditorVersion = "vscode/1.100.0" + copilotPluginVersion = "copilot/1.300.0" + copilotIntegrationID = "vscode-chat" + copilotOpenAIIntent = "conversation-panel" ) // GitHubCopilotExecutor handles requests to the GitHub Copilot API. type GitHubCopilotExecutor struct { - cfg *config.Config + cfg *config.Config + mu sync.RWMutex + cache map[string]*cachedAPIToken +} + +// cachedAPIToken stores a cached Copilot API token with its expiry. +type cachedAPIToken struct { + token string + expiresAt time.Time } // NewGitHubCopilotExecutor constructs a new executor instance. func NewGitHubCopilotExecutor(cfg *config.Config) *GitHubCopilotExecutor { - return &GitHubCopilotExecutor{cfg: cfg} + return &GitHubCopilotExecutor{ + cfg: cfg, + cache: make(map[string]*cachedAPIToken), + } } // Identifier implements ProviderExecutor. @@ -100,7 +123,7 @@ func (e *GitHubCopilotExecutor) Execute(ctx context.Context, auth *cliproxyauth. recordAPIResponseMetadata(ctx, e.cfg, httpResp.StatusCode, httpResp.Header.Clone()) - if httpResp.StatusCode < http.StatusOK || httpResp.StatusCode >= http.StatusMultipleChoices { + if !isHTTPSuccess(httpResp.StatusCode) { data, _ := io.ReadAll(httpResp.Body) appendAPIResponseChunk(ctx, e.cfg, data) log.Debugf("github-copilot executor: upstream error status: %d, body: %s", httpResp.StatusCode, summarizeErrorBody(httpResp.Header.Get("Content-Type"), data)) @@ -180,7 +203,7 @@ func (e *GitHubCopilotExecutor) ExecuteStream(ctx context.Context, auth *cliprox recordAPIResponseMetadata(ctx, e.cfg, httpResp.StatusCode, httpResp.Header.Clone()) - if httpResp.StatusCode < http.StatusOK || httpResp.StatusCode >= http.StatusMultipleChoices { + if !isHTTPSuccess(httpResp.StatusCode) { data, readErr := io.ReadAll(httpResp.Body) if errClose := httpResp.Body.Close(); errClose != nil { log.Errorf("github-copilot executor: close response body error: %v", errClose) @@ -207,7 +230,7 @@ func (e *GitHubCopilotExecutor) ExecuteStream(ctx context.Context, auth *cliprox }() scanner := bufio.NewScanner(httpResp.Body) - scanner.Buffer(nil, 20_971_520) + scanner.Buffer(nil, maxScannerBufferSize) var param any for scanner.Scan() { @@ -277,19 +300,20 @@ func (e *GitHubCopilotExecutor) ensureAPIToken(ctx context.Context, auth *clipro return "", statusErr{code: http.StatusUnauthorized, msg: "missing auth"} } - // Check for cached API token - if cachedToken := metaStringValue(auth.Metadata, "copilot_api_token"); cachedToken != "" { - if expiresAt := tokenExpiry(auth.Metadata); expiresAt.After(time.Now().Add(5 * time.Minute)) { - return cachedToken, nil - } - } - // Get the GitHub access token accessToken := metaStringValue(auth.Metadata, "access_token") if accessToken == "" { return "", statusErr{code: http.StatusUnauthorized, msg: "missing github access token"} } + // Check for cached API token using thread-safe access + e.mu.RLock() + if cached, ok := e.cache[accessToken]; ok && cached.expiresAt.After(time.Now().Add(tokenExpiryBuffer)) { + e.mu.RUnlock() + return cached.token, nil + } + e.mu.RUnlock() + // Get a new Copilot API token copilotAuth := copilotauth.NewCopilotAuth(e.cfg) apiToken, err := copilotAuth.GetCopilotAPIToken(ctx, accessToken) @@ -297,16 +321,17 @@ func (e *GitHubCopilotExecutor) ensureAPIToken(ctx context.Context, auth *clipro return "", statusErr{code: http.StatusUnauthorized, msg: fmt.Sprintf("failed to get copilot api token: %v", err)} } - // Cache the token in metadata (will be persisted on next save) - if auth.Metadata == nil { - auth.Metadata = make(map[string]any) - } - auth.Metadata["copilot_api_token"] = apiToken.Token + // Cache the token with thread-safe access + expiresAt := time.Now().Add(githubCopilotTokenCacheTTL) if apiToken.ExpiresAt > 0 { - auth.Metadata["expired"] = time.Unix(apiToken.ExpiresAt, 0).Format(time.RFC3339) - } else { - auth.Metadata["expired"] = time.Now().Add(githubCopilotTokenCacheTTL).Format(time.RFC3339) + expiresAt = time.Unix(apiToken.ExpiresAt, 0) } + e.mu.Lock() + e.cache[accessToken] = &cachedAPIToken{ + token: apiToken.Token, + expiresAt: expiresAt, + } + e.mu.Unlock() return apiToken.Token, nil } @@ -316,39 +341,21 @@ func (e *GitHubCopilotExecutor) applyHeaders(r *http.Request, apiToken string) { r.Header.Set("Content-Type", "application/json") r.Header.Set("Authorization", "Bearer "+apiToken) r.Header.Set("Accept", "application/json") - r.Header.Set("User-Agent", "GithubCopilot/1.0") - r.Header.Set("Editor-Version", "vscode/1.100.0") - r.Header.Set("Editor-Plugin-Version", "copilot/1.300.0") - r.Header.Set("Openai-Intent", "conversation-panel") - r.Header.Set("Copilot-Integration-Id", "vscode-chat") + r.Header.Set("User-Agent", copilotUserAgent) + r.Header.Set("Editor-Version", copilotEditorVersion) + r.Header.Set("Editor-Plugin-Version", copilotPluginVersion) + r.Header.Set("Openai-Intent", copilotOpenAIIntent) + r.Header.Set("Copilot-Integration-Id", copilotIntegrationID) r.Header.Set("X-Request-Id", uuid.NewString()) } -// normalizeModel ensures the model name is correct for the API. -func (e *GitHubCopilotExecutor) normalizeModel(requestedModel string, body []byte) []byte { - // Map friendly names to API model names - modelMap := map[string]string{ - "gpt-4.1": "gpt-4.1", - "gpt-5": "gpt-5", - "gpt-5-mini": "gpt-5-mini", - "gpt-5-codex": "gpt-5-codex", - "gpt-5.1": "gpt-5.1", - "gpt-5.1-codex": "gpt-5.1-codex", - "gpt-5.1-codex-mini": "gpt-5.1-codex-mini", - "claude-haiku-4.5": "claude-haiku-4.5", - "claude-opus-4.1": "claude-opus-4.1", - "claude-opus-4.5": "claude-opus-4.5", - "claude-sonnet-4": "claude-sonnet-4", - "claude-sonnet-4.5": "claude-sonnet-4.5", - "gemini-2.5-pro": "gemini-2.5-pro", - "gemini-3-pro": "gemini-3-pro", - "grok-code-fast-1": "grok-code-fast-1", - "raptor-mini": "raptor-mini", - } - - if mapped, ok := modelMap[requestedModel]; ok { - body, _ = sjson.SetBytes(body, "model", mapped) - } - +// normalizeModel is a no-op as GitHub Copilot accepts model names directly. +// Model mapping should be done at the registry level if needed. +func (e *GitHubCopilotExecutor) normalizeModel(_ string, body []byte) []byte { return body } + +// isHTTPSuccess checks if the status code indicates success (2xx). +func isHTTPSuccess(statusCode int) bool { + return statusCode >= 200 && statusCode < 300 +} diff --git a/sdk/auth/github_copilot.go b/sdk/auth/github_copilot.go index 3ffcf133..1d14ac47 100644 --- a/sdk/auth/github_copilot.go +++ b/sdk/auth/github_copilot.go @@ -36,9 +36,6 @@ func (a GitHubCopilotAuthenticator) Login(ctx context.Context, cfg *config.Confi if cfg == nil { return nil, fmt.Errorf("cliproxy auth: configuration is required") } - if ctx == nil { - ctx = context.Background() - } if opts == nil { opts = &LoginOptions{} } @@ -85,7 +82,7 @@ func (a GitHubCopilotAuthenticator) Login(ctx context.Context, cfg *config.Confi // Create the token storage tokenStorage := authSvc.CreateTokenStorage(authBundle) - // Build metadata + // Build metadata with token information for the executor metadata := map[string]any{ "type": "github-copilot", "username": authBundle.Username, @@ -93,7 +90,6 @@ func (a GitHubCopilotAuthenticator) Login(ctx context.Context, cfg *config.Confi "token_type": authBundle.TokenData.TokenType, "scope": authBundle.TokenData.Scope, "timestamp": time.Now().UnixMilli(), - "api_endpoint": copilot.BuildChatCompletionURL(), } if apiToken.ExpiresAt > 0 { From 0087eecad8516e4021bdeeff5765acd5a0ac6837 Mon Sep 17 00:00:00 2001 From: Luis Pater Date: Sun, 30 Nov 2025 17:23:47 +0800 Subject: [PATCH 05/38] **fix(build): append '-plus' suffix to version metadata in workflows and Goreleaser** - Updated release and Docker workflows to ensure the `-plus` suffix is added to build versions when missing. - Adjusted Goreleaser configuration to include `-plus` suffix in `main.Version` during build process. --- .github/workflows/docker-image.yml | 7 +++++-- .github/workflows/release.yaml | 6 +++++- .goreleaser.yml | 2 +- 3 files changed, 11 insertions(+), 4 deletions(-) diff --git a/.github/workflows/docker-image.yml b/.github/workflows/docker-image.yml index 3aacf4f5..63e1c4b7 100644 --- a/.github/workflows/docker-image.yml +++ b/.github/workflows/docker-image.yml @@ -26,7 +26,11 @@ jobs: password: ${{ secrets.DOCKERHUB_TOKEN }} - name: Generate Build Metadata run: | - echo VERSION=`git describe --tags --always --dirty` >> $GITHUB_ENV + VERSION=$(git describe --tags --always --dirty) + if [[ "$VERSION" != *"-plus" ]]; then + VERSION="${VERSION}-plus" + fi + echo "VERSION=${VERSION}" >> $GITHUB_ENV echo COMMIT=`git rev-parse --short HEAD` >> $GITHUB_ENV echo BUILD_DATE=`date -u +%Y-%m-%dT%H:%M:%SZ` >> $GITHUB_ENV - name: Build and push @@ -42,5 +46,4 @@ jobs: COMMIT=${{ env.COMMIT }} BUILD_DATE=${{ env.BUILD_DATE }} tags: | - ${{ env.DOCKERHUB_REPO }}:latest ${{ env.DOCKERHUB_REPO }}:${{ env.VERSION }} diff --git a/.github/workflows/release.yaml b/.github/workflows/release.yaml index 4bb5e63b..53c1cf32 100644 --- a/.github/workflows/release.yaml +++ b/.github/workflows/release.yaml @@ -23,7 +23,11 @@ jobs: cache: true - name: Generate Build Metadata run: | - echo VERSION=`git describe --tags --always --dirty` >> $GITHUB_ENV + VERSION=$(git describe --tags --always --dirty) + if [[ "$VERSION" != *"-plus" ]]; then + VERSION="${VERSION}-plus" + fi + echo "VERSION=${VERSION}" >> $GITHUB_ENV echo COMMIT=`git rev-parse --short HEAD` >> $GITHUB_ENV echo BUILD_DATE=`date -u +%Y-%m-%dT%H:%M:%SZ` >> $GITHUB_ENV - uses: goreleaser/goreleaser-action@v4 diff --git a/.goreleaser.yml b/.goreleaser.yml index 31d05e6d..8fbec015 100644 --- a/.goreleaser.yml +++ b/.goreleaser.yml @@ -12,7 +12,7 @@ builds: main: ./cmd/server/ binary: cli-proxy-api ldflags: - - -s -w -X 'main.Version={{.Version}}' -X 'main.Commit={{.ShortCommit}}' -X 'main.BuildDate={{.Date}}' + - -s -w -X 'main.Version={{.Version}}-plus' -X 'main.Commit={{.ShortCommit}}' -X 'main.BuildDate={{.Date}}' archives: - id: "cli-proxy-api" format: tar.gz From 52e5551d8f2298274338c467c6629ad4978023de Mon Sep 17 00:00:00 2001 From: Luis Pater Date: Sun, 30 Nov 2025 17:54:37 +0800 Subject: [PATCH 06/38] **chore(build): rename artifacts and adjust workflows for 'plus' variant** - Updated Docker and release workflows to use `cli-proxy-api-plus` as the Docker repository name and adjusted tag generation logic. - Renamed artifacts in Goreleaser configuration to align with the new 'plus' variant naming convention. --- .github/workflows/docker-image.yml | 9 +++------ .github/workflows/release.yaml | 3 --- .goreleaser.yml | 6 +++--- 3 files changed, 6 insertions(+), 12 deletions(-) diff --git a/.github/workflows/docker-image.yml b/.github/workflows/docker-image.yml index 63e1c4b7..de924672 100644 --- a/.github/workflows/docker-image.yml +++ b/.github/workflows/docker-image.yml @@ -7,7 +7,7 @@ on: env: APP_NAME: CLIProxyAPI - DOCKERHUB_REPO: eceasy/cli-proxy-api + DOCKERHUB_REPO: eceasy/cli-proxy-api-plus jobs: docker: @@ -26,11 +26,7 @@ jobs: password: ${{ secrets.DOCKERHUB_TOKEN }} - name: Generate Build Metadata run: | - VERSION=$(git describe --tags --always --dirty) - if [[ "$VERSION" != *"-plus" ]]; then - VERSION="${VERSION}-plus" - fi - echo "VERSION=${VERSION}" >> $GITHUB_ENV + echo VERSION=`git describe --tags --always --dirty` >> $GITHUB_ENV echo COMMIT=`git rev-parse --short HEAD` >> $GITHUB_ENV echo BUILD_DATE=`date -u +%Y-%m-%dT%H:%M:%SZ` >> $GITHUB_ENV - name: Build and push @@ -46,4 +42,5 @@ jobs: COMMIT=${{ env.COMMIT }} BUILD_DATE=${{ env.BUILD_DATE }} tags: | + ${{ env.DOCKERHUB_REPO }}:latest ${{ env.DOCKERHUB_REPO }}:${{ env.VERSION }} diff --git a/.github/workflows/release.yaml b/.github/workflows/release.yaml index 53c1cf32..4c4aafe7 100644 --- a/.github/workflows/release.yaml +++ b/.github/workflows/release.yaml @@ -24,9 +24,6 @@ jobs: - name: Generate Build Metadata run: | VERSION=$(git describe --tags --always --dirty) - if [[ "$VERSION" != *"-plus" ]]; then - VERSION="${VERSION}-plus" - fi echo "VERSION=${VERSION}" >> $GITHUB_ENV echo COMMIT=`git rev-parse --short HEAD` >> $GITHUB_ENV echo BUILD_DATE=`date -u +%Y-%m-%dT%H:%M:%SZ` >> $GITHUB_ENV diff --git a/.goreleaser.yml b/.goreleaser.yml index 8fbec015..6e1829ed 100644 --- a/.goreleaser.yml +++ b/.goreleaser.yml @@ -1,5 +1,5 @@ builds: - - id: "cli-proxy-api" + - id: "cli-proxy-api-plus" env: - CGO_ENABLED=0 goos: @@ -10,11 +10,11 @@ builds: - amd64 - arm64 main: ./cmd/server/ - binary: cli-proxy-api + binary: cli-proxy-api-plus ldflags: - -s -w -X 'main.Version={{.Version}}-plus' -X 'main.Commit={{.ShortCommit}}' -X 'main.BuildDate={{.Date}}' archives: - - id: "cli-proxy-api" + - id: "cli-proxy-api-plus" format: tar.gz format_overrides: - goos: windows From 8203bf64ec554e1fc93ceeca870e51d45dde08b1 Mon Sep 17 00:00:00 2001 From: Luis Pater Date: Sun, 30 Nov 2025 20:16:51 +0800 Subject: [PATCH 07/38] **docs: update README for CLIProxyAPI Plus and rename artifacts** - Updated README files to reflect the new 'Plus' variant with detailed third-party provider support information. - Adjusted Dockerfile, `docker-compose.yml`, and build configurations to align with the `CLIProxyAPIPlus` naming convention. - Added information about GitHub Copilot OAuth integration contributed by the community. --- Dockerfile | 6 +-- README.md | 93 ++++------------------------------------- README_CN.md | 100 ++++----------------------------------------- docker-compose.yml | 4 +- 4 files changed, 23 insertions(+), 180 deletions(-) diff --git a/Dockerfile b/Dockerfile index 8623dc5e..98509423 100644 --- a/Dockerfile +++ b/Dockerfile @@ -12,7 +12,7 @@ ARG VERSION=dev ARG COMMIT=none ARG BUILD_DATE=unknown -RUN CGO_ENABLED=0 GOOS=linux go build -ldflags="-s -w -X 'main.Version=${VERSION}' -X 'main.Commit=${COMMIT}' -X 'main.BuildDate=${BUILD_DATE}'" -o ./CLIProxyAPI ./cmd/server/ +RUN CGO_ENABLED=0 GOOS=linux go build -ldflags="-s -w -X 'main.Version=${VERSION}-plus' -X 'main.Commit=${COMMIT}' -X 'main.BuildDate=${BUILD_DATE}'" -o ./CLIProxyAPIPlus ./cmd/server/ FROM alpine:3.22.0 @@ -20,7 +20,7 @@ RUN apk add --no-cache tzdata RUN mkdir /CLIProxyAPI -COPY --from=builder ./app/CLIProxyAPI /CLIProxyAPI/CLIProxyAPI +COPY --from=builder ./app/CLIProxyAPIPlus /CLIProxyAPI/CLIProxyAPIPlus COPY config.example.yaml /CLIProxyAPI/config.example.yaml @@ -32,4 +32,4 @@ ENV TZ=Asia/Shanghai RUN cp /usr/share/zoneinfo/${TZ} /etc/localtime && echo "${TZ}" > /etc/timezone -CMD ["./CLIProxyAPI"] \ No newline at end of file +CMD ["./CLIProxyAPIPlus"] \ No newline at end of file diff --git a/README.md b/README.md index 90d5d465..8e27ab05 100644 --- a/README.md +++ b/README.md @@ -1,97 +1,22 @@ -# CLI Proxy API +# CLIProxyAPI Plus -English | [中文](README_CN.md) +English | [Chinese](README_CN.md) -A proxy server that provides OpenAI/Gemini/Claude/Codex compatible API interfaces for CLI. +This is the Plus version of [CLIProxyAPI](https://github.com/router-for-me/CLIProxyAPI), adding support for third-party providers on top of the mainline project. -It now also supports OpenAI Codex (GPT models) and Claude Code via OAuth. +All third-party provider support is maintained by community contributors; CLIProxyAPI does not provide technical support. Please contact the corresponding community maintainer if you need assistance. -So you can use local or multi-account CLI access with OpenAI(include Responses)/Gemini/Claude-compatible clients and SDKs. +The Plus release stays in lockstep with the mainline features. -## Sponsor +## Differences from the Mainline -[![z.ai](https://assets.router-for.me/english.png)](https://z.ai/subscribe?ic=8JVLJQFSKB) - -This project is sponsored by Z.ai, supporting us with their GLM CODING PLAN. - -GLM CODING PLAN is a subscription service designed for AI coding, starting at just $3/month. It provides access to their flagship GLM-4.6 model across 10+ popular AI coding tools (Claude Code, Cline, Roo Code, etc.), offering developers top-tier, fast, and stable coding experiences. - -Get 10% OFF GLM CODING PLAN:https://z.ai/subscribe?ic=8JVLJQFSKB - -## Overview - -- OpenAI/Gemini/Claude compatible API endpoints for CLI models -- OpenAI Codex support (GPT models) via OAuth login -- Claude Code support via OAuth login -- Qwen Code support via OAuth login -- iFlow support via OAuth login -- Amp CLI and IDE extensions support with provider routing -- Streaming and non-streaming responses -- Function calling/tools support -- Multimodal input support (text and images) -- Multiple accounts with round-robin load balancing (Gemini, OpenAI, Claude, Qwen and iFlow) -- Simple CLI authentication flows (Gemini, OpenAI, Claude, Qwen and iFlow) -- Generative Language API Key support -- AI Studio Build multi-account load balancing -- Gemini CLI multi-account load balancing -- Claude Code multi-account load balancing -- Qwen Code multi-account load balancing -- iFlow multi-account load balancing -- OpenAI Codex multi-account load balancing -- OpenAI-compatible upstream providers via config (e.g., OpenRouter) -- Reusable Go SDK for embedding the proxy (see `docs/sdk-usage.md`) - -## Getting Started - -CLIProxyAPI Guides: [https://help.router-for.me/](https://help.router-for.me/) - -## Management API - -see [MANAGEMENT_API.md](https://help.router-for.me/management/api) - -## Amp CLI Support - -CLIProxyAPI includes integrated support for [Amp CLI](https://ampcode.com) and Amp IDE extensions, enabling you to use your Google/ChatGPT/Claude OAuth subscriptions with Amp's coding tools: - -- Provider route aliases for Amp's API patterns (`/api/provider/{provider}/v1...`) -- Management proxy for OAuth authentication and account features -- Smart model fallback with automatic routing -- Security-first design with localhost-only management endpoints - -**→ [Complete Amp CLI Integration Guide](docs/amp-cli-integration.md)** - -## SDK Docs - -- Usage: [docs/sdk-usage.md](docs/sdk-usage.md) -- Advanced (executors & translators): [docs/sdk-advanced.md](docs/sdk-advanced.md) -- Access: [docs/sdk-access.md](docs/sdk-access.md) -- Watcher: [docs/sdk-watcher.md](docs/sdk-watcher.md) -- Custom Provider Example: `examples/custom-provider` +- Added GitHub Copilot support (OAuth login), provided by [em4gp](https://github.com/em4go/CLIProxyAPI/tree/feature/github-copilot-auth) ## Contributing -Contributions are welcome! Please feel free to submit a Pull Request. +This project only accepts pull requests that relate to third-party provider support. Any pull requests unrelated to third-party provider support will be rejected. -1. Fork the repository -2. Create your feature branch (`git checkout -b feature/amazing-feature`) -3. Commit your changes (`git commit -m 'Add some amazing feature'`) -4. Push to the branch (`git push origin feature/amazing-feature`) -5. Open a Pull Request - -## Who is with us? - -Those projects are based on CLIProxyAPI: - -### [vibeproxy](https://github.com/automazeio/vibeproxy) - -Native macOS menu bar app to use your Claude Code & ChatGPT subscriptions with AI coding tools - no API keys needed - -### [Subtitle Translator](https://github.com/VjayC/SRT-Subtitle-Translator-Validator) - -Browser-based tool to translate SRT subtitles using your Gemini subscription via CLIProxyAPI with automatic validation/error correction - no API keys needed - -> [!NOTE] -> If you developed a project based on CLIProxyAPI, please open a PR to add it to this list. +If you need to submit any non-third-party provider changes, please open them against the mainline repository. ## License diff --git a/README_CN.md b/README_CN.md index be0aa234..84e90d40 100644 --- a/README_CN.md +++ b/README_CN.md @@ -1,105 +1,23 @@ -# CLI 代理 API +# CLIProxyAPI Plus [English](README.md) | 中文 -一个为 CLI 提供 OpenAI/Gemini/Claude/Codex 兼容 API 接口的代理服务器。 +这是 [CLIProxyAPI](https://github.com/router-for-me/CLIProxyAPI) 的 Plus 版本,在原有基础上增加了第三方供应商的支持。 -现已支持通过 OAuth 登录接入 OpenAI Codex(GPT 系列)和 Claude Code。 +所有的第三方供应商支持都由第三方社区维护者提供,CLIProxyAPI 不提供技术支持。如需取得支持,请与对应的社区维护者联系。 -您可以使用本地或多账户的CLI方式,通过任何与 OpenAI(包括Responses)/Gemini/Claude 兼容的客户端和SDK进行访问。 +该 Plus 版本的主线功能与主线功能强制同步。 -## 赞助商 +## 与主线版本版本差异 -[![bigmodel.cn](https://assets.router-for.me/chinese.png)](https://www.bigmodel.cn/claude-code?ic=RRVJPB5SII) - -本项目由 Z智谱 提供赞助, 他们通过 GLM CODING PLAN 对本项目提供技术支持。 - -GLM CODING PLAN 是专为AI编码打造的订阅套餐,每月最低仅需20元,即可在十余款主流AI编码工具如 Claude Code、Cline、Roo Code 中畅享智谱旗舰模型GLM-4.6,为开发者提供顶尖的编码体验。 - -智谱AI为本软件提供了特别优惠,使用以下链接购买可以享受九折优惠:https://www.bigmodel.cn/claude-code?ic=RRVJPB5SII - -## 功能特性 - -- 为 CLI 模型提供 OpenAI/Gemini/Claude/Codex 兼容的 API 端点 -- 新增 OpenAI Codex(GPT 系列)支持(OAuth 登录) -- 新增 Claude Code 支持(OAuth 登录) -- 新增 Qwen Code 支持(OAuth 登录) -- 新增 iFlow 支持(OAuth 登录) -- 支持流式与非流式响应 -- 函数调用/工具支持 -- 多模态输入(文本、图片) -- 多账户支持与轮询负载均衡(Gemini、OpenAI、Claude、Qwen 与 iFlow) -- 简单的 CLI 身份验证流程(Gemini、OpenAI、Claude、Qwen 与 iFlow) -- 支持 Gemini AIStudio API 密钥 -- 支持 AI Studio Build 多账户轮询 -- 支持 Gemini CLI 多账户轮询 -- 支持 Claude Code 多账户轮询 -- 支持 Qwen Code 多账户轮询 -- 支持 iFlow 多账户轮询 -- 支持 OpenAI Codex 多账户轮询 -- 通过配置接入上游 OpenAI 兼容提供商(例如 OpenRouter) -- 可复用的 Go SDK(见 `docs/sdk-usage_CN.md`) - -## 新手入门 - -CLIProxyAPI 用户手册: [https://help.router-for.me/](https://help.router-for.me/cn/) - -## 管理 API 文档 - -请参见 [MANAGEMENT_API_CN.md](https://help.router-for.me/cn/management/api) - -## Amp CLI 支持 - -CLIProxyAPI 已内置对 [Amp CLI](https://ampcode.com) 和 Amp IDE 扩展的支持,可让你使用自己的 Google/ChatGPT/Claude OAuth 订阅来配合 Amp 编码工具: - -- 提供商路由别名,兼容 Amp 的 API 路径模式(`/api/provider/{provider}/v1...`) -- 管理代理,处理 OAuth 认证和账号功能 -- 智能模型回退与自动路由 -- 以安全为先的设计,管理端点仅限 localhost - -**→ [Amp CLI 完整集成指南](docs/amp-cli-integration_CN.md)** - -## SDK 文档 - -- 使用文档:[docs/sdk-usage_CN.md](docs/sdk-usage_CN.md) -- 高级(执行器与翻译器):[docs/sdk-advanced_CN.md](docs/sdk-advanced_CN.md) -- 认证: [docs/sdk-access_CN.md](docs/sdk-access_CN.md) -- 凭据加载/更新: [docs/sdk-watcher_CN.md](docs/sdk-watcher_CN.md) -- 自定义 Provider 示例:`examples/custom-provider` +- 新增 GitHub Copilot 支持(OAuth 登录),由[em4gp](https://github.com/em4go/CLIProxyAPI/tree/feature/github-copilot-auth)提供 ## 贡献 -欢迎贡献!请随时提交 Pull Request。 +该项目仅接受第三方供应商支持的 Pull Request。任何非第三方供应商支持的 Pull Request 都将被拒绝。 -1. Fork 仓库 -2. 创建您的功能分支(`git checkout -b feature/amazing-feature`) -3. 提交您的更改(`git commit -m 'Add some amazing feature'`) -4. 推送到分支(`git push origin feature/amazing-feature`) -5. 打开 Pull Request - -## 谁与我们在一起? - -这些项目基于 CLIProxyAPI: - -### [vibeproxy](https://github.com/automazeio/vibeproxy) - -一个原生 macOS 菜单栏应用,让您可以使用 Claude Code & ChatGPT 订阅服务和 AI 编程工具,无需 API 密钥。 - -### [Subtitle Translator](https://github.com/VjayC/SRT-Subtitle-Translator-Validator) - -一款基于浏览器的 SRT 字幕翻译工具,可通过 CLI 代理 API 使用您的 Gemini 订阅。内置自动验证与错误修正功能,无需 API 密钥。 - -> [!NOTE] -> 如果你开发了基于 CLIProxyAPI 的项目,请提交一个 PR(拉取请求)将其添加到此列表中。 +如果需要提交任何非第三方供应商支持的 Pull Request,请提交到主线版本。 ## 许可证 -此项目根据 MIT 许可证授权 - 有关详细信息,请参阅 [LICENSE](LICENSE) 文件。 - -## 写给所有中国网友的 - -QQ 群:188637136 - -或 - -Telegram 群:https://t.me/CLIProxyAPI +此项目根据 MIT 许可证授权 - 有关详细信息,请参阅 [LICENSE](LICENSE) 文件。 \ No newline at end of file diff --git a/docker-compose.yml b/docker-compose.yml index 29712419..80693fbe 100644 --- a/docker-compose.yml +++ b/docker-compose.yml @@ -1,6 +1,6 @@ services: cli-proxy-api: - image: ${CLI_PROXY_IMAGE:-eceasy/cli-proxy-api:latest} + image: ${CLI_PROXY_IMAGE:-eceasy/cli-proxy-api-plus:latest} pull_policy: always build: context: . @@ -9,7 +9,7 @@ services: VERSION: ${VERSION:-dev} COMMIT: ${COMMIT:-none} BUILD_DATE: ${BUILD_DATE:-unknown} - container_name: cli-proxy-api + container_name: cli-proxy-api-plus # env_file: # - .env environment: From 691cdb6bdfae92c9dd1d4207bcbb36360ffbc55d Mon Sep 17 00:00:00 2001 From: Luis Pater Date: Fri, 5 Dec 2025 10:32:28 +0800 Subject: [PATCH 08/38] **fix(api): update GitHub release URL and user agent for CLIProxyAPIPlus** --- internal/api/handlers/management/config_basic.go | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/internal/api/handlers/management/config_basic.go b/internal/api/handlers/management/config_basic.go index ae292982..f9069198 100644 --- a/internal/api/handlers/management/config_basic.go +++ b/internal/api/handlers/management/config_basic.go @@ -19,8 +19,8 @@ import ( ) const ( - latestReleaseURL = "https://api.github.com/repos/router-for-me/CLIProxyAPI/releases/latest" - latestReleaseUserAgent = "CLIProxyAPI" + latestReleaseURL = "https://api.github.com/repos/router-for-me/CLIProxyAPIPlus/releases/latest" + latestReleaseUserAgent = "CLIProxyAPIPlus" ) func (h *Handler) GetConfig(c *gin.Context) { From 02d8a1cfecb2bf4ee1c5105507f4022f09e0bcde Mon Sep 17 00:00:00 2001 From: Mansi Date: Fri, 5 Dec 2025 22:46:24 +0300 Subject: [PATCH 09/38] feat(kiro): add AWS Builder ID authentication support - Add --kiro-aws-login flag for AWS Builder ID device code flow - Add DoKiroAWSLogin function for AWS SSO OIDC authentication - Complete Kiro integration with AWS, Google OAuth, and social auth - Add kiro executor, translator, and SDK components - Update browser support for Kiro authentication flows --- .gitignore | 2 + cmd/server/main.go | 50 + config.example.yaml | 15 + go.mod | 3 +- go.sum | 5 +- .../api/handlers/management/auth_files.go | 155 +- internal/api/server.go | 2 +- internal/auth/claude/oauth_server.go | 11 + internal/auth/codex/oauth_server.go | 11 + internal/auth/iflow/iflow_auth.go | 22 +- internal/auth/kiro/aws.go | 301 +++ internal/auth/kiro/aws_auth.go | 314 +++ internal/auth/kiro/aws_test.go | 161 ++ internal/auth/kiro/oauth.go | 296 +++ internal/auth/kiro/protocol_handler.go | 725 +++++ internal/auth/kiro/social_auth.go | 403 +++ internal/auth/kiro/sso_oidc.go | 527 ++++ internal/auth/kiro/token.go | 72 + internal/browser/browser.go | 444 +++- internal/cmd/auth_manager.go | 1 + internal/cmd/kiro_login.go | 160 ++ internal/config/config.go | 53 + internal/constant/constant.go | 3 + internal/logging/global_logger.go | 14 +- internal/registry/model_definitions.go | 158 ++ .../runtime/executor/antigravity_executor.go | 10 +- internal/runtime/executor/cache_helpers.go | 32 +- internal/runtime/executor/codex_executor.go | 4 +- internal/runtime/executor/kiro_executor.go | 2353 +++++++++++++++++ internal/runtime/executor/proxy_helpers.go | 49 +- .../claude/gemini/claude_gemini_response.go | 5 +- .../claude_openai_response.go | 4 + .../claude/gemini-cli_claude_response.go | 97 +- internal/translator/init.go | 3 + internal/translator/kiro/claude/init.go | 19 + .../translator/kiro/claude/kiro_claude.go | 24 + .../kiro/openai/chat-completions/init.go | 19 + .../chat-completions/kiro_openai_request.go | 258 ++ .../chat-completions/kiro_openai_response.go | 316 +++ internal/watcher/watcher.go | 181 +- sdk/api/handlers/handlers.go | 32 +- sdk/auth/kiro.go | 357 +++ sdk/auth/manager.go | 13 + sdk/auth/refresh_registry.go | 1 + sdk/cliproxy/service.go | 5 + 45 files changed, 7519 insertions(+), 171 deletions(-) create mode 100644 internal/auth/kiro/aws.go create mode 100644 internal/auth/kiro/aws_auth.go create mode 100644 internal/auth/kiro/aws_test.go create mode 100644 internal/auth/kiro/oauth.go create mode 100644 internal/auth/kiro/protocol_handler.go create mode 100644 internal/auth/kiro/social_auth.go create mode 100644 internal/auth/kiro/sso_oidc.go create mode 100644 internal/auth/kiro/token.go create mode 100644 internal/cmd/kiro_login.go create mode 100644 internal/runtime/executor/kiro_executor.go create mode 100644 internal/translator/kiro/claude/init.go create mode 100644 internal/translator/kiro/claude/kiro_claude.go create mode 100644 internal/translator/kiro/openai/chat-completions/init.go create mode 100644 internal/translator/kiro/openai/chat-completions/kiro_openai_request.go create mode 100644 internal/translator/kiro/openai/chat-completions/kiro_openai_response.go create mode 100644 sdk/auth/kiro.go diff --git a/.gitignore b/.gitignore index 9e730c98..9014b273 100644 --- a/.gitignore +++ b/.gitignore @@ -1,5 +1,6 @@ # Binaries cli-proxy-api +cliproxy *.exe # Configuration @@ -31,6 +32,7 @@ GEMINI.md .vscode/* .claude/* .serena/* +.mcp/cache/ # macOS .DS_Store diff --git a/cmd/server/main.go b/cmd/server/main.go index bbf500e7..28c3e514 100644 --- a/cmd/server/main.go +++ b/cmd/server/main.go @@ -62,10 +62,16 @@ func main() { var iflowCookie bool var noBrowser bool var antigravityLogin bool + var kiroLogin bool + var kiroGoogleLogin bool + var kiroAWSLogin bool + var kiroImport bool var projectID string var vertexImport string var configPath string var password string + var noIncognito bool + var useIncognito bool // Define command-line flags for different operation modes. flag.BoolVar(&login, "login", false, "Login Google Account") @@ -75,7 +81,13 @@ func main() { flag.BoolVar(&iflowLogin, "iflow-login", false, "Login to iFlow using OAuth") flag.BoolVar(&iflowCookie, "iflow-cookie", false, "Login to iFlow using Cookie") flag.BoolVar(&noBrowser, "no-browser", false, "Don't open browser automatically for OAuth") + flag.BoolVar(&useIncognito, "incognito", false, "Open browser in incognito/private mode for OAuth (useful for multiple accounts)") + flag.BoolVar(&noIncognito, "no-incognito", false, "Force disable incognito mode (uses existing browser session)") flag.BoolVar(&antigravityLogin, "antigravity-login", false, "Login to Antigravity using OAuth") + flag.BoolVar(&kiroLogin, "kiro-login", false, "Login to Kiro using Google OAuth") + flag.BoolVar(&kiroGoogleLogin, "kiro-google-login", false, "Login to Kiro using Google OAuth (same as --kiro-login)") + flag.BoolVar(&kiroAWSLogin, "kiro-aws-login", false, "Login to Kiro using AWS Builder ID (device code flow)") + flag.BoolVar(&kiroImport, "kiro-import", false, "Import Kiro token from Kiro IDE (~/.aws/sso/cache/kiro-auth-token.json)") flag.StringVar(&projectID, "project_id", "", "Project ID (Gemini only, not required)") flag.StringVar(&configPath, "config", DefaultConfigPath, "Configure File Path") flag.StringVar(&vertexImport, "vertex-import", "", "Import Vertex service account key JSON file") @@ -448,6 +460,44 @@ func main() { cmd.DoIFlowLogin(cfg, options) } else if iflowCookie { cmd.DoIFlowCookieAuth(cfg, options) + } else if kiroLogin { + // For Kiro auth, default to incognito mode for multi-account support + // Users can explicitly override with --no-incognito + // Note: This config mutation is safe - auth commands exit after completion + // and don't share config with StartService (which is in the else branch) + if useIncognito { + cfg.IncognitoBrowser = true + } else if noIncognito { + cfg.IncognitoBrowser = false + } else { + cfg.IncognitoBrowser = true // Kiro default + } + cmd.DoKiroLogin(cfg, options) + } else if kiroGoogleLogin { + // For Kiro auth, default to incognito mode for multi-account support + // Users can explicitly override with --no-incognito + // Note: This config mutation is safe - auth commands exit after completion + if useIncognito { + cfg.IncognitoBrowser = true + } else if noIncognito { + cfg.IncognitoBrowser = false + } else { + cfg.IncognitoBrowser = true // Kiro default + } + cmd.DoKiroGoogleLogin(cfg, options) + } else if kiroAWSLogin { + // For Kiro auth, default to incognito mode for multi-account support + // Users can explicitly override with --no-incognito + if useIncognito { + cfg.IncognitoBrowser = true + } else if noIncognito { + cfg.IncognitoBrowser = false + } else { + cfg.IncognitoBrowser = true // Kiro default + } + cmd.DoKiroAWSLogin(cfg, options) + } else if kiroImport { + cmd.DoKiroImport(cfg, options) } else { // In cloud deploy mode without config file, just wait for shutdown signals if isCloudDeploy && !configFileExists { diff --git a/config.example.yaml b/config.example.yaml index 61f51d47..8fb01c14 100644 --- a/config.example.yaml +++ b/config.example.yaml @@ -32,6 +32,11 @@ api-keys: # Enable debug logging debug: false +# Open OAuth URLs in incognito/private browser mode. +# Useful when you want to login with a different account without logging out from your current session. +# Default: false (but Kiro auth defaults to true for multi-account support) +incognito-browser: true + # When true, write application logs to rotating files instead of stdout logging-to-file: false @@ -99,6 +104,16 @@ ws-auth: false # - "*-think" # wildcard matching suffix (e.g. claude-opus-4-5-thinking) # - "*haiku*" # wildcard matching substring (e.g. claude-3-5-haiku-20241022) +# Kiro (AWS CodeWhisperer) configuration +# Note: Kiro API currently only operates in us-east-1 region +#kiro: +# - token-file: "~/.aws/sso/cache/kiro-auth-token.json" # path to Kiro token file +# agent-task-type: "" # optional: "vibe" or empty (API default) +# - access-token: "aoaAAAAA..." # or provide tokens directly +# refresh-token: "aorAAAAA..." +# profile-arn: "arn:aws:codewhisperer:us-east-1:..." +# proxy-url: "socks5://proxy.example.com:1080" # optional: proxy override + # OpenAI compatibility providers # openai-compatibility: # - name: "openrouter" # The name of the provider; it will be used in the user agent and other places. diff --git a/go.mod b/go.mod index c7660c96..85d816c9 100644 --- a/go.mod +++ b/go.mod @@ -13,14 +13,15 @@ require ( github.com/joho/godotenv v1.5.1 github.com/klauspost/compress v1.17.4 github.com/minio/minio-go/v7 v7.0.66 + github.com/pkg/browser v0.0.0-20240102092130-5ac0b6a4141c github.com/sirupsen/logrus v1.9.3 - github.com/skratchdot/open-golang v0.0.0-20200116055534-eef842397966 github.com/tidwall/gjson v1.18.0 github.com/tidwall/sjson v1.2.5 github.com/tiktoken-go/tokenizer v0.7.0 golang.org/x/crypto v0.43.0 golang.org/x/net v0.46.0 golang.org/x/oauth2 v0.30.0 + golang.org/x/term v0.36.0 gopkg.in/natefinch/lumberjack.v2 v2.2.1 gopkg.in/yaml.v3 v3.0.1 ) diff --git a/go.sum b/go.sum index 3acf8562..833c45b9 100644 --- a/go.sum +++ b/go.sum @@ -116,6 +116,8 @@ github.com/pelletier/go-toml/v2 v2.2.2 h1:aYUidT7k73Pcl9nb2gScu7NSrKCSHIDE89b3+6 github.com/pelletier/go-toml/v2 v2.2.2/go.mod h1:1t835xjRzz80PqgE6HHgN2JOsmgYu/h4qDAS4n929Rs= github.com/pjbgf/sha1cd v0.5.0 h1:a+UkboSi1znleCDUNT3M5YxjOnN1fz2FhN48FlwCxs0= github.com/pjbgf/sha1cd v0.5.0/go.mod h1:lhpGlyHLpQZoxMv8HcgXvZEhcGs0PG/vsZnEJ7H0iCM= +github.com/pkg/browser v0.0.0-20240102092130-5ac0b6a4141c h1:+mdjkGKdHQG3305AYmdv1U2eRNDiU2ErMBj1gwrq8eQ= +github.com/pkg/browser v0.0.0-20240102092130-5ac0b6a4141c/go.mod h1:7rwL4CYBLnjLxUqIJNnCWiEdr3bn6IUYi15bNlnbCCU= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= github.com/rogpeppe/go-internal v1.14.1 h1:UQB4HGPB6osV0SQTLymcB4TgvyWu6ZyliaW0tI/otEQ= @@ -126,8 +128,6 @@ github.com/sergi/go-diff v1.4.0 h1:n/SP9D5ad1fORl+llWyN+D6qoUETXNZARKjyY2/KVCw= github.com/sergi/go-diff v1.4.0/go.mod h1:A0bzQcvG0E7Rwjx0REVgAGH58e96+X0MeOfepqsbeW4= github.com/sirupsen/logrus v1.9.3 h1:dueUQJ1C2q9oE3F7wvmSGAaVtTmUizReu6fjN8uqzbQ= github.com/sirupsen/logrus v1.9.3/go.mod h1:naHLuLoDiP4jHNo9R0sCBMtWGeIprob74mVsIT4qYEQ= -github.com/skratchdot/open-golang v0.0.0-20200116055534-eef842397966 h1:JIAuq3EEf9cgbU6AtGPK4CTG3Zf6CKMNqf0MHTggAUA= -github.com/skratchdot/open-golang v0.0.0-20200116055534-eef842397966/go.mod h1:sUM3LWHvSMaG192sy56D9F7CNvL7jUJVXoqM1QKLnog= github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= github.com/stretchr/objx v0.4.0/go.mod h1:YvHI0jy2hoMjB+UWwv71VJQ9isScKT/TqJzVSSt89Yw= github.com/stretchr/objx v0.5.0/go.mod h1:Yh+to48EsGEfYuaHDzXPcE3xhTkx73EhmCGUpEOglKo= @@ -169,6 +169,7 @@ golang.org/x/oauth2 v0.30.0/go.mod h1:B++QgG3ZKulg6sRPGD/mqlHQs5rB3Ml9erfeDY7xKl golang.org/x/sync v0.17.0 h1:l60nONMj9l5drqw6jlhIELNv9I0A4OFgRsG9k2oT9Ug= golang.org/x/sync v0.17.0/go.mod h1:9KTHXmSnoGruLpwFjVSX0lNNA75CykiMECbovNTZqGI= golang.org/x/sys v0.0.0-20220715151400-c0bba94af5f8/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.1.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.37.0 h1:fdNQudmxPjkdUTPnLn5mdQv7Zwvbvpaxqs831goi9kQ= golang.org/x/sys v0.37.0/go.mod h1:OgkHotnGiDImocRcuBABYBEXf8A9a87e/uXjp9XT3ks= diff --git a/internal/api/handlers/management/auth_files.go b/internal/api/handlers/management/auth_files.go index 6f77fda9..161f7a93 100644 --- a/internal/api/handlers/management/auth_files.go +++ b/internal/api/handlers/management/auth_files.go @@ -36,9 +36,32 @@ import ( ) var ( - oauthStatus = make(map[string]string) + oauthStatus = make(map[string]string) + oauthStatusMutex sync.RWMutex ) +// getOAuthStatus safely retrieves an OAuth status +func getOAuthStatus(key string) (string, bool) { + oauthStatusMutex.RLock() + defer oauthStatusMutex.RUnlock() + status, ok := oauthStatus[key] + return status, ok +} + +// setOAuthStatus safely sets an OAuth status +func setOAuthStatus(key string, status string) { + oauthStatusMutex.Lock() + defer oauthStatusMutex.Unlock() + oauthStatus[key] = status +} + +// deleteOAuthStatus safely deletes an OAuth status +func deleteOAuthStatus(key string) { + oauthStatusMutex.Lock() + defer oauthStatusMutex.Unlock() + delete(oauthStatus, key) +} + var lastRefreshKeys = []string{"last_refresh", "lastRefresh", "last_refreshed_at", "lastRefreshedAt"} const ( @@ -760,7 +783,7 @@ func (h *Handler) RequestAnthropicToken(c *gin.Context) { deadline := time.Now().Add(timeout) for { if time.Now().After(deadline) { - oauthStatus[state] = "Timeout waiting for OAuth callback" + setOAuthStatus(state, "Timeout waiting for OAuth callback") return nil, fmt.Errorf("timeout waiting for OAuth callback") } data, errRead := os.ReadFile(path) @@ -785,13 +808,13 @@ func (h *Handler) RequestAnthropicToken(c *gin.Context) { if errStr := resultMap["error"]; errStr != "" { oauthErr := claude.NewOAuthError(errStr, "", http.StatusBadRequest) log.Error(claude.GetUserFriendlyMessage(oauthErr)) - oauthStatus[state] = "Bad request" + setOAuthStatus(state, "Bad request") return } if resultMap["state"] != state { authErr := claude.NewAuthenticationError(claude.ErrInvalidState, fmt.Errorf("expected %s, got %s", state, resultMap["state"])) log.Error(claude.GetUserFriendlyMessage(authErr)) - oauthStatus[state] = "State code error" + setOAuthStatus(state, "State code error") return } @@ -824,7 +847,7 @@ func (h *Handler) RequestAnthropicToken(c *gin.Context) { if errDo != nil { authErr := claude.NewAuthenticationError(claude.ErrCodeExchangeFailed, errDo) log.Errorf("Failed to exchange authorization code for tokens: %v", authErr) - oauthStatus[state] = "Failed to exchange authorization code for tokens" + setOAuthStatus(state, "Failed to exchange authorization code for tokens") return } defer func() { @@ -835,7 +858,7 @@ func (h *Handler) RequestAnthropicToken(c *gin.Context) { respBody, _ := io.ReadAll(resp.Body) if resp.StatusCode != http.StatusOK { log.Errorf("token exchange failed with status %d: %s", resp.StatusCode, string(respBody)) - oauthStatus[state] = fmt.Sprintf("token exchange failed with status %d", resp.StatusCode) + setOAuthStatus(state, fmt.Sprintf("token exchange failed with status %d", resp.StatusCode)) return } var tResp struct { @@ -848,7 +871,7 @@ func (h *Handler) RequestAnthropicToken(c *gin.Context) { } if errU := json.Unmarshal(respBody, &tResp); errU != nil { log.Errorf("failed to parse token response: %v", errU) - oauthStatus[state] = "Failed to parse token response" + setOAuthStatus(state, "Failed to parse token response") return } bundle := &claude.ClaudeAuthBundle{ @@ -873,7 +896,7 @@ func (h *Handler) RequestAnthropicToken(c *gin.Context) { savedPath, errSave := h.saveTokenRecord(ctx, record) if errSave != nil { log.Fatalf("Failed to save authentication tokens: %v", errSave) - oauthStatus[state] = "Failed to save authentication tokens" + setOAuthStatus(state, "Failed to save authentication tokens") return } @@ -882,10 +905,10 @@ func (h *Handler) RequestAnthropicToken(c *gin.Context) { fmt.Println("API key obtained and saved") } fmt.Println("You can now use Claude services through this CLI") - delete(oauthStatus, state) + deleteOAuthStatus(state) }() - oauthStatus[state] = "" + setOAuthStatus(state, "") c.JSON(200, gin.H{"status": "ok", "url": authURL, "state": state}) } @@ -944,7 +967,7 @@ func (h *Handler) RequestGeminiCLIToken(c *gin.Context) { for { if time.Now().After(deadline) { log.Error("oauth flow timed out") - oauthStatus[state] = "OAuth flow timed out" + setOAuthStatus(state, "OAuth flow timed out") return } if data, errR := os.ReadFile(waitFile); errR == nil { @@ -953,13 +976,13 @@ func (h *Handler) RequestGeminiCLIToken(c *gin.Context) { _ = os.Remove(waitFile) if errStr := m["error"]; errStr != "" { log.Errorf("Authentication failed: %s", errStr) - oauthStatus[state] = "Authentication failed" + setOAuthStatus(state, "Authentication failed") return } authCode = m["code"] if authCode == "" { log.Errorf("Authentication failed: code not found") - oauthStatus[state] = "Authentication failed: code not found" + setOAuthStatus(state, "Authentication failed: code not found") return } break @@ -971,7 +994,7 @@ func (h *Handler) RequestGeminiCLIToken(c *gin.Context) { token, err := conf.Exchange(ctx, authCode) if err != nil { log.Errorf("Failed to exchange token: %v", err) - oauthStatus[state] = "Failed to exchange token" + setOAuthStatus(state, "Failed to exchange token") return } @@ -982,7 +1005,7 @@ func (h *Handler) RequestGeminiCLIToken(c *gin.Context) { req, errNewRequest := http.NewRequestWithContext(ctx, "GET", "https://www.googleapis.com/oauth2/v1/userinfo?alt=json", nil) if errNewRequest != nil { log.Errorf("Could not get user info: %v", errNewRequest) - oauthStatus[state] = "Could not get user info" + setOAuthStatus(state, "Could not get user info") return } req.Header.Set("Content-Type", "application/json") @@ -991,7 +1014,7 @@ func (h *Handler) RequestGeminiCLIToken(c *gin.Context) { resp, errDo := authHTTPClient.Do(req) if errDo != nil { log.Errorf("Failed to execute request: %v", errDo) - oauthStatus[state] = "Failed to execute request" + setOAuthStatus(state, "Failed to execute request") return } defer func() { @@ -1003,7 +1026,7 @@ func (h *Handler) RequestGeminiCLIToken(c *gin.Context) { bodyBytes, _ := io.ReadAll(resp.Body) if resp.StatusCode < 200 || resp.StatusCode >= 300 { log.Errorf("Get user info request failed with status %d: %s", resp.StatusCode, string(bodyBytes)) - oauthStatus[state] = fmt.Sprintf("Get user info request failed with status %d", resp.StatusCode) + setOAuthStatus(state, fmt.Sprintf("Get user info request failed with status %d", resp.StatusCode)) return } @@ -1012,7 +1035,7 @@ func (h *Handler) RequestGeminiCLIToken(c *gin.Context) { fmt.Printf("Authenticated user email: %s\n", email) } else { fmt.Println("Failed to get user email from token") - oauthStatus[state] = "Failed to get user email from token" + setOAuthStatus(state, "Failed to get user email from token") } // Marshal/unmarshal oauth2.Token to generic map and enrich fields @@ -1020,7 +1043,7 @@ func (h *Handler) RequestGeminiCLIToken(c *gin.Context) { jsonData, _ := json.Marshal(token) if errUnmarshal := json.Unmarshal(jsonData, &ifToken); errUnmarshal != nil { log.Errorf("Failed to unmarshal token: %v", errUnmarshal) - oauthStatus[state] = "Failed to unmarshal token" + setOAuthStatus(state, "Failed to unmarshal token") return } @@ -1046,7 +1069,7 @@ func (h *Handler) RequestGeminiCLIToken(c *gin.Context) { gemClient, errGetClient := gemAuth.GetAuthenticatedClient(ctx, &ts, h.cfg, true) if errGetClient != nil { log.Fatalf("failed to get authenticated client: %v", errGetClient) - oauthStatus[state] = "Failed to get authenticated client" + setOAuthStatus(state, "Failed to get authenticated client") return } fmt.Println("Authentication successful.") @@ -1056,12 +1079,12 @@ func (h *Handler) RequestGeminiCLIToken(c *gin.Context) { projects, errAll := onboardAllGeminiProjects(ctx, gemClient, &ts) if errAll != nil { log.Errorf("Failed to complete Gemini CLI onboarding: %v", errAll) - oauthStatus[state] = "Failed to complete Gemini CLI onboarding" + setOAuthStatus(state, "Failed to complete Gemini CLI onboarding") return } if errVerify := ensureGeminiProjectsEnabled(ctx, gemClient, projects); errVerify != nil { log.Errorf("Failed to verify Cloud AI API status: %v", errVerify) - oauthStatus[state] = "Failed to verify Cloud AI API status" + setOAuthStatus(state, "Failed to verify Cloud AI API status") return } ts.ProjectID = strings.Join(projects, ",") @@ -1069,26 +1092,26 @@ func (h *Handler) RequestGeminiCLIToken(c *gin.Context) { } else { if errEnsure := ensureGeminiProjectAndOnboard(ctx, gemClient, &ts, requestedProjectID); errEnsure != nil { log.Errorf("Failed to complete Gemini CLI onboarding: %v", errEnsure) - oauthStatus[state] = "Failed to complete Gemini CLI onboarding" + setOAuthStatus(state, "Failed to complete Gemini CLI onboarding") return } if strings.TrimSpace(ts.ProjectID) == "" { log.Error("Onboarding did not return a project ID") - oauthStatus[state] = "Failed to resolve project ID" + setOAuthStatus(state, "Failed to resolve project ID") return } isChecked, errCheck := checkCloudAPIIsEnabled(ctx, gemClient, ts.ProjectID) if errCheck != nil { log.Errorf("Failed to verify Cloud AI API status: %v", errCheck) - oauthStatus[state] = "Failed to verify Cloud AI API status" + setOAuthStatus(state, "Failed to verify Cloud AI API status") return } ts.Checked = isChecked if !isChecked { log.Error("Cloud AI API is not enabled for the selected project") - oauthStatus[state] = "Cloud AI API not enabled" + setOAuthStatus(state, "Cloud AI API not enabled") return } } @@ -1111,15 +1134,15 @@ func (h *Handler) RequestGeminiCLIToken(c *gin.Context) { savedPath, errSave := h.saveTokenRecord(ctx, record) if errSave != nil { log.Fatalf("Failed to save token to file: %v", errSave) - oauthStatus[state] = "Failed to save token to file" + setOAuthStatus(state, "Failed to save token to file") return } - delete(oauthStatus, state) + deleteOAuthStatus(state) fmt.Printf("You can now use Gemini CLI services through this CLI; token saved to %s\n", savedPath) }() - oauthStatus[state] = "" + setOAuthStatus(state, "") c.JSON(200, gin.H{"status": "ok", "url": authURL, "state": state}) } @@ -1180,7 +1203,7 @@ func (h *Handler) RequestCodexToken(c *gin.Context) { if time.Now().After(deadline) { authErr := codex.NewAuthenticationError(codex.ErrCallbackTimeout, fmt.Errorf("timeout waiting for OAuth callback")) log.Error(codex.GetUserFriendlyMessage(authErr)) - oauthStatus[state] = "Timeout waiting for OAuth callback" + setOAuthStatus(state, "Timeout waiting for OAuth callback") return } if data, errR := os.ReadFile(waitFile); errR == nil { @@ -1190,12 +1213,12 @@ func (h *Handler) RequestCodexToken(c *gin.Context) { if errStr := m["error"]; errStr != "" { oauthErr := codex.NewOAuthError(errStr, "", http.StatusBadRequest) log.Error(codex.GetUserFriendlyMessage(oauthErr)) - oauthStatus[state] = "Bad Request" + setOAuthStatus(state, "Bad Request") return } if m["state"] != state { authErr := codex.NewAuthenticationError(codex.ErrInvalidState, fmt.Errorf("expected %s, got %s", state, m["state"])) - oauthStatus[state] = "State code error" + setOAuthStatus(state, "State code error") log.Error(codex.GetUserFriendlyMessage(authErr)) return } @@ -1226,14 +1249,14 @@ func (h *Handler) RequestCodexToken(c *gin.Context) { resp, errDo := httpClient.Do(req) if errDo != nil { authErr := codex.NewAuthenticationError(codex.ErrCodeExchangeFailed, errDo) - oauthStatus[state] = "Failed to exchange authorization code for tokens" + setOAuthStatus(state, "Failed to exchange authorization code for tokens") log.Errorf("Failed to exchange authorization code for tokens: %v", authErr) return } defer func() { _ = resp.Body.Close() }() respBody, _ := io.ReadAll(resp.Body) if resp.StatusCode != http.StatusOK { - oauthStatus[state] = fmt.Sprintf("Token exchange failed with status %d", resp.StatusCode) + setOAuthStatus(state, fmt.Sprintf("Token exchange failed with status %d", resp.StatusCode)) log.Errorf("token exchange failed with status %d: %s", resp.StatusCode, string(respBody)) return } @@ -1244,7 +1267,7 @@ func (h *Handler) RequestCodexToken(c *gin.Context) { ExpiresIn int `json:"expires_in"` } if errU := json.Unmarshal(respBody, &tokenResp); errU != nil { - oauthStatus[state] = "Failed to parse token response" + setOAuthStatus(state, "Failed to parse token response") log.Errorf("failed to parse token response: %v", errU) return } @@ -1282,7 +1305,7 @@ func (h *Handler) RequestCodexToken(c *gin.Context) { } savedPath, errSave := h.saveTokenRecord(ctx, record) if errSave != nil { - oauthStatus[state] = "Failed to save authentication tokens" + setOAuthStatus(state, "Failed to save authentication tokens") log.Fatalf("Failed to save authentication tokens: %v", errSave) return } @@ -1291,10 +1314,10 @@ func (h *Handler) RequestCodexToken(c *gin.Context) { fmt.Println("API key obtained and saved") } fmt.Println("You can now use Codex services through this CLI") - delete(oauthStatus, state) + deleteOAuthStatus(state) }() - oauthStatus[state] = "" + setOAuthStatus(state, "") c.JSON(200, gin.H{"status": "ok", "url": authURL, "state": state}) } @@ -1360,7 +1383,7 @@ func (h *Handler) RequestAntigravityToken(c *gin.Context) { for { if time.Now().After(deadline) { log.Error("oauth flow timed out") - oauthStatus[state] = "OAuth flow timed out" + setOAuthStatus(state, "OAuth flow timed out") return } if data, errReadFile := os.ReadFile(waitFile); errReadFile == nil { @@ -1369,18 +1392,18 @@ func (h *Handler) RequestAntigravityToken(c *gin.Context) { _ = os.Remove(waitFile) if errStr := strings.TrimSpace(payload["error"]); errStr != "" { log.Errorf("Authentication failed: %s", errStr) - oauthStatus[state] = "Authentication failed" + setOAuthStatus(state, "Authentication failed") return } if payloadState := strings.TrimSpace(payload["state"]); payloadState != "" && payloadState != state { log.Errorf("Authentication failed: state mismatch") - oauthStatus[state] = "Authentication failed: state mismatch" + setOAuthStatus(state, "Authentication failed: state mismatch") return } authCode = strings.TrimSpace(payload["code"]) if authCode == "" { log.Error("Authentication failed: code not found") - oauthStatus[state] = "Authentication failed: code not found" + setOAuthStatus(state, "Authentication failed: code not found") return } break @@ -1399,7 +1422,7 @@ func (h *Handler) RequestAntigravityToken(c *gin.Context) { req, errNewRequest := http.NewRequestWithContext(ctx, http.MethodPost, "https://oauth2.googleapis.com/token", strings.NewReader(form.Encode())) if errNewRequest != nil { log.Errorf("Failed to build token request: %v", errNewRequest) - oauthStatus[state] = "Failed to build token request" + setOAuthStatus(state, "Failed to build token request") return } req.Header.Set("Content-Type", "application/x-www-form-urlencoded") @@ -1407,7 +1430,7 @@ func (h *Handler) RequestAntigravityToken(c *gin.Context) { resp, errDo := httpClient.Do(req) if errDo != nil { log.Errorf("Failed to execute token request: %v", errDo) - oauthStatus[state] = "Failed to exchange token" + setOAuthStatus(state, "Failed to exchange token") return } defer func() { @@ -1419,7 +1442,7 @@ func (h *Handler) RequestAntigravityToken(c *gin.Context) { if resp.StatusCode < http.StatusOK || resp.StatusCode >= http.StatusMultipleChoices { bodyBytes, _ := io.ReadAll(resp.Body) log.Errorf("Antigravity token exchange failed with status %d: %s", resp.StatusCode, string(bodyBytes)) - oauthStatus[state] = fmt.Sprintf("Token exchange failed: %d", resp.StatusCode) + setOAuthStatus(state, fmt.Sprintf("Token exchange failed: %d", resp.StatusCode)) return } @@ -1431,7 +1454,7 @@ func (h *Handler) RequestAntigravityToken(c *gin.Context) { } if errDecode := json.NewDecoder(resp.Body).Decode(&tokenResp); errDecode != nil { log.Errorf("Failed to parse token response: %v", errDecode) - oauthStatus[state] = "Failed to parse token response" + setOAuthStatus(state, "Failed to parse token response") return } @@ -1440,7 +1463,7 @@ func (h *Handler) RequestAntigravityToken(c *gin.Context) { infoReq, errInfoReq := http.NewRequestWithContext(ctx, http.MethodGet, "https://www.googleapis.com/oauth2/v1/userinfo?alt=json", nil) if errInfoReq != nil { log.Errorf("Failed to build user info request: %v", errInfoReq) - oauthStatus[state] = "Failed to build user info request" + setOAuthStatus(state, "Failed to build user info request") return } infoReq.Header.Set("Authorization", "Bearer "+tokenResp.AccessToken) @@ -1448,7 +1471,7 @@ func (h *Handler) RequestAntigravityToken(c *gin.Context) { infoResp, errInfo := httpClient.Do(infoReq) if errInfo != nil { log.Errorf("Failed to execute user info request: %v", errInfo) - oauthStatus[state] = "Failed to execute user info request" + setOAuthStatus(state, "Failed to execute user info request") return } defer func() { @@ -1467,7 +1490,7 @@ func (h *Handler) RequestAntigravityToken(c *gin.Context) { } else { bodyBytes, _ := io.ReadAll(infoResp.Body) log.Errorf("User info request failed with status %d: %s", infoResp.StatusCode, string(bodyBytes)) - oauthStatus[state] = fmt.Sprintf("User info request failed: %d", infoResp.StatusCode) + setOAuthStatus(state, fmt.Sprintf("User info request failed: %d", infoResp.StatusCode)) return } } @@ -1515,11 +1538,11 @@ func (h *Handler) RequestAntigravityToken(c *gin.Context) { savedPath, errSave := h.saveTokenRecord(ctx, record) if errSave != nil { log.Fatalf("Failed to save token to file: %v", errSave) - oauthStatus[state] = "Failed to save token to file" + setOAuthStatus(state, "Failed to save token to file") return } - delete(oauthStatus, state) + deleteOAuthStatus(state) fmt.Printf("Authentication successful! Token saved to %s\n", savedPath) if projectID != "" { fmt.Printf("Using GCP project: %s\n", projectID) @@ -1527,7 +1550,7 @@ func (h *Handler) RequestAntigravityToken(c *gin.Context) { fmt.Println("You can now use Antigravity services through this CLI") }() - oauthStatus[state] = "" + setOAuthStatus(state, "") c.JSON(200, gin.H{"status": "ok", "url": authURL, "state": state}) } @@ -1552,7 +1575,7 @@ func (h *Handler) RequestQwenToken(c *gin.Context) { fmt.Println("Waiting for authentication...") tokenData, errPollForToken := qwenAuth.PollForToken(deviceFlow.DeviceCode, deviceFlow.CodeVerifier) if errPollForToken != nil { - oauthStatus[state] = "Authentication failed" + setOAuthStatus(state, "Authentication failed") fmt.Printf("Authentication failed: %v\n", errPollForToken) return } @@ -1571,16 +1594,16 @@ func (h *Handler) RequestQwenToken(c *gin.Context) { savedPath, errSave := h.saveTokenRecord(ctx, record) if errSave != nil { log.Fatalf("Failed to save authentication tokens: %v", errSave) - oauthStatus[state] = "Failed to save authentication tokens" + setOAuthStatus(state, "Failed to save authentication tokens") return } fmt.Printf("Authentication successful! Token saved to %s\n", savedPath) fmt.Println("You can now use Qwen services through this CLI") - delete(oauthStatus, state) + deleteOAuthStatus(state) }() - oauthStatus[state] = "" + setOAuthStatus(state, "") c.JSON(200, gin.H{"status": "ok", "url": authURL, "state": state}) } @@ -1619,7 +1642,7 @@ func (h *Handler) RequestIFlowToken(c *gin.Context) { var resultMap map[string]string for { if time.Now().After(deadline) { - oauthStatus[state] = "Authentication failed" + setOAuthStatus(state, "Authentication failed") fmt.Println("Authentication failed: timeout waiting for callback") return } @@ -1632,26 +1655,26 @@ func (h *Handler) RequestIFlowToken(c *gin.Context) { } if errStr := strings.TrimSpace(resultMap["error"]); errStr != "" { - oauthStatus[state] = "Authentication failed" + setOAuthStatus(state, "Authentication failed") fmt.Printf("Authentication failed: %s\n", errStr) return } if resultState := strings.TrimSpace(resultMap["state"]); resultState != state { - oauthStatus[state] = "Authentication failed" + setOAuthStatus(state, "Authentication failed") fmt.Println("Authentication failed: state mismatch") return } code := strings.TrimSpace(resultMap["code"]) if code == "" { - oauthStatus[state] = "Authentication failed" + setOAuthStatus(state, "Authentication failed") fmt.Println("Authentication failed: code missing") return } tokenData, errExchange := authSvc.ExchangeCodeForTokens(ctx, code, redirectURI) if errExchange != nil { - oauthStatus[state] = "Authentication failed" + setOAuthStatus(state, "Authentication failed") fmt.Printf("Authentication failed: %v\n", errExchange) return } @@ -1673,7 +1696,7 @@ func (h *Handler) RequestIFlowToken(c *gin.Context) { savedPath, errSave := h.saveTokenRecord(ctx, record) if errSave != nil { - oauthStatus[state] = "Failed to save authentication tokens" + setOAuthStatus(state, "Failed to save authentication tokens") log.Fatalf("Failed to save authentication tokens: %v", errSave) return } @@ -1683,10 +1706,10 @@ func (h *Handler) RequestIFlowToken(c *gin.Context) { fmt.Println("API key obtained and saved") } fmt.Println("You can now use iFlow services through this CLI") - delete(oauthStatus, state) + deleteOAuthStatus(state) }() - oauthStatus[state] = "" + setOAuthStatus(state, "") c.JSON(http.StatusOK, gin.H{"status": "ok", "url": authURL, "state": state}) } @@ -2110,7 +2133,7 @@ func checkCloudAPIIsEnabled(ctx context.Context, httpClient *http.Client, projec func (h *Handler) GetAuthStatus(c *gin.Context) { state := c.Query("state") - if err, ok := oauthStatus[state]; ok { + if err, ok := getOAuthStatus(state); ok { if err != "" { c.JSON(200, gin.H{"status": "error", "error": err}) } else { @@ -2120,5 +2143,5 @@ func (h *Handler) GetAuthStatus(c *gin.Context) { } else { c.JSON(200, gin.H{"status": "ok"}) } - delete(oauthStatus, state) + deleteOAuthStatus(state) } diff --git a/internal/api/server.go b/internal/api/server.go index 9e1c5848..72cb0313 100644 --- a/internal/api/server.go +++ b/internal/api/server.go @@ -902,7 +902,7 @@ func (s *Server) UpdateClients(cfg *config.Config) { for _, p := range cfg.OpenAICompatibility { providerNames = append(providerNames, p.Name) } - s.handlers.OpenAICompatProviders = providerNames + s.handlers.SetOpenAICompatProviders(providerNames) s.handlers.UpdateClients(&cfg.SDKConfig) diff --git a/internal/auth/claude/oauth_server.go b/internal/auth/claude/oauth_server.go index a6ebe2f7..49b04794 100644 --- a/internal/auth/claude/oauth_server.go +++ b/internal/auth/claude/oauth_server.go @@ -242,6 +242,11 @@ func (s *OAuthServer) handleSuccess(w http.ResponseWriter, r *http.Request) { platformURL = "https://console.anthropic.com/" } + // Validate platformURL to prevent XSS - only allow http/https URLs + if !isValidURL(platformURL) { + platformURL = "https://console.anthropic.com/" + } + // Generate success page HTML with dynamic content successHTML := s.generateSuccessHTML(setupRequired, platformURL) @@ -251,6 +256,12 @@ func (s *OAuthServer) handleSuccess(w http.ResponseWriter, r *http.Request) { } } +// isValidURL checks if the URL is a valid http/https URL to prevent XSS +func isValidURL(urlStr string) bool { + urlStr = strings.TrimSpace(urlStr) + return strings.HasPrefix(urlStr, "https://") || strings.HasPrefix(urlStr, "http://") +} + // generateSuccessHTML creates the HTML content for the success page. // It customizes the page based on whether additional setup is required // and includes a link to the platform. diff --git a/internal/auth/codex/oauth_server.go b/internal/auth/codex/oauth_server.go index 9c6a6c5b..58b5394e 100644 --- a/internal/auth/codex/oauth_server.go +++ b/internal/auth/codex/oauth_server.go @@ -239,6 +239,11 @@ func (s *OAuthServer) handleSuccess(w http.ResponseWriter, r *http.Request) { platformURL = "https://platform.openai.com" } + // Validate platformURL to prevent XSS - only allow http/https URLs + if !isValidURL(platformURL) { + platformURL = "https://platform.openai.com" + } + // Generate success page HTML with dynamic content successHTML := s.generateSuccessHTML(setupRequired, platformURL) @@ -248,6 +253,12 @@ func (s *OAuthServer) handleSuccess(w http.ResponseWriter, r *http.Request) { } } +// isValidURL checks if the URL is a valid http/https URL to prevent XSS +func isValidURL(urlStr string) bool { + urlStr = strings.TrimSpace(urlStr) + return strings.HasPrefix(urlStr, "https://") || strings.HasPrefix(urlStr, "http://") +} + // generateSuccessHTML creates the HTML content for the success page. // It customizes the page based on whether additional setup is required // and includes a link to the platform. diff --git a/internal/auth/iflow/iflow_auth.go b/internal/auth/iflow/iflow_auth.go index 4957f519..b3431f84 100644 --- a/internal/auth/iflow/iflow_auth.go +++ b/internal/auth/iflow/iflow_auth.go @@ -9,6 +9,7 @@ import ( "io" "net/http" "net/url" + "os" "strings" "time" @@ -28,10 +29,21 @@ const ( iFlowAPIKeyEndpoint = "https://platform.iflow.cn/api/openapi/apikey" // Client credentials provided by iFlow for the Code Assist integration. - iFlowOAuthClientID = "10009311001" - iFlowOAuthClientSecret = "4Z3YjXycVsQvyGF1etiNlIBB4RsqSDtW" + iFlowOAuthClientID = "10009311001" + // Default client secret (can be overridden via IFLOW_CLIENT_SECRET env var) + defaultIFlowClientSecret = "4Z3YjXycVsQvyGF1etiNlIBB4RsqSDtW" ) +// getIFlowClientSecret returns the iFlow OAuth client secret. +// It first checks the IFLOW_CLIENT_SECRET environment variable, +// falling back to the default value if not set. +func getIFlowClientSecret() string { + if secret := os.Getenv("IFLOW_CLIENT_SECRET"); secret != "" { + return secret + } + return defaultIFlowClientSecret +} + // DefaultAPIBaseURL is the canonical chat completions endpoint. const DefaultAPIBaseURL = "https://apis.iflow.cn/v1" @@ -72,7 +84,7 @@ func (ia *IFlowAuth) ExchangeCodeForTokens(ctx context.Context, code, redirectUR form.Set("code", code) form.Set("redirect_uri", redirectURI) form.Set("client_id", iFlowOAuthClientID) - form.Set("client_secret", iFlowOAuthClientSecret) + form.Set("client_secret", getIFlowClientSecret()) req, err := ia.newTokenRequest(ctx, form) if err != nil { @@ -88,7 +100,7 @@ func (ia *IFlowAuth) RefreshTokens(ctx context.Context, refreshToken string) (*I form.Set("grant_type", "refresh_token") form.Set("refresh_token", refreshToken) form.Set("client_id", iFlowOAuthClientID) - form.Set("client_secret", iFlowOAuthClientSecret) + form.Set("client_secret", getIFlowClientSecret()) req, err := ia.newTokenRequest(ctx, form) if err != nil { @@ -104,7 +116,7 @@ func (ia *IFlowAuth) newTokenRequest(ctx context.Context, form url.Values) (*htt return nil, fmt.Errorf("iflow token: create request failed: %w", err) } - basic := base64.StdEncoding.EncodeToString([]byte(iFlowOAuthClientID + ":" + iFlowOAuthClientSecret)) + basic := base64.StdEncoding.EncodeToString([]byte(iFlowOAuthClientID + ":" + getIFlowClientSecret())) req.Header.Set("Content-Type", "application/x-www-form-urlencoded") req.Header.Set("Accept", "application/json") req.Header.Set("Authorization", "Basic "+basic) diff --git a/internal/auth/kiro/aws.go b/internal/auth/kiro/aws.go new file mode 100644 index 00000000..9be025c2 --- /dev/null +++ b/internal/auth/kiro/aws.go @@ -0,0 +1,301 @@ +// Package kiro provides authentication functionality for AWS CodeWhisperer (Kiro) API. +// It includes interfaces and implementations for token storage and authentication methods. +package kiro + +import ( + "encoding/base64" + "encoding/json" + "fmt" + "os" + "path/filepath" + "strings" +) + +// PKCECodes holds PKCE verification codes for OAuth2 PKCE flow +type PKCECodes struct { + // CodeVerifier is the cryptographically random string used to correlate + // the authorization request to the token request + CodeVerifier string `json:"code_verifier"` + // CodeChallenge is the SHA256 hash of the code verifier, base64url-encoded + CodeChallenge string `json:"code_challenge"` +} + +// KiroTokenData holds OAuth token information from AWS CodeWhisperer (Kiro) +type KiroTokenData struct { + // AccessToken is the OAuth2 access token for API access + AccessToken string `json:"accessToken"` + // RefreshToken is used to obtain new access tokens + RefreshToken string `json:"refreshToken"` + // ProfileArn is the AWS CodeWhisperer profile ARN + ProfileArn string `json:"profileArn"` + // ExpiresAt is the timestamp when the token expires + ExpiresAt string `json:"expiresAt"` + // AuthMethod indicates the authentication method used (e.g., "builder-id", "social") + AuthMethod string `json:"authMethod"` + // Provider indicates the OAuth provider (e.g., "AWS", "Google") + Provider string `json:"provider"` + // ClientID is the OIDC client ID (needed for token refresh) + ClientID string `json:"clientId,omitempty"` + // ClientSecret is the OIDC client secret (needed for token refresh) + ClientSecret string `json:"clientSecret,omitempty"` + // Email is the user's email address (used for file naming) + Email string `json:"email,omitempty"` +} + +// KiroAuthBundle aggregates authentication data after OAuth flow completion +type KiroAuthBundle struct { + // TokenData contains the OAuth tokens from the authentication flow + TokenData KiroTokenData `json:"token_data"` + // LastRefresh is the timestamp of the last token refresh + LastRefresh string `json:"last_refresh"` +} + +// KiroUsageInfo represents usage information from CodeWhisperer API +type KiroUsageInfo struct { + // SubscriptionTitle is the subscription plan name (e.g., "KIRO FREE") + SubscriptionTitle string `json:"subscription_title"` + // CurrentUsage is the current credit usage + CurrentUsage float64 `json:"current_usage"` + // UsageLimit is the maximum credit limit + UsageLimit float64 `json:"usage_limit"` + // NextReset is the timestamp of the next usage reset + NextReset string `json:"next_reset"` +} + +// KiroModel represents a model available through the CodeWhisperer API +type KiroModel struct { + // ModelID is the unique identifier for the model + ModelID string `json:"modelId"` + // ModelName is the human-readable name + ModelName string `json:"modelName"` + // Description is the model description + Description string `json:"description"` + // RateMultiplier is the credit multiplier for this model + RateMultiplier float64 `json:"rateMultiplier"` + // RateUnit is the unit for rate calculation (e.g., "credit") + RateUnit string `json:"rateUnit"` + // MaxInputTokens is the maximum input token limit + MaxInputTokens int `json:"maxInputTokens,omitempty"` +} + +// KiroIDETokenFile is the default path to Kiro IDE's token file +const KiroIDETokenFile = ".aws/sso/cache/kiro-auth-token.json" + +// LoadKiroIDEToken loads token data from Kiro IDE's token file. +func LoadKiroIDEToken() (*KiroTokenData, error) { + homeDir, err := os.UserHomeDir() + if err != nil { + return nil, fmt.Errorf("failed to get home directory: %w", err) + } + + tokenPath := filepath.Join(homeDir, KiroIDETokenFile) + data, err := os.ReadFile(tokenPath) + if err != nil { + return nil, fmt.Errorf("failed to read Kiro IDE token file (%s): %w", tokenPath, err) + } + + var token KiroTokenData + if err := json.Unmarshal(data, &token); err != nil { + return nil, fmt.Errorf("failed to parse Kiro IDE token: %w", err) + } + + if token.AccessToken == "" { + return nil, fmt.Errorf("access token is empty in Kiro IDE token file") + } + + return &token, nil +} + +// LoadKiroTokenFromPath loads token data from a custom path. +// This supports multiple accounts by allowing different token files. +func LoadKiroTokenFromPath(tokenPath string) (*KiroTokenData, error) { + // Expand ~ to home directory + if len(tokenPath) > 0 && tokenPath[0] == '~' { + homeDir, err := os.UserHomeDir() + if err != nil { + return nil, fmt.Errorf("failed to get home directory: %w", err) + } + tokenPath = filepath.Join(homeDir, tokenPath[1:]) + } + + data, err := os.ReadFile(tokenPath) + if err != nil { + return nil, fmt.Errorf("failed to read token file (%s): %w", tokenPath, err) + } + + var token KiroTokenData + if err := json.Unmarshal(data, &token); err != nil { + return nil, fmt.Errorf("failed to parse token file: %w", err) + } + + if token.AccessToken == "" { + return nil, fmt.Errorf("access token is empty in token file") + } + + return &token, nil +} + +// ListKiroTokenFiles lists all Kiro token files in the cache directory. +// This supports multiple accounts by finding all token files. +func ListKiroTokenFiles() ([]string, error) { + homeDir, err := os.UserHomeDir() + if err != nil { + return nil, fmt.Errorf("failed to get home directory: %w", err) + } + + cacheDir := filepath.Join(homeDir, ".aws", "sso", "cache") + + // Check if directory exists + if _, err := os.Stat(cacheDir); os.IsNotExist(err) { + return nil, nil // No token files + } + + entries, err := os.ReadDir(cacheDir) + if err != nil { + return nil, fmt.Errorf("failed to read cache directory: %w", err) + } + + var tokenFiles []string + for _, entry := range entries { + if entry.IsDir() { + continue + } + name := entry.Name() + // Look for kiro token files only (avoid matching unrelated AWS SSO cache files) + if strings.HasSuffix(name, ".json") && strings.HasPrefix(name, "kiro") { + tokenFiles = append(tokenFiles, filepath.Join(cacheDir, name)) + } + } + + return tokenFiles, nil +} + +// LoadAllKiroTokens loads all Kiro tokens from the cache directory. +// This supports multiple accounts. +func LoadAllKiroTokens() ([]*KiroTokenData, error) { + files, err := ListKiroTokenFiles() + if err != nil { + return nil, err + } + + var tokens []*KiroTokenData + for _, file := range files { + token, err := LoadKiroTokenFromPath(file) + if err != nil { + // Skip invalid token files + continue + } + tokens = append(tokens, token) + } + + return tokens, nil +} + +// JWTClaims represents the claims we care about from a JWT token. +// JWT tokens from Kiro/AWS contain user information in the payload. +type JWTClaims struct { + Email string `json:"email,omitempty"` + Sub string `json:"sub,omitempty"` + PreferredUser string `json:"preferred_username,omitempty"` + Name string `json:"name,omitempty"` + Iss string `json:"iss,omitempty"` +} + +// ExtractEmailFromJWT extracts the user's email from a JWT access token. +// JWT tokens typically have format: header.payload.signature +// The payload is base64url-encoded JSON containing user claims. +func ExtractEmailFromJWT(accessToken string) string { + if accessToken == "" { + return "" + } + + // JWT format: header.payload.signature + parts := strings.Split(accessToken, ".") + if len(parts) != 3 { + return "" + } + + // Decode the payload (second part) + payload := parts[1] + + // Add padding if needed (base64url requires padding) + switch len(payload) % 4 { + case 2: + payload += "==" + case 3: + payload += "=" + } + + decoded, err := base64.URLEncoding.DecodeString(payload) + if err != nil { + // Try RawURLEncoding (no padding) + decoded, err = base64.RawURLEncoding.DecodeString(parts[1]) + if err != nil { + return "" + } + } + + var claims JWTClaims + if err := json.Unmarshal(decoded, &claims); err != nil { + return "" + } + + // Return email if available + if claims.Email != "" { + return claims.Email + } + + // Fallback to preferred_username (some providers use this) + if claims.PreferredUser != "" && strings.Contains(claims.PreferredUser, "@") { + return claims.PreferredUser + } + + // Fallback to sub if it looks like an email + if claims.Sub != "" && strings.Contains(claims.Sub, "@") { + return claims.Sub + } + + return "" +} + +// SanitizeEmailForFilename sanitizes an email address for use in a filename. +// Replaces special characters with underscores and prevents path traversal attacks. +// Also handles URL-encoded characters to prevent encoded path traversal attempts. +func SanitizeEmailForFilename(email string) string { + if email == "" { + return "" + } + + result := email + + // First, handle URL-encoded path traversal attempts (%2F, %2E, %5C, etc.) + // This prevents encoded characters from bypassing the sanitization. + // Note: We replace % last to catch any remaining encodings including double-encoding (%252F) + result = strings.ReplaceAll(result, "%2F", "_") // / + result = strings.ReplaceAll(result, "%2f", "_") + result = strings.ReplaceAll(result, "%5C", "_") // \ + result = strings.ReplaceAll(result, "%5c", "_") + result = strings.ReplaceAll(result, "%2E", "_") // . + result = strings.ReplaceAll(result, "%2e", "_") + result = strings.ReplaceAll(result, "%00", "_") // null byte + result = strings.ReplaceAll(result, "%", "_") // Catch remaining % to prevent double-encoding attacks + + // Replace characters that are problematic in filenames + // Keep @ and . in middle but replace other special characters + for _, char := range []string{"/", "\\", ":", "*", "?", "\"", "<", ">", "|", " ", "\x00"} { + result = strings.ReplaceAll(result, char, "_") + } + + // Prevent path traversal: replace leading dots in each path component + // This handles cases like "../../../etc/passwd" → "_.._.._.._etc_passwd" + parts := strings.Split(result, "_") + for i, part := range parts { + for strings.HasPrefix(part, ".") { + part = "_" + part[1:] + } + parts[i] = part + } + result = strings.Join(parts, "_") + + return result +} diff --git a/internal/auth/kiro/aws_auth.go b/internal/auth/kiro/aws_auth.go new file mode 100644 index 00000000..53c77a8b --- /dev/null +++ b/internal/auth/kiro/aws_auth.go @@ -0,0 +1,314 @@ +// Package kiro provides OAuth2 authentication functionality for AWS CodeWhisperer (Kiro) API. +// This package implements token loading, refresh, and API communication with CodeWhisperer. +package kiro + +import ( + "context" + "encoding/json" + "fmt" + "io" + "net/http" + "os" + "path/filepath" + "strings" + "time" + + "github.com/router-for-me/CLIProxyAPI/v6/internal/config" + "github.com/router-for-me/CLIProxyAPI/v6/internal/util" + log "github.com/sirupsen/logrus" +) + +const ( + // awsKiroEndpoint is used for CodeWhisperer management APIs (GetUsageLimits, ListProfiles, etc.) + // Note: This is different from the Amazon Q streaming endpoint (q.us-east-1.amazonaws.com) + // used in kiro_executor.go for GenerateAssistantResponse. Both endpoints are correct + // for their respective API operations. + awsKiroEndpoint = "https://codewhisperer.us-east-1.amazonaws.com" + defaultTokenFile = "~/.aws/sso/cache/kiro-auth-token.json" + targetGetUsage = "AmazonCodeWhispererService.GetUsageLimits" + targetListModels = "AmazonCodeWhispererService.ListAvailableModels" + targetGenerateChat = "AmazonCodeWhispererStreamingService.GenerateAssistantResponse" +) + +// KiroAuth handles AWS CodeWhisperer authentication and API communication. +// It provides methods for loading tokens, refreshing expired tokens, +// and communicating with the CodeWhisperer API. +type KiroAuth struct { + httpClient *http.Client + endpoint string +} + +// NewKiroAuth creates a new Kiro authentication service. +// It initializes the HTTP client with proxy settings from the configuration. +// +// Parameters: +// - cfg: The application configuration containing proxy settings +// +// Returns: +// - *KiroAuth: A new Kiro authentication service instance +func NewKiroAuth(cfg *config.Config) *KiroAuth { + return &KiroAuth{ + httpClient: util.SetProxy(&cfg.SDKConfig, &http.Client{Timeout: 120 * time.Second}), + endpoint: awsKiroEndpoint, + } +} + +// LoadTokenFromFile loads token data from a file path. +// This method reads and parses the token file, expanding ~ to the home directory. +// +// Parameters: +// - tokenFile: Path to the token file (supports ~ expansion) +// +// Returns: +// - *KiroTokenData: The parsed token data +// - error: An error if file reading or parsing fails +func (k *KiroAuth) LoadTokenFromFile(tokenFile string) (*KiroTokenData, error) { + // Expand ~ to home directory + if strings.HasPrefix(tokenFile, "~") { + home, err := os.UserHomeDir() + if err != nil { + return nil, fmt.Errorf("failed to get home directory: %w", err) + } + tokenFile = filepath.Join(home, tokenFile[1:]) + } + + data, err := os.ReadFile(tokenFile) + if err != nil { + return nil, fmt.Errorf("failed to read token file: %w", err) + } + + var tokenData KiroTokenData + if err := json.Unmarshal(data, &tokenData); err != nil { + return nil, fmt.Errorf("failed to parse token file: %w", err) + } + + return &tokenData, nil +} + +// IsTokenExpired checks if the token has expired. +// This method parses the expiration timestamp and compares it with the current time. +// +// Parameters: +// - tokenData: The token data to check +// +// Returns: +// - bool: True if the token has expired, false otherwise +func (k *KiroAuth) IsTokenExpired(tokenData *KiroTokenData) bool { + if tokenData.ExpiresAt == "" { + return true + } + + expiresAt, err := time.Parse(time.RFC3339, tokenData.ExpiresAt) + if err != nil { + // Try alternate format + expiresAt, err = time.Parse("2006-01-02T15:04:05.000Z", tokenData.ExpiresAt) + if err != nil { + return true + } + } + + return time.Now().After(expiresAt) +} + +// makeRequest sends a request to the CodeWhisperer API. +// This is an internal method for making authenticated API calls. +// +// Parameters: +// - ctx: The context for the request +// - target: The API target (e.g., "AmazonCodeWhispererService.GetUsageLimits") +// - accessToken: The OAuth access token +// - payload: The request payload +// +// Returns: +// - []byte: The response body +// - error: An error if the request fails +func (k *KiroAuth) makeRequest(ctx context.Context, target string, accessToken string, payload interface{}) ([]byte, error) { + jsonBody, err := json.Marshal(payload) + if err != nil { + return nil, fmt.Errorf("failed to marshal request: %w", err) + } + + req, err := http.NewRequestWithContext(ctx, http.MethodPost, k.endpoint, strings.NewReader(string(jsonBody))) + if err != nil { + return nil, fmt.Errorf("failed to create request: %w", err) + } + + req.Header.Set("Content-Type", "application/x-amz-json-1.0") + req.Header.Set("x-amz-target", target) + req.Header.Set("Authorization", "Bearer "+accessToken) + req.Header.Set("Accept", "application/json") + + resp, err := k.httpClient.Do(req) + if err != nil { + return nil, fmt.Errorf("request failed: %w", err) + } + defer func() { + if errClose := resp.Body.Close(); errClose != nil { + log.Errorf("failed to close response body: %v", errClose) + } + }() + + body, err := io.ReadAll(resp.Body) + if err != nil { + return nil, fmt.Errorf("failed to read response: %w", err) + } + + if resp.StatusCode != http.StatusOK { + return nil, fmt.Errorf("API error (status %d): %s", resp.StatusCode, string(body)) + } + + return body, nil +} + +// GetUsageLimits retrieves usage information from the CodeWhisperer API. +// This method fetches the current usage statistics and subscription information. +// +// Parameters: +// - ctx: The context for the request +// - tokenData: The token data containing access token and profile ARN +// +// Returns: +// - *KiroUsageInfo: The usage information +// - error: An error if the request fails +func (k *KiroAuth) GetUsageLimits(ctx context.Context, tokenData *KiroTokenData) (*KiroUsageInfo, error) { + payload := map[string]interface{}{ + "origin": "AI_EDITOR", + "profileArn": tokenData.ProfileArn, + "resourceType": "AGENTIC_REQUEST", + } + + body, err := k.makeRequest(ctx, targetGetUsage, tokenData.AccessToken, payload) + if err != nil { + return nil, err + } + + var result struct { + SubscriptionInfo struct { + SubscriptionTitle string `json:"subscriptionTitle"` + } `json:"subscriptionInfo"` + UsageBreakdownList []struct { + CurrentUsageWithPrecision float64 `json:"currentUsageWithPrecision"` + UsageLimitWithPrecision float64 `json:"usageLimitWithPrecision"` + } `json:"usageBreakdownList"` + NextDateReset float64 `json:"nextDateReset"` + } + + if err := json.Unmarshal(body, &result); err != nil { + return nil, fmt.Errorf("failed to parse usage response: %w", err) + } + + usage := &KiroUsageInfo{ + SubscriptionTitle: result.SubscriptionInfo.SubscriptionTitle, + NextReset: fmt.Sprintf("%v", result.NextDateReset), + } + + if len(result.UsageBreakdownList) > 0 { + usage.CurrentUsage = result.UsageBreakdownList[0].CurrentUsageWithPrecision + usage.UsageLimit = result.UsageBreakdownList[0].UsageLimitWithPrecision + } + + return usage, nil +} + +// ListAvailableModels retrieves available models from the CodeWhisperer API. +// This method fetches the list of AI models available for the authenticated user. +// +// Parameters: +// - ctx: The context for the request +// - tokenData: The token data containing access token and profile ARN +// +// Returns: +// - []*KiroModel: The list of available models +// - error: An error if the request fails +func (k *KiroAuth) ListAvailableModels(ctx context.Context, tokenData *KiroTokenData) ([]*KiroModel, error) { + payload := map[string]interface{}{ + "origin": "AI_EDITOR", + "profileArn": tokenData.ProfileArn, + } + + body, err := k.makeRequest(ctx, targetListModels, tokenData.AccessToken, payload) + if err != nil { + return nil, err + } + + var result struct { + Models []struct { + ModelID string `json:"modelId"` + ModelName string `json:"modelName"` + Description string `json:"description"` + RateMultiplier float64 `json:"rateMultiplier"` + RateUnit string `json:"rateUnit"` + TokenLimits struct { + MaxInputTokens int `json:"maxInputTokens"` + } `json:"tokenLimits"` + } `json:"models"` + } + + if err := json.Unmarshal(body, &result); err != nil { + return nil, fmt.Errorf("failed to parse models response: %w", err) + } + + models := make([]*KiroModel, 0, len(result.Models)) + for _, m := range result.Models { + models = append(models, &KiroModel{ + ModelID: m.ModelID, + ModelName: m.ModelName, + Description: m.Description, + RateMultiplier: m.RateMultiplier, + RateUnit: m.RateUnit, + MaxInputTokens: m.TokenLimits.MaxInputTokens, + }) + } + + return models, nil +} + +// CreateTokenStorage creates a new KiroTokenStorage from token data. +// This method converts the token data into a storage structure suitable for persistence. +// +// Parameters: +// - tokenData: The token data to convert +// +// Returns: +// - *KiroTokenStorage: A new token storage instance +func (k *KiroAuth) CreateTokenStorage(tokenData *KiroTokenData) *KiroTokenStorage { + return &KiroTokenStorage{ + AccessToken: tokenData.AccessToken, + RefreshToken: tokenData.RefreshToken, + ProfileArn: tokenData.ProfileArn, + ExpiresAt: tokenData.ExpiresAt, + AuthMethod: tokenData.AuthMethod, + Provider: tokenData.Provider, + LastRefresh: time.Now().Format(time.RFC3339), + } +} + +// ValidateToken checks if the token is valid by making a test API call. +// This method verifies the token by attempting to fetch usage limits. +// +// Parameters: +// - ctx: The context for the request +// - tokenData: The token data to validate +// +// Returns: +// - error: An error if the token is invalid +func (k *KiroAuth) ValidateToken(ctx context.Context, tokenData *KiroTokenData) error { + _, err := k.GetUsageLimits(ctx, tokenData) + return err +} + +// UpdateTokenStorage updates an existing token storage with new token data. +// This method refreshes the token storage with newly obtained access and refresh tokens. +// +// Parameters: +// - storage: The existing token storage to update +// - tokenData: The new token data to apply +func (k *KiroAuth) UpdateTokenStorage(storage *KiroTokenStorage, tokenData *KiroTokenData) { + storage.AccessToken = tokenData.AccessToken + storage.RefreshToken = tokenData.RefreshToken + storage.ProfileArn = tokenData.ProfileArn + storage.ExpiresAt = tokenData.ExpiresAt + storage.AuthMethod = tokenData.AuthMethod + storage.Provider = tokenData.Provider + storage.LastRefresh = time.Now().Format(time.RFC3339) +} diff --git a/internal/auth/kiro/aws_test.go b/internal/auth/kiro/aws_test.go new file mode 100644 index 00000000..5f60294c --- /dev/null +++ b/internal/auth/kiro/aws_test.go @@ -0,0 +1,161 @@ +package kiro + +import ( + "encoding/base64" + "encoding/json" + "testing" +) + +func TestExtractEmailFromJWT(t *testing.T) { + tests := []struct { + name string + token string + expected string + }{ + { + name: "Empty token", + token: "", + expected: "", + }, + { + name: "Invalid token format", + token: "not.a.valid.jwt", + expected: "", + }, + { + name: "Invalid token - not base64", + token: "xxx.yyy.zzz", + expected: "", + }, + { + name: "Valid JWT with email", + token: createTestJWT(map[string]any{"email": "test@example.com", "sub": "user123"}), + expected: "test@example.com", + }, + { + name: "JWT without email but with preferred_username", + token: createTestJWT(map[string]any{"preferred_username": "user@domain.com", "sub": "user123"}), + expected: "user@domain.com", + }, + { + name: "JWT with email-like sub", + token: createTestJWT(map[string]any{"sub": "another@test.com"}), + expected: "another@test.com", + }, + { + name: "JWT without any email fields", + token: createTestJWT(map[string]any{"sub": "user123", "name": "Test User"}), + expected: "", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := ExtractEmailFromJWT(tt.token) + if result != tt.expected { + t.Errorf("ExtractEmailFromJWT() = %q, want %q", result, tt.expected) + } + }) + } +} + +func TestSanitizeEmailForFilename(t *testing.T) { + tests := []struct { + name string + email string + expected string + }{ + { + name: "Empty email", + email: "", + expected: "", + }, + { + name: "Simple email", + email: "user@example.com", + expected: "user@example.com", + }, + { + name: "Email with space", + email: "user name@example.com", + expected: "user_name@example.com", + }, + { + name: "Email with special chars", + email: "user:name@example.com", + expected: "user_name@example.com", + }, + { + name: "Email with multiple special chars", + email: "user/name:test@example.com", + expected: "user_name_test@example.com", + }, + { + name: "Path traversal attempt", + email: "../../../etc/passwd", + expected: "_.__.__._etc_passwd", + }, + { + name: "Path traversal with backslash", + email: `..\..\..\..\windows\system32`, + expected: "_.__.__.__._windows_system32", + }, + { + name: "Null byte injection attempt", + email: "user\x00@evil.com", + expected: "user_@evil.com", + }, + // URL-encoded path traversal tests + { + name: "URL-encoded slash", + email: "user%2Fpath@example.com", + expected: "user_path@example.com", + }, + { + name: "URL-encoded backslash", + email: "user%5Cpath@example.com", + expected: "user_path@example.com", + }, + { + name: "URL-encoded dot", + email: "%2E%2E%2Fetc%2Fpasswd", + expected: "___etc_passwd", + }, + { + name: "URL-encoded null", + email: "user%00@evil.com", + expected: "user_@evil.com", + }, + { + name: "Double URL-encoding attack", + email: "%252F%252E%252E", + expected: "_252F_252E_252E", // % replaced with _, remaining chars preserved (safe) + }, + { + name: "Mixed case URL-encoding", + email: "%2f%2F%5c%5C", + expected: "____", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := SanitizeEmailForFilename(tt.email) + if result != tt.expected { + t.Errorf("SanitizeEmailForFilename() = %q, want %q", result, tt.expected) + } + }) + } +} + +// createTestJWT creates a test JWT token with the given claims +func createTestJWT(claims map[string]any) string { + header := base64.RawURLEncoding.EncodeToString([]byte(`{"alg":"RS256","typ":"JWT"}`)) + + payloadBytes, _ := json.Marshal(claims) + payload := base64.RawURLEncoding.EncodeToString(payloadBytes) + + signature := base64.RawURLEncoding.EncodeToString([]byte("fake-signature")) + + return header + "." + payload + "." + signature +} diff --git a/internal/auth/kiro/oauth.go b/internal/auth/kiro/oauth.go new file mode 100644 index 00000000..e828da14 --- /dev/null +++ b/internal/auth/kiro/oauth.go @@ -0,0 +1,296 @@ +// Package kiro provides OAuth2 authentication for Kiro using native Google login. +package kiro + +import ( + "context" + "crypto/rand" + "crypto/sha256" + "encoding/base64" + "encoding/json" + "fmt" + "html" + "io" + "net" + "net/http" + "strings" + "time" + + "github.com/router-for-me/CLIProxyAPI/v6/internal/config" + "github.com/router-for-me/CLIProxyAPI/v6/internal/util" + log "github.com/sirupsen/logrus" +) + +const ( + // Kiro auth endpoint + kiroAuthEndpoint = "https://prod.us-east-1.auth.desktop.kiro.dev" + + // Default callback port + defaultCallbackPort = 9876 + + // Auth timeout + authTimeout = 10 * time.Minute +) + +// KiroTokenResponse represents the response from Kiro token endpoint. +type KiroTokenResponse struct { + AccessToken string `json:"accessToken"` + RefreshToken string `json:"refreshToken"` + ProfileArn string `json:"profileArn"` + ExpiresIn int `json:"expiresIn"` +} + +// KiroOAuth handles the OAuth flow for Kiro authentication. +type KiroOAuth struct { + httpClient *http.Client + cfg *config.Config +} + +// NewKiroOAuth creates a new Kiro OAuth handler. +func NewKiroOAuth(cfg *config.Config) *KiroOAuth { + client := &http.Client{Timeout: 30 * time.Second} + if cfg != nil { + client = util.SetProxy(&cfg.SDKConfig, client) + } + return &KiroOAuth{ + httpClient: client, + cfg: cfg, + } +} + +// generateCodeVerifier generates a random code verifier for PKCE. +func generateCodeVerifier() (string, error) { + b := make([]byte, 32) + if _, err := rand.Read(b); err != nil { + return "", err + } + return base64.RawURLEncoding.EncodeToString(b), nil +} + +// generateCodeChallenge generates the code challenge from verifier. +func generateCodeChallenge(verifier string) string { + h := sha256.Sum256([]byte(verifier)) + return base64.RawURLEncoding.EncodeToString(h[:]) +} + +// generateState generates a random state parameter. +func generateState() (string, error) { + b := make([]byte, 16) + if _, err := rand.Read(b); err != nil { + return "", err + } + return base64.RawURLEncoding.EncodeToString(b), nil +} + +// AuthResult contains the authorization code and state from callback. +type AuthResult struct { + Code string + State string + Error string +} + +// startCallbackServer starts a local HTTP server to receive the OAuth callback. +func (o *KiroOAuth) startCallbackServer(ctx context.Context, expectedState string) (string, <-chan AuthResult, error) { + // Try to find an available port - use localhost like Kiro does + listener, err := net.Listen("tcp", fmt.Sprintf("localhost:%d", defaultCallbackPort)) + if err != nil { + // Try with dynamic port (RFC 8252 allows dynamic ports for native apps) + log.Warnf("kiro oauth: default port %d is busy, falling back to dynamic port", defaultCallbackPort) + listener, err = net.Listen("tcp", "localhost:0") + if err != nil { + return "", nil, fmt.Errorf("failed to start callback server: %w", err) + } + } + + port := listener.Addr().(*net.TCPAddr).Port + // Use http scheme for local callback server + redirectURI := fmt.Sprintf("http://localhost:%d/oauth/callback", port) + resultChan := make(chan AuthResult, 1) + + server := &http.Server{ + ReadHeaderTimeout: 10 * time.Second, + } + + mux := http.NewServeMux() + mux.HandleFunc("/oauth/callback", func(w http.ResponseWriter, r *http.Request) { + code := r.URL.Query().Get("code") + state := r.URL.Query().Get("state") + errParam := r.URL.Query().Get("error") + + if errParam != "" { + w.Header().Set("Content-Type", "text/html") + w.WriteHeader(http.StatusBadRequest) + fmt.Fprintf(w, `

Login Failed

%s

You can close this window.

`, html.EscapeString(errParam)) + resultChan <- AuthResult{Error: errParam} + return + } + + if state != expectedState { + w.Header().Set("Content-Type", "text/html") + w.WriteHeader(http.StatusBadRequest) + fmt.Fprint(w, `

Login Failed

Invalid state parameter

You can close this window.

`) + resultChan <- AuthResult{Error: "state mismatch"} + return + } + + w.Header().Set("Content-Type", "text/html") + fmt.Fprint(w, `

Login Successful!

You can close this window and return to the terminal.

`) + resultChan <- AuthResult{Code: code, State: state} + }) + + server.Handler = mux + + go func() { + if err := server.Serve(listener); err != nil && err != http.ErrServerClosed { + log.Debugf("callback server error: %v", err) + } + }() + + go func() { + select { + case <-ctx.Done(): + case <-time.After(authTimeout): + case <-resultChan: + } + _ = server.Shutdown(context.Background()) + }() + + return redirectURI, resultChan, nil +} + +// LoginWithBuilderID performs OAuth login with AWS Builder ID using device code flow. +func (o *KiroOAuth) LoginWithBuilderID(ctx context.Context) (*KiroTokenData, error) { + ssoClient := NewSSOOIDCClient(o.cfg) + return ssoClient.LoginWithBuilderID(ctx) +} + +// exchangeCodeForToken exchanges the authorization code for tokens. +func (o *KiroOAuth) exchangeCodeForToken(ctx context.Context, code, codeVerifier, redirectURI string) (*KiroTokenData, error) { + payload := map[string]string{ + "code": code, + "code_verifier": codeVerifier, + "redirect_uri": redirectURI, + } + + body, err := json.Marshal(payload) + if err != nil { + return nil, fmt.Errorf("failed to marshal request: %w", err) + } + + tokenURL := kiroAuthEndpoint + "/oauth/token" + req, err := http.NewRequestWithContext(ctx, http.MethodPost, tokenURL, strings.NewReader(string(body))) + if err != nil { + return nil, fmt.Errorf("failed to create request: %w", err) + } + + req.Header.Set("Content-Type", "application/json") + req.Header.Set("User-Agent", "cli-proxy-api/1.0.0") + + resp, err := o.httpClient.Do(req) + if err != nil { + return nil, fmt.Errorf("token request failed: %w", err) + } + defer resp.Body.Close() + + respBody, err := io.ReadAll(resp.Body) + if err != nil { + return nil, fmt.Errorf("failed to read response: %w", err) + } + + if resp.StatusCode != http.StatusOK { + log.Debugf("token exchange failed (status %d): %s", resp.StatusCode, string(respBody)) + return nil, fmt.Errorf("token exchange failed (status %d)", resp.StatusCode) + } + + var tokenResp KiroTokenResponse + if err := json.Unmarshal(respBody, &tokenResp); err != nil { + return nil, fmt.Errorf("failed to parse token response: %w", err) + } + + // Validate ExpiresIn - use default 1 hour if invalid + expiresIn := tokenResp.ExpiresIn + if expiresIn <= 0 { + expiresIn = 3600 + } + expiresAt := time.Now().Add(time.Duration(expiresIn) * time.Second) + + return &KiroTokenData{ + AccessToken: tokenResp.AccessToken, + RefreshToken: tokenResp.RefreshToken, + ProfileArn: tokenResp.ProfileArn, + ExpiresAt: expiresAt.Format(time.RFC3339), + AuthMethod: "social", + Provider: "", // Caller should preserve original provider + }, nil +} + +// RefreshToken refreshes an expired access token. +func (o *KiroOAuth) RefreshToken(ctx context.Context, refreshToken string) (*KiroTokenData, error) { + payload := map[string]string{ + "refreshToken": refreshToken, + } + + body, err := json.Marshal(payload) + if err != nil { + return nil, fmt.Errorf("failed to marshal request: %w", err) + } + + refreshURL := kiroAuthEndpoint + "/refreshToken" + req, err := http.NewRequestWithContext(ctx, http.MethodPost, refreshURL, strings.NewReader(string(body))) + if err != nil { + return nil, fmt.Errorf("failed to create request: %w", err) + } + + req.Header.Set("Content-Type", "application/json") + req.Header.Set("User-Agent", "cli-proxy-api/1.0.0") + + resp, err := o.httpClient.Do(req) + if err != nil { + return nil, fmt.Errorf("refresh request failed: %w", err) + } + defer resp.Body.Close() + + respBody, err := io.ReadAll(resp.Body) + if err != nil { + return nil, fmt.Errorf("failed to read response: %w", err) + } + + if resp.StatusCode != http.StatusOK { + log.Debugf("token refresh failed (status %d): %s", resp.StatusCode, string(respBody)) + return nil, fmt.Errorf("token refresh failed (status %d)", resp.StatusCode) + } + + var tokenResp KiroTokenResponse + if err := json.Unmarshal(respBody, &tokenResp); err != nil { + return nil, fmt.Errorf("failed to parse token response: %w", err) + } + + // Validate ExpiresIn - use default 1 hour if invalid + expiresIn := tokenResp.ExpiresIn + if expiresIn <= 0 { + expiresIn = 3600 + } + expiresAt := time.Now().Add(time.Duration(expiresIn) * time.Second) + + return &KiroTokenData{ + AccessToken: tokenResp.AccessToken, + RefreshToken: tokenResp.RefreshToken, + ProfileArn: tokenResp.ProfileArn, + ExpiresAt: expiresAt.Format(time.RFC3339), + AuthMethod: "social", + Provider: "", // Caller should preserve original provider + }, nil +} + +// LoginWithGoogle performs OAuth login with Google using Kiro's social auth. +// This uses a custom protocol handler (kiro://) to receive the callback. +func (o *KiroOAuth) LoginWithGoogle(ctx context.Context) (*KiroTokenData, error) { + socialClient := NewSocialAuthClient(o.cfg) + return socialClient.LoginWithGoogle(ctx) +} + +// LoginWithGitHub performs OAuth login with GitHub using Kiro's social auth. +// This uses a custom protocol handler (kiro://) to receive the callback. +func (o *KiroOAuth) LoginWithGitHub(ctx context.Context) (*KiroTokenData, error) { + socialClient := NewSocialAuthClient(o.cfg) + return socialClient.LoginWithGitHub(ctx) +} diff --git a/internal/auth/kiro/protocol_handler.go b/internal/auth/kiro/protocol_handler.go new file mode 100644 index 00000000..e07cc24d --- /dev/null +++ b/internal/auth/kiro/protocol_handler.go @@ -0,0 +1,725 @@ +// Package kiro provides custom protocol handler registration for Kiro OAuth. +// This enables the CLI to intercept kiro:// URIs for social authentication (Google/GitHub). +package kiro + +import ( + "context" + "fmt" + "html" + "net" + "net/http" + "net/url" + "os" + "os/exec" + "path/filepath" + "runtime" + "strings" + "sync" + "time" + + log "github.com/sirupsen/logrus" +) + +const ( + // KiroProtocol is the custom URI scheme used by Kiro + KiroProtocol = "kiro" + + // KiroAuthority is the URI authority for authentication callbacks + KiroAuthority = "kiro.kiroAgent" + + // KiroAuthPath is the path for successful authentication + KiroAuthPath = "/authenticate-success" + + // KiroRedirectURI is the full redirect URI for social auth + KiroRedirectURI = "kiro://kiro.kiroAgent/authenticate-success" + + // DefaultHandlerPort is the default port for the local callback server + DefaultHandlerPort = 19876 + + // HandlerTimeout is how long to wait for the OAuth callback + HandlerTimeout = 10 * time.Minute +) + +// ProtocolHandler manages the custom kiro:// protocol handler for OAuth callbacks. +type ProtocolHandler struct { + port int + server *http.Server + listener net.Listener + resultChan chan *AuthCallback + stopChan chan struct{} + mu sync.Mutex + running bool +} + +// AuthCallback contains the OAuth callback parameters. +type AuthCallback struct { + Code string + State string + Error string +} + +// NewProtocolHandler creates a new protocol handler. +func NewProtocolHandler() *ProtocolHandler { + return &ProtocolHandler{ + port: DefaultHandlerPort, + resultChan: make(chan *AuthCallback, 1), + stopChan: make(chan struct{}), + } +} + +// Start starts the local callback server that receives redirects from the protocol handler. +func (h *ProtocolHandler) Start(ctx context.Context) (int, error) { + h.mu.Lock() + defer h.mu.Unlock() + + if h.running { + return h.port, nil + } + + // Drain any stale results from previous runs + select { + case <-h.resultChan: + default: + } + + // Reset stopChan for reuse - close old channel first to unblock any waiting goroutines + if h.stopChan != nil { + select { + case <-h.stopChan: + // Already closed + default: + close(h.stopChan) + } + } + h.stopChan = make(chan struct{}) + + // Try ports in known range (must match handler script port range) + var listener net.Listener + var err error + portRange := []int{DefaultHandlerPort, DefaultHandlerPort + 1, DefaultHandlerPort + 2, DefaultHandlerPort + 3, DefaultHandlerPort + 4} + + for _, port := range portRange { + listener, err = net.Listen("tcp", fmt.Sprintf("127.0.0.1:%d", port)) + if err == nil { + break + } + log.Debugf("kiro protocol handler: port %d busy, trying next", port) + } + + if listener == nil { + return 0, fmt.Errorf("failed to start callback server: all ports %d-%d are busy", DefaultHandlerPort, DefaultHandlerPort+4) + } + + h.listener = listener + h.port = listener.Addr().(*net.TCPAddr).Port + + mux := http.NewServeMux() + mux.HandleFunc("/oauth/callback", h.handleCallback) + + h.server = &http.Server{ + Handler: mux, + ReadHeaderTimeout: 10 * time.Second, + } + + go func() { + if err := h.server.Serve(listener); err != nil && err != http.ErrServerClosed { + log.Debugf("kiro protocol handler server error: %v", err) + } + }() + + h.running = true + log.Debugf("kiro protocol handler started on port %d", h.port) + + // Auto-shutdown after context done, timeout, or explicit stop + // Capture references to prevent race with new Start() calls + currentStopChan := h.stopChan + currentServer := h.server + currentListener := h.listener + go func() { + select { + case <-ctx.Done(): + case <-time.After(HandlerTimeout): + case <-currentStopChan: + return // Already stopped, exit goroutine + } + // Only stop if this is still the current server/listener instance + h.mu.Lock() + if h.server == currentServer && h.listener == currentListener { + h.mu.Unlock() + h.Stop() + } else { + h.mu.Unlock() + } + }() + + return h.port, nil +} + +// Stop stops the callback server. +func (h *ProtocolHandler) Stop() { + h.mu.Lock() + defer h.mu.Unlock() + + if !h.running { + return + } + + // Signal the auto-shutdown goroutine to exit. + // This select pattern is safe because stopChan is only modified while holding h.mu, + // and we hold the lock here. The select prevents panic from double-close. + select { + case <-h.stopChan: + // Already closed + default: + close(h.stopChan) + } + + if h.server != nil { + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + _ = h.server.Shutdown(ctx) + } + + h.running = false + log.Debug("kiro protocol handler stopped") +} + +// WaitForCallback waits for the OAuth callback and returns the result. +func (h *ProtocolHandler) WaitForCallback(ctx context.Context) (*AuthCallback, error) { + select { + case <-ctx.Done(): + return nil, ctx.Err() + case <-time.After(HandlerTimeout): + return nil, fmt.Errorf("timeout waiting for OAuth callback") + case result := <-h.resultChan: + return result, nil + } +} + +// GetPort returns the port the handler is listening on. +func (h *ProtocolHandler) GetPort() int { + return h.port +} + +// handleCallback processes the OAuth callback from the protocol handler script. +func (h *ProtocolHandler) handleCallback(w http.ResponseWriter, r *http.Request) { + code := r.URL.Query().Get("code") + state := r.URL.Query().Get("state") + errParam := r.URL.Query().Get("error") + + result := &AuthCallback{ + Code: code, + State: state, + Error: errParam, + } + + // Send result + select { + case h.resultChan <- result: + default: + // Channel full, ignore duplicate callbacks + } + + // Send success response + w.Header().Set("Content-Type", "text/html; charset=utf-8") + if errParam != "" { + w.WriteHeader(http.StatusBadRequest) + fmt.Fprintf(w, ` + +Login Failed + +

Login Failed

+

Error: %s

+

You can close this window.

+ +`, html.EscapeString(errParam)) + } else { + fmt.Fprint(w, ` + +Login Successful + +

Login Successful!

+

You can close this window and return to the terminal.

+ + +`) + } +} + +// IsProtocolHandlerInstalled checks if the kiro:// protocol handler is installed. +func IsProtocolHandlerInstalled() bool { + switch runtime.GOOS { + case "linux": + return isLinuxHandlerInstalled() + case "windows": + return isWindowsHandlerInstalled() + case "darwin": + return isDarwinHandlerInstalled() + default: + return false + } +} + +// InstallProtocolHandler installs the kiro:// protocol handler for the current platform. +func InstallProtocolHandler(handlerPort int) error { + switch runtime.GOOS { + case "linux": + return installLinuxHandler(handlerPort) + case "windows": + return installWindowsHandler(handlerPort) + case "darwin": + return installDarwinHandler(handlerPort) + default: + return fmt.Errorf("unsupported platform: %s", runtime.GOOS) + } +} + +// UninstallProtocolHandler removes the kiro:// protocol handler. +func UninstallProtocolHandler() error { + switch runtime.GOOS { + case "linux": + return uninstallLinuxHandler() + case "windows": + return uninstallWindowsHandler() + case "darwin": + return uninstallDarwinHandler() + default: + return fmt.Errorf("unsupported platform: %s", runtime.GOOS) + } +} + +// --- Linux Implementation --- + +func getLinuxDesktopPath() string { + homeDir, _ := os.UserHomeDir() + return filepath.Join(homeDir, ".local", "share", "applications", "kiro-oauth-handler.desktop") +} + +func getLinuxHandlerScriptPath() string { + homeDir, _ := os.UserHomeDir() + return filepath.Join(homeDir, ".local", "bin", "kiro-oauth-handler") +} + +func isLinuxHandlerInstalled() bool { + desktopPath := getLinuxDesktopPath() + _, err := os.Stat(desktopPath) + return err == nil +} + +func installLinuxHandler(handlerPort int) error { + // Create directories + homeDir, err := os.UserHomeDir() + if err != nil { + return err + } + + binDir := filepath.Join(homeDir, ".local", "bin") + appDir := filepath.Join(homeDir, ".local", "share", "applications") + + if err := os.MkdirAll(binDir, 0755); err != nil { + return fmt.Errorf("failed to create bin directory: %w", err) + } + if err := os.MkdirAll(appDir, 0755); err != nil { + return fmt.Errorf("failed to create applications directory: %w", err) + } + + // Create handler script - tries multiple ports to handle dynamic port allocation + scriptPath := getLinuxHandlerScriptPath() + scriptContent := fmt.Sprintf(`#!/bin/bash +# Kiro OAuth Protocol Handler +# Handles kiro:// URIs - tries CLI first, then forwards to Kiro IDE + +URL="$1" + +# Check curl availability +if ! command -v curl &> /dev/null; then + echo "Error: curl is required for Kiro OAuth handler" >&2 + exit 1 +fi + +# Extract code and state from URL +[[ "$URL" =~ code=([^&]+) ]] && CODE="${BASH_REMATCH[1]}" +[[ "$URL" =~ state=([^&]+) ]] && STATE="${BASH_REMATCH[1]}" +[[ "$URL" =~ error=([^&]+) ]] && ERROR="${BASH_REMATCH[1]}" + +# Try CLI proxy on multiple possible ports (default + dynamic range) +CLI_OK=0 +for PORT in %d %d %d %d %d; do + if [ -n "$ERROR" ]; then + curl -sf --connect-timeout 1 "http://127.0.0.1:$PORT/oauth/callback?error=$ERROR" && CLI_OK=1 && break + elif [ -n "$CODE" ] && [ -n "$STATE" ]; then + curl -sf --connect-timeout 1 "http://127.0.0.1:$PORT/oauth/callback?code=$CODE&state=$STATE" && CLI_OK=1 && break + fi +done + +# If CLI not available, forward to Kiro IDE +if [ $CLI_OK -eq 0 ] && [ -x "/usr/share/kiro/kiro" ]; then + /usr/share/kiro/kiro --open-url "$URL" & +fi +`, handlerPort, handlerPort+1, handlerPort+2, handlerPort+3, handlerPort+4) + + if err := os.WriteFile(scriptPath, []byte(scriptContent), 0755); err != nil { + return fmt.Errorf("failed to write handler script: %w", err) + } + + // Create .desktop file + desktopPath := getLinuxDesktopPath() + desktopContent := fmt.Sprintf(`[Desktop Entry] +Name=Kiro OAuth Handler +Comment=Handle kiro:// protocol for CLI Proxy API authentication +Exec=%s %%u +Type=Application +Terminal=false +NoDisplay=true +MimeType=x-scheme-handler/kiro; +Categories=Utility; +`, scriptPath) + + if err := os.WriteFile(desktopPath, []byte(desktopContent), 0644); err != nil { + return fmt.Errorf("failed to write desktop file: %w", err) + } + + // Register handler with xdg-mime + cmd := exec.Command("xdg-mime", "default", "kiro-oauth-handler.desktop", "x-scheme-handler/kiro") + if err := cmd.Run(); err != nil { + log.Warnf("xdg-mime registration failed (may need manual setup): %v", err) + } + + // Update desktop database + cmd = exec.Command("update-desktop-database", appDir) + _ = cmd.Run() // Ignore errors, not critical + + log.Info("Kiro protocol handler installed for Linux") + return nil +} + +func uninstallLinuxHandler() error { + desktopPath := getLinuxDesktopPath() + scriptPath := getLinuxHandlerScriptPath() + + if err := os.Remove(desktopPath); err != nil && !os.IsNotExist(err) { + return fmt.Errorf("failed to remove desktop file: %w", err) + } + if err := os.Remove(scriptPath); err != nil && !os.IsNotExist(err) { + return fmt.Errorf("failed to remove handler script: %w", err) + } + + log.Info("Kiro protocol handler uninstalled") + return nil +} + +// --- Windows Implementation --- + +func isWindowsHandlerInstalled() bool { + // Check registry key existence + cmd := exec.Command("reg", "query", `HKCU\Software\Classes\kiro`, "/ve") + return cmd.Run() == nil +} + +func installWindowsHandler(handlerPort int) error { + homeDir, err := os.UserHomeDir() + if err != nil { + return err + } + + // Create handler script (PowerShell) + scriptDir := filepath.Join(homeDir, ".cliproxyapi") + if err := os.MkdirAll(scriptDir, 0755); err != nil { + return fmt.Errorf("failed to create script directory: %w", err) + } + + scriptPath := filepath.Join(scriptDir, "kiro-oauth-handler.ps1") + scriptContent := fmt.Sprintf(`# Kiro OAuth Protocol Handler for Windows +param([string]$url) + +# Load required assembly for HttpUtility +Add-Type -AssemblyName System.Web + +# Parse URL parameters +$uri = [System.Uri]$url +$query = [System.Web.HttpUtility]::ParseQueryString($uri.Query) +$code = $query["code"] +$state = $query["state"] +$errorParam = $query["error"] + +# Try multiple ports (default + dynamic range) +$ports = @(%d, %d, %d, %d, %d) +$success = $false + +foreach ($port in $ports) { + if ($success) { break } + $callbackUrl = "http://127.0.0.1:$port/oauth/callback" + try { + if ($errorParam) { + $fullUrl = $callbackUrl + "?error=" + $errorParam + Invoke-WebRequest -Uri $fullUrl -UseBasicParsing -TimeoutSec 1 -ErrorAction Stop | Out-Null + $success = $true + } elseif ($code -and $state) { + $fullUrl = $callbackUrl + "?code=" + $code + "&state=" + $state + Invoke-WebRequest -Uri $fullUrl -UseBasicParsing -TimeoutSec 1 -ErrorAction Stop | Out-Null + $success = $true + } + } catch { + # Try next port + } +} +`, handlerPort, handlerPort+1, handlerPort+2, handlerPort+3, handlerPort+4) + + if err := os.WriteFile(scriptPath, []byte(scriptContent), 0644); err != nil { + return fmt.Errorf("failed to write handler script: %w", err) + } + + // Create batch wrapper + batchPath := filepath.Join(scriptDir, "kiro-oauth-handler.bat") + batchContent := fmt.Sprintf("@echo off\npowershell -ExecutionPolicy Bypass -File \"%s\" \"%%1\"\n", scriptPath) + + if err := os.WriteFile(batchPath, []byte(batchContent), 0644); err != nil { + return fmt.Errorf("failed to write batch wrapper: %w", err) + } + + // Register in Windows registry + commands := [][]string{ + {"reg", "add", `HKCU\Software\Classes\kiro`, "/ve", "/d", "URL:Kiro Protocol", "/f"}, + {"reg", "add", `HKCU\Software\Classes\kiro`, "/v", "URL Protocol", "/d", "", "/f"}, + {"reg", "add", `HKCU\Software\Classes\kiro\shell`, "/f"}, + {"reg", "add", `HKCU\Software\Classes\kiro\shell\open`, "/f"}, + {"reg", "add", `HKCU\Software\Classes\kiro\shell\open\command`, "/ve", "/d", fmt.Sprintf("\"%s\" \"%%1\"", batchPath), "/f"}, + } + + for _, args := range commands { + cmd := exec.Command(args[0], args[1:]...) + if err := cmd.Run(); err != nil { + return fmt.Errorf("failed to run registry command: %w", err) + } + } + + log.Info("Kiro protocol handler installed for Windows") + return nil +} + +func uninstallWindowsHandler() error { + // Remove registry keys + cmd := exec.Command("reg", "delete", `HKCU\Software\Classes\kiro`, "/f") + if err := cmd.Run(); err != nil { + log.Warnf("failed to remove registry key: %v", err) + } + + // Remove scripts + homeDir, _ := os.UserHomeDir() + scriptDir := filepath.Join(homeDir, ".cliproxyapi") + _ = os.Remove(filepath.Join(scriptDir, "kiro-oauth-handler.ps1")) + _ = os.Remove(filepath.Join(scriptDir, "kiro-oauth-handler.bat")) + + log.Info("Kiro protocol handler uninstalled") + return nil +} + +// --- macOS Implementation --- + +func getDarwinAppPath() string { + homeDir, _ := os.UserHomeDir() + return filepath.Join(homeDir, "Applications", "KiroOAuthHandler.app") +} + +func isDarwinHandlerInstalled() bool { + appPath := getDarwinAppPath() + _, err := os.Stat(appPath) + return err == nil +} + +func installDarwinHandler(handlerPort int) error { + // Create app bundle structure + appPath := getDarwinAppPath() + contentsPath := filepath.Join(appPath, "Contents") + macOSPath := filepath.Join(contentsPath, "MacOS") + + if err := os.MkdirAll(macOSPath, 0755); err != nil { + return fmt.Errorf("failed to create app bundle: %w", err) + } + + // Create Info.plist + plistPath := filepath.Join(contentsPath, "Info.plist") + plistContent := ` + + + + CFBundleIdentifier + com.cliproxyapi.kiro-oauth-handler + CFBundleName + KiroOAuthHandler + CFBundleExecutable + kiro-oauth-handler + CFBundleVersion + 1.0 + CFBundleURLTypes + + + CFBundleURLName + Kiro Protocol + CFBundleURLSchemes + + kiro + + + + LSBackgroundOnly + + +` + + if err := os.WriteFile(plistPath, []byte(plistContent), 0644); err != nil { + return fmt.Errorf("failed to write Info.plist: %w", err) + } + + // Create executable script - tries multiple ports to handle dynamic port allocation + execPath := filepath.Join(macOSPath, "kiro-oauth-handler") + execContent := fmt.Sprintf(`#!/bin/bash +# Kiro OAuth Protocol Handler for macOS + +URL="$1" + +# Check curl availability (should always exist on macOS) +if [ ! -x /usr/bin/curl ]; then + echo "Error: curl is required for Kiro OAuth handler" >&2 + exit 1 +fi + +# Extract code and state from URL +[[ "$URL" =~ code=([^&]+) ]] && CODE="${BASH_REMATCH[1]}" +[[ "$URL" =~ state=([^&]+) ]] && STATE="${BASH_REMATCH[1]}" +[[ "$URL" =~ error=([^&]+) ]] && ERROR="${BASH_REMATCH[1]}" + +# Try multiple ports (default + dynamic range) +for PORT in %d %d %d %d %d; do + if [ -n "$ERROR" ]; then + /usr/bin/curl -sf --connect-timeout 1 "http://127.0.0.1:$PORT/oauth/callback?error=$ERROR" && exit 0 + elif [ -n "$CODE" ] && [ -n "$STATE" ]; then + /usr/bin/curl -sf --connect-timeout 1 "http://127.0.0.1:$PORT/oauth/callback?code=$CODE&state=$STATE" && exit 0 + fi +done +`, handlerPort, handlerPort+1, handlerPort+2, handlerPort+3, handlerPort+4) + + if err := os.WriteFile(execPath, []byte(execContent), 0755); err != nil { + return fmt.Errorf("failed to write executable: %w", err) + } + + // Register the app with Launch Services + cmd := exec.Command("/System/Library/Frameworks/CoreServices.framework/Frameworks/LaunchServices.framework/Support/lsregister", + "-f", appPath) + if err := cmd.Run(); err != nil { + log.Warnf("lsregister failed (handler may still work): %v", err) + } + + log.Info("Kiro protocol handler installed for macOS") + return nil +} + +func uninstallDarwinHandler() error { + appPath := getDarwinAppPath() + + // Unregister from Launch Services + cmd := exec.Command("/System/Library/Frameworks/CoreServices.framework/Frameworks/LaunchServices.framework/Support/lsregister", + "-u", appPath) + _ = cmd.Run() + + // Remove app bundle + if err := os.RemoveAll(appPath); err != nil && !os.IsNotExist(err) { + return fmt.Errorf("failed to remove app bundle: %w", err) + } + + log.Info("Kiro protocol handler uninstalled") + return nil +} + +// ParseKiroURI parses a kiro:// URI and extracts the callback parameters. +func ParseKiroURI(rawURI string) (*AuthCallback, error) { + u, err := url.Parse(rawURI) + if err != nil { + return nil, fmt.Errorf("invalid URI: %w", err) + } + + if u.Scheme != KiroProtocol { + return nil, fmt.Errorf("invalid scheme: expected %s, got %s", KiroProtocol, u.Scheme) + } + + if u.Host != KiroAuthority { + return nil, fmt.Errorf("invalid authority: expected %s, got %s", KiroAuthority, u.Host) + } + + query := u.Query() + return &AuthCallback{ + Code: query.Get("code"), + State: query.Get("state"), + Error: query.Get("error"), + }, nil +} + +// GetHandlerInstructions returns platform-specific instructions for manual handler setup. +func GetHandlerInstructions() string { + switch runtime.GOOS { + case "linux": + return `To manually set up the Kiro protocol handler on Linux: + +1. Create ~/.local/share/applications/kiro-oauth-handler.desktop: + [Desktop Entry] + Name=Kiro OAuth Handler + Exec=~/.local/bin/kiro-oauth-handler %u + Type=Application + Terminal=false + MimeType=x-scheme-handler/kiro; + +2. Create ~/.local/bin/kiro-oauth-handler (make it executable): + #!/bin/bash + URL="$1" + # ... (see generated script for full content) + +3. Run: xdg-mime default kiro-oauth-handler.desktop x-scheme-handler/kiro` + + case "windows": + return `To manually set up the Kiro protocol handler on Windows: + +1. Open Registry Editor (regedit.exe) +2. Create key: HKEY_CURRENT_USER\Software\Classes\kiro +3. Set default value to: URL:Kiro Protocol +4. Create string value "URL Protocol" with empty data +5. Create subkey: shell\open\command +6. Set default value to: "C:\path\to\handler.bat" "%1"` + + case "darwin": + return `To manually set up the Kiro protocol handler on macOS: + +1. Create ~/Applications/KiroOAuthHandler.app bundle +2. Add Info.plist with CFBundleURLTypes containing "kiro" scheme +3. Create executable in Contents/MacOS/ +4. Run: /System/Library/.../lsregister -f ~/Applications/KiroOAuthHandler.app` + + default: + return "Protocol handler setup is not supported on this platform." + } +} + +// SetupProtocolHandlerIfNeeded checks and installs the protocol handler if needed. +func SetupProtocolHandlerIfNeeded(handlerPort int) error { + if IsProtocolHandlerInstalled() { + log.Debug("Kiro protocol handler already installed") + return nil + } + + fmt.Println("\n╔══════════════════════════════════════════════════════════╗") + fmt.Println("║ Kiro Protocol Handler Setup Required ║") + fmt.Println("╚══════════════════════════════════════════════════════════╝") + fmt.Println("\nTo enable Google/GitHub login, we need to install a protocol handler.") + fmt.Println("This allows your browser to redirect back to the CLI after authentication.") + fmt.Println("\nInstalling protocol handler...") + + if err := InstallProtocolHandler(handlerPort); err != nil { + fmt.Printf("\n⚠ Automatic installation failed: %v\n", err) + fmt.Println("\nManual setup instructions:") + fmt.Println(strings.Repeat("-", 60)) + fmt.Println(GetHandlerInstructions()) + return err + } + + fmt.Println("\n✓ Protocol handler installed successfully!") + return nil +} diff --git a/internal/auth/kiro/social_auth.go b/internal/auth/kiro/social_auth.go new file mode 100644 index 00000000..61c67886 --- /dev/null +++ b/internal/auth/kiro/social_auth.go @@ -0,0 +1,403 @@ +// Package kiro provides social authentication (Google/GitHub) for Kiro via AuthServiceClient. +package kiro + +import ( + "bufio" + "context" + "crypto/rand" + "crypto/sha256" + "encoding/base64" + "encoding/json" + "fmt" + "io" + "net/http" + "net/url" + "os" + "os/exec" + "runtime" + "strings" + "time" + + "github.com/router-for-me/CLIProxyAPI/v6/internal/browser" + "github.com/router-for-me/CLIProxyAPI/v6/internal/config" + "github.com/router-for-me/CLIProxyAPI/v6/internal/util" + log "github.com/sirupsen/logrus" + "golang.org/x/term" +) + +const ( + // Kiro AuthService endpoint + kiroAuthServiceEndpoint = "https://prod.us-east-1.auth.desktop.kiro.dev" + + // OAuth timeout + socialAuthTimeout = 10 * time.Minute +) + +// SocialProvider represents the social login provider. +type SocialProvider string + +const ( + // ProviderGoogle is Google OAuth provider + ProviderGoogle SocialProvider = "Google" + // ProviderGitHub is GitHub OAuth provider + ProviderGitHub SocialProvider = "Github" + // Note: AWS Builder ID is NOT supported by Kiro's auth service. + // It only supports: Google, Github, Cognito + // AWS Builder ID must use device code flow via SSO OIDC. +) + +// CreateTokenRequest is sent to Kiro's /oauth/token endpoint. +type CreateTokenRequest struct { + Code string `json:"code"` + CodeVerifier string `json:"code_verifier"` + RedirectURI string `json:"redirect_uri"` + InvitationCode string `json:"invitation_code,omitempty"` +} + +// SocialTokenResponse from Kiro's /oauth/token endpoint for social auth. +type SocialTokenResponse struct { + AccessToken string `json:"accessToken"` + RefreshToken string `json:"refreshToken"` + ProfileArn string `json:"profileArn"` + ExpiresIn int `json:"expiresIn"` +} + +// RefreshTokenRequest is sent to Kiro's /refreshToken endpoint. +type RefreshTokenRequest struct { + RefreshToken string `json:"refreshToken"` +} + +// SocialAuthClient handles social authentication with Kiro. +type SocialAuthClient struct { + httpClient *http.Client + cfg *config.Config + protocolHandler *ProtocolHandler +} + +// NewSocialAuthClient creates a new social auth client. +func NewSocialAuthClient(cfg *config.Config) *SocialAuthClient { + client := &http.Client{Timeout: 30 * time.Second} + if cfg != nil { + client = util.SetProxy(&cfg.SDKConfig, client) + } + return &SocialAuthClient{ + httpClient: client, + cfg: cfg, + protocolHandler: NewProtocolHandler(), + } +} + +// generatePKCE generates PKCE code verifier and challenge. +func generatePKCE() (verifier, challenge string, err error) { + // Generate 32 bytes of random data for verifier + b := make([]byte, 32) + if _, err := rand.Read(b); err != nil { + return "", "", fmt.Errorf("failed to generate random bytes: %w", err) + } + verifier = base64.RawURLEncoding.EncodeToString(b) + + // Generate SHA256 hash of verifier for challenge + h := sha256.Sum256([]byte(verifier)) + challenge = base64.RawURLEncoding.EncodeToString(h[:]) + + return verifier, challenge, nil +} + +// generateState generates a random state parameter. +func generateStateParam() (string, error) { + b := make([]byte, 16) + if _, err := rand.Read(b); err != nil { + return "", err + } + return base64.RawURLEncoding.EncodeToString(b), nil +} + +// buildLoginURL constructs the Kiro OAuth login URL. +// The login endpoint expects a GET request with query parameters. +// Format: /login?idp=Google&redirect_uri=...&code_challenge=...&code_challenge_method=S256&state=...&prompt=select_account +// The prompt=select_account parameter forces the account selection screen even if already logged in. +func (c *SocialAuthClient) buildLoginURL(provider, redirectURI, codeChallenge, state string) string { + return fmt.Sprintf("%s/login?idp=%s&redirect_uri=%s&code_challenge=%s&code_challenge_method=S256&state=%s&prompt=select_account", + kiroAuthServiceEndpoint, + provider, + url.QueryEscape(redirectURI), + codeChallenge, + state, + ) +} + +// createToken exchanges the authorization code for tokens. +func (c *SocialAuthClient) createToken(ctx context.Context, req *CreateTokenRequest) (*SocialTokenResponse, error) { + body, err := json.Marshal(req) + if err != nil { + return nil, fmt.Errorf("failed to marshal token request: %w", err) + } + + tokenURL := kiroAuthServiceEndpoint + "/oauth/token" + httpReq, err := http.NewRequestWithContext(ctx, http.MethodPost, tokenURL, strings.NewReader(string(body))) + if err != nil { + return nil, fmt.Errorf("failed to create token request: %w", err) + } + + httpReq.Header.Set("Content-Type", "application/json") + httpReq.Header.Set("User-Agent", "cli-proxy-api/1.0.0") + + resp, err := c.httpClient.Do(httpReq) + if err != nil { + return nil, fmt.Errorf("token request failed: %w", err) + } + defer resp.Body.Close() + + respBody, err := io.ReadAll(resp.Body) + if err != nil { + return nil, fmt.Errorf("failed to read token response: %w", err) + } + + if resp.StatusCode != http.StatusOK { + log.Debugf("token exchange failed (status %d): %s", resp.StatusCode, string(respBody)) + return nil, fmt.Errorf("token exchange failed (status %d)", resp.StatusCode) + } + + var tokenResp SocialTokenResponse + if err := json.Unmarshal(respBody, &tokenResp); err != nil { + return nil, fmt.Errorf("failed to parse token response: %w", err) + } + + return &tokenResp, nil +} + +// RefreshSocialToken refreshes an expired social auth token. +func (c *SocialAuthClient) RefreshSocialToken(ctx context.Context, refreshToken string) (*KiroTokenData, error) { + body, err := json.Marshal(&RefreshTokenRequest{RefreshToken: refreshToken}) + if err != nil { + return nil, fmt.Errorf("failed to marshal refresh request: %w", err) + } + + refreshURL := kiroAuthServiceEndpoint + "/refreshToken" + httpReq, err := http.NewRequestWithContext(ctx, http.MethodPost, refreshURL, strings.NewReader(string(body))) + if err != nil { + return nil, fmt.Errorf("failed to create refresh request: %w", err) + } + + httpReq.Header.Set("Content-Type", "application/json") + httpReq.Header.Set("User-Agent", "cli-proxy-api/1.0.0") + + resp, err := c.httpClient.Do(httpReq) + if err != nil { + return nil, fmt.Errorf("refresh request failed: %w", err) + } + defer resp.Body.Close() + + respBody, err := io.ReadAll(resp.Body) + if err != nil { + return nil, fmt.Errorf("failed to read refresh response: %w", err) + } + + if resp.StatusCode != http.StatusOK { + log.Debugf("token refresh failed (status %d): %s", resp.StatusCode, string(respBody)) + return nil, fmt.Errorf("token refresh failed (status %d)", resp.StatusCode) + } + + var tokenResp SocialTokenResponse + if err := json.Unmarshal(respBody, &tokenResp); err != nil { + return nil, fmt.Errorf("failed to parse refresh response: %w", err) + } + + // Validate ExpiresIn - use default 1 hour if invalid + expiresIn := tokenResp.ExpiresIn + if expiresIn <= 0 { + expiresIn = 3600 // Default 1 hour + } + expiresAt := time.Now().Add(time.Duration(expiresIn) * time.Second) + + return &KiroTokenData{ + AccessToken: tokenResp.AccessToken, + RefreshToken: tokenResp.RefreshToken, + ProfileArn: tokenResp.ProfileArn, + ExpiresAt: expiresAt.Format(time.RFC3339), + AuthMethod: "social", + Provider: "", // Caller should preserve original provider + }, nil +} + +// LoginWithSocial performs OAuth login with Google. +func (c *SocialAuthClient) LoginWithSocial(ctx context.Context, provider SocialProvider) (*KiroTokenData, error) { + providerName := string(provider) + + fmt.Println("\n╔══════════════════════════════════════════════════════════╗") + fmt.Printf("║ Kiro Authentication (%s) ║\n", providerName) + fmt.Println("╚══════════════════════════════════════════════════════════╝") + + // Step 1: Setup protocol handler + fmt.Println("\nSetting up authentication...") + + // Start the local callback server + handlerPort, err := c.protocolHandler.Start(ctx) + if err != nil { + return nil, fmt.Errorf("failed to start callback server: %w", err) + } + defer c.protocolHandler.Stop() + + // Ensure protocol handler is installed and set as default + if err := SetupProtocolHandlerIfNeeded(handlerPort); err != nil { + fmt.Println("\n⚠ Protocol handler setup failed. Trying alternative method...") + fmt.Println(" If you see a browser 'Open with' dialog, select your default browser.") + fmt.Println(" For manual setup instructions, run: cliproxy kiro --help-protocol") + log.Debugf("kiro: protocol handler setup error: %v", err) + // Continue anyway - user might have set it up manually or select browser manually + } else { + // Force set our handler as default (prevents "Open with" dialog) + forceDefaultProtocolHandler() + } + + // Step 2: Generate PKCE codes + codeVerifier, codeChallenge, err := generatePKCE() + if err != nil { + return nil, fmt.Errorf("failed to generate PKCE: %w", err) + } + + // Step 3: Generate state + state, err := generateStateParam() + if err != nil { + return nil, fmt.Errorf("failed to generate state: %w", err) + } + + // Step 4: Build the login URL (Kiro uses GET request with query params) + authURL := c.buildLoginURL(providerName, KiroRedirectURI, codeChallenge, state) + + // Set incognito mode based on config (defaults to true for Kiro, can be overridden with --no-incognito) + // Incognito mode enables multi-account support by bypassing cached sessions + if c.cfg != nil { + browser.SetIncognitoMode(c.cfg.IncognitoBrowser) + if !c.cfg.IncognitoBrowser { + log.Info("kiro: using normal browser mode (--no-incognito). Note: You may not be able to select a different account.") + } else { + log.Debug("kiro: using incognito mode for multi-account support") + } + } else { + browser.SetIncognitoMode(true) // Default to incognito if no config + log.Debug("kiro: using incognito mode for multi-account support (default)") + } + + // Step 5: Open browser for user authentication + fmt.Println("\n════════════════════════════════════════════════════════════") + fmt.Printf(" Opening browser for %s authentication...\n", providerName) + fmt.Println("════════════════════════════════════════════════════════════") + fmt.Printf("\n URL: %s\n\n", authURL) + + if err := browser.OpenURL(authURL); err != nil { + log.Warnf("Could not open browser automatically: %v", err) + fmt.Println(" ⚠ Could not open browser automatically.") + fmt.Println(" Please open the URL above in your browser manually.") + } else { + fmt.Println(" (Browser opened automatically)") + } + + fmt.Println("\n Waiting for authentication callback...") + + // Step 6: Wait for callback + callback, err := c.protocolHandler.WaitForCallback(ctx) + if err != nil { + return nil, fmt.Errorf("failed to receive callback: %w", err) + } + + if callback.Error != "" { + return nil, fmt.Errorf("authentication error: %s", callback.Error) + } + + if callback.State != state { + // Log state values for debugging, but don't expose in user-facing error + log.Debugf("kiro: OAuth state mismatch - expected %s, got %s", state, callback.State) + return nil, fmt.Errorf("OAuth state validation failed - please try again") + } + + if callback.Code == "" { + return nil, fmt.Errorf("no authorization code received") + } + + fmt.Println("\n✓ Authorization received!") + + // Step 7: Exchange code for tokens + fmt.Println("Exchanging code for tokens...") + + tokenReq := &CreateTokenRequest{ + Code: callback.Code, + CodeVerifier: codeVerifier, + RedirectURI: KiroRedirectURI, + } + + tokenResp, err := c.createToken(ctx, tokenReq) + if err != nil { + return nil, fmt.Errorf("failed to exchange code for tokens: %w", err) + } + + fmt.Println("\n✓ Authentication successful!") + + // Close the browser window + if err := browser.CloseBrowser(); err != nil { + log.Debugf("Failed to close browser: %v", err) + } + + // Validate ExpiresIn - use default 1 hour if invalid + expiresIn := tokenResp.ExpiresIn + if expiresIn <= 0 { + expiresIn = 3600 + } + expiresAt := time.Now().Add(time.Duration(expiresIn) * time.Second) + + // Try to extract email from JWT access token first + email := ExtractEmailFromJWT(tokenResp.AccessToken) + + // If no email in JWT, ask user for account label (only in interactive mode) + if email == "" && isInteractiveTerminal() { + fmt.Print("\n Enter account label for file naming (optional, press Enter to skip): ") + reader := bufio.NewReader(os.Stdin) + var err error + email, err = reader.ReadString('\n') + if err != nil { + log.Debugf("Failed to read account label: %v", err) + } + email = strings.TrimSpace(email) + } + + return &KiroTokenData{ + AccessToken: tokenResp.AccessToken, + RefreshToken: tokenResp.RefreshToken, + ProfileArn: tokenResp.ProfileArn, + ExpiresAt: expiresAt.Format(time.RFC3339), + AuthMethod: "social", + Provider: providerName, + Email: email, // JWT email or user-provided label + }, nil +} + +// LoginWithGoogle performs OAuth login with Google. +func (c *SocialAuthClient) LoginWithGoogle(ctx context.Context) (*KiroTokenData, error) { + return c.LoginWithSocial(ctx, ProviderGoogle) +} + +// LoginWithGitHub performs OAuth login with GitHub. +func (c *SocialAuthClient) LoginWithGitHub(ctx context.Context) (*KiroTokenData, error) { + return c.LoginWithSocial(ctx, ProviderGitHub) +} + +// forceDefaultProtocolHandler sets our protocol handler as the default for kiro:// URLs. +// This prevents the "Open with" dialog from appearing on Linux. +// On non-Linux platforms, this is a no-op as they use different mechanisms. +func forceDefaultProtocolHandler() { + if runtime.GOOS != "linux" { + return // Non-Linux platforms use different handler mechanisms + } + + // Set our handler as default using xdg-mime + cmd := exec.Command("xdg-mime", "default", "kiro-oauth-handler.desktop", "x-scheme-handler/kiro") + if err := cmd.Run(); err != nil { + log.Warnf("Failed to set default protocol handler: %v. You may see a handler selection dialog.", err) + } +} + +// isInteractiveTerminal checks if stdin is connected to an interactive terminal. +// Returns false in CI/automated environments or when stdin is piped. +func isInteractiveTerminal() bool { + return term.IsTerminal(int(os.Stdin.Fd())) +} diff --git a/internal/auth/kiro/sso_oidc.go b/internal/auth/kiro/sso_oidc.go new file mode 100644 index 00000000..d3c27d16 --- /dev/null +++ b/internal/auth/kiro/sso_oidc.go @@ -0,0 +1,527 @@ +// Package kiro provides AWS SSO OIDC authentication for Kiro. +package kiro + +import ( + "context" + "encoding/json" + "fmt" + "io" + "net/http" + "strings" + "time" + + "github.com/router-for-me/CLIProxyAPI/v6/internal/browser" + "github.com/router-for-me/CLIProxyAPI/v6/internal/config" + "github.com/router-for-me/CLIProxyAPI/v6/internal/util" + log "github.com/sirupsen/logrus" +) + +const ( + // AWS SSO OIDC endpoints + ssoOIDCEndpoint = "https://oidc.us-east-1.amazonaws.com" + + // Kiro's start URL for Builder ID + builderIDStartURL = "https://view.awsapps.com/start" + + // Polling interval + pollInterval = 5 * time.Second +) + +// SSOOIDCClient handles AWS SSO OIDC authentication. +type SSOOIDCClient struct { + httpClient *http.Client + cfg *config.Config +} + +// NewSSOOIDCClient creates a new SSO OIDC client. +func NewSSOOIDCClient(cfg *config.Config) *SSOOIDCClient { + client := &http.Client{Timeout: 30 * time.Second} + if cfg != nil { + client = util.SetProxy(&cfg.SDKConfig, client) + } + return &SSOOIDCClient{ + httpClient: client, + cfg: cfg, + } +} + +// RegisterClientResponse from AWS SSO OIDC. +type RegisterClientResponse struct { + ClientID string `json:"clientId"` + ClientSecret string `json:"clientSecret"` + ClientIDIssuedAt int64 `json:"clientIdIssuedAt"` + ClientSecretExpiresAt int64 `json:"clientSecretExpiresAt"` +} + +// StartDeviceAuthResponse from AWS SSO OIDC. +type StartDeviceAuthResponse struct { + DeviceCode string `json:"deviceCode"` + UserCode string `json:"userCode"` + VerificationURI string `json:"verificationUri"` + VerificationURIComplete string `json:"verificationUriComplete"` + ExpiresIn int `json:"expiresIn"` + Interval int `json:"interval"` +} + +// CreateTokenResponse from AWS SSO OIDC. +type CreateTokenResponse struct { + AccessToken string `json:"accessToken"` + TokenType string `json:"tokenType"` + ExpiresIn int `json:"expiresIn"` + RefreshToken string `json:"refreshToken"` +} + +// RegisterClient registers a new OIDC client with AWS. +func (c *SSOOIDCClient) RegisterClient(ctx context.Context) (*RegisterClientResponse, error) { + // Generate unique client name for each registration to support multiple accounts + clientName := fmt.Sprintf("CLI-Proxy-API-%d", time.Now().UnixNano()) + + payload := map[string]interface{}{ + "clientName": clientName, + "clientType": "public", + "scopes": []string{"codewhisperer:completions", "codewhisperer:analysis", "codewhisperer:conversations"}, + } + + body, err := json.Marshal(payload) + if err != nil { + return nil, err + } + + req, err := http.NewRequestWithContext(ctx, http.MethodPost, ssoOIDCEndpoint+"/client/register", strings.NewReader(string(body))) + if err != nil { + return nil, err + } + req.Header.Set("Content-Type", "application/json") + + resp, err := c.httpClient.Do(req) + if err != nil { + return nil, err + } + defer resp.Body.Close() + + respBody, err := io.ReadAll(resp.Body) + if err != nil { + return nil, err + } + + if resp.StatusCode != http.StatusOK { + log.Debugf("register client failed (status %d): %s", resp.StatusCode, string(respBody)) + return nil, fmt.Errorf("register client failed (status %d)", resp.StatusCode) + } + + var result RegisterClientResponse + if err := json.Unmarshal(respBody, &result); err != nil { + return nil, err + } + + return &result, nil +} + +// StartDeviceAuthorization starts the device authorization flow. +func (c *SSOOIDCClient) StartDeviceAuthorization(ctx context.Context, clientID, clientSecret string) (*StartDeviceAuthResponse, error) { + payload := map[string]string{ + "clientId": clientID, + "clientSecret": clientSecret, + "startUrl": builderIDStartURL, + } + + body, err := json.Marshal(payload) + if err != nil { + return nil, err + } + + req, err := http.NewRequestWithContext(ctx, http.MethodPost, ssoOIDCEndpoint+"/device_authorization", strings.NewReader(string(body))) + if err != nil { + return nil, err + } + req.Header.Set("Content-Type", "application/json") + + resp, err := c.httpClient.Do(req) + if err != nil { + return nil, err + } + defer resp.Body.Close() + + respBody, err := io.ReadAll(resp.Body) + if err != nil { + return nil, err + } + + if resp.StatusCode != http.StatusOK { + log.Debugf("start device auth failed (status %d): %s", resp.StatusCode, string(respBody)) + return nil, fmt.Errorf("start device auth failed (status %d)", resp.StatusCode) + } + + var result StartDeviceAuthResponse + if err := json.Unmarshal(respBody, &result); err != nil { + return nil, err + } + + return &result, nil +} + +// CreateToken polls for the access token after user authorization. +func (c *SSOOIDCClient) CreateToken(ctx context.Context, clientID, clientSecret, deviceCode string) (*CreateTokenResponse, error) { + payload := map[string]string{ + "clientId": clientID, + "clientSecret": clientSecret, + "deviceCode": deviceCode, + "grantType": "urn:ietf:params:oauth:grant-type:device_code", + } + + body, err := json.Marshal(payload) + if err != nil { + return nil, err + } + + req, err := http.NewRequestWithContext(ctx, http.MethodPost, ssoOIDCEndpoint+"/token", strings.NewReader(string(body))) + if err != nil { + return nil, err + } + req.Header.Set("Content-Type", "application/json") + + resp, err := c.httpClient.Do(req) + if err != nil { + return nil, err + } + defer resp.Body.Close() + + respBody, err := io.ReadAll(resp.Body) + if err != nil { + return nil, err + } + + // Check for pending authorization + if resp.StatusCode == http.StatusBadRequest { + var errResp struct { + Error string `json:"error"` + } + if json.Unmarshal(respBody, &errResp) == nil { + if errResp.Error == "authorization_pending" { + return nil, fmt.Errorf("authorization_pending") + } + if errResp.Error == "slow_down" { + return nil, fmt.Errorf("slow_down") + } + } + log.Debugf("create token failed: %s", string(respBody)) + return nil, fmt.Errorf("create token failed") + } + + if resp.StatusCode != http.StatusOK { + log.Debugf("create token failed (status %d): %s", resp.StatusCode, string(respBody)) + return nil, fmt.Errorf("create token failed (status %d)", resp.StatusCode) + } + + var result CreateTokenResponse + if err := json.Unmarshal(respBody, &result); err != nil { + return nil, err + } + + return &result, nil +} + +// RefreshToken refreshes an access token using the refresh token. +func (c *SSOOIDCClient) RefreshToken(ctx context.Context, clientID, clientSecret, refreshToken string) (*KiroTokenData, error) { + payload := map[string]string{ + "clientId": clientID, + "clientSecret": clientSecret, + "refreshToken": refreshToken, + "grantType": "refresh_token", + } + + body, err := json.Marshal(payload) + if err != nil { + return nil, err + } + + req, err := http.NewRequestWithContext(ctx, http.MethodPost, ssoOIDCEndpoint+"/token", strings.NewReader(string(body))) + if err != nil { + return nil, err + } + req.Header.Set("Content-Type", "application/json") + + resp, err := c.httpClient.Do(req) + if err != nil { + return nil, err + } + defer resp.Body.Close() + + respBody, err := io.ReadAll(resp.Body) + if err != nil { + return nil, err + } + + if resp.StatusCode != http.StatusOK { + log.Debugf("token refresh failed (status %d): %s", resp.StatusCode, string(respBody)) + return nil, fmt.Errorf("token refresh failed (status %d)", resp.StatusCode) + } + + var result CreateTokenResponse + if err := json.Unmarshal(respBody, &result); err != nil { + return nil, err + } + + expiresAt := time.Now().Add(time.Duration(result.ExpiresIn) * time.Second) + + return &KiroTokenData{ + AccessToken: result.AccessToken, + RefreshToken: result.RefreshToken, + ExpiresAt: expiresAt.Format(time.RFC3339), + AuthMethod: "builder-id", + Provider: "AWS", + ClientID: clientID, + ClientSecret: clientSecret, + }, nil +} + +// LoginWithBuilderID performs the full device code flow for AWS Builder ID. +func (c *SSOOIDCClient) LoginWithBuilderID(ctx context.Context) (*KiroTokenData, error) { + fmt.Println("\n╔══════════════════════════════════════════════════════════╗") + fmt.Println("║ Kiro Authentication (AWS Builder ID) ║") + fmt.Println("╚══════════════════════════════════════════════════════════╝") + + // Step 1: Register client + fmt.Println("\nRegistering client...") + regResp, err := c.RegisterClient(ctx) + if err != nil { + return nil, fmt.Errorf("failed to register client: %w", err) + } + log.Debugf("Client registered: %s", regResp.ClientID) + + // Step 2: Start device authorization + fmt.Println("Starting device authorization...") + authResp, err := c.StartDeviceAuthorization(ctx, regResp.ClientID, regResp.ClientSecret) + if err != nil { + return nil, fmt.Errorf("failed to start device auth: %w", err) + } + + // Step 3: Show user the verification URL + fmt.Printf("\n") + fmt.Println("════════════════════════════════════════════════════════════") + fmt.Printf(" Open this URL in your browser:\n") + fmt.Printf(" %s\n", authResp.VerificationURIComplete) + fmt.Println("════════════════════════════════════════════════════════════") + fmt.Printf("\n Or go to: %s\n", authResp.VerificationURI) + fmt.Printf(" And enter code: %s\n\n", authResp.UserCode) + + // Set incognito mode based on config (defaults to true for Kiro, can be overridden with --no-incognito) + // Incognito mode enables multi-account support by bypassing cached sessions + if c.cfg != nil { + browser.SetIncognitoMode(c.cfg.IncognitoBrowser) + if !c.cfg.IncognitoBrowser { + log.Info("kiro: using normal browser mode (--no-incognito). Note: You may not be able to select a different account.") + } else { + log.Debug("kiro: using incognito mode for multi-account support") + } + } else { + browser.SetIncognitoMode(true) // Default to incognito if no config + log.Debug("kiro: using incognito mode for multi-account support (default)") + } + + // Open browser using cross-platform browser package + if err := browser.OpenURL(authResp.VerificationURIComplete); err != nil { + log.Warnf("Could not open browser automatically: %v", err) + fmt.Println(" Please open the URL manually in your browser.") + } else { + fmt.Println(" (Browser opened automatically)") + } + + // Step 4: Poll for token + fmt.Println("Waiting for authorization...") + + interval := pollInterval + if authResp.Interval > 0 { + interval = time.Duration(authResp.Interval) * time.Second + } + + deadline := time.Now().Add(time.Duration(authResp.ExpiresIn) * time.Second) + + for time.Now().Before(deadline) { + select { + case <-ctx.Done(): + browser.CloseBrowser() // Cleanup on cancel + return nil, ctx.Err() + case <-time.After(interval): + tokenResp, err := c.CreateToken(ctx, regResp.ClientID, regResp.ClientSecret, authResp.DeviceCode) + if err != nil { + errStr := err.Error() + if strings.Contains(errStr, "authorization_pending") { + fmt.Print(".") + continue + } + if strings.Contains(errStr, "slow_down") { + interval += 5 * time.Second + continue + } + // Close browser on error before returning + browser.CloseBrowser() + return nil, fmt.Errorf("token creation failed: %w", err) + } + + fmt.Println("\n\n✓ Authorization successful!") + + // Close the browser window + if err := browser.CloseBrowser(); err != nil { + log.Debugf("Failed to close browser: %v", err) + } + + // Step 5: Get profile ARN from CodeWhisperer API + fmt.Println("Fetching profile information...") + profileArn := c.fetchProfileArn(ctx, tokenResp.AccessToken) + + // Extract email from JWT access token + email := ExtractEmailFromJWT(tokenResp.AccessToken) + if email != "" { + fmt.Printf(" Logged in as: %s\n", email) + } + + expiresAt := time.Now().Add(time.Duration(tokenResp.ExpiresIn) * time.Second) + + return &KiroTokenData{ + AccessToken: tokenResp.AccessToken, + RefreshToken: tokenResp.RefreshToken, + ProfileArn: profileArn, + ExpiresAt: expiresAt.Format(time.RFC3339), + AuthMethod: "builder-id", + Provider: "AWS", + ClientID: regResp.ClientID, + ClientSecret: regResp.ClientSecret, + Email: email, + }, nil + } + } + + // Close browser on timeout for better UX + if err := browser.CloseBrowser(); err != nil { + log.Debugf("Failed to close browser on timeout: %v", err) + } + return nil, fmt.Errorf("authorization timed out") +} + +// fetchProfileArn retrieves the profile ARN from CodeWhisperer API. +// This is needed for file naming since AWS SSO OIDC doesn't return profile info. +func (c *SSOOIDCClient) fetchProfileArn(ctx context.Context, accessToken string) string { + // Try ListProfiles API first + profileArn := c.tryListProfiles(ctx, accessToken) + if profileArn != "" { + return profileArn + } + + // Fallback: Try ListAvailableCustomizations + return c.tryListCustomizations(ctx, accessToken) +} + +func (c *SSOOIDCClient) tryListProfiles(ctx context.Context, accessToken string) string { + payload := map[string]interface{}{ + "origin": "AI_EDITOR", + } + + body, err := json.Marshal(payload) + if err != nil { + return "" + } + + req, err := http.NewRequestWithContext(ctx, http.MethodPost, "https://codewhisperer.us-east-1.amazonaws.com", strings.NewReader(string(body))) + if err != nil { + return "" + } + + req.Header.Set("Content-Type", "application/x-amz-json-1.0") + req.Header.Set("x-amz-target", "AmazonCodeWhispererService.ListProfiles") + req.Header.Set("Authorization", "Bearer "+accessToken) + req.Header.Set("Accept", "application/json") + + resp, err := c.httpClient.Do(req) + if err != nil { + return "" + } + defer resp.Body.Close() + + respBody, _ := io.ReadAll(resp.Body) + + if resp.StatusCode != http.StatusOK { + log.Debugf("ListProfiles failed (status %d): %s", resp.StatusCode, string(respBody)) + return "" + } + + log.Debugf("ListProfiles response: %s", string(respBody)) + + var result struct { + Profiles []struct { + Arn string `json:"arn"` + } `json:"profiles"` + ProfileArn string `json:"profileArn"` + } + + if err := json.Unmarshal(respBody, &result); err != nil { + return "" + } + + if result.ProfileArn != "" { + return result.ProfileArn + } + + if len(result.Profiles) > 0 { + return result.Profiles[0].Arn + } + + return "" +} + +func (c *SSOOIDCClient) tryListCustomizations(ctx context.Context, accessToken string) string { + payload := map[string]interface{}{ + "origin": "AI_EDITOR", + } + + body, err := json.Marshal(payload) + if err != nil { + return "" + } + + req, err := http.NewRequestWithContext(ctx, http.MethodPost, "https://codewhisperer.us-east-1.amazonaws.com", strings.NewReader(string(body))) + if err != nil { + return "" + } + + req.Header.Set("Content-Type", "application/x-amz-json-1.0") + req.Header.Set("x-amz-target", "AmazonCodeWhispererService.ListAvailableCustomizations") + req.Header.Set("Authorization", "Bearer "+accessToken) + req.Header.Set("Accept", "application/json") + + resp, err := c.httpClient.Do(req) + if err != nil { + return "" + } + defer resp.Body.Close() + + respBody, _ := io.ReadAll(resp.Body) + + if resp.StatusCode != http.StatusOK { + log.Debugf("ListAvailableCustomizations failed (status %d): %s", resp.StatusCode, string(respBody)) + return "" + } + + log.Debugf("ListAvailableCustomizations response: %s", string(respBody)) + + var result struct { + Customizations []struct { + Arn string `json:"arn"` + } `json:"customizations"` + ProfileArn string `json:"profileArn"` + } + + if err := json.Unmarshal(respBody, &result); err != nil { + return "" + } + + if result.ProfileArn != "" { + return result.ProfileArn + } + + if len(result.Customizations) > 0 { + return result.Customizations[0].Arn + } + + return "" +} diff --git a/internal/auth/kiro/token.go b/internal/auth/kiro/token.go new file mode 100644 index 00000000..e83b1728 --- /dev/null +++ b/internal/auth/kiro/token.go @@ -0,0 +1,72 @@ +package kiro + +import ( + "encoding/json" + "fmt" + "os" + "path/filepath" +) + +// KiroTokenStorage holds the persistent token data for Kiro authentication. +type KiroTokenStorage struct { + // AccessToken is the OAuth2 access token for API access + AccessToken string `json:"access_token"` + // RefreshToken is used to obtain new access tokens + RefreshToken string `json:"refresh_token"` + // ProfileArn is the AWS CodeWhisperer profile ARN + ProfileArn string `json:"profile_arn"` + // ExpiresAt is the timestamp when the token expires + ExpiresAt string `json:"expires_at"` + // AuthMethod indicates the authentication method used + AuthMethod string `json:"auth_method"` + // Provider indicates the OAuth provider + Provider string `json:"provider"` + // LastRefresh is the timestamp of the last token refresh + LastRefresh string `json:"last_refresh"` +} + +// SaveTokenToFile persists the token storage to the specified file path. +func (s *KiroTokenStorage) SaveTokenToFile(authFilePath string) error { + dir := filepath.Dir(authFilePath) + if err := os.MkdirAll(dir, 0700); err != nil { + return fmt.Errorf("failed to create directory: %w", err) + } + + data, err := json.MarshalIndent(s, "", " ") + if err != nil { + return fmt.Errorf("failed to marshal token storage: %w", err) + } + + if err := os.WriteFile(authFilePath, data, 0600); err != nil { + return fmt.Errorf("failed to write token file: %w", err) + } + + return nil +} + +// LoadFromFile loads token storage from the specified file path. +func LoadFromFile(authFilePath string) (*KiroTokenStorage, error) { + data, err := os.ReadFile(authFilePath) + if err != nil { + return nil, fmt.Errorf("failed to read token file: %w", err) + } + + var storage KiroTokenStorage + if err := json.Unmarshal(data, &storage); err != nil { + return nil, fmt.Errorf("failed to parse token file: %w", err) + } + + return &storage, nil +} + +// ToTokenData converts storage to KiroTokenData for API use. +func (s *KiroTokenStorage) ToTokenData() *KiroTokenData { + return &KiroTokenData{ + AccessToken: s.AccessToken, + RefreshToken: s.RefreshToken, + ProfileArn: s.ProfileArn, + ExpiresAt: s.ExpiresAt, + AuthMethod: s.AuthMethod, + Provider: s.Provider, + } +} diff --git a/internal/browser/browser.go b/internal/browser/browser.go index b24dc5e1..1ff0b469 100644 --- a/internal/browser/browser.go +++ b/internal/browser/browser.go @@ -6,14 +6,48 @@ import ( "fmt" "os/exec" "runtime" + "sync" + pkgbrowser "github.com/pkg/browser" log "github.com/sirupsen/logrus" - "github.com/skratchdot/open-golang/open" ) +// incognitoMode controls whether to open URLs in incognito/private mode. +// This is useful for OAuth flows where you want to use a different account. +var incognitoMode bool + +// lastBrowserProcess stores the last opened browser process for cleanup +var lastBrowserProcess *exec.Cmd +var browserMutex sync.Mutex + +// SetIncognitoMode enables or disables incognito/private browsing mode. +func SetIncognitoMode(enabled bool) { + incognitoMode = enabled +} + +// IsIncognitoMode returns whether incognito mode is enabled. +func IsIncognitoMode() bool { + return incognitoMode +} + +// CloseBrowser closes the last opened browser process. +func CloseBrowser() error { + browserMutex.Lock() + defer browserMutex.Unlock() + + if lastBrowserProcess == nil || lastBrowserProcess.Process == nil { + return nil + } + + err := lastBrowserProcess.Process.Kill() + lastBrowserProcess = nil + return err +} + // OpenURL opens the specified URL in the default web browser. -// It first attempts to use a platform-agnostic library and falls back to -// platform-specific commands if that fails. +// It uses the pkg/browser library which provides robust cross-platform support +// for Windows, macOS, and Linux. +// If incognito mode is enabled, it will open in a private/incognito window. // // Parameters: // - url: The URL to open. @@ -21,16 +55,22 @@ import ( // Returns: // - An error if the URL cannot be opened, otherwise nil. func OpenURL(url string) error { - fmt.Printf("Attempting to open URL in browser: %s\n", url) + log.Debugf("Opening URL in browser: %s (incognito=%v)", url, incognitoMode) - // Try using the open-golang library first - err := open.Run(url) + // If incognito mode is enabled, use platform-specific incognito commands + if incognitoMode { + log.Debug("Using incognito mode") + return openURLIncognito(url) + } + + // Use pkg/browser for cross-platform support + err := pkgbrowser.OpenURL(url) if err == nil { - log.Debug("Successfully opened URL using open-golang library") + log.Debug("Successfully opened URL using pkg/browser library") return nil } - log.Debugf("open-golang failed: %v, trying platform-specific commands", err) + log.Debugf("pkg/browser failed: %v, trying platform-specific commands", err) // Fallback to platform-specific commands return openURLPlatformSpecific(url) @@ -78,18 +118,394 @@ func openURLPlatformSpecific(url string) error { return nil } +// openURLIncognito opens a URL in incognito/private browsing mode. +// It first tries to detect the default browser and use its incognito flag. +// Falls back to a chain of known browsers if detection fails. +// +// Parameters: +// - url: The URL to open. +// +// Returns: +// - An error if the URL cannot be opened, otherwise nil. +func openURLIncognito(url string) error { + // First, try to detect and use the default browser + if cmd := tryDefaultBrowserIncognito(url); cmd != nil { + log.Debugf("Using detected default browser: %s %v", cmd.Path, cmd.Args[1:]) + if err := cmd.Start(); err == nil { + storeBrowserProcess(cmd) + log.Debug("Successfully opened URL in default browser's incognito mode") + return nil + } + log.Debugf("Failed to start default browser, trying fallback chain") + } + + // Fallback to known browser chain + cmd := tryFallbackBrowsersIncognito(url) + if cmd == nil { + log.Warn("No browser with incognito support found, falling back to normal mode") + return openURLPlatformSpecific(url) + } + + log.Debugf("Running incognito command: %s %v", cmd.Path, cmd.Args[1:]) + err := cmd.Start() + if err != nil { + log.Warnf("Failed to open incognito browser: %v, falling back to normal mode", err) + return openURLPlatformSpecific(url) + } + + storeBrowserProcess(cmd) + log.Debug("Successfully opened URL in incognito/private mode") + return nil +} + +// storeBrowserProcess safely stores the browser process for later cleanup. +func storeBrowserProcess(cmd *exec.Cmd) { + browserMutex.Lock() + lastBrowserProcess = cmd + browserMutex.Unlock() +} + +// tryDefaultBrowserIncognito attempts to detect the default browser and return +// an exec.Cmd configured with the appropriate incognito flag. +func tryDefaultBrowserIncognito(url string) *exec.Cmd { + switch runtime.GOOS { + case "darwin": + return tryDefaultBrowserMacOS(url) + case "windows": + return tryDefaultBrowserWindows(url) + case "linux": + return tryDefaultBrowserLinux(url) + } + return nil +} + +// tryDefaultBrowserMacOS detects the default browser on macOS. +func tryDefaultBrowserMacOS(url string) *exec.Cmd { + // Try to get default browser from Launch Services + out, err := exec.Command("defaults", "read", "com.apple.LaunchServices/com.apple.launchservices.secure", "LSHandlers").Output() + if err != nil { + return nil + } + + output := string(out) + var browserName string + + // Parse the output to find the http/https handler + if containsBrowserID(output, "com.google.chrome") { + browserName = "chrome" + } else if containsBrowserID(output, "org.mozilla.firefox") { + browserName = "firefox" + } else if containsBrowserID(output, "com.apple.safari") { + browserName = "safari" + } else if containsBrowserID(output, "com.brave.browser") { + browserName = "brave" + } else if containsBrowserID(output, "com.microsoft.edgemac") { + browserName = "edge" + } + + return createMacOSIncognitoCmd(browserName, url) +} + +// containsBrowserID checks if the LaunchServices output contains a browser ID. +func containsBrowserID(output, bundleID string) bool { + return stringContains(output, bundleID) +} + +// stringContains is a simple string contains check. +func stringContains(s, substr string) bool { + return len(s) >= len(substr) && (s == substr || len(substr) == 0 || + (len(s) > 0 && len(substr) > 0 && findSubstring(s, substr))) +} + +func findSubstring(s, substr string) bool { + for i := 0; i <= len(s)-len(substr); i++ { + if s[i:i+len(substr)] == substr { + return true + } + } + return false +} + +// createMacOSIncognitoCmd creates the appropriate incognito command for macOS browsers. +func createMacOSIncognitoCmd(browserName, url string) *exec.Cmd { + switch browserName { + case "chrome": + // Try direct path first + chromePath := "/Applications/Google Chrome.app/Contents/MacOS/Google Chrome" + if _, err := exec.LookPath(chromePath); err == nil { + return exec.Command(chromePath, "--incognito", url) + } + return exec.Command("open", "-na", "Google Chrome", "--args", "--incognito", url) + case "firefox": + return exec.Command("open", "-na", "Firefox", "--args", "--private-window", url) + case "safari": + // Safari doesn't have CLI incognito, try AppleScript + return tryAppleScriptSafariPrivate(url) + case "brave": + return exec.Command("open", "-na", "Brave Browser", "--args", "--incognito", url) + case "edge": + return exec.Command("open", "-na", "Microsoft Edge", "--args", "--inprivate", url) + } + return nil +} + +// tryAppleScriptSafariPrivate attempts to open Safari in private browsing mode using AppleScript. +func tryAppleScriptSafariPrivate(url string) *exec.Cmd { + // AppleScript to open a new private window in Safari + script := fmt.Sprintf(` + tell application "Safari" + activate + tell application "System Events" + keystroke "n" using {command down, shift down} + delay 0.5 + end tell + set URL of document 1 to "%s" + end tell + `, url) + + cmd := exec.Command("osascript", "-e", script) + // Test if this approach works by checking if Safari is available + if _, err := exec.LookPath("/Applications/Safari.app/Contents/MacOS/Safari"); err != nil { + log.Debug("Safari not found, AppleScript private window not available") + return nil + } + log.Debug("Attempting Safari private window via AppleScript") + return cmd +} + +// tryDefaultBrowserWindows detects the default browser on Windows via registry. +func tryDefaultBrowserWindows(url string) *exec.Cmd { + // Query registry for default browser + out, err := exec.Command("reg", "query", + `HKEY_CURRENT_USER\Software\Microsoft\Windows\Shell\Associations\UrlAssociations\http\UserChoice`, + "/v", "ProgId").Output() + if err != nil { + return nil + } + + output := string(out) + var browserName string + + // Map ProgId to browser name + if stringContains(output, "ChromeHTML") { + browserName = "chrome" + } else if stringContains(output, "FirefoxURL") { + browserName = "firefox" + } else if stringContains(output, "MSEdgeHTM") { + browserName = "edge" + } else if stringContains(output, "BraveHTML") { + browserName = "brave" + } + + return createWindowsIncognitoCmd(browserName, url) +} + +// createWindowsIncognitoCmd creates the appropriate incognito command for Windows browsers. +func createWindowsIncognitoCmd(browserName, url string) *exec.Cmd { + switch browserName { + case "chrome": + paths := []string{ + "chrome", + `C:\Program Files\Google\Chrome\Application\chrome.exe`, + `C:\Program Files (x86)\Google\Chrome\Application\chrome.exe`, + } + for _, p := range paths { + if _, err := exec.LookPath(p); err == nil { + return exec.Command(p, "--incognito", url) + } + } + case "firefox": + if path, err := exec.LookPath("firefox"); err == nil { + return exec.Command(path, "--private-window", url) + } + case "edge": + paths := []string{ + "msedge", + `C:\Program Files (x86)\Microsoft\Edge\Application\msedge.exe`, + `C:\Program Files\Microsoft\Edge\Application\msedge.exe`, + } + for _, p := range paths { + if _, err := exec.LookPath(p); err == nil { + return exec.Command(p, "--inprivate", url) + } + } + case "brave": + paths := []string{ + `C:\Program Files\BraveSoftware\Brave-Browser\Application\brave.exe`, + `C:\Program Files (x86)\BraveSoftware\Brave-Browser\Application\brave.exe`, + } + for _, p := range paths { + if _, err := exec.LookPath(p); err == nil { + return exec.Command(p, "--incognito", url) + } + } + } + return nil +} + +// tryDefaultBrowserLinux detects the default browser on Linux using xdg-settings. +func tryDefaultBrowserLinux(url string) *exec.Cmd { + out, err := exec.Command("xdg-settings", "get", "default-web-browser").Output() + if err != nil { + return nil + } + + desktop := string(out) + var browserName string + + // Map .desktop file to browser name + if stringContains(desktop, "google-chrome") || stringContains(desktop, "chrome") { + browserName = "chrome" + } else if stringContains(desktop, "firefox") { + browserName = "firefox" + } else if stringContains(desktop, "chromium") { + browserName = "chromium" + } else if stringContains(desktop, "brave") { + browserName = "brave" + } else if stringContains(desktop, "microsoft-edge") || stringContains(desktop, "msedge") { + browserName = "edge" + } + + return createLinuxIncognitoCmd(browserName, url) +} + +// createLinuxIncognitoCmd creates the appropriate incognito command for Linux browsers. +func createLinuxIncognitoCmd(browserName, url string) *exec.Cmd { + switch browserName { + case "chrome": + paths := []string{"google-chrome", "google-chrome-stable"} + for _, p := range paths { + if path, err := exec.LookPath(p); err == nil { + return exec.Command(path, "--incognito", url) + } + } + case "firefox": + paths := []string{"firefox", "firefox-esr"} + for _, p := range paths { + if path, err := exec.LookPath(p); err == nil { + return exec.Command(path, "--private-window", url) + } + } + case "chromium": + paths := []string{"chromium", "chromium-browser"} + for _, p := range paths { + if path, err := exec.LookPath(p); err == nil { + return exec.Command(path, "--incognito", url) + } + } + case "brave": + if path, err := exec.LookPath("brave-browser"); err == nil { + return exec.Command(path, "--incognito", url) + } + case "edge": + if path, err := exec.LookPath("microsoft-edge"); err == nil { + return exec.Command(path, "--inprivate", url) + } + } + return nil +} + +// tryFallbackBrowsersIncognito tries a chain of known browsers as fallback. +func tryFallbackBrowsersIncognito(url string) *exec.Cmd { + switch runtime.GOOS { + case "darwin": + return tryFallbackBrowsersMacOS(url) + case "windows": + return tryFallbackBrowsersWindows(url) + case "linux": + return tryFallbackBrowsersLinuxChain(url) + } + return nil +} + +// tryFallbackBrowsersMacOS tries known browsers on macOS. +func tryFallbackBrowsersMacOS(url string) *exec.Cmd { + // Try Chrome + chromePath := "/Applications/Google Chrome.app/Contents/MacOS/Google Chrome" + if _, err := exec.LookPath(chromePath); err == nil { + return exec.Command(chromePath, "--incognito", url) + } + // Try Firefox + if _, err := exec.LookPath("/Applications/Firefox.app/Contents/MacOS/firefox"); err == nil { + return exec.Command("open", "-na", "Firefox", "--args", "--private-window", url) + } + // Try Brave + if _, err := exec.LookPath("/Applications/Brave Browser.app/Contents/MacOS/Brave Browser"); err == nil { + return exec.Command("open", "-na", "Brave Browser", "--args", "--incognito", url) + } + // Try Edge + if _, err := exec.LookPath("/Applications/Microsoft Edge.app/Contents/MacOS/Microsoft Edge"); err == nil { + return exec.Command("open", "-na", "Microsoft Edge", "--args", "--inprivate", url) + } + // Last resort: try Safari with AppleScript + if cmd := tryAppleScriptSafariPrivate(url); cmd != nil { + log.Info("Using Safari with AppleScript for private browsing (may require accessibility permissions)") + return cmd + } + return nil +} + +// tryFallbackBrowsersWindows tries known browsers on Windows. +func tryFallbackBrowsersWindows(url string) *exec.Cmd { + // Chrome + chromePaths := []string{ + "chrome", + `C:\Program Files\Google\Chrome\Application\chrome.exe`, + `C:\Program Files (x86)\Google\Chrome\Application\chrome.exe`, + } + for _, p := range chromePaths { + if _, err := exec.LookPath(p); err == nil { + return exec.Command(p, "--incognito", url) + } + } + // Firefox + if path, err := exec.LookPath("firefox"); err == nil { + return exec.Command(path, "--private-window", url) + } + // Edge (usually available on Windows 10+) + edgePaths := []string{ + "msedge", + `C:\Program Files (x86)\Microsoft\Edge\Application\msedge.exe`, + `C:\Program Files\Microsoft\Edge\Application\msedge.exe`, + } + for _, p := range edgePaths { + if _, err := exec.LookPath(p); err == nil { + return exec.Command(p, "--inprivate", url) + } + } + return nil +} + +// tryFallbackBrowsersLinuxChain tries known browsers on Linux. +func tryFallbackBrowsersLinuxChain(url string) *exec.Cmd { + type browserConfig struct { + name string + flag string + } + browsers := []browserConfig{ + {"google-chrome", "--incognito"}, + {"google-chrome-stable", "--incognito"}, + {"chromium", "--incognito"}, + {"chromium-browser", "--incognito"}, + {"firefox", "--private-window"}, + {"firefox-esr", "--private-window"}, + {"brave-browser", "--incognito"}, + {"microsoft-edge", "--inprivate"}, + } + for _, b := range browsers { + if path, err := exec.LookPath(b.name); err == nil { + return exec.Command(path, b.flag, url) + } + } + return nil +} + // IsAvailable checks if the system has a command available to open a web browser. // It verifies the presence of necessary commands for the current operating system. // // Returns: // - true if a browser can be opened, false otherwise. func IsAvailable() bool { - // First check if open-golang can work - testErr := open.Run("about:blank") - if testErr == nil { - return true - } - // Check platform-specific commands switch runtime.GOOS { case "darwin": diff --git a/internal/cmd/auth_manager.go b/internal/cmd/auth_manager.go index e6caa954..692ea34d 100644 --- a/internal/cmd/auth_manager.go +++ b/internal/cmd/auth_manager.go @@ -19,6 +19,7 @@ func newAuthManager() *sdkAuth.Manager { sdkAuth.NewQwenAuthenticator(), sdkAuth.NewIFlowAuthenticator(), sdkAuth.NewAntigravityAuthenticator(), + sdkAuth.NewKiroAuthenticator(), ) return manager } diff --git a/internal/cmd/kiro_login.go b/internal/cmd/kiro_login.go new file mode 100644 index 00000000..5fc3b9eb --- /dev/null +++ b/internal/cmd/kiro_login.go @@ -0,0 +1,160 @@ +package cmd + +import ( + "context" + "fmt" + + "github.com/router-for-me/CLIProxyAPI/v6/internal/config" + sdkAuth "github.com/router-for-me/CLIProxyAPI/v6/sdk/auth" + log "github.com/sirupsen/logrus" +) + +// DoKiroLogin triggers the Kiro authentication flow with Google OAuth. +// This is the default login method (same as --kiro-google-login). +// +// Parameters: +// - cfg: The application configuration +// - options: Login options including Prompt field +func DoKiroLogin(cfg *config.Config, options *LoginOptions) { + // Use Google login as default + DoKiroGoogleLogin(cfg, options) +} + +// DoKiroGoogleLogin triggers Kiro authentication with Google OAuth. +// This uses a custom protocol handler (kiro://) to receive the callback. +// +// Parameters: +// - cfg: The application configuration +// - options: Login options including prompts +func DoKiroGoogleLogin(cfg *config.Config, options *LoginOptions) { + if options == nil { + options = &LoginOptions{} + } + + // Note: Kiro defaults to incognito mode for multi-account support. + // Users can override with --no-incognito if they want to use existing browser sessions. + + manager := newAuthManager() + + // Use KiroAuthenticator with Google login + authenticator := sdkAuth.NewKiroAuthenticator() + record, err := authenticator.LoginWithGoogle(context.Background(), cfg, &sdkAuth.LoginOptions{ + NoBrowser: options.NoBrowser, + Metadata: map[string]string{}, + Prompt: options.Prompt, + }) + if err != nil { + log.Errorf("Kiro Google authentication failed: %v", err) + fmt.Println("\nTroubleshooting:") + fmt.Println("1. Make sure the protocol handler is installed") + fmt.Println("2. Complete the Google login in the browser") + fmt.Println("3. If callback fails, try: --kiro-import (after logging in via Kiro IDE)") + return + } + + // Save the auth record + savedPath, err := manager.SaveAuth(record, cfg) + if err != nil { + log.Errorf("Failed to save auth: %v", err) + return + } + + if savedPath != "" { + fmt.Printf("Authentication saved to %s\n", savedPath) + } + if record != nil && record.Label != "" { + fmt.Printf("Authenticated as %s\n", record.Label) + } + fmt.Println("Kiro Google authentication successful!") +} + +// DoKiroAWSLogin triggers Kiro authentication with AWS Builder ID. +// This uses the device code flow for AWS SSO OIDC authentication. +// +// Parameters: +// - cfg: The application configuration +// - options: Login options including prompts +func DoKiroAWSLogin(cfg *config.Config, options *LoginOptions) { + if options == nil { + options = &LoginOptions{} + } + + // Note: Kiro defaults to incognito mode for multi-account support. + // Users can override with --no-incognito if they want to use existing browser sessions. + + manager := newAuthManager() + + // Use KiroAuthenticator with AWS Builder ID login (device code flow) + authenticator := sdkAuth.NewKiroAuthenticator() + record, err := authenticator.Login(context.Background(), cfg, &sdkAuth.LoginOptions{ + NoBrowser: options.NoBrowser, + Metadata: map[string]string{}, + Prompt: options.Prompt, + }) + if err != nil { + log.Errorf("Kiro AWS authentication failed: %v", err) + fmt.Println("\nTroubleshooting:") + fmt.Println("1. Make sure you have an AWS Builder ID") + fmt.Println("2. Complete the authorization in the browser") + fmt.Println("3. If callback fails, try: --kiro-import (after logging in via Kiro IDE)") + return + } + + // Save the auth record + savedPath, err := manager.SaveAuth(record, cfg) + if err != nil { + log.Errorf("Failed to save auth: %v", err) + return + } + + if savedPath != "" { + fmt.Printf("Authentication saved to %s\n", savedPath) + } + if record != nil && record.Label != "" { + fmt.Printf("Authenticated as %s\n", record.Label) + } + fmt.Println("Kiro AWS authentication successful!") +} + +// DoKiroImport imports Kiro token from Kiro IDE's token file. +// This is useful for users who have already logged in via Kiro IDE +// and want to use the same credentials in CLI Proxy API. +// +// Parameters: +// - cfg: The application configuration +// - options: Login options (currently unused for import) +func DoKiroImport(cfg *config.Config, options *LoginOptions) { + if options == nil { + options = &LoginOptions{} + } + + manager := newAuthManager() + + // Use ImportFromKiroIDE instead of Login + authenticator := sdkAuth.NewKiroAuthenticator() + record, err := authenticator.ImportFromKiroIDE(context.Background(), cfg) + if err != nil { + log.Errorf("Kiro token import failed: %v", err) + fmt.Println("\nMake sure you have logged in to Kiro IDE first:") + fmt.Println("1. Open Kiro IDE") + fmt.Println("2. Click 'Sign in with Google' (or GitHub)") + fmt.Println("3. Complete the login process") + fmt.Println("4. Run this command again") + return + } + + // Save the imported auth record + savedPath, err := manager.SaveAuth(record, cfg) + if err != nil { + log.Errorf("Failed to save auth: %v", err) + return + } + + if savedPath != "" { + fmt.Printf("Authentication saved to %s\n", savedPath) + } + if record != nil && record.Label != "" { + fmt.Printf("Imported as %s\n", record.Label) + } + fmt.Println("Kiro token import successful!") +} diff --git a/internal/config/config.go b/internal/config/config.go index 2681d049..1c72ece4 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -58,6 +58,9 @@ type Config struct { // GeminiKey defines Gemini API key configurations with optional routing overrides. GeminiKey []GeminiKey `yaml:"gemini-api-key" json:"gemini-api-key"` + // KiroKey defines a list of Kiro (AWS CodeWhisperer) configurations. + KiroKey []KiroKey `yaml:"kiro" json:"kiro"` + // Codex defines a list of Codex API key configurations as specified in the YAML configuration file. CodexKey []CodexKey `yaml:"codex-api-key" json:"codex-api-key"` @@ -80,6 +83,11 @@ type Config struct { // Payload defines default and override rules for provider payload parameters. Payload PayloadConfig `yaml:"payload" json:"payload"` + // IncognitoBrowser enables opening OAuth URLs in incognito/private browsing mode. + // This is useful when you want to login with a different account without logging out + // from your current session. Default: false. + IncognitoBrowser bool `yaml:"incognito-browser" json:"incognito-browser"` + legacyMigrationPending bool `yaml:"-" json:"-"` } @@ -240,6 +248,31 @@ type GeminiKey struct { ExcludedModels []string `yaml:"excluded-models,omitempty" json:"excluded-models,omitempty"` } +// KiroKey represents the configuration for Kiro (AWS CodeWhisperer) authentication. +type KiroKey struct { + // TokenFile is the path to the Kiro token file (default: ~/.aws/sso/cache/kiro-auth-token.json) + TokenFile string `yaml:"token-file,omitempty" json:"token-file,omitempty"` + + // AccessToken is the OAuth access token for direct configuration. + AccessToken string `yaml:"access-token,omitempty" json:"access-token,omitempty"` + + // RefreshToken is the OAuth refresh token for token renewal. + RefreshToken string `yaml:"refresh-token,omitempty" json:"refresh-token,omitempty"` + + // ProfileArn is the AWS CodeWhisperer profile ARN. + ProfileArn string `yaml:"profile-arn,omitempty" json:"profile-arn,omitempty"` + + // Region is the AWS region (default: us-east-1). + Region string `yaml:"region,omitempty" json:"region,omitempty"` + + // ProxyURL optionally overrides the global proxy for this configuration. + ProxyURL string `yaml:"proxy-url,omitempty" json:"proxy-url,omitempty"` + + // AgentTaskType sets the Kiro API task type. Known values: "vibe", "dev", "chat". + // Leave empty to let API use defaults. Different values may inject different system prompts. + AgentTaskType string `yaml:"agent-task-type,omitempty" json:"agent-task-type,omitempty"` +} + // OpenAICompatibility represents the configuration for OpenAI API compatibility // with external providers, allowing model aliases to be routed through OpenAI API format. type OpenAICompatibility struct { @@ -320,6 +353,7 @@ func LoadConfigOptional(configFile string, optional bool) (*Config, error) { cfg.UsageStatisticsEnabled = false cfg.DisableCooling = false cfg.AmpCode.RestrictManagementToLocalhost = true // Default to secure: only localhost access + cfg.IncognitoBrowser = false // Default to normal browser (AWS uses incognito by force) if err = yaml.Unmarshal(data, &cfg); err != nil { if optional { // In cloud deploy mode, if YAML parsing fails, return empty config instead of error. @@ -370,6 +404,9 @@ func LoadConfigOptional(configFile string, optional bool) (*Config, error) { // Sanitize Claude key headers cfg.SanitizeClaudeKeys() + // Sanitize Kiro keys: trim whitespace from credential fields + cfg.SanitizeKiroKeys() + // Sanitize OpenAI compatibility providers: drop entries without base-url cfg.SanitizeOpenAICompatibility() @@ -446,6 +483,22 @@ func (cfg *Config) SanitizeClaudeKeys() { } } +// SanitizeKiroKeys trims whitespace from Kiro credential fields. +func (cfg *Config) SanitizeKiroKeys() { + if cfg == nil || len(cfg.KiroKey) == 0 { + return + } + for i := range cfg.KiroKey { + entry := &cfg.KiroKey[i] + entry.TokenFile = strings.TrimSpace(entry.TokenFile) + entry.AccessToken = strings.TrimSpace(entry.AccessToken) + entry.RefreshToken = strings.TrimSpace(entry.RefreshToken) + entry.ProfileArn = strings.TrimSpace(entry.ProfileArn) + entry.Region = strings.TrimSpace(entry.Region) + entry.ProxyURL = strings.TrimSpace(entry.ProxyURL) + } +} + // SanitizeGeminiKeys deduplicates and normalizes Gemini credentials. func (cfg *Config) SanitizeGeminiKeys() { if cfg == nil { diff --git a/internal/constant/constant.go b/internal/constant/constant.go index 58b388a1..1dbeecde 100644 --- a/internal/constant/constant.go +++ b/internal/constant/constant.go @@ -24,4 +24,7 @@ const ( // Antigravity represents the Antigravity response format identifier. Antigravity = "antigravity" + + // Kiro represents the AWS CodeWhisperer (Kiro) provider identifier. + Kiro = "kiro" ) diff --git a/internal/logging/global_logger.go b/internal/logging/global_logger.go index 28fde213..74a79efc 100644 --- a/internal/logging/global_logger.go +++ b/internal/logging/global_logger.go @@ -38,13 +38,16 @@ func (m *LogFormatter) Format(entry *log.Entry) ([]byte, error) { timestamp := entry.Time.Format("2006-01-02 15:04:05") message := strings.TrimRight(entry.Message, "\r\n") - - var formatted string + + // Handle nil Caller (can happen with some log entries) + callerFile := "unknown" + callerLine := 0 if entry.Caller != nil { - formatted = fmt.Sprintf("[%s] [%s] [%s:%d] %s\n", timestamp, entry.Level, filepath.Base(entry.Caller.File), entry.Caller.Line, message) - } else { - formatted = fmt.Sprintf("[%s] [%s] %s\n", timestamp, entry.Level, message) + callerFile = filepath.Base(entry.Caller.File) + callerLine = entry.Caller.Line } + + formatted := fmt.Sprintf("[%s] [%s] [%s:%d] %s\n", timestamp, entry.Level, callerFile, callerLine, message) buffer.WriteString(formatted) return buffer.Bytes(), nil @@ -55,6 +58,7 @@ func (m *LogFormatter) Format(entry *log.Entry) ([]byte, error) { func SetupBaseLogger() { setupOnce.Do(func() { log.SetOutput(os.Stdout) + log.SetLevel(log.InfoLevel) log.SetReportCaller(true) log.SetFormatter(&LogFormatter{}) diff --git a/internal/registry/model_definitions.go b/internal/registry/model_definitions.go index 36aa83bb..9f10eeac 100644 --- a/internal/registry/model_definitions.go +++ b/internal/registry/model_definitions.go @@ -986,3 +986,161 @@ func GetIFlowModels() []*ModelInfo { } return models } + +// GetKiroModels returns the Kiro (AWS CodeWhisperer) model definitions +func GetKiroModels() []*ModelInfo { + return []*ModelInfo{ + { + ID: "kiro-auto", + Object: "model", + Created: 1732752000, // 2024-11-28 + OwnedBy: "aws", + Type: "kiro", + DisplayName: "Kiro Auto", + Description: "Automatic model selection by AWS CodeWhisperer", + ContextLength: 200000, + MaxCompletionTokens: 64000, + }, + { + ID: "kiro-claude-opus-4.5", + Object: "model", + Created: 1732752000, + OwnedBy: "aws", + Type: "kiro", + DisplayName: "Kiro Claude Opus 4.5", + Description: "Claude Opus 4.5 via Kiro (2.2x credit)", + ContextLength: 200000, + MaxCompletionTokens: 64000, + }, + { + ID: "kiro-claude-sonnet-4.5", + Object: "model", + Created: 1732752000, + OwnedBy: "aws", + Type: "kiro", + DisplayName: "Kiro Claude Sonnet 4.5", + Description: "Claude Sonnet 4.5 via Kiro (1.3x credit)", + ContextLength: 200000, + MaxCompletionTokens: 64000, + }, + { + ID: "kiro-claude-sonnet-4", + Object: "model", + Created: 1732752000, + OwnedBy: "aws", + Type: "kiro", + DisplayName: "Kiro Claude Sonnet 4", + Description: "Claude Sonnet 4 via Kiro (1.3x credit)", + ContextLength: 200000, + MaxCompletionTokens: 64000, + }, + { + ID: "kiro-claude-haiku-4.5", + Object: "model", + Created: 1732752000, + OwnedBy: "aws", + Type: "kiro", + DisplayName: "Kiro Claude Haiku 4.5", + Description: "Claude Haiku 4.5 via Kiro (0.4x credit)", + ContextLength: 200000, + MaxCompletionTokens: 64000, + }, + // --- Chat Variant (No tool calling, for pure conversation) --- + { + ID: "kiro-claude-opus-4.5-chat", + Object: "model", + Created: 1732752000, + OwnedBy: "aws", + Type: "kiro", + DisplayName: "Kiro Claude Opus 4.5 (Chat)", + Description: "Claude Opus 4.5 for chat only (no tool calling)", + ContextLength: 200000, + MaxCompletionTokens: 64000, + }, + // --- Agentic Variants (Optimized for coding agents with chunked writes) --- + { + ID: "kiro-claude-opus-4.5-agentic", + Object: "model", + Created: 1732752000, + OwnedBy: "aws", + Type: "kiro", + DisplayName: "Kiro Claude Opus 4.5 (Agentic)", + Description: "Claude Opus 4.5 optimized for coding agents (chunked writes)", + ContextLength: 200000, + MaxCompletionTokens: 64000, + }, + { + ID: "kiro-claude-sonnet-4.5-agentic", + Object: "model", + Created: 1732752000, + OwnedBy: "aws", + Type: "kiro", + DisplayName: "Kiro Claude Sonnet 4.5 (Agentic)", + Description: "Claude Sonnet 4.5 optimized for coding agents (chunked writes)", + ContextLength: 200000, + MaxCompletionTokens: 64000, + }, + } +} + +// GetAmazonQModels returns the Amazon Q (AWS CodeWhisperer) model definitions. +// These models use the same API as Kiro and share the same executor. +func GetAmazonQModels() []*ModelInfo { + return []*ModelInfo{ + { + ID: "amazonq-auto", + Object: "model", + Created: 1732752000, + OwnedBy: "aws", + Type: "kiro", // Uses Kiro executor - same API + DisplayName: "Amazon Q Auto", + Description: "Automatic model selection by Amazon Q", + ContextLength: 200000, + MaxCompletionTokens: 64000, + }, + { + ID: "amazonq-claude-opus-4.5", + Object: "model", + Created: 1732752000, + OwnedBy: "aws", + Type: "kiro", + DisplayName: "Amazon Q Claude Opus 4.5", + Description: "Claude Opus 4.5 via Amazon Q (2.2x credit)", + ContextLength: 200000, + MaxCompletionTokens: 64000, + }, + { + ID: "amazonq-claude-sonnet-4.5", + Object: "model", + Created: 1732752000, + OwnedBy: "aws", + Type: "kiro", + DisplayName: "Amazon Q Claude Sonnet 4.5", + Description: "Claude Sonnet 4.5 via Amazon Q (1.3x credit)", + ContextLength: 200000, + MaxCompletionTokens: 64000, + }, + { + ID: "amazonq-claude-sonnet-4", + Object: "model", + Created: 1732752000, + OwnedBy: "aws", + Type: "kiro", + DisplayName: "Amazon Q Claude Sonnet 4", + Description: "Claude Sonnet 4 via Amazon Q (1.3x credit)", + ContextLength: 200000, + MaxCompletionTokens: 64000, + }, + { + ID: "amazonq-claude-haiku-4.5", + Object: "model", + Created: 1732752000, + OwnedBy: "aws", + Type: "kiro", + DisplayName: "Amazon Q Claude Haiku 4.5", + Description: "Claude Haiku 4.5 via Amazon Q (0.4x credit)", + ContextLength: 200000, + MaxCompletionTokens: 64000, + }, + } +} diff --git a/internal/runtime/executor/antigravity_executor.go b/internal/runtime/executor/antigravity_executor.go index 73914750..c124c829 100644 --- a/internal/runtime/executor/antigravity_executor.go +++ b/internal/runtime/executor/antigravity_executor.go @@ -12,6 +12,7 @@ import ( "net/url" "strconv" "strings" + "sync" "time" "github.com/google/uuid" @@ -41,7 +42,10 @@ const ( streamScannerBuffer int = 20_971_520 ) -var randSource = rand.New(rand.NewSource(time.Now().UnixNano())) +var ( + randSource = rand.New(rand.NewSource(time.Now().UnixNano())) + randSourceMutex sync.Mutex +) // AntigravityExecutor proxies requests to the antigravity upstream. type AntigravityExecutor struct { @@ -754,15 +758,19 @@ func generateRequestID() string { } func generateSessionID() string { + randSourceMutex.Lock() n := randSource.Int63n(9_000_000_000_000_000_000) + randSourceMutex.Unlock() return "-" + strconv.FormatInt(n, 10) } func generateProjectID() string { adjectives := []string{"useful", "bright", "swift", "calm", "bold"} nouns := []string{"fuze", "wave", "spark", "flow", "core"} + randSourceMutex.Lock() adj := adjectives[randSource.Intn(len(adjectives))] noun := nouns[randSource.Intn(len(nouns))] + randSourceMutex.Unlock() randomPart := strings.ToLower(uuid.NewString())[:5] return adj + "-" + noun + "-" + randomPart } diff --git a/internal/runtime/executor/cache_helpers.go b/internal/runtime/executor/cache_helpers.go index 5272686b..4b553662 100644 --- a/internal/runtime/executor/cache_helpers.go +++ b/internal/runtime/executor/cache_helpers.go @@ -1,10 +1,38 @@ package executor -import "time" +import ( + "sync" + "time" +) type codexCache struct { ID string Expire time.Time } -var codexCacheMap = map[string]codexCache{} +var ( + codexCacheMap = map[string]codexCache{} + codexCacheMutex sync.RWMutex +) + +// getCodexCache safely retrieves a cache entry +func getCodexCache(key string) (codexCache, bool) { + codexCacheMutex.RLock() + defer codexCacheMutex.RUnlock() + cache, ok := codexCacheMap[key] + return cache, ok +} + +// setCodexCache safely sets a cache entry +func setCodexCache(key string, cache codexCache) { + codexCacheMutex.Lock() + defer codexCacheMutex.Unlock() + codexCacheMap[key] = cache +} + +// deleteCodexCache safely deletes a cache entry +func deleteCodexCache(key string) { + codexCacheMutex.Lock() + defer codexCacheMutex.Unlock() + delete(codexCacheMap, key) +} diff --git a/internal/runtime/executor/codex_executor.go b/internal/runtime/executor/codex_executor.go index 1c4291f6..c14b7be8 100644 --- a/internal/runtime/executor/codex_executor.go +++ b/internal/runtime/executor/codex_executor.go @@ -506,12 +506,12 @@ func (e *CodexExecutor) cacheHelper(ctx context.Context, from sdktranslator.Form if userIDResult.Exists() { var hasKey bool key := fmt.Sprintf("%s-%s", req.Model, userIDResult.String()) - if cache, hasKey = codexCacheMap[key]; !hasKey || cache.Expire.Before(time.Now()) { + if cache, hasKey = getCodexCache(key); !hasKey || cache.Expire.Before(time.Now()) { cache = codexCache{ ID: uuid.New().String(), Expire: time.Now().Add(1 * time.Hour), } - codexCacheMap[key] = cache + setCodexCache(key, cache) } } } else if from == "openai-response" { diff --git a/internal/runtime/executor/kiro_executor.go b/internal/runtime/executor/kiro_executor.go new file mode 100644 index 00000000..eb068c14 --- /dev/null +++ b/internal/runtime/executor/kiro_executor.go @@ -0,0 +1,2353 @@ +package executor + +import ( + "bufio" + "bytes" + "context" + "encoding/base64" + "encoding/binary" + "encoding/json" + "fmt" + "io" + "net/http" + "regexp" + "strings" + "sync" + "time" + "unicode/utf8" + + "github.com/google/uuid" + kiroauth "github.com/router-for-me/CLIProxyAPI/v6/internal/auth/kiro" + "github.com/router-for-me/CLIProxyAPI/v6/internal/config" + "github.com/router-for-me/CLIProxyAPI/v6/internal/util" + cliproxyauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth" + cliproxyexecutor "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/executor" + "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/usage" + sdktranslator "github.com/router-for-me/CLIProxyAPI/v6/sdk/translator" + log "github.com/sirupsen/logrus" + "github.com/tidwall/gjson" + + "github.com/gin-gonic/gin" +) + +const ( + // kiroEndpoint is the Amazon Q streaming endpoint for chat API (GenerateAssistantResponse). + // Note: This is different from the CodeWhisperer management endpoint (codewhisperer.us-east-1.amazonaws.com) + // used in aws_auth.go for GetUsageLimits, ListProfiles, etc. Both endpoints are correct + // for their respective API operations. + kiroEndpoint = "https://q.us-east-1.amazonaws.com" + kiroTargetChat = "AmazonCodeWhispererStreamingService.GenerateAssistantResponse" + kiroContentType = "application/x-amz-json-1.0" + kiroAcceptStream = "application/vnd.amazon.eventstream" + kiroMaxMessageSize = 10 * 1024 * 1024 // 10MB max message size for event stream + kiroMaxToolDescLen = 10237 // Kiro API limit is 10240 bytes, leave room for "..." + + // kiroAgenticSystemPrompt is injected only for -agentic models to prevent timeouts on large writes. + // AWS Kiro API has a 2-3 minute timeout for large file write operations. + kiroAgenticSystemPrompt = ` +# CRITICAL: CHUNKED WRITE PROTOCOL (MANDATORY) + +You MUST follow these rules for ALL file operations. Violation causes server timeouts and task failure. + +## ABSOLUTE LIMITS +- **MAXIMUM 350 LINES** per single write/edit operation - NO EXCEPTIONS +- **RECOMMENDED 300 LINES** or less for optimal performance +- **NEVER** write entire files in one operation if >300 lines + +## MANDATORY CHUNKED WRITE STRATEGY + +### For NEW FILES (>300 lines total): +1. FIRST: Write initial chunk (first 250-300 lines) using write_to_file/fsWrite +2. THEN: Append remaining content in 250-300 line chunks using file append operations +3. REPEAT: Continue appending until complete + +### For EDITING EXISTING FILES: +1. Use surgical edits (apply_diff/targeted edits) - change ONLY what's needed +2. NEVER rewrite entire files - use incremental modifications +3. Split large refactors into multiple small, focused edits + +### For LARGE CODE GENERATION: +1. Generate in logical sections (imports, types, functions separately) +2. Write each section as a separate operation +3. Use append operations for subsequent sections + +## EXAMPLES OF CORRECT BEHAVIOR + +✅ CORRECT: Writing a 600-line file +- Operation 1: Write lines 1-300 (initial file creation) +- Operation 2: Append lines 301-600 + +✅ CORRECT: Editing multiple functions +- Operation 1: Edit function A +- Operation 2: Edit function B +- Operation 3: Edit function C + +❌ WRONG: Writing 500 lines in single operation → TIMEOUT +❌ WRONG: Rewriting entire file to change 5 lines → TIMEOUT +❌ WRONG: Generating massive code blocks without chunking → TIMEOUT + +## WHY THIS MATTERS +- Server has 2-3 minute timeout for operations +- Large writes exceed timeout and FAIL completely +- Chunked writes are FASTER and more RELIABLE +- Failed writes waste time and require retry + +REMEMBER: When in doubt, write LESS per operation. Multiple small operations > one large operation.` +) + +// KiroExecutor handles requests to AWS CodeWhisperer (Kiro) API. +type KiroExecutor struct { + cfg *config.Config + refreshMu sync.Mutex // Serializes token refresh operations to prevent race conditions +} + +// NewKiroExecutor creates a new Kiro executor instance. +func NewKiroExecutor(cfg *config.Config) *KiroExecutor { + return &KiroExecutor{cfg: cfg} +} + +// Identifier returns the unique identifier for this executor. +func (e *KiroExecutor) Identifier() string { return "kiro" } + +// PrepareRequest prepares the HTTP request before execution. +func (e *KiroExecutor) PrepareRequest(_ *http.Request, _ *cliproxyauth.Auth) error { return nil } + + +// Execute sends the request to Kiro API and returns the response. +// Supports automatic token refresh on 401/403 errors. +func (e *KiroExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (resp cliproxyexecutor.Response, err error) { + accessToken, profileArn := kiroCredentials(auth) + if accessToken == "" { + return resp, fmt.Errorf("kiro: access token not found in auth") + } + if profileArn == "" { + log.Warnf("kiro: profile ARN not found in auth, API calls may fail") + } + + reporter := newUsageReporter(ctx, e.Identifier(), req.Model, auth) + defer reporter.trackFailure(ctx, &err) + + // Check if token is expired before making request + if e.isTokenExpired(accessToken) { + log.Infof("kiro: access token expired, attempting refresh before request") + refreshedAuth, refreshErr := e.Refresh(ctx, auth) + if refreshErr != nil { + log.Warnf("kiro: pre-request token refresh failed: %v", refreshErr) + } else if refreshedAuth != nil { + auth = refreshedAuth + accessToken, profileArn = kiroCredentials(auth) + log.Infof("kiro: token refreshed successfully before request") + } + } + + from := opts.SourceFormat + to := sdktranslator.FromString("kiro") + body := sdktranslator.TranslateRequest(from, to, req.Model, bytes.Clone(req.Payload), true) + + kiroModelID := e.mapModelToKiro(req.Model) + + // Check if this is an agentic model variant + isAgentic := strings.HasSuffix(req.Model, "-agentic") + + // Check if this is a chat-only model variant (no tool calling) + isChatOnly := strings.HasSuffix(req.Model, "-chat") + + // Determine initial origin based on model type + // Opus models use AI_EDITOR (Kiro IDE quota), others start with CLI (Amazon Q quota) + var currentOrigin string + if strings.Contains(strings.ToLower(req.Model), "opus") { + currentOrigin = "AI_EDITOR" + } else { + currentOrigin = "CLI" + } + + kiroPayload := e.buildKiroPayload(body, kiroModelID, profileArn, currentOrigin, isAgentic, isChatOnly) + + // Execute with retry on 401/403 and 429 (quota exhausted) + resp, err = e.executeWithRetry(ctx, auth, req, opts, accessToken, profileArn, kiroPayload, body, from, to, reporter, currentOrigin, kiroModelID, isAgentic, isChatOnly) + return resp, err +} + +// executeWithRetry performs the actual HTTP request with automatic retry on auth errors. +// Supports automatic fallback from CLI (Amazon Q) quota to AI_EDITOR (Kiro IDE) quota on 429. +func (e *KiroExecutor) executeWithRetry(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options, accessToken, profileArn string, kiroPayload, body []byte, from, to sdktranslator.Format, reporter *usageReporter, currentOrigin, kiroModelID string, isAgentic, isChatOnly bool) (cliproxyexecutor.Response, error) { + var resp cliproxyexecutor.Response + maxRetries := 2 // Allow retries for token refresh + origin fallback + + for attempt := 0; attempt <= maxRetries; attempt++ { + url := kiroEndpoint + httpReq, err := http.NewRequestWithContext(ctx, http.MethodPost, url, bytes.NewReader(kiroPayload)) + if err != nil { + return resp, err + } + + httpReq.Header.Set("Content-Type", kiroContentType) + httpReq.Header.Set("x-amz-target", kiroTargetChat) + httpReq.Header.Set("Authorization", "Bearer "+accessToken) + httpReq.Header.Set("Accept", kiroAcceptStream) + + var attrs map[string]string + if auth != nil { + attrs = auth.Attributes + } + util.ApplyCustomHeadersFromAttrs(httpReq, attrs) + + var authID, authLabel, authType, authValue string + if auth != nil { + authID = auth.ID + authLabel = auth.Label + authType, authValue = auth.AccountInfo() + } + recordAPIRequest(ctx, e.cfg, upstreamRequestLog{ + URL: url, + Method: http.MethodPost, + Headers: httpReq.Header.Clone(), + Body: kiroPayload, + Provider: e.Identifier(), + AuthID: authID, + AuthLabel: authLabel, + AuthType: authType, + AuthValue: authValue, + }) + + httpClient := newProxyAwareHTTPClient(ctx, e.cfg, auth, 120*time.Second) + httpResp, err := httpClient.Do(httpReq) + if err != nil { + recordAPIResponseError(ctx, e.cfg, err) + return resp, err + } + recordAPIResponseMetadata(ctx, e.cfg, httpResp.StatusCode, httpResp.Header.Clone()) + + // Handle 429 errors (quota exhausted) with origin fallback + if httpResp.StatusCode == 429 { + respBody, _ := io.ReadAll(httpResp.Body) + _ = httpResp.Body.Close() + appendAPIResponseChunk(ctx, e.cfg, respBody) + + // If currently using CLI quota and it's exhausted, switch to AI_EDITOR (Kiro IDE) quota + if currentOrigin == "CLI" { + log.Warnf("kiro: Amazon Q (CLI) quota exhausted (429), switching to Kiro (AI_EDITOR) fallback") + currentOrigin = "AI_EDITOR" + + // Rebuild payload with new origin + kiroPayload = e.buildKiroPayload(body, kiroModelID, profileArn, currentOrigin, isAgentic, isChatOnly) + + // Retry with new origin + continue + } + + // Already on AI_EDITOR or other origin, return the error + log.Debugf("kiro request error, status: %d, body: %s", httpResp.StatusCode, summarizeErrorBody(httpResp.Header.Get("Content-Type"), respBody)) + return resp, statusErr{code: httpResp.StatusCode, msg: string(respBody)} + } + + // Handle 401/403 errors with token refresh and retry + if httpResp.StatusCode == 401 || httpResp.StatusCode == 403 { + respBody, _ := io.ReadAll(httpResp.Body) + _ = httpResp.Body.Close() + appendAPIResponseChunk(ctx, e.cfg, respBody) + + if attempt < maxRetries { + log.Warnf("kiro: received %d error, attempting token refresh and retry (attempt %d/%d)", httpResp.StatusCode, attempt+1, maxRetries+1) + + refreshedAuth, refreshErr := e.Refresh(ctx, auth) + if refreshErr != nil { + log.Errorf("kiro: token refresh failed: %v", refreshErr) + return resp, statusErr{code: httpResp.StatusCode, msg: string(respBody)} + } + + if refreshedAuth != nil { + auth = refreshedAuth + accessToken, profileArn = kiroCredentials(auth) + // Rebuild payload with new profile ARN if changed + kiroPayload = e.buildKiroPayload(body, kiroModelID, profileArn, currentOrigin, isAgentic, isChatOnly) + log.Infof("kiro: token refreshed successfully, retrying request") + continue + } + } + + log.Debugf("kiro request error, status: %d, body: %s", httpResp.StatusCode, summarizeErrorBody(httpResp.Header.Get("Content-Type"), respBody)) + return resp, statusErr{code: httpResp.StatusCode, msg: string(respBody)} + } + + if httpResp.StatusCode < 200 || httpResp.StatusCode >= 300 { + b, _ := io.ReadAll(httpResp.Body) + appendAPIResponseChunk(ctx, e.cfg, b) + log.Debugf("kiro request error, status: %d, body: %s", httpResp.StatusCode, summarizeErrorBody(httpResp.Header.Get("Content-Type"), b)) + err = statusErr{code: httpResp.StatusCode, msg: string(b)} + if errClose := httpResp.Body.Close(); errClose != nil { + log.Errorf("response body close error: %v", errClose) + } + return resp, err + } + + defer func() { + if errClose := httpResp.Body.Close(); errClose != nil { + log.Errorf("response body close error: %v", errClose) + } + }() + + content, toolUses, usageInfo, err := e.parseEventStream(httpResp.Body) + if err != nil { + recordAPIResponseError(ctx, e.cfg, err) + return resp, err + } + + // Fallback for usage if missing from upstream + if usageInfo.TotalTokens == 0 { + if enc, encErr := tokenizerForModel(req.Model); encErr == nil { + if inp, countErr := countOpenAIChatTokens(enc, opts.OriginalRequest); countErr == nil { + usageInfo.InputTokens = inp + } + } + if len(content) > 0 { + usageInfo.OutputTokens = int64(len(content) / 4) + if usageInfo.OutputTokens == 0 { + usageInfo.OutputTokens = 1 + } + } + usageInfo.TotalTokens = usageInfo.InputTokens + usageInfo.OutputTokens + } + + appendAPIResponseChunk(ctx, e.cfg, []byte(content)) + reporter.publish(ctx, usageInfo) + + // Build response in Claude format for Kiro translator + kiroResponse := e.buildClaudeResponse(content, toolUses, req.Model, usageInfo) + out := sdktranslator.TranslateNonStream(ctx, to, from, req.Model, bytes.Clone(opts.OriginalRequest), body, kiroResponse, nil) + resp = cliproxyexecutor.Response{Payload: []byte(out)} + return resp, nil + } + + return resp, fmt.Errorf("kiro: max retries exceeded") +} + +// ExecuteStream handles streaming requests to Kiro API. +// Supports automatic token refresh on 401/403 errors and quota fallback on 429. +func (e *KiroExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (stream <-chan cliproxyexecutor.StreamChunk, err error) { + accessToken, profileArn := kiroCredentials(auth) + if accessToken == "" { + return nil, fmt.Errorf("kiro: access token not found in auth") + } + if profileArn == "" { + log.Warnf("kiro: profile ARN not found in auth, API calls may fail") + } + + reporter := newUsageReporter(ctx, e.Identifier(), req.Model, auth) + defer reporter.trackFailure(ctx, &err) + + // Check if token is expired before making request + if e.isTokenExpired(accessToken) { + log.Infof("kiro: access token expired, attempting refresh before stream request") + refreshedAuth, refreshErr := e.Refresh(ctx, auth) + if refreshErr != nil { + log.Warnf("kiro: pre-request token refresh failed: %v", refreshErr) + } else if refreshedAuth != nil { + auth = refreshedAuth + accessToken, profileArn = kiroCredentials(auth) + log.Infof("kiro: token refreshed successfully before stream request") + } + } + + from := opts.SourceFormat + to := sdktranslator.FromString("kiro") + body := sdktranslator.TranslateRequest(from, to, req.Model, bytes.Clone(req.Payload), true) + + kiroModelID := e.mapModelToKiro(req.Model) + + // Check if this is an agentic model variant + isAgentic := strings.HasSuffix(req.Model, "-agentic") + + // Check if this is a chat-only model variant (no tool calling) + isChatOnly := strings.HasSuffix(req.Model, "-chat") + + // Determine initial origin based on model type + // Opus models use AI_EDITOR (Kiro IDE quota), others start with CLI (Amazon Q quota) + var currentOrigin string + if strings.Contains(strings.ToLower(req.Model), "opus") { + currentOrigin = "AI_EDITOR" + } else { + currentOrigin = "CLI" + } + + kiroPayload := e.buildKiroPayload(body, kiroModelID, profileArn, currentOrigin, isAgentic, isChatOnly) + + // Execute stream with retry on 401/403 and 429 (quota exhausted) + return e.executeStreamWithRetry(ctx, auth, req, opts, accessToken, profileArn, kiroPayload, body, from, reporter, currentOrigin, kiroModelID, isAgentic, isChatOnly) +} + +// executeStreamWithRetry performs the streaming HTTP request with automatic retry on auth errors. +// Supports automatic fallback from CLI (Amazon Q) quota to AI_EDITOR (Kiro IDE) quota on 429. +func (e *KiroExecutor) executeStreamWithRetry(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options, accessToken, profileArn string, kiroPayload, body []byte, from sdktranslator.Format, reporter *usageReporter, currentOrigin, kiroModelID string, isAgentic, isChatOnly bool) (<-chan cliproxyexecutor.StreamChunk, error) { + maxRetries := 2 // Allow retries for token refresh + origin fallback + + for attempt := 0; attempt <= maxRetries; attempt++ { + url := kiroEndpoint + httpReq, err := http.NewRequestWithContext(ctx, http.MethodPost, url, bytes.NewReader(kiroPayload)) + if err != nil { + return nil, err + } + + httpReq.Header.Set("Content-Type", kiroContentType) + httpReq.Header.Set("x-amz-target", kiroTargetChat) + httpReq.Header.Set("Authorization", "Bearer "+accessToken) + httpReq.Header.Set("Accept", kiroAcceptStream) + + var attrs map[string]string + if auth != nil { + attrs = auth.Attributes + } + util.ApplyCustomHeadersFromAttrs(httpReq, attrs) + + var authID, authLabel, authType, authValue string + if auth != nil { + authID = auth.ID + authLabel = auth.Label + authType, authValue = auth.AccountInfo() + } + recordAPIRequest(ctx, e.cfg, upstreamRequestLog{ + URL: url, + Method: http.MethodPost, + Headers: httpReq.Header.Clone(), + Body: kiroPayload, + Provider: e.Identifier(), + AuthID: authID, + AuthLabel: authLabel, + AuthType: authType, + AuthValue: authValue, + }) + + httpClient := newProxyAwareHTTPClient(ctx, e.cfg, auth, 0) + httpResp, err := httpClient.Do(httpReq) + if err != nil { + recordAPIResponseError(ctx, e.cfg, err) + return nil, err + } + recordAPIResponseMetadata(ctx, e.cfg, httpResp.StatusCode, httpResp.Header.Clone()) + + // Handle 429 errors (quota exhausted) with origin fallback + if httpResp.StatusCode == 429 { + respBody, _ := io.ReadAll(httpResp.Body) + _ = httpResp.Body.Close() + appendAPIResponseChunk(ctx, e.cfg, respBody) + + // If currently using CLI quota and it's exhausted, switch to AI_EDITOR (Kiro IDE) quota + if currentOrigin == "CLI" { + log.Warnf("kiro: stream Amazon Q (CLI) quota exhausted (429), switching to Kiro (AI_EDITOR) fallback") + currentOrigin = "AI_EDITOR" + + // Rebuild payload with new origin + kiroPayload = e.buildKiroPayload(body, kiroModelID, profileArn, currentOrigin, isAgentic, isChatOnly) + + // Retry with new origin + continue + } + + // Already on AI_EDITOR or other origin, return the error + log.Debugf("kiro stream error, status: %d, body: %s", httpResp.StatusCode, string(respBody)) + return nil, statusErr{code: httpResp.StatusCode, msg: string(respBody)} + } + + // Handle 401/403 errors with token refresh and retry + if httpResp.StatusCode == 401 || httpResp.StatusCode == 403 { + respBody, _ := io.ReadAll(httpResp.Body) + _ = httpResp.Body.Close() + appendAPIResponseChunk(ctx, e.cfg, respBody) + + if attempt < maxRetries { + log.Warnf("kiro: stream received %d error, attempting token refresh and retry (attempt %d/%d)", httpResp.StatusCode, attempt+1, maxRetries+1) + + refreshedAuth, refreshErr := e.Refresh(ctx, auth) + if refreshErr != nil { + log.Errorf("kiro: token refresh failed: %v", refreshErr) + return nil, statusErr{code: httpResp.StatusCode, msg: string(respBody)} + } + + if refreshedAuth != nil { + auth = refreshedAuth + accessToken, profileArn = kiroCredentials(auth) + // Rebuild payload with new profile ARN if changed + kiroPayload = e.buildKiroPayload(body, kiroModelID, profileArn, currentOrigin, isAgentic, isChatOnly) + log.Infof("kiro: token refreshed successfully, retrying stream request") + continue + } + } + + log.Debugf("kiro stream error, status: %d, body: %s", httpResp.StatusCode, string(respBody)) + return nil, statusErr{code: httpResp.StatusCode, msg: string(respBody)} + } + + if httpResp.StatusCode < 200 || httpResp.StatusCode >= 300 { + b, _ := io.ReadAll(httpResp.Body) + appendAPIResponseChunk(ctx, e.cfg, b) + log.Debugf("kiro stream error, status: %d, body: %s", httpResp.StatusCode, string(b)) + if errClose := httpResp.Body.Close(); errClose != nil { + log.Errorf("response body close error: %v", errClose) + } + return nil, statusErr{code: httpResp.StatusCode, msg: string(b)} + } + + out := make(chan cliproxyexecutor.StreamChunk) + + go func(resp *http.Response) { + defer close(out) + defer func() { + if errClose := resp.Body.Close(); errClose != nil { + log.Errorf("response body close error: %v", errClose) + } + }() + + e.streamToChannel(ctx, resp.Body, out, from, req.Model, opts.OriginalRequest, body, reporter) + }(httpResp) + + return out, nil + } + + return nil, fmt.Errorf("kiro: max retries exceeded for stream") +} + + +// kiroCredentials extracts access token and profile ARN from auth. +func kiroCredentials(auth *cliproxyauth.Auth) (accessToken, profileArn string) { + if auth == nil { + return "", "" + } + + // Try Metadata first (wrapper format) + if auth.Metadata != nil { + if token, ok := auth.Metadata["access_token"].(string); ok { + accessToken = token + } + if arn, ok := auth.Metadata["profile_arn"].(string); ok { + profileArn = arn + } + } + + // Try Attributes + if accessToken == "" && auth.Attributes != nil { + accessToken = auth.Attributes["access_token"] + profileArn = auth.Attributes["profile_arn"] + } + + // Try direct fields from flat JSON format (new AWS Builder ID format) + if accessToken == "" && auth.Metadata != nil { + if token, ok := auth.Metadata["accessToken"].(string); ok { + accessToken = token + } + if arn, ok := auth.Metadata["profileArn"].(string); ok { + profileArn = arn + } + } + + return accessToken, profileArn +} + +// mapModelToKiro maps external model names to Kiro model IDs. +// Supports both Kiro and Amazon Q prefixes since they use the same API. +// Agentic variants (-agentic suffix) map to the same backend model IDs. +func (e *KiroExecutor) mapModelToKiro(model string) string { + modelMap := map[string]string{ + // Proxy format (kiro- prefix) + "kiro-auto": "auto", + "kiro-claude-opus-4.5": "claude-opus-4.5", + "kiro-claude-sonnet-4.5": "claude-sonnet-4.5", + "kiro-claude-sonnet-4": "claude-sonnet-4", + "kiro-claude-haiku-4.5": "claude-haiku-4.5", + // Amazon Q format (amazonq- prefix) - same API as Kiro + "amazonq-auto": "auto", + "amazonq-claude-opus-4.5": "claude-opus-4.5", + "amazonq-claude-sonnet-4.5": "claude-sonnet-4.5", + "amazonq-claude-sonnet-4": "claude-sonnet-4", + "amazonq-claude-haiku-4.5": "claude-haiku-4.5", + // Native Kiro format (no prefix) - used by Kiro IDE directly + "claude-opus-4.5": "claude-opus-4.5", + "claude-sonnet-4.5": "claude-sonnet-4.5", + "claude-sonnet-4": "claude-sonnet-4", + "claude-haiku-4.5": "claude-haiku-4.5", + "auto": "auto", + // Chat variant (no tool calling support) + "kiro-claude-opus-4.5-chat": "claude-opus-4.5", + // Agentic variants (same backend model IDs, but with special system prompt) + "kiro-claude-opus-4.5-agentic": "claude-opus-4.5", + "kiro-claude-sonnet-4.5-agentic": "claude-sonnet-4.5", + "kiro-claude-sonnet-4-agentic": "claude-sonnet-4", + "kiro-claude-haiku-4.5-agentic": "claude-haiku-4.5", + "amazonq-claude-sonnet-4.5-agentic": "claude-sonnet-4.5", + } + if kiroID, ok := modelMap[model]; ok { + return kiroID + } + log.Debugf("kiro: unknown model '%s', falling back to 'auto'", model) + return "auto" +} + +// Kiro API request structs - field order determines JSON key order + +type kiroPayload struct { + ConversationState kiroConversationState `json:"conversationState"` + ProfileArn string `json:"profileArn,omitempty"` +} + +type kiroConversationState struct { + ConversationID string `json:"conversationId"` + History []kiroHistoryMessage `json:"history"` + CurrentMessage kiroCurrentMessage `json:"currentMessage"` + ChatTriggerType string `json:"chatTriggerType"` // Required: "MANUAL" +} + +type kiroCurrentMessage struct { + UserInputMessage kiroUserInputMessage `json:"userInputMessage"` +} + +type kiroHistoryMessage struct { + UserInputMessage *kiroUserInputMessage `json:"userInputMessage,omitempty"` + AssistantResponseMessage *kiroAssistantResponseMessage `json:"assistantResponseMessage,omitempty"` +} + +type kiroUserInputMessage struct { + Content string `json:"content"` + ModelID string `json:"modelId"` + Origin string `json:"origin"` + UserInputMessageContext *kiroUserInputMessageContext `json:"userInputMessageContext,omitempty"` +} + +type kiroUserInputMessageContext struct { + ToolResults []kiroToolResult `json:"toolResults,omitempty"` + Tools []kiroToolWrapper `json:"tools,omitempty"` +} + +type kiroToolResult struct { + ToolUseID string `json:"toolUseId"` + Content []kiroTextContent `json:"content"` + Status string `json:"status"` +} + +type kiroTextContent struct { + Text string `json:"text"` +} + +type kiroToolWrapper struct { + ToolSpecification kiroToolSpecification `json:"toolSpecification"` +} + +type kiroToolSpecification struct { + Name string `json:"name"` + Description string `json:"description"` + InputSchema kiroInputSchema `json:"inputSchema"` +} + +type kiroInputSchema struct { + JSON interface{} `json:"json"` +} + +type kiroAssistantResponseMessage struct { + Content string `json:"content"` + ToolUses []kiroToolUse `json:"toolUses,omitempty"` +} + +type kiroToolUse struct { + ToolUseID string `json:"toolUseId"` + Name string `json:"name"` + Input map[string]interface{} `json:"input"` +} + +// buildKiroPayload constructs the Kiro API request payload. +// Supports tool calling - tools are passed via userInputMessageContext. +// origin parameter determines which quota to use: "CLI" for Amazon Q, "AI_EDITOR" for Kiro IDE. +// isAgentic parameter enables chunked write optimization prompt for -agentic model variants. +// isChatOnly parameter disables tool calling for -chat model variants (pure conversation mode). +func (e *KiroExecutor) buildKiroPayload(claudeBody []byte, modelID, profileArn, origin string, isAgentic, isChatOnly bool) []byte { + messages := gjson.GetBytes(claudeBody, "messages") + + // For chat-only mode, don't include tools + var tools gjson.Result + if !isChatOnly { + tools = gjson.GetBytes(claudeBody, "tools") + } + + // Extract system prompt - can be string or array of content blocks + systemField := gjson.GetBytes(claudeBody, "system") + var systemPrompt string + if systemField.IsArray() { + // System is array of content blocks, extract text + var sb strings.Builder + for _, block := range systemField.Array() { + if block.Get("type").String() == "text" { + sb.WriteString(block.Get("text").String()) + } else if block.Type == gjson.String { + sb.WriteString(block.String()) + } + } + systemPrompt = sb.String() + } else { + systemPrompt = systemField.String() + } + + // Inject agentic optimization prompt for -agentic model variants + // This prevents AWS Kiro API timeouts during large file write operations + if isAgentic { + if systemPrompt != "" { + systemPrompt += "\n" + } + systemPrompt += kiroAgenticSystemPrompt + } + + // Convert Claude tools to Kiro format + var kiroTools []kiroToolWrapper + if tools.IsArray() { + for _, tool := range tools.Array() { + name := tool.Get("name").String() + description := tool.Get("description").String() + inputSchema := tool.Get("input_schema").Value() + + // Truncate long descriptions (Kiro API limit is in bytes) + // Truncate at valid UTF-8 boundary to avoid breaking multi-byte chars + if len(description) > kiroMaxToolDescLen { + // Find a valid UTF-8 boundary before the limit + truncLen := kiroMaxToolDescLen + for truncLen > 0 && !utf8.RuneStart(description[truncLen]) { + truncLen-- + } + description = description[:truncLen] + "..." + } + + kiroTools = append(kiroTools, kiroToolWrapper{ + ToolSpecification: kiroToolSpecification{ + Name: name, + Description: description, + InputSchema: kiroInputSchema{JSON: inputSchema}, + }, + }) + } + } + + var history []kiroHistoryMessage + var currentUserMsg *kiroUserInputMessage + var currentToolResults []kiroToolResult + + messagesArray := messages.Array() + for i, msg := range messagesArray { + role := msg.Get("role").String() + isLastMessage := i == len(messagesArray)-1 + + if role == "user" { + userMsg, toolResults := e.buildUserMessageStruct(msg, modelID, origin) + if isLastMessage { + currentUserMsg = &userMsg + currentToolResults = toolResults + } else { + // For history messages, embed tool results in context + if len(toolResults) > 0 { + userMsg.UserInputMessageContext = &kiroUserInputMessageContext{ + ToolResults: toolResults, + } + } + history = append(history, kiroHistoryMessage{ + UserInputMessage: &userMsg, + }) + } + } else if role == "assistant" { + assistantMsg := e.buildAssistantMessageStruct(msg) + history = append(history, kiroHistoryMessage{ + AssistantResponseMessage: &assistantMsg, + }) + } + } + + // Build content with system prompt + if currentUserMsg != nil { + var contentBuilder strings.Builder + + // Add system prompt if present + if systemPrompt != "" { + contentBuilder.WriteString("--- SYSTEM PROMPT ---\n") + contentBuilder.WriteString(systemPrompt) + contentBuilder.WriteString("\n--- END SYSTEM PROMPT ---\n\n") + } + + // Add the actual user message + contentBuilder.WriteString(currentUserMsg.Content) + currentUserMsg.Content = contentBuilder.String() + + // Build userInputMessageContext with tools and tool results + if len(kiroTools) > 0 || len(currentToolResults) > 0 { + currentUserMsg.UserInputMessageContext = &kiroUserInputMessageContext{ + Tools: kiroTools, + ToolResults: currentToolResults, + } + } + } + + // Build payload using structs (preserves key order) + var currentMessage kiroCurrentMessage + if currentUserMsg != nil { + currentMessage = kiroCurrentMessage{UserInputMessage: *currentUserMsg} + } else { + // Fallback when no user messages - still include system prompt if present + fallbackContent := "" + if systemPrompt != "" { + fallbackContent = "--- SYSTEM PROMPT ---\n" + systemPrompt + "\n--- END SYSTEM PROMPT ---\n" + } + currentMessage = kiroCurrentMessage{UserInputMessage: kiroUserInputMessage{ + Content: fallbackContent, + ModelID: modelID, + Origin: origin, + }} + } + + payload := kiroPayload{ + ConversationState: kiroConversationState{ + ConversationID: uuid.New().String(), + History: history, + CurrentMessage: currentMessage, + ChatTriggerType: "MANUAL", // Required by Kiro API + }, + ProfileArn: profileArn, + } + + // Ensure history is not nil (empty array) + if payload.ConversationState.History == nil { + payload.ConversationState.History = []kiroHistoryMessage{} + } + + result, err := json.Marshal(payload) + if err != nil { + log.Debugf("kiro: failed to marshal payload: %v", err) + return nil + } + return result +} + +// buildUserMessageStruct builds a user message and extracts tool results +// origin parameter determines which quota to use: "CLI" for Amazon Q, "AI_EDITOR" for Kiro IDE. +func (e *KiroExecutor) buildUserMessageStruct(msg gjson.Result, modelID, origin string) (kiroUserInputMessage, []kiroToolResult) { + content := msg.Get("content") + var contentBuilder strings.Builder + var toolResults []kiroToolResult + + if content.IsArray() { + for _, part := range content.Array() { + partType := part.Get("type").String() + switch partType { + case "text": + contentBuilder.WriteString(part.Get("text").String()) + case "tool_result": + // Extract tool result for API + toolUseID := part.Get("tool_use_id").String() + isError := part.Get("is_error").Bool() + resultContent := part.Get("content") + + // Convert content to Kiro format: [{text: "..."}] + var textContents []kiroTextContent + if resultContent.IsArray() { + for _, item := range resultContent.Array() { + if item.Get("type").String() == "text" { + textContents = append(textContents, kiroTextContent{Text: item.Get("text").String()}) + } else if item.Type == gjson.String { + textContents = append(textContents, kiroTextContent{Text: item.String()}) + } + } + } else if resultContent.Type == gjson.String { + textContents = append(textContents, kiroTextContent{Text: resultContent.String()}) + } + + // If no content, add default message + if len(textContents) == 0 { + textContents = append(textContents, kiroTextContent{Text: "Tool use was cancelled by the user"}) + } + + status := "success" + if isError { + status = "error" + } + + toolResults = append(toolResults, kiroToolResult{ + ToolUseID: toolUseID, + Content: textContents, + Status: status, + }) + } + } + } else { + contentBuilder.WriteString(content.String()) + } + + return kiroUserInputMessage{ + Content: contentBuilder.String(), + ModelID: modelID, + Origin: origin, + }, toolResults +} + +// buildAssistantMessageStruct builds an assistant message with tool uses +func (e *KiroExecutor) buildAssistantMessageStruct(msg gjson.Result) kiroAssistantResponseMessage { + content := msg.Get("content") + var contentBuilder strings.Builder + var toolUses []kiroToolUse + + if content.IsArray() { + for _, part := range content.Array() { + partType := part.Get("type").String() + switch partType { + case "text": + contentBuilder.WriteString(part.Get("text").String()) + case "tool_use": + // Extract tool use for API + toolUseID := part.Get("id").String() + toolName := part.Get("name").String() + toolInput := part.Get("input") + + // Convert input to map + var inputMap map[string]interface{} + if toolInput.IsObject() { + inputMap = make(map[string]interface{}) + toolInput.ForEach(func(key, value gjson.Result) bool { + inputMap[key.String()] = value.Value() + return true + }) + } + + toolUses = append(toolUses, kiroToolUse{ + ToolUseID: toolUseID, + Name: toolName, + Input: inputMap, + }) + } + } + } else { + contentBuilder.WriteString(content.String()) + } + + return kiroAssistantResponseMessage{ + Content: contentBuilder.String(), + ToolUses: toolUses, + } +} + +// NOTE: Tool calling is now supported via userInputMessageContext.tools and toolResults + +// parseEventStream parses AWS Event Stream binary format. +// Extracts text content and tool uses from the response. +// Supports embedded [Called ...] tool calls and input buffering for toolUseEvent. +func (e *KiroExecutor) parseEventStream(body io.Reader) (string, []kiroToolUse, usage.Detail, error) { + var content strings.Builder + var toolUses []kiroToolUse + var usageInfo usage.Detail + reader := bufio.NewReader(body) + + // Tool use state tracking for input buffering and deduplication + processedIDs := make(map[string]bool) + var currentToolUse *toolUseState + + for { + prelude := make([]byte, 8) + _, err := io.ReadFull(reader, prelude) + if err == io.EOF { + break + } + if err != nil { + return content.String(), toolUses, usageInfo, fmt.Errorf("failed to read prelude: %w", err) + } + + totalLen := binary.BigEndian.Uint32(prelude[0:4]) + if totalLen < 8 { + return content.String(), toolUses, usageInfo, fmt.Errorf("invalid message length: %d", totalLen) + } + if totalLen > kiroMaxMessageSize { + return content.String(), toolUses, usageInfo, fmt.Errorf("message too large: %d bytes", totalLen) + } + headersLen := binary.BigEndian.Uint32(prelude[4:8]) + + remaining := make([]byte, totalLen-8) + _, err = io.ReadFull(reader, remaining) + if err != nil { + return content.String(), toolUses, usageInfo, fmt.Errorf("failed to read message: %w", err) + } + + // Extract event type from headers + eventType := e.extractEventType(remaining[:headersLen+4]) + + payloadStart := 4 + headersLen + payloadEnd := uint32(len(remaining)) - 4 + if payloadStart >= payloadEnd { + continue + } + + payload := remaining[payloadStart:payloadEnd] + + var event map[string]interface{} + if err := json.Unmarshal(payload, &event); err != nil { + log.Debugf("kiro: skipping malformed event: %v", err) + continue + } + + // Handle different event types + switch eventType { + case "assistantResponseEvent": + if assistantResp, ok := event["assistantResponseEvent"].(map[string]interface{}); ok { + if contentText, ok := assistantResp["content"].(string); ok { + content.WriteString(contentText) + } + // Extract tool uses from response + if toolUsesRaw, ok := assistantResp["toolUses"].([]interface{}); ok { + for _, tuRaw := range toolUsesRaw { + if tu, ok := tuRaw.(map[string]interface{}); ok { + toolUseID := getString(tu, "toolUseId") + // Check for duplicate + if processedIDs[toolUseID] { + log.Debugf("kiro: skipping duplicate tool use from assistantResponse: %s", toolUseID) + continue + } + processedIDs[toolUseID] = true + + toolUse := kiroToolUse{ + ToolUseID: toolUseID, + Name: getString(tu, "name"), + } + if input, ok := tu["input"].(map[string]interface{}); ok { + toolUse.Input = input + } + toolUses = append(toolUses, toolUse) + } + } + } + } + // Also try direct format + if contentText, ok := event["content"].(string); ok { + content.WriteString(contentText) + } + // Direct tool uses + if toolUsesRaw, ok := event["toolUses"].([]interface{}); ok { + for _, tuRaw := range toolUsesRaw { + if tu, ok := tuRaw.(map[string]interface{}); ok { + toolUseID := getString(tu, "toolUseId") + // Check for duplicate + if processedIDs[toolUseID] { + log.Debugf("kiro: skipping duplicate direct tool use: %s", toolUseID) + continue + } + processedIDs[toolUseID] = true + + toolUse := kiroToolUse{ + ToolUseID: toolUseID, + Name: getString(tu, "name"), + } + if input, ok := tu["input"].(map[string]interface{}); ok { + toolUse.Input = input + } + toolUses = append(toolUses, toolUse) + } + } + } + + case "toolUseEvent": + // Handle dedicated tool use events with input buffering + completedToolUses, newState := e.processToolUseEvent(event, currentToolUse, processedIDs) + currentToolUse = newState + toolUses = append(toolUses, completedToolUses...) + + case "supplementaryWebLinksEvent": + if inputTokens, ok := event["inputTokens"].(float64); ok { + usageInfo.InputTokens = int64(inputTokens) + } + if outputTokens, ok := event["outputTokens"].(float64); ok { + usageInfo.OutputTokens = int64(outputTokens) + } + } + + // Also check nested supplementaryWebLinksEvent + if usageEvent, ok := event["supplementaryWebLinksEvent"].(map[string]interface{}); ok { + if inputTokens, ok := usageEvent["inputTokens"].(float64); ok { + usageInfo.InputTokens = int64(inputTokens) + } + if outputTokens, ok := usageEvent["outputTokens"].(float64); ok { + usageInfo.OutputTokens = int64(outputTokens) + } + } + } + + // Parse embedded tool calls from content (e.g., [Called tool_name with args: {...}]) + contentStr := content.String() + cleanedContent, embeddedToolUses := e.parseEmbeddedToolCalls(contentStr, processedIDs) + toolUses = append(toolUses, embeddedToolUses...) + + // Deduplicate all tool uses + toolUses = deduplicateToolUses(toolUses) + + return cleanedContent, toolUses, usageInfo, nil +} + +// extractEventType extracts the event type from AWS Event Stream headers +func (e *KiroExecutor) extractEventType(headerBytes []byte) string { + // Skip prelude CRC (4 bytes) + if len(headerBytes) < 4 { + return "" + } + headers := headerBytes[4:] + + offset := 0 + for offset < len(headers) { + if offset >= len(headers) { + break + } + nameLen := int(headers[offset]) + offset++ + if offset+nameLen > len(headers) { + break + } + name := string(headers[offset : offset+nameLen]) + offset += nameLen + + if offset >= len(headers) { + break + } + valueType := headers[offset] + offset++ + + if valueType == 7 { // String type + if offset+2 > len(headers) { + break + } + valueLen := int(binary.BigEndian.Uint16(headers[offset : offset+2])) + offset += 2 + if offset+valueLen > len(headers) { + break + } + value := string(headers[offset : offset+valueLen]) + offset += valueLen + + if name == ":event-type" { + return value + } + } else { + // Skip other types + break + } + } + return "" +} + +// getString safely extracts a string from a map +func getString(m map[string]interface{}, key string) string { + if v, ok := m[key].(string); ok { + return v + } + return "" +} + +// buildClaudeResponse constructs a Claude-compatible response. +// Supports tool_use blocks when tools are present in the response. +func (e *KiroExecutor) buildClaudeResponse(content string, toolUses []kiroToolUse, model string, usageInfo usage.Detail) []byte { + var contentBlocks []map[string]interface{} + + // Add text content if present + if content != "" { + contentBlocks = append(contentBlocks, map[string]interface{}{ + "type": "text", + "text": content, + }) + } + + // Add tool_use blocks + for _, toolUse := range toolUses { + contentBlocks = append(contentBlocks, map[string]interface{}{ + "type": "tool_use", + "id": toolUse.ToolUseID, + "name": toolUse.Name, + "input": toolUse.Input, + }) + } + + // Ensure at least one content block (Claude API requires non-empty content) + if len(contentBlocks) == 0 { + contentBlocks = append(contentBlocks, map[string]interface{}{ + "type": "text", + "text": "", + }) + } + + // Determine stop reason + stopReason := "end_turn" + if len(toolUses) > 0 { + stopReason = "tool_use" + } + + response := map[string]interface{}{ + "id": "msg_" + uuid.New().String()[:24], + "type": "message", + "role": "assistant", + "model": model, + "content": contentBlocks, + "stop_reason": stopReason, + "usage": map[string]interface{}{ + "input_tokens": usageInfo.InputTokens, + "output_tokens": usageInfo.OutputTokens, + }, + } + result, _ := json.Marshal(response) + return result +} + +// NOTE: Tool uses are now extracted from API response, not parsed from text + + +// streamToChannel converts AWS Event Stream to channel-based streaming. +// Supports tool calling - emits tool_use content blocks when tools are used. +// Includes embedded [Called ...] tool call parsing and input buffering for toolUseEvent. +func (e *KiroExecutor) streamToChannel(ctx context.Context, body io.Reader, out chan<- cliproxyexecutor.StreamChunk, targetFormat sdktranslator.Format, model string, originalReq, claudeBody []byte, reporter *usageReporter) { + reader := bufio.NewReader(body) + var totalUsage usage.Detail + var hasToolUses bool // Track if any tool uses were emitted + + // Tool use state tracking for input buffering and deduplication + processedIDs := make(map[string]bool) + var currentToolUse *toolUseState + + // Translator param for maintaining tool call state across streaming events + // IMPORTANT: This must persist across all TranslateStream calls + var translatorParam any + + // Pre-calculate input tokens from request if possible + if enc, err := tokenizerForModel(model); err == nil { + // Try OpenAI format first, then fall back to raw byte count estimation + if inp, err := countOpenAIChatTokens(enc, originalReq); err == nil && inp > 0 { + totalUsage.InputTokens = inp + } else { + // Fallback: estimate from raw request size (roughly 4 chars per token) + totalUsage.InputTokens = int64(len(originalReq) / 4) + if totalUsage.InputTokens == 0 && len(originalReq) > 0 { + totalUsage.InputTokens = 1 + } + } + log.Debugf("kiro: streamToChannel pre-calculated input tokens: %d (request size: %d bytes)", totalUsage.InputTokens, len(originalReq)) + } + + contentBlockIndex := -1 + messageStartSent := false + isTextBlockOpen := false + var outputLen int + + // Ensure usage is published even on early return + defer func() { + reporter.publish(ctx, totalUsage) + }() + + for { + select { + case <-ctx.Done(): + return + default: + } + + prelude := make([]byte, 8) + _, err := io.ReadFull(reader, prelude) + if err == io.EOF { + break + } + if err != nil { + out <- cliproxyexecutor.StreamChunk{Err: fmt.Errorf("failed to read prelude: %w", err)} + return + } + + totalLen := binary.BigEndian.Uint32(prelude[0:4]) + if totalLen < 8 { + out <- cliproxyexecutor.StreamChunk{Err: fmt.Errorf("invalid message length: %d", totalLen)} + return + } + if totalLen > kiroMaxMessageSize { + out <- cliproxyexecutor.StreamChunk{Err: fmt.Errorf("message too large: %d bytes", totalLen)} + return + } + headersLen := binary.BigEndian.Uint32(prelude[4:8]) + + remaining := make([]byte, totalLen-8) + _, err = io.ReadFull(reader, remaining) + if err != nil { + out <- cliproxyexecutor.StreamChunk{Err: fmt.Errorf("failed to read message: %w", err)} + return + } + + eventType := e.extractEventType(remaining[:headersLen+4]) + + payloadStart := 4 + headersLen + payloadEnd := uint32(len(remaining)) - 4 + if payloadStart >= payloadEnd { + continue + } + + payload := remaining[payloadStart:payloadEnd] + appendAPIResponseChunk(ctx, e.cfg, payload) + + var event map[string]interface{} + if err := json.Unmarshal(payload, &event); err != nil { + continue + } + + // Send message_start on first event + if !messageStartSent { + msgStart := e.buildClaudeMessageStartEvent(model, totalUsage.InputTokens) + sseData := sdktranslator.TranslateStream(ctx, sdktranslator.FromString("claude"), targetFormat, model, originalReq, claudeBody, msgStart, &translatorParam) + for _, chunk := range sseData { + if chunk != "" { + out <- cliproxyexecutor.StreamChunk{Payload: []byte(chunk + "\n\n")} + } + } + messageStartSent = true + } + + switch eventType { + case "assistantResponseEvent": + var contentDelta string + var toolUses []map[string]interface{} + + if assistantResp, ok := event["assistantResponseEvent"].(map[string]interface{}); ok { + if c, ok := assistantResp["content"].(string); ok { + contentDelta = c + } + // Extract tool uses from response + if tus, ok := assistantResp["toolUses"].([]interface{}); ok { + for _, tuRaw := range tus { + if tu, ok := tuRaw.(map[string]interface{}); ok { + toolUses = append(toolUses, tu) + } + } + } + } + if contentDelta == "" { + if c, ok := event["content"].(string); ok { + contentDelta = c + } + } + // Direct tool uses + if tus, ok := event["toolUses"].([]interface{}); ok { + for _, tuRaw := range tus { + if tu, ok := tuRaw.(map[string]interface{}); ok { + toolUses = append(toolUses, tu) + } + } + } + + // Handle text content + if contentDelta != "" { + outputLen += len(contentDelta) + // Start text content block if needed + if !isTextBlockOpen { + contentBlockIndex++ + isTextBlockOpen = true + blockStart := e.buildClaudeContentBlockStartEvent(contentBlockIndex, "text", "", "") + sseData := sdktranslator.TranslateStream(ctx, sdktranslator.FromString("claude"), targetFormat, model, originalReq, claudeBody, blockStart, &translatorParam) + for _, chunk := range sseData { + if chunk != "" { + out <- cliproxyexecutor.StreamChunk{Payload: []byte(chunk + "\n\n")} + } + } + } + + claudeEvent := e.buildClaudeStreamEvent(contentDelta, contentBlockIndex) + sseData := sdktranslator.TranslateStream(ctx, sdktranslator.FromString("claude"), targetFormat, model, originalReq, claudeBody, claudeEvent, &translatorParam) + for _, chunk := range sseData { + if chunk != "" { + out <- cliproxyexecutor.StreamChunk{Payload: []byte(chunk + "\n\n")} + } + } + } + + // Handle tool uses in response (with deduplication) + for _, tu := range toolUses { + toolUseID := getString(tu, "toolUseId") + + // Check for duplicate + if processedIDs[toolUseID] { + log.Debugf("kiro: skipping duplicate tool use in stream: %s", toolUseID) + continue + } + processedIDs[toolUseID] = true + + hasToolUses = true + // Close text block if open before starting tool_use block + if isTextBlockOpen && contentBlockIndex >= 0 { + blockStop := e.buildClaudeContentBlockStopEvent(contentBlockIndex) + sseData := sdktranslator.TranslateStream(ctx, sdktranslator.FromString("claude"), targetFormat, model, originalReq, claudeBody, blockStop, &translatorParam) + for _, chunk := range sseData { + if chunk != "" { + out <- cliproxyexecutor.StreamChunk{Payload: []byte(chunk + "\n\n")} + } + } + isTextBlockOpen = false + } + + // Emit tool_use content block + contentBlockIndex++ + toolName := getString(tu, "name") + + blockStart := e.buildClaudeContentBlockStartEvent(contentBlockIndex, "tool_use", toolUseID, toolName) + sseData := sdktranslator.TranslateStream(ctx, sdktranslator.FromString("claude"), targetFormat, model, originalReq, claudeBody, blockStart, &translatorParam) + for _, chunk := range sseData { + if chunk != "" { + out <- cliproxyexecutor.StreamChunk{Payload: []byte(chunk + "\n\n")} + } + } + + // Send input_json_delta with the tool input + if input, ok := tu["input"].(map[string]interface{}); ok { + inputJSON, err := json.Marshal(input) + if err != nil { + log.Debugf("kiro: failed to marshal tool input: %v", err) + // Don't continue - still need to close the block + } else { + inputDelta := e.buildClaudeInputJsonDeltaEvent(string(inputJSON), contentBlockIndex) + sseData = sdktranslator.TranslateStream(ctx, sdktranslator.FromString("claude"), targetFormat, model, originalReq, claudeBody, inputDelta, &translatorParam) + for _, chunk := range sseData { + if chunk != "" { + out <- cliproxyexecutor.StreamChunk{Payload: []byte(chunk + "\n\n")} + } + } + } + } + + // Close tool_use block (always close even if input marshal failed) + blockStop := e.buildClaudeContentBlockStopEvent(contentBlockIndex) + sseData = sdktranslator.TranslateStream(ctx, sdktranslator.FromString("claude"), targetFormat, model, originalReq, claudeBody, blockStop, &translatorParam) + for _, chunk := range sseData { + if chunk != "" { + out <- cliproxyexecutor.StreamChunk{Payload: []byte(chunk + "\n\n")} + } + } + } + + case "toolUseEvent": + // Handle dedicated tool use events with input buffering + completedToolUses, newState := e.processToolUseEvent(event, currentToolUse, processedIDs) + currentToolUse = newState + + // Emit completed tool uses + for _, tu := range completedToolUses { + hasToolUses = true + + // Close text block if open + if isTextBlockOpen && contentBlockIndex >= 0 { + blockStop := e.buildClaudeContentBlockStopEvent(contentBlockIndex) + sseData := sdktranslator.TranslateStream(ctx, sdktranslator.FromString("claude"), targetFormat, model, originalReq, claudeBody, blockStop, &translatorParam) + for _, chunk := range sseData { + if chunk != "" { + out <- cliproxyexecutor.StreamChunk{Payload: []byte(chunk + "\n\n")} + } + } + isTextBlockOpen = false + } + + contentBlockIndex++ + + blockStart := e.buildClaudeContentBlockStartEvent(contentBlockIndex, "tool_use", tu.ToolUseID, tu.Name) + sseData := sdktranslator.TranslateStream(ctx, sdktranslator.FromString("claude"), targetFormat, model, originalReq, claudeBody, blockStart, &translatorParam) + for _, chunk := range sseData { + if chunk != "" { + out <- cliproxyexecutor.StreamChunk{Payload: []byte(chunk + "\n\n")} + } + } + + if tu.Input != nil { + inputJSON, err := json.Marshal(tu.Input) + if err != nil { + log.Debugf("kiro: failed to marshal tool input in toolUseEvent: %v", err) + } else { + inputDelta := e.buildClaudeInputJsonDeltaEvent(string(inputJSON), contentBlockIndex) + sseData = sdktranslator.TranslateStream(ctx, sdktranslator.FromString("claude"), targetFormat, model, originalReq, claudeBody, inputDelta, &translatorParam) + for _, chunk := range sseData { + if chunk != "" { + out <- cliproxyexecutor.StreamChunk{Payload: []byte(chunk + "\n\n")} + } + } + } + } + + blockStop := e.buildClaudeContentBlockStopEvent(contentBlockIndex) + sseData = sdktranslator.TranslateStream(ctx, sdktranslator.FromString("claude"), targetFormat, model, originalReq, claudeBody, blockStop, &translatorParam) + for _, chunk := range sseData { + if chunk != "" { + out <- cliproxyexecutor.StreamChunk{Payload: []byte(chunk + "\n\n")} + } + } + } + + case "supplementaryWebLinksEvent": + if inputTokens, ok := event["inputTokens"].(float64); ok { + totalUsage.InputTokens = int64(inputTokens) + } + if outputTokens, ok := event["outputTokens"].(float64); ok { + totalUsage.OutputTokens = int64(outputTokens) + } + } + + // Check nested usage event + if usageEvent, ok := event["supplementaryWebLinksEvent"].(map[string]interface{}); ok { + if inputTokens, ok := usageEvent["inputTokens"].(float64); ok { + totalUsage.InputTokens = int64(inputTokens) + } + if outputTokens, ok := usageEvent["outputTokens"].(float64); ok { + totalUsage.OutputTokens = int64(outputTokens) + } + } + } + + // Close content block if open + if isTextBlockOpen && contentBlockIndex >= 0 { + blockStop := e.buildClaudeContentBlockStopEvent(contentBlockIndex) + sseData := sdktranslator.TranslateStream(ctx, sdktranslator.FromString("claude"), targetFormat, model, originalReq, claudeBody, blockStop, &translatorParam) + for _, chunk := range sseData { + if chunk != "" { + out <- cliproxyexecutor.StreamChunk{Payload: []byte(chunk + "\n\n")} + } + } + } + + // Fallback for output tokens if not received from upstream + if totalUsage.OutputTokens == 0 && outputLen > 0 { + totalUsage.OutputTokens = int64(outputLen / 4) + if totalUsage.OutputTokens == 0 { + totalUsage.OutputTokens = 1 + } + } + totalUsage.TotalTokens = totalUsage.InputTokens + totalUsage.OutputTokens + + // Determine stop reason based on whether tool uses were emitted + stopReason := "end_turn" + if hasToolUses { + stopReason = "tool_use" + } + + // Send message_delta and message_stop + msgStop := e.buildClaudeMessageStopEvent(stopReason, totalUsage) + sseData := sdktranslator.TranslateStream(ctx, sdktranslator.FromString("claude"), targetFormat, model, originalReq, claudeBody, msgStop, &translatorParam) + for _, chunk := range sseData { + if chunk != "" { + out <- cliproxyexecutor.StreamChunk{Payload: []byte(chunk + "\n\n")} + } + } + // reporter.publish is called via defer +} + + +// Claude SSE event builders +func (e *KiroExecutor) buildClaudeMessageStartEvent(model string, inputTokens int64) []byte { + event := map[string]interface{}{ + "type": "message_start", + "message": map[string]interface{}{ + "id": "msg_" + uuid.New().String()[:24], + "type": "message", + "role": "assistant", + "content": []interface{}{}, + "model": model, + "stop_reason": nil, + "stop_sequence": nil, + "usage": map[string]interface{}{"input_tokens": inputTokens, "output_tokens": 0}, + }, + } + result, _ := json.Marshal(event) + return []byte("data: " + string(result)) +} + +func (e *KiroExecutor) buildClaudeContentBlockStartEvent(index int, blockType, toolUseID, toolName string) []byte { + var contentBlock map[string]interface{} + if blockType == "tool_use" { + contentBlock = map[string]interface{}{ + "type": "tool_use", + "id": toolUseID, + "name": toolName, + "input": map[string]interface{}{}, + } + } else { + contentBlock = map[string]interface{}{ + "type": "text", + "text": "", + } + } + + event := map[string]interface{}{ + "type": "content_block_start", + "index": index, + "content_block": contentBlock, + } + result, _ := json.Marshal(event) + return []byte("data: " + string(result)) +} + +func (e *KiroExecutor) buildClaudeStreamEvent(contentDelta string, index int) []byte { + event := map[string]interface{}{ + "type": "content_block_delta", + "index": index, + "delta": map[string]interface{}{ + "type": "text_delta", + "text": contentDelta, + }, + } + result, _ := json.Marshal(event) + return []byte("data: " + string(result)) +} + +// buildClaudeInputJsonDeltaEvent creates an input_json_delta event for tool use streaming +func (e *KiroExecutor) buildClaudeInputJsonDeltaEvent(partialJSON string, index int) []byte { + event := map[string]interface{}{ + "type": "content_block_delta", + "index": index, + "delta": map[string]interface{}{ + "type": "input_json_delta", + "partial_json": partialJSON, + }, + } + result, _ := json.Marshal(event) + return []byte("data: " + string(result)) +} + +func (e *KiroExecutor) buildClaudeContentBlockStopEvent(index int) []byte { + event := map[string]interface{}{ + "type": "content_block_stop", + "index": index, + } + result, _ := json.Marshal(event) + return []byte("data: " + string(result)) +} + +func (e *KiroExecutor) buildClaudeMessageStopEvent(stopReason string, usageInfo usage.Detail) []byte { + // First message_delta + deltaEvent := map[string]interface{}{ + "type": "message_delta", + "delta": map[string]interface{}{ + "stop_reason": stopReason, + "stop_sequence": nil, + }, + "usage": map[string]interface{}{ + "input_tokens": usageInfo.InputTokens, + "output_tokens": usageInfo.OutputTokens, + }, + } + deltaResult, _ := json.Marshal(deltaEvent) + + // Then message_stop + stopEvent := map[string]interface{}{ + "type": "message_stop", + } + stopResult, _ := json.Marshal(stopEvent) + + return []byte("data: " + string(deltaResult) + "\n\ndata: " + string(stopResult)) +} + +// buildClaudeFinalEvent constructs the final Claude-style event. +func (e *KiroExecutor) buildClaudeFinalEvent() []byte { + event := map[string]interface{}{ + "type": "message_stop", + } + result, _ := json.Marshal(event) + return []byte("data: " + string(result)) +} + +// CountTokens is not supported for the Kiro provider. +func (e *KiroExecutor) CountTokens(context.Context, *cliproxyauth.Auth, cliproxyexecutor.Request, cliproxyexecutor.Options) (cliproxyexecutor.Response, error) { + return cliproxyexecutor.Response{}, statusErr{code: http.StatusNotImplemented, msg: "count tokens not supported for kiro"} +} + +// Refresh refreshes the Kiro OAuth token. +// Supports both AWS Builder ID (SSO OIDC) and Google OAuth (social login). +// Uses mutex to prevent race conditions when multiple concurrent requests try to refresh. +func (e *KiroExecutor) Refresh(ctx context.Context, auth *cliproxyauth.Auth) (*cliproxyauth.Auth, error) { + // Serialize token refresh operations to prevent race conditions + e.refreshMu.Lock() + defer e.refreshMu.Unlock() + + log.Debugf("kiro executor: refresh called for auth %s", auth.ID) + if auth == nil { + return nil, fmt.Errorf("kiro executor: auth is nil") + } + + // Double-check: After acquiring lock, verify token still needs refresh + // Another goroutine may have already refreshed while we were waiting + // NOTE: This check has a design limitation - it reads from the auth object passed in, + // not from persistent storage. If another goroutine returns a new Auth object (via Clone), + // this check won't see those updates. The mutex still prevents truly concurrent refreshes, + // but queued goroutines may still attempt redundant refreshes. This is acceptable as + // the refresh operation is idempotent and the extra API calls are infrequent. + if auth.Metadata != nil { + if lastRefresh, ok := auth.Metadata["last_refresh"].(string); ok { + if refreshTime, err := time.Parse(time.RFC3339, lastRefresh); err == nil { + // If token was refreshed within the last 30 seconds, skip refresh + if time.Since(refreshTime) < 30*time.Second { + log.Debugf("kiro executor: token was recently refreshed by another goroutine, skipping") + return auth, nil + } + } + } + // Also check if expires_at is now in the future with sufficient buffer + if expiresAt, ok := auth.Metadata["expires_at"].(string); ok { + if expTime, err := time.Parse(time.RFC3339, expiresAt); err == nil { + // If token expires more than 2 minutes from now, it's still valid + if time.Until(expTime) > 2*time.Minute { + log.Debugf("kiro executor: token is still valid (expires in %v), skipping refresh", time.Until(expTime)) + return auth, nil + } + } + } + } + + var refreshToken string + var clientID, clientSecret string + var authMethod string + + if auth.Metadata != nil { + if rt, ok := auth.Metadata["refresh_token"].(string); ok { + refreshToken = rt + } + if cid, ok := auth.Metadata["client_id"].(string); ok { + clientID = cid + } + if cs, ok := auth.Metadata["client_secret"].(string); ok { + clientSecret = cs + } + if am, ok := auth.Metadata["auth_method"].(string); ok { + authMethod = am + } + } + + if refreshToken == "" { + return nil, fmt.Errorf("kiro executor: refresh token not found") + } + + var tokenData *kiroauth.KiroTokenData + var err error + + // Use SSO OIDC refresh for AWS Builder ID, otherwise use Kiro's OAuth refresh endpoint + if clientID != "" && clientSecret != "" && authMethod == "builder-id" { + log.Debugf("kiro executor: using SSO OIDC refresh for AWS Builder ID") + ssoClient := kiroauth.NewSSOOIDCClient(e.cfg) + tokenData, err = ssoClient.RefreshToken(ctx, clientID, clientSecret, refreshToken) + } else { + log.Debugf("kiro executor: using Kiro OAuth refresh endpoint") + oauth := kiroauth.NewKiroOAuth(e.cfg) + tokenData, err = oauth.RefreshToken(ctx, refreshToken) + } + + if err != nil { + return nil, fmt.Errorf("kiro executor: token refresh failed: %w", err) + } + + updated := auth.Clone() + now := time.Now() + updated.UpdatedAt = now + updated.LastRefreshedAt = now + + if updated.Metadata == nil { + updated.Metadata = make(map[string]any) + } + updated.Metadata["access_token"] = tokenData.AccessToken + updated.Metadata["refresh_token"] = tokenData.RefreshToken + updated.Metadata["expires_at"] = tokenData.ExpiresAt + updated.Metadata["last_refresh"] = now.Format(time.RFC3339) + if tokenData.ProfileArn != "" { + updated.Metadata["profile_arn"] = tokenData.ProfileArn + } + if tokenData.AuthMethod != "" { + updated.Metadata["auth_method"] = tokenData.AuthMethod + } + if tokenData.Provider != "" { + updated.Metadata["provider"] = tokenData.Provider + } + // Preserve client credentials for future refreshes (AWS Builder ID) + if tokenData.ClientID != "" { + updated.Metadata["client_id"] = tokenData.ClientID + } + if tokenData.ClientSecret != "" { + updated.Metadata["client_secret"] = tokenData.ClientSecret + } + + if updated.Attributes == nil { + updated.Attributes = make(map[string]string) + } + updated.Attributes["access_token"] = tokenData.AccessToken + if tokenData.ProfileArn != "" { + updated.Attributes["profile_arn"] = tokenData.ProfileArn + } + + // Set next refresh time to 30 minutes before expiry + if expiresAt, parseErr := time.Parse(time.RFC3339, tokenData.ExpiresAt); parseErr == nil { + updated.NextRefreshAfter = expiresAt.Add(-30 * time.Minute) + } + + log.Infof("kiro executor: token refreshed successfully, expires at %s", tokenData.ExpiresAt) + return updated, nil +} + +// streamEventStream converts AWS Event Stream to SSE (legacy method for gin.Context). +// Note: For full tool calling support, use streamToChannel instead. +func (e *KiroExecutor) streamEventStream(ctx context.Context, body io.Reader, c *gin.Context, targetFormat sdktranslator.Format, model string, originalReq, claudeBody []byte, reporter *usageReporter) error { + reader := bufio.NewReader(body) + var totalUsage usage.Detail + + // Translator param for maintaining tool call state across streaming events + var translatorParam any + + // Pre-calculate input tokens from request if possible + if enc, err := tokenizerForModel(model); err == nil { + // Try OpenAI format first, then fall back to raw byte count estimation + if inp, err := countOpenAIChatTokens(enc, originalReq); err == nil && inp > 0 { + totalUsage.InputTokens = inp + } else { + // Fallback: estimate from raw request size (roughly 4 chars per token) + totalUsage.InputTokens = int64(len(originalReq) / 4) + if totalUsage.InputTokens == 0 && len(originalReq) > 0 { + totalUsage.InputTokens = 1 + } + } + log.Debugf("kiro: streamEventStream pre-calculated input tokens: %d (request size: %d bytes)", totalUsage.InputTokens, len(originalReq)) + } + + contentBlockIndex := -1 + messageStartSent := false + isBlockOpen := false + var outputLen int + + for { + select { + case <-ctx.Done(): + return ctx.Err() + default: + } + + prelude := make([]byte, 8) + _, err := io.ReadFull(reader, prelude) + if err == io.EOF { + break + } + if err != nil { + return fmt.Errorf("failed to read prelude: %w", err) + } + + totalLen := binary.BigEndian.Uint32(prelude[0:4]) + if totalLen < 8 { + return fmt.Errorf("invalid message length: %d", totalLen) + } + if totalLen > kiroMaxMessageSize { + return fmt.Errorf("message too large: %d bytes", totalLen) + } + headersLen := binary.BigEndian.Uint32(prelude[4:8]) + + remaining := make([]byte, totalLen-8) + _, err = io.ReadFull(reader, remaining) + if err != nil { + return fmt.Errorf("failed to read message: %w", err) + } + + eventType := e.extractEventType(remaining[:headersLen+4]) + + payloadStart := 4 + headersLen + payloadEnd := uint32(len(remaining)) - 4 + if payloadStart >= payloadEnd { + continue + } + + payload := remaining[payloadStart:payloadEnd] + appendAPIResponseChunk(ctx, e.cfg, payload) + + var event map[string]interface{} + if err := json.Unmarshal(payload, &event); err != nil { + continue + } + + if !messageStartSent { + msgStart := e.buildClaudeMessageStartEvent(model, totalUsage.InputTokens) + sseData := sdktranslator.TranslateStream(ctx, sdktranslator.FromString("claude"), targetFormat, model, originalReq, claudeBody, msgStart, &translatorParam) + for _, chunk := range sseData { + if chunk != "" { + c.Writer.Write([]byte(chunk + "\n\n")) + } + } + c.Writer.Flush() + messageStartSent = true + } + + switch eventType { + case "assistantResponseEvent": + var contentDelta string + if assistantResp, ok := event["assistantResponseEvent"].(map[string]interface{}); ok { + if ct, ok := assistantResp["content"].(string); ok { + contentDelta = ct + } + } + if contentDelta == "" { + if ct, ok := event["content"].(string); ok { + contentDelta = ct + } + } + + if contentDelta != "" { + outputLen += len(contentDelta) + // Start text content block if needed + if !isBlockOpen { + contentBlockIndex++ + isBlockOpen = true + blockStart := e.buildClaudeContentBlockStartEvent(contentBlockIndex, "text", "", "") + sseData := sdktranslator.TranslateStream(ctx, sdktranslator.FromString("claude"), targetFormat, model, originalReq, claudeBody, blockStart, &translatorParam) + for _, chunk := range sseData { + if chunk != "" { + c.Writer.Write([]byte(chunk + "\n\n")) + } + } + c.Writer.Flush() + } + + claudeEvent := e.buildClaudeStreamEvent(contentDelta, contentBlockIndex) + sseData := sdktranslator.TranslateStream(ctx, sdktranslator.FromString("claude"), targetFormat, model, originalReq, claudeBody, claudeEvent, &translatorParam) + for _, chunk := range sseData { + if chunk != "" { + c.Writer.Write([]byte(chunk + "\n\n")) + } + } + c.Writer.Flush() + } + + // Note: For full toolUseEvent support, use streamToChannel + + case "supplementaryWebLinksEvent": + if inputTokens, ok := event["inputTokens"].(float64); ok { + totalUsage.InputTokens = int64(inputTokens) + } + if outputTokens, ok := event["outputTokens"].(float64); ok { + totalUsage.OutputTokens = int64(outputTokens) + } + } + + if usageEvent, ok := event["supplementaryWebLinksEvent"].(map[string]interface{}); ok { + if inputTokens, ok := usageEvent["inputTokens"].(float64); ok { + totalUsage.InputTokens = int64(inputTokens) + } + if outputTokens, ok := usageEvent["outputTokens"].(float64); ok { + totalUsage.OutputTokens = int64(outputTokens) + } + } + } + + // Close content block if open + if isBlockOpen && contentBlockIndex >= 0 { + blockStop := e.buildClaudeContentBlockStopEvent(contentBlockIndex) + sseData := sdktranslator.TranslateStream(ctx, sdktranslator.FromString("claude"), targetFormat, model, originalReq, claudeBody, blockStop, &translatorParam) + for _, chunk := range sseData { + if chunk != "" { + c.Writer.Write([]byte(chunk + "\n\n")) + } + } + c.Writer.Flush() + } + + // Fallback for output tokens if not received from upstream + if totalUsage.OutputTokens == 0 && outputLen > 0 { + totalUsage.OutputTokens = int64(outputLen / 4) + if totalUsage.OutputTokens == 0 { + totalUsage.OutputTokens = 1 + } + } + totalUsage.TotalTokens = totalUsage.InputTokens + totalUsage.OutputTokens + + // Always use end_turn (no tool_use support) + msgStop := e.buildClaudeMessageStopEvent("end_turn", totalUsage) + sseData := sdktranslator.TranslateStream(ctx, sdktranslator.FromString("claude"), targetFormat, model, originalReq, claudeBody, msgStop, &translatorParam) + for _, chunk := range sseData { + if chunk != "" { + c.Writer.Write([]byte(chunk + "\n\n")) + } + } + + c.Writer.Write([]byte("data: [DONE]\n\n")) + c.Writer.Flush() + + reporter.publish(ctx, totalUsage) + return nil +} + +// isTokenExpired checks if a JWT access token has expired. +// Returns true if the token is expired or cannot be parsed. +func (e *KiroExecutor) isTokenExpired(accessToken string) bool { + if accessToken == "" { + return true + } + + // JWT tokens have 3 parts separated by dots + parts := strings.Split(accessToken, ".") + if len(parts) != 3 { + // Not a JWT token, assume not expired + return false + } + + // Decode the payload (second part) + // JWT uses base64url encoding without padding (RawURLEncoding) + payload := parts[1] + decoded, err := base64.RawURLEncoding.DecodeString(payload) + if err != nil { + // Try with padding added as fallback + switch len(payload) % 4 { + case 2: + payload += "==" + case 3: + payload += "=" + } + decoded, err = base64.URLEncoding.DecodeString(payload) + if err != nil { + log.Debugf("kiro: failed to decode JWT payload: %v", err) + return false + } + } + + var claims struct { + Exp int64 `json:"exp"` + } + if err := json.Unmarshal(decoded, &claims); err != nil { + log.Debugf("kiro: failed to parse JWT claims: %v", err) + return false + } + + if claims.Exp == 0 { + // No expiration claim, assume not expired + return false + } + + expTime := time.Unix(claims.Exp, 0) + now := time.Now() + + // Consider token expired if it expires within 1 minute (buffer for clock skew) + isExpired := now.After(expTime) || expTime.Sub(now) < time.Minute + if isExpired { + log.Debugf("kiro: token expired at %s (now: %s)", expTime.Format(time.RFC3339), now.Format(time.RFC3339)) + } + + return isExpired +} + +// ============================================================================ +// Tool Calling Support - Embedded tool call parsing and input buffering +// Based on amq2api and AIClient-2-API implementations +// ============================================================================ + +// toolUseState tracks the state of an in-progress tool use during streaming. +type toolUseState struct { + toolUseID string + name string + inputBuffer strings.Builder + isComplete bool +} + +// Pre-compiled regex patterns for performance (avoid recompilation on each call) +var ( + // embeddedToolCallPattern matches [Called tool_name with args: {...}] format + // This pattern is used by Kiro when it embeds tool calls in text content + embeddedToolCallPattern = regexp.MustCompile(`\[Called\s+(\w+)\s+with\s+args:\s*`) + // whitespaceCollapsePattern collapses multiple whitespace characters into single space + whitespaceCollapsePattern = regexp.MustCompile(`\s+`) + // trailingCommaPattern matches trailing commas before closing braces/brackets + trailingCommaPattern = regexp.MustCompile(`,\s*([}\]])`) + // unquotedKeyPattern matches unquoted JSON keys that need quoting + unquotedKeyPattern = regexp.MustCompile(`([{,]\s*)([a-zA-Z_][a-zA-Z0-9_]*)\s*:`) +) + +// parseEmbeddedToolCalls extracts [Called tool_name with args: {...}] format from text. +// Kiro sometimes embeds tool calls in text content instead of using toolUseEvent. +// Returns the cleaned text (with tool calls removed) and extracted tool uses. +func (e *KiroExecutor) parseEmbeddedToolCalls(text string, processedIDs map[string]bool) (string, []kiroToolUse) { + if !strings.Contains(text, "[Called") { + return text, nil + } + + var toolUses []kiroToolUse + cleanText := text + + // Find all [Called markers + matches := embeddedToolCallPattern.FindAllStringSubmatchIndex(text, -1) + if len(matches) == 0 { + return text, nil + } + + // Process matches in reverse order to maintain correct indices + for i := len(matches) - 1; i >= 0; i-- { + matchStart := matches[i][0] + toolNameStart := matches[i][2] + toolNameEnd := matches[i][3] + + if toolNameStart < 0 || toolNameEnd < 0 { + continue + } + + toolName := text[toolNameStart:toolNameEnd] + + // Find the JSON object start (after "with args:") + jsonStart := matches[i][1] + if jsonStart >= len(text) { + continue + } + + // Skip whitespace to find the opening brace + for jsonStart < len(text) && (text[jsonStart] == ' ' || text[jsonStart] == '\t') { + jsonStart++ + } + + if jsonStart >= len(text) || text[jsonStart] != '{' { + continue + } + + // Find matching closing bracket + jsonEnd := findMatchingBracket(text, jsonStart) + if jsonEnd < 0 { + continue + } + + // Extract JSON and find the closing bracket of [Called ...] + jsonStr := text[jsonStart : jsonEnd+1] + + // Find the closing ] after the JSON + closingBracket := jsonEnd + 1 + for closingBracket < len(text) && text[closingBracket] != ']' { + closingBracket++ + } + if closingBracket >= len(text) { + continue + } + + // Extract and repair the full tool call text + fullMatch := text[matchStart : closingBracket+1] + + // Repair and parse JSON + repairedJSON := repairJSON(jsonStr) + var inputMap map[string]interface{} + if err := json.Unmarshal([]byte(repairedJSON), &inputMap); err != nil { + log.Debugf("kiro: failed to parse embedded tool call JSON: %v, raw: %s", err, jsonStr) + continue + } + + // Generate unique tool ID + toolUseID := "toolu_" + uuid.New().String()[:12] + + // Check for duplicates using name+input as key + dedupeKey := toolName + ":" + repairedJSON + if processedIDs != nil { + if processedIDs[dedupeKey] { + log.Debugf("kiro: skipping duplicate embedded tool call: %s", toolName) + // Still remove from text even if duplicate + cleanText = strings.Replace(cleanText, fullMatch, "", 1) + continue + } + processedIDs[dedupeKey] = true + } + + toolUses = append(toolUses, kiroToolUse{ + ToolUseID: toolUseID, + Name: toolName, + Input: inputMap, + }) + + log.Infof("kiro: extracted embedded tool call: %s (ID: %s)", toolName, toolUseID) + + // Remove from clean text + cleanText = strings.Replace(cleanText, fullMatch, "", 1) + } + + // Clean up extra whitespace + cleanText = strings.TrimSpace(cleanText) + cleanText = whitespaceCollapsePattern.ReplaceAllString(cleanText, " ") + + return cleanText, toolUses +} + +// findMatchingBracket finds the index of the closing brace/bracket that matches +// the opening one at startPos. Handles nested objects and strings correctly. +func findMatchingBracket(text string, startPos int) int { + if startPos >= len(text) { + return -1 + } + + openChar := text[startPos] + var closeChar byte + switch openChar { + case '{': + closeChar = '}' + case '[': + closeChar = ']' + default: + return -1 + } + + depth := 1 + inString := false + escapeNext := false + + for i := startPos + 1; i < len(text); i++ { + char := text[i] + + if escapeNext { + escapeNext = false + continue + } + + if char == '\\' && inString { + escapeNext = true + continue + } + + if char == '"' { + inString = !inString + continue + } + + if !inString { + if char == openChar { + depth++ + } else if char == closeChar { + depth-- + if depth == 0 { + return i + } + } + } + } + + return -1 +} + +// repairJSON attempts to fix common JSON issues that may occur in tool call arguments. +// Based on AIClient-2-API's JSON repair implementation. +// Uses pre-compiled regex patterns for performance. +func repairJSON(raw string) string { + // Remove trailing commas before closing braces/brackets + repaired := trailingCommaPattern.ReplaceAllString(raw, "$1") + // Fix unquoted keys (basic attempt - handles simple cases) + repaired = unquotedKeyPattern.ReplaceAllString(repaired, `$1"$2":`) + return repaired +} + +// processToolUseEvent handles a toolUseEvent from the Kiro stream. +// It accumulates input fragments and emits tool_use blocks when complete. +// Returns events to emit and updated state. +func (e *KiroExecutor) processToolUseEvent(event map[string]interface{}, currentToolUse *toolUseState, processedIDs map[string]bool) ([]kiroToolUse, *toolUseState) { + var toolUses []kiroToolUse + + // Extract from nested toolUseEvent or direct format + tu := event + if nested, ok := event["toolUseEvent"].(map[string]interface{}); ok { + tu = nested + } + + toolUseID := getString(tu, "toolUseId") + toolName := getString(tu, "name") + isStop := false + if stop, ok := tu["stop"].(bool); ok { + isStop = stop + } + + // Get input - can be string (fragment) or object (complete) + var inputFragment string + var inputMap map[string]interface{} + + if inputRaw, ok := tu["input"]; ok { + switch v := inputRaw.(type) { + case string: + inputFragment = v + case map[string]interface{}: + inputMap = v + } + } + + // New tool use starting + if toolUseID != "" && toolName != "" { + if currentToolUse != nil && currentToolUse.toolUseID != toolUseID { + // New tool use arrived while another is in progress (interleaved events) + // This is unusual - log warning and complete the previous one + log.Warnf("kiro: interleaved tool use detected - new ID %s arrived while %s in progress, completing previous", + toolUseID, currentToolUse.toolUseID) + // Emit incomplete previous tool use + if !processedIDs[currentToolUse.toolUseID] { + incomplete := kiroToolUse{ + ToolUseID: currentToolUse.toolUseID, + Name: currentToolUse.name, + } + if currentToolUse.inputBuffer.Len() > 0 { + var input map[string]interface{} + if err := json.Unmarshal([]byte(currentToolUse.inputBuffer.String()), &input); err == nil { + incomplete.Input = input + } + } + toolUses = append(toolUses, incomplete) + processedIDs[currentToolUse.toolUseID] = true + } + currentToolUse = nil + } + + if currentToolUse == nil { + // Check for duplicate + if processedIDs != nil && processedIDs[toolUseID] { + log.Debugf("kiro: skipping duplicate toolUseEvent: %s", toolUseID) + return nil, nil + } + + currentToolUse = &toolUseState{ + toolUseID: toolUseID, + name: toolName, + } + log.Infof("kiro: starting new tool use: %s (ID: %s)", toolName, toolUseID) + } + } + + // Accumulate input fragments + if currentToolUse != nil && inputFragment != "" { + currentToolUse.inputBuffer.WriteString(inputFragment) + log.Debugf("kiro: accumulated input fragment, total length: %d", currentToolUse.inputBuffer.Len()) + } + + // If complete input object provided directly + if currentToolUse != nil && inputMap != nil { + inputBytes, _ := json.Marshal(inputMap) + currentToolUse.inputBuffer.Reset() + currentToolUse.inputBuffer.Write(inputBytes) + } + + // Tool use complete + if isStop && currentToolUse != nil { + fullInput := currentToolUse.inputBuffer.String() + + // Repair and parse the accumulated JSON + repairedJSON := repairJSON(fullInput) + var finalInput map[string]interface{} + if err := json.Unmarshal([]byte(repairedJSON), &finalInput); err != nil { + log.Warnf("kiro: failed to parse accumulated tool input: %v, raw: %s", err, fullInput) + // Use empty input as fallback + finalInput = make(map[string]interface{}) + } + + toolUse := kiroToolUse{ + ToolUseID: currentToolUse.toolUseID, + Name: currentToolUse.name, + Input: finalInput, + } + toolUses = append(toolUses, toolUse) + + // Mark as processed + if processedIDs != nil { + processedIDs[currentToolUse.toolUseID] = true + } + + log.Infof("kiro: completed tool use: %s (ID: %s)", currentToolUse.name, currentToolUse.toolUseID) + return toolUses, nil // Reset state + } + + return toolUses, currentToolUse +} + +// deduplicateToolUses removes duplicate tool uses based on toolUseId. +func deduplicateToolUses(toolUses []kiroToolUse) []kiroToolUse { + seen := make(map[string]bool) + var unique []kiroToolUse + + for _, tu := range toolUses { + if !seen[tu.ToolUseID] { + seen[tu.ToolUseID] = true + unique = append(unique, tu) + } else { + log.Debugf("kiro: removing duplicate tool use: %s", tu.ToolUseID) + } + } + + return unique +} diff --git a/internal/runtime/executor/proxy_helpers.go b/internal/runtime/executor/proxy_helpers.go index ab0f626a..8998eb23 100644 --- a/internal/runtime/executor/proxy_helpers.go +++ b/internal/runtime/executor/proxy_helpers.go @@ -6,6 +6,7 @@ import ( "net/http" "net/url" "strings" + "sync" "time" "github.com/router-for-me/CLIProxyAPI/v6/internal/config" @@ -14,11 +15,19 @@ import ( "golang.org/x/net/proxy" ) +// httpClientCache caches HTTP clients by proxy URL to enable connection reuse +var ( + httpClientCache = make(map[string]*http.Client) + httpClientCacheMutex sync.RWMutex +) + // newProxyAwareHTTPClient creates an HTTP client with proper proxy configuration priority: // 1. Use auth.ProxyURL if configured (highest priority) // 2. Use cfg.ProxyURL if auth proxy is not configured // 3. Use RoundTripper from context if neither are configured // +// This function caches HTTP clients by proxy URL to enable TCP/TLS connection reuse. +// // Parameters: // - ctx: The context containing optional RoundTripper // - cfg: The application configuration @@ -28,11 +37,6 @@ import ( // Returns: // - *http.Client: An HTTP client with configured proxy or transport func newProxyAwareHTTPClient(ctx context.Context, cfg *config.Config, auth *cliproxyauth.Auth, timeout time.Duration) *http.Client { - httpClient := &http.Client{} - if timeout > 0 { - httpClient.Timeout = timeout - } - // Priority 1: Use auth.ProxyURL if configured var proxyURL string if auth != nil { @@ -44,11 +48,39 @@ func newProxyAwareHTTPClient(ctx context.Context, cfg *config.Config, auth *clip proxyURL = strings.TrimSpace(cfg.ProxyURL) } + // Build cache key from proxy URL (empty string for no proxy) + cacheKey := proxyURL + + // Check cache first + httpClientCacheMutex.RLock() + if cachedClient, ok := httpClientCache[cacheKey]; ok { + httpClientCacheMutex.RUnlock() + // Return a wrapper with the requested timeout but shared transport + if timeout > 0 { + return &http.Client{ + Transport: cachedClient.Transport, + Timeout: timeout, + } + } + return cachedClient + } + httpClientCacheMutex.RUnlock() + + // Create new client + httpClient := &http.Client{} + if timeout > 0 { + httpClient.Timeout = timeout + } + // If we have a proxy URL configured, set up the transport if proxyURL != "" { transport := buildProxyTransport(proxyURL) if transport != nil { httpClient.Transport = transport + // Cache the client + httpClientCacheMutex.Lock() + httpClientCache[cacheKey] = httpClient + httpClientCacheMutex.Unlock() return httpClient } // If proxy setup failed, log and fall through to context RoundTripper @@ -60,6 +92,13 @@ func newProxyAwareHTTPClient(ctx context.Context, cfg *config.Config, auth *clip httpClient.Transport = rt } + // Cache the client for no-proxy case + if proxyURL == "" { + httpClientCacheMutex.Lock() + httpClientCache[cacheKey] = httpClient + httpClientCacheMutex.Unlock() + } + return httpClient } diff --git a/internal/translator/claude/gemini/claude_gemini_response.go b/internal/translator/claude/gemini/claude_gemini_response.go index 0c90398e..72e1820c 100644 --- a/internal/translator/claude/gemini/claude_gemini_response.go +++ b/internal/translator/claude/gemini/claude_gemini_response.go @@ -331,8 +331,9 @@ func ConvertClaudeResponseToGeminiNonStream(_ context.Context, modelName string, streamingEvents := make([][]byte, 0) scanner := bufio.NewScanner(bytes.NewReader(rawJSON)) - buffer := make([]byte, 20_971_520) - scanner.Buffer(buffer, 20_971_520) + // Use a smaller initial buffer (64KB) that can grow up to 20MB if needed + // This prevents allocating 20MB for every request regardless of size + scanner.Buffer(make([]byte, 64*1024), 20_971_520) for scanner.Scan() { line := scanner.Bytes() // log.Debug(string(line)) diff --git a/internal/translator/claude/openai/chat-completions/claude_openai_response.go b/internal/translator/claude/openai/chat-completions/claude_openai_response.go index f8fd4018..5b0238bf 100644 --- a/internal/translator/claude/openai/chat-completions/claude_openai_response.go +++ b/internal/translator/claude/openai/chat-completions/claude_openai_response.go @@ -50,6 +50,10 @@ type ToolCallAccumulator struct { // Returns: // - []string: A slice of strings, each containing an OpenAI-compatible JSON response func ConvertClaudeResponseToOpenAI(_ context.Context, modelName string, originalRequestRawJSON, requestRawJSON, rawJSON []byte, param *any) []string { + var localParam any + if param == nil { + param = &localParam + } if *param == nil { *param = &ConvertAnthropicResponseToOpenAIParams{ CreatedAt: 0, diff --git a/internal/translator/gemini-cli/claude/gemini-cli_claude_response.go b/internal/translator/gemini-cli/claude/gemini-cli_claude_response.go index 733668f3..e7f6275f 100644 --- a/internal/translator/gemini-cli/claude/gemini-cli_claude_response.go +++ b/internal/translator/gemini-cli/claude/gemini-cli_claude_response.go @@ -60,12 +60,12 @@ func ConvertGeminiCLIResponseToClaude(_ context.Context, _ string, originalReque // Track whether tools are being used in this response chunk usedTool := false - output := "" + var sb strings.Builder // Initialize the streaming session with a message_start event // This is only sent for the very first response chunk to establish the streaming session if !(*param).(*Params).HasFirstResponse { - output = "event: message_start\n" + sb.WriteString("event: message_start\n") // Create the initial message structure with default values according to Claude Code API specification // This follows the Claude Code API specification for streaming message initialization @@ -78,7 +78,7 @@ func ConvertGeminiCLIResponseToClaude(_ context.Context, _ string, originalReque if responseIDResult := gjson.GetBytes(rawJSON, "response.responseId"); responseIDResult.Exists() { messageStartTemplate, _ = sjson.Set(messageStartTemplate, "message.id", responseIDResult.String()) } - output = output + fmt.Sprintf("data: %s\n\n\n", messageStartTemplate) + sb.WriteString(fmt.Sprintf("data: %s\n\n\n", messageStartTemplate)) (*param).(*Params).HasFirstResponse = true } @@ -101,62 +101,52 @@ func ConvertGeminiCLIResponseToClaude(_ context.Context, _ string, originalReque if partResult.Get("thought").Bool() { // Continue existing thinking block if already in thinking state if (*param).(*Params).ResponseType == 2 { - output = output + "event: content_block_delta\n" + sb.WriteString("event: content_block_delta\n") data, _ := sjson.Set(fmt.Sprintf(`{"type":"content_block_delta","index":%d,"delta":{"type":"thinking_delta","thinking":""}}`, (*param).(*Params).ResponseIndex), "delta.thinking", partTextResult.String()) - output = output + fmt.Sprintf("data: %s\n\n\n", data) + sb.WriteString(fmt.Sprintf("data: %s\n\n\n", data)) } else { // Transition from another state to thinking // First, close any existing content block if (*param).(*Params).ResponseType != 0 { - if (*param).(*Params).ResponseType == 2 { - // output = output + "event: content_block_delta\n" - // output = output + fmt.Sprintf(`data: {"type":"content_block_delta","index":%d,"delta":{"type":"signature_delta","signature":null}}`, (*param).(*Params).ResponseIndex) - // output = output + "\n\n\n" - } - output = output + "event: content_block_stop\n" - output = output + fmt.Sprintf(`data: {"type":"content_block_stop","index":%d}`, (*param).(*Params).ResponseIndex) - output = output + "\n\n\n" + sb.WriteString("event: content_block_stop\n") + sb.WriteString(fmt.Sprintf(`data: {"type":"content_block_stop","index":%d}`, (*param).(*Params).ResponseIndex)) + sb.WriteString("\n\n\n") (*param).(*Params).ResponseIndex++ } // Start a new thinking content block - output = output + "event: content_block_start\n" - output = output + fmt.Sprintf(`data: {"type":"content_block_start","index":%d,"content_block":{"type":"thinking","thinking":""}}`, (*param).(*Params).ResponseIndex) - output = output + "\n\n\n" - output = output + "event: content_block_delta\n" + sb.WriteString("event: content_block_start\n") + sb.WriteString(fmt.Sprintf(`data: {"type":"content_block_start","index":%d,"content_block":{"type":"thinking","thinking":""}}`, (*param).(*Params).ResponseIndex)) + sb.WriteString("\n\n\n") + sb.WriteString("event: content_block_delta\n") data, _ := sjson.Set(fmt.Sprintf(`{"type":"content_block_delta","index":%d,"delta":{"type":"thinking_delta","thinking":""}}`, (*param).(*Params).ResponseIndex), "delta.thinking", partTextResult.String()) - output = output + fmt.Sprintf("data: %s\n\n\n", data) + sb.WriteString(fmt.Sprintf("data: %s\n\n\n", data)) (*param).(*Params).ResponseType = 2 // Set state to thinking } } else { // Process regular text content (user-visible output) // Continue existing text block if already in content state if (*param).(*Params).ResponseType == 1 { - output = output + "event: content_block_delta\n" + sb.WriteString("event: content_block_delta\n") data, _ := sjson.Set(fmt.Sprintf(`{"type":"content_block_delta","index":%d,"delta":{"type":"text_delta","text":""}}`, (*param).(*Params).ResponseIndex), "delta.text", partTextResult.String()) - output = output + fmt.Sprintf("data: %s\n\n\n", data) + sb.WriteString(fmt.Sprintf("data: %s\n\n\n", data)) } else { // Transition from another state to text content // First, close any existing content block if (*param).(*Params).ResponseType != 0 { - if (*param).(*Params).ResponseType == 2 { - // output = output + "event: content_block_delta\n" - // output = output + fmt.Sprintf(`data: {"type":"content_block_delta","index":%d,"delta":{"type":"signature_delta","signature":null}}`, (*param).(*Params).ResponseIndex) - // output = output + "\n\n\n" - } - output = output + "event: content_block_stop\n" - output = output + fmt.Sprintf(`data: {"type":"content_block_stop","index":%d}`, (*param).(*Params).ResponseIndex) - output = output + "\n\n\n" + sb.WriteString("event: content_block_stop\n") + sb.WriteString(fmt.Sprintf(`data: {"type":"content_block_stop","index":%d}`, (*param).(*Params).ResponseIndex)) + sb.WriteString("\n\n\n") (*param).(*Params).ResponseIndex++ } // Start a new text content block - output = output + "event: content_block_start\n" - output = output + fmt.Sprintf(`data: {"type":"content_block_start","index":%d,"content_block":{"type":"text","text":""}}`, (*param).(*Params).ResponseIndex) - output = output + "\n\n\n" - output = output + "event: content_block_delta\n" + sb.WriteString("event: content_block_start\n") + sb.WriteString(fmt.Sprintf(`data: {"type":"content_block_start","index":%d,"content_block":{"type":"text","text":""}}`, (*param).(*Params).ResponseIndex)) + sb.WriteString("\n\n\n") + sb.WriteString("event: content_block_delta\n") data, _ := sjson.Set(fmt.Sprintf(`{"type":"content_block_delta","index":%d,"delta":{"type":"text_delta","text":""}}`, (*param).(*Params).ResponseIndex), "delta.text", partTextResult.String()) - output = output + fmt.Sprintf("data: %s\n\n\n", data) + sb.WriteString(fmt.Sprintf("data: %s\n\n\n", data)) (*param).(*Params).ResponseType = 1 // Set state to content } } @@ -169,42 +159,35 @@ func ConvertGeminiCLIResponseToClaude(_ context.Context, _ string, originalReque // Handle state transitions when switching to function calls // Close any existing function call block first if (*param).(*Params).ResponseType == 3 { - output = output + "event: content_block_stop\n" - output = output + fmt.Sprintf(`data: {"type":"content_block_stop","index":%d}`, (*param).(*Params).ResponseIndex) - output = output + "\n\n\n" + sb.WriteString("event: content_block_stop\n") + sb.WriteString(fmt.Sprintf(`data: {"type":"content_block_stop","index":%d}`, (*param).(*Params).ResponseIndex)) + sb.WriteString("\n\n\n") (*param).(*Params).ResponseIndex++ (*param).(*Params).ResponseType = 0 } - // Special handling for thinking state transition - if (*param).(*Params).ResponseType == 2 { - // output = output + "event: content_block_delta\n" - // output = output + fmt.Sprintf(`data: {"type":"content_block_delta","index":%d,"delta":{"type":"signature_delta","signature":null}}`, (*param).(*Params).ResponseIndex) - // output = output + "\n\n\n" - } - // Close any other existing content block if (*param).(*Params).ResponseType != 0 { - output = output + "event: content_block_stop\n" - output = output + fmt.Sprintf(`data: {"type":"content_block_stop","index":%d}`, (*param).(*Params).ResponseIndex) - output = output + "\n\n\n" + sb.WriteString("event: content_block_stop\n") + sb.WriteString(fmt.Sprintf(`data: {"type":"content_block_stop","index":%d}`, (*param).(*Params).ResponseIndex)) + sb.WriteString("\n\n\n") (*param).(*Params).ResponseIndex++ } // Start a new tool use content block // This creates the structure for a function call in Claude Code format - output = output + "event: content_block_start\n" + sb.WriteString("event: content_block_start\n") // Create the tool use block with unique ID and function details data := fmt.Sprintf(`{"type":"content_block_start","index":%d,"content_block":{"type":"tool_use","id":"","name":"","input":{}}}`, (*param).(*Params).ResponseIndex) data, _ = sjson.Set(data, "content_block.id", fmt.Sprintf("%s-%d", fcName, time.Now().UnixNano())) data, _ = sjson.Set(data, "content_block.name", fcName) - output = output + fmt.Sprintf("data: %s\n\n\n", data) + sb.WriteString(fmt.Sprintf("data: %s\n\n\n", data)) if fcArgsResult := functionCallResult.Get("args"); fcArgsResult.Exists() { - output = output + "event: content_block_delta\n" + sb.WriteString("event: content_block_delta\n") data, _ = sjson.Set(fmt.Sprintf(`{"type":"content_block_delta","index":%d,"delta":{"type":"input_json_delta","partial_json":""}}`, (*param).(*Params).ResponseIndex), "delta.partial_json", fcArgsResult.Raw) - output = output + fmt.Sprintf("data: %s\n\n\n", data) + sb.WriteString(fmt.Sprintf("data: %s\n\n\n", data)) } (*param).(*Params).ResponseType = 3 } @@ -216,13 +199,13 @@ func ConvertGeminiCLIResponseToClaude(_ context.Context, _ string, originalReque if usageResult.Exists() && bytes.Contains(rawJSON, []byte(`"finishReason"`)) { if candidatesTokenCountResult := usageResult.Get("candidatesTokenCount"); candidatesTokenCountResult.Exists() { // Close the final content block - output = output + "event: content_block_stop\n" - output = output + fmt.Sprintf(`data: {"type":"content_block_stop","index":%d}`, (*param).(*Params).ResponseIndex) - output = output + "\n\n\n" + sb.WriteString("event: content_block_stop\n") + sb.WriteString(fmt.Sprintf(`data: {"type":"content_block_stop","index":%d}`, (*param).(*Params).ResponseIndex)) + sb.WriteString("\n\n\n") // Send the final message delta with usage information and stop reason - output = output + "event: message_delta\n" - output = output + `data: ` + sb.WriteString("event: message_delta\n") + sb.WriteString(`data: `) // Create the message delta template with appropriate stop reason template := `{"type":"message_delta","delta":{"stop_reason":"end_turn","stop_sequence":null},"usage":{"input_tokens":0,"output_tokens":0}}` @@ -236,11 +219,11 @@ func ConvertGeminiCLIResponseToClaude(_ context.Context, _ string, originalReque template, _ = sjson.Set(template, "usage.output_tokens", candidatesTokenCountResult.Int()+thoughtsTokenCount) template, _ = sjson.Set(template, "usage.input_tokens", usageResult.Get("promptTokenCount").Int()) - output = output + template + "\n\n\n" + sb.WriteString(template + "\n\n\n") } } - return []string{output} + return []string{sb.String()} } // ConvertGeminiCLIResponseToClaudeNonStream converts a non-streaming Gemini CLI response to a non-streaming Claude response. diff --git a/internal/translator/init.go b/internal/translator/init.go index 084ea7ac..d19d9b34 100644 --- a/internal/translator/init.go +++ b/internal/translator/init.go @@ -33,4 +33,7 @@ import ( _ "github.com/router-for-me/CLIProxyAPI/v6/internal/translator/antigravity/gemini" _ "github.com/router-for-me/CLIProxyAPI/v6/internal/translator/antigravity/openai/chat-completions" _ "github.com/router-for-me/CLIProxyAPI/v6/internal/translator/antigravity/openai/responses" + + _ "github.com/router-for-me/CLIProxyAPI/v6/internal/translator/kiro/claude" + _ "github.com/router-for-me/CLIProxyAPI/v6/internal/translator/kiro/openai/chat-completions" ) diff --git a/internal/translator/kiro/claude/init.go b/internal/translator/kiro/claude/init.go new file mode 100644 index 00000000..9e3a2ba3 --- /dev/null +++ b/internal/translator/kiro/claude/init.go @@ -0,0 +1,19 @@ +package claude + +import ( + . "github.com/router-for-me/CLIProxyAPI/v6/internal/constant" + "github.com/router-for-me/CLIProxyAPI/v6/internal/interfaces" + "github.com/router-for-me/CLIProxyAPI/v6/internal/translator/translator" +) + +func init() { + translator.Register( + Claude, + Kiro, + ConvertClaudeRequestToKiro, + interfaces.TranslateResponse{ + Stream: ConvertKiroResponseToClaude, + NonStream: ConvertKiroResponseToClaudeNonStream, + }, + ) +} diff --git a/internal/translator/kiro/claude/kiro_claude.go b/internal/translator/kiro/claude/kiro_claude.go new file mode 100644 index 00000000..9922860e --- /dev/null +++ b/internal/translator/kiro/claude/kiro_claude.go @@ -0,0 +1,24 @@ +// Package claude provides translation between Kiro and Claude formats. +// Since Kiro uses Claude-compatible format internally, translations are mostly pass-through. +package claude + +import ( + "bytes" + "context" +) + +// ConvertClaudeRequestToKiro converts Claude request to Kiro format. +// Since Kiro uses Claude format internally, this is mostly a pass-through. +func ConvertClaudeRequestToKiro(modelName string, inputRawJSON []byte, stream bool) []byte { + return bytes.Clone(inputRawJSON) +} + +// ConvertKiroResponseToClaude converts Kiro streaming response to Claude format. +func ConvertKiroResponseToClaude(ctx context.Context, model string, originalRequest, request, rawResponse []byte, param *any) []string { + return []string{string(rawResponse)} +} + +// ConvertKiroResponseToClaudeNonStream converts Kiro non-streaming response to Claude format. +func ConvertKiroResponseToClaudeNonStream(ctx context.Context, model string, originalRequest, request, rawResponse []byte, param *any) string { + return string(rawResponse) +} diff --git a/internal/translator/kiro/openai/chat-completions/init.go b/internal/translator/kiro/openai/chat-completions/init.go new file mode 100644 index 00000000..2a99d0e0 --- /dev/null +++ b/internal/translator/kiro/openai/chat-completions/init.go @@ -0,0 +1,19 @@ +package chat_completions + +import ( + . "github.com/router-for-me/CLIProxyAPI/v6/internal/constant" + "github.com/router-for-me/CLIProxyAPI/v6/internal/interfaces" + "github.com/router-for-me/CLIProxyAPI/v6/internal/translator/translator" +) + +func init() { + translator.Register( + OpenAI, + Kiro, + ConvertOpenAIRequestToKiro, + interfaces.TranslateResponse{ + Stream: ConvertKiroResponseToOpenAI, + NonStream: ConvertKiroResponseToOpenAINonStream, + }, + ) +} diff --git a/internal/translator/kiro/openai/chat-completions/kiro_openai_request.go b/internal/translator/kiro/openai/chat-completions/kiro_openai_request.go new file mode 100644 index 00000000..fc850d96 --- /dev/null +++ b/internal/translator/kiro/openai/chat-completions/kiro_openai_request.go @@ -0,0 +1,258 @@ +// Package chat_completions provides request translation from OpenAI to Kiro format. +package chat_completions + +import ( + "bytes" + "encoding/json" + + "github.com/tidwall/gjson" + "github.com/tidwall/sjson" +) + +// ConvertOpenAIRequestToKiro transforms an OpenAI Chat Completions API request into Kiro (Claude) format. +// Kiro uses Claude-compatible format internally, so we primarily pass through to Claude format. +// Supports tool calling: OpenAI tools -> Claude tools, tool_calls -> tool_use, tool messages -> tool_result. +func ConvertOpenAIRequestToKiro(modelName string, inputRawJSON []byte, stream bool) []byte { + rawJSON := bytes.Clone(inputRawJSON) + root := gjson.ParseBytes(rawJSON) + + // Build Claude-compatible request + out := `{"model":"","max_tokens":32000,"messages":[]}` + + // Set model + out, _ = sjson.Set(out, "model", modelName) + + // Copy max_tokens if present + if v := root.Get("max_tokens"); v.Exists() { + out, _ = sjson.Set(out, "max_tokens", v.Int()) + } + + // Copy temperature if present + if v := root.Get("temperature"); v.Exists() { + out, _ = sjson.Set(out, "temperature", v.Float()) + } + + // Copy top_p if present + if v := root.Get("top_p"); v.Exists() { + out, _ = sjson.Set(out, "top_p", v.Float()) + } + + // Convert OpenAI tools to Claude tools format + if tools := root.Get("tools"); tools.Exists() && tools.IsArray() { + claudeTools := make([]interface{}, 0) + for _, tool := range tools.Array() { + if tool.Get("type").String() == "function" { + fn := tool.Get("function") + claudeTool := map[string]interface{}{ + "name": fn.Get("name").String(), + "description": fn.Get("description").String(), + } + // Convert parameters to input_schema + if params := fn.Get("parameters"); params.Exists() { + claudeTool["input_schema"] = params.Value() + } else { + claudeTool["input_schema"] = map[string]interface{}{ + "type": "object", + "properties": map[string]interface{}{}, + } + } + claudeTools = append(claudeTools, claudeTool) + } + } + if len(claudeTools) > 0 { + out, _ = sjson.Set(out, "tools", claudeTools) + } + } + + // Process messages + messages := root.Get("messages") + if messages.Exists() && messages.IsArray() { + claudeMessages := make([]interface{}, 0) + var systemPrompt string + + // Track pending tool results to merge with next user message + var pendingToolResults []map[string]interface{} + + for _, msg := range messages.Array() { + role := msg.Get("role").String() + content := msg.Get("content") + + if role == "system" { + // Extract system message + if content.IsArray() { + for _, part := range content.Array() { + if part.Get("type").String() == "text" { + systemPrompt += part.Get("text").String() + "\n" + } + } + } else { + systemPrompt = content.String() + } + continue + } + + if role == "tool" { + // OpenAI tool message -> Claude tool_result content block + toolCallID := msg.Get("tool_call_id").String() + toolContent := content.String() + + toolResult := map[string]interface{}{ + "type": "tool_result", + "tool_use_id": toolCallID, + } + + // Handle content - can be string or structured + if content.IsArray() { + contentParts := make([]interface{}, 0) + for _, part := range content.Array() { + if part.Get("type").String() == "text" { + contentParts = append(contentParts, map[string]interface{}{ + "type": "text", + "text": part.Get("text").String(), + }) + } + } + toolResult["content"] = contentParts + } else { + toolResult["content"] = toolContent + } + + pendingToolResults = append(pendingToolResults, toolResult) + continue + } + + claudeMsg := map[string]interface{}{ + "role": role, + } + + // Handle assistant messages with tool_calls + if role == "assistant" && msg.Get("tool_calls").Exists() { + contentParts := make([]interface{}, 0) + + // Add text content if present + if content.Exists() && content.String() != "" { + contentParts = append(contentParts, map[string]interface{}{ + "type": "text", + "text": content.String(), + }) + } + + // Convert tool_calls to tool_use blocks + for _, toolCall := range msg.Get("tool_calls").Array() { + toolUseID := toolCall.Get("id").String() + fnName := toolCall.Get("function.name").String() + fnArgs := toolCall.Get("function.arguments").String() + + // Parse arguments JSON + var argsMap map[string]interface{} + if err := json.Unmarshal([]byte(fnArgs), &argsMap); err != nil { + argsMap = map[string]interface{}{"raw": fnArgs} + } + + contentParts = append(contentParts, map[string]interface{}{ + "type": "tool_use", + "id": toolUseID, + "name": fnName, + "input": argsMap, + }) + } + + claudeMsg["content"] = contentParts + claudeMessages = append(claudeMessages, claudeMsg) + continue + } + + // Handle user messages - may need to include pending tool results + if role == "user" && len(pendingToolResults) > 0 { + contentParts := make([]interface{}, 0) + + // Add pending tool results first + for _, tr := range pendingToolResults { + contentParts = append(contentParts, tr) + } + pendingToolResults = nil + + // Add user content + if content.IsArray() { + for _, part := range content.Array() { + partType := part.Get("type").String() + if partType == "text" { + contentParts = append(contentParts, map[string]interface{}{ + "type": "text", + "text": part.Get("text").String(), + }) + } else if partType == "image_url" { + contentParts = append(contentParts, map[string]interface{}{ + "type": "image", + "source": map[string]interface{}{ + "type": "url", + "url": part.Get("image_url.url").String(), + }, + }) + } + } + } else if content.String() != "" { + contentParts = append(contentParts, map[string]interface{}{ + "type": "text", + "text": content.String(), + }) + } + + claudeMsg["content"] = contentParts + claudeMessages = append(claudeMessages, claudeMsg) + continue + } + + // Handle regular content + if content.IsArray() { + contentParts := make([]interface{}, 0) + for _, part := range content.Array() { + partType := part.Get("type").String() + if partType == "text" { + contentParts = append(contentParts, map[string]interface{}{ + "type": "text", + "text": part.Get("text").String(), + }) + } else if partType == "image_url" { + contentParts = append(contentParts, map[string]interface{}{ + "type": "image", + "source": map[string]interface{}{ + "type": "url", + "url": part.Get("image_url.url").String(), + }, + }) + } + } + claudeMsg["content"] = contentParts + } else { + claudeMsg["content"] = content.String() + } + + claudeMessages = append(claudeMessages, claudeMsg) + } + + // If there are pending tool results without a following user message, + // create a user message with just the tool results + if len(pendingToolResults) > 0 { + contentParts := make([]interface{}, 0) + for _, tr := range pendingToolResults { + contentParts = append(contentParts, tr) + } + claudeMessages = append(claudeMessages, map[string]interface{}{ + "role": "user", + "content": contentParts, + }) + } + + out, _ = sjson.Set(out, "messages", claudeMessages) + + if systemPrompt != "" { + out, _ = sjson.Set(out, "system", systemPrompt) + } + } + + // Set stream + out, _ = sjson.Set(out, "stream", stream) + + return []byte(out) +} diff --git a/internal/translator/kiro/openai/chat-completions/kiro_openai_response.go b/internal/translator/kiro/openai/chat-completions/kiro_openai_response.go new file mode 100644 index 00000000..6a0ad250 --- /dev/null +++ b/internal/translator/kiro/openai/chat-completions/kiro_openai_response.go @@ -0,0 +1,316 @@ +// Package chat_completions provides response translation from Kiro to OpenAI format. +package chat_completions + +import ( + "context" + "encoding/json" + "time" + + "github.com/google/uuid" + "github.com/tidwall/gjson" +) + +// ConvertKiroResponseToOpenAI converts Kiro streaming response to OpenAI SSE format. +// Handles Claude SSE events: content_block_start, content_block_delta, input_json_delta, +// content_block_stop, message_delta, and message_stop. +func ConvertKiroResponseToOpenAI(ctx context.Context, model string, originalRequest, request, rawResponse []byte, param *any) []string { + root := gjson.ParseBytes(rawResponse) + var results []string + + eventType := root.Get("type").String() + + switch eventType { + case "message_start": + // Initial message event - could emit initial chunk if needed + return results + + case "content_block_start": + // Start of a content block (text or tool_use) + blockType := root.Get("content_block.type").String() + index := int(root.Get("index").Int()) + + if blockType == "tool_use" { + // Start of tool_use block + toolUseID := root.Get("content_block.id").String() + toolName := root.Get("content_block.name").String() + + toolCall := map[string]interface{}{ + "index": index, + "id": toolUseID, + "type": "function", + "function": map[string]interface{}{ + "name": toolName, + "arguments": "", + }, + } + + response := map[string]interface{}{ + "id": "chatcmpl-" + uuid.New().String()[:24], + "object": "chat.completion.chunk", + "created": time.Now().Unix(), + "model": model, + "choices": []map[string]interface{}{ + { + "index": 0, + "delta": map[string]interface{}{ + "tool_calls": []map[string]interface{}{toolCall}, + }, + "finish_reason": nil, + }, + }, + } + result, _ := json.Marshal(response) + results = append(results, string(result)) + } + return results + + case "content_block_delta": + index := int(root.Get("index").Int()) + deltaType := root.Get("delta.type").String() + + if deltaType == "text_delta" { + // Text content delta + contentDelta := root.Get("delta.text").String() + if contentDelta != "" { + response := map[string]interface{}{ + "id": "chatcmpl-" + uuid.New().String()[:24], + "object": "chat.completion.chunk", + "created": time.Now().Unix(), + "model": model, + "choices": []map[string]interface{}{ + { + "index": 0, + "delta": map[string]interface{}{ + "content": contentDelta, + }, + "finish_reason": nil, + }, + }, + } + result, _ := json.Marshal(response) + results = append(results, string(result)) + } + } else if deltaType == "input_json_delta" { + // Tool input delta (streaming arguments) + partialJSON := root.Get("delta.partial_json").String() + if partialJSON != "" { + toolCall := map[string]interface{}{ + "index": index, + "function": map[string]interface{}{ + "arguments": partialJSON, + }, + } + + response := map[string]interface{}{ + "id": "chatcmpl-" + uuid.New().String()[:24], + "object": "chat.completion.chunk", + "created": time.Now().Unix(), + "model": model, + "choices": []map[string]interface{}{ + { + "index": 0, + "delta": map[string]interface{}{ + "tool_calls": []map[string]interface{}{toolCall}, + }, + "finish_reason": nil, + }, + }, + } + result, _ := json.Marshal(response) + results = append(results, string(result)) + } + } + return results + + case "content_block_stop": + // End of content block - no output needed for OpenAI format + return results + + case "message_delta": + // Final message delta with stop_reason + stopReason := root.Get("delta.stop_reason").String() + if stopReason != "" { + finishReason := "stop" + if stopReason == "tool_use" { + finishReason = "tool_calls" + } else if stopReason == "end_turn" { + finishReason = "stop" + } else if stopReason == "max_tokens" { + finishReason = "length" + } + + response := map[string]interface{}{ + "id": "chatcmpl-" + uuid.New().String()[:24], + "object": "chat.completion.chunk", + "created": time.Now().Unix(), + "model": model, + "choices": []map[string]interface{}{ + { + "index": 0, + "delta": map[string]interface{}{}, + "finish_reason": finishReason, + }, + }, + } + result, _ := json.Marshal(response) + results = append(results, string(result)) + } + return results + + case "message_stop": + // End of message - could emit [DONE] marker + return results + } + + // Fallback: handle raw content for backward compatibility + var contentDelta string + if delta := root.Get("delta.text"); delta.Exists() { + contentDelta = delta.String() + } else if content := root.Get("content"); content.Exists() && root.Get("type").String() == "" { + contentDelta = content.String() + } + + if contentDelta != "" { + response := map[string]interface{}{ + "id": "chatcmpl-" + uuid.New().String()[:24], + "object": "chat.completion.chunk", + "created": time.Now().Unix(), + "model": model, + "choices": []map[string]interface{}{ + { + "index": 0, + "delta": map[string]interface{}{ + "content": contentDelta, + }, + "finish_reason": nil, + }, + }, + } + result, _ := json.Marshal(response) + results = append(results, string(result)) + } + + // Handle tool_use content blocks (Claude format) - fallback + toolUses := root.Get("delta.tool_use") + if !toolUses.Exists() { + toolUses = root.Get("tool_use") + } + if toolUses.Exists() && toolUses.IsObject() { + inputJSON := toolUses.Get("input").String() + if inputJSON == "" { + if inputObj := toolUses.Get("input"); inputObj.Exists() { + inputBytes, _ := json.Marshal(inputObj.Value()) + inputJSON = string(inputBytes) + } + } + + toolCall := map[string]interface{}{ + "index": 0, + "id": toolUses.Get("id").String(), + "type": "function", + "function": map[string]interface{}{ + "name": toolUses.Get("name").String(), + "arguments": inputJSON, + }, + } + + response := map[string]interface{}{ + "id": "chatcmpl-" + uuid.New().String()[:24], + "object": "chat.completion.chunk", + "created": time.Now().Unix(), + "model": model, + "choices": []map[string]interface{}{ + { + "index": 0, + "delta": map[string]interface{}{ + "tool_calls": []map[string]interface{}{toolCall}, + }, + "finish_reason": nil, + }, + }, + } + result, _ := json.Marshal(response) + results = append(results, string(result)) + } + + return results +} + +// ConvertKiroResponseToOpenAINonStream converts Kiro non-streaming response to OpenAI format. +func ConvertKiroResponseToOpenAINonStream(ctx context.Context, model string, originalRequest, request, rawResponse []byte, param *any) string { + root := gjson.ParseBytes(rawResponse) + + var content string + var toolCalls []map[string]interface{} + + contentArray := root.Get("content") + if contentArray.IsArray() { + for _, item := range contentArray.Array() { + itemType := item.Get("type").String() + if itemType == "text" { + content += item.Get("text").String() + } else if itemType == "tool_use" { + // Convert Claude tool_use to OpenAI tool_calls format + inputJSON := item.Get("input").String() + if inputJSON == "" { + // If input is an object, marshal it + if inputObj := item.Get("input"); inputObj.Exists() { + inputBytes, _ := json.Marshal(inputObj.Value()) + inputJSON = string(inputBytes) + } + } + toolCall := map[string]interface{}{ + "id": item.Get("id").String(), + "type": "function", + "function": map[string]interface{}{ + "name": item.Get("name").String(), + "arguments": inputJSON, + }, + } + toolCalls = append(toolCalls, toolCall) + } + } + } else { + content = root.Get("content").String() + } + + inputTokens := root.Get("usage.input_tokens").Int() + outputTokens := root.Get("usage.output_tokens").Int() + + message := map[string]interface{}{ + "role": "assistant", + "content": content, + } + + // Add tool_calls if present + if len(toolCalls) > 0 { + message["tool_calls"] = toolCalls + } + + finishReason := "stop" + if len(toolCalls) > 0 { + finishReason = "tool_calls" + } + + response := map[string]interface{}{ + "id": "chatcmpl-" + uuid.New().String()[:24], + "object": "chat.completion", + "created": time.Now().Unix(), + "model": model, + "choices": []map[string]interface{}{ + { + "index": 0, + "message": message, + "finish_reason": finishReason, + }, + }, + "usage": map[string]interface{}{ + "prompt_tokens": inputTokens, + "completion_tokens": outputTokens, + "total_tokens": inputTokens + outputTokens, + }, + } + + result, _ := json.Marshal(response) + return string(result) +} diff --git a/internal/watcher/watcher.go b/internal/watcher/watcher.go index 1f4f9043..da152141 100644 --- a/internal/watcher/watcher.go +++ b/internal/watcher/watcher.go @@ -21,6 +21,7 @@ import ( "github.com/fsnotify/fsnotify" "github.com/router-for-me/CLIProxyAPI/v6/internal/config" + kiroauth "github.com/router-for-me/CLIProxyAPI/v6/internal/auth/kiro" "github.com/router-for-me/CLIProxyAPI/v6/internal/runtime/geminicli" "gopkg.in/yaml.v3" @@ -176,6 +177,9 @@ func (w *Watcher) Start(ctx context.Context) error { } log.Debugf("watching auth directory: %s", w.authDir) + // Watch Kiro IDE token file directory for automatic token updates + w.watchKiroIDETokenFile() + // Start the event processing goroutine go w.processEvents(ctx) @@ -184,6 +188,31 @@ func (w *Watcher) Start(ctx context.Context) error { return nil } +// watchKiroIDETokenFile adds the Kiro IDE token file directory to the watcher. +// This enables automatic detection of token updates from Kiro IDE. +func (w *Watcher) watchKiroIDETokenFile() { + homeDir, err := os.UserHomeDir() + if err != nil { + log.Debugf("failed to get home directory for Kiro IDE token watch: %v", err) + return + } + + // Kiro IDE stores tokens in ~/.aws/sso/cache/ + kiroTokenDir := filepath.Join(homeDir, ".aws", "sso", "cache") + + // Check if directory exists + if _, statErr := os.Stat(kiroTokenDir); os.IsNotExist(statErr) { + log.Debugf("Kiro IDE token directory does not exist: %s", kiroTokenDir) + return + } + + if errAdd := w.watcher.Add(kiroTokenDir); errAdd != nil { + log.Debugf("failed to watch Kiro IDE token directory %s: %v", kiroTokenDir, errAdd) + return + } + log.Debugf("watching Kiro IDE token directory: %s", kiroTokenDir) +} + // Stop stops the file watcher func (w *Watcher) Stop() error { w.stopDispatch() @@ -744,10 +773,20 @@ func (w *Watcher) handleEvent(event fsnotify.Event) { isConfigEvent := event.Name == w.configPath && event.Op&configOps != 0 authOps := fsnotify.Create | fsnotify.Write | fsnotify.Remove | fsnotify.Rename isAuthJSON := strings.HasPrefix(event.Name, w.authDir) && strings.HasSuffix(event.Name, ".json") && event.Op&authOps != 0 - if !isConfigEvent && !isAuthJSON { + + // Check for Kiro IDE token file changes + isKiroIDEToken := w.isKiroIDETokenFile(event.Name) && event.Op&authOps != 0 + + if !isConfigEvent && !isAuthJSON && !isKiroIDEToken { // Ignore unrelated files (e.g., cookie snapshots *.cookie) and other noise. return } + + // Handle Kiro IDE token file changes + if isKiroIDEToken { + w.handleKiroIDETokenChange(event) + return + } now := time.Now() log.Debugf("file system event detected: %s %s", event.Op.String(), event.Name) @@ -805,6 +844,51 @@ func (w *Watcher) scheduleConfigReload() { }) } +// isKiroIDETokenFile checks if the given path is the Kiro IDE token file. +func (w *Watcher) isKiroIDETokenFile(path string) bool { + // Check if it's the kiro-auth-token.json file in ~/.aws/sso/cache/ + // Use filepath.ToSlash to ensure consistent separators across platforms (Windows uses backslashes) + normalized := filepath.ToSlash(path) + return strings.HasSuffix(normalized, "kiro-auth-token.json") && strings.Contains(normalized, ".aws/sso/cache") +} + +// handleKiroIDETokenChange processes changes to the Kiro IDE token file. +// When the token file is updated by Kiro IDE, this triggers a reload of Kiro auth. +func (w *Watcher) handleKiroIDETokenChange(event fsnotify.Event) { + log.Debugf("Kiro IDE token file event detected: %s %s", event.Op.String(), event.Name) + + if event.Op&(fsnotify.Remove|fsnotify.Rename) != 0 { + // Token file removed - wait briefly for potential atomic replace + time.Sleep(replaceCheckDelay) + if _, statErr := os.Stat(event.Name); statErr != nil { + log.Debugf("Kiro IDE token file removed: %s", event.Name) + return + } + } + + // Try to load the updated token + tokenData, err := kiroauth.LoadKiroIDEToken() + if err != nil { + log.Debugf("failed to load Kiro IDE token after change: %v", err) + return + } + + log.Infof("Kiro IDE token file updated, access token refreshed (provider: %s)", tokenData.Provider) + + // Trigger auth state refresh to pick up the new token + w.refreshAuthState() + + // Notify callback if set + w.clientsMutex.RLock() + cfg := w.config + w.clientsMutex.RUnlock() + + if w.reloadCallback != nil && cfg != nil { + log.Debugf("triggering server update callback after Kiro IDE token change") + w.reloadCallback(cfg) + } +} + func (w *Watcher) reloadConfigIfChanged() { data, err := os.ReadFile(w.configPath) if err != nil { @@ -1181,6 +1265,67 @@ func (w *Watcher) SnapshotCoreAuths() []*coreauth.Auth { applyAuthExcludedModelsMeta(a, cfg, ck.ExcludedModels, "apikey") out = append(out, a) } + // Kiro (AWS CodeWhisperer) -> synthesize auths + var kAuth *kiroauth.KiroAuth + if len(cfg.KiroKey) > 0 { + kAuth = kiroauth.NewKiroAuth(cfg) + } + for i := range cfg.KiroKey { + kk := cfg.KiroKey[i] + var accessToken, profileArn string + + // Try to load from token file first + if kk.TokenFile != "" && kAuth != nil { + tokenData, err := kAuth.LoadTokenFromFile(kk.TokenFile) + if err != nil { + log.Warnf("failed to load kiro token file %s: %v", kk.TokenFile, err) + } else { + accessToken = tokenData.AccessToken + profileArn = tokenData.ProfileArn + } + } + + // Override with direct config values if provided + if kk.AccessToken != "" { + accessToken = kk.AccessToken + } + if kk.ProfileArn != "" { + profileArn = kk.ProfileArn + } + + if accessToken == "" { + log.Warnf("kiro config[%d] missing access_token, skipping", i) + continue + } + + // profileArn is optional for AWS Builder ID users + id, token := idGen.next("kiro:token", accessToken, profileArn) + attrs := map[string]string{ + "source": fmt.Sprintf("config:kiro[%s]", token), + "access_token": accessToken, + } + if profileArn != "" { + attrs["profile_arn"] = profileArn + } + if kk.Region != "" { + attrs["region"] = kk.Region + } + if kk.AgentTaskType != "" { + attrs["agent_task_type"] = kk.AgentTaskType + } + proxyURL := strings.TrimSpace(kk.ProxyURL) + a := &coreauth.Auth{ + ID: id, + Provider: "kiro", + Label: "kiro-token", + Status: coreauth.StatusActive, + ProxyURL: proxyURL, + Attributes: attrs, + CreatedAt: now, + UpdatedAt: now, + } + out = append(out, a) + } for i := range cfg.OpenAICompatibility { compat := &cfg.OpenAICompatibility[i] providerName := strings.ToLower(strings.TrimSpace(compat.Name)) @@ -1287,7 +1432,12 @@ func (w *Watcher) SnapshotCoreAuths() []*coreauth.Auth { } // Also synthesize auth entries directly from auth files (for OAuth/file-backed providers) - entries, _ := os.ReadDir(w.authDir) + log.Debugf("SnapshotCoreAuths: scanning auth directory: %s", w.authDir) + entries, readErr := os.ReadDir(w.authDir) + if readErr != nil { + log.Errorf("SnapshotCoreAuths: failed to read auth directory %s: %v", w.authDir, readErr) + } + log.Debugf("SnapshotCoreAuths: found %d entries in auth directory", len(entries)) for _, e := range entries { if e.IsDir() { continue @@ -1306,9 +1456,20 @@ func (w *Watcher) SnapshotCoreAuths() []*coreauth.Auth { continue } t, _ := metadata["type"].(string) + + // Detect Kiro auth files by auth_method field (they don't have "type" field) if t == "" { + if authMethod, _ := metadata["auth_method"].(string); authMethod == "builder-id" || authMethod == "social" { + t = "kiro" + log.Debugf("SnapshotCoreAuths: detected Kiro auth by auth_method: %s", name) + } + } + + if t == "" { + log.Debugf("SnapshotCoreAuths: skipping file without type: %s", name) continue } + log.Debugf("SnapshotCoreAuths: processing auth file: %s (type=%s)", name, t) provider := strings.ToLower(t) if provider == "gemini" { provider = "gemini-cli" @@ -1317,6 +1478,12 @@ func (w *Watcher) SnapshotCoreAuths() []*coreauth.Auth { if email, _ := metadata["email"].(string); email != "" { label = email } + // For Kiro, use provider field as label if available + if provider == "kiro" { + if kiroProvider, _ := metadata["provider"].(string); kiroProvider != "" { + label = fmt.Sprintf("kiro-%s", strings.ToLower(kiroProvider)) + } + } // Use relative path under authDir as ID to stay consistent with the file-based token store id := full if rel, errRel := filepath.Rel(w.authDir, full); errRel == nil && rel != "" { @@ -1342,6 +1509,16 @@ func (w *Watcher) SnapshotCoreAuths() []*coreauth.Auth { CreatedAt: now, UpdatedAt: now, } + // Set NextRefreshAfter for Kiro auth based on expires_at + if provider == "kiro" { + if expiresAtStr, ok := metadata["expires_at"].(string); ok && expiresAtStr != "" { + if expiresAt, parseErr := time.Parse(time.RFC3339, expiresAtStr); parseErr == nil { + // Refresh 30 minutes before expiry + a.NextRefreshAfter = expiresAt.Add(-30 * time.Minute) + } + } + } + applyAuthExcludedModelsMeta(a, cfg, nil, "oauth") if provider == "gemini-cli" { if virtuals := synthesizeGeminiVirtualAuths(a, metadata, now); len(virtuals) > 0 { diff --git a/sdk/api/handlers/handlers.go b/sdk/api/handlers/handlers.go index 76280b3a..3bbf6f61 100644 --- a/sdk/api/handlers/handlers.go +++ b/sdk/api/handlers/handlers.go @@ -8,6 +8,7 @@ import ( "fmt" "net/http" "strings" + "sync" "github.com/gin-gonic/gin" "github.com/router-for-me/CLIProxyAPI/v6/internal/interfaces" @@ -50,7 +51,25 @@ type BaseAPIHandler struct { Cfg *config.SDKConfig // OpenAICompatProviders is a list of provider names for OpenAI compatibility. - OpenAICompatProviders []string + openAICompatProviders []string + openAICompatMutex sync.RWMutex +} + +// GetOpenAICompatProviders safely returns a copy of the provider names +func (h *BaseAPIHandler) GetOpenAICompatProviders() []string { + h.openAICompatMutex.RLock() + defer h.openAICompatMutex.RUnlock() + result := make([]string, len(h.openAICompatProviders)) + copy(result, h.openAICompatProviders) + return result +} + +// SetOpenAICompatProviders safely sets the provider names +func (h *BaseAPIHandler) SetOpenAICompatProviders(providers []string) { + h.openAICompatMutex.Lock() + defer h.openAICompatMutex.Unlock() + h.openAICompatProviders = make([]string, len(providers)) + copy(h.openAICompatProviders, providers) } // NewBaseAPIHandlers creates a new API handlers instance. @@ -63,11 +82,12 @@ type BaseAPIHandler struct { // Returns: // - *BaseAPIHandler: A new API handlers instance func NewBaseAPIHandlers(cfg *config.SDKConfig, authManager *coreauth.Manager, openAICompatProviders []string) *BaseAPIHandler { - return &BaseAPIHandler{ - Cfg: cfg, - AuthManager: authManager, - OpenAICompatProviders: openAICompatProviders, + h := &BaseAPIHandler{ + Cfg: cfg, + AuthManager: authManager, } + h.SetOpenAICompatProviders(openAICompatProviders) + return h } // UpdateClients updates the handlers' client list and configuration. @@ -363,7 +383,7 @@ func (h *BaseAPIHandler) parseDynamicModel(modelName string) (providerName, mode } // Check if the provider is a configured openai-compatibility provider - for _, pName := range h.OpenAICompatProviders { + for _, pName := range h.GetOpenAICompatProviders() { if pName == providerPart { return providerPart, modelPart, true } diff --git a/sdk/auth/kiro.go b/sdk/auth/kiro.go new file mode 100644 index 00000000..b95d103b --- /dev/null +++ b/sdk/auth/kiro.go @@ -0,0 +1,357 @@ +package auth + +import ( + "context" + "fmt" + "strings" + "time" + + kiroauth "github.com/router-for-me/CLIProxyAPI/v6/internal/auth/kiro" + "github.com/router-for-me/CLIProxyAPI/v6/internal/config" + coreauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth" +) + +// extractKiroIdentifier extracts a meaningful identifier for file naming. +// Returns account name if provided, otherwise profile ARN ID. +// All extracted values are sanitized to prevent path injection attacks. +func extractKiroIdentifier(accountName, profileArn string) string { + // Priority 1: Use account name if provided + if accountName != "" { + return kiroauth.SanitizeEmailForFilename(accountName) + } + + // Priority 2: Use profile ARN ID part (sanitized to prevent path injection) + if profileArn != "" { + parts := strings.Split(profileArn, "/") + if len(parts) >= 2 { + // Sanitize the ARN component to prevent path traversal + return kiroauth.SanitizeEmailForFilename(parts[len(parts)-1]) + } + } + + // Fallback: timestamp + return fmt.Sprintf("%d", time.Now().UnixNano()%100000) +} + +// KiroAuthenticator implements OAuth authentication for Kiro with Google login. +type KiroAuthenticator struct{} + +// NewKiroAuthenticator constructs a Kiro authenticator. +func NewKiroAuthenticator() *KiroAuthenticator { + return &KiroAuthenticator{} +} + +// Provider returns the provider key for the authenticator. +func (a *KiroAuthenticator) Provider() string { + return "kiro" +} + +// RefreshLead indicates how soon before expiry a refresh should be attempted. +func (a *KiroAuthenticator) RefreshLead() *time.Duration { + d := 30 * time.Minute + return &d +} + +// Login performs OAuth login for Kiro with AWS Builder ID. +func (a *KiroAuthenticator) Login(ctx context.Context, cfg *config.Config, opts *LoginOptions) (*coreauth.Auth, error) { + if cfg == nil { + return nil, fmt.Errorf("kiro auth: configuration is required") + } + + oauth := kiroauth.NewKiroOAuth(cfg) + + // Use AWS Builder ID device code flow + tokenData, err := oauth.LoginWithBuilderID(ctx) + if err != nil { + return nil, fmt.Errorf("login failed: %w", err) + } + + // Parse expires_at + expiresAt, err := time.Parse(time.RFC3339, tokenData.ExpiresAt) + if err != nil { + expiresAt = time.Now().Add(1 * time.Hour) + } + + // Extract identifier for file naming + idPart := extractKiroIdentifier(tokenData.Email, tokenData.ProfileArn) + + now := time.Now() + fileName := fmt.Sprintf("kiro-aws-%s.json", idPart) + + record := &coreauth.Auth{ + ID: fileName, + Provider: "kiro", + FileName: fileName, + Label: "kiro-aws", + Status: coreauth.StatusActive, + CreatedAt: now, + UpdatedAt: now, + Metadata: map[string]any{ + "type": "kiro", + "access_token": tokenData.AccessToken, + "refresh_token": tokenData.RefreshToken, + "profile_arn": tokenData.ProfileArn, + "expires_at": tokenData.ExpiresAt, + "auth_method": tokenData.AuthMethod, + "provider": tokenData.Provider, + "client_id": tokenData.ClientID, + "client_secret": tokenData.ClientSecret, + "email": tokenData.Email, + }, + Attributes: map[string]string{ + "profile_arn": tokenData.ProfileArn, + "source": "aws-builder-id", + "email": tokenData.Email, + }, + NextRefreshAfter: expiresAt.Add(-30 * time.Minute), + } + + if tokenData.Email != "" { + fmt.Printf("\n✓ Kiro authentication completed successfully! (Account: %s)\n", tokenData.Email) + } else { + fmt.Println("\n✓ Kiro authentication completed successfully!") + } + + return record, nil +} + +// LoginWithGoogle performs OAuth login for Kiro with Google. +// This uses a custom protocol handler (kiro://) to receive the callback. +func (a *KiroAuthenticator) LoginWithGoogle(ctx context.Context, cfg *config.Config, opts *LoginOptions) (*coreauth.Auth, error) { + if cfg == nil { + return nil, fmt.Errorf("kiro auth: configuration is required") + } + + oauth := kiroauth.NewKiroOAuth(cfg) + + // Use Google OAuth flow with protocol handler + tokenData, err := oauth.LoginWithGoogle(ctx) + if err != nil { + return nil, fmt.Errorf("google login failed: %w", err) + } + + // Parse expires_at + expiresAt, err := time.Parse(time.RFC3339, tokenData.ExpiresAt) + if err != nil { + expiresAt = time.Now().Add(1 * time.Hour) + } + + // Extract identifier for file naming + idPart := extractKiroIdentifier(tokenData.Email, tokenData.ProfileArn) + + now := time.Now() + fileName := fmt.Sprintf("kiro-google-%s.json", idPart) + + record := &coreauth.Auth{ + ID: fileName, + Provider: "kiro", + FileName: fileName, + Label: "kiro-google", + Status: coreauth.StatusActive, + CreatedAt: now, + UpdatedAt: now, + Metadata: map[string]any{ + "type": "kiro", + "access_token": tokenData.AccessToken, + "refresh_token": tokenData.RefreshToken, + "profile_arn": tokenData.ProfileArn, + "expires_at": tokenData.ExpiresAt, + "auth_method": tokenData.AuthMethod, + "provider": tokenData.Provider, + "email": tokenData.Email, + }, + Attributes: map[string]string{ + "profile_arn": tokenData.ProfileArn, + "source": "google-oauth", + "email": tokenData.Email, + }, + NextRefreshAfter: expiresAt.Add(-30 * time.Minute), + } + + if tokenData.Email != "" { + fmt.Printf("\n✓ Kiro Google authentication completed successfully! (Account: %s)\n", tokenData.Email) + } else { + fmt.Println("\n✓ Kiro Google authentication completed successfully!") + } + + return record, nil +} + +// LoginWithGitHub performs OAuth login for Kiro with GitHub. +// This uses a custom protocol handler (kiro://) to receive the callback. +func (a *KiroAuthenticator) LoginWithGitHub(ctx context.Context, cfg *config.Config, opts *LoginOptions) (*coreauth.Auth, error) { + if cfg == nil { + return nil, fmt.Errorf("kiro auth: configuration is required") + } + + oauth := kiroauth.NewKiroOAuth(cfg) + + // Use GitHub OAuth flow with protocol handler + tokenData, err := oauth.LoginWithGitHub(ctx) + if err != nil { + return nil, fmt.Errorf("github login failed: %w", err) + } + + // Parse expires_at + expiresAt, err := time.Parse(time.RFC3339, tokenData.ExpiresAt) + if err != nil { + expiresAt = time.Now().Add(1 * time.Hour) + } + + // Extract identifier for file naming + idPart := extractKiroIdentifier(tokenData.Email, tokenData.ProfileArn) + + now := time.Now() + fileName := fmt.Sprintf("kiro-github-%s.json", idPart) + + record := &coreauth.Auth{ + ID: fileName, + Provider: "kiro", + FileName: fileName, + Label: "kiro-github", + Status: coreauth.StatusActive, + CreatedAt: now, + UpdatedAt: now, + Metadata: map[string]any{ + "type": "kiro", + "access_token": tokenData.AccessToken, + "refresh_token": tokenData.RefreshToken, + "profile_arn": tokenData.ProfileArn, + "expires_at": tokenData.ExpiresAt, + "auth_method": tokenData.AuthMethod, + "provider": tokenData.Provider, + "email": tokenData.Email, + }, + Attributes: map[string]string{ + "profile_arn": tokenData.ProfileArn, + "source": "github-oauth", + "email": tokenData.Email, + }, + NextRefreshAfter: expiresAt.Add(-30 * time.Minute), + } + + if tokenData.Email != "" { + fmt.Printf("\n✓ Kiro GitHub authentication completed successfully! (Account: %s)\n", tokenData.Email) + } else { + fmt.Println("\n✓ Kiro GitHub authentication completed successfully!") + } + + return record, nil +} + +// ImportFromKiroIDE imports token from Kiro IDE's token file. +func (a *KiroAuthenticator) ImportFromKiroIDE(ctx context.Context, cfg *config.Config) (*coreauth.Auth, error) { + tokenData, err := kiroauth.LoadKiroIDEToken() + if err != nil { + return nil, fmt.Errorf("failed to load Kiro IDE token: %w", err) + } + + // Parse expires_at + expiresAt, err := time.Parse(time.RFC3339, tokenData.ExpiresAt) + if err != nil { + expiresAt = time.Now().Add(1 * time.Hour) + } + + // Extract email from JWT if not already set (for imported tokens) + if tokenData.Email == "" { + tokenData.Email = kiroauth.ExtractEmailFromJWT(tokenData.AccessToken) + } + + // Extract identifier for file naming + idPart := extractKiroIdentifier(tokenData.Email, tokenData.ProfileArn) + // Sanitize provider to prevent path traversal (defense-in-depth) + provider := kiroauth.SanitizeEmailForFilename(strings.ToLower(strings.TrimSpace(tokenData.Provider))) + if provider == "" { + provider = "imported" // Fallback for legacy tokens without provider + } + + now := time.Now() + fileName := fmt.Sprintf("kiro-%s-%s.json", provider, idPart) + + record := &coreauth.Auth{ + ID: fileName, + Provider: "kiro", + FileName: fileName, + Label: fmt.Sprintf("kiro-%s", provider), + Status: coreauth.StatusActive, + CreatedAt: now, + UpdatedAt: now, + Metadata: map[string]any{ + "type": "kiro", + "access_token": tokenData.AccessToken, + "refresh_token": tokenData.RefreshToken, + "profile_arn": tokenData.ProfileArn, + "expires_at": tokenData.ExpiresAt, + "auth_method": tokenData.AuthMethod, + "provider": tokenData.Provider, + "email": tokenData.Email, + }, + Attributes: map[string]string{ + "profile_arn": tokenData.ProfileArn, + "source": "kiro-ide-import", + "email": tokenData.Email, + }, + NextRefreshAfter: expiresAt.Add(-30 * time.Minute), + } + + // Display the email if extracted + if tokenData.Email != "" { + fmt.Printf("\n✓ Imported Kiro token from IDE (Provider: %s, Account: %s)\n", tokenData.Provider, tokenData.Email) + } else { + fmt.Printf("\n✓ Imported Kiro token from IDE (Provider: %s)\n", tokenData.Provider) + } + + return record, nil +} + +// Refresh refreshes an expired Kiro token using AWS SSO OIDC. +func (a *KiroAuthenticator) Refresh(ctx context.Context, cfg *config.Config, auth *coreauth.Auth) (*coreauth.Auth, error) { + if auth == nil || auth.Metadata == nil { + return nil, fmt.Errorf("invalid auth record") + } + + refreshToken, ok := auth.Metadata["refresh_token"].(string) + if !ok || refreshToken == "" { + return nil, fmt.Errorf("refresh token not found") + } + + clientID, _ := auth.Metadata["client_id"].(string) + clientSecret, _ := auth.Metadata["client_secret"].(string) + authMethod, _ := auth.Metadata["auth_method"].(string) + + var tokenData *kiroauth.KiroTokenData + var err error + + // Use SSO OIDC refresh for AWS Builder ID, otherwise use Kiro's OAuth refresh endpoint + if clientID != "" && clientSecret != "" && authMethod == "builder-id" { + ssoClient := kiroauth.NewSSOOIDCClient(cfg) + tokenData, err = ssoClient.RefreshToken(ctx, clientID, clientSecret, refreshToken) + } else { + // Fallback to Kiro's refresh endpoint (for social auth: Google/GitHub) + oauth := kiroauth.NewKiroOAuth(cfg) + tokenData, err = oauth.RefreshToken(ctx, refreshToken) + } + + if err != nil { + return nil, fmt.Errorf("token refresh failed: %w", err) + } + + // Parse expires_at + expiresAt, err := time.Parse(time.RFC3339, tokenData.ExpiresAt) + if err != nil { + expiresAt = time.Now().Add(1 * time.Hour) + } + + // Clone auth to avoid mutating the input parameter + updated := auth.Clone() + now := time.Now() + updated.UpdatedAt = now + updated.LastRefreshedAt = now + updated.Metadata["access_token"] = tokenData.AccessToken + updated.Metadata["refresh_token"] = tokenData.RefreshToken + updated.Metadata["expires_at"] = tokenData.ExpiresAt + updated.Metadata["last_refresh"] = now.Format(time.RFC3339) // For double-check optimization + updated.NextRefreshAfter = expiresAt.Add(-30 * time.Minute) + + return updated, nil +} diff --git a/sdk/auth/manager.go b/sdk/auth/manager.go index c6469a7d..d630f128 100644 --- a/sdk/auth/manager.go +++ b/sdk/auth/manager.go @@ -74,3 +74,16 @@ func (m *Manager) Login(ctx context.Context, provider string, cfg *config.Config } return record, savedPath, nil } + +// SaveAuth persists an auth record directly without going through the login flow. +func (m *Manager) SaveAuth(record *coreauth.Auth, cfg *config.Config) (string, error) { + if m.store == nil { + return "", fmt.Errorf("no store configured") + } + if cfg != nil { + if dirSetter, ok := m.store.(interface{ SetBaseDir(string) }); ok { + dirSetter.SetBaseDir(cfg.AuthDir) + } + } + return m.store.Save(context.Background(), record) +} diff --git a/sdk/auth/refresh_registry.go b/sdk/auth/refresh_registry.go index e82ac684..09406de5 100644 --- a/sdk/auth/refresh_registry.go +++ b/sdk/auth/refresh_registry.go @@ -14,6 +14,7 @@ func init() { registerRefreshLead("gemini", func() Authenticator { return NewGeminiAuthenticator() }) registerRefreshLead("gemini-cli", func() Authenticator { return NewGeminiAuthenticator() }) registerRefreshLead("antigravity", func() Authenticator { return NewAntigravityAuthenticator() }) + registerRefreshLead("kiro", func() Authenticator { return NewKiroAuthenticator() }) } func registerRefreshLead(provider string, factory func() Authenticator) { diff --git a/sdk/cliproxy/service.go b/sdk/cliproxy/service.go index 8b9a6639..7e1acb83 100644 --- a/sdk/cliproxy/service.go +++ b/sdk/cliproxy/service.go @@ -379,6 +379,8 @@ func (s *Service) ensureExecutorsForAuth(a *coreauth.Auth) { s.coreManager.RegisterExecutor(executor.NewQwenExecutor(s.cfg)) case "iflow": s.coreManager.RegisterExecutor(executor.NewIFlowExecutor(s.cfg)) + case "kiro": + s.coreManager.RegisterExecutor(executor.NewKiroExecutor(s.cfg)) default: providerKey := strings.ToLower(strings.TrimSpace(a.Provider)) if providerKey == "" { @@ -721,6 +723,9 @@ func (s *Service) registerModelsForAuth(a *coreauth.Auth) { case "iflow": models = registry.GetIFlowModels() models = applyExcludedModels(models, excluded) + case "kiro": + models = registry.GetKiroModels() + models = applyExcludedModels(models, excluded) default: // Handle OpenAI-compatibility providers by name using config if s.cfg != nil { From 9583f6b1c5124d1ac86f93361144eec5b81838ba Mon Sep 17 00:00:00 2001 From: Mansi Date: Fri, 5 Dec 2025 23:06:07 +0300 Subject: [PATCH 10/38] refactor: extract setKiroIncognitoMode helper to reduce code duplication - Added setKiroIncognitoMode() helper function to handle Kiro auth incognito mode setting - Replaced 3 duplicate code blocks (21 lines) with single function calls (3 lines) - Kiro auth defaults to incognito mode for multi-account support - Users can override with --incognito or --no-incognito flags This addresses the code duplication noted in PR #1 review. --- cmd/server/main.go | 37 ++++++++++++++++--------------------- 1 file changed, 16 insertions(+), 21 deletions(-) diff --git a/cmd/server/main.go b/cmd/server/main.go index 28c3e514..ef949187 100644 --- a/cmd/server/main.go +++ b/cmd/server/main.go @@ -47,6 +47,19 @@ func init() { buildinfo.BuildDate = BuildDate } +// setKiroIncognitoMode sets the incognito browser mode for Kiro authentication. +// Kiro defaults to incognito mode for multi-account support. +// Users can explicitly override with --incognito or --no-incognito flags. +func setKiroIncognitoMode(cfg *config.Config, useIncognito, noIncognito bool) { + if useIncognito { + cfg.IncognitoBrowser = true + } else if noIncognito { + cfg.IncognitoBrowser = false + } else { + cfg.IncognitoBrowser = true // Kiro default + } +} + // main is the entry point of the application. // It parses command-line flags, loads configuration, and starts the appropriate // service based on the provided flags (login, codex-login, or server mode). @@ -465,36 +478,18 @@ func main() { // Users can explicitly override with --no-incognito // Note: This config mutation is safe - auth commands exit after completion // and don't share config with StartService (which is in the else branch) - if useIncognito { - cfg.IncognitoBrowser = true - } else if noIncognito { - cfg.IncognitoBrowser = false - } else { - cfg.IncognitoBrowser = true // Kiro default - } + setKiroIncognitoMode(cfg, useIncognito, noIncognito) cmd.DoKiroLogin(cfg, options) } else if kiroGoogleLogin { // For Kiro auth, default to incognito mode for multi-account support // Users can explicitly override with --no-incognito // Note: This config mutation is safe - auth commands exit after completion - if useIncognito { - cfg.IncognitoBrowser = true - } else if noIncognito { - cfg.IncognitoBrowser = false - } else { - cfg.IncognitoBrowser = true // Kiro default - } + setKiroIncognitoMode(cfg, useIncognito, noIncognito) cmd.DoKiroGoogleLogin(cfg, options) } else if kiroAWSLogin { // For Kiro auth, default to incognito mode for multi-account support // Users can explicitly override with --no-incognito - if useIncognito { - cfg.IncognitoBrowser = true - } else if noIncognito { - cfg.IncognitoBrowser = false - } else { - cfg.IncognitoBrowser = true // Kiro default - } + setKiroIncognitoMode(cfg, useIncognito, noIncognito) cmd.DoKiroAWSLogin(cfg, options) } else if kiroImport { cmd.DoKiroImport(cfg, options) From df83ba877f40e5528ae8348e935d54829217071b Mon Sep 17 00:00:00 2001 From: Mansi Date: Fri, 5 Dec 2025 23:16:48 +0300 Subject: [PATCH 11/38] refactor: replace custom stringContains with strings.Contains - Remove custom stringContains and findSubstring helper functions - Use standard library strings.Contains for better maintainability - No functional change, just cleaner code Addresses Gemini Code Assist review feedback --- internal/browser/browser.go | 36 +++++++++++------------------------- 1 file changed, 11 insertions(+), 25 deletions(-) diff --git a/internal/browser/browser.go b/internal/browser/browser.go index 1ff0b469..3a5aeea7 100644 --- a/internal/browser/browser.go +++ b/internal/browser/browser.go @@ -6,6 +6,7 @@ import ( "fmt" "os/exec" "runtime" + "strings" "sync" pkgbrowser "github.com/pkg/browser" @@ -208,22 +209,7 @@ func tryDefaultBrowserMacOS(url string) *exec.Cmd { // containsBrowserID checks if the LaunchServices output contains a browser ID. func containsBrowserID(output, bundleID string) bool { - return stringContains(output, bundleID) -} - -// stringContains is a simple string contains check. -func stringContains(s, substr string) bool { - return len(s) >= len(substr) && (s == substr || len(substr) == 0 || - (len(s) > 0 && len(substr) > 0 && findSubstring(s, substr))) -} - -func findSubstring(s, substr string) bool { - for i := 0; i <= len(s)-len(substr); i++ { - if s[i:i+len(substr)] == substr { - return true - } - } - return false + return strings.Contains(output, bundleID) } // createMacOSIncognitoCmd creates the appropriate incognito command for macOS browsers. @@ -287,13 +273,13 @@ func tryDefaultBrowserWindows(url string) *exec.Cmd { var browserName string // Map ProgId to browser name - if stringContains(output, "ChromeHTML") { + if strings.Contains(output, "ChromeHTML") { browserName = "chrome" - } else if stringContains(output, "FirefoxURL") { + } else if strings.Contains(output, "FirefoxURL") { browserName = "firefox" - } else if stringContains(output, "MSEdgeHTM") { + } else if strings.Contains(output, "MSEdgeHTM") { browserName = "edge" - } else if stringContains(output, "BraveHTML") { + } else if strings.Contains(output, "BraveHTML") { browserName = "brave" } @@ -354,15 +340,15 @@ func tryDefaultBrowserLinux(url string) *exec.Cmd { var browserName string // Map .desktop file to browser name - if stringContains(desktop, "google-chrome") || stringContains(desktop, "chrome") { + if strings.Contains(desktop, "google-chrome") || strings.Contains(desktop, "chrome") { browserName = "chrome" - } else if stringContains(desktop, "firefox") { + } else if strings.Contains(desktop, "firefox") { browserName = "firefox" - } else if stringContains(desktop, "chromium") { + } else if strings.Contains(desktop, "chromium") { browserName = "chromium" - } else if stringContains(desktop, "brave") { + } else if strings.Contains(desktop, "brave") { browserName = "brave" - } else if stringContains(desktop, "microsoft-edge") || stringContains(desktop, "msedge") { + } else if strings.Contains(desktop, "microsoft-edge") || strings.Contains(desktop, "msedge") { browserName = "edge" } From b06463c6d92450bedc696350fc832b870594b1b2 Mon Sep 17 00:00:00 2001 From: Luis Pater Date: Sat, 6 Dec 2025 12:04:36 +0800 Subject: [PATCH 12/38] docs: update README to include Kiro (AWS CodeWhisperer) support integration details --- README.md | 1 + README_CN.md | 1 + 2 files changed, 2 insertions(+) diff --git a/README.md b/README.md index 8e27ab05..b4037180 100644 --- a/README.md +++ b/README.md @@ -11,6 +11,7 @@ The Plus release stays in lockstep with the mainline features. ## Differences from the Mainline - Added GitHub Copilot support (OAuth login), provided by [em4gp](https://github.com/em4go/CLIProxyAPI/tree/feature/github-copilot-auth) +- Added Kiro (AWS CodeWhisperer) support (OAuth login), provided by [fuko2935](https://github.com/fuko2935/CLIProxyAPI/tree/feature/kiro-integration) ## Contributing diff --git a/README_CN.md b/README_CN.md index 84e90d40..f3260e9b 100644 --- a/README_CN.md +++ b/README_CN.md @@ -11,6 +11,7 @@ ## 与主线版本版本差异 - 新增 GitHub Copilot 支持(OAuth 登录),由[em4gp](https://github.com/em4go/CLIProxyAPI/tree/feature/github-copilot-auth)提供 +- 新增 Kiro (AWS CodeWhisperer) 支持 (OAuth 登录), 由[fuko2935](https://github.com/fuko2935/CLIProxyAPI/tree/feature/kiro-integration)提供 ## 贡献 From b73e53d6c41eedb64d445c2fccba831c73b6533a Mon Sep 17 00:00:00 2001 From: Luis Pater Date: Sat, 6 Dec 2025 12:14:23 +0800 Subject: [PATCH 13/38] docs: update README to fix formatting for GitHub Copilot and Kiro support sections --- README.md | 2 +- README_CN.md | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/README.md b/README.md index b4037180..7a552135 100644 --- a/README.md +++ b/README.md @@ -10,7 +10,7 @@ The Plus release stays in lockstep with the mainline features. ## Differences from the Mainline -- Added GitHub Copilot support (OAuth login), provided by [em4gp](https://github.com/em4go/CLIProxyAPI/tree/feature/github-copilot-auth) +- Added GitHub Copilot support (OAuth login), provided by [em4go](https://github.com/em4go/CLIProxyAPI/tree/feature/github-copilot-auth) - Added Kiro (AWS CodeWhisperer) support (OAuth login), provided by [fuko2935](https://github.com/fuko2935/CLIProxyAPI/tree/feature/kiro-integration) ## Contributing diff --git a/README_CN.md b/README_CN.md index f3260e9b..163bec07 100644 --- a/README_CN.md +++ b/README_CN.md @@ -10,7 +10,7 @@ ## 与主线版本版本差异 -- 新增 GitHub Copilot 支持(OAuth 登录),由[em4gp](https://github.com/em4go/CLIProxyAPI/tree/feature/github-copilot-auth)提供 +- 新增 GitHub Copilot 支持(OAuth 登录),由[em4go](https://github.com/em4go/CLIProxyAPI/tree/feature/github-copilot-auth)提供 - 新增 Kiro (AWS CodeWhisperer) 支持 (OAuth 登录), 由[fuko2935](https://github.com/fuko2935/CLIProxyAPI/tree/feature/kiro-integration)提供 ## 贡献 From 68cbe2066453e24dfd65ca249f83a0a66aac98c8 Mon Sep 17 00:00:00 2001 From: Your Name Date: Sun, 7 Dec 2025 17:56:06 +0800 Subject: [PATCH 14/38] =?UTF-8?q?feat:=20=E6=B7=BB=E5=8A=A0Kiro=E6=B8=A0?= =?UTF-8?q?=E9=81=93=E5=9B=BE=E7=89=87=E6=94=AF=E6=8C=81=E5=8A=9F=E8=83=BD?= =?UTF-8?q?=EF=BC=8C=E5=80=9F=E9=89=B4justlovemaki/AIClient-2-API=E5=AE=9E?= =?UTF-8?q?=E7=8E=B0?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- internal/runtime/executor/kiro_executor.go | 43 +++++++++++++++++++++- 1 file changed, 41 insertions(+), 2 deletions(-) diff --git a/internal/runtime/executor/kiro_executor.go b/internal/runtime/executor/kiro_executor.go index eb068c14..0157d68c 100644 --- a/internal/runtime/executor/kiro_executor.go +++ b/internal/runtime/executor/kiro_executor.go @@ -604,10 +604,22 @@ type kiroHistoryMessage struct { AssistantResponseMessage *kiroAssistantResponseMessage `json:"assistantResponseMessage,omitempty"` } +// kiroImage represents an image in Kiro API format +type kiroImage struct { + Format string `json:"format"` + Source kiroImageSource `json:"source"` +} + +// kiroImageSource contains the image data +type kiroImageSource struct { + Bytes string `json:"bytes"` // base64 encoded image data +} + type kiroUserInputMessage struct { Content string `json:"content"` ModelID string `json:"modelId"` Origin string `json:"origin"` + Images []kiroImage `json:"images,omitempty"` UserInputMessageContext *kiroUserInputMessageContext `json:"userInputMessageContext,omitempty"` } @@ -824,6 +836,7 @@ func (e *KiroExecutor) buildUserMessageStruct(msg gjson.Result, modelID, origin content := msg.Get("content") var contentBuilder strings.Builder var toolResults []kiroToolResult + var images []kiroImage if content.IsArray() { for _, part := range content.Array() { @@ -831,6 +844,25 @@ func (e *KiroExecutor) buildUserMessageStruct(msg gjson.Result, modelID, origin switch partType { case "text": contentBuilder.WriteString(part.Get("text").String()) + case "image": + // Extract image data from Claude API format + mediaType := part.Get("source.media_type").String() + data := part.Get("source.data").String() + + // Extract format from media_type (e.g., "image/png" -> "png") + format := "" + if idx := strings.LastIndex(mediaType, "/"); idx != -1 { + format = mediaType[idx+1:] + } + + if format != "" && data != "" { + images = append(images, kiroImage{ + Format: format, + Source: kiroImageSource{ + Bytes: data, + }, + }) + } case "tool_result": // Extract tool result for API toolUseID := part.Get("tool_use_id").String() @@ -872,11 +904,18 @@ func (e *KiroExecutor) buildUserMessageStruct(msg gjson.Result, modelID, origin contentBuilder.WriteString(content.String()) } - return kiroUserInputMessage{ + userMsg := kiroUserInputMessage{ Content: contentBuilder.String(), ModelID: modelID, Origin: origin, - }, toolResults + } + + // Add images to message if present + if len(images) > 0 { + userMsg.Images = images + } + + return userMsg, toolResults } // buildAssistantMessageStruct builds an assistant message with tool uses From 5d716dc796b92756b895dab5e88ea8afaacb687e Mon Sep 17 00:00:00 2001 From: Your Name Date: Sun, 7 Dec 2025 21:34:44 +0800 Subject: [PATCH 15/38] =?UTF-8?q?fix(kiro):=20=E4=BF=AE=E5=A4=8D=20base64?= =?UTF-8?q?=20=E5=9B=BE=E7=89=87=E6=A0=BC=E5=BC=8F=E8=BD=AC=E6=8D=A2?= =?UTF-8?q?=E9=97=AE=E9=A2=98?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../chat-completions/kiro_openai_request.go | 89 ++++++++++++++++--- 1 file changed, 75 insertions(+), 14 deletions(-) diff --git a/internal/translator/kiro/openai/chat-completions/kiro_openai_request.go b/internal/translator/kiro/openai/chat-completions/kiro_openai_request.go index fc850d96..2dbf79b9 100644 --- a/internal/translator/kiro/openai/chat-completions/kiro_openai_request.go +++ b/internal/translator/kiro/openai/chat-completions/kiro_openai_request.go @@ -4,6 +4,7 @@ package chat_completions import ( "bytes" "encoding/json" + "strings" "github.com/tidwall/gjson" "github.com/tidwall/sjson" @@ -182,13 +183,43 @@ func ConvertOpenAIRequestToKiro(modelName string, inputRawJSON []byte, stream bo "text": part.Get("text").String(), }) } else if partType == "image_url" { - contentParts = append(contentParts, map[string]interface{}{ - "type": "image", - "source": map[string]interface{}{ - "type": "url", - "url": part.Get("image_url.url").String(), - }, - }) + imageURL := part.Get("image_url.url").String() + + // 检查是否是base64格式 (data:image/png;base64,xxxxx) + if strings.HasPrefix(imageURL, "data:") { + // 解析 data URL 格式 + // 格式: data:image/png;base64,xxxxx + commaIdx := strings.Index(imageURL, ",") + if commaIdx != -1 { + // 提取 media_type (例如 "image/png") + header := imageURL[5:commaIdx] // 去掉 "data:" 前缀 + mediaType := header + if semiIdx := strings.Index(header, ";"); semiIdx != -1 { + mediaType = header[:semiIdx] + } + + // 提取 base64 数据 + base64Data := imageURL[commaIdx+1:] + + contentParts = append(contentParts, map[string]interface{}{ + "type": "image", + "source": map[string]interface{}{ + "type": "base64", + "media_type": mediaType, + "data": base64Data, + }, + }) + } + } else { + // 普通URL格式 - 保持原有逻辑 + contentParts = append(contentParts, map[string]interface{}{ + "type": "image", + "source": map[string]interface{}{ + "type": "url", + "url": imageURL, + }, + }) + } } } } else if content.String() != "" { @@ -214,13 +245,43 @@ func ConvertOpenAIRequestToKiro(modelName string, inputRawJSON []byte, stream bo "text": part.Get("text").String(), }) } else if partType == "image_url" { - contentParts = append(contentParts, map[string]interface{}{ - "type": "image", - "source": map[string]interface{}{ - "type": "url", - "url": part.Get("image_url.url").String(), - }, - }) + imageURL := part.Get("image_url.url").String() + + // 检查是否是base64格式 (data:image/png;base64,xxxxx) + if strings.HasPrefix(imageURL, "data:") { + // 解析 data URL 格式 + // 格式: data:image/png;base64,xxxxx + commaIdx := strings.Index(imageURL, ",") + if commaIdx != -1 { + // 提取 media_type (例如 "image/png") + header := imageURL[5:commaIdx] // 去掉 "data:" 前缀 + mediaType := header + if semiIdx := strings.Index(header, ";"); semiIdx != -1 { + mediaType = header[:semiIdx] + } + + // 提取 base64 数据 + base64Data := imageURL[commaIdx+1:] + + contentParts = append(contentParts, map[string]interface{}{ + "type": "image", + "source": map[string]interface{}{ + "type": "base64", + "media_type": mediaType, + "data": base64Data, + }, + }) + } + } else { + // 普通URL格式 - 保持原有逻辑 + contentParts = append(contentParts, map[string]interface{}{ + "type": "image", + "source": map[string]interface{}{ + "type": "url", + "url": imageURL, + }, + }) + } } } claudeMsg["content"] = contentParts From 2bf9e08b31f11b718c3e24c2eda6dbf80677f159 Mon Sep 17 00:00:00 2001 From: Your Name Date: Sun, 7 Dec 2025 21:50:06 +0800 Subject: [PATCH 16/38] style(kiro): convert Chinese comments to English in base64 image handling --- .../chat-completions/kiro_openai_request.go | 28 +++++++++---------- 1 file changed, 14 insertions(+), 14 deletions(-) diff --git a/internal/translator/kiro/openai/chat-completions/kiro_openai_request.go b/internal/translator/kiro/openai/chat-completions/kiro_openai_request.go index 2dbf79b9..3d339505 100644 --- a/internal/translator/kiro/openai/chat-completions/kiro_openai_request.go +++ b/internal/translator/kiro/openai/chat-completions/kiro_openai_request.go @@ -185,20 +185,20 @@ func ConvertOpenAIRequestToKiro(modelName string, inputRawJSON []byte, stream bo } else if partType == "image_url" { imageURL := part.Get("image_url.url").String() - // 检查是否是base64格式 (data:image/png;base64,xxxxx) + // Check if it's base64 format (data:image/png;base64,xxxxx) if strings.HasPrefix(imageURL, "data:") { - // 解析 data URL 格式 - // 格式: data:image/png;base64,xxxxx + // Parse data URL format + // Format: data:image/png;base64,xxxxx commaIdx := strings.Index(imageURL, ",") if commaIdx != -1 { - // 提取 media_type (例如 "image/png") - header := imageURL[5:commaIdx] // 去掉 "data:" 前缀 + // Extract media_type (e.g., "image/png") + header := imageURL[5:commaIdx] // Remove "data:" prefix mediaType := header if semiIdx := strings.Index(header, ";"); semiIdx != -1 { mediaType = header[:semiIdx] } - // 提取 base64 数据 + // Extract base64 data base64Data := imageURL[commaIdx+1:] contentParts = append(contentParts, map[string]interface{}{ @@ -211,7 +211,7 @@ func ConvertOpenAIRequestToKiro(modelName string, inputRawJSON []byte, stream bo }) } } else { - // 普通URL格式 - 保持原有逻辑 + // Regular URL format - keep original logic contentParts = append(contentParts, map[string]interface{}{ "type": "image", "source": map[string]interface{}{ @@ -247,20 +247,20 @@ func ConvertOpenAIRequestToKiro(modelName string, inputRawJSON []byte, stream bo } else if partType == "image_url" { imageURL := part.Get("image_url.url").String() - // 检查是否是base64格式 (data:image/png;base64,xxxxx) + // Check if it's base64 format (data:image/png;base64,xxxxx) if strings.HasPrefix(imageURL, "data:") { - // 解析 data URL 格式 - // 格式: data:image/png;base64,xxxxx + // Parse data URL format + // Format: data:image/png;base64,xxxxx commaIdx := strings.Index(imageURL, ",") if commaIdx != -1 { - // 提取 media_type (例如 "image/png") - header := imageURL[5:commaIdx] // 去掉 "data:" 前缀 + // Extract media_type (e.g., "image/png") + header := imageURL[5:commaIdx] // Remove "data:" prefix mediaType := header if semiIdx := strings.Index(header, ";"); semiIdx != -1 { mediaType = header[:semiIdx] } - // 提取 base64 数据 + // Extract base64 data base64Data := imageURL[commaIdx+1:] contentParts = append(contentParts, map[string]interface{}{ @@ -273,7 +273,7 @@ func ConvertOpenAIRequestToKiro(modelName string, inputRawJSON []byte, stream bo }) } } else { - // 普通URL格式 - 保持原有逻辑 + // Regular URL format - keep original logic contentParts = append(contentParts, map[string]interface{}{ "type": "image", "source": map[string]interface{}{ From a0c6cffb0da0d53d94a6c6bfb8cdf528773a09ee Mon Sep 17 00:00:00 2001 From: Ravens2121 Date: Sun, 7 Dec 2025 21:55:13 +0800 Subject: [PATCH 17/38] =?UTF-8?q?fix(kiro):=E4=BF=AE=E5=A4=8D=20base64=20?= =?UTF-8?q?=E5=9B=BE=E7=89=87=E6=A0=BC=E5=BC=8F=E8=BD=AC=E6=8D=A2=E9=97=AE?= =?UTF-8?q?=E9=A2=98=20(#10)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../chat-completions/kiro_openai_request.go | 89 ++++++++++++++++--- 1 file changed, 75 insertions(+), 14 deletions(-) diff --git a/internal/translator/kiro/openai/chat-completions/kiro_openai_request.go b/internal/translator/kiro/openai/chat-completions/kiro_openai_request.go index fc850d96..3d339505 100644 --- a/internal/translator/kiro/openai/chat-completions/kiro_openai_request.go +++ b/internal/translator/kiro/openai/chat-completions/kiro_openai_request.go @@ -4,6 +4,7 @@ package chat_completions import ( "bytes" "encoding/json" + "strings" "github.com/tidwall/gjson" "github.com/tidwall/sjson" @@ -182,13 +183,43 @@ func ConvertOpenAIRequestToKiro(modelName string, inputRawJSON []byte, stream bo "text": part.Get("text").String(), }) } else if partType == "image_url" { - contentParts = append(contentParts, map[string]interface{}{ - "type": "image", - "source": map[string]interface{}{ - "type": "url", - "url": part.Get("image_url.url").String(), - }, - }) + imageURL := part.Get("image_url.url").String() + + // Check if it's base64 format (data:image/png;base64,xxxxx) + if strings.HasPrefix(imageURL, "data:") { + // Parse data URL format + // Format: data:image/png;base64,xxxxx + commaIdx := strings.Index(imageURL, ",") + if commaIdx != -1 { + // Extract media_type (e.g., "image/png") + header := imageURL[5:commaIdx] // Remove "data:" prefix + mediaType := header + if semiIdx := strings.Index(header, ";"); semiIdx != -1 { + mediaType = header[:semiIdx] + } + + // Extract base64 data + base64Data := imageURL[commaIdx+1:] + + contentParts = append(contentParts, map[string]interface{}{ + "type": "image", + "source": map[string]interface{}{ + "type": "base64", + "media_type": mediaType, + "data": base64Data, + }, + }) + } + } else { + // Regular URL format - keep original logic + contentParts = append(contentParts, map[string]interface{}{ + "type": "image", + "source": map[string]interface{}{ + "type": "url", + "url": imageURL, + }, + }) + } } } } else if content.String() != "" { @@ -214,13 +245,43 @@ func ConvertOpenAIRequestToKiro(modelName string, inputRawJSON []byte, stream bo "text": part.Get("text").String(), }) } else if partType == "image_url" { - contentParts = append(contentParts, map[string]interface{}{ - "type": "image", - "source": map[string]interface{}{ - "type": "url", - "url": part.Get("image_url.url").String(), - }, - }) + imageURL := part.Get("image_url.url").String() + + // Check if it's base64 format (data:image/png;base64,xxxxx) + if strings.HasPrefix(imageURL, "data:") { + // Parse data URL format + // Format: data:image/png;base64,xxxxx + commaIdx := strings.Index(imageURL, ",") + if commaIdx != -1 { + // Extract media_type (e.g., "image/png") + header := imageURL[5:commaIdx] // Remove "data:" prefix + mediaType := header + if semiIdx := strings.Index(header, ";"); semiIdx != -1 { + mediaType = header[:semiIdx] + } + + // Extract base64 data + base64Data := imageURL[commaIdx+1:] + + contentParts = append(contentParts, map[string]interface{}{ + "type": "image", + "source": map[string]interface{}{ + "type": "base64", + "media_type": mediaType, + "data": base64Data, + }, + }) + } + } else { + // Regular URL format - keep original logic + contentParts = append(contentParts, map[string]interface{}{ + "type": "image", + "source": map[string]interface{}{ + "type": "url", + "url": imageURL, + }, + }) + } } } claudeMsg["content"] = contentParts From ab9e9442ec0a63a29ebf6f9468d330fd76280bb9 Mon Sep 17 00:00:00 2001 From: Luis Pater Date: Mon, 8 Dec 2025 22:32:29 +0800 Subject: [PATCH 18/38] v6.5.56 (#12) * feat(api): add comprehensive ampcode management endpoints Add new REST API endpoints under /v0/management/ampcode for managing ampcode configuration including upstream URL, API key, localhost restriction, model mappings, and force model mappings settings. - Move force-model-mappings from config_basic to config_lists - Add GET/PUT/PATCH/DELETE endpoints for all ampcode settings - Support model mapping CRUD with upsert (PATCH) capability - Add comprehensive test coverage for all ampcode endpoints * refactor(api): simplify request body parsing in ampcode handlers * feat(logging): add upstream API request/response capture to streaming logs * style(logging): remove redundant separator line from response section * feat(antigravity): enforce thinking budget limits for Claude models * refactor(logging): remove unused variable in `ensureAttempt` and redundant function call --------- Co-authored-by: hkfires <10558748+hkfires@users.noreply.github.com> --- .../api/handlers/management/config_basic.go | 8 - .../api/handlers/management/config_lists.go | 152 ++++ internal/api/handlers/management/handler.go | 10 - internal/api/middleware/response_writer.go | 9 + internal/api/server.go | 22 +- internal/logging/request_logger.go | 202 ++++- internal/registry/model_definitions.go | 31 +- .../runtime/executor/antigravity_executor.go | 63 +- internal/runtime/executor/logging_helpers.go | 4 +- .../antigravity_openai_request.go | 5 +- test/amp_management_test.go | 827 ++++++++++++++++++ 11 files changed, 1263 insertions(+), 70 deletions(-) create mode 100644 test/amp_management_test.go diff --git a/internal/api/handlers/management/config_basic.go b/internal/api/handlers/management/config_basic.go index c788aca4..f9069198 100644 --- a/internal/api/handlers/management/config_basic.go +++ b/internal/api/handlers/management/config_basic.go @@ -241,11 +241,3 @@ func (h *Handler) DeleteProxyURL(c *gin.Context) { h.cfg.ProxyURL = "" h.persist(c) } - -// Force Model Mappings (for Amp CLI) -func (h *Handler) GetForceModelMappings(c *gin.Context) { - c.JSON(200, gin.H{"force-model-mappings": h.cfg.AmpCode.ForceModelMappings}) -} -func (h *Handler) PutForceModelMappings(c *gin.Context) { - h.updateBoolField(c, func(v bool) { h.cfg.AmpCode.ForceModelMappings = v }) -} diff --git a/internal/api/handlers/management/config_lists.go b/internal/api/handlers/management/config_lists.go index 8f4c4037..a0d0b169 100644 --- a/internal/api/handlers/management/config_lists.go +++ b/internal/api/handlers/management/config_lists.go @@ -706,3 +706,155 @@ func normalizeClaudeKey(entry *config.ClaudeKey) { } entry.Models = normalized } + +// GetAmpCode returns the complete ampcode configuration. +func (h *Handler) GetAmpCode(c *gin.Context) { + if h == nil || h.cfg == nil { + c.JSON(200, gin.H{"ampcode": config.AmpCode{}}) + return + } + c.JSON(200, gin.H{"ampcode": h.cfg.AmpCode}) +} + +// GetAmpUpstreamURL returns the ampcode upstream URL. +func (h *Handler) GetAmpUpstreamURL(c *gin.Context) { + if h == nil || h.cfg == nil { + c.JSON(200, gin.H{"upstream-url": ""}) + return + } + c.JSON(200, gin.H{"upstream-url": h.cfg.AmpCode.UpstreamURL}) +} + +// PutAmpUpstreamURL updates the ampcode upstream URL. +func (h *Handler) PutAmpUpstreamURL(c *gin.Context) { + h.updateStringField(c, func(v string) { h.cfg.AmpCode.UpstreamURL = strings.TrimSpace(v) }) +} + +// DeleteAmpUpstreamURL clears the ampcode upstream URL. +func (h *Handler) DeleteAmpUpstreamURL(c *gin.Context) { + h.cfg.AmpCode.UpstreamURL = "" + h.persist(c) +} + +// GetAmpUpstreamAPIKey returns the ampcode upstream API key. +func (h *Handler) GetAmpUpstreamAPIKey(c *gin.Context) { + if h == nil || h.cfg == nil { + c.JSON(200, gin.H{"upstream-api-key": ""}) + return + } + c.JSON(200, gin.H{"upstream-api-key": h.cfg.AmpCode.UpstreamAPIKey}) +} + +// PutAmpUpstreamAPIKey updates the ampcode upstream API key. +func (h *Handler) PutAmpUpstreamAPIKey(c *gin.Context) { + h.updateStringField(c, func(v string) { h.cfg.AmpCode.UpstreamAPIKey = strings.TrimSpace(v) }) +} + +// DeleteAmpUpstreamAPIKey clears the ampcode upstream API key. +func (h *Handler) DeleteAmpUpstreamAPIKey(c *gin.Context) { + h.cfg.AmpCode.UpstreamAPIKey = "" + h.persist(c) +} + +// GetAmpRestrictManagementToLocalhost returns the localhost restriction setting. +func (h *Handler) GetAmpRestrictManagementToLocalhost(c *gin.Context) { + if h == nil || h.cfg == nil { + c.JSON(200, gin.H{"restrict-management-to-localhost": true}) + return + } + c.JSON(200, gin.H{"restrict-management-to-localhost": h.cfg.AmpCode.RestrictManagementToLocalhost}) +} + +// PutAmpRestrictManagementToLocalhost updates the localhost restriction setting. +func (h *Handler) PutAmpRestrictManagementToLocalhost(c *gin.Context) { + h.updateBoolField(c, func(v bool) { h.cfg.AmpCode.RestrictManagementToLocalhost = v }) +} + +// GetAmpModelMappings returns the ampcode model mappings. +func (h *Handler) GetAmpModelMappings(c *gin.Context) { + if h == nil || h.cfg == nil { + c.JSON(200, gin.H{"model-mappings": []config.AmpModelMapping{}}) + return + } + c.JSON(200, gin.H{"model-mappings": h.cfg.AmpCode.ModelMappings}) +} + +// PutAmpModelMappings replaces all ampcode model mappings. +func (h *Handler) PutAmpModelMappings(c *gin.Context) { + var body struct { + Value []config.AmpModelMapping `json:"value"` + } + if err := c.ShouldBindJSON(&body); err != nil { + c.JSON(400, gin.H{"error": "invalid body"}) + return + } + h.cfg.AmpCode.ModelMappings = body.Value + h.persist(c) +} + +// PatchAmpModelMappings adds or updates model mappings. +func (h *Handler) PatchAmpModelMappings(c *gin.Context) { + var body struct { + Value []config.AmpModelMapping `json:"value"` + } + if err := c.ShouldBindJSON(&body); err != nil { + c.JSON(400, gin.H{"error": "invalid body"}) + return + } + + existing := make(map[string]int) + for i, m := range h.cfg.AmpCode.ModelMappings { + existing[strings.TrimSpace(m.From)] = i + } + + for _, newMapping := range body.Value { + from := strings.TrimSpace(newMapping.From) + if idx, ok := existing[from]; ok { + h.cfg.AmpCode.ModelMappings[idx] = newMapping + } else { + h.cfg.AmpCode.ModelMappings = append(h.cfg.AmpCode.ModelMappings, newMapping) + existing[from] = len(h.cfg.AmpCode.ModelMappings) - 1 + } + } + h.persist(c) +} + +// DeleteAmpModelMappings removes specified model mappings by "from" field. +func (h *Handler) DeleteAmpModelMappings(c *gin.Context) { + var body struct { + Value []string `json:"value"` + } + if err := c.ShouldBindJSON(&body); err != nil || len(body.Value) == 0 { + h.cfg.AmpCode.ModelMappings = nil + h.persist(c) + return + } + + toRemove := make(map[string]bool) + for _, from := range body.Value { + toRemove[strings.TrimSpace(from)] = true + } + + newMappings := make([]config.AmpModelMapping, 0, len(h.cfg.AmpCode.ModelMappings)) + for _, m := range h.cfg.AmpCode.ModelMappings { + if !toRemove[strings.TrimSpace(m.From)] { + newMappings = append(newMappings, m) + } + } + h.cfg.AmpCode.ModelMappings = newMappings + h.persist(c) +} + +// GetAmpForceModelMappings returns whether model mappings are forced. +func (h *Handler) GetAmpForceModelMappings(c *gin.Context) { + if h == nil || h.cfg == nil { + c.JSON(200, gin.H{"force-model-mappings": false}) + return + } + c.JSON(200, gin.H{"force-model-mappings": h.cfg.AmpCode.ForceModelMappings}) +} + +// PutAmpForceModelMappings updates the force model mappings setting. +func (h *Handler) PutAmpForceModelMappings(c *gin.Context) { + h.updateBoolField(c, func(v bool) { h.cfg.AmpCode.ForceModelMappings = v }) +} diff --git a/internal/api/handlers/management/handler.go b/internal/api/handlers/management/handler.go index ef6f400a..39e6b7fd 100644 --- a/internal/api/handlers/management/handler.go +++ b/internal/api/handlers/management/handler.go @@ -240,16 +240,6 @@ func (h *Handler) updateBoolField(c *gin.Context, set func(bool)) { Value *bool `json:"value"` } if err := c.ShouldBindJSON(&body); err != nil || body.Value == nil { - var m map[string]any - if err2 := c.ShouldBindJSON(&m); err2 == nil { - for _, v := range m { - if b, ok := v.(bool); ok { - set(b) - h.persist(c) - return - } - } - } c.JSON(http.StatusBadRequest, gin.H{"error": "invalid body"}) return } diff --git a/internal/api/middleware/response_writer.go b/internal/api/middleware/response_writer.go index f0d1ad26..b7259bc6 100644 --- a/internal/api/middleware/response_writer.go +++ b/internal/api/middleware/response_writer.go @@ -232,7 +232,16 @@ func (w *ResponseWriterWrapper) Finalize(c *gin.Context) error { w.streamDone = nil } + // Write API Request and Response to the streaming log before closing if w.streamWriter != nil { + apiRequest := w.extractAPIRequest(c) + if len(apiRequest) > 0 { + _ = w.streamWriter.WriteAPIRequest(apiRequest) + } + apiResponse := w.extractAPIResponse(c) + if len(apiResponse) > 0 { + _ = w.streamWriter.WriteAPIResponse(apiResponse) + } if err := w.streamWriter.Close(); err != nil { w.streamWriter = nil return err diff --git a/internal/api/server.go b/internal/api/server.go index 1f35429e..2e463de4 100644 --- a/internal/api/server.go +++ b/internal/api/server.go @@ -520,9 +520,25 @@ func (s *Server) registerManagementRoutes() { mgmt.PUT("/ws-auth", s.mgmt.PutWebsocketAuth) mgmt.PATCH("/ws-auth", s.mgmt.PutWebsocketAuth) - mgmt.GET("/force-model-mappings", s.mgmt.GetForceModelMappings) - mgmt.PUT("/force-model-mappings", s.mgmt.PutForceModelMappings) - mgmt.PATCH("/force-model-mappings", s.mgmt.PutForceModelMappings) + mgmt.GET("/ampcode", s.mgmt.GetAmpCode) + mgmt.GET("/ampcode/upstream-url", s.mgmt.GetAmpUpstreamURL) + mgmt.PUT("/ampcode/upstream-url", s.mgmt.PutAmpUpstreamURL) + mgmt.PATCH("/ampcode/upstream-url", s.mgmt.PutAmpUpstreamURL) + mgmt.DELETE("/ampcode/upstream-url", s.mgmt.DeleteAmpUpstreamURL) + mgmt.GET("/ampcode/upstream-api-key", s.mgmt.GetAmpUpstreamAPIKey) + mgmt.PUT("/ampcode/upstream-api-key", s.mgmt.PutAmpUpstreamAPIKey) + mgmt.PATCH("/ampcode/upstream-api-key", s.mgmt.PutAmpUpstreamAPIKey) + mgmt.DELETE("/ampcode/upstream-api-key", s.mgmt.DeleteAmpUpstreamAPIKey) + mgmt.GET("/ampcode/restrict-management-to-localhost", s.mgmt.GetAmpRestrictManagementToLocalhost) + mgmt.PUT("/ampcode/restrict-management-to-localhost", s.mgmt.PutAmpRestrictManagementToLocalhost) + mgmt.PATCH("/ampcode/restrict-management-to-localhost", s.mgmt.PutAmpRestrictManagementToLocalhost) + mgmt.GET("/ampcode/model-mappings", s.mgmt.GetAmpModelMappings) + mgmt.PUT("/ampcode/model-mappings", s.mgmt.PutAmpModelMappings) + mgmt.PATCH("/ampcode/model-mappings", s.mgmt.PatchAmpModelMappings) + mgmt.DELETE("/ampcode/model-mappings", s.mgmt.DeleteAmpModelMappings) + mgmt.GET("/ampcode/force-model-mappings", s.mgmt.GetAmpForceModelMappings) + mgmt.PUT("/ampcode/force-model-mappings", s.mgmt.PutAmpForceModelMappings) + mgmt.PATCH("/ampcode/force-model-mappings", s.mgmt.PutAmpForceModelMappings) mgmt.GET("/request-retry", s.mgmt.GetRequestRetry) mgmt.PUT("/request-retry", s.mgmt.PutRequestRetry) diff --git a/internal/logging/request_logger.go b/internal/logging/request_logger.go index c574febb..f8c068c5 100644 --- a/internal/logging/request_logger.go +++ b/internal/logging/request_logger.go @@ -84,6 +84,26 @@ type StreamingLogWriter interface { // - error: An error if writing fails, nil otherwise WriteStatus(status int, headers map[string][]string) error + // WriteAPIRequest writes the upstream API request details to the log. + // This should be called before WriteStatus to maintain proper log ordering. + // + // Parameters: + // - apiRequest: The API request data (typically includes URL, headers, body sent upstream) + // + // Returns: + // - error: An error if writing fails, nil otherwise + WriteAPIRequest(apiRequest []byte) error + + // WriteAPIResponse writes the upstream API response details to the log. + // This should be called after the streaming response is complete. + // + // Parameters: + // - apiResponse: The API response data + // + // Returns: + // - error: An error if writing fails, nil otherwise + WriteAPIResponse(apiResponse []byte) error + // Close finalizes the log file and cleans up resources. // // Returns: @@ -248,10 +268,11 @@ func (l *FileRequestLogger) LogStreamingRequest(url, method string, headers map[ // Create streaming writer writer := &FileStreamingLogWriter{ - file: file, - chunkChan: make(chan []byte, 100), // Buffered channel for async writes - closeChan: make(chan struct{}), - errorChan: make(chan error, 1), + file: file, + chunkChan: make(chan []byte, 100), // Buffered channel for async writes + closeChan: make(chan struct{}), + errorChan: make(chan error, 1), + bufferedChunks: &bytes.Buffer{}, } // Start async writer goroutine @@ -628,11 +649,12 @@ func (l *FileRequestLogger) formatRequestInfo(url, method string, headers map[st // FileStreamingLogWriter implements StreamingLogWriter for file-based streaming logs. // It handles asynchronous writing of streaming response chunks to a file. +// All data is buffered and written in the correct order when Close is called. type FileStreamingLogWriter struct { // file is the file where log data is written. file *os.File - // chunkChan is a channel for receiving response chunks to write. + // chunkChan is a channel for receiving response chunks to buffer. chunkChan chan []byte // closeChan is a channel for signaling when the writer is closed. @@ -641,8 +663,23 @@ type FileStreamingLogWriter struct { // errorChan is a channel for reporting errors during writing. errorChan chan error - // statusWritten indicates whether the response status has been written. + // bufferedChunks stores the response chunks in order. + bufferedChunks *bytes.Buffer + + // responseStatus stores the HTTP status code. + responseStatus int + + // statusWritten indicates whether a non-zero status was recorded. statusWritten bool + + // responseHeaders stores the response headers. + responseHeaders map[string][]string + + // apiRequest stores the upstream API request data. + apiRequest []byte + + // apiResponse stores the upstream API response data. + apiResponse []byte } // WriteChunkAsync writes a response chunk asynchronously (non-blocking). @@ -666,39 +703,65 @@ func (w *FileStreamingLogWriter) WriteChunkAsync(chunk []byte) { } } -// WriteStatus writes the response status and headers to the log. +// WriteStatus buffers the response status and headers for later writing. // // Parameters: // - status: The response status code // - headers: The response headers // // Returns: -// - error: An error if writing fails, nil otherwise +// - error: Always returns nil (buffering cannot fail) func (w *FileStreamingLogWriter) WriteStatus(status int, headers map[string][]string) error { - if w.file == nil || w.statusWritten { + if status == 0 { return nil } - var content strings.Builder - content.WriteString("========================================\n") - content.WriteString("=== RESPONSE ===\n") - content.WriteString(fmt.Sprintf("Status: %d\n", status)) - - for key, values := range headers { - for _, value := range values { - content.WriteString(fmt.Sprintf("%s: %s\n", key, value)) + w.responseStatus = status + if headers != nil { + w.responseHeaders = make(map[string][]string, len(headers)) + for key, values := range headers { + headerValues := make([]string, len(values)) + copy(headerValues, values) + w.responseHeaders[key] = headerValues } } - content.WriteString("\n") + w.statusWritten = true + return nil +} - _, err := w.file.WriteString(content.String()) - if err == nil { - w.statusWritten = true +// WriteAPIRequest buffers the upstream API request details for later writing. +// +// Parameters: +// - apiRequest: The API request data (typically includes URL, headers, body sent upstream) +// +// Returns: +// - error: Always returns nil (buffering cannot fail) +func (w *FileStreamingLogWriter) WriteAPIRequest(apiRequest []byte) error { + if len(apiRequest) == 0 { + return nil } - return err + w.apiRequest = bytes.Clone(apiRequest) + return nil +} + +// WriteAPIResponse buffers the upstream API response details for later writing. +// +// Parameters: +// - apiResponse: The API response data +// +// Returns: +// - error: Always returns nil (buffering cannot fail) +func (w *FileStreamingLogWriter) WriteAPIResponse(apiResponse []byte) error { + if len(apiResponse) == 0 { + return nil + } + w.apiResponse = bytes.Clone(apiResponse) + return nil } // Close finalizes the log file and cleans up resources. +// It writes all buffered data to the file in the correct order: +// API REQUEST -> API RESPONSE -> RESPONSE (status, headers, body chunks) // // Returns: // - error: An error if closing fails, nil otherwise @@ -707,27 +770,84 @@ func (w *FileStreamingLogWriter) Close() error { close(w.chunkChan) } - // Wait for async writer to finish + // Wait for async writer to finish buffering chunks if w.closeChan != nil { <-w.closeChan w.chunkChan = nil } - if w.file != nil { - return w.file.Close() + if w.file == nil { + return nil } - return nil + // Write all content in the correct order + var content strings.Builder + + // 1. Write API REQUEST section + if len(w.apiRequest) > 0 { + if bytes.HasPrefix(w.apiRequest, []byte("=== API REQUEST")) { + content.Write(w.apiRequest) + if !bytes.HasSuffix(w.apiRequest, []byte("\n")) { + content.WriteString("\n") + } + } else { + content.WriteString("=== API REQUEST ===\n") + content.Write(w.apiRequest) + content.WriteString("\n") + } + content.WriteString("\n") + } + + // 2. Write API RESPONSE section + if len(w.apiResponse) > 0 { + if bytes.HasPrefix(w.apiResponse, []byte("=== API RESPONSE")) { + content.Write(w.apiResponse) + if !bytes.HasSuffix(w.apiResponse, []byte("\n")) { + content.WriteString("\n") + } + } else { + content.WriteString("=== API RESPONSE ===\n") + content.Write(w.apiResponse) + content.WriteString("\n") + } + content.WriteString("\n") + } + + // 3. Write RESPONSE section (status, headers, buffered chunks) + content.WriteString("=== RESPONSE ===\n") + if w.statusWritten { + content.WriteString(fmt.Sprintf("Status: %d\n", w.responseStatus)) + } + + for key, values := range w.responseHeaders { + for _, value := range values { + content.WriteString(fmt.Sprintf("%s: %s\n", key, value)) + } + } + content.WriteString("\n") + + // Write buffered response body chunks + if w.bufferedChunks != nil && w.bufferedChunks.Len() > 0 { + content.Write(w.bufferedChunks.Bytes()) + } + + // Write the complete content to file + if _, err := w.file.WriteString(content.String()); err != nil { + _ = w.file.Close() + return err + } + + return w.file.Close() } -// asyncWriter runs in a goroutine to handle async chunk writing. -// It continuously reads chunks from the channel and writes them to the file. +// asyncWriter runs in a goroutine to buffer chunks from the channel. +// It continuously reads chunks from the channel and buffers them for later writing. func (w *FileStreamingLogWriter) asyncWriter() { defer close(w.closeChan) for chunk := range w.chunkChan { - if w.file != nil { - _, _ = w.file.Write(chunk) + if w.bufferedChunks != nil { + w.bufferedChunks.Write(chunk) } } } @@ -754,6 +874,28 @@ func (w *NoOpStreamingLogWriter) WriteStatus(_ int, _ map[string][]string) error return nil } +// WriteAPIRequest is a no-op implementation that does nothing and always returns nil. +// +// Parameters: +// - apiRequest: The API request data (ignored) +// +// Returns: +// - error: Always returns nil +func (w *NoOpStreamingLogWriter) WriteAPIRequest(_ []byte) error { + return nil +} + +// WriteAPIResponse is a no-op implementation that does nothing and always returns nil. +// +// Parameters: +// - apiResponse: The API response data (ignored) +// +// Returns: +// - error: Always returns nil +func (w *NoOpStreamingLogWriter) WriteAPIResponse(_ []byte) error { + return nil +} + // Close is a no-op implementation that does nothing and always returns nil. // // Returns: diff --git a/internal/registry/model_definitions.go b/internal/registry/model_definitions.go index b25d91c2..fc7e75a1 100644 --- a/internal/registry/model_definitions.go +++ b/internal/registry/model_definitions.go @@ -943,18 +943,6 @@ func GetQwenModels() []*ModelInfo { } } -// GetAntigravityThinkingConfig returns the Thinking configuration for antigravity models. -// Keys use the ALIASED model names (after modelName2Alias conversion) for direct lookup. -func GetAntigravityThinkingConfig() map[string]*ThinkingSupport { - return map[string]*ThinkingSupport{ - "gemini-2.5-flash": {Min: 0, Max: 24576, ZeroAllowed: true, DynamicAllowed: true}, - "gemini-2.5-flash-lite": {Min: 0, Max: 24576, ZeroAllowed: true, DynamicAllowed: true}, - "gemini-3-pro-preview": {Min: 128, Max: 32768, ZeroAllowed: false, DynamicAllowed: true}, - "gemini-claude-sonnet-4-5-thinking": {Min: 1024, Max: 200000, ZeroAllowed: false, DynamicAllowed: true}, - "gemini-claude-opus-4-5-thinking": {Min: 1024, Max: 200000, ZeroAllowed: false, DynamicAllowed: true}, - } -} - // GetIFlowModels returns supported models for iFlow OAuth accounts. func GetIFlowModels() []*ModelInfo { entries := []struct { @@ -998,6 +986,25 @@ func GetIFlowModels() []*ModelInfo { return models } +// AntigravityModelConfig captures static antigravity model overrides, including +// Thinking budget limits and provider max completion tokens. +type AntigravityModelConfig struct { + Thinking *ThinkingSupport + MaxCompletionTokens int +} + +// GetAntigravityModelConfig returns static configuration for antigravity models. +// Keys use the ALIASED model names (after modelName2Alias conversion) for direct lookup. +func GetAntigravityModelConfig() map[string]*AntigravityModelConfig { + return map[string]*AntigravityModelConfig{ + "gemini-2.5-flash": {Thinking: &ThinkingSupport{Min: 0, Max: 24576, ZeroAllowed: true, DynamicAllowed: true}}, + "gemini-2.5-flash-lite": {Thinking: &ThinkingSupport{Min: 0, Max: 24576, ZeroAllowed: true, DynamicAllowed: true}}, + "gemini-3-pro-preview": {Thinking: &ThinkingSupport{Min: 128, Max: 32768, ZeroAllowed: false, DynamicAllowed: true}}, + "gemini-claude-sonnet-4-5-thinking": {Thinking: &ThinkingSupport{Min: 1024, Max: 200000, ZeroAllowed: false, DynamicAllowed: true}, MaxCompletionTokens: 64000}, + "gemini-claude-opus-4-5-thinking": {Thinking: &ThinkingSupport{Min: 1024, Max: 200000, ZeroAllowed: false, DynamicAllowed: true}, MaxCompletionTokens: 64000}, + } +} + // GetGitHubCopilotModels returns the available models for GitHub Copilot. // These models are available through the GitHub Copilot API at api.githubcopilot.com. func GetGitHubCopilotModels() []*ModelInfo { diff --git a/internal/runtime/executor/antigravity_executor.go b/internal/runtime/executor/antigravity_executor.go index ce836a77..d83559ab 100644 --- a/internal/runtime/executor/antigravity_executor.go +++ b/internal/runtime/executor/antigravity_executor.go @@ -81,6 +81,7 @@ func (e *AntigravityExecutor) Execute(ctx context.Context, auth *cliproxyauth.Au translated := sdktranslator.TranslateRequest(from, to, req.Model, bytes.Clone(req.Payload), false) translated = applyThinkingMetadataCLI(translated, req.Metadata, req.Model) + translated = normalizeAntigravityThinking(req.Model, translated) baseURLs := antigravityBaseURLFallbackOrder(auth) httpClient := newProxyAwareHTTPClient(ctx, e.cfg, auth, 0) @@ -174,6 +175,7 @@ func (e *AntigravityExecutor) ExecuteStream(ctx context.Context, auth *cliproxya translated := sdktranslator.TranslateRequest(from, to, req.Model, bytes.Clone(req.Payload), true) translated = applyThinkingMetadataCLI(translated, req.Metadata, req.Model) + translated = normalizeAntigravityThinking(req.Model, translated) baseURLs := antigravityBaseURLFallbackOrder(auth) httpClient := newProxyAwareHTTPClient(ctx, e.cfg, auth, 0) @@ -370,7 +372,7 @@ func FetchAntigravityModels(ctx context.Context, auth *cliproxyauth.Auth, cfg *c } now := time.Now().Unix() - thinkingConfig := registry.GetAntigravityThinkingConfig() + modelConfig := registry.GetAntigravityModelConfig() models := make([]*registry.ModelInfo, 0, len(result.Map())) for originalName := range result.Map() { aliasName := modelName2Alias(originalName) @@ -387,8 +389,13 @@ func FetchAntigravityModels(ctx context.Context, auth *cliproxyauth.Auth, cfg *c Type: antigravityAuthType, } // Look up Thinking support from static config using alias name - if thinking, ok := thinkingConfig[aliasName]; ok { - modelInfo.Thinking = thinking + if cfg, ok := modelConfig[aliasName]; ok { + if cfg.Thinking != nil { + modelInfo.Thinking = cfg.Thinking + } + if cfg.MaxCompletionTokens > 0 { + modelInfo.MaxCompletionTokens = cfg.MaxCompletionTokens + } } models = append(models, modelInfo) } @@ -812,3 +819,53 @@ func alias2ModelName(modelName string) string { return modelName } } + +// normalizeAntigravityThinking clamps or removes thinking config based on model support. +// For Claude models, it additionally ensures thinking budget < max_tokens. +func normalizeAntigravityThinking(model string, payload []byte) []byte { + payload = util.StripThinkingConfigIfUnsupported(model, payload) + if !util.ModelSupportsThinking(model) { + return payload + } + budget := gjson.GetBytes(payload, "request.generationConfig.thinkingConfig.thinkingBudget") + if !budget.Exists() { + return payload + } + raw := int(budget.Int()) + normalized := util.NormalizeThinkingBudget(model, raw) + + isClaude := strings.Contains(strings.ToLower(model), "claude") + if isClaude { + effectiveMax, setDefaultMax := antigravityEffectiveMaxTokens(model, payload) + if effectiveMax > 0 && normalized >= effectiveMax { + normalized = effectiveMax - 1 + if normalized < 1 { + normalized = 1 + } + } + if setDefaultMax { + if res, errSet := sjson.SetBytes(payload, "request.generationConfig.maxOutputTokens", effectiveMax); errSet == nil { + payload = res + } + } + } + + updated, err := sjson.SetBytes(payload, "request.generationConfig.thinkingConfig.thinkingBudget", normalized) + if err != nil { + return payload + } + return updated +} + +// antigravityEffectiveMaxTokens returns the max tokens to cap thinking: +// prefer request-provided maxOutputTokens; otherwise fall back to model default. +// The boolean indicates whether the value came from the model default (and thus should be written back). +func antigravityEffectiveMaxTokens(model string, payload []byte) (max int, fromModel bool) { + if maxTok := gjson.GetBytes(payload, "request.generationConfig.maxOutputTokens"); maxTok.Exists() && maxTok.Int() > 0 { + return int(maxTok.Int()), false + } + if modelInfo := registry.GetGlobalRegistry().GetModelInfo(model); modelInfo != nil && modelInfo.MaxCompletionTokens > 0 { + return modelInfo.MaxCompletionTokens, true + } + return 0, false +} diff --git a/internal/runtime/executor/logging_helpers.go b/internal/runtime/executor/logging_helpers.go index 26931f53..7798b96b 100644 --- a/internal/runtime/executor/logging_helpers.go +++ b/internal/runtime/executor/logging_helpers.go @@ -157,7 +157,7 @@ func appendAPIResponseChunk(ctx context.Context, cfg *config.Config, chunk []byt if ginCtx == nil { return } - attempts, attempt := ensureAttempt(ginCtx) + _, attempt := ensureAttempt(ginCtx) ensureResponseIntro(attempt) if !attempt.headersWritten { @@ -175,8 +175,6 @@ func appendAPIResponseChunk(ctx context.Context, cfg *config.Config, chunk []byt } attempt.response.WriteString(string(data)) attempt.bodyHasContent = true - - updateAggregatedResponse(ginCtx, attempts) } func ginContextFrom(ctx context.Context) *gin.Context { diff --git a/internal/translator/antigravity/openai/chat-completions/antigravity_openai_request.go b/internal/translator/antigravity/openai/chat-completions/antigravity_openai_request.go index 82e71758..1c90a803 100644 --- a/internal/translator/antigravity/openai/chat-completions/antigravity_openai_request.go +++ b/internal/translator/antigravity/openai/chat-completions/antigravity_openai_request.go @@ -111,7 +111,7 @@ func ConvertOpenAIRequestToAntigravity(modelName string, inputRawJSON []byte, _ out, _ = sjson.SetBytes(out, "request.generationConfig.thinkingConfig.include_thoughts", true) } - // Temperature/top_p/top_k + // Temperature/top_p/top_k/max_tokens if tr := gjson.GetBytes(rawJSON, "temperature"); tr.Exists() && tr.Type == gjson.Number { out, _ = sjson.SetBytes(out, "request.generationConfig.temperature", tr.Num) } @@ -121,6 +121,9 @@ func ConvertOpenAIRequestToAntigravity(modelName string, inputRawJSON []byte, _ if tkr := gjson.GetBytes(rawJSON, "top_k"); tkr.Exists() && tkr.Type == gjson.Number { out, _ = sjson.SetBytes(out, "request.generationConfig.topK", tkr.Num) } + if maxTok := gjson.GetBytes(rawJSON, "max_tokens"); maxTok.Exists() && maxTok.Type == gjson.Number { + out, _ = sjson.SetBytes(out, "request.generationConfig.maxOutputTokens", maxTok.Num) + } // Map OpenAI modalities -> Gemini CLI request.generationConfig.responseModalities // e.g. "modalities": ["image", "text"] -> ["IMAGE", "TEXT"] diff --git a/test/amp_management_test.go b/test/amp_management_test.go new file mode 100644 index 00000000..19450dbf --- /dev/null +++ b/test/amp_management_test.go @@ -0,0 +1,827 @@ +package test + +import ( + "bytes" + "encoding/json" + "net/http" + "net/http/httptest" + "os" + "path/filepath" + "testing" + + "github.com/gin-gonic/gin" + "github.com/router-for-me/CLIProxyAPI/v6/internal/api/handlers/management" + "github.com/router-for-me/CLIProxyAPI/v6/internal/config" +) + +func init() { + gin.SetMode(gin.TestMode) +} + +// newAmpTestHandler creates a test handler with default ampcode configuration. +func newAmpTestHandler(t *testing.T) (*management.Handler, string) { + t.Helper() + tmpDir := t.TempDir() + configPath := filepath.Join(tmpDir, "config.yaml") + + cfg := &config.Config{ + AmpCode: config.AmpCode{ + UpstreamURL: "https://example.com", + UpstreamAPIKey: "test-api-key-12345", + RestrictManagementToLocalhost: true, + ForceModelMappings: false, + ModelMappings: []config.AmpModelMapping{ + {From: "gpt-4", To: "gemini-pro"}, + }, + }, + } + + if err := os.WriteFile(configPath, []byte("port: 8080\n"), 0644); err != nil { + t.Fatalf("failed to write config file: %v", err) + } + + h := management.NewHandler(cfg, configPath, nil) + return h, configPath +} + +// setupAmpRouter creates a test router with all ampcode management endpoints. +func setupAmpRouter(h *management.Handler) *gin.Engine { + r := gin.New() + mgmt := r.Group("/v0/management") + { + mgmt.GET("/ampcode", h.GetAmpCode) + mgmt.GET("/ampcode/upstream-url", h.GetAmpUpstreamURL) + mgmt.PUT("/ampcode/upstream-url", h.PutAmpUpstreamURL) + mgmt.DELETE("/ampcode/upstream-url", h.DeleteAmpUpstreamURL) + mgmt.GET("/ampcode/upstream-api-key", h.GetAmpUpstreamAPIKey) + mgmt.PUT("/ampcode/upstream-api-key", h.PutAmpUpstreamAPIKey) + mgmt.DELETE("/ampcode/upstream-api-key", h.DeleteAmpUpstreamAPIKey) + mgmt.GET("/ampcode/restrict-management-to-localhost", h.GetAmpRestrictManagementToLocalhost) + mgmt.PUT("/ampcode/restrict-management-to-localhost", h.PutAmpRestrictManagementToLocalhost) + mgmt.GET("/ampcode/model-mappings", h.GetAmpModelMappings) + mgmt.PUT("/ampcode/model-mappings", h.PutAmpModelMappings) + mgmt.PATCH("/ampcode/model-mappings", h.PatchAmpModelMappings) + mgmt.DELETE("/ampcode/model-mappings", h.DeleteAmpModelMappings) + mgmt.GET("/ampcode/force-model-mappings", h.GetAmpForceModelMappings) + mgmt.PUT("/ampcode/force-model-mappings", h.PutAmpForceModelMappings) + } + return r +} + +// TestGetAmpCode verifies GET /v0/management/ampcode returns full ampcode config. +func TestGetAmpCode(t *testing.T) { + h, _ := newAmpTestHandler(t) + r := setupAmpRouter(h) + + req := httptest.NewRequest(http.MethodGet, "/v0/management/ampcode", nil) + w := httptest.NewRecorder() + r.ServeHTTP(w, req) + + if w.Code != http.StatusOK { + t.Fatalf("expected status %d, got %d", http.StatusOK, w.Code) + } + + var resp map[string]config.AmpCode + if err := json.Unmarshal(w.Body.Bytes(), &resp); err != nil { + t.Fatalf("failed to unmarshal response: %v", err) + } + + ampcode := resp["ampcode"] + if ampcode.UpstreamURL != "https://example.com" { + t.Errorf("expected upstream-url %q, got %q", "https://example.com", ampcode.UpstreamURL) + } + if len(ampcode.ModelMappings) != 1 { + t.Errorf("expected 1 model mapping, got %d", len(ampcode.ModelMappings)) + } +} + +// TestGetAmpUpstreamURL verifies GET /v0/management/ampcode/upstream-url returns the upstream URL. +func TestGetAmpUpstreamURL(t *testing.T) { + h, _ := newAmpTestHandler(t) + r := setupAmpRouter(h) + + req := httptest.NewRequest(http.MethodGet, "/v0/management/ampcode/upstream-url", nil) + w := httptest.NewRecorder() + r.ServeHTTP(w, req) + + if w.Code != http.StatusOK { + t.Fatalf("expected status %d, got %d", http.StatusOK, w.Code) + } + + var resp map[string]string + if err := json.Unmarshal(w.Body.Bytes(), &resp); err != nil { + t.Fatalf("failed to unmarshal response: %v", err) + } + + if resp["upstream-url"] != "https://example.com" { + t.Errorf("expected %q, got %q", "https://example.com", resp["upstream-url"]) + } +} + +// TestPutAmpUpstreamURL verifies PUT /v0/management/ampcode/upstream-url updates the upstream URL. +func TestPutAmpUpstreamURL(t *testing.T) { + h, _ := newAmpTestHandler(t) + r := setupAmpRouter(h) + + body := `{"value": "https://new-upstream.com"}` + req := httptest.NewRequest(http.MethodPut, "/v0/management/ampcode/upstream-url", bytes.NewBufferString(body)) + req.Header.Set("Content-Type", "application/json") + w := httptest.NewRecorder() + r.ServeHTTP(w, req) + + if w.Code != http.StatusOK { + t.Fatalf("expected status %d, got %d: %s", http.StatusOK, w.Code, w.Body.String()) + } +} + +// TestDeleteAmpUpstreamURL verifies DELETE /v0/management/ampcode/upstream-url clears the upstream URL. +func TestDeleteAmpUpstreamURL(t *testing.T) { + h, _ := newAmpTestHandler(t) + r := setupAmpRouter(h) + + req := httptest.NewRequest(http.MethodDelete, "/v0/management/ampcode/upstream-url", nil) + w := httptest.NewRecorder() + r.ServeHTTP(w, req) + + if w.Code != http.StatusOK { + t.Fatalf("expected status %d, got %d", http.StatusOK, w.Code) + } +} + +// TestGetAmpUpstreamAPIKey verifies GET /v0/management/ampcode/upstream-api-key returns the API key. +func TestGetAmpUpstreamAPIKey(t *testing.T) { + h, _ := newAmpTestHandler(t) + r := setupAmpRouter(h) + + req := httptest.NewRequest(http.MethodGet, "/v0/management/ampcode/upstream-api-key", nil) + w := httptest.NewRecorder() + r.ServeHTTP(w, req) + + if w.Code != http.StatusOK { + t.Fatalf("expected status %d, got %d", http.StatusOK, w.Code) + } + + var resp map[string]any + if err := json.Unmarshal(w.Body.Bytes(), &resp); err != nil { + t.Fatalf("failed to unmarshal response: %v", err) + } + + key := resp["upstream-api-key"].(string) + if key != "test-api-key-12345" { + t.Errorf("expected key %q, got %q", "test-api-key-12345", key) + } +} + +// TestPutAmpUpstreamAPIKey verifies PUT /v0/management/ampcode/upstream-api-key updates the API key. +func TestPutAmpUpstreamAPIKey(t *testing.T) { + h, _ := newAmpTestHandler(t) + r := setupAmpRouter(h) + + body := `{"value": "new-secret-key"}` + req := httptest.NewRequest(http.MethodPut, "/v0/management/ampcode/upstream-api-key", bytes.NewBufferString(body)) + req.Header.Set("Content-Type", "application/json") + w := httptest.NewRecorder() + r.ServeHTTP(w, req) + + if w.Code != http.StatusOK { + t.Fatalf("expected status %d, got %d", http.StatusOK, w.Code) + } +} + +// TestDeleteAmpUpstreamAPIKey verifies DELETE /v0/management/ampcode/upstream-api-key clears the API key. +func TestDeleteAmpUpstreamAPIKey(t *testing.T) { + h, _ := newAmpTestHandler(t) + r := setupAmpRouter(h) + + req := httptest.NewRequest(http.MethodDelete, "/v0/management/ampcode/upstream-api-key", nil) + w := httptest.NewRecorder() + r.ServeHTTP(w, req) + + if w.Code != http.StatusOK { + t.Fatalf("expected status %d, got %d", http.StatusOK, w.Code) + } +} + +// TestGetAmpRestrictManagementToLocalhost verifies GET returns the localhost restriction setting. +func TestGetAmpRestrictManagementToLocalhost(t *testing.T) { + h, _ := newAmpTestHandler(t) + r := setupAmpRouter(h) + + req := httptest.NewRequest(http.MethodGet, "/v0/management/ampcode/restrict-management-to-localhost", nil) + w := httptest.NewRecorder() + r.ServeHTTP(w, req) + + if w.Code != http.StatusOK { + t.Fatalf("expected status %d, got %d", http.StatusOK, w.Code) + } + + var resp map[string]bool + if err := json.Unmarshal(w.Body.Bytes(), &resp); err != nil { + t.Fatalf("failed to unmarshal response: %v", err) + } + + if resp["restrict-management-to-localhost"] != true { + t.Error("expected restrict-management-to-localhost to be true") + } +} + +// TestPutAmpRestrictManagementToLocalhost verifies PUT updates the localhost restriction setting. +func TestPutAmpRestrictManagementToLocalhost(t *testing.T) { + h, _ := newAmpTestHandler(t) + r := setupAmpRouter(h) + + body := `{"value": false}` + req := httptest.NewRequest(http.MethodPut, "/v0/management/ampcode/restrict-management-to-localhost", bytes.NewBufferString(body)) + req.Header.Set("Content-Type", "application/json") + w := httptest.NewRecorder() + r.ServeHTTP(w, req) + + if w.Code != http.StatusOK { + t.Fatalf("expected status %d, got %d", http.StatusOK, w.Code) + } +} + +// TestGetAmpModelMappings verifies GET /v0/management/ampcode/model-mappings returns all mappings. +func TestGetAmpModelMappings(t *testing.T) { + h, _ := newAmpTestHandler(t) + r := setupAmpRouter(h) + + req := httptest.NewRequest(http.MethodGet, "/v0/management/ampcode/model-mappings", nil) + w := httptest.NewRecorder() + r.ServeHTTP(w, req) + + if w.Code != http.StatusOK { + t.Fatalf("expected status %d, got %d", http.StatusOK, w.Code) + } + + var resp map[string][]config.AmpModelMapping + if err := json.Unmarshal(w.Body.Bytes(), &resp); err != nil { + t.Fatalf("failed to unmarshal response: %v", err) + } + + mappings := resp["model-mappings"] + if len(mappings) != 1 { + t.Fatalf("expected 1 mapping, got %d", len(mappings)) + } + if mappings[0].From != "gpt-4" || mappings[0].To != "gemini-pro" { + t.Errorf("unexpected mapping: %+v", mappings[0]) + } +} + +// TestPutAmpModelMappings verifies PUT /v0/management/ampcode/model-mappings replaces all mappings. +func TestPutAmpModelMappings(t *testing.T) { + h, _ := newAmpTestHandler(t) + r := setupAmpRouter(h) + + body := `{"value": [{"from": "claude-3", "to": "gpt-4o"}, {"from": "gemini", "to": "claude"}]}` + req := httptest.NewRequest(http.MethodPut, "/v0/management/ampcode/model-mappings", bytes.NewBufferString(body)) + req.Header.Set("Content-Type", "application/json") + w := httptest.NewRecorder() + r.ServeHTTP(w, req) + + if w.Code != http.StatusOK { + t.Fatalf("expected status %d, got %d: %s", http.StatusOK, w.Code, w.Body.String()) + } +} + +// TestPatchAmpModelMappings verifies PATCH updates existing mappings and adds new ones. +func TestPatchAmpModelMappings(t *testing.T) { + h, _ := newAmpTestHandler(t) + r := setupAmpRouter(h) + + body := `{"value": [{"from": "gpt-4", "to": "updated-model"}, {"from": "new-model", "to": "target"}]}` + req := httptest.NewRequest(http.MethodPatch, "/v0/management/ampcode/model-mappings", bytes.NewBufferString(body)) + req.Header.Set("Content-Type", "application/json") + w := httptest.NewRecorder() + r.ServeHTTP(w, req) + + if w.Code != http.StatusOK { + t.Fatalf("expected status %d, got %d: %s", http.StatusOK, w.Code, w.Body.String()) + } +} + +// TestDeleteAmpModelMappings_Specific verifies DELETE removes specified mappings by "from" field. +func TestDeleteAmpModelMappings_Specific(t *testing.T) { + h, _ := newAmpTestHandler(t) + r := setupAmpRouter(h) + + body := `{"value": ["gpt-4"]}` + req := httptest.NewRequest(http.MethodDelete, "/v0/management/ampcode/model-mappings", bytes.NewBufferString(body)) + req.Header.Set("Content-Type", "application/json") + w := httptest.NewRecorder() + r.ServeHTTP(w, req) + + if w.Code != http.StatusOK { + t.Fatalf("expected status %d, got %d", http.StatusOK, w.Code) + } +} + +// TestDeleteAmpModelMappings_All verifies DELETE with empty body removes all mappings. +func TestDeleteAmpModelMappings_All(t *testing.T) { + h, _ := newAmpTestHandler(t) + r := setupAmpRouter(h) + + req := httptest.NewRequest(http.MethodDelete, "/v0/management/ampcode/model-mappings", nil) + w := httptest.NewRecorder() + r.ServeHTTP(w, req) + + if w.Code != http.StatusOK { + t.Fatalf("expected status %d, got %d", http.StatusOK, w.Code) + } +} + +// TestGetAmpForceModelMappings verifies GET returns the force-model-mappings setting. +func TestGetAmpForceModelMappings(t *testing.T) { + h, _ := newAmpTestHandler(t) + r := setupAmpRouter(h) + + req := httptest.NewRequest(http.MethodGet, "/v0/management/ampcode/force-model-mappings", nil) + w := httptest.NewRecorder() + r.ServeHTTP(w, req) + + if w.Code != http.StatusOK { + t.Fatalf("expected status %d, got %d", http.StatusOK, w.Code) + } + + var resp map[string]bool + if err := json.Unmarshal(w.Body.Bytes(), &resp); err != nil { + t.Fatalf("failed to unmarshal response: %v", err) + } + + if resp["force-model-mappings"] != false { + t.Error("expected force-model-mappings to be false") + } +} + +// TestPutAmpForceModelMappings verifies PUT updates the force-model-mappings setting. +func TestPutAmpForceModelMappings(t *testing.T) { + h, _ := newAmpTestHandler(t) + r := setupAmpRouter(h) + + body := `{"value": true}` + req := httptest.NewRequest(http.MethodPut, "/v0/management/ampcode/force-model-mappings", bytes.NewBufferString(body)) + req.Header.Set("Content-Type", "application/json") + w := httptest.NewRecorder() + r.ServeHTTP(w, req) + + if w.Code != http.StatusOK { + t.Fatalf("expected status %d, got %d", http.StatusOK, w.Code) + } +} + +// TestPutAmpModelMappings_VerifyState verifies PUT replaces mappings and state is persisted. +func TestPutAmpModelMappings_VerifyState(t *testing.T) { + h, _ := newAmpTestHandler(t) + r := setupAmpRouter(h) + + body := `{"value": [{"from": "model-a", "to": "model-b"}, {"from": "model-c", "to": "model-d"}, {"from": "model-e", "to": "model-f"}]}` + req := httptest.NewRequest(http.MethodPut, "/v0/management/ampcode/model-mappings", bytes.NewBufferString(body)) + req.Header.Set("Content-Type", "application/json") + w := httptest.NewRecorder() + r.ServeHTTP(w, req) + + if w.Code != http.StatusOK { + t.Fatalf("PUT failed: status %d, body: %s", w.Code, w.Body.String()) + } + + req = httptest.NewRequest(http.MethodGet, "/v0/management/ampcode/model-mappings", nil) + w = httptest.NewRecorder() + r.ServeHTTP(w, req) + + var resp map[string][]config.AmpModelMapping + if err := json.Unmarshal(w.Body.Bytes(), &resp); err != nil { + t.Fatalf("failed to unmarshal: %v", err) + } + + mappings := resp["model-mappings"] + if len(mappings) != 3 { + t.Fatalf("expected 3 mappings, got %d", len(mappings)) + } + + expected := map[string]string{"model-a": "model-b", "model-c": "model-d", "model-e": "model-f"} + for _, m := range mappings { + if expected[m.From] != m.To { + t.Errorf("mapping %q -> expected %q, got %q", m.From, expected[m.From], m.To) + } + } +} + +// TestPatchAmpModelMappings_VerifyState verifies PATCH merges mappings correctly. +func TestPatchAmpModelMappings_VerifyState(t *testing.T) { + h, _ := newAmpTestHandler(t) + r := setupAmpRouter(h) + + body := `{"value": [{"from": "gpt-4", "to": "updated-target"}, {"from": "new-model", "to": "new-target"}]}` + req := httptest.NewRequest(http.MethodPatch, "/v0/management/ampcode/model-mappings", bytes.NewBufferString(body)) + req.Header.Set("Content-Type", "application/json") + w := httptest.NewRecorder() + r.ServeHTTP(w, req) + + if w.Code != http.StatusOK { + t.Fatalf("PATCH failed: status %d", w.Code) + } + + req = httptest.NewRequest(http.MethodGet, "/v0/management/ampcode/model-mappings", nil) + w = httptest.NewRecorder() + r.ServeHTTP(w, req) + + var resp map[string][]config.AmpModelMapping + if err := json.Unmarshal(w.Body.Bytes(), &resp); err != nil { + t.Fatalf("failed to unmarshal: %v", err) + } + + mappings := resp["model-mappings"] + if len(mappings) != 2 { + t.Fatalf("expected 2 mappings (1 updated + 1 new), got %d", len(mappings)) + } + + found := make(map[string]string) + for _, m := range mappings { + found[m.From] = m.To + } + + if found["gpt-4"] != "updated-target" { + t.Errorf("gpt-4 should map to updated-target, got %q", found["gpt-4"]) + } + if found["new-model"] != "new-target" { + t.Errorf("new-model should map to new-target, got %q", found["new-model"]) + } +} + +// TestDeleteAmpModelMappings_VerifyState verifies DELETE removes specific mappings and keeps others. +func TestDeleteAmpModelMappings_VerifyState(t *testing.T) { + h, _ := newAmpTestHandler(t) + r := setupAmpRouter(h) + + putBody := `{"value": [{"from": "a", "to": "1"}, {"from": "b", "to": "2"}, {"from": "c", "to": "3"}]}` + req := httptest.NewRequest(http.MethodPut, "/v0/management/ampcode/model-mappings", bytes.NewBufferString(putBody)) + req.Header.Set("Content-Type", "application/json") + w := httptest.NewRecorder() + r.ServeHTTP(w, req) + + delBody := `{"value": ["a", "c"]}` + req = httptest.NewRequest(http.MethodDelete, "/v0/management/ampcode/model-mappings", bytes.NewBufferString(delBody)) + req.Header.Set("Content-Type", "application/json") + w = httptest.NewRecorder() + r.ServeHTTP(w, req) + + if w.Code != http.StatusOK { + t.Fatalf("DELETE failed: status %d", w.Code) + } + + req = httptest.NewRequest(http.MethodGet, "/v0/management/ampcode/model-mappings", nil) + w = httptest.NewRecorder() + r.ServeHTTP(w, req) + + var resp map[string][]config.AmpModelMapping + if err := json.Unmarshal(w.Body.Bytes(), &resp); err != nil { + t.Fatalf("failed to unmarshal: %v", err) + } + + mappings := resp["model-mappings"] + if len(mappings) != 1 { + t.Fatalf("expected 1 mapping remaining, got %d", len(mappings)) + } + if mappings[0].From != "b" || mappings[0].To != "2" { + t.Errorf("expected b->2, got %s->%s", mappings[0].From, mappings[0].To) + } +} + +// TestDeleteAmpModelMappings_NonExistent verifies DELETE with non-existent mapping doesn't affect existing ones. +func TestDeleteAmpModelMappings_NonExistent(t *testing.T) { + h, _ := newAmpTestHandler(t) + r := setupAmpRouter(h) + + delBody := `{"value": ["non-existent-model"]}` + req := httptest.NewRequest(http.MethodDelete, "/v0/management/ampcode/model-mappings", bytes.NewBufferString(delBody)) + req.Header.Set("Content-Type", "application/json") + w := httptest.NewRecorder() + r.ServeHTTP(w, req) + + if w.Code != http.StatusOK { + t.Fatalf("expected status %d, got %d", http.StatusOK, w.Code) + } + + req = httptest.NewRequest(http.MethodGet, "/v0/management/ampcode/model-mappings", nil) + w = httptest.NewRecorder() + r.ServeHTTP(w, req) + + var resp map[string][]config.AmpModelMapping + if err := json.Unmarshal(w.Body.Bytes(), &resp); err != nil { + t.Fatalf("failed to unmarshal: %v", err) + } + + if len(resp["model-mappings"]) != 1 { + t.Errorf("original mapping should remain, got %d mappings", len(resp["model-mappings"])) + } +} + +// TestPutAmpModelMappings_Empty verifies PUT with empty array clears all mappings. +func TestPutAmpModelMappings_Empty(t *testing.T) { + h, _ := newAmpTestHandler(t) + r := setupAmpRouter(h) + + body := `{"value": []}` + req := httptest.NewRequest(http.MethodPut, "/v0/management/ampcode/model-mappings", bytes.NewBufferString(body)) + req.Header.Set("Content-Type", "application/json") + w := httptest.NewRecorder() + r.ServeHTTP(w, req) + + if w.Code != http.StatusOK { + t.Fatalf("expected status %d, got %d", http.StatusOK, w.Code) + } + + req = httptest.NewRequest(http.MethodGet, "/v0/management/ampcode/model-mappings", nil) + w = httptest.NewRecorder() + r.ServeHTTP(w, req) + + var resp map[string][]config.AmpModelMapping + if err := json.Unmarshal(w.Body.Bytes(), &resp); err != nil { + t.Fatalf("failed to unmarshal: %v", err) + } + + if len(resp["model-mappings"]) != 0 { + t.Errorf("expected 0 mappings, got %d", len(resp["model-mappings"])) + } +} + +// TestPutAmpUpstreamURL_VerifyState verifies PUT updates upstream URL and persists state. +func TestPutAmpUpstreamURL_VerifyState(t *testing.T) { + h, _ := newAmpTestHandler(t) + r := setupAmpRouter(h) + + body := `{"value": "https://new-api.example.com"}` + req := httptest.NewRequest(http.MethodPut, "/v0/management/ampcode/upstream-url", bytes.NewBufferString(body)) + req.Header.Set("Content-Type", "application/json") + w := httptest.NewRecorder() + r.ServeHTTP(w, req) + + if w.Code != http.StatusOK { + t.Fatalf("PUT failed: status %d", w.Code) + } + + req = httptest.NewRequest(http.MethodGet, "/v0/management/ampcode/upstream-url", nil) + w = httptest.NewRecorder() + r.ServeHTTP(w, req) + + var resp map[string]string + if err := json.Unmarshal(w.Body.Bytes(), &resp); err != nil { + t.Fatalf("failed to unmarshal: %v", err) + } + + if resp["upstream-url"] != "https://new-api.example.com" { + t.Errorf("expected %q, got %q", "https://new-api.example.com", resp["upstream-url"]) + } +} + +// TestDeleteAmpUpstreamURL_VerifyState verifies DELETE clears upstream URL. +func TestDeleteAmpUpstreamURL_VerifyState(t *testing.T) { + h, _ := newAmpTestHandler(t) + r := setupAmpRouter(h) + + req := httptest.NewRequest(http.MethodDelete, "/v0/management/ampcode/upstream-url", nil) + w := httptest.NewRecorder() + r.ServeHTTP(w, req) + + if w.Code != http.StatusOK { + t.Fatalf("DELETE failed: status %d", w.Code) + } + + req = httptest.NewRequest(http.MethodGet, "/v0/management/ampcode/upstream-url", nil) + w = httptest.NewRecorder() + r.ServeHTTP(w, req) + + var resp map[string]string + if err := json.Unmarshal(w.Body.Bytes(), &resp); err != nil { + t.Fatalf("failed to unmarshal: %v", err) + } + + if resp["upstream-url"] != "" { + t.Errorf("expected empty string, got %q", resp["upstream-url"]) + } +} + +// TestPutAmpUpstreamAPIKey_VerifyState verifies PUT updates API key and persists state. +func TestPutAmpUpstreamAPIKey_VerifyState(t *testing.T) { + h, _ := newAmpTestHandler(t) + r := setupAmpRouter(h) + + body := `{"value": "new-secret-api-key-xyz"}` + req := httptest.NewRequest(http.MethodPut, "/v0/management/ampcode/upstream-api-key", bytes.NewBufferString(body)) + req.Header.Set("Content-Type", "application/json") + w := httptest.NewRecorder() + r.ServeHTTP(w, req) + + if w.Code != http.StatusOK { + t.Fatalf("PUT failed: status %d", w.Code) + } + + req = httptest.NewRequest(http.MethodGet, "/v0/management/ampcode/upstream-api-key", nil) + w = httptest.NewRecorder() + r.ServeHTTP(w, req) + + var resp map[string]string + if err := json.Unmarshal(w.Body.Bytes(), &resp); err != nil { + t.Fatalf("failed to unmarshal: %v", err) + } + + if resp["upstream-api-key"] != "new-secret-api-key-xyz" { + t.Errorf("expected %q, got %q", "new-secret-api-key-xyz", resp["upstream-api-key"]) + } +} + +// TestDeleteAmpUpstreamAPIKey_VerifyState verifies DELETE clears API key. +func TestDeleteAmpUpstreamAPIKey_VerifyState(t *testing.T) { + h, _ := newAmpTestHandler(t) + r := setupAmpRouter(h) + + req := httptest.NewRequest(http.MethodDelete, "/v0/management/ampcode/upstream-api-key", nil) + w := httptest.NewRecorder() + r.ServeHTTP(w, req) + + if w.Code != http.StatusOK { + t.Fatalf("DELETE failed: status %d", w.Code) + } + + req = httptest.NewRequest(http.MethodGet, "/v0/management/ampcode/upstream-api-key", nil) + w = httptest.NewRecorder() + r.ServeHTTP(w, req) + + var resp map[string]string + if err := json.Unmarshal(w.Body.Bytes(), &resp); err != nil { + t.Fatalf("failed to unmarshal: %v", err) + } + + if resp["upstream-api-key"] != "" { + t.Errorf("expected empty string, got %q", resp["upstream-api-key"]) + } +} + +// TestPutAmpRestrictManagementToLocalhost_VerifyState verifies PUT updates localhost restriction. +func TestPutAmpRestrictManagementToLocalhost_VerifyState(t *testing.T) { + h, _ := newAmpTestHandler(t) + r := setupAmpRouter(h) + + body := `{"value": false}` + req := httptest.NewRequest(http.MethodPut, "/v0/management/ampcode/restrict-management-to-localhost", bytes.NewBufferString(body)) + req.Header.Set("Content-Type", "application/json") + w := httptest.NewRecorder() + r.ServeHTTP(w, req) + + if w.Code != http.StatusOK { + t.Fatalf("PUT failed: status %d", w.Code) + } + + req = httptest.NewRequest(http.MethodGet, "/v0/management/ampcode/restrict-management-to-localhost", nil) + w = httptest.NewRecorder() + r.ServeHTTP(w, req) + + var resp map[string]bool + if err := json.Unmarshal(w.Body.Bytes(), &resp); err != nil { + t.Fatalf("failed to unmarshal: %v", err) + } + + if resp["restrict-management-to-localhost"] != false { + t.Error("expected false after update") + } +} + +// TestPutAmpForceModelMappings_VerifyState verifies PUT updates force-model-mappings setting. +func TestPutAmpForceModelMappings_VerifyState(t *testing.T) { + h, _ := newAmpTestHandler(t) + r := setupAmpRouter(h) + + body := `{"value": true}` + req := httptest.NewRequest(http.MethodPut, "/v0/management/ampcode/force-model-mappings", bytes.NewBufferString(body)) + req.Header.Set("Content-Type", "application/json") + w := httptest.NewRecorder() + r.ServeHTTP(w, req) + + if w.Code != http.StatusOK { + t.Fatalf("PUT failed: status %d", w.Code) + } + + req = httptest.NewRequest(http.MethodGet, "/v0/management/ampcode/force-model-mappings", nil) + w = httptest.NewRecorder() + r.ServeHTTP(w, req) + + var resp map[string]bool + if err := json.Unmarshal(w.Body.Bytes(), &resp); err != nil { + t.Fatalf("failed to unmarshal: %v", err) + } + + if resp["force-model-mappings"] != true { + t.Error("expected true after update") + } +} + +// TestPutBoolField_EmptyObject verifies PUT with empty object returns 400. +func TestPutBoolField_EmptyObject(t *testing.T) { + h, _ := newAmpTestHandler(t) + r := setupAmpRouter(h) + + body := `{}` + req := httptest.NewRequest(http.MethodPut, "/v0/management/ampcode/force-model-mappings", bytes.NewBufferString(body)) + req.Header.Set("Content-Type", "application/json") + w := httptest.NewRecorder() + r.ServeHTTP(w, req) + + if w.Code != http.StatusBadRequest { + t.Fatalf("expected status %d for empty object, got %d", http.StatusBadRequest, w.Code) + } +} + +// TestComplexMappingsWorkflow tests a full workflow: PUT, PATCH, DELETE, and GET. +func TestComplexMappingsWorkflow(t *testing.T) { + h, _ := newAmpTestHandler(t) + r := setupAmpRouter(h) + + putBody := `{"value": [{"from": "m1", "to": "t1"}, {"from": "m2", "to": "t2"}, {"from": "m3", "to": "t3"}, {"from": "m4", "to": "t4"}]}` + req := httptest.NewRequest(http.MethodPut, "/v0/management/ampcode/model-mappings", bytes.NewBufferString(putBody)) + req.Header.Set("Content-Type", "application/json") + w := httptest.NewRecorder() + r.ServeHTTP(w, req) + + patchBody := `{"value": [{"from": "m2", "to": "t2-updated"}, {"from": "m5", "to": "t5"}]}` + req = httptest.NewRequest(http.MethodPatch, "/v0/management/ampcode/model-mappings", bytes.NewBufferString(patchBody)) + req.Header.Set("Content-Type", "application/json") + w = httptest.NewRecorder() + r.ServeHTTP(w, req) + + delBody := `{"value": ["m1", "m3"]}` + req = httptest.NewRequest(http.MethodDelete, "/v0/management/ampcode/model-mappings", bytes.NewBufferString(delBody)) + req.Header.Set("Content-Type", "application/json") + w = httptest.NewRecorder() + r.ServeHTTP(w, req) + + req = httptest.NewRequest(http.MethodGet, "/v0/management/ampcode/model-mappings", nil) + w = httptest.NewRecorder() + r.ServeHTTP(w, req) + + var resp map[string][]config.AmpModelMapping + if err := json.Unmarshal(w.Body.Bytes(), &resp); err != nil { + t.Fatalf("failed to unmarshal: %v", err) + } + + mappings := resp["model-mappings"] + if len(mappings) != 3 { + t.Fatalf("expected 3 mappings (m2, m4, m5), got %d", len(mappings)) + } + + expected := map[string]string{"m2": "t2-updated", "m4": "t4", "m5": "t5"} + found := make(map[string]string) + for _, m := range mappings { + found[m.From] = m.To + } + + for from, to := range expected { + if found[from] != to { + t.Errorf("mapping %s: expected %q, got %q", from, to, found[from]) + } + } +} + +// TestNilHandlerGetAmpCode verifies handler works with empty config. +func TestNilHandlerGetAmpCode(t *testing.T) { + cfg := &config.Config{} + h := management.NewHandler(cfg, "", nil) + r := setupAmpRouter(h) + + req := httptest.NewRequest(http.MethodGet, "/v0/management/ampcode", nil) + w := httptest.NewRecorder() + r.ServeHTTP(w, req) + + if w.Code != http.StatusOK { + t.Fatalf("expected status %d, got %d", http.StatusOK, w.Code) + } +} + +// TestEmptyConfigGetAmpModelMappings verifies GET returns empty array for fresh config. +func TestEmptyConfigGetAmpModelMappings(t *testing.T) { + cfg := &config.Config{} + tmpDir := t.TempDir() + configPath := filepath.Join(tmpDir, "config.yaml") + if err := os.WriteFile(configPath, []byte("port: 8080\n"), 0644); err != nil { + t.Fatalf("failed to write config: %v", err) + } + + h := management.NewHandler(cfg, configPath, nil) + r := setupAmpRouter(h) + + req := httptest.NewRequest(http.MethodGet, "/v0/management/ampcode/model-mappings", nil) + w := httptest.NewRecorder() + r.ServeHTTP(w, req) + + if w.Code != http.StatusOK { + t.Fatalf("expected status %d, got %d", http.StatusOK, w.Code) + } + + var resp map[string][]config.AmpModelMapping + if err := json.Unmarshal(w.Body.Bytes(), &resp); err != nil { + t.Fatalf("failed to unmarshal: %v", err) + } + + if len(resp["model-mappings"]) != 0 { + t.Errorf("expected 0 mappings, got %d", len(resp["model-mappings"])) + } +} From 1fa5514d56c8ff7a56ecfca91ffb91839e752995 Mon Sep 17 00:00:00 2001 From: Mario Date: Tue, 9 Dec 2025 20:13:16 +0800 Subject: [PATCH 19/38] fix kiro cannot refresh the token --- internal/watcher/watcher.go | 17 ++++++++++++++++- 1 file changed, 16 insertions(+), 1 deletion(-) diff --git a/internal/watcher/watcher.go b/internal/watcher/watcher.go index da152141..36276de9 100644 --- a/internal/watcher/watcher.go +++ b/internal/watcher/watcher.go @@ -1272,7 +1272,7 @@ func (w *Watcher) SnapshotCoreAuths() []*coreauth.Auth { } for i := range cfg.KiroKey { kk := cfg.KiroKey[i] - var accessToken, profileArn string + var accessToken, profileArn, refreshToken string // Try to load from token file first if kk.TokenFile != "" && kAuth != nil { @@ -1282,6 +1282,7 @@ func (w *Watcher) SnapshotCoreAuths() []*coreauth.Auth { } else { accessToken = tokenData.AccessToken profileArn = tokenData.ProfileArn + refreshToken = tokenData.RefreshToken } } @@ -1292,6 +1293,9 @@ func (w *Watcher) SnapshotCoreAuths() []*coreauth.Auth { if kk.ProfileArn != "" { profileArn = kk.ProfileArn } + if kk.RefreshToken != "" { + refreshToken = kk.RefreshToken + } if accessToken == "" { log.Warnf("kiro config[%d] missing access_token, skipping", i) @@ -1313,6 +1317,9 @@ func (w *Watcher) SnapshotCoreAuths() []*coreauth.Auth { if kk.AgentTaskType != "" { attrs["agent_task_type"] = kk.AgentTaskType } + if refreshToken != "" { + attrs["refresh_token"] = refreshToken + } proxyURL := strings.TrimSpace(kk.ProxyURL) a := &coreauth.Auth{ ID: id, @@ -1324,6 +1331,14 @@ func (w *Watcher) SnapshotCoreAuths() []*coreauth.Auth { CreatedAt: now, UpdatedAt: now, } + + if refreshToken != "" { + if a.Metadata == nil { + a.Metadata = make(map[string]any) + } + a.Metadata["refresh_token"] = refreshToken + } + out = append(out, a) } for i := range cfg.OpenAICompatibility { From a594338bc57cc6df86b90a5ce10a72f2de880a07 Mon Sep 17 00:00:00 2001 From: fuko2935 Date: Tue, 9 Dec 2025 19:14:40 +0300 Subject: [PATCH 20/38] fix(registry): remove unstable kiro-auto model - Removes kiro-auto from static model registry - Removes kiro-auto mapping from executor - Fixes compatibility issues reported in #7 Fixes #7 --- internal/registry/model_definitions.go | 11 ----------- internal/runtime/executor/kiro_executor.go | 2 -- 2 files changed, 13 deletions(-) diff --git a/internal/registry/model_definitions.go b/internal/registry/model_definitions.go index 3c31e61f..31f08f98 100644 --- a/internal/registry/model_definitions.go +++ b/internal/registry/model_definitions.go @@ -1195,17 +1195,6 @@ func GetGitHubCopilotModels() []*ModelInfo { // GetKiroModels returns the Kiro (AWS CodeWhisperer) model definitions func GetKiroModels() []*ModelInfo { return []*ModelInfo{ - { - ID: "kiro-auto", - Object: "model", - Created: 1732752000, // 2024-11-28 - OwnedBy: "aws", - Type: "kiro", - DisplayName: "Kiro Auto", - Description: "Automatic model selection by AWS CodeWhisperer", - ContextLength: 200000, - MaxCompletionTokens: 64000, - }, { ID: "kiro-claude-opus-4.5", Object: "model", diff --git a/internal/runtime/executor/kiro_executor.go b/internal/runtime/executor/kiro_executor.go index 0157d68c..b965c9ca 100644 --- a/internal/runtime/executor/kiro_executor.go +++ b/internal/runtime/executor/kiro_executor.go @@ -547,8 +547,6 @@ func kiroCredentials(auth *cliproxyauth.Auth) (accessToken, profileArn string) { // Agentic variants (-agentic suffix) map to the same backend model IDs. func (e *KiroExecutor) mapModelToKiro(model string) string { modelMap := map[string]string{ - // Proxy format (kiro- prefix) - "kiro-auto": "auto", "kiro-claude-opus-4.5": "claude-opus-4.5", "kiro-claude-sonnet-4.5": "claude-sonnet-4.5", "kiro-claude-sonnet-4": "claude-sonnet-4", From 084e2666cb261f0fc0c2ee92e0f38bbb6e0e5e97 Mon Sep 17 00:00:00 2001 From: Ravens Date: Thu, 11 Dec 2025 00:13:44 +0800 Subject: [PATCH 21/38] fix(kiro): add SSE event: prefix for Claude client compatibility Amp-Thread-ID: https://ampcode.com/threads/T-019b08fc-ff96-766e-a942-63dd35ed28c6 Co-authored-by: Amp --- .../translator/kiro/claude/kiro_claude.go | 50 ++++++++++++++++++- 1 file changed, 49 insertions(+), 1 deletion(-) diff --git a/internal/translator/kiro/claude/kiro_claude.go b/internal/translator/kiro/claude/kiro_claude.go index 9922860e..335873a7 100644 --- a/internal/translator/kiro/claude/kiro_claude.go +++ b/internal/translator/kiro/claude/kiro_claude.go @@ -1,10 +1,14 @@ // Package claude provides translation between Kiro and Claude formats. // Since Kiro uses Claude-compatible format internally, translations are mostly pass-through. +// However, SSE events require proper "event: " prefix for Claude clients. package claude import ( "bytes" "context" + "strings" + + "github.com/tidwall/gjson" ) // ConvertClaudeRequestToKiro converts Claude request to Kiro format. @@ -14,8 +18,52 @@ func ConvertClaudeRequestToKiro(modelName string, inputRawJSON []byte, stream bo } // ConvertKiroResponseToClaude converts Kiro streaming response to Claude format. +// It adds the required "event: " prefix for SSE compliance with Claude clients. +// Input format: "data: {\"type\":\"message_start\",...}" +// Output format: "event: message_start\ndata: {\"type\":\"message_start\",...}" func ConvertKiroResponseToClaude(ctx context.Context, model string, originalRequest, request, rawResponse []byte, param *any) []string { - return []string{string(rawResponse)} + raw := string(rawResponse) + + // Handle multiple data blocks (e.g., message_delta + message_stop) + lines := strings.Split(raw, "\n\n") + var results []string + + for _, line := range lines { + line = strings.TrimSpace(line) + if line == "" { + continue + } + + // Extract event type from JSON and add "event:" prefix + formatted := addEventPrefix(line) + if formatted != "" { + results = append(results, formatted) + } + } + + if len(results) == 0 { + return []string{raw} + } + + return results +} + +// addEventPrefix extracts the event type from the data line and adds the event: prefix. +// Input: "data: {\"type\":\"message_start\",...}" +// Output: "event: message_start\ndata: {\"type\":\"message_start\",...}" +func addEventPrefix(dataLine string) string { + if !strings.HasPrefix(dataLine, "data: ") { + return dataLine + } + + jsonPart := strings.TrimPrefix(dataLine, "data: ") + eventType := gjson.Get(jsonPart, "type").String() + + if eventType == "" { + return dataLine + } + + return "event: " + eventType + "\n" + dataLine } // ConvertKiroResponseToClaudeNonStream converts Kiro non-streaming response to Claude format. From 8d5f89ccfd02f4a3228a431f6d3cbbaf8a8887a1 Mon Sep 17 00:00:00 2001 From: Ravens Date: Thu, 11 Dec 2025 01:15:00 +0800 Subject: [PATCH 22/38] fix(kiro): fix translator format mismatch for OpenAI protocol Amp-Thread-ID: https://ampcode.com/threads/T-019b092b-f2de-72a1-b428-72511c0de628 Co-authored-by: Amp --- internal/runtime/executor/kiro_executor.go | 51 ++++++++--------- .../translator/kiro/claude/kiro_claude.go | 55 ++----------------- .../chat-completions/kiro_openai_response.go | 48 +++++++++++++++- 3 files changed, 77 insertions(+), 77 deletions(-) diff --git a/internal/runtime/executor/kiro_executor.go b/internal/runtime/executor/kiro_executor.go index b965c9ca..b69fd8be 100644 --- a/internal/runtime/executor/kiro_executor.go +++ b/internal/runtime/executor/kiro_executor.go @@ -1323,7 +1323,7 @@ func (e *KiroExecutor) streamToChannel(ctx context.Context, body io.Reader, out // Send message_start on first event if !messageStartSent { msgStart := e.buildClaudeMessageStartEvent(model, totalUsage.InputTokens) - sseData := sdktranslator.TranslateStream(ctx, sdktranslator.FromString("claude"), targetFormat, model, originalReq, claudeBody, msgStart, &translatorParam) + sseData := sdktranslator.TranslateStream(ctx, sdktranslator.FromString("kiro"), targetFormat, model, originalReq, claudeBody, msgStart, &translatorParam) for _, chunk := range sseData { if chunk != "" { out <- cliproxyexecutor.StreamChunk{Payload: []byte(chunk + "\n\n")} @@ -1372,7 +1372,7 @@ func (e *KiroExecutor) streamToChannel(ctx context.Context, body io.Reader, out contentBlockIndex++ isTextBlockOpen = true blockStart := e.buildClaudeContentBlockStartEvent(contentBlockIndex, "text", "", "") - sseData := sdktranslator.TranslateStream(ctx, sdktranslator.FromString("claude"), targetFormat, model, originalReq, claudeBody, blockStart, &translatorParam) + sseData := sdktranslator.TranslateStream(ctx, sdktranslator.FromString("kiro"), targetFormat, model, originalReq, claudeBody, blockStart, &translatorParam) for _, chunk := range sseData { if chunk != "" { out <- cliproxyexecutor.StreamChunk{Payload: []byte(chunk + "\n\n")} @@ -1381,7 +1381,7 @@ func (e *KiroExecutor) streamToChannel(ctx context.Context, body io.Reader, out } claudeEvent := e.buildClaudeStreamEvent(contentDelta, contentBlockIndex) - sseData := sdktranslator.TranslateStream(ctx, sdktranslator.FromString("claude"), targetFormat, model, originalReq, claudeBody, claudeEvent, &translatorParam) + sseData := sdktranslator.TranslateStream(ctx, sdktranslator.FromString("kiro"), targetFormat, model, originalReq, claudeBody, claudeEvent, &translatorParam) for _, chunk := range sseData { if chunk != "" { out <- cliproxyexecutor.StreamChunk{Payload: []byte(chunk + "\n\n")} @@ -1404,7 +1404,7 @@ func (e *KiroExecutor) streamToChannel(ctx context.Context, body io.Reader, out // Close text block if open before starting tool_use block if isTextBlockOpen && contentBlockIndex >= 0 { blockStop := e.buildClaudeContentBlockStopEvent(contentBlockIndex) - sseData := sdktranslator.TranslateStream(ctx, sdktranslator.FromString("claude"), targetFormat, model, originalReq, claudeBody, blockStop, &translatorParam) + sseData := sdktranslator.TranslateStream(ctx, sdktranslator.FromString("kiro"), targetFormat, model, originalReq, claudeBody, blockStop, &translatorParam) for _, chunk := range sseData { if chunk != "" { out <- cliproxyexecutor.StreamChunk{Payload: []byte(chunk + "\n\n")} @@ -1418,7 +1418,7 @@ func (e *KiroExecutor) streamToChannel(ctx context.Context, body io.Reader, out toolName := getString(tu, "name") blockStart := e.buildClaudeContentBlockStartEvent(contentBlockIndex, "tool_use", toolUseID, toolName) - sseData := sdktranslator.TranslateStream(ctx, sdktranslator.FromString("claude"), targetFormat, model, originalReq, claudeBody, blockStart, &translatorParam) + sseData := sdktranslator.TranslateStream(ctx, sdktranslator.FromString("kiro"), targetFormat, model, originalReq, claudeBody, blockStart, &translatorParam) for _, chunk := range sseData { if chunk != "" { out <- cliproxyexecutor.StreamChunk{Payload: []byte(chunk + "\n\n")} @@ -1433,7 +1433,7 @@ func (e *KiroExecutor) streamToChannel(ctx context.Context, body io.Reader, out // Don't continue - still need to close the block } else { inputDelta := e.buildClaudeInputJsonDeltaEvent(string(inputJSON), contentBlockIndex) - sseData = sdktranslator.TranslateStream(ctx, sdktranslator.FromString("claude"), targetFormat, model, originalReq, claudeBody, inputDelta, &translatorParam) + sseData = sdktranslator.TranslateStream(ctx, sdktranslator.FromString("kiro"), targetFormat, model, originalReq, claudeBody, inputDelta, &translatorParam) for _, chunk := range sseData { if chunk != "" { out <- cliproxyexecutor.StreamChunk{Payload: []byte(chunk + "\n\n")} @@ -1444,7 +1444,7 @@ func (e *KiroExecutor) streamToChannel(ctx context.Context, body io.Reader, out // Close tool_use block (always close even if input marshal failed) blockStop := e.buildClaudeContentBlockStopEvent(contentBlockIndex) - sseData = sdktranslator.TranslateStream(ctx, sdktranslator.FromString("claude"), targetFormat, model, originalReq, claudeBody, blockStop, &translatorParam) + sseData = sdktranslator.TranslateStream(ctx, sdktranslator.FromString("kiro"), targetFormat, model, originalReq, claudeBody, blockStop, &translatorParam) for _, chunk := range sseData { if chunk != "" { out <- cliproxyexecutor.StreamChunk{Payload: []byte(chunk + "\n\n")} @@ -1464,7 +1464,7 @@ func (e *KiroExecutor) streamToChannel(ctx context.Context, body io.Reader, out // Close text block if open if isTextBlockOpen && contentBlockIndex >= 0 { blockStop := e.buildClaudeContentBlockStopEvent(contentBlockIndex) - sseData := sdktranslator.TranslateStream(ctx, sdktranslator.FromString("claude"), targetFormat, model, originalReq, claudeBody, blockStop, &translatorParam) + sseData := sdktranslator.TranslateStream(ctx, sdktranslator.FromString("kiro"), targetFormat, model, originalReq, claudeBody, blockStop, &translatorParam) for _, chunk := range sseData { if chunk != "" { out <- cliproxyexecutor.StreamChunk{Payload: []byte(chunk + "\n\n")} @@ -1476,7 +1476,7 @@ func (e *KiroExecutor) streamToChannel(ctx context.Context, body io.Reader, out contentBlockIndex++ blockStart := e.buildClaudeContentBlockStartEvent(contentBlockIndex, "tool_use", tu.ToolUseID, tu.Name) - sseData := sdktranslator.TranslateStream(ctx, sdktranslator.FromString("claude"), targetFormat, model, originalReq, claudeBody, blockStart, &translatorParam) + sseData := sdktranslator.TranslateStream(ctx, sdktranslator.FromString("kiro"), targetFormat, model, originalReq, claudeBody, blockStart, &translatorParam) for _, chunk := range sseData { if chunk != "" { out <- cliproxyexecutor.StreamChunk{Payload: []byte(chunk + "\n\n")} @@ -1489,7 +1489,7 @@ func (e *KiroExecutor) streamToChannel(ctx context.Context, body io.Reader, out log.Debugf("kiro: failed to marshal tool input in toolUseEvent: %v", err) } else { inputDelta := e.buildClaudeInputJsonDeltaEvent(string(inputJSON), contentBlockIndex) - sseData = sdktranslator.TranslateStream(ctx, sdktranslator.FromString("claude"), targetFormat, model, originalReq, claudeBody, inputDelta, &translatorParam) + sseData = sdktranslator.TranslateStream(ctx, sdktranslator.FromString("kiro"), targetFormat, model, originalReq, claudeBody, inputDelta, &translatorParam) for _, chunk := range sseData { if chunk != "" { out <- cliproxyexecutor.StreamChunk{Payload: []byte(chunk + "\n\n")} @@ -1499,7 +1499,7 @@ func (e *KiroExecutor) streamToChannel(ctx context.Context, body io.Reader, out } blockStop := e.buildClaudeContentBlockStopEvent(contentBlockIndex) - sseData = sdktranslator.TranslateStream(ctx, sdktranslator.FromString("claude"), targetFormat, model, originalReq, claudeBody, blockStop, &translatorParam) + sseData = sdktranslator.TranslateStream(ctx, sdktranslator.FromString("kiro"), targetFormat, model, originalReq, claudeBody, blockStop, &translatorParam) for _, chunk := range sseData { if chunk != "" { out <- cliproxyexecutor.StreamChunk{Payload: []byte(chunk + "\n\n")} @@ -1530,7 +1530,7 @@ func (e *KiroExecutor) streamToChannel(ctx context.Context, body io.Reader, out // Close content block if open if isTextBlockOpen && contentBlockIndex >= 0 { blockStop := e.buildClaudeContentBlockStopEvent(contentBlockIndex) - sseData := sdktranslator.TranslateStream(ctx, sdktranslator.FromString("claude"), targetFormat, model, originalReq, claudeBody, blockStop, &translatorParam) + sseData := sdktranslator.TranslateStream(ctx, sdktranslator.FromString("kiro"), targetFormat, model, originalReq, claudeBody, blockStop, &translatorParam) for _, chunk := range sseData { if chunk != "" { out <- cliproxyexecutor.StreamChunk{Payload: []byte(chunk + "\n\n")} @@ -1555,7 +1555,7 @@ func (e *KiroExecutor) streamToChannel(ctx context.Context, body io.Reader, out // Send message_delta and message_stop msgStop := e.buildClaudeMessageStopEvent(stopReason, totalUsage) - sseData := sdktranslator.TranslateStream(ctx, sdktranslator.FromString("claude"), targetFormat, model, originalReq, claudeBody, msgStop, &translatorParam) + sseData := sdktranslator.TranslateStream(ctx, sdktranslator.FromString("kiro"), targetFormat, model, originalReq, claudeBody, msgStop, &translatorParam) for _, chunk := range sseData { if chunk != "" { out <- cliproxyexecutor.StreamChunk{Payload: []byte(chunk + "\n\n")} @@ -1566,6 +1566,7 @@ func (e *KiroExecutor) streamToChannel(ctx context.Context, body io.Reader, out // Claude SSE event builders +// All builders return complete SSE format with "event:" line for Claude client compatibility. func (e *KiroExecutor) buildClaudeMessageStartEvent(model string, inputTokens int64) []byte { event := map[string]interface{}{ "type": "message_start", @@ -1581,7 +1582,7 @@ func (e *KiroExecutor) buildClaudeMessageStartEvent(model string, inputTokens in }, } result, _ := json.Marshal(event) - return []byte("data: " + string(result)) + return []byte("event: message_start\ndata: " + string(result)) } func (e *KiroExecutor) buildClaudeContentBlockStartEvent(index int, blockType, toolUseID, toolName string) []byte { @@ -1606,7 +1607,7 @@ func (e *KiroExecutor) buildClaudeContentBlockStartEvent(index int, blockType, t "content_block": contentBlock, } result, _ := json.Marshal(event) - return []byte("data: " + string(result)) + return []byte("event: content_block_start\ndata: " + string(result)) } func (e *KiroExecutor) buildClaudeStreamEvent(contentDelta string, index int) []byte { @@ -1619,7 +1620,7 @@ func (e *KiroExecutor) buildClaudeStreamEvent(contentDelta string, index int) [] }, } result, _ := json.Marshal(event) - return []byte("data: " + string(result)) + return []byte("event: content_block_delta\ndata: " + string(result)) } // buildClaudeInputJsonDeltaEvent creates an input_json_delta event for tool use streaming @@ -1633,7 +1634,7 @@ func (e *KiroExecutor) buildClaudeInputJsonDeltaEvent(partialJSON string, index }, } result, _ := json.Marshal(event) - return []byte("data: " + string(result)) + return []byte("event: content_block_delta\ndata: " + string(result)) } func (e *KiroExecutor) buildClaudeContentBlockStopEvent(index int) []byte { @@ -1642,7 +1643,7 @@ func (e *KiroExecutor) buildClaudeContentBlockStopEvent(index int) []byte { "index": index, } result, _ := json.Marshal(event) - return []byte("data: " + string(result)) + return []byte("event: content_block_stop\ndata: " + string(result)) } func (e *KiroExecutor) buildClaudeMessageStopEvent(stopReason string, usageInfo usage.Detail) []byte { @@ -1666,7 +1667,7 @@ func (e *KiroExecutor) buildClaudeMessageStopEvent(stopReason string, usageInfo } stopResult, _ := json.Marshal(stopEvent) - return []byte("data: " + string(deltaResult) + "\n\ndata: " + string(stopResult)) + return []byte("event: message_delta\ndata: " + string(deltaResult) + "\n\nevent: message_stop\ndata: " + string(stopResult)) } // buildClaudeFinalEvent constructs the final Claude-style event. @@ -1675,7 +1676,7 @@ func (e *KiroExecutor) buildClaudeFinalEvent() []byte { "type": "message_stop", } result, _ := json.Marshal(event) - return []byte("data: " + string(result)) + return []byte("event: message_stop\ndata: " + string(result)) } // CountTokens is not supported for the Kiro provider. @@ -1890,7 +1891,7 @@ func (e *KiroExecutor) streamEventStream(ctx context.Context, body io.Reader, c if !messageStartSent { msgStart := e.buildClaudeMessageStartEvent(model, totalUsage.InputTokens) - sseData := sdktranslator.TranslateStream(ctx, sdktranslator.FromString("claude"), targetFormat, model, originalReq, claudeBody, msgStart, &translatorParam) + sseData := sdktranslator.TranslateStream(ctx, sdktranslator.FromString("kiro"), targetFormat, model, originalReq, claudeBody, msgStart, &translatorParam) for _, chunk := range sseData { if chunk != "" { c.Writer.Write([]byte(chunk + "\n\n")) @@ -1921,7 +1922,7 @@ func (e *KiroExecutor) streamEventStream(ctx context.Context, body io.Reader, c contentBlockIndex++ isBlockOpen = true blockStart := e.buildClaudeContentBlockStartEvent(contentBlockIndex, "text", "", "") - sseData := sdktranslator.TranslateStream(ctx, sdktranslator.FromString("claude"), targetFormat, model, originalReq, claudeBody, blockStart, &translatorParam) + sseData := sdktranslator.TranslateStream(ctx, sdktranslator.FromString("kiro"), targetFormat, model, originalReq, claudeBody, blockStart, &translatorParam) for _, chunk := range sseData { if chunk != "" { c.Writer.Write([]byte(chunk + "\n\n")) @@ -1931,7 +1932,7 @@ func (e *KiroExecutor) streamEventStream(ctx context.Context, body io.Reader, c } claudeEvent := e.buildClaudeStreamEvent(contentDelta, contentBlockIndex) - sseData := sdktranslator.TranslateStream(ctx, sdktranslator.FromString("claude"), targetFormat, model, originalReq, claudeBody, claudeEvent, &translatorParam) + sseData := sdktranslator.TranslateStream(ctx, sdktranslator.FromString("kiro"), targetFormat, model, originalReq, claudeBody, claudeEvent, &translatorParam) for _, chunk := range sseData { if chunk != "" { c.Writer.Write([]byte(chunk + "\n\n")) @@ -1964,7 +1965,7 @@ func (e *KiroExecutor) streamEventStream(ctx context.Context, body io.Reader, c // Close content block if open if isBlockOpen && contentBlockIndex >= 0 { blockStop := e.buildClaudeContentBlockStopEvent(contentBlockIndex) - sseData := sdktranslator.TranslateStream(ctx, sdktranslator.FromString("claude"), targetFormat, model, originalReq, claudeBody, blockStop, &translatorParam) + sseData := sdktranslator.TranslateStream(ctx, sdktranslator.FromString("kiro"), targetFormat, model, originalReq, claudeBody, blockStop, &translatorParam) for _, chunk := range sseData { if chunk != "" { c.Writer.Write([]byte(chunk + "\n\n")) @@ -1984,7 +1985,7 @@ func (e *KiroExecutor) streamEventStream(ctx context.Context, body io.Reader, c // Always use end_turn (no tool_use support) msgStop := e.buildClaudeMessageStopEvent("end_turn", totalUsage) - sseData := sdktranslator.TranslateStream(ctx, sdktranslator.FromString("claude"), targetFormat, model, originalReq, claudeBody, msgStop, &translatorParam) + sseData := sdktranslator.TranslateStream(ctx, sdktranslator.FromString("kiro"), targetFormat, model, originalReq, claudeBody, msgStop, &translatorParam) for _, chunk := range sseData { if chunk != "" { c.Writer.Write([]byte(chunk + "\n\n")) diff --git a/internal/translator/kiro/claude/kiro_claude.go b/internal/translator/kiro/claude/kiro_claude.go index 335873a7..554dbf21 100644 --- a/internal/translator/kiro/claude/kiro_claude.go +++ b/internal/translator/kiro/claude/kiro_claude.go @@ -1,14 +1,11 @@ // Package claude provides translation between Kiro and Claude formats. -// Since Kiro uses Claude-compatible format internally, translations are mostly pass-through. -// However, SSE events require proper "event: " prefix for Claude clients. +// Since Kiro executor generates Claude-compatible SSE format internally (with event: prefix), +// translations are pass-through. package claude import ( "bytes" "context" - "strings" - - "github.com/tidwall/gjson" ) // ConvertClaudeRequestToKiro converts Claude request to Kiro format. @@ -18,52 +15,10 @@ func ConvertClaudeRequestToKiro(modelName string, inputRawJSON []byte, stream bo } // ConvertKiroResponseToClaude converts Kiro streaming response to Claude format. -// It adds the required "event: " prefix for SSE compliance with Claude clients. -// Input format: "data: {\"type\":\"message_start\",...}" -// Output format: "event: message_start\ndata: {\"type\":\"message_start\",...}" +// Kiro executor already generates complete SSE format with "event:" prefix, +// so this is a simple pass-through. func ConvertKiroResponseToClaude(ctx context.Context, model string, originalRequest, request, rawResponse []byte, param *any) []string { - raw := string(rawResponse) - - // Handle multiple data blocks (e.g., message_delta + message_stop) - lines := strings.Split(raw, "\n\n") - var results []string - - for _, line := range lines { - line = strings.TrimSpace(line) - if line == "" { - continue - } - - // Extract event type from JSON and add "event:" prefix - formatted := addEventPrefix(line) - if formatted != "" { - results = append(results, formatted) - } - } - - if len(results) == 0 { - return []string{raw} - } - - return results -} - -// addEventPrefix extracts the event type from the data line and adds the event: prefix. -// Input: "data: {\"type\":\"message_start\",...}" -// Output: "event: message_start\ndata: {\"type\":\"message_start\",...}" -func addEventPrefix(dataLine string) string { - if !strings.HasPrefix(dataLine, "data: ") { - return dataLine - } - - jsonPart := strings.TrimPrefix(dataLine, "data: ") - eventType := gjson.Get(jsonPart, "type").String() - - if eventType == "" { - return dataLine - } - - return "event: " + eventType + "\n" + dataLine + return []string{string(rawResponse)} } // ConvertKiroResponseToClaudeNonStream converts Kiro non-streaming response to Claude format. diff --git a/internal/translator/kiro/openai/chat-completions/kiro_openai_response.go b/internal/translator/kiro/openai/chat-completions/kiro_openai_response.go index 6a0ad250..df75cc07 100644 --- a/internal/translator/kiro/openai/chat-completions/kiro_openai_response.go +++ b/internal/translator/kiro/openai/chat-completions/kiro_openai_response.go @@ -4,6 +4,7 @@ package chat_completions import ( "context" "encoding/json" + "strings" "time" "github.com/google/uuid" @@ -13,15 +14,58 @@ import ( // ConvertKiroResponseToOpenAI converts Kiro streaming response to OpenAI SSE format. // Handles Claude SSE events: content_block_start, content_block_delta, input_json_delta, // content_block_stop, message_delta, and message_stop. +// Input may be in SSE format: "event: xxx\ndata: {...}" or raw JSON. func ConvertKiroResponseToOpenAI(ctx context.Context, model string, originalRequest, request, rawResponse []byte, param *any) []string { - root := gjson.ParseBytes(rawResponse) + raw := string(rawResponse) + var results []string + + // Handle SSE format: extract JSON from "data: " lines + // Input format: "event: message_start\ndata: {...}" + lines := strings.Split(raw, "\n") + for _, line := range lines { + line = strings.TrimSpace(line) + if strings.HasPrefix(line, "data: ") { + jsonPart := strings.TrimPrefix(line, "data: ") + chunks := convertClaudeEventToOpenAI(jsonPart, model) + results = append(results, chunks...) + } else if strings.HasPrefix(line, "{") { + // Raw JSON (backward compatibility) + chunks := convertClaudeEventToOpenAI(line, model) + results = append(results, chunks...) + } + } + + return results +} + +// convertClaudeEventToOpenAI converts a single Claude JSON event to OpenAI format +func convertClaudeEventToOpenAI(jsonStr string, model string) []string { + root := gjson.Parse(jsonStr) var results []string eventType := root.Get("type").String() switch eventType { case "message_start": - // Initial message event - could emit initial chunk if needed + // Initial message event - emit initial chunk with role + response := map[string]interface{}{ + "id": "chatcmpl-" + uuid.New().String()[:24], + "object": "chat.completion.chunk", + "created": time.Now().Unix(), + "model": model, + "choices": []map[string]interface{}{ + { + "index": 0, + "delta": map[string]interface{}{ + "role": "assistant", + "content": "", + }, + "finish_reason": nil, + }, + }, + } + result, _ := json.Marshal(response) + results = append(results, string(result)) return results case "content_block_start": From cd4e84a3600782577a537f457d6f866d2dd9cf24 Mon Sep 17 00:00:00 2001 From: Ravens2121 Date: Thu, 11 Dec 2025 05:24:21 +0800 Subject: [PATCH 23/38] feat(kiro): enhance request format, stream handling, and usage tracking MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ## English Description ### Request Format Fixes - Fix conversationState field order (chatTriggerType must be first) - Add conditional profileArn inclusion based on auth method - builder-id auth (AWS SSO) doesn't require profileArn - social auth (Google OAuth) requires profileArn ### Stream Processing Enhancements - Add headersLen boundary validation to prevent slice out of bounds - Handle incomplete tool use at EOF by flushing pending data - Separate message_delta and message_stop events for proper streaming - Add error logging for JSON unmarshal failures ### JSON Repair Improvements - Add escapeNewlinesInStrings() to handle control characters in JSON strings - Remove incorrect unquotedKeyPattern that broke valid JSON content - Fix handling of streaming fragments with embedded newlines/tabs ### Debug Info Filtering (Optional) - Add filterHeliosDebugInfo() to remove [HELIOS_CHK] blocks - Pattern matches internal state tracking from Kiro/Amazon Q - Currently disabled pending further testing ### Usage Tracking - Add usage information extraction in message_delta response - Include prompt_tokens, completion_tokens, total_tokens in OpenAI format --- ## 中文描述 ### 请求格式修复 - 修复 conversationState 字段顺序(chatTriggerType 必须在第一位) - 根据认证方式条件性包含 profileArn - builder-id 认证(AWS SSO)不需要 profileArn - social 认证(Google OAuth)需要 profileArn ### 流处理增强 - 添加 headersLen 边界验证,防止切片越界 - 在 EOF 时处理未完成的工具调用,刷新待处理数据 - 分离 message_delta 和 message_stop 事件以实现正确的流式传输 - 添加 JSON 反序列化失败的错误日志 ### JSON 修复改进 - 添加 escapeNewlinesInStrings() 处理 JSON 字符串中的控制字符 - 移除错误的 unquotedKeyPattern,该模式会破坏有效的 JSON 内容 - 修复包含嵌入换行符/制表符的流式片段处理 ### 调试信息过滤(可选) - 添加 filterHeliosDebugInfo() 移除 [HELIOS_CHK] 块 - 模式匹配来自 Kiro/Amazon Q 的内部状态跟踪信息 - 目前已禁用,等待进一步测试 ### 使用量跟踪 - 在 message_delta 响应中添加 usage 信息提取 - 以 OpenAI 格式包含 prompt_tokens、completion_tokens、total_tokens --- internal/runtime/executor/kiro_executor.go | 277 ++++++++++++++++-- .../chat-completions/kiro_openai_response.go | 15 +- 2 files changed, 260 insertions(+), 32 deletions(-) diff --git a/internal/runtime/executor/kiro_executor.go b/internal/runtime/executor/kiro_executor.go index b69fd8be..534e0c58 100644 --- a/internal/runtime/executor/kiro_executor.go +++ b/internal/runtime/executor/kiro_executor.go @@ -121,7 +121,12 @@ func (e *KiroExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, req return resp, fmt.Errorf("kiro: access token not found in auth") } if profileArn == "" { - log.Warnf("kiro: profile ARN not found in auth, API calls may fail") + // Only warn if not using builder-id auth (which doesn't need profileArn) + if auth == nil || auth.Metadata == nil { + log.Debugf("kiro: profile ARN not found in auth (may be normal for builder-id)") + } else if authMethod, ok := auth.Metadata["auth_method"].(string); !ok || authMethod != "builder-id" { + log.Warnf("kiro: profile ARN not found in auth, API calls may fail") + } } reporter := newUsageReporter(ctx, e.Identifier(), req.Model, auth) @@ -161,10 +166,19 @@ func (e *KiroExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, req currentOrigin = "CLI" } - kiroPayload := e.buildKiroPayload(body, kiroModelID, profileArn, currentOrigin, isAgentic, isChatOnly) + // Determine if profileArn should be included based on auth method + // profileArn is only needed for social auth (Google OAuth), not for builder-id (AWS SSO) + effectiveProfileArn := profileArn + if auth != nil && auth.Metadata != nil { + if authMethod, ok := auth.Metadata["auth_method"].(string); ok && authMethod == "builder-id" { + effectiveProfileArn = "" // Don't include profileArn for builder-id auth + } + } + + kiroPayload := e.buildKiroPayload(body, kiroModelID, effectiveProfileArn, currentOrigin, isAgentic, isChatOnly) // Execute with retry on 401/403 and 429 (quota exhausted) - resp, err = e.executeWithRetry(ctx, auth, req, opts, accessToken, profileArn, kiroPayload, body, from, to, reporter, currentOrigin, kiroModelID, isAgentic, isChatOnly) + resp, err = e.executeWithRetry(ctx, auth, req, opts, accessToken, effectiveProfileArn, kiroPayload, body, from, to, reporter, currentOrigin, kiroModelID, isAgentic, isChatOnly) return resp, err } @@ -330,7 +344,12 @@ func (e *KiroExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Aut return nil, fmt.Errorf("kiro: access token not found in auth") } if profileArn == "" { - log.Warnf("kiro: profile ARN not found in auth, API calls may fail") + // Only warn if not using builder-id auth (which doesn't need profileArn) + if auth == nil || auth.Metadata == nil { + log.Debugf("kiro: profile ARN not found in auth (may be normal for builder-id)") + } else if authMethod, ok := auth.Metadata["auth_method"].(string); !ok || authMethod != "builder-id" { + log.Warnf("kiro: profile ARN not found in auth, API calls may fail") + } } reporter := newUsageReporter(ctx, e.Identifier(), req.Model, auth) @@ -370,10 +389,19 @@ func (e *KiroExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Aut currentOrigin = "CLI" } - kiroPayload := e.buildKiroPayload(body, kiroModelID, profileArn, currentOrigin, isAgentic, isChatOnly) + // Determine if profileArn should be included based on auth method + // profileArn is only needed for social auth (Google OAuth), not for builder-id (AWS SSO) + effectiveProfileArn := profileArn + if auth != nil && auth.Metadata != nil { + if authMethod, ok := auth.Metadata["auth_method"].(string); ok && authMethod == "builder-id" { + effectiveProfileArn = "" // Don't include profileArn for builder-id auth + } + } + + kiroPayload := e.buildKiroPayload(body, kiroModelID, effectiveProfileArn, currentOrigin, isAgentic, isChatOnly) // Execute stream with retry on 401/403 and 429 (quota exhausted) - return e.executeStreamWithRetry(ctx, auth, req, opts, accessToken, profileArn, kiroPayload, body, from, reporter, currentOrigin, kiroModelID, isAgentic, isChatOnly) + return e.executeStreamWithRetry(ctx, auth, req, opts, accessToken, effectiveProfileArn, kiroPayload, body, from, reporter, currentOrigin, kiroModelID, isAgentic, isChatOnly) } // executeStreamWithRetry performs the streaming HTTP request with automatic retry on auth errors. @@ -587,10 +615,10 @@ type kiroPayload struct { } type kiroConversationState struct { + ChatTriggerType string `json:"chatTriggerType"` // Required: "MANUAL" - must be first field ConversationID string `json:"conversationId"` - History []kiroHistoryMessage `json:"history"` CurrentMessage kiroCurrentMessage `json:"currentMessage"` - ChatTriggerType string `json:"chatTriggerType"` // Required: "MANUAL" + History []kiroHistoryMessage `json:"history,omitempty"` // Only include when non-empty } type kiroCurrentMessage struct { @@ -805,21 +833,18 @@ func (e *KiroExecutor) buildKiroPayload(claudeBody []byte, modelID, profileArn, }} } + // Build payload with correct field order (matches struct definition) + // Note: history is omitempty, so nil/empty slice won't be serialized payload := kiroPayload{ ConversationState: kiroConversationState{ + ChatTriggerType: "MANUAL", // Required by Kiro API - must be first ConversationID: uuid.New().String(), - History: history, CurrentMessage: currentMessage, - ChatTriggerType: "MANUAL", // Required by Kiro API + History: history, // Will be omitted if empty due to omitempty tag }, ProfileArn: profileArn, } - // Ensure history is not nil (empty array) - if payload.ConversationState.History == nil { - payload.ConversationState.History = []kiroHistoryMessage{} - } - result, err := json.Marshal(payload) if err != nil { log.Debugf("kiro: failed to marshal payload: %v", err) @@ -1001,6 +1026,12 @@ func (e *KiroExecutor) parseEventStream(body io.Reader) (string, []kiroToolUse, return content.String(), toolUses, usageInfo, fmt.Errorf("failed to read message: %w", err) } + // Validate headersLen to prevent slice out of bounds + if headersLen+4 > uint32(len(remaining)) { + log.Warnf("kiro: invalid headersLen %d exceeds remaining buffer %d", headersLen, len(remaining)) + continue + } + // Extract event type from headers eventType := e.extractEventType(remaining[:headersLen+4]) @@ -1111,6 +1142,11 @@ func (e *KiroExecutor) parseEventStream(body io.Reader) (string, []kiroToolUse, // Deduplicate all tool uses toolUses = deduplicateToolUses(toolUses) + // OPTIONAL: Filter out [HELIOS_CHK] debug blocks from Kiro/Amazon Q internal systems + // These blocks contain internal state tracking information that should not be exposed to users. + // To enable filtering, uncomment the following line: + // cleanedContent = filterHeliosDebugInfo(cleanedContent) + return cleanedContent, toolUses, usageInfo, nil } @@ -1279,6 +1315,51 @@ func (e *KiroExecutor) streamToChannel(ctx context.Context, body io.Reader, out prelude := make([]byte, 8) _, err := io.ReadFull(reader, prelude) if err == io.EOF { + // Flush any incomplete tool use before ending stream + if currentToolUse != nil && !processedIDs[currentToolUse.toolUseID] { + log.Warnf("kiro: flushing incomplete tool use at EOF: %s (ID: %s)", currentToolUse.name, currentToolUse.toolUseID) + fullInput := currentToolUse.inputBuffer.String() + repairedJSON := repairJSON(fullInput) + var finalInput map[string]interface{} + if err := json.Unmarshal([]byte(repairedJSON), &finalInput); err != nil { + log.Warnf("kiro: failed to parse incomplete tool input at EOF: %v", err) + finalInput = make(map[string]interface{}) + } + + processedIDs[currentToolUse.toolUseID] = true + contentBlockIndex++ + + // Send tool_use content block + blockStart := e.buildClaudeContentBlockStartEvent(contentBlockIndex, "tool_use", currentToolUse.toolUseID, currentToolUse.name) + sseData := sdktranslator.TranslateStream(ctx, sdktranslator.FromString("kiro"), targetFormat, model, originalReq, claudeBody, blockStart, &translatorParam) + for _, chunk := range sseData { + if chunk != "" { + out <- cliproxyexecutor.StreamChunk{Payload: []byte(chunk + "\n\n")} + } + } + + // Send tool input as delta + inputBytes, _ := json.Marshal(finalInput) + inputDelta := e.buildClaudeInputJsonDeltaEvent(string(inputBytes), contentBlockIndex) + sseData = sdktranslator.TranslateStream(ctx, sdktranslator.FromString("kiro"), targetFormat, model, originalReq, claudeBody, inputDelta, &translatorParam) + for _, chunk := range sseData { + if chunk != "" { + out <- cliproxyexecutor.StreamChunk{Payload: []byte(chunk + "\n\n")} + } + } + + // Close block + blockStop := e.buildClaudeContentBlockStopEvent(contentBlockIndex) + sseData = sdktranslator.TranslateStream(ctx, sdktranslator.FromString("kiro"), targetFormat, model, originalReq, claudeBody, blockStop, &translatorParam) + for _, chunk := range sseData { + if chunk != "" { + out <- cliproxyexecutor.StreamChunk{Payload: []byte(chunk + "\n\n")} + } + } + + hasToolUses = true + currentToolUse = nil + } break } if err != nil { @@ -1304,6 +1385,12 @@ func (e *KiroExecutor) streamToChannel(ctx context.Context, body io.Reader, out return } + // Validate headersLen to prevent slice out of bounds + if headersLen+4 > uint32(len(remaining)) { + log.Warnf("kiro: invalid headersLen %d exceeds remaining buffer %d", headersLen, len(remaining)) + continue + } + eventType := e.extractEventType(remaining[:headersLen+4]) payloadStart := 4 + headersLen @@ -1317,6 +1404,7 @@ func (e *KiroExecutor) streamToChannel(ctx context.Context, body io.Reader, out var event map[string]interface{} if err := json.Unmarshal(payload, &event); err != nil { + log.Warnf("kiro: failed to unmarshal event payload: %v, raw: %s", err, string(payload)) continue } @@ -1553,9 +1641,18 @@ func (e *KiroExecutor) streamToChannel(ctx context.Context, body io.Reader, out stopReason = "tool_use" } - // Send message_delta and message_stop - msgStop := e.buildClaudeMessageStopEvent(stopReason, totalUsage) - sseData := sdktranslator.TranslateStream(ctx, sdktranslator.FromString("kiro"), targetFormat, model, originalReq, claudeBody, msgStop, &translatorParam) + // Send message_delta event + msgDelta := e.buildClaudeMessageDeltaEvent(stopReason, totalUsage) + sseData := sdktranslator.TranslateStream(ctx, sdktranslator.FromString("kiro"), targetFormat, model, originalReq, claudeBody, msgDelta, &translatorParam) + for _, chunk := range sseData { + if chunk != "" { + out <- cliproxyexecutor.StreamChunk{Payload: []byte(chunk + "\n\n")} + } + } + + // Send message_stop event separately + msgStop := e.buildClaudeMessageStopOnlyEvent() + sseData = sdktranslator.TranslateStream(ctx, sdktranslator.FromString("kiro"), targetFormat, model, originalReq, claudeBody, msgStop, &translatorParam) for _, chunk := range sseData { if chunk != "" { out <- cliproxyexecutor.StreamChunk{Payload: []byte(chunk + "\n\n")} @@ -1646,8 +1743,8 @@ func (e *KiroExecutor) buildClaudeContentBlockStopEvent(index int) []byte { return []byte("event: content_block_stop\ndata: " + string(result)) } -func (e *KiroExecutor) buildClaudeMessageStopEvent(stopReason string, usageInfo usage.Detail) []byte { - // First message_delta +// buildClaudeMessageDeltaEvent creates the message_delta event with stop_reason and usage. +func (e *KiroExecutor) buildClaudeMessageDeltaEvent(stopReason string, usageInfo usage.Detail) []byte { deltaEvent := map[string]interface{}{ "type": "message_delta", "delta": map[string]interface{}{ @@ -1660,14 +1757,16 @@ func (e *KiroExecutor) buildClaudeMessageStopEvent(stopReason string, usageInfo }, } deltaResult, _ := json.Marshal(deltaEvent) + return []byte("event: message_delta\ndata: " + string(deltaResult)) +} - // Then message_stop +// buildClaudeMessageStopOnlyEvent creates only the message_stop event. +func (e *KiroExecutor) buildClaudeMessageStopOnlyEvent() []byte { stopEvent := map[string]interface{}{ "type": "message_stop", } stopResult, _ := json.Marshal(stopEvent) - - return []byte("event: message_delta\ndata: " + string(deltaResult) + "\n\nevent: message_stop\ndata: " + string(stopResult)) + return []byte("event: message_stop\ndata: " + string(stopResult)) } // buildClaudeFinalEvent constructs the final Claude-style event. @@ -1873,6 +1972,12 @@ func (e *KiroExecutor) streamEventStream(ctx context.Context, body io.Reader, c return fmt.Errorf("failed to read message: %w", err) } + // Validate headersLen to prevent slice out of bounds + if headersLen+4 > uint32(len(remaining)) { + log.Warnf("kiro: invalid headersLen %d exceeds remaining buffer %d", headersLen, len(remaining)) + continue + } + eventType := e.extractEventType(remaining[:headersLen+4]) payloadStart := 4 + headersLen @@ -1886,6 +1991,7 @@ func (e *KiroExecutor) streamEventStream(ctx context.Context, body io.Reader, c var event map[string]interface{} if err := json.Unmarshal(payload, &event); err != nil { + log.Warnf("kiro: failed to unmarshal event payload: %v, raw: %s", err, string(payload)) continue } @@ -1983,9 +2089,19 @@ func (e *KiroExecutor) streamEventStream(ctx context.Context, body io.Reader, c } totalUsage.TotalTokens = totalUsage.InputTokens + totalUsage.OutputTokens - // Always use end_turn (no tool_use support) - msgStop := e.buildClaudeMessageStopEvent("end_turn", totalUsage) - sseData := sdktranslator.TranslateStream(ctx, sdktranslator.FromString("kiro"), targetFormat, model, originalReq, claudeBody, msgStop, &translatorParam) + // Send message_delta event + msgDelta := e.buildClaudeMessageDeltaEvent("end_turn", totalUsage) + sseData := sdktranslator.TranslateStream(ctx, sdktranslator.FromString("kiro"), targetFormat, model, originalReq, claudeBody, msgDelta, &translatorParam) + for _, chunk := range sseData { + if chunk != "" { + c.Writer.Write([]byte(chunk + "\n\n")) + } + } + c.Writer.Flush() + + // Send message_stop event separately + msgStop := e.buildClaudeMessageStopOnlyEvent() + sseData = sdktranslator.TranslateStream(ctx, sdktranslator.FromString("kiro"), targetFormat, model, originalReq, claudeBody, msgStop, &translatorParam) for _, chunk := range sseData { if chunk != "" { c.Writer.Write([]byte(chunk + "\n\n")) @@ -2079,8 +2195,12 @@ var ( whitespaceCollapsePattern = regexp.MustCompile(`\s+`) // trailingCommaPattern matches trailing commas before closing braces/brackets trailingCommaPattern = regexp.MustCompile(`,\s*([}\]])`) - // unquotedKeyPattern matches unquoted JSON keys that need quoting - unquotedKeyPattern = regexp.MustCompile(`([{,]\s*)([a-zA-Z_][a-zA-Z0-9_]*)\s*:`) + + // heliosDebugPattern matches [HELIOS_CHK] debug blocks from Kiro/Amazon Q internal systems + // These blocks contain internal state tracking information that should not be exposed to users + // Format: [HELIOS_CHK]\nPhase: ...\nHelios: ...\nAction: ...\nTask: ...\nContext-Map: ...\nNext: ... + // The pattern matches from [HELIOS_CHK] to the end of the "Next:" line (which may span multiple lines until double newline) + heliosDebugPattern = regexp.MustCompile(`(?s)\[HELIOS_CHK\].*?(?:Next:.*?)(?:\n\n|\z)`) ) // parseEmbeddedToolCalls extracts [Called tool_name with args: {...}] format from text. @@ -2249,13 +2369,70 @@ func findMatchingBracket(text string, startPos int) int { // Based on AIClient-2-API's JSON repair implementation. // Uses pre-compiled regex patterns for performance. func repairJSON(raw string) string { + // First, escape unescaped newlines/tabs within JSON string values + repaired := escapeNewlinesInStrings(raw) // Remove trailing commas before closing braces/brackets - repaired := trailingCommaPattern.ReplaceAllString(raw, "$1") - // Fix unquoted keys (basic attempt - handles simple cases) - repaired = unquotedKeyPattern.ReplaceAllString(repaired, `$1"$2":`) + repaired = trailingCommaPattern.ReplaceAllString(repaired, "$1") + // Note: unquotedKeyPattern removed - it incorrectly matches content inside + // JSON string values (e.g., "classDef fill:#1a1a2e,stroke:#00d4ff" would have + // ",stroke:" incorrectly treated as an unquoted key) return repaired } +// escapeNewlinesInStrings escapes literal newlines, tabs, and other control characters +// that appear inside JSON string values. This handles cases where streaming fragments +// contain unescaped control characters within string content. +func escapeNewlinesInStrings(raw string) string { + var result strings.Builder + result.Grow(len(raw) + 100) // Pre-allocate with some extra space + + inString := false + escaped := false + + for i := 0; i < len(raw); i++ { + c := raw[i] + + if escaped { + // Previous character was backslash, this is an escape sequence + result.WriteByte(c) + escaped = false + continue + } + + if c == '\\' && inString { + // Start of escape sequence + result.WriteByte(c) + escaped = true + continue + } + + if c == '"' { + // Toggle string state + inString = !inString + result.WriteByte(c) + continue + } + + if inString { + // Inside a string, escape control characters + switch c { + case '\n': + result.WriteString("\\n") + case '\r': + result.WriteString("\\r") + case '\t': + result.WriteString("\\t") + default: + result.WriteByte(c) + } + } else { + result.WriteByte(c) + } + } + + return result.String() +} + // processToolUseEvent handles a toolUseEvent from the Kiro stream. // It accumulates input fragments and emits tool_use blocks when complete. // Returns events to emit and updated state. @@ -2330,6 +2507,8 @@ func (e *KiroExecutor) processToolUseEvent(event map[string]interface{}, current // Accumulate input fragments if currentToolUse != nil && inputFragment != "" { + // Accumulate fragments directly - they form valid JSON when combined + // The fragments are already decoded from JSON, so we just concatenate them currentToolUse.inputBuffer.WriteString(inputFragment) log.Debugf("kiro: accumulated input fragment, total length: %d", currentToolUse.inputBuffer.Len()) } @@ -2389,3 +2568,39 @@ func deduplicateToolUses(toolUses []kiroToolUse) []kiroToolUse { return unique } + +// filterHeliosDebugInfo removes [HELIOS_CHK] debug blocks from Kiro/Amazon Q responses. +// These blocks contain internal state tracking information from the Helios system +// that should not be exposed to end users. +// +// The [HELIOS_CHK] block format typically looks like: +// +// [HELIOS_CHK] +// Phase: E (Evaluate & Evolve) +// Helios: [P1_Intent: OK] [P2_Research: OK] [P3_Strategy: OK] +// Action: [TEXT_RESPONSE] +// Task: #T001 - Some task description +// Context-Map: [SYNCED] +// Next: Some next action description... +// +// This function is currently DISABLED (commented out in callers) pending further testing. +// To enable, uncomment the filterHeliosDebugInfo calls in parseEventStream() and streamToChannel(). +func filterHeliosDebugInfo(content string) string { + if !strings.Contains(content, "[HELIOS_CHK]") { + return content + } + + // Remove [HELIOS_CHK] blocks + filtered := heliosDebugPattern.ReplaceAllString(content, "") + + // Clean up any resulting double newlines or leading/trailing whitespace + filtered = strings.TrimSpace(filtered) + + // Log when filtering occurs for debugging purposes + if filtered != content { + log.Debugf("kiro: filtered HELIOS debug info from response (original len: %d, filtered len: %d)", + len(content), len(filtered)) + } + + return filtered +} diff --git a/internal/translator/kiro/openai/chat-completions/kiro_openai_response.go b/internal/translator/kiro/openai/chat-completions/kiro_openai_response.go index df75cc07..d56c94ac 100644 --- a/internal/translator/kiro/openai/chat-completions/kiro_openai_response.go +++ b/internal/translator/kiro/openai/chat-completions/kiro_openai_response.go @@ -171,7 +171,7 @@ func convertClaudeEventToOpenAI(jsonStr string, model string) []string { return results case "message_delta": - // Final message delta with stop_reason + // Final message delta with stop_reason and usage stopReason := root.Get("delta.stop_reason").String() if stopReason != "" { finishReason := "stop" @@ -196,6 +196,19 @@ func convertClaudeEventToOpenAI(jsonStr string, model string) []string { }, }, } + + // Extract and include usage information from message_delta event + usage := root.Get("usage") + if usage.Exists() { + inputTokens := usage.Get("input_tokens").Int() + outputTokens := usage.Get("output_tokens").Int() + response["usage"] = map[string]interface{}{ + "prompt_tokens": inputTokens, + "completion_tokens": outputTokens, + "total_tokens": inputTokens + outputTokens, + } + } + result, _ := json.Marshal(response) results = append(results, string(result)) } From 6133bac226d098406a53e78baaaf61bca1e49907 Mon Sep 17 00:00:00 2001 From: Ravens2121 Date: Thu, 11 Dec 2025 08:10:11 +0800 Subject: [PATCH 24/38] feat(kiro): enhance Kiro executor stability and compatibility ## Changes Overview This commit includes multiple improvements to the Kiro executor for better stability, API compatibility, and code quality. ## Detailed Changes ### 1. Output Token Calculation Improvement (lines 317-330) - Replace simple len(content)/4 estimation with tiktoken-based calculation - Add fallback to character count estimation if tiktoken fails - Improves token counting accuracy for usage tracking ### 2. Stream Handler Panic Recovery (lines 528-533) - Add defer/recover block in streamToChannel goroutine - Prevents single request crashes from affecting the entire service ### 3. Struct Field Reordering (lines 670-673) - Reorder kiroToolResult struct fields: Content, Status, ToolUseID - Ensures consistency with API expectations ### 4. Message Merging Function (lines 778-780, 2356-2483) - Add mergeAdjacentMessages() to combine consecutive messages with same role - Add helper functions: mergeMessageContent(), blockToMap(), createMergedMessage() - Required by Kiro API which doesn't allow adjacent messages from same role ### 5. Empty Content Handling (lines 791-800) - Add default content for empty history messages - User messages with tool results: "Tool results provided." - User messages without tool results: "Continue" ### 6. Assistant Last Message Handling (lines 811-830) - Detect when last message is from assistant - Create synthetic "Continue" user message to satisfy Kiro API requirements - Kiro API requires currentMessage to be userInputMessage type ### 7. Duplicate Content Event Detection (lines 1650-1660) - Track lastContentEvent to detect duplicate streaming events - Skip redundant events to prevent duplicate content in responses - Based on AIClient-2-API implementation for Kiro ### 8. Streaming Token Calculation Enhancement (lines 1785-1817) - Add accumulatedContent buffer for streaming token calculation - Use tiktoken for accurate output token counting during streaming - Add fallback to character count estimation with proper logging ### 9. JSON Repair Enhancement (lines 2665-2818) - Implement conservative JSON repair strategy - First try to parse JSON directly - if valid, return unchanged - Add bracket balancing detection and repair - Only repair when necessary to avoid corrupting valid JSON - Validate repaired JSON before returning ### 10. HELIOS_CHK Filtering Removal (lines 2500-2504, 3004-3039) - Remove filterHeliosDebugInfo function - Remove heliosDebugPattern regex - HELIOS_CHK fields now handled by client-side processing ### 11. Comment Translation - Translate Chinese comments to English for code consistency - Affected areas: token calculation, buffer handling, message processing --- internal/runtime/executor/kiro_executor.go | 534 ++++++++++++++++++--- 1 file changed, 467 insertions(+), 67 deletions(-) diff --git a/internal/runtime/executor/kiro_executor.go b/internal/runtime/executor/kiro_executor.go index 534e0c58..84fd990c 100644 --- a/internal/runtime/executor/kiro_executor.go +++ b/internal/runtime/executor/kiro_executor.go @@ -315,9 +315,18 @@ func (e *KiroExecutor) executeWithRetry(ctx context.Context, auth *cliproxyauth. } } if len(content) > 0 { - usageInfo.OutputTokens = int64(len(content) / 4) + // Use tiktoken for more accurate output token calculation + if enc, encErr := tokenizerForModel(req.Model); encErr == nil { + if tokenCount, countErr := enc.Count(content); countErr == nil { + usageInfo.OutputTokens = int64(tokenCount) + } + } + // Fallback to character count estimation if tiktoken fails if usageInfo.OutputTokens == 0 { - usageInfo.OutputTokens = 1 + usageInfo.OutputTokens = int64(len(content) / 4) + if usageInfo.OutputTokens == 0 { + usageInfo.OutputTokens = 1 + } } } usageInfo.TotalTokens = usageInfo.InputTokens + usageInfo.OutputTokens @@ -519,6 +528,12 @@ func (e *KiroExecutor) executeStreamWithRetry(ctx context.Context, auth *cliprox go func(resp *http.Response) { defer close(out) + defer func() { + if r := recover(); r != nil { + log.Errorf("kiro: panic in stream handler: %v", r) + out <- cliproxyexecutor.StreamChunk{Err: fmt.Errorf("internal error: %v", r)} + } + }() defer func() { if errClose := resp.Body.Close(); errClose != nil { log.Errorf("response body close error: %v", errClose) @@ -655,9 +670,9 @@ type kiroUserInputMessageContext struct { } type kiroToolResult struct { - ToolUseID string `json:"toolUseId"` Content []kiroTextContent `json:"content"` Status string `json:"status"` + ToolUseID string `json:"toolUseId"` } type kiroTextContent struct { @@ -763,7 +778,9 @@ func (e *KiroExecutor) buildKiroPayload(claudeBody []byte, modelID, profileArn, var currentUserMsg *kiroUserInputMessage var currentToolResults []kiroToolResult - messagesArray := messages.Array() + // Merge adjacent messages with the same role before processing + // This reduces API call complexity and improves compatibility + messagesArray := mergeAdjacentMessages(messages.Array()) for i, msg := range messagesArray { role := msg.Get("role").String() isLastMessage := i == len(messagesArray)-1 @@ -774,6 +791,14 @@ func (e *KiroExecutor) buildKiroPayload(claudeBody []byte, modelID, profileArn, currentUserMsg = &userMsg currentToolResults = toolResults } else { + // CRITICAL: Kiro API requires content to be non-empty for history messages too + if strings.TrimSpace(userMsg.Content) == "" { + if len(toolResults) > 0 { + userMsg.Content = "Tool results provided." + } else { + userMsg.Content = "Continue" + } + } // For history messages, embed tool results in context if len(toolResults) > 0 { userMsg.UserInputMessageContext = &kiroUserInputMessageContext{ @@ -786,9 +811,24 @@ func (e *KiroExecutor) buildKiroPayload(claudeBody []byte, modelID, profileArn, } } else if role == "assistant" { assistantMsg := e.buildAssistantMessageStruct(msg) - history = append(history, kiroHistoryMessage{ - AssistantResponseMessage: &assistantMsg, - }) + // If this is the last message and it's an assistant message, + // we need to add it to history and create a "Continue" user message + // because Kiro API requires currentMessage to be userInputMessage type + if isLastMessage { + history = append(history, kiroHistoryMessage{ + AssistantResponseMessage: &assistantMsg, + }) + // Create a "Continue" user message as currentMessage + currentUserMsg = &kiroUserInputMessage{ + Content: "Continue", + ModelID: modelID, + Origin: origin, + } + } else { + history = append(history, kiroHistoryMessage{ + AssistantResponseMessage: &assistantMsg, + }) + } } } @@ -805,7 +845,35 @@ func (e *KiroExecutor) buildKiroPayload(claudeBody []byte, modelID, profileArn, // Add the actual user message contentBuilder.WriteString(currentUserMsg.Content) - currentUserMsg.Content = contentBuilder.String() + finalContent := contentBuilder.String() + + // CRITICAL: Kiro API requires content to be non-empty, even when toolResults are present + // If content is empty or only whitespace, provide a default message + if strings.TrimSpace(finalContent) == "" { + if len(currentToolResults) > 0 { + finalContent = "Tool results provided." + } else { + finalContent = "Continue" + } + log.Debugf("kiro: content was empty, using default: %s", finalContent) + } + currentUserMsg.Content = finalContent + + // Deduplicate currentToolResults before adding to context + // Kiro API does not accept duplicate toolUseIds + if len(currentToolResults) > 0 { + seenIDs := make(map[string]bool) + uniqueToolResults := make([]kiroToolResult, 0, len(currentToolResults)) + for _, tr := range currentToolResults { + if !seenIDs[tr.ToolUseID] { + seenIDs[tr.ToolUseID] = true + uniqueToolResults = append(uniqueToolResults, tr) + } else { + log.Debugf("kiro: skipping duplicate toolResult in currentMessage: %s", tr.ToolUseID) + } + } + currentToolResults = uniqueToolResults + } // Build userInputMessageContext with tools and tool results if len(kiroTools) > 0 || len(currentToolResults) > 0 { @@ -855,11 +923,15 @@ func (e *KiroExecutor) buildKiroPayload(claudeBody []byte, modelID, profileArn, // buildUserMessageStruct builds a user message and extracts tool results // origin parameter determines which quota to use: "CLI" for Amazon Q, "AI_EDITOR" for Kiro IDE. +// IMPORTANT: Kiro API does not accept duplicate toolUseIds, so we deduplicate here. func (e *KiroExecutor) buildUserMessageStruct(msg gjson.Result, modelID, origin string) (kiroUserInputMessage, []kiroToolResult) { content := msg.Get("content") var contentBuilder strings.Builder var toolResults []kiroToolResult var images []kiroImage + + // Track seen toolUseIds to deduplicate - Kiro API rejects duplicate toolUseIds + seenToolUseIDs := make(map[string]bool) if content.IsArray() { for _, part := range content.Array() { @@ -889,6 +961,14 @@ func (e *KiroExecutor) buildUserMessageStruct(msg gjson.Result, modelID, origin case "tool_result": // Extract tool result for API toolUseID := part.Get("tool_use_id").String() + + // Skip duplicate toolUseIds - Kiro API does not accept duplicates + if seenToolUseIDs[toolUseID] { + log.Debugf("kiro: skipping duplicate tool_result with toolUseId: %s", toolUseID) + continue + } + seenToolUseIDs[toolUseID] = true + isError := part.Get("is_error").Bool() resultContent := part.Get("content") @@ -1049,6 +1129,37 @@ func (e *KiroExecutor) parseEventStream(body io.Reader) (string, []kiroToolUse, continue } + // DIAGNOSTIC: Log all received event types for debugging + log.Debugf("kiro: parseEventStream received event type: %s", eventType) + if log.IsLevelEnabled(log.TraceLevel) { + log.Tracef("kiro: parseEventStream event payload: %s", string(payload)) + } + + // Check for error/exception events in the payload (Kiro API may return errors with HTTP 200) + // These can appear as top-level fields or nested within the event + if errType, hasErrType := event["_type"].(string); hasErrType { + // AWS-style error: {"_type": "com.amazon.aws.codewhisperer#ValidationException", "message": "..."} + errMsg := "" + if msg, ok := event["message"].(string); ok { + errMsg = msg + } + log.Errorf("kiro: received AWS error in event stream: type=%s, message=%s", errType, errMsg) + return "", nil, usageInfo, fmt.Errorf("kiro API error: %s - %s", errType, errMsg) + } + if errType, hasErrType := event["type"].(string); hasErrType && (errType == "error" || errType == "exception") { + // Generic error event + errMsg := "" + if msg, ok := event["message"].(string); ok { + errMsg = msg + } else if errObj, ok := event["error"].(map[string]interface{}); ok { + if msg, ok := errObj["message"].(string); ok { + errMsg = msg + } + } + log.Errorf("kiro: received error event in stream: type=%s, message=%s", errType, errMsg) + return "", nil, usageInfo, fmt.Errorf("kiro API error: %s", errMsg) + } + // Handle different event types switch eventType { case "assistantResponseEvent": @@ -1142,11 +1253,6 @@ func (e *KiroExecutor) parseEventStream(body io.Reader) (string, []kiroToolUse, // Deduplicate all tool uses toolUses = deduplicateToolUses(toolUses) - // OPTIONAL: Filter out [HELIOS_CHK] debug blocks from Kiro/Amazon Q internal systems - // These blocks contain internal state tracking information that should not be exposed to users. - // To enable filtering, uncomment the following line: - // cleanedContent = filterHeliosDebugInfo(cleanedContent) - return cleanedContent, toolUses, usageInfo, nil } @@ -1267,8 +1373,9 @@ func (e *KiroExecutor) buildClaudeResponse(content string, toolUses []kiroToolUs // streamToChannel converts AWS Event Stream to channel-based streaming. // Supports tool calling - emits tool_use content blocks when tools are used. // Includes embedded [Called ...] tool call parsing and input buffering for toolUseEvent. +// Implements duplicate content filtering using lastContentEvent detection (based on AIClient-2-API). func (e *KiroExecutor) streamToChannel(ctx context.Context, body io.Reader, out chan<- cliproxyexecutor.StreamChunk, targetFormat sdktranslator.Format, model string, originalReq, claudeBody []byte, reporter *usageReporter) { - reader := bufio.NewReader(body) + reader := bufio.NewReaderSize(body, 20*1024*1024) // 20MB buffer to match other providers var totalUsage usage.Detail var hasToolUses bool // Track if any tool uses were emitted @@ -1276,6 +1383,15 @@ func (e *KiroExecutor) streamToChannel(ctx context.Context, body io.Reader, out processedIDs := make(map[string]bool) var currentToolUse *toolUseState + // Duplicate content detection - tracks last content event to filter duplicates + // Based on AIClient-2-API implementation for Kiro + var lastContentEvent string + + // Streaming token calculation - accumulate content for real-time token counting + // Based on AIClient-2-API implementation + var accumulatedContent strings.Builder + accumulatedContent.Grow(4096) // Pre-allocate 4KB capacity to reduce reallocations + // Translator param for maintaining tool call state across streaming events // IMPORTANT: This must persist across all TranslateStream calls var translatorParam any @@ -1408,6 +1524,39 @@ func (e *KiroExecutor) streamToChannel(ctx context.Context, body io.Reader, out continue } + // DIAGNOSTIC: Log all received event types for debugging + log.Debugf("kiro: streamToChannel received event type: %s", eventType) + if log.IsLevelEnabled(log.TraceLevel) { + log.Tracef("kiro: streamToChannel event payload: %s", string(payload)) + } + + // Check for error/exception events in the payload (Kiro API may return errors with HTTP 200) + // These can appear as top-level fields or nested within the event + if errType, hasErrType := event["_type"].(string); hasErrType { + // AWS-style error: {"_type": "com.amazon.aws.codewhisperer#ValidationException", "message": "..."} + errMsg := "" + if msg, ok := event["message"].(string); ok { + errMsg = msg + } + log.Errorf("kiro: received AWS error in stream: type=%s, message=%s", errType, errMsg) + out <- cliproxyexecutor.StreamChunk{Err: fmt.Errorf("kiro API error: %s - %s", errType, errMsg)} + return + } + if errType, hasErrType := event["type"].(string); hasErrType && (errType == "error" || errType == "exception") { + // Generic error event + errMsg := "" + if msg, ok := event["message"].(string); ok { + errMsg = msg + } else if errObj, ok := event["error"].(map[string]interface{}); ok { + if msg, ok := errObj["message"].(string); ok { + errMsg = msg + } + } + log.Errorf("kiro: received error event in stream: type=%s, message=%s", errType, errMsg) + out <- cliproxyexecutor.StreamChunk{Err: fmt.Errorf("kiro API error: %s", errMsg)} + return + } + // Send message_start on first event if !messageStartSent { msgStart := e.buildClaudeMessageStartEvent(model, totalUsage.InputTokens) @@ -1452,9 +1601,19 @@ func (e *KiroExecutor) streamToChannel(ctx context.Context, body io.Reader, out } } - // Handle text content + // Handle text content with duplicate detection if contentDelta != "" { + // Check for duplicate content - skip if identical to last content event + // Based on AIClient-2-API implementation for Kiro + if contentDelta == lastContentEvent { + log.Debugf("kiro: skipping duplicate content event (len: %d)", len(contentDelta)) + continue + } + lastContentEvent = contentDelta + outputLen += len(contentDelta) + // Accumulate content for streaming token calculation + accumulatedContent.WriteString(contentDelta) // Start text content block if needed if !isTextBlockOpen { contentBlockIndex++ @@ -1626,8 +1785,32 @@ func (e *KiroExecutor) streamToChannel(ctx context.Context, body io.Reader, out } } - // Fallback for output tokens if not received from upstream - if totalUsage.OutputTokens == 0 && outputLen > 0 { + // Streaming token calculation - calculate output tokens from accumulated content + // This provides more accurate token counting than simple character division + if totalUsage.OutputTokens == 0 && accumulatedContent.Len() > 0 { + // Try to use tiktoken for accurate counting + if enc, err := tokenizerForModel(model); err == nil { + if tokenCount, countErr := enc.Count(accumulatedContent.String()); countErr == nil { + totalUsage.OutputTokens = int64(tokenCount) + log.Debugf("kiro: streamToChannel calculated output tokens using tiktoken: %d", totalUsage.OutputTokens) + } else { + // Fallback on count error: estimate from character count + totalUsage.OutputTokens = int64(accumulatedContent.Len() / 4) + if totalUsage.OutputTokens == 0 { + totalUsage.OutputTokens = 1 + } + log.Debugf("kiro: streamToChannel tiktoken count failed, estimated from chars: %d", totalUsage.OutputTokens) + } + } else { + // Fallback: estimate from character count (roughly 4 chars per token) + totalUsage.OutputTokens = int64(accumulatedContent.Len() / 4) + if totalUsage.OutputTokens == 0 { + totalUsage.OutputTokens = 1 + } + log.Debugf("kiro: streamToChannel estimated output tokens from chars: %d (content len: %d)", totalUsage.OutputTokens, accumulatedContent.Len()) + } + } else if totalUsage.OutputTokens == 0 && outputLen > 0 { + // Legacy fallback using outputLen totalUsage.OutputTokens = int64(outputLen / 4) if totalUsage.OutputTokens == 0 { totalUsage.OutputTokens = 1 @@ -2173,6 +2356,128 @@ func (e *KiroExecutor) isTokenExpired(accessToken string) bool { return isExpired } +// ============================================================================ +// Message Merging Support - Merge adjacent messages with the same role +// Based on AIClient-2-API implementation +// ============================================================================ + +// mergeAdjacentMessages merges adjacent messages with the same role. +// This reduces API call complexity and improves compatibility. +// Based on AIClient-2-API implementation. +func mergeAdjacentMessages(messages []gjson.Result) []gjson.Result { + if len(messages) <= 1 { + return messages + } + + var merged []gjson.Result + for _, msg := range messages { + if len(merged) == 0 { + merged = append(merged, msg) + continue + } + + lastMsg := merged[len(merged)-1] + currentRole := msg.Get("role").String() + lastRole := lastMsg.Get("role").String() + + if currentRole == lastRole { + // Merge content from current message into last message + mergedContent := mergeMessageContent(lastMsg, msg) + // Create a new merged message JSON + mergedMsg := createMergedMessage(lastRole, mergedContent) + merged[len(merged)-1] = gjson.Parse(mergedMsg) + } else { + merged = append(merged, msg) + } + } + + return merged +} + +// mergeMessageContent merges the content of two messages with the same role. +// Handles both string content and array content (with text, tool_use, tool_result blocks). +func mergeMessageContent(msg1, msg2 gjson.Result) string { + content1 := msg1.Get("content") + content2 := msg2.Get("content") + + // Extract content blocks from both messages + var blocks1, blocks2 []map[string]interface{} + + if content1.IsArray() { + for _, block := range content1.Array() { + blocks1 = append(blocks1, blockToMap(block)) + } + } else if content1.Type == gjson.String { + blocks1 = append(blocks1, map[string]interface{}{ + "type": "text", + "text": content1.String(), + }) + } + + if content2.IsArray() { + for _, block := range content2.Array() { + blocks2 = append(blocks2, blockToMap(block)) + } + } else if content2.Type == gjson.String { + blocks2 = append(blocks2, map[string]interface{}{ + "type": "text", + "text": content2.String(), + }) + } + + // Merge text blocks if both end/start with text + if len(blocks1) > 0 && len(blocks2) > 0 { + if blocks1[len(blocks1)-1]["type"] == "text" && blocks2[0]["type"] == "text" { + // Merge the last text block of msg1 with the first text block of msg2 + text1 := blocks1[len(blocks1)-1]["text"].(string) + text2 := blocks2[0]["text"].(string) + blocks1[len(blocks1)-1]["text"] = text1 + "\n" + text2 + blocks2 = blocks2[1:] // Remove the merged block from blocks2 + } + } + + // Combine all blocks + allBlocks := append(blocks1, blocks2...) + + // Convert to JSON + result, _ := json.Marshal(allBlocks) + return string(result) +} + +// blockToMap converts a gjson.Result block to a map[string]interface{} +func blockToMap(block gjson.Result) map[string]interface{} { + result := make(map[string]interface{}) + block.ForEach(func(key, value gjson.Result) bool { + if value.IsObject() { + result[key.String()] = blockToMap(value) + } else if value.IsArray() { + var arr []interface{} + for _, item := range value.Array() { + if item.IsObject() { + arr = append(arr, blockToMap(item)) + } else { + arr = append(arr, item.Value()) + } + } + result[key.String()] = arr + } else { + result[key.String()] = value.Value() + } + return true + }) + return result +} + +// createMergedMessage creates a JSON string for a merged message +func createMergedMessage(role string, content string) string { + msg := map[string]interface{}{ + "role": role, + "content": json.RawMessage(content), + } + result, _ := json.Marshal(msg) + return string(result) +} + // ============================================================================ // Tool Calling Support - Embedded tool call parsing and input buffering // Based on amq2api and AIClient-2-API implementations @@ -2195,12 +2500,6 @@ var ( whitespaceCollapsePattern = regexp.MustCompile(`\s+`) // trailingCommaPattern matches trailing commas before closing braces/brackets trailingCommaPattern = regexp.MustCompile(`,\s*([}\]])`) - - // heliosDebugPattern matches [HELIOS_CHK] debug blocks from Kiro/Amazon Q internal systems - // These blocks contain internal state tracking information that should not be exposed to users - // Format: [HELIOS_CHK]\nPhase: ...\nHelios: ...\nAction: ...\nTask: ...\nContext-Map: ...\nNext: ... - // The pattern matches from [HELIOS_CHK] to the end of the "Next:" line (which may span multiple lines until double newline) - heliosDebugPattern = regexp.MustCompile(`(?s)\[HELIOS_CHK\].*?(?:Next:.*?)(?:\n\n|\z)`) ) // parseEmbeddedToolCalls extracts [Called tool_name with args: {...}] format from text. @@ -2366,17 +2665,154 @@ func findMatchingBracket(text string, startPos int) int { } // repairJSON attempts to fix common JSON issues that may occur in tool call arguments. -// Based on AIClient-2-API's JSON repair implementation. +// Based on AIClient-2-API's JSON repair implementation with a more conservative strategy. +// +// Conservative repair strategy: +// 1. First try to parse JSON directly - if valid, return as-is +// 2. Only attempt repair if parsing fails +// 3. After repair, validate the result - if still invalid, return original +// +// Handles incomplete JSON by balancing brackets and removing trailing incomplete content. // Uses pre-compiled regex patterns for performance. -func repairJSON(raw string) string { +func repairJSON(jsonString string) string { + // Handle empty or invalid input + if jsonString == "" { + return "{}" + } + + str := strings.TrimSpace(jsonString) + if str == "" { + return "{}" + } + + // CONSERVATIVE STRATEGY: First try to parse directly + // If the JSON is already valid, return it unchanged + var testParse interface{} + if err := json.Unmarshal([]byte(str), &testParse); err == nil { + log.Debugf("kiro: repairJSON - JSON is already valid, returning unchanged") + return str + } + + log.Debugf("kiro: repairJSON - JSON parse failed, attempting repair") + originalStr := str // Keep original for fallback + // First, escape unescaped newlines/tabs within JSON string values - repaired := escapeNewlinesInStrings(raw) + str = escapeNewlinesInStrings(str) // Remove trailing commas before closing braces/brackets - repaired = trailingCommaPattern.ReplaceAllString(repaired, "$1") - // Note: unquotedKeyPattern removed - it incorrectly matches content inside - // JSON string values (e.g., "classDef fill:#1a1a2e,stroke:#00d4ff" would have - // ",stroke:" incorrectly treated as an unquoted key) - return repaired + str = trailingCommaPattern.ReplaceAllString(str, "$1") + + // Calculate bracket balance to detect incomplete JSON + braceCount := 0 // {} balance + bracketCount := 0 // [] balance + inString := false + escape := false + lastValidIndex := -1 + + for i := 0; i < len(str); i++ { + char := str[i] + + // Handle escape sequences + if escape { + escape = false + continue + } + + if char == '\\' { + escape = true + continue + } + + // Handle string boundaries + if char == '"' { + inString = !inString + continue + } + + // Skip characters inside strings (they don't affect bracket balance) + if inString { + continue + } + + // Track bracket balance + switch char { + case '{': + braceCount++ + case '}': + braceCount-- + case '[': + bracketCount++ + case ']': + bracketCount-- + } + + // Record last valid position (where brackets are balanced or positive) + if braceCount >= 0 && bracketCount >= 0 { + lastValidIndex = i + } + } + + // If brackets are unbalanced, try to repair + if braceCount > 0 || bracketCount > 0 { + // Truncate to last valid position if we have incomplete content + if lastValidIndex > 0 && lastValidIndex < len(str)-1 { + // Check if truncation would help (only truncate if there's trailing garbage) + truncated := str[:lastValidIndex+1] + // Recount brackets after truncation + braceCount = 0 + bracketCount = 0 + inString = false + escape = false + for i := 0; i < len(truncated); i++ { + char := truncated[i] + if escape { + escape = false + continue + } + if char == '\\' { + escape = true + continue + } + if char == '"' { + inString = !inString + continue + } + if inString { + continue + } + switch char { + case '{': + braceCount++ + case '}': + braceCount-- + case '[': + bracketCount++ + case ']': + bracketCount-- + } + } + str = truncated + } + + // Add missing closing brackets + for braceCount > 0 { + str += "}" + braceCount-- + } + for bracketCount > 0 { + str += "]" + bracketCount-- + } + } + + // CONSERVATIVE STRATEGY: Validate repaired JSON + // If repair didn't produce valid JSON, return original string + if err := json.Unmarshal([]byte(str), &testParse); err != nil { + log.Warnf("kiro: repairJSON - repair failed to produce valid JSON, returning original") + return originalStr + } + + log.Debugf("kiro: repairJSON - successfully repaired JSON") + return str } // escapeNewlinesInStrings escapes literal newlines, tabs, and other control characters @@ -2568,39 +3004,3 @@ func deduplicateToolUses(toolUses []kiroToolUse) []kiroToolUse { return unique } - -// filterHeliosDebugInfo removes [HELIOS_CHK] debug blocks from Kiro/Amazon Q responses. -// These blocks contain internal state tracking information from the Helios system -// that should not be exposed to end users. -// -// The [HELIOS_CHK] block format typically looks like: -// -// [HELIOS_CHK] -// Phase: E (Evaluate & Evolve) -// Helios: [P1_Intent: OK] [P2_Research: OK] [P3_Strategy: OK] -// Action: [TEXT_RESPONSE] -// Task: #T001 - Some task description -// Context-Map: [SYNCED] -// Next: Some next action description... -// -// This function is currently DISABLED (commented out in callers) pending further testing. -// To enable, uncomment the filterHeliosDebugInfo calls in parseEventStream() and streamToChannel(). -func filterHeliosDebugInfo(content string) string { - if !strings.Contains(content, "[HELIOS_CHK]") { - return content - } - - // Remove [HELIOS_CHK] blocks - filtered := heliosDebugPattern.ReplaceAllString(content, "") - - // Clean up any resulting double newlines or leading/trailing whitespace - filtered = strings.TrimSpace(filtered) - - // Log when filtering occurs for debugging purposes - if filtered != content { - log.Debugf("kiro: filtered HELIOS debug info from response (original len: %d, filtered len: %d)", - len(content), len(filtered)) - } - - return filtered -} From ef0edbfe697fed04a271a6cdee719670c22fa31e Mon Sep 17 00:00:00 2001 From: Luis Pater Date: Thu, 11 Dec 2025 22:34:06 +0800 Subject: [PATCH 25/38] refactor(claude): replace `strings.Builder` with simpler `output` string concatenation --- .../claude/gemini-cli_claude_response.go | 81 +++++++++++-------- 1 file changed, 49 insertions(+), 32 deletions(-) diff --git a/internal/translator/gemini-cli/claude/gemini-cli_claude_response.go b/internal/translator/gemini-cli/claude/gemini-cli_claude_response.go index 92061086..ca905f9e 100644 --- a/internal/translator/gemini-cli/claude/gemini-cli_claude_response.go +++ b/internal/translator/gemini-cli/claude/gemini-cli_claude_response.go @@ -69,12 +69,12 @@ func ConvertGeminiCLIResponseToClaude(_ context.Context, _ string, originalReque // Track whether tools are being used in this response chunk usedTool := false - var sb strings.Builder + output := "" // Initialize the streaming session with a message_start event // This is only sent for the very first response chunk to establish the streaming session if !(*param).(*Params).HasFirstResponse { - sb.WriteString("event: message_start\n") + output = "event: message_start\n" // Create the initial message structure with default values according to Claude Code API specification // This follows the Claude Code API specification for streaming message initialization @@ -87,7 +87,7 @@ func ConvertGeminiCLIResponseToClaude(_ context.Context, _ string, originalReque if responseIDResult := gjson.GetBytes(rawJSON, "response.responseId"); responseIDResult.Exists() { messageStartTemplate, _ = sjson.Set(messageStartTemplate, "message.id", responseIDResult.String()) } - sb.WriteString(fmt.Sprintf("data: %s\n\n\n", messageStartTemplate)) + output = output + fmt.Sprintf("data: %s\n\n\n", messageStartTemplate) (*param).(*Params).HasFirstResponse = true } @@ -110,7 +110,7 @@ func ConvertGeminiCLIResponseToClaude(_ context.Context, _ string, originalReque if partResult.Get("thought").Bool() { // Continue existing thinking block if already in thinking state if (*param).(*Params).ResponseType == 2 { - sb.WriteString("event: content_block_delta\n") + output = output + "event: content_block_delta\n" data, _ := sjson.Set(fmt.Sprintf(`{"type":"content_block_delta","index":%d,"delta":{"type":"thinking_delta","thinking":""}}`, (*param).(*Params).ResponseIndex), "delta.thinking", partTextResult.String()) output = output + fmt.Sprintf("data: %s\n\n\n", data) (*param).(*Params).HasContent = true @@ -118,19 +118,24 @@ func ConvertGeminiCLIResponseToClaude(_ context.Context, _ string, originalReque // Transition from another state to thinking // First, close any existing content block if (*param).(*Params).ResponseType != 0 { - sb.WriteString("event: content_block_stop\n") - sb.WriteString(fmt.Sprintf(`data: {"type":"content_block_stop","index":%d}`, (*param).(*Params).ResponseIndex)) - sb.WriteString("\n\n\n") + if (*param).(*Params).ResponseType == 2 { + // output = output + "event: content_block_delta\n" + // output = output + fmt.Sprintf(`data: {"type":"content_block_delta","index":%d,"delta":{"type":"signature_delta","signature":null}}`, (*param).(*Params).ResponseIndex) + // output = output + "\n\n\n" + } + output = output + "event: content_block_stop\n" + output = output + fmt.Sprintf(`data: {"type":"content_block_stop","index":%d}`, (*param).(*Params).ResponseIndex) + output = output + "\n\n\n" (*param).(*Params).ResponseIndex++ } // Start a new thinking content block - sb.WriteString("event: content_block_start\n") - sb.WriteString(fmt.Sprintf(`data: {"type":"content_block_start","index":%d,"content_block":{"type":"thinking","thinking":""}}`, (*param).(*Params).ResponseIndex)) - sb.WriteString("\n\n\n") - sb.WriteString("event: content_block_delta\n") + output = output + "event: content_block_start\n" + output = output + fmt.Sprintf(`data: {"type":"content_block_start","index":%d,"content_block":{"type":"thinking","thinking":""}}`, (*param).(*Params).ResponseIndex) + output = output + "\n\n\n" + output = output + "event: content_block_delta\n" data, _ := sjson.Set(fmt.Sprintf(`{"type":"content_block_delta","index":%d,"delta":{"type":"thinking_delta","thinking":""}}`, (*param).(*Params).ResponseIndex), "delta.thinking", partTextResult.String()) - sb.WriteString(fmt.Sprintf("data: %s\n\n\n", data)) + output = output + fmt.Sprintf("data: %s\n\n\n", data) (*param).(*Params).ResponseType = 2 // Set state to thinking (*param).(*Params).HasContent = true } @@ -138,7 +143,7 @@ func ConvertGeminiCLIResponseToClaude(_ context.Context, _ string, originalReque // Process regular text content (user-visible output) // Continue existing text block if already in content state if (*param).(*Params).ResponseType == 1 { - sb.WriteString("event: content_block_delta\n") + output = output + "event: content_block_delta\n" data, _ := sjson.Set(fmt.Sprintf(`{"type":"content_block_delta","index":%d,"delta":{"type":"text_delta","text":""}}`, (*param).(*Params).ResponseIndex), "delta.text", partTextResult.String()) output = output + fmt.Sprintf("data: %s\n\n\n", data) (*param).(*Params).HasContent = true @@ -146,19 +151,24 @@ func ConvertGeminiCLIResponseToClaude(_ context.Context, _ string, originalReque // Transition from another state to text content // First, close any existing content block if (*param).(*Params).ResponseType != 0 { - sb.WriteString("event: content_block_stop\n") - sb.WriteString(fmt.Sprintf(`data: {"type":"content_block_stop","index":%d}`, (*param).(*Params).ResponseIndex)) - sb.WriteString("\n\n\n") + if (*param).(*Params).ResponseType == 2 { + // output = output + "event: content_block_delta\n" + // output = output + fmt.Sprintf(`data: {"type":"content_block_delta","index":%d,"delta":{"type":"signature_delta","signature":null}}`, (*param).(*Params).ResponseIndex) + // output = output + "\n\n\n" + } + output = output + "event: content_block_stop\n" + output = output + fmt.Sprintf(`data: {"type":"content_block_stop","index":%d}`, (*param).(*Params).ResponseIndex) + output = output + "\n\n\n" (*param).(*Params).ResponseIndex++ } // Start a new text content block - sb.WriteString("event: content_block_start\n") - sb.WriteString(fmt.Sprintf(`data: {"type":"content_block_start","index":%d,"content_block":{"type":"text","text":""}}`, (*param).(*Params).ResponseIndex)) - sb.WriteString("\n\n\n") - sb.WriteString("event: content_block_delta\n") + output = output + "event: content_block_start\n" + output = output + fmt.Sprintf(`data: {"type":"content_block_start","index":%d,"content_block":{"type":"text","text":""}}`, (*param).(*Params).ResponseIndex) + output = output + "\n\n\n" + output = output + "event: content_block_delta\n" data, _ := sjson.Set(fmt.Sprintf(`{"type":"content_block_delta","index":%d,"delta":{"type":"text_delta","text":""}}`, (*param).(*Params).ResponseIndex), "delta.text", partTextResult.String()) - sb.WriteString(fmt.Sprintf("data: %s\n\n\n", data)) + output = output + fmt.Sprintf("data: %s\n\n\n", data) (*param).(*Params).ResponseType = 1 // Set state to content (*param).(*Params).HasContent = true } @@ -172,35 +182,42 @@ func ConvertGeminiCLIResponseToClaude(_ context.Context, _ string, originalReque // Handle state transitions when switching to function calls // Close any existing function call block first if (*param).(*Params).ResponseType == 3 { - sb.WriteString("event: content_block_stop\n") - sb.WriteString(fmt.Sprintf(`data: {"type":"content_block_stop","index":%d}`, (*param).(*Params).ResponseIndex)) - sb.WriteString("\n\n\n") + output = output + "event: content_block_stop\n" + output = output + fmt.Sprintf(`data: {"type":"content_block_stop","index":%d}`, (*param).(*Params).ResponseIndex) + output = output + "\n\n\n" (*param).(*Params).ResponseIndex++ (*param).(*Params).ResponseType = 0 } + // Special handling for thinking state transition + if (*param).(*Params).ResponseType == 2 { + // output = output + "event: content_block_delta\n" + // output = output + fmt.Sprintf(`data: {"type":"content_block_delta","index":%d,"delta":{"type":"signature_delta","signature":null}}`, (*param).(*Params).ResponseIndex) + // output = output + "\n\n\n" + } + // Close any other existing content block if (*param).(*Params).ResponseType != 0 { - sb.WriteString("event: content_block_stop\n") - sb.WriteString(fmt.Sprintf(`data: {"type":"content_block_stop","index":%d}`, (*param).(*Params).ResponseIndex)) - sb.WriteString("\n\n\n") + output = output + "event: content_block_stop\n" + output = output + fmt.Sprintf(`data: {"type":"content_block_stop","index":%d}`, (*param).(*Params).ResponseIndex) + output = output + "\n\n\n" (*param).(*Params).ResponseIndex++ } // Start a new tool use content block // This creates the structure for a function call in Claude Code format - sb.WriteString("event: content_block_start\n") + output = output + "event: content_block_start\n" // Create the tool use block with unique ID and function details data := fmt.Sprintf(`{"type":"content_block_start","index":%d,"content_block":{"type":"tool_use","id":"","name":"","input":{}}}`, (*param).(*Params).ResponseIndex) data, _ = sjson.Set(data, "content_block.id", fmt.Sprintf("%s-%d-%d", fcName, time.Now().UnixNano(), atomic.AddUint64(&toolUseIDCounter, 1))) data, _ = sjson.Set(data, "content_block.name", fcName) - sb.WriteString(fmt.Sprintf("data: %s\n\n\n", data)) + output = output + fmt.Sprintf("data: %s\n\n\n", data) if fcArgsResult := functionCallResult.Get("args"); fcArgsResult.Exists() { - sb.WriteString("event: content_block_delta\n") + output = output + "event: content_block_delta\n" data, _ = sjson.Set(fmt.Sprintf(`{"type":"content_block_delta","index":%d,"delta":{"type":"input_json_delta","partial_json":""}}`, (*param).(*Params).ResponseIndex), "delta.partial_json", fcArgsResult.Raw) - sb.WriteString(fmt.Sprintf("data: %s\n\n\n", data)) + output = output + fmt.Sprintf("data: %s\n\n\n", data) } (*param).(*Params).ResponseType = 3 (*param).(*Params).HasContent = true @@ -240,7 +257,7 @@ func ConvertGeminiCLIResponseToClaude(_ context.Context, _ string, originalReque } } - return []string{sb.String()} + return []string{output} } // ConvertGeminiCLIResponseToClaudeNonStream converts a non-streaming Gemini CLI response to a non-streaming Claude response. From 40e7f066e493b918705d0043e27d709dd5e8cd58 Mon Sep 17 00:00:00 2001 From: Ravens2121 Date: Fri, 12 Dec 2025 01:59:06 +0800 Subject: [PATCH 26/38] feat(kiro): enhance Kiro executor with retry, deduplication and event filtering --- internal/api/server.go | 6 ++ internal/runtime/executor/kiro_executor.go | 82 +++++++++++++++++++--- 2 files changed, 80 insertions(+), 8 deletions(-) diff --git a/internal/api/server.go b/internal/api/server.go index e1cea9e9..ade08fef 100644 --- a/internal/api/server.go +++ b/internal/api/server.go @@ -349,6 +349,12 @@ func (s *Server) setupRoutes() { }, }) }) + + // Event logging endpoint - handles Claude Code telemetry requests + // Returns 200 OK to prevent 404 errors in logs + s.engine.POST("/api/event_logging/batch", func(c *gin.Context) { + c.JSON(http.StatusOK, gin.H{"status": "ok"}) + }) s.engine.POST("/v1internal:method", geminiCLIHandlers.CLIHandler) // OAuth callback endpoints (reuse main server port) diff --git a/internal/runtime/executor/kiro_executor.go b/internal/runtime/executor/kiro_executor.go index 84fd990c..bcdebe1f 100644 --- a/internal/runtime/executor/kiro_executor.go +++ b/internal/runtime/executor/kiro_executor.go @@ -255,6 +255,26 @@ func (e *KiroExecutor) executeWithRetry(ctx context.Context, auth *cliproxyauth. return resp, statusErr{code: httpResp.StatusCode, msg: string(respBody)} } + // Handle 5xx server errors with exponential backoff retry + if httpResp.StatusCode >= 500 && httpResp.StatusCode < 600 { + respBody, _ := io.ReadAll(httpResp.Body) + _ = httpResp.Body.Close() + appendAPIResponseChunk(ctx, e.cfg, respBody) + + if attempt < maxRetries { + // Exponential backoff: 1s, 2s, 4s... (max 30s) + backoff := time.Duration(1< 30*time.Second { + backoff = 30 * time.Second + } + log.Warnf("kiro: server error %d, retrying in %v (attempt %d/%d)", httpResp.StatusCode, backoff, attempt+1, maxRetries) + time.Sleep(backoff) + continue + } + log.Errorf("kiro: server error %d after %d retries", httpResp.StatusCode, maxRetries) + return resp, statusErr{code: httpResp.StatusCode, msg: string(respBody)} + } + // Handle 401/403 errors with token refresh and retry if httpResp.StatusCode == 401 || httpResp.StatusCode == 403 { respBody, _ := io.ReadAll(httpResp.Body) @@ -485,6 +505,26 @@ func (e *KiroExecutor) executeStreamWithRetry(ctx context.Context, auth *cliprox return nil, statusErr{code: httpResp.StatusCode, msg: string(respBody)} } + // Handle 5xx server errors with exponential backoff retry + if httpResp.StatusCode >= 500 && httpResp.StatusCode < 600 { + respBody, _ := io.ReadAll(httpResp.Body) + _ = httpResp.Body.Close() + appendAPIResponseChunk(ctx, e.cfg, respBody) + + if attempt < maxRetries { + // Exponential backoff: 1s, 2s, 4s... (max 30s) + backoff := time.Duration(1< 30*time.Second { + backoff = 30 * time.Second + } + log.Warnf("kiro: stream server error %d, retrying in %v (attempt %d/%d)", httpResp.StatusCode, backoff, attempt+1, maxRetries) + time.Sleep(backoff) + continue + } + log.Errorf("kiro: stream server error %d after %d retries", httpResp.StatusCode, maxRetries) + return nil, statusErr{code: httpResp.StatusCode, msg: string(respBody)} + } + // Handle 401/403 errors with token refresh and retry if httpResp.StatusCode == 401 || httpResp.StatusCode == 403 { respBody, _ := io.ReadAll(httpResp.Body) @@ -1162,6 +1202,11 @@ func (e *KiroExecutor) parseEventStream(body io.Reader) (string, []kiroToolUse, // Handle different event types switch eventType { + case "followupPromptEvent": + // Filter out followupPrompt events - these are UI suggestions, not content + log.Debugf("kiro: parseEventStream ignoring followupPrompt event") + continue + case "assistantResponseEvent": if assistantResp, ok := event["assistantResponseEvent"].(map[string]interface{}); ok { if contentText, ok := assistantResp["content"].(string); ok { @@ -1570,6 +1615,11 @@ func (e *KiroExecutor) streamToChannel(ctx context.Context, body io.Reader, out } switch eventType { + case "followupPromptEvent": + // Filter out followupPrompt events - these are UI suggestions, not content + log.Debugf("kiro: streamToChannel ignoring followupPrompt event") + continue + case "assistantResponseEvent": var contentDelta string var toolUses []map[string]interface{} @@ -1961,7 +2011,8 @@ func (e *KiroExecutor) buildClaudeFinalEvent() []byte { return []byte("event: message_stop\ndata: " + string(result)) } -// CountTokens is not supported for the Kiro provider. +// CountTokens is not supported for Kiro provider. +// Kiro/Amazon Q backend doesn't expose a token counting API. func (e *KiroExecutor) CountTokens(context.Context, *cliproxyauth.Auth, cliproxyexecutor.Request, cliproxyexecutor.Options) (cliproxyexecutor.Response, error) { return cliproxyexecutor.Response{}, statusErr{code: http.StatusNotImplemented, msg: "count tokens not supported for kiro"} } @@ -2988,18 +3039,33 @@ func (e *KiroExecutor) processToolUseEvent(event map[string]interface{}, current return toolUses, currentToolUse } -// deduplicateToolUses removes duplicate tool uses based on toolUseId. +// deduplicateToolUses removes duplicate tool uses based on toolUseId and content (name+arguments). +// This prevents both ID-based duplicates and content-based duplicates (same tool call with different IDs). func deduplicateToolUses(toolUses []kiroToolUse) []kiroToolUse { - seen := make(map[string]bool) + seenIDs := make(map[string]bool) + seenContent := make(map[string]bool) // Content-based deduplication (name + arguments) var unique []kiroToolUse for _, tu := range toolUses { - if !seen[tu.ToolUseID] { - seen[tu.ToolUseID] = true - unique = append(unique, tu) - } else { - log.Debugf("kiro: removing duplicate tool use: %s", tu.ToolUseID) + // Skip if we've already seen this ID + if seenIDs[tu.ToolUseID] { + log.Debugf("kiro: removing ID-duplicate tool use: %s (name: %s)", tu.ToolUseID, tu.Name) + continue } + + // Build content key for content-based deduplication + inputJSON, _ := json.Marshal(tu.Input) + contentKey := tu.Name + ":" + string(inputJSON) + + // Skip if we've already seen this content (same name + arguments) + if seenContent[contentKey] { + log.Debugf("kiro: removing content-duplicate tool use: %s (id: %s)", tu.Name, tu.ToolUseID) + continue + } + + seenIDs[tu.ToolUseID] = true + seenContent[contentKey] = true + unique = append(unique, tu) } return unique From 204bba9dea1832752aa83f6111c1a0f24c1c7fdc Mon Sep 17 00:00:00 2001 From: Ravens2121 Date: Fri, 12 Dec 2025 09:27:30 +0800 Subject: [PATCH 27/38] refactor(kiro): update Kiro executor to use CodeWhisperer endpoint and improve tool calling support --- internal/registry/model_definitions.go | 45 ++++-- internal/runtime/executor/kiro_executor.go | 178 ++++++++++++++------- 2 files changed, 152 insertions(+), 71 deletions(-) diff --git a/internal/registry/model_definitions.go b/internal/registry/model_definitions.go index 59881c60..c80816f9 100644 --- a/internal/registry/model_definitions.go +++ b/internal/registry/model_definitions.go @@ -870,8 +870,9 @@ func GetGitHubCopilotModels() []*ModelInfo { // GetKiroModels returns the Kiro (AWS CodeWhisperer) model definitions func GetKiroModels() []*ModelInfo { return []*ModelInfo{ + // --- Base Models --- { - ID: "kiro-claude-opus-4.5", + ID: "kiro-claude-opus-4-5", Object: "model", Created: 1732752000, OwnedBy: "aws", @@ -882,7 +883,7 @@ func GetKiroModels() []*ModelInfo { MaxCompletionTokens: 64000, }, { - ID: "kiro-claude-sonnet-4.5", + ID: "kiro-claude-sonnet-4-5", Object: "model", Created: 1732752000, OwnedBy: "aws", @@ -904,7 +905,7 @@ func GetKiroModels() []*ModelInfo { MaxCompletionTokens: 64000, }, { - ID: "kiro-claude-haiku-4.5", + ID: "kiro-claude-haiku-4-5", Object: "model", Created: 1732752000, OwnedBy: "aws", @@ -914,21 +915,9 @@ func GetKiroModels() []*ModelInfo { ContextLength: 200000, MaxCompletionTokens: 64000, }, - // --- Chat Variant (No tool calling, for pure conversation) --- - { - ID: "kiro-claude-opus-4.5-chat", - Object: "model", - Created: 1732752000, - OwnedBy: "aws", - Type: "kiro", - DisplayName: "Kiro Claude Opus 4.5 (Chat)", - Description: "Claude Opus 4.5 for chat only (no tool calling)", - ContextLength: 200000, - MaxCompletionTokens: 64000, - }, // --- Agentic Variants (Optimized for coding agents with chunked writes) --- { - ID: "kiro-claude-opus-4.5-agentic", + ID: "kiro-claude-opus-4-5-agentic", Object: "model", Created: 1732752000, OwnedBy: "aws", @@ -939,7 +928,7 @@ func GetKiroModels() []*ModelInfo { MaxCompletionTokens: 64000, }, { - ID: "kiro-claude-sonnet-4.5-agentic", + ID: "kiro-claude-sonnet-4-5-agentic", Object: "model", Created: 1732752000, OwnedBy: "aws", @@ -949,6 +938,28 @@ func GetKiroModels() []*ModelInfo { ContextLength: 200000, MaxCompletionTokens: 64000, }, + { + ID: "kiro-claude-sonnet-4-agentic", + Object: "model", + Created: 1732752000, + OwnedBy: "aws", + Type: "kiro", + DisplayName: "Kiro Claude Sonnet 4 (Agentic)", + Description: "Claude Sonnet 4 optimized for coding agents (chunked writes)", + ContextLength: 200000, + MaxCompletionTokens: 64000, + }, + { + ID: "kiro-claude-haiku-4-5-agentic", + Object: "model", + Created: 1732752000, + OwnedBy: "aws", + Type: "kiro", + DisplayName: "Kiro Claude Haiku 4.5 (Agentic)", + Description: "Claude Haiku 4.5 optimized for coding agents (chunked writes)", + ContextLength: 200000, + MaxCompletionTokens: 64000, + }, } } diff --git a/internal/runtime/executor/kiro_executor.go b/internal/runtime/executor/kiro_executor.go index bcdebe1f..2987e2d3 100644 --- a/internal/runtime/executor/kiro_executor.go +++ b/internal/runtime/executor/kiro_executor.go @@ -31,16 +31,18 @@ import ( ) const ( - // kiroEndpoint is the Amazon Q streaming endpoint for chat API (GenerateAssistantResponse). - // Note: This is different from the CodeWhisperer management endpoint (codewhisperer.us-east-1.amazonaws.com) - // used in aws_auth.go for GetUsageLimits, ListProfiles, etc. Both endpoints are correct - // for their respective API operations. - kiroEndpoint = "https://q.us-east-1.amazonaws.com" - kiroTargetChat = "AmazonCodeWhispererStreamingService.GenerateAssistantResponse" - kiroContentType = "application/x-amz-json-1.0" - kiroAcceptStream = "application/vnd.amazon.eventstream" + // kiroEndpoint is the CodeWhisperer streaming endpoint for chat API (GenerateAssistantResponse). + // Based on AIClient-2-API reference implementation. + // Note: Amazon Q uses a different endpoint (q.us-east-1.amazonaws.com) with different request format. + kiroEndpoint = "https://codewhisperer.us-east-1.amazonaws.com/generateAssistantResponse" + kiroContentType = "application/json" + kiroAcceptStream = "application/json" kiroMaxMessageSize = 10 * 1024 * 1024 // 10MB max message size for event stream kiroMaxToolDescLen = 10237 // Kiro API limit is 10240 bytes, leave room for "..." + // kiroUserAgent matches AIClient-2-API format for x-amz-user-agent header + kiroUserAgent = "aws-sdk-js/1.0.7 KiroIDE-0.1.25" + // kiroFullUserAgent is the complete user-agent header matching AIClient-2-API + kiroFullUserAgent = "aws-sdk-js/1.0.7 ua/2.1 os/linux lang/go api/codewhispererstreaming#1.0.7 m/E KiroIDE-0.1.25" // kiroAgenticSystemPrompt is injected only for -agentic models to prevent timeouts on large writes. // AWS Kiro API has a 2-3 minute timeout for large file write operations. @@ -157,14 +159,10 @@ func (e *KiroExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, req // Check if this is a chat-only model variant (no tool calling) isChatOnly := strings.HasSuffix(req.Model, "-chat") - // Determine initial origin based on model type - // Opus models use AI_EDITOR (Kiro IDE quota), others start with CLI (Amazon Q quota) - var currentOrigin string - if strings.Contains(strings.ToLower(req.Model), "opus") { - currentOrigin = "AI_EDITOR" - } else { - currentOrigin = "CLI" - } + // Determine initial origin - always use AI_EDITOR to match AIClient-2-API behavior + // AIClient-2-API uses AI_EDITOR for all models, which is the Kiro IDE quota + // Note: CLI origin is for Amazon Q quota, but AIClient-2-API doesn't use it + currentOrigin := "AI_EDITOR" // Determine if profileArn should be included based on auth method // profileArn is only needed for social auth (Google OAuth), not for builder-id (AWS SSO) @@ -196,9 +194,13 @@ func (e *KiroExecutor) executeWithRetry(ctx context.Context, auth *cliproxyauth. } httpReq.Header.Set("Content-Type", kiroContentType) - httpReq.Header.Set("x-amz-target", kiroTargetChat) httpReq.Header.Set("Authorization", "Bearer "+accessToken) httpReq.Header.Set("Accept", kiroAcceptStream) + httpReq.Header.Set("x-amz-user-agent", kiroUserAgent) + httpReq.Header.Set("User-Agent", kiroFullUserAgent) + httpReq.Header.Set("amz-sdk-request", "attempt=1; max=1") + httpReq.Header.Set("x-amzn-kiro-agent-mode", "vibe") + httpReq.Header.Set("amz-sdk-invocation-id", uuid.New().String()) var attrs map[string]string if auth != nil { @@ -409,14 +411,9 @@ func (e *KiroExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Aut // Check if this is a chat-only model variant (no tool calling) isChatOnly := strings.HasSuffix(req.Model, "-chat") - // Determine initial origin based on model type - // Opus models use AI_EDITOR (Kiro IDE quota), others start with CLI (Amazon Q quota) - var currentOrigin string - if strings.Contains(strings.ToLower(req.Model), "opus") { - currentOrigin = "AI_EDITOR" - } else { - currentOrigin = "CLI" - } + // Determine initial origin - always use AI_EDITOR to match AIClient-2-API behavior + // AIClient-2-API uses AI_EDITOR for all models, which is the Kiro IDE quota + currentOrigin := "AI_EDITOR" // Determine if profileArn should be included based on auth method // profileArn is only needed for social auth (Google OAuth), not for builder-id (AWS SSO) @@ -446,9 +443,13 @@ func (e *KiroExecutor) executeStreamWithRetry(ctx context.Context, auth *cliprox } httpReq.Header.Set("Content-Type", kiroContentType) - httpReq.Header.Set("x-amz-target", kiroTargetChat) httpReq.Header.Set("Authorization", "Bearer "+accessToken) httpReq.Header.Set("Accept", kiroAcceptStream) + httpReq.Header.Set("x-amz-user-agent", kiroUserAgent) + httpReq.Header.Set("User-Agent", kiroFullUserAgent) + httpReq.Header.Set("amz-sdk-request", "attempt=1; max=1") + httpReq.Header.Set("x-amzn-kiro-agent-mode", "vibe") + httpReq.Header.Set("amz-sdk-invocation-id", uuid.New().String()) var attrs map[string]string if auth != nil { @@ -630,36 +631,81 @@ func kiroCredentials(auth *cliproxyauth.Auth) (accessToken, profileArn string) { // Agentic variants (-agentic suffix) map to the same backend model IDs. func (e *KiroExecutor) mapModelToKiro(model string) string { modelMap := map[string]string{ - "kiro-claude-opus-4.5": "claude-opus-4.5", - "kiro-claude-sonnet-4.5": "claude-sonnet-4.5", - "kiro-claude-sonnet-4": "claude-sonnet-4", - "kiro-claude-haiku-4.5": "claude-haiku-4.5", // Amazon Q format (amazonq- prefix) - same API as Kiro - "amazonq-auto": "auto", - "amazonq-claude-opus-4.5": "claude-opus-4.5", - "amazonq-claude-sonnet-4.5": "claude-sonnet-4.5", - "amazonq-claude-sonnet-4": "claude-sonnet-4", - "amazonq-claude-haiku-4.5": "claude-haiku-4.5", - // Native Kiro format (no prefix) - used by Kiro IDE directly - "claude-opus-4.5": "claude-opus-4.5", - "claude-sonnet-4.5": "claude-sonnet-4.5", - "claude-sonnet-4": "claude-sonnet-4", - "claude-haiku-4.5": "claude-haiku-4.5", - "auto": "auto", - // Chat variant (no tool calling support) - "kiro-claude-opus-4.5-chat": "claude-opus-4.5", + "amazonq-auto": "auto", + "amazonq-claude-opus-4-5": "claude-opus-4.5", + "amazonq-claude-sonnet-4-5": "claude-sonnet-4.5", + "amazonq-claude-sonnet-4-5-20250929": "claude-sonnet-4.5", + "amazonq-claude-sonnet-4": "claude-sonnet-4", + "amazonq-claude-sonnet-4-20250514": "claude-sonnet-4", + "amazonq-claude-haiku-4-5": "claude-haiku-4.5", + // Kiro format (kiro- prefix) - valid model names that should be preserved + "kiro-claude-opus-4-5": "claude-opus-4.5", + "kiro-claude-sonnet-4-5": "claude-sonnet-4.5", + "kiro-claude-sonnet-4-5-20250929": "claude-sonnet-4.5", + "kiro-claude-sonnet-4": "claude-sonnet-4", + "kiro-claude-sonnet-4-20250514": "claude-sonnet-4", + "kiro-claude-haiku-4-5": "claude-haiku-4.5", + "kiro-auto": "auto", + // Native format (no prefix) - used by Kiro IDE directly + "claude-opus-4-5": "claude-opus-4.5", + "claude-opus-4.5": "claude-opus-4.5", + "claude-haiku-4-5": "claude-haiku-4.5", + "claude-haiku-4.5": "claude-haiku-4.5", + "claude-sonnet-4-5": "claude-sonnet-4.5", + "claude-sonnet-4-5-20250929": "claude-sonnet-4.5", + "claude-sonnet-4.5": "claude-sonnet-4.5", + "claude-sonnet-4": "claude-sonnet-4", + "claude-sonnet-4-20250514": "claude-sonnet-4", + "auto": "auto", // Agentic variants (same backend model IDs, but with special system prompt) - "kiro-claude-opus-4.5-agentic": "claude-opus-4.5", - "kiro-claude-sonnet-4.5-agentic": "claude-sonnet-4.5", - "kiro-claude-sonnet-4-agentic": "claude-sonnet-4", - "kiro-claude-haiku-4.5-agentic": "claude-haiku-4.5", - "amazonq-claude-sonnet-4.5-agentic": "claude-sonnet-4.5", + "claude-opus-4.5-agentic": "claude-opus-4.5", + "claude-sonnet-4.5-agentic": "claude-sonnet-4.5", + "claude-sonnet-4-agentic": "claude-sonnet-4", + "claude-haiku-4.5-agentic": "claude-haiku-4.5", + "kiro-claude-opus-4-5-agentic": "claude-opus-4.5", + "kiro-claude-sonnet-4-5-agentic": "claude-sonnet-4.5", + "kiro-claude-sonnet-4-agentic": "claude-sonnet-4", + "kiro-claude-haiku-4-5-agentic": "claude-haiku-4.5", } if kiroID, ok := modelMap[model]; ok { return kiroID } - log.Debugf("kiro: unknown model '%s', falling back to 'auto'", model) - return "auto" + + // Smart fallback: try to infer model type from name patterns + modelLower := strings.ToLower(model) + + // Check for Haiku variants + if strings.Contains(modelLower, "haiku") { + log.Debugf("kiro: unknown Haiku model '%s', mapping to claude-haiku-4.5", model) + return "claude-haiku-4.5" + } + + // Check for Sonnet variants + if strings.Contains(modelLower, "sonnet") { + // Check for specific version patterns + if strings.Contains(modelLower, "3-7") || strings.Contains(modelLower, "3.7") { + log.Debugf("kiro: unknown Sonnet 3.7 model '%s', mapping to claude-3-7-sonnet-20250219", model) + return "claude-3-7-sonnet-20250219" + } + if strings.Contains(modelLower, "4-5") || strings.Contains(modelLower, "4.5") { + log.Debugf("kiro: unknown Sonnet 4.5 model '%s', mapping to claude-sonnet-4.5", model) + return "claude-sonnet-4.5" + } + // Default to Sonnet 4 + log.Debugf("kiro: unknown Sonnet model '%s', mapping to claude-sonnet-4", model) + return "claude-sonnet-4" + } + + // Check for Opus variants + if strings.Contains(modelLower, "opus") { + log.Debugf("kiro: unknown Opus model '%s', mapping to claude-opus-4.5", model) + return "claude-opus-4.5" + } + + // Final fallback to Sonnet 4.5 (most commonly used model) + log.Warnf("kiro: unknown model '%s', falling back to claude-sonnet-4.5", model) + return "claude-sonnet-4.5" } // Kiro API request structs - field order determines JSON key order @@ -673,7 +719,7 @@ type kiroConversationState struct { ChatTriggerType string `json:"chatTriggerType"` // Required: "MANUAL" - must be first field ConversationID string `json:"conversationId"` CurrentMessage kiroCurrentMessage `json:"currentMessage"` - History []kiroHistoryMessage `json:"history,omitempty"` // Only include when non-empty + History []kiroHistoryMessage `json:"history,omitempty"` } type kiroCurrentMessage struct { @@ -750,6 +796,24 @@ type kiroToolUse struct { // isAgentic parameter enables chunked write optimization prompt for -agentic model variants. // isChatOnly parameter disables tool calling for -chat model variants (pure conversation mode). func (e *KiroExecutor) buildKiroPayload(claudeBody []byte, modelID, profileArn, origin string, isAgentic, isChatOnly bool) []byte { + // Normalize origin value for Kiro API compatibility + // Kiro API only accepts "CLI" or "AI_EDITOR" as valid origin values + switch origin { + case "KIRO_CLI": + origin = "CLI" + case "KIRO_AI_EDITOR": + origin = "AI_EDITOR" + case "AMAZON_Q": + origin = "CLI" + case "KIRO_IDE": + origin = "AI_EDITOR" + // Add any other non-standard origin values that need normalization + default: + // Keep the original value if it's already standard + // Valid values: "CLI", "AI_EDITOR" + } + log.Debugf("kiro: normalized origin value: %s", origin) + messages := gjson.GetBytes(claudeBody, "messages") // For chat-only mode, don't include tools @@ -942,13 +1006,12 @@ func (e *KiroExecutor) buildKiroPayload(claudeBody []byte, modelID, profileArn, } // Build payload with correct field order (matches struct definition) - // Note: history is omitempty, so nil/empty slice won't be serialized payload := kiroPayload{ ConversationState: kiroConversationState{ ChatTriggerType: "MANUAL", // Required by Kiro API - must be first ConversationID: uuid.New().String(), CurrentMessage: currentMessage, - History: history, // Will be omitted if empty due to omitempty tag + History: history, // Now always included (non-nil slice) }, ProfileArn: profileArn, } @@ -958,6 +1021,7 @@ func (e *KiroExecutor) buildKiroPayload(claudeBody []byte, modelID, profileArn, log.Debugf("kiro: failed to marshal payload: %v", err) return nil } + return result } @@ -2025,7 +2089,13 @@ func (e *KiroExecutor) Refresh(ctx context.Context, auth *cliproxyauth.Auth) (*c e.refreshMu.Lock() defer e.refreshMu.Unlock() - log.Debugf("kiro executor: refresh called for auth %s", auth.ID) + var authID string + if auth != nil { + authID = auth.ID + } else { + authID = "" + } + log.Debugf("kiro executor: refresh called for auth %s", authID) if auth == nil { return nil, fmt.Errorf("kiro executor: auth is nil") } From 84920cb6709a475f7054d067e5f9d9588c057d62 Mon Sep 17 00:00:00 2001 From: Ravens2121 Date: Fri, 12 Dec 2025 13:43:36 +0800 Subject: [PATCH 28/38] feat(kiro): add multi-endpoint fallback & thinking mode support --- PR_DOCUMENTATION.md | 49 ++ internal/api/modules/amp/proxy.go | 10 +- internal/config/config.go | 9 + internal/runtime/executor/kiro_executor.go | 732 +++++++++++++++--- .../chat-completions/kiro_openai_request.go | 29 + .../chat-completions/kiro_openai_response.go | 31 + internal/watcher/watcher.go | 17 + 7 files changed, 787 insertions(+), 90 deletions(-) create mode 100644 PR_DOCUMENTATION.md diff --git a/PR_DOCUMENTATION.md b/PR_DOCUMENTATION.md new file mode 100644 index 00000000..6b830af6 --- /dev/null +++ b/PR_DOCUMENTATION.md @@ -0,0 +1,49 @@ +# PR Title / 拉取请求标题 + +`feat(kiro): Add Thinking Mode support & enhance reliability with multi-quota failover` +`feat(kiro): 支持思考模型 (Thinking Mode) 并通过多配额故障转移增强稳定性` + +--- + +# PR Description / 拉取请求描述 + +## 📝 Summary / 摘要 + +This PR introduces significant upgrades to the Kiro (AWS CodeWhisperer/Amazon Q) module. It adds native support for **Thinking/Reasoning models** (similar to OpenAI o1/Claude 3.7), implements a robust **Multi-Endpoint Failover** system to handle rate limits (429), and optimizes configuration flexibility. + +本次 PR 对 Kiro (AWS CodeWhisperer/Amazon Q) 模块进行了重大升级。它增加了对 **思考/推理模型 (Thinking/Reasoning models)** 的原生支持(类似 OpenAI o1/Claude 3.7),实现了一套健壮的 **多端点故障转移 (Multi-Endpoint Failover)** 系统以应对速率限制 (429),并优化了配置灵活性。 + +## ✨ Key Changes / 主要变更 + +### 1. 🧠 Thinking Mode Support / 思考模式支持 +- **OpenAI Compatibility**: Automatically maps OpenAI's `reasoning_effort` parameter (low/medium/high) to Claude's `budget_tokens` (4k/16k/32k). + - **OpenAI 兼容性**:自动将 OpenAI 的 `reasoning_effort` 参数(low/medium/high)映射为 Claude 的 `budget_tokens`(4k/16k/32k)。 +- **Stream Parsing**: Implemented advanced stream parsing logic to detect and extract content within `...` tags, even across chunk boundaries. + - **流式解析**:实现了高级流式解析逻辑,能够检测并提取 `...` 标签内的内容,即使标签跨越了数据块边界。 +- **Protocol Translation**: Converts Kiro's internal thinking content into OpenAI-compatible `reasoning_content` fields (for non-stream) or `thinking_delta` events (for stream). + - **协议转换**:将 Kiro 内部的思考内容转换为兼容 OpenAI 的 `reasoning_content` 字段(非流式)或 `thinking_delta` 事件(流式)。 + +### 2. 🛡️ Robustness & Failover / 稳健性与故障转移 +- **Dual Quota System**: Explicitly defined `kiroEndpointConfig` to distinguish between **IDE (CodeWhisperer)** and **CLI (Amazon Q)** quotas. + - **双配额系统**:显式定义了 `kiroEndpointConfig` 结构,明确区分 **IDE (CodeWhisperer)** 和 **CLI (Amazon Q)** 的配额来源。 +- **Auto Failover**: Implemented automatic failover logic. If one endpoint returns `429 Too Many Requests`, the request seamlessly retries on the next available endpoint/quota. + - **自动故障转移**:实现了自动故障转移逻辑。如果一个端点返回 `429 Too Many Requests`,请求将无缝在下一个可用端点/配额上重试。 +- **Strict Protocol Compliance**: Enforced strict matching of `Origin` and `X-Amz-Target` headers for each endpoint to prevent `403 Forbidden` errors due to protocol mismatches. + - **严格协议合规**:强制每个端点严格匹配 `Origin` 和 `X-Amz-Target` 头信息,防止因协议不匹配导致的 `403 Forbidden` 错误。 + +### 3. ⚙️ Configuration & Models / 配置与模型 +- **New Config Options**: Added `KiroPreferredEndpoint` (global) and `PreferredEndpoint` (per-key) settings to allow users to prioritize specific quotas (e.g., "ide" or "cli"). + - **新配置项**:添加了 `KiroPreferredEndpoint`(全局)和 `PreferredEndpoint`(单 Key)设置,允许用户优先选择特定的配额(如 "ide" 或 "cli")。 +- **Model Registry**: Normalized model IDs (replaced dots with hyphens) and added `-agentic` variants optimized for large code generation tasks. + - **模型注册表**:规范化了模型 ID(将点号替换为连字符),并添加了针对大型代码生成任务优化的 `-agentic` 变体。 + +### 4. 🔧 Fixes / 修复 +- **AMP Proxy**: Downgraded client-side context cancellation logs from `Error` to `Debug` to reduce log noise. + - **AMP 代理**:将客户端上下文取消的日志级别从 `Error` 降级为 `Debug`,减少日志噪音。 + +## ⚠️ Impact / 影响 + +- **Authentication**: **No changes** to the login/OAuth process. Existing tokens work as is. +- **认证**:登录/OAuth 流程 **无变更**。现有 Token 可直接使用。 +- **Compatibility**: Fully backward compatible. The new failover logic is transparent to the user. +- **兼容性**:完全向后兼容。新的故障转移逻辑对用户是透明的。 \ No newline at end of file diff --git a/internal/api/modules/amp/proxy.go b/internal/api/modules/amp/proxy.go index 33f32c28..6a6b1b54 100644 --- a/internal/api/modules/amp/proxy.go +++ b/internal/api/modules/amp/proxy.go @@ -3,6 +3,8 @@ package amp import ( "bytes" "compress/gzip" + "context" + "errors" "fmt" "io" "net/http" @@ -148,7 +150,13 @@ func createReverseProxy(upstreamURL string, secretSource SecretSource) (*httputi // Error handler for proxy failures proxy.ErrorHandler = func(rw http.ResponseWriter, req *http.Request, err error) { - log.Errorf("amp upstream proxy error for %s %s: %v", req.Method, req.URL.Path, err) + // Check if this is a client-side cancellation (normal behavior) + // Don't log as error for context canceled - it's usually client closing connection + if errors.Is(err, context.Canceled) { + log.Debugf("amp upstream proxy: client canceled request for %s %s", req.Method, req.URL.Path) + } else { + log.Errorf("amp upstream proxy error for %s %s: %v", req.Method, req.URL.Path, err) + } rw.Header().Set("Content-Type", "application/json") rw.WriteHeader(http.StatusBadGateway) _, _ = rw.Write([]byte(`{"error":"amp_upstream_proxy_error","message":"Failed to reach Amp upstream"}`)) diff --git a/internal/config/config.go b/internal/config/config.go index f9da2c29..86b79ad2 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -64,6 +64,10 @@ type Config struct { // KiroKey defines a list of Kiro (AWS CodeWhisperer) configurations. KiroKey []KiroKey `yaml:"kiro" json:"kiro"` + // KiroPreferredEndpoint sets the global default preferred endpoint for all Kiro providers. + // Values: "ide" (default, CodeWhisperer) or "cli" (Amazon Q). + KiroPreferredEndpoint string `yaml:"kiro-preferred-endpoint" json:"kiro-preferred-endpoint"` + // Codex defines a list of Codex API key configurations as specified in the YAML configuration file. CodexKey []CodexKey `yaml:"codex-api-key" json:"codex-api-key"` @@ -278,6 +282,10 @@ type KiroKey struct { // AgentTaskType sets the Kiro API task type. Known values: "vibe", "dev", "chat". // Leave empty to let API use defaults. Different values may inject different system prompts. AgentTaskType string `yaml:"agent-task-type,omitempty" json:"agent-task-type,omitempty"` + + // PreferredEndpoint sets the preferred Kiro API endpoint/quota. + // Values: "codewhisperer" (default, IDE quota) or "amazonq" (CLI quota). + PreferredEndpoint string `yaml:"preferred-endpoint,omitempty" json:"preferred-endpoint,omitempty"` } // OpenAICompatibility represents the configuration for OpenAI API compatibility @@ -504,6 +512,7 @@ func (cfg *Config) SanitizeKiroKeys() { entry.ProfileArn = strings.TrimSpace(entry.ProfileArn) entry.Region = strings.TrimSpace(entry.Region) entry.ProxyURL = strings.TrimSpace(entry.ProxyURL) + entry.PreferredEndpoint = strings.TrimSpace(entry.PreferredEndpoint) } } diff --git a/internal/runtime/executor/kiro_executor.go b/internal/runtime/executor/kiro_executor.go index 2987e2d3..bff3fb57 100644 --- a/internal/runtime/executor/kiro_executor.go +++ b/internal/runtime/executor/kiro_executor.go @@ -31,18 +31,23 @@ import ( ) const ( - // kiroEndpoint is the CodeWhisperer streaming endpoint for chat API (GenerateAssistantResponse). - // Based on AIClient-2-API reference implementation. - // Note: Amazon Q uses a different endpoint (q.us-east-1.amazonaws.com) with different request format. - kiroEndpoint = "https://codewhisperer.us-east-1.amazonaws.com/generateAssistantResponse" - kiroContentType = "application/json" - kiroAcceptStream = "application/json" + // Kiro API common constants + kiroContentType = "application/x-amz-json-1.0" + kiroAcceptStream = "*/*" kiroMaxMessageSize = 10 * 1024 * 1024 // 10MB max message size for event stream kiroMaxToolDescLen = 10237 // Kiro API limit is 10240 bytes, leave room for "..." - // kiroUserAgent matches AIClient-2-API format for x-amz-user-agent header - kiroUserAgent = "aws-sdk-js/1.0.7 KiroIDE-0.1.25" - // kiroFullUserAgent is the complete user-agent header matching AIClient-2-API - kiroFullUserAgent = "aws-sdk-js/1.0.7 ua/2.1 os/linux lang/go api/codewhispererstreaming#1.0.7 m/E KiroIDE-0.1.25" + // kiroUserAgent matches amq2api format for User-Agent header + kiroUserAgent = "aws-sdk-rust/1.3.9 os/macos lang/rust/1.87.0" + // kiroFullUserAgent is the complete x-amz-user-agent header matching amq2api + kiroFullUserAgent = "aws-sdk-rust/1.3.9 ua/2.1 api/ssooidc/1.88.0 os/macos lang/rust/1.87.0 m/E app/AmazonQ-For-CLI" + + // Thinking mode support - based on amq2api implementation + // These tags wrap reasoning content in the response stream + thinkingStartTag = "" + thinkingEndTag = "" + // thinkingHint is injected into the request to enable interleaved thinking mode + // This tells the model to use thinking tags and sets the max thinking length + thinkingHint = "interleaved16000" // kiroAgenticSystemPrompt is injected only for -agentic models to prevent timeouts on large writes. // AWS Kiro API has a 2-3 minute timeout for large file write operations. @@ -97,6 +102,106 @@ You MUST follow these rules for ALL file operations. Violation causes server tim REMEMBER: When in doubt, write LESS per operation. Multiple small operations > one large operation.` ) +// kiroEndpointConfig bundles endpoint URL with its compatible Origin and AmzTarget values. +// This solves the "triple mismatch" problem where different endpoints require matching +// Origin and X-Amz-Target header values. +// +// Based on reference implementations: +// - amq2api-main: Uses Amazon Q endpoint with CLI origin and AmazonQDeveloperStreamingService target +// - AIClient-2-API: Uses CodeWhisperer endpoint with AI_EDITOR origin and AmazonCodeWhispererStreamingService target +type kiroEndpointConfig struct { + URL string // Endpoint URL + Origin string // Request Origin: "CLI" for Amazon Q quota, "AI_EDITOR" for Kiro IDE quota + AmzTarget string // X-Amz-Target header value + Name string // Endpoint name for logging +} + +// kiroEndpointConfigs defines the available Kiro API endpoints with their compatible configurations. +// The order determines fallback priority: primary endpoint first, then fallbacks. +// +// CRITICAL: Each endpoint MUST use its compatible Origin and AmzTarget values: +// - CodeWhisperer endpoint (codewhisperer.us-east-1.amazonaws.com): Uses AI_EDITOR origin and AmazonCodeWhispererStreamingService target +// - Amazon Q endpoint (q.us-east-1.amazonaws.com): Uses CLI origin and AmazonQDeveloperStreamingService target +// +// Mismatched combinations will result in 403 Forbidden errors. +// +// NOTE: CodeWhisperer is set as the default endpoint because: +// 1. Most tokens come from Kiro IDE / VSCode extensions (AWS Builder ID auth) +// 2. These tokens use AI_EDITOR origin which is only compatible with CodeWhisperer endpoint +// 3. Amazon Q endpoint requires CLI origin which is for Amazon Q CLI tokens +// This matches the AIClient-2-API-main project's configuration. +var kiroEndpointConfigs = []kiroEndpointConfig{ + { + URL: "https://codewhisperer.us-east-1.amazonaws.com/generateAssistantResponse", + Origin: "AI_EDITOR", + AmzTarget: "AmazonCodeWhispererStreamingService.GenerateAssistantResponse", + Name: "CodeWhisperer", + }, + { + URL: "https://q.us-east-1.amazonaws.com/", + Origin: "CLI", + AmzTarget: "AmazonQDeveloperStreamingService.SendMessage", + Name: "AmazonQ", + }, +} + +// getKiroEndpointConfigs returns the list of Kiro API endpoint configurations to try in order. +// Supports reordering based on "preferred_endpoint" in auth metadata/attributes. +func getKiroEndpointConfigs(auth *cliproxyauth.Auth) []kiroEndpointConfig { + if auth == nil { + return kiroEndpointConfigs + } + + // Check for preference + var preference string + if auth.Metadata != nil { + if p, ok := auth.Metadata["preferred_endpoint"].(string); ok { + preference = p + } + } + // Check attributes as fallback (e.g. from HTTP headers) + if preference == "" && auth.Attributes != nil { + preference = auth.Attributes["preferred_endpoint"] + } + + if preference == "" { + return kiroEndpointConfigs + } + + preference = strings.ToLower(strings.TrimSpace(preference)) + + // Create new slice to avoid modifying global state + var sorted []kiroEndpointConfig + var remaining []kiroEndpointConfig + + for _, cfg := range kiroEndpointConfigs { + name := strings.ToLower(cfg.Name) + // Check for matches + // CodeWhisperer aliases: codewhisperer, ide + // AmazonQ aliases: amazonq, q, cli + isMatch := false + if (preference == "codewhisperer" || preference == "ide") && name == "codewhisperer" { + isMatch = true + } else if (preference == "amazonq" || preference == "q" || preference == "cli") && name == "amazonq" { + isMatch = true + } + + if isMatch { + sorted = append(sorted, cfg) + } else { + remaining = append(remaining, cfg) + } + } + + // If preference didn't match anything, return default + if len(sorted) == 0 { + return kiroEndpointConfigs + } + + // Combine: preferred first, then others + return append(sorted, remaining...) +} + // KiroExecutor handles requests to AWS CodeWhisperer (Kiro) API. type KiroExecutor struct { cfg *config.Config @@ -181,13 +286,29 @@ func (e *KiroExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, req } // executeWithRetry performs the actual HTTP request with automatic retry on auth errors. -// Supports automatic fallback from CLI (Amazon Q) quota to AI_EDITOR (Kiro IDE) quota on 429. +// Supports automatic fallback between endpoints with different quotas: +// - Amazon Q endpoint (CLI origin) uses Amazon Q Developer quota +// - CodeWhisperer endpoint (AI_EDITOR origin) uses Kiro IDE quota +// Also supports multi-endpoint fallback similar to Antigravity implementation. func (e *KiroExecutor) executeWithRetry(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options, accessToken, profileArn string, kiroPayload, body []byte, from, to sdktranslator.Format, reporter *usageReporter, currentOrigin, kiroModelID string, isAgentic, isChatOnly bool) (cliproxyexecutor.Response, error) { var resp cliproxyexecutor.Response - maxRetries := 2 // Allow retries for token refresh + origin fallback + maxRetries := 2 // Allow retries for token refresh + endpoint fallback + endpointConfigs := getKiroEndpointConfigs(auth) + + for endpointIdx := 0; endpointIdx < len(endpointConfigs); endpointIdx++ { + endpointConfig := endpointConfigs[endpointIdx] + url := endpointConfig.URL + // Use this endpoint's compatible Origin (critical for avoiding 403 errors) + currentOrigin = endpointConfig.Origin + + // Rebuild payload with the correct origin for this endpoint + // Each endpoint requires its matching Origin value in the request body + kiroPayload = e.buildKiroPayload(body, kiroModelID, profileArn, currentOrigin, isAgentic, isChatOnly) + + log.Debugf("kiro: trying endpoint %d/%d: %s (Name: %s, Origin: %s)", + endpointIdx+1, len(endpointConfigs), url, endpointConfig.Name, currentOrigin) for attempt := 0; attempt <= maxRetries; attempt++ { - url := kiroEndpoint httpReq, err := http.NewRequestWithContext(ctx, http.MethodPost, url, bytes.NewReader(kiroPayload)) if err != nil { return resp, err @@ -196,11 +317,12 @@ func (e *KiroExecutor) executeWithRetry(ctx context.Context, auth *cliproxyauth. httpReq.Header.Set("Content-Type", kiroContentType) httpReq.Header.Set("Authorization", "Bearer "+accessToken) httpReq.Header.Set("Accept", kiroAcceptStream) - httpReq.Header.Set("x-amz-user-agent", kiroUserAgent) - httpReq.Header.Set("User-Agent", kiroFullUserAgent) - httpReq.Header.Set("amz-sdk-request", "attempt=1; max=1") - httpReq.Header.Set("x-amzn-kiro-agent-mode", "vibe") - httpReq.Header.Set("amz-sdk-invocation-id", uuid.New().String()) + // Use endpoint-specific X-Amz-Target (critical for avoiding 403 errors) + httpReq.Header.Set("X-Amz-Target", endpointConfig.AmzTarget) + httpReq.Header.Set("User-Agent", kiroUserAgent) + httpReq.Header.Set("X-Amz-User-Agent", kiroFullUserAgent) + httpReq.Header.Set("Amz-Sdk-Request", "attempt=1; max=3") + httpReq.Header.Set("Amz-Sdk-Invocation-Id", uuid.New().String()) var attrs map[string]string if auth != nil { @@ -234,27 +356,17 @@ func (e *KiroExecutor) executeWithRetry(ctx context.Context, auth *cliproxyauth. } recordAPIResponseMetadata(ctx, e.cfg, httpResp.StatusCode, httpResp.Header.Clone()) - // Handle 429 errors (quota exhausted) with origin fallback + // Handle 429 errors (quota exhausted) - try next endpoint + // Each endpoint has its own quota pool, so we can try different endpoints if httpResp.StatusCode == 429 { respBody, _ := io.ReadAll(httpResp.Body) _ = httpResp.Body.Close() appendAPIResponseChunk(ctx, e.cfg, respBody) - // If currently using CLI quota and it's exhausted, switch to AI_EDITOR (Kiro IDE) quota - if currentOrigin == "CLI" { - log.Warnf("kiro: Amazon Q (CLI) quota exhausted (429), switching to Kiro (AI_EDITOR) fallback") - currentOrigin = "AI_EDITOR" - - // Rebuild payload with new origin - kiroPayload = e.buildKiroPayload(body, kiroModelID, profileArn, currentOrigin, isAgentic, isChatOnly) - - // Retry with new origin - continue - } - - // Already on AI_EDITOR or other origin, return the error - log.Debugf("kiro request error, status: %d, body: %s", httpResp.StatusCode, summarizeErrorBody(httpResp.Header.Get("Content-Type"), respBody)) - return resp, statusErr{code: httpResp.StatusCode, msg: string(respBody)} + log.Warnf("kiro: %s endpoint quota exhausted (429), will try next endpoint", endpointConfig.Name) + + // Break inner retry loop to try next endpoint (which has different quota) + break } // Handle 5xx server errors with exponential backoff retry @@ -277,14 +389,15 @@ func (e *KiroExecutor) executeWithRetry(ctx context.Context, auth *cliproxyauth. return resp, statusErr{code: httpResp.StatusCode, msg: string(respBody)} } - // Handle 401/403 errors with token refresh and retry - if httpResp.StatusCode == 401 || httpResp.StatusCode == 403 { + // Handle 401 errors with token refresh and retry + // 401 = Unauthorized (token expired/invalid) - refresh token + if httpResp.StatusCode == 401 { respBody, _ := io.ReadAll(httpResp.Body) _ = httpResp.Body.Close() appendAPIResponseChunk(ctx, e.cfg, respBody) if attempt < maxRetries { - log.Warnf("kiro: received %d error, attempting token refresh and retry (attempt %d/%d)", httpResp.StatusCode, attempt+1, maxRetries+1) + log.Warnf("kiro: received 401 error, attempting token refresh and retry (attempt %d/%d)", attempt+1, maxRetries+1) refreshedAuth, refreshErr := e.Refresh(ctx, auth) if refreshErr != nil { @@ -302,7 +415,66 @@ func (e *KiroExecutor) executeWithRetry(ctx context.Context, auth *cliproxyauth. } } - log.Debugf("kiro request error, status: %d, body: %s", httpResp.StatusCode, summarizeErrorBody(httpResp.Header.Get("Content-Type"), respBody)) + log.Warnf("kiro request error, status: 401, body: %s", summarizeErrorBody(httpResp.Header.Get("Content-Type"), respBody)) + return resp, statusErr{code: httpResp.StatusCode, msg: string(respBody)} + } + + // Handle 402 errors - Monthly Limit Reached + if httpResp.StatusCode == 402 { + respBody, _ := io.ReadAll(httpResp.Body) + _ = httpResp.Body.Close() + appendAPIResponseChunk(ctx, e.cfg, respBody) + + log.Warnf("kiro: received 402 (monthly limit). Upstream body: %s", string(respBody)) + + // Return upstream error body directly + return resp, statusErr{code: httpResp.StatusCode, msg: string(respBody)} + } + + // Handle 403 errors - Access Denied / Token Expired + // Do NOT switch endpoints for 403 errors + if httpResp.StatusCode == 403 { + respBody, _ := io.ReadAll(httpResp.Body) + _ = httpResp.Body.Close() + appendAPIResponseChunk(ctx, e.cfg, respBody) + + // Log the 403 error details for debugging + log.Warnf("kiro: received 403 error (attempt %d/%d), body: %s", attempt+1, maxRetries+1, summarizeErrorBody(httpResp.Header.Get("Content-Type"), respBody)) + + respBodyStr := string(respBody) + + // Check for SUSPENDED status - return immediately without retry + if strings.Contains(respBodyStr, "SUSPENDED") || strings.Contains(respBodyStr, "TEMPORARILY_SUSPENDED") { + log.Errorf("kiro: account is suspended, cannot proceed") + return resp, statusErr{code: httpResp.StatusCode, msg: "account suspended: " + string(respBody)} + } + + // Check if this looks like a token-related 403 (some APIs return 403 for expired tokens) + isTokenRelated := strings.Contains(respBodyStr, "token") || + strings.Contains(respBodyStr, "expired") || + strings.Contains(respBodyStr, "invalid") || + strings.Contains(respBodyStr, "unauthorized") + + if isTokenRelated && attempt < maxRetries { + log.Warnf("kiro: 403 appears token-related, attempting token refresh") + refreshedAuth, refreshErr := e.Refresh(ctx, auth) + if refreshErr != nil { + log.Errorf("kiro: token refresh failed: %v", refreshErr) + // Token refresh failed - return error immediately + return resp, statusErr{code: httpResp.StatusCode, msg: string(respBody)} + } + if refreshedAuth != nil { + auth = refreshedAuth + accessToken, profileArn = kiroCredentials(auth) + kiroPayload = e.buildKiroPayload(body, kiroModelID, profileArn, currentOrigin, isAgentic, isChatOnly) + log.Infof("kiro: token refreshed for 403, retrying request") + continue + } + } + + // For non-token 403 or after max retries, return error immediately + // Do NOT switch endpoints for 403 errors + log.Warnf("kiro: 403 error, returning immediately (no endpoint switch)") return resp, statusErr{code: httpResp.StatusCode, msg: string(respBody)} } @@ -362,9 +534,14 @@ func (e *KiroExecutor) executeWithRetry(ctx context.Context, auth *cliproxyauth. out := sdktranslator.TranslateNonStream(ctx, to, from, req.Model, bytes.Clone(opts.OriginalRequest), body, kiroResponse, nil) resp = cliproxyexecutor.Response{Payload: []byte(out)} return resp, nil + } + // Inner retry loop exhausted for this endpoint, try next endpoint + // Note: This code is unreachable because all paths in the inner loop + // either return or continue. Kept as comment for documentation. } - return resp, fmt.Errorf("kiro: max retries exceeded") + // All endpoints exhausted + return resp, fmt.Errorf("kiro: all endpoints exhausted") } // ExecuteStream handles streaming requests to Kiro API. @@ -431,12 +608,28 @@ func (e *KiroExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Aut } // executeStreamWithRetry performs the streaming HTTP request with automatic retry on auth errors. -// Supports automatic fallback from CLI (Amazon Q) quota to AI_EDITOR (Kiro IDE) quota on 429. +// Supports automatic fallback between endpoints with different quotas: +// - Amazon Q endpoint (CLI origin) uses Amazon Q Developer quota +// - CodeWhisperer endpoint (AI_EDITOR origin) uses Kiro IDE quota +// Also supports multi-endpoint fallback similar to Antigravity implementation. func (e *KiroExecutor) executeStreamWithRetry(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options, accessToken, profileArn string, kiroPayload, body []byte, from sdktranslator.Format, reporter *usageReporter, currentOrigin, kiroModelID string, isAgentic, isChatOnly bool) (<-chan cliproxyexecutor.StreamChunk, error) { - maxRetries := 2 // Allow retries for token refresh + origin fallback + maxRetries := 2 // Allow retries for token refresh + endpoint fallback + endpointConfigs := getKiroEndpointConfigs(auth) + + for endpointIdx := 0; endpointIdx < len(endpointConfigs); endpointIdx++ { + endpointConfig := endpointConfigs[endpointIdx] + url := endpointConfig.URL + // Use this endpoint's compatible Origin (critical for avoiding 403 errors) + currentOrigin = endpointConfig.Origin + + // Rebuild payload with the correct origin for this endpoint + // Each endpoint requires its matching Origin value in the request body + kiroPayload = e.buildKiroPayload(body, kiroModelID, profileArn, currentOrigin, isAgentic, isChatOnly) + + log.Debugf("kiro: stream trying endpoint %d/%d: %s (Name: %s, Origin: %s)", + endpointIdx+1, len(endpointConfigs), url, endpointConfig.Name, currentOrigin) for attempt := 0; attempt <= maxRetries; attempt++ { - url := kiroEndpoint httpReq, err := http.NewRequestWithContext(ctx, http.MethodPost, url, bytes.NewReader(kiroPayload)) if err != nil { return nil, err @@ -445,11 +638,12 @@ func (e *KiroExecutor) executeStreamWithRetry(ctx context.Context, auth *cliprox httpReq.Header.Set("Content-Type", kiroContentType) httpReq.Header.Set("Authorization", "Bearer "+accessToken) httpReq.Header.Set("Accept", kiroAcceptStream) - httpReq.Header.Set("x-amz-user-agent", kiroUserAgent) - httpReq.Header.Set("User-Agent", kiroFullUserAgent) - httpReq.Header.Set("amz-sdk-request", "attempt=1; max=1") - httpReq.Header.Set("x-amzn-kiro-agent-mode", "vibe") - httpReq.Header.Set("amz-sdk-invocation-id", uuid.New().String()) + // Use endpoint-specific X-Amz-Target (critical for avoiding 403 errors) + httpReq.Header.Set("X-Amz-Target", endpointConfig.AmzTarget) + httpReq.Header.Set("User-Agent", kiroUserAgent) + httpReq.Header.Set("X-Amz-User-Agent", kiroFullUserAgent) + httpReq.Header.Set("Amz-Sdk-Request", "attempt=1; max=3") + httpReq.Header.Set("Amz-Sdk-Invocation-Id", uuid.New().String()) var attrs map[string]string if auth != nil { @@ -483,27 +677,17 @@ func (e *KiroExecutor) executeStreamWithRetry(ctx context.Context, auth *cliprox } recordAPIResponseMetadata(ctx, e.cfg, httpResp.StatusCode, httpResp.Header.Clone()) - // Handle 429 errors (quota exhausted) with origin fallback + // Handle 429 errors (quota exhausted) - try next endpoint + // Each endpoint has its own quota pool, so we can try different endpoints if httpResp.StatusCode == 429 { respBody, _ := io.ReadAll(httpResp.Body) _ = httpResp.Body.Close() appendAPIResponseChunk(ctx, e.cfg, respBody) - // If currently using CLI quota and it's exhausted, switch to AI_EDITOR (Kiro IDE) quota - if currentOrigin == "CLI" { - log.Warnf("kiro: stream Amazon Q (CLI) quota exhausted (429), switching to Kiro (AI_EDITOR) fallback") - currentOrigin = "AI_EDITOR" - - // Rebuild payload with new origin - kiroPayload = e.buildKiroPayload(body, kiroModelID, profileArn, currentOrigin, isAgentic, isChatOnly) - - // Retry with new origin - continue - } - - // Already on AI_EDITOR or other origin, return the error - log.Debugf("kiro stream error, status: %d, body: %s", httpResp.StatusCode, string(respBody)) - return nil, statusErr{code: httpResp.StatusCode, msg: string(respBody)} + log.Warnf("kiro: stream %s endpoint quota exhausted (429), will try next endpoint", endpointConfig.Name) + + // Break inner retry loop to try next endpoint (which has different quota) + break } // Handle 5xx server errors with exponential backoff retry @@ -526,14 +710,28 @@ func (e *KiroExecutor) executeStreamWithRetry(ctx context.Context, auth *cliprox return nil, statusErr{code: httpResp.StatusCode, msg: string(respBody)} } - // Handle 401/403 errors with token refresh and retry - if httpResp.StatusCode == 401 || httpResp.StatusCode == 403 { + // Handle 400 errors - Credential/Validation issues + // Do NOT switch endpoints - return error immediately + if httpResp.StatusCode == 400 { + respBody, _ := io.ReadAll(httpResp.Body) + _ = httpResp.Body.Close() + appendAPIResponseChunk(ctx, e.cfg, respBody) + + log.Warnf("kiro: received 400 error (attempt %d/%d), body: %s", attempt+1, maxRetries+1, summarizeErrorBody(httpResp.Header.Get("Content-Type"), respBody)) + + // 400 errors indicate request validation issues - return immediately without retry + return nil, statusErr{code: httpResp.StatusCode, msg: string(respBody)} + } + + // Handle 401 errors with token refresh and retry + // 401 = Unauthorized (token expired/invalid) - refresh token + if httpResp.StatusCode == 401 { respBody, _ := io.ReadAll(httpResp.Body) _ = httpResp.Body.Close() appendAPIResponseChunk(ctx, e.cfg, respBody) if attempt < maxRetries { - log.Warnf("kiro: stream received %d error, attempting token refresh and retry (attempt %d/%d)", httpResp.StatusCode, attempt+1, maxRetries+1) + log.Warnf("kiro: stream received 401 error, attempting token refresh and retry (attempt %d/%d)", attempt+1, maxRetries+1) refreshedAuth, refreshErr := e.Refresh(ctx, auth) if refreshErr != nil { @@ -551,7 +749,66 @@ func (e *KiroExecutor) executeStreamWithRetry(ctx context.Context, auth *cliprox } } - log.Debugf("kiro stream error, status: %d, body: %s", httpResp.StatusCode, string(respBody)) + log.Warnf("kiro stream error, status: 401, body: %s", string(respBody)) + return nil, statusErr{code: httpResp.StatusCode, msg: string(respBody)} + } + + // Handle 402 errors - Monthly Limit Reached + if httpResp.StatusCode == 402 { + respBody, _ := io.ReadAll(httpResp.Body) + _ = httpResp.Body.Close() + appendAPIResponseChunk(ctx, e.cfg, respBody) + + log.Warnf("kiro: stream received 402 (monthly limit). Upstream body: %s", string(respBody)) + + // Return upstream error body directly + return nil, statusErr{code: httpResp.StatusCode, msg: string(respBody)} + } + + // Handle 403 errors - Access Denied / Token Expired + // Do NOT switch endpoints for 403 errors + if httpResp.StatusCode == 403 { + respBody, _ := io.ReadAll(httpResp.Body) + _ = httpResp.Body.Close() + appendAPIResponseChunk(ctx, e.cfg, respBody) + + // Log the 403 error details for debugging + log.Warnf("kiro: stream received 403 error (attempt %d/%d), body: %s", attempt+1, maxRetries+1, string(respBody)) + + respBodyStr := string(respBody) + + // Check for SUSPENDED status - return immediately without retry + if strings.Contains(respBodyStr, "SUSPENDED") || strings.Contains(respBodyStr, "TEMPORARILY_SUSPENDED") { + log.Errorf("kiro: account is suspended, cannot proceed") + return nil, statusErr{code: httpResp.StatusCode, msg: "account suspended: " + string(respBody)} + } + + // Check if this looks like a token-related 403 (some APIs return 403 for expired tokens) + isTokenRelated := strings.Contains(respBodyStr, "token") || + strings.Contains(respBodyStr, "expired") || + strings.Contains(respBodyStr, "invalid") || + strings.Contains(respBodyStr, "unauthorized") + + if isTokenRelated && attempt < maxRetries { + log.Warnf("kiro: 403 appears token-related, attempting token refresh") + refreshedAuth, refreshErr := e.Refresh(ctx, auth) + if refreshErr != nil { + log.Errorf("kiro: token refresh failed: %v", refreshErr) + // Token refresh failed - return error immediately + return nil, statusErr{code: httpResp.StatusCode, msg: string(respBody)} + } + if refreshedAuth != nil { + auth = refreshedAuth + accessToken, profileArn = kiroCredentials(auth) + kiroPayload = e.buildKiroPayload(body, kiroModelID, profileArn, currentOrigin, isAgentic, isChatOnly) + log.Infof("kiro: token refreshed for 403, retrying stream request") + continue + } + } + + // For non-token 403 or after max retries, return error immediately + // Do NOT switch endpoints for 403 errors + log.Warnf("kiro: 403 error, returning immediately (no endpoint switch)") return nil, statusErr{code: httpResp.StatusCode, msg: string(respBody)} } @@ -585,9 +842,14 @@ func (e *KiroExecutor) executeStreamWithRetry(ctx context.Context, auth *cliprox }(httpResp) return out, nil + } + // Inner retry loop exhausted for this endpoint, try next endpoint + // Note: This code is unreachable because all paths in the inner loop + // either return or continue. Kept as comment for documentation. } - return nil, fmt.Errorf("kiro: max retries exceeded for stream") + // All endpoints exhausted + return nil, fmt.Errorf("kiro: stream all endpoints exhausted") } @@ -795,6 +1057,7 @@ type kiroToolUse struct { // origin parameter determines which quota to use: "CLI" for Amazon Q, "AI_EDITOR" for Kiro IDE. // isAgentic parameter enables chunked write optimization prompt for -agentic model variants. // isChatOnly parameter disables tool calling for -chat model variants (pure conversation mode). +// Supports thinking mode - when Claude API thinking parameter is present, injects thinkingHint. func (e *KiroExecutor) buildKiroPayload(claudeBody []byte, modelID, profileArn, origin string, isAgentic, isChatOnly bool) []byte { // Normalize origin value for Kiro API compatibility // Kiro API only accepts "CLI" or "AI_EDITOR" as valid origin values @@ -840,6 +1103,39 @@ func (e *KiroExecutor) buildKiroPayload(claudeBody []byte, modelID, profileArn, systemPrompt = systemField.String() } + // Check for thinking parameter in Claude API request + // Claude API format: {"thinking": {"type": "enabled", "budget_tokens": 16000}} + // When thinking is enabled, inject dynamic thinkingHint based on budget_tokens + // This allows reasoning_effort (low/medium/high) to control actual thinking length + thinkingEnabled := false + var budgetTokens int64 = 16000 // Default value (same as OpenAI reasoning_effort "medium") + thinkingField := gjson.GetBytes(claudeBody, "thinking") + if thinkingField.Exists() { + // Check if thinking.type is "enabled" + thinkingType := thinkingField.Get("type").String() + if thinkingType == "enabled" { + thinkingEnabled = true + // Read budget_tokens if specified - this value comes from: + // - Claude API: thinking.budget_tokens directly + // - OpenAI API: reasoning_effort -> budget_tokens (low:4000, medium:16000, high:32000) + if bt := thinkingField.Get("budget_tokens"); bt.Exists() && bt.Int() > 0 { + budgetTokens = bt.Int() + } + log.Debugf("kiro: thinking mode enabled via Claude API parameter, budget_tokens: %d", budgetTokens) + } + } + + // Inject timestamp context for better temporal awareness + // Based on amq2api implementation - helps model understand current time context + timestamp := time.Now().Format("2006-01-02 15:04:05 MST") + timestampContext := fmt.Sprintf("[Context: Current time is %s]", timestamp) + if systemPrompt != "" { + systemPrompt = timestampContext + "\n\n" + systemPrompt + } else { + systemPrompt = timestampContext + } + log.Debugf("kiro: injected timestamp context: %s", timestamp) + // Inject agentic optimization prompt for -agentic model variants // This prevents AWS Kiro API timeouts during large file write operations if isAgentic { @@ -849,6 +1145,20 @@ func (e *KiroExecutor) buildKiroPayload(claudeBody []byte, modelID, profileArn, systemPrompt += kiroAgenticSystemPrompt } + // Inject thinking hint when thinking mode is enabled + // This tells the model to use tags in its response + // DYNAMICALLY set max_thinking_length based on budget_tokens from request + // This respects the reasoning_effort setting: low(4000), medium(16000), high(32000) + if thinkingEnabled { + if systemPrompt != "" { + systemPrompt += "\n" + } + // Build dynamic thinking hint with the actual budget_tokens value + dynamicThinkingHint := fmt.Sprintf("interleaved%d", budgetTokens) + systemPrompt += dynamicThinkingHint + log.Debugf("kiro: injected dynamic thinking hint into system prompt, max_thinking_length: %d", budgetTokens) + } + // Convert Claude tools to Kiro format var kiroTools []kiroToolWrapper if tools.IsArray() { @@ -859,13 +1169,15 @@ func (e *KiroExecutor) buildKiroPayload(claudeBody []byte, modelID, profileArn, // Truncate long descriptions (Kiro API limit is in bytes) // Truncate at valid UTF-8 boundary to avoid breaking multi-byte chars + // Add truncation notice to help model understand the description is incomplete if len(description) > kiroMaxToolDescLen { // Find a valid UTF-8 boundary before the limit - truncLen := kiroMaxToolDescLen + // Reserve space for truncation notice (about 30 bytes) + truncLen := kiroMaxToolDescLen - 30 for truncLen > 0 && !utf8.RuneStart(description[truncLen]) { truncLen-- } - description = description[:truncLen] + "..." + description = description[:truncLen] + "... (description truncated)" } kiroTools = append(kiroTools, kiroToolWrapper{ @@ -1505,6 +1817,14 @@ func (e *KiroExecutor) streamToChannel(ctx context.Context, body io.Reader, out // IMPORTANT: This must persist across all TranslateStream calls var translatorParam any + // Thinking mode state tracking - based on amq2api implementation + // Tracks whether we're inside a block and handles partial tags + inThinkBlock := false + pendingStartTagChars := 0 // Number of chars that might be start of + pendingEndTagChars := 0 // Number of chars that might be start of + isThinkingBlockOpen := false // Track if thinking content block is open + thinkingBlockIndex := -1 // Index of the thinking content block + // Pre-calculate input tokens from request if possible if enc, err := tokenizerForModel(model); err == nil { // Try OpenAI format first, then fall back to raw byte count estimation @@ -1715,7 +2035,7 @@ func (e *KiroExecutor) streamToChannel(ctx context.Context, body io.Reader, out } } - // Handle text content with duplicate detection + // Handle text content with duplicate detection and thinking mode support if contentDelta != "" { // Check for duplicate content - skip if identical to last content event // Based on AIClient-2-API implementation for Kiro @@ -1728,24 +2048,218 @@ func (e *KiroExecutor) streamToChannel(ctx context.Context, body io.Reader, out outputLen += len(contentDelta) // Accumulate content for streaming token calculation accumulatedContent.WriteString(contentDelta) - // Start text content block if needed - if !isTextBlockOpen { - contentBlockIndex++ - isTextBlockOpen = true - blockStart := e.buildClaudeContentBlockStartEvent(contentBlockIndex, "text", "", "") - sseData := sdktranslator.TranslateStream(ctx, sdktranslator.FromString("kiro"), targetFormat, model, originalReq, claudeBody, blockStart, &translatorParam) - for _, chunk := range sseData { - if chunk != "" { - out <- cliproxyexecutor.StreamChunk{Payload: []byte(chunk + "\n\n")} - } - } + + // Process content with thinking tag detection - based on amq2api implementation + // This handles and tags that may span across chunks + remaining := contentDelta + + // If we have pending start tag chars from previous chunk, prepend them + if pendingStartTagChars > 0 { + remaining = thinkingStartTag[:pendingStartTagChars] + remaining + pendingStartTagChars = 0 + } + + // If we have pending end tag chars from previous chunk, prepend them + if pendingEndTagChars > 0 { + remaining = thinkingEndTag[:pendingEndTagChars] + remaining + pendingEndTagChars = 0 } - claudeEvent := e.buildClaudeStreamEvent(contentDelta, contentBlockIndex) - sseData := sdktranslator.TranslateStream(ctx, sdktranslator.FromString("kiro"), targetFormat, model, originalReq, claudeBody, claudeEvent, &translatorParam) - for _, chunk := range sseData { - if chunk != "" { - out <- cliproxyexecutor.StreamChunk{Payload: []byte(chunk + "\n\n")} + for len(remaining) > 0 { + if inThinkBlock { + // Inside thinking block - look for end tag + endIdx := strings.Index(remaining, thinkingEndTag) + if endIdx >= 0 { + // Found end tag - emit any content before end tag, then close block + thinkContent := remaining[:endIdx] + if thinkContent != "" { + // TRUE STREAMING: Emit thinking content immediately + // Start thinking block if not open + if !isThinkingBlockOpen { + contentBlockIndex++ + thinkingBlockIndex = contentBlockIndex + isThinkingBlockOpen = true + blockStart := e.buildClaudeContentBlockStartEvent(thinkingBlockIndex, "thinking", "", "") + sseData := sdktranslator.TranslateStream(ctx, sdktranslator.FromString("kiro"), targetFormat, model, originalReq, claudeBody, blockStart, &translatorParam) + for _, chunk := range sseData { + if chunk != "" { + out <- cliproxyexecutor.StreamChunk{Payload: []byte(chunk + "\n\n")} + } + } + } + + // Send thinking delta immediately + thinkingEvent := e.buildClaudeThinkingDeltaEvent(thinkContent, thinkingBlockIndex) + sseData := sdktranslator.TranslateStream(ctx, sdktranslator.FromString("kiro"), targetFormat, model, originalReq, claudeBody, thinkingEvent, &translatorParam) + for _, chunk := range sseData { + if chunk != "" { + out <- cliproxyexecutor.StreamChunk{Payload: []byte(chunk + "\n\n")} + } + } + } + + // Note: Partial tag handling is done via pendingEndTagChars + // When the next chunk arrives, the partial tag will be reconstructed + + // Close thinking block + if isThinkingBlockOpen { + blockStop := e.buildClaudeContentBlockStopEvent(thinkingBlockIndex) + sseData := sdktranslator.TranslateStream(ctx, sdktranslator.FromString("kiro"), targetFormat, model, originalReq, claudeBody, blockStop, &translatorParam) + for _, chunk := range sseData { + if chunk != "" { + out <- cliproxyexecutor.StreamChunk{Payload: []byte(chunk + "\n\n")} + } + } + isThinkingBlockOpen = false + } + + inThinkBlock = false + remaining = remaining[endIdx+len(thinkingEndTag):] + log.Debugf("kiro: exited thinking block") + } else { + // No end tag found - TRUE STREAMING: emit content immediately + // Only save potential partial tag length for next iteration + pendingEnd := pendingTagSuffix(remaining, thinkingEndTag) + + // Calculate content to emit immediately (excluding potential partial tag) + var contentToEmit string + if pendingEnd > 0 { + contentToEmit = remaining[:len(remaining)-pendingEnd] + // Save partial tag length for next iteration (will be reconstructed from thinkingEndTag) + pendingEndTagChars = pendingEnd + } else { + contentToEmit = remaining + } + + // TRUE STREAMING: Emit thinking content immediately + if contentToEmit != "" { + // Start thinking block if not open + if !isThinkingBlockOpen { + contentBlockIndex++ + thinkingBlockIndex = contentBlockIndex + isThinkingBlockOpen = true + blockStart := e.buildClaudeContentBlockStartEvent(thinkingBlockIndex, "thinking", "", "") + sseData := sdktranslator.TranslateStream(ctx, sdktranslator.FromString("kiro"), targetFormat, model, originalReq, claudeBody, blockStart, &translatorParam) + for _, chunk := range sseData { + if chunk != "" { + out <- cliproxyexecutor.StreamChunk{Payload: []byte(chunk + "\n\n")} + } + } + } + + // Send thinking delta immediately - TRUE STREAMING! + thinkingEvent := e.buildClaudeThinkingDeltaEvent(contentToEmit, thinkingBlockIndex) + sseData := sdktranslator.TranslateStream(ctx, sdktranslator.FromString("kiro"), targetFormat, model, originalReq, claudeBody, thinkingEvent, &translatorParam) + for _, chunk := range sseData { + if chunk != "" { + out <- cliproxyexecutor.StreamChunk{Payload: []byte(chunk + "\n\n")} + } + } + } + + remaining = "" + } + } else { + // Outside thinking block - look for start tag + startIdx := strings.Index(remaining, thinkingStartTag) + if startIdx >= 0 { + // Found start tag - emit text before it and switch to thinking mode + textBefore := remaining[:startIdx] + if textBefore != "" { + // Start text content block if needed + if !isTextBlockOpen { + contentBlockIndex++ + isTextBlockOpen = true + blockStart := e.buildClaudeContentBlockStartEvent(contentBlockIndex, "text", "", "") + sseData := sdktranslator.TranslateStream(ctx, sdktranslator.FromString("kiro"), targetFormat, model, originalReq, claudeBody, blockStart, &translatorParam) + for _, chunk := range sseData { + if chunk != "" { + out <- cliproxyexecutor.StreamChunk{Payload: []byte(chunk + "\n\n")} + } + } + } + + claudeEvent := e.buildClaudeStreamEvent(textBefore, contentBlockIndex) + sseData := sdktranslator.TranslateStream(ctx, sdktranslator.FromString("kiro"), targetFormat, model, originalReq, claudeBody, claudeEvent, &translatorParam) + for _, chunk := range sseData { + if chunk != "" { + out <- cliproxyexecutor.StreamChunk{Payload: []byte(chunk + "\n\n")} + } + } + } + + // Close text block before starting thinking block + if isTextBlockOpen { + blockStop := e.buildClaudeContentBlockStopEvent(contentBlockIndex) + sseData := sdktranslator.TranslateStream(ctx, sdktranslator.FromString("kiro"), targetFormat, model, originalReq, claudeBody, blockStop, &translatorParam) + for _, chunk := range sseData { + if chunk != "" { + out <- cliproxyexecutor.StreamChunk{Payload: []byte(chunk + "\n\n")} + } + } + isTextBlockOpen = false + } + + inThinkBlock = true + remaining = remaining[startIdx+len(thinkingStartTag):] + log.Debugf("kiro: entered thinking block") + } else { + // No start tag found - check for partial start tag at buffer end + pendingStart := pendingTagSuffix(remaining, thinkingStartTag) + if pendingStart > 0 { + // Emit text except potential partial tag + textToEmit := remaining[:len(remaining)-pendingStart] + if textToEmit != "" { + // Start text content block if needed + if !isTextBlockOpen { + contentBlockIndex++ + isTextBlockOpen = true + blockStart := e.buildClaudeContentBlockStartEvent(contentBlockIndex, "text", "", "") + sseData := sdktranslator.TranslateStream(ctx, sdktranslator.FromString("kiro"), targetFormat, model, originalReq, claudeBody, blockStart, &translatorParam) + for _, chunk := range sseData { + if chunk != "" { + out <- cliproxyexecutor.StreamChunk{Payload: []byte(chunk + "\n\n")} + } + } + } + + claudeEvent := e.buildClaudeStreamEvent(textToEmit, contentBlockIndex) + sseData := sdktranslator.TranslateStream(ctx, sdktranslator.FromString("kiro"), targetFormat, model, originalReq, claudeBody, claudeEvent, &translatorParam) + for _, chunk := range sseData { + if chunk != "" { + out <- cliproxyexecutor.StreamChunk{Payload: []byte(chunk + "\n\n")} + } + } + } + pendingStartTagChars = pendingStart + remaining = "" + } else { + // No partial tag - emit all as text + if remaining != "" { + // Start text content block if needed + if !isTextBlockOpen { + contentBlockIndex++ + isTextBlockOpen = true + blockStart := e.buildClaudeContentBlockStartEvent(contentBlockIndex, "text", "", "") + sseData := sdktranslator.TranslateStream(ctx, sdktranslator.FromString("kiro"), targetFormat, model, originalReq, claudeBody, blockStart, &translatorParam) + for _, chunk := range sseData { + if chunk != "" { + out <- cliproxyexecutor.StreamChunk{Payload: []byte(chunk + "\n\n")} + } + } + } + + claudeEvent := e.buildClaudeStreamEvent(remaining, contentBlockIndex) + sseData := sdktranslator.TranslateStream(ctx, sdktranslator.FromString("kiro"), targetFormat, model, originalReq, claudeBody, claudeEvent, &translatorParam) + for _, chunk := range sseData { + if chunk != "" { + out <- cliproxyexecutor.StreamChunk{Payload: []byte(chunk + "\n\n")} + } + } + } + remaining = "" + } + } } } } @@ -1981,14 +2495,20 @@ func (e *KiroExecutor) buildClaudeMessageStartEvent(model string, inputTokens in func (e *KiroExecutor) buildClaudeContentBlockStartEvent(index int, blockType, toolUseID, toolName string) []byte { var contentBlock map[string]interface{} - if blockType == "tool_use" { + switch blockType { + case "tool_use": contentBlock = map[string]interface{}{ "type": "tool_use", "id": toolUseID, "name": toolName, "input": map[string]interface{}{}, } - } else { + case "thinking": + contentBlock = map[string]interface{}{ + "type": "thinking", + "thinking": "", + } + default: contentBlock = map[string]interface{}{ "type": "text", "text": "", @@ -2075,6 +2595,40 @@ func (e *KiroExecutor) buildClaudeFinalEvent() []byte { return []byte("event: message_stop\ndata: " + string(result)) } +// buildClaudeThinkingDeltaEvent creates a thinking_delta event for Claude API compatibility. +// This is used when streaming thinking content wrapped in tags. +func (e *KiroExecutor) buildClaudeThinkingDeltaEvent(thinkingDelta string, index int) []byte { + event := map[string]interface{}{ + "type": "content_block_delta", + "index": index, + "delta": map[string]interface{}{ + "type": "thinking_delta", + "thinking": thinkingDelta, + }, + } + result, _ := json.Marshal(event) + return []byte("event: content_block_delta\ndata: " + string(result)) +} + +// pendingTagSuffix detects if the buffer ends with a partial prefix of the given tag. +// Returns the length of the partial match (0 if no match). +// Based on amq2api implementation for handling cross-chunk tag boundaries. +func pendingTagSuffix(buffer, tag string) int { + if buffer == "" || tag == "" { + return 0 + } + maxLen := len(buffer) + if maxLen > len(tag)-1 { + maxLen = len(tag) - 1 + } + for length := maxLen; length > 0; length-- { + if len(buffer) >= length && buffer[len(buffer)-length:] == tag[:length] { + return length + } + } + return 0 +} + // CountTokens is not supported for Kiro provider. // Kiro/Amazon Q backend doesn't expose a token counting API. func (e *KiroExecutor) CountTokens(context.Context, *cliproxyauth.Auth, cliproxyexecutor.Request, cliproxyexecutor.Options) (cliproxyexecutor.Response, error) { diff --git a/internal/translator/kiro/openai/chat-completions/kiro_openai_request.go b/internal/translator/kiro/openai/chat-completions/kiro_openai_request.go index 3d339505..d1094c1c 100644 --- a/internal/translator/kiro/openai/chat-completions/kiro_openai_request.go +++ b/internal/translator/kiro/openai/chat-completions/kiro_openai_request.go @@ -10,9 +10,18 @@ import ( "github.com/tidwall/sjson" ) +// reasoningEffortToBudget maps OpenAI reasoning_effort values to Claude thinking budget_tokens. +// OpenAI uses "low", "medium", "high" while Claude uses numeric budget_tokens. +var reasoningEffortToBudget = map[string]int{ + "low": 4000, + "medium": 16000, + "high": 32000, +} + // ConvertOpenAIRequestToKiro transforms an OpenAI Chat Completions API request into Kiro (Claude) format. // Kiro uses Claude-compatible format internally, so we primarily pass through to Claude format. // Supports tool calling: OpenAI tools -> Claude tools, tool_calls -> tool_use, tool messages -> tool_result. +// Supports reasoning/thinking: OpenAI reasoning_effort -> Claude thinking parameter. func ConvertOpenAIRequestToKiro(modelName string, inputRawJSON []byte, stream bool) []byte { rawJSON := bytes.Clone(inputRawJSON) root := gjson.ParseBytes(rawJSON) @@ -38,6 +47,26 @@ func ConvertOpenAIRequestToKiro(modelName string, inputRawJSON []byte, stream bo out, _ = sjson.Set(out, "top_p", v.Float()) } + // Handle OpenAI reasoning_effort parameter -> Claude thinking parameter + // OpenAI format: {"reasoning_effort": "low"|"medium"|"high"} + // Claude format: {"thinking": {"type": "enabled", "budget_tokens": N}} + if v := root.Get("reasoning_effort"); v.Exists() { + effort := v.String() + if budget, ok := reasoningEffortToBudget[effort]; ok { + thinking := map[string]interface{}{ + "type": "enabled", + "budget_tokens": budget, + } + out, _ = sjson.Set(out, "thinking", thinking) + } + } + + // Also support direct thinking parameter passthrough (for Claude API compatibility) + // Claude format: {"thinking": {"type": "enabled", "budget_tokens": N}} + if v := root.Get("thinking"); v.Exists() && v.IsObject() { + out, _ = sjson.Set(out, "thinking", v.Value()) + } + // Convert OpenAI tools to Claude tools format if tools := root.Get("tools"); tools.Exists() && tools.IsArray() { claudeTools := make([]interface{}, 0) diff --git a/internal/translator/kiro/openai/chat-completions/kiro_openai_response.go b/internal/translator/kiro/openai/chat-completions/kiro_openai_response.go index d56c94ac..2fab2a4d 100644 --- a/internal/translator/kiro/openai/chat-completions/kiro_openai_response.go +++ b/internal/translator/kiro/openai/chat-completions/kiro_openai_response.go @@ -134,6 +134,28 @@ func convertClaudeEventToOpenAI(jsonStr string, model string) []string { result, _ := json.Marshal(response) results = append(results, string(result)) } + } else if deltaType == "thinking_delta" { + // Thinking/reasoning content delta - convert to OpenAI reasoning_content format + thinkingDelta := root.Get("delta.thinking").String() + if thinkingDelta != "" { + response := map[string]interface{}{ + "id": "chatcmpl-" + uuid.New().String()[:24], + "object": "chat.completion.chunk", + "created": time.Now().Unix(), + "model": model, + "choices": []map[string]interface{}{ + { + "index": 0, + "delta": map[string]interface{}{ + "reasoning_content": thinkingDelta, + }, + "finish_reason": nil, + }, + }, + } + result, _ := json.Marshal(response) + results = append(results, string(result)) + } } else if deltaType == "input_json_delta" { // Tool input delta (streaming arguments) partialJSON := root.Get("delta.partial_json").String() @@ -298,6 +320,7 @@ func ConvertKiroResponseToOpenAINonStream(ctx context.Context, model string, ori root := gjson.ParseBytes(rawResponse) var content string + var reasoningContent string var toolCalls []map[string]interface{} contentArray := root.Get("content") @@ -306,6 +329,9 @@ func ConvertKiroResponseToOpenAINonStream(ctx context.Context, model string, ori itemType := item.Get("type").String() if itemType == "text" { content += item.Get("text").String() + } else if itemType == "thinking" { + // Extract thinking/reasoning content + reasoningContent += item.Get("thinking").String() } else if itemType == "tool_use" { // Convert Claude tool_use to OpenAI tool_calls format inputJSON := item.Get("input").String() @@ -339,6 +365,11 @@ func ConvertKiroResponseToOpenAINonStream(ctx context.Context, model string, ori "content": content, } + // Add reasoning_content if present (OpenAI reasoning format) + if reasoningContent != "" { + message["reasoning_content"] = reasoningContent + } + // Add tool_calls if present if len(toolCalls) > 0 { message["tool_calls"] = toolCalls diff --git a/internal/watcher/watcher.go b/internal/watcher/watcher.go index 36276de9..428ab814 100644 --- a/internal/watcher/watcher.go +++ b/internal/watcher/watcher.go @@ -1317,6 +1317,12 @@ func (w *Watcher) SnapshotCoreAuths() []*coreauth.Auth { if kk.AgentTaskType != "" { attrs["agent_task_type"] = kk.AgentTaskType } + if kk.PreferredEndpoint != "" { + attrs["preferred_endpoint"] = kk.PreferredEndpoint + } else if cfg.KiroPreferredEndpoint != "" { + // Apply global default if not overridden by specific key + attrs["preferred_endpoint"] = cfg.KiroPreferredEndpoint + } if refreshToken != "" { attrs["refresh_token"] = refreshToken } @@ -1532,6 +1538,17 @@ func (w *Watcher) SnapshotCoreAuths() []*coreauth.Auth { a.NextRefreshAfter = expiresAt.Add(-30 * time.Minute) } } + + // Apply global preferred endpoint setting if not present in metadata + if cfg.KiroPreferredEndpoint != "" { + // Check if already set in metadata (which takes precedence in executor) + if _, hasMeta := metadata["preferred_endpoint"]; !hasMeta { + if a.Attributes == nil { + a.Attributes = make(map[string]string) + } + a.Attributes["preferred_endpoint"] = cfg.KiroPreferredEndpoint + } + } } applyAuthExcludedModelsMeta(a, cfg, nil, "oauth") From 54d4fd7f84c1c187aea4d3de013b5cdd807fc36d Mon Sep 17 00:00:00 2001 From: Luis Pater Date: Fri, 12 Dec 2025 16:10:15 +0800 Subject: [PATCH 29/38] remove PR_DOCUMENTATION.md --- PR_DOCUMENTATION.md | 49 --------------------------------------------- 1 file changed, 49 deletions(-) delete mode 100644 PR_DOCUMENTATION.md diff --git a/PR_DOCUMENTATION.md b/PR_DOCUMENTATION.md deleted file mode 100644 index 6b830af6..00000000 --- a/PR_DOCUMENTATION.md +++ /dev/null @@ -1,49 +0,0 @@ -# PR Title / 拉取请求标题 - -`feat(kiro): Add Thinking Mode support & enhance reliability with multi-quota failover` -`feat(kiro): 支持思考模型 (Thinking Mode) 并通过多配额故障转移增强稳定性` - ---- - -# PR Description / 拉取请求描述 - -## 📝 Summary / 摘要 - -This PR introduces significant upgrades to the Kiro (AWS CodeWhisperer/Amazon Q) module. It adds native support for **Thinking/Reasoning models** (similar to OpenAI o1/Claude 3.7), implements a robust **Multi-Endpoint Failover** system to handle rate limits (429), and optimizes configuration flexibility. - -本次 PR 对 Kiro (AWS CodeWhisperer/Amazon Q) 模块进行了重大升级。它增加了对 **思考/推理模型 (Thinking/Reasoning models)** 的原生支持(类似 OpenAI o1/Claude 3.7),实现了一套健壮的 **多端点故障转移 (Multi-Endpoint Failover)** 系统以应对速率限制 (429),并优化了配置灵活性。 - -## ✨ Key Changes / 主要变更 - -### 1. 🧠 Thinking Mode Support / 思考模式支持 -- **OpenAI Compatibility**: Automatically maps OpenAI's `reasoning_effort` parameter (low/medium/high) to Claude's `budget_tokens` (4k/16k/32k). - - **OpenAI 兼容性**:自动将 OpenAI 的 `reasoning_effort` 参数(low/medium/high)映射为 Claude 的 `budget_tokens`(4k/16k/32k)。 -- **Stream Parsing**: Implemented advanced stream parsing logic to detect and extract content within `...` tags, even across chunk boundaries. - - **流式解析**:实现了高级流式解析逻辑,能够检测并提取 `...` 标签内的内容,即使标签跨越了数据块边界。 -- **Protocol Translation**: Converts Kiro's internal thinking content into OpenAI-compatible `reasoning_content` fields (for non-stream) or `thinking_delta` events (for stream). - - **协议转换**:将 Kiro 内部的思考内容转换为兼容 OpenAI 的 `reasoning_content` 字段(非流式)或 `thinking_delta` 事件(流式)。 - -### 2. 🛡️ Robustness & Failover / 稳健性与故障转移 -- **Dual Quota System**: Explicitly defined `kiroEndpointConfig` to distinguish between **IDE (CodeWhisperer)** and **CLI (Amazon Q)** quotas. - - **双配额系统**:显式定义了 `kiroEndpointConfig` 结构,明确区分 **IDE (CodeWhisperer)** 和 **CLI (Amazon Q)** 的配额来源。 -- **Auto Failover**: Implemented automatic failover logic. If one endpoint returns `429 Too Many Requests`, the request seamlessly retries on the next available endpoint/quota. - - **自动故障转移**:实现了自动故障转移逻辑。如果一个端点返回 `429 Too Many Requests`,请求将无缝在下一个可用端点/配额上重试。 -- **Strict Protocol Compliance**: Enforced strict matching of `Origin` and `X-Amz-Target` headers for each endpoint to prevent `403 Forbidden` errors due to protocol mismatches. - - **严格协议合规**:强制每个端点严格匹配 `Origin` 和 `X-Amz-Target` 头信息,防止因协议不匹配导致的 `403 Forbidden` 错误。 - -### 3. ⚙️ Configuration & Models / 配置与模型 -- **New Config Options**: Added `KiroPreferredEndpoint` (global) and `PreferredEndpoint` (per-key) settings to allow users to prioritize specific quotas (e.g., "ide" or "cli"). - - **新配置项**:添加了 `KiroPreferredEndpoint`(全局)和 `PreferredEndpoint`(单 Key)设置,允许用户优先选择特定的配额(如 "ide" 或 "cli")。 -- **Model Registry**: Normalized model IDs (replaced dots with hyphens) and added `-agentic` variants optimized for large code generation tasks. - - **模型注册表**:规范化了模型 ID(将点号替换为连字符),并添加了针对大型代码生成任务优化的 `-agentic` 变体。 - -### 4. 🔧 Fixes / 修复 -- **AMP Proxy**: Downgraded client-side context cancellation logs from `Error` to `Debug` to reduce log noise. - - **AMP 代理**:将客户端上下文取消的日志级别从 `Error` 降级为 `Debug`,减少日志噪音。 - -## ⚠️ Impact / 影响 - -- **Authentication**: **No changes** to the login/OAuth process. Existing tokens work as is. -- **认证**:登录/OAuth 流程 **无变更**。现有 Token 可直接使用。 -- **Compatibility**: Fully backward compatible. The new failover logic is transparent to the user. -- **兼容性**:完全向后兼容。新的故障转移逻辑对用户是透明的。 \ No newline at end of file From db80b20bc233aa88d8d88ef6f63950696caf6807 Mon Sep 17 00:00:00 2001 From: Ravens2121 Date: Sat, 13 Dec 2025 03:51:14 +0800 Subject: [PATCH 30/38] feat(kiro): enhance thinking support and fix truncation issues - **Thinking Support**: - Enabled thinking support for all Kiro Claude models, including Haiku 4.5 and agentic variants. - Updated `model_definitions.go` with thinking configuration (Min: 1024, Max: 32000, ZeroAllowed: true). - Fixed `extended_thinking` field names in `model_registry.go` (from `min_budget`/`max_budget` to `min`/`max`) to comply with Claude API specs, enabling thinking control in clients like Claude Code. - **Kiro Executor Fixes**: - Fixed `budget_tokens` handling: explicitly disable thinking when budget is 0 or negative. - Removed aggressive duplicate content filtering logic that caused truncation/data loss. - Enhanced thinking tag parsing with `extractThinkingFromContent` to correctly handle interleaved thinking/text blocks. - Added EOF handling to flush pending thinking tag characters, preventing data loss at stream end. - **Performance**: - Optimized Claude stream handler (v6.2) with reduced buffer size (4KB) and faster flush interval (50ms) to minimize latency and prevent timeouts. --- internal/registry/model_definitions.go | 8 + internal/registry/model_registry.go | 16 +- internal/runtime/executor/kiro_executor.go | 197 +++++++++++++++++++-- sdk/api/handlers/claude/code_handlers.go | 26 +-- 4 files changed, 219 insertions(+), 28 deletions(-) diff --git a/internal/registry/model_definitions.go b/internal/registry/model_definitions.go index 4df0cf67..c6282759 100644 --- a/internal/registry/model_definitions.go +++ b/internal/registry/model_definitions.go @@ -895,6 +895,7 @@ func GetKiroModels() []*ModelInfo { Description: "Claude Opus 4.5 via Kiro (2.2x credit)", ContextLength: 200000, MaxCompletionTokens: 64000, + Thinking: &ThinkingSupport{Min: 1024, Max: 32000, ZeroAllowed: true, DynamicAllowed: true}, }, { ID: "kiro-claude-sonnet-4-5", @@ -906,6 +907,7 @@ func GetKiroModels() []*ModelInfo { Description: "Claude Sonnet 4.5 via Kiro (1.3x credit)", ContextLength: 200000, MaxCompletionTokens: 64000, + Thinking: &ThinkingSupport{Min: 1024, Max: 32000, ZeroAllowed: true, DynamicAllowed: true}, }, { ID: "kiro-claude-sonnet-4", @@ -917,6 +919,7 @@ func GetKiroModels() []*ModelInfo { Description: "Claude Sonnet 4 via Kiro (1.3x credit)", ContextLength: 200000, MaxCompletionTokens: 64000, + Thinking: &ThinkingSupport{Min: 1024, Max: 32000, ZeroAllowed: true, DynamicAllowed: true}, }, { ID: "kiro-claude-haiku-4-5", @@ -928,6 +931,7 @@ func GetKiroModels() []*ModelInfo { Description: "Claude Haiku 4.5 via Kiro (0.4x credit)", ContextLength: 200000, MaxCompletionTokens: 64000, + Thinking: &ThinkingSupport{Min: 1024, Max: 32000, ZeroAllowed: true, DynamicAllowed: true}, }, // --- Agentic Variants (Optimized for coding agents with chunked writes) --- { @@ -940,6 +944,7 @@ func GetKiroModels() []*ModelInfo { Description: "Claude Opus 4.5 optimized for coding agents (chunked writes)", ContextLength: 200000, MaxCompletionTokens: 64000, + Thinking: &ThinkingSupport{Min: 1024, Max: 32000, ZeroAllowed: true, DynamicAllowed: true}, }, { ID: "kiro-claude-sonnet-4-5-agentic", @@ -951,6 +956,7 @@ func GetKiroModels() []*ModelInfo { Description: "Claude Sonnet 4.5 optimized for coding agents (chunked writes)", ContextLength: 200000, MaxCompletionTokens: 64000, + Thinking: &ThinkingSupport{Min: 1024, Max: 32000, ZeroAllowed: true, DynamicAllowed: true}, }, { ID: "kiro-claude-sonnet-4-agentic", @@ -962,6 +968,7 @@ func GetKiroModels() []*ModelInfo { Description: "Claude Sonnet 4 optimized for coding agents (chunked writes)", ContextLength: 200000, MaxCompletionTokens: 64000, + Thinking: &ThinkingSupport{Min: 1024, Max: 32000, ZeroAllowed: true, DynamicAllowed: true}, }, { ID: "kiro-claude-haiku-4-5-agentic", @@ -973,6 +980,7 @@ func GetKiroModels() []*ModelInfo { Description: "Claude Haiku 4.5 optimized for coding agents (chunked writes)", ContextLength: 200000, MaxCompletionTokens: 64000, + Thinking: &ThinkingSupport{Min: 1024, Max: 32000, ZeroAllowed: true, DynamicAllowed: true}, }, } } diff --git a/internal/registry/model_registry.go b/internal/registry/model_registry.go index f3517bde..8f575df4 100644 --- a/internal/registry/model_registry.go +++ b/internal/registry/model_registry.go @@ -748,7 +748,8 @@ func (r *ModelRegistry) convertModelToMap(model *ModelInfo, handlerType string) } return result - case "claude": + case "claude", "kiro", "antigravity": + // Claude, Kiro, and Antigravity all use Claude-compatible format for Claude Code client result := map[string]any{ "id": model.ID, "object": "model", @@ -763,6 +764,19 @@ func (r *ModelRegistry) convertModelToMap(model *ModelInfo, handlerType string) if model.DisplayName != "" { result["display_name"] = model.DisplayName } + // Add thinking support for Claude Code client + // Claude Code checks for "thinking" field (simple boolean) to enable tab toggle + // Also add "extended_thinking" for detailed budget info + if model.Thinking != nil { + result["thinking"] = true + result["extended_thinking"] = map[string]any{ + "supported": true, + "min": model.Thinking.Min, + "max": model.Thinking.Max, + "zero_allowed": model.Thinking.ZeroAllowed, + "dynamic_allowed": model.Thinking.DynamicAllowed, + } + } return result case "gemini": diff --git a/internal/runtime/executor/kiro_executor.go b/internal/runtime/executor/kiro_executor.go index bff3fb57..60148829 100644 --- a/internal/runtime/executor/kiro_executor.go +++ b/internal/runtime/executor/kiro_executor.go @@ -1118,10 +1118,18 @@ func (e *KiroExecutor) buildKiroPayload(claudeBody []byte, modelID, profileArn, // Read budget_tokens if specified - this value comes from: // - Claude API: thinking.budget_tokens directly // - OpenAI API: reasoning_effort -> budget_tokens (low:4000, medium:16000, high:32000) - if bt := thinkingField.Get("budget_tokens"); bt.Exists() && bt.Int() > 0 { + if bt := thinkingField.Get("budget_tokens"); bt.Exists() { budgetTokens = bt.Int() + // If budget_tokens <= 0, disable thinking explicitly + // This allows users to disable thinking by setting budget_tokens to 0 + if budgetTokens <= 0 { + thinkingEnabled = false + log.Debugf("kiro: thinking mode disabled via budget_tokens <= 0") + } + } + if thinkingEnabled { + log.Debugf("kiro: thinking mode enabled via Claude API parameter, budget_tokens: %d", budgetTokens) } - log.Debugf("kiro: thinking mode enabled via Claude API parameter, budget_tokens: %d", budgetTokens) } } @@ -1737,15 +1745,23 @@ func getString(m map[string]interface{}, key string) string { // buildClaudeResponse constructs a Claude-compatible response. // Supports tool_use blocks when tools are present in the response. +// Supports thinking blocks - parses tags and converts to Claude thinking content blocks. func (e *KiroExecutor) buildClaudeResponse(content string, toolUses []kiroToolUse, model string, usageInfo usage.Detail) []byte { var contentBlocks []map[string]interface{} - // Add text content if present + // Extract thinking blocks and text from content + // This handles ... tags from Kiro's response if content != "" { - contentBlocks = append(contentBlocks, map[string]interface{}{ - "type": "text", - "text": content, - }) + blocks := e.extractThinkingFromContent(content) + contentBlocks = append(contentBlocks, blocks...) + + // DIAGNOSTIC: Log if thinking blocks were extracted + for _, block := range blocks { + if block["type"] == "thinking" { + thinkingContent := block["thinking"].(string) + log.Infof("kiro: buildClaudeResponse extracted thinking block (len: %d)", len(thinkingContent)) + } + } } // Add tool_use blocks @@ -1788,6 +1804,101 @@ func (e *KiroExecutor) buildClaudeResponse(content string, toolUses []kiroToolUs return result } +// extractThinkingFromContent parses content to extract thinking blocks and text. +// Returns a list of content blocks in the order they appear in the content. +// Handles interleaved thinking and text blocks correctly. +// Based on the streaming implementation's thinking tag handling. +func (e *KiroExecutor) extractThinkingFromContent(content string) []map[string]interface{} { + var blocks []map[string]interface{} + + if content == "" { + return blocks + } + + // Check if content contains thinking tags at all + if !strings.Contains(content, thinkingStartTag) { + // No thinking tags, return as plain text + return []map[string]interface{}{ + { + "type": "text", + "text": content, + }, + } + } + + log.Debugf("kiro: extractThinkingFromContent - found thinking tags in content (len: %d)", len(content)) + + remaining := content + + for len(remaining) > 0 { + // Look for tag + startIdx := strings.Index(remaining, thinkingStartTag) + + if startIdx == -1 { + // No more thinking tags, add remaining as text + if strings.TrimSpace(remaining) != "" { + blocks = append(blocks, map[string]interface{}{ + "type": "text", + "text": remaining, + }) + } + break + } + + // Add text before thinking tag (if any meaningful content) + if startIdx > 0 { + textBefore := remaining[:startIdx] + if strings.TrimSpace(textBefore) != "" { + blocks = append(blocks, map[string]interface{}{ + "type": "text", + "text": textBefore, + }) + } + } + + // Move past the opening tag + remaining = remaining[startIdx+len(thinkingStartTag):] + + // Find closing tag + endIdx := strings.Index(remaining, thinkingEndTag) + + if endIdx == -1 { + // No closing tag found, treat rest as thinking content (incomplete response) + if strings.TrimSpace(remaining) != "" { + blocks = append(blocks, map[string]interface{}{ + "type": "thinking", + "thinking": remaining, + }) + log.Warnf("kiro: extractThinkingFromContent - missing closing tag") + } + break + } + + // Extract thinking content between tags + thinkContent := remaining[:endIdx] + if strings.TrimSpace(thinkContent) != "" { + blocks = append(blocks, map[string]interface{}{ + "type": "thinking", + "thinking": thinkContent, + }) + log.Debugf("kiro: extractThinkingFromContent - extracted thinking block (len: %d)", len(thinkContent)) + } + + // Move past the closing tag + remaining = remaining[endIdx+len(thinkingEndTag):] + } + + // If no blocks were created (all whitespace), return empty text block + if len(blocks) == 0 { + blocks = append(blocks, map[string]interface{}{ + "type": "text", + "text": "", + }) + } + + return blocks +} + // NOTE: Tool uses are now extracted from API response, not parsed from text @@ -1804,9 +1915,10 @@ func (e *KiroExecutor) streamToChannel(ctx context.Context, body io.Reader, out processedIDs := make(map[string]bool) var currentToolUse *toolUseState - // Duplicate content detection - tracks last content event to filter duplicates - // Based on AIClient-2-API implementation for Kiro - var lastContentEvent string + // NOTE: Duplicate content filtering removed - it was causing legitimate repeated + // content (like consecutive newlines) to be incorrectly filtered out. + // The previous implementation compared lastContentEvent == contentDelta which + // is too aggressive for streaming scenarios. // Streaming token calculation - accumulate content for real-time token counting // Based on AIClient-2-API implementation @@ -1905,6 +2017,56 @@ func (e *KiroExecutor) streamToChannel(ctx context.Context, body io.Reader, out hasToolUses = true currentToolUse = nil } + + // Flush any pending tag characters at EOF + // These are partial tag prefixes that were held back waiting for more data + // Since no more data is coming, output them as regular text + var pendingText string + if pendingStartTagChars > 0 { + pendingText = thinkingStartTag[:pendingStartTagChars] + log.Debugf("kiro: flushing pending start tag chars at EOF: %q", pendingText) + pendingStartTagChars = 0 + } + if pendingEndTagChars > 0 { + pendingText += thinkingEndTag[:pendingEndTagChars] + log.Debugf("kiro: flushing pending end tag chars at EOF: %q", pendingText) + pendingEndTagChars = 0 + } + + // Output pending text if any + if pendingText != "" { + // If we're in a thinking block, output as thinking content + if inThinkBlock && isThinkingBlockOpen { + thinkingEvent := e.buildClaudeThinkingDeltaEvent(pendingText, thinkingBlockIndex) + sseData := sdktranslator.TranslateStream(ctx, sdktranslator.FromString("kiro"), targetFormat, model, originalReq, claudeBody, thinkingEvent, &translatorParam) + for _, chunk := range sseData { + if chunk != "" { + out <- cliproxyexecutor.StreamChunk{Payload: []byte(chunk + "\n\n")} + } + } + } else { + // Output as regular text + if !isTextBlockOpen { + contentBlockIndex++ + isTextBlockOpen = true + blockStart := e.buildClaudeContentBlockStartEvent(contentBlockIndex, "text", "", "") + sseData := sdktranslator.TranslateStream(ctx, sdktranslator.FromString("kiro"), targetFormat, model, originalReq, claudeBody, blockStart, &translatorParam) + for _, chunk := range sseData { + if chunk != "" { + out <- cliproxyexecutor.StreamChunk{Payload: []byte(chunk + "\n\n")} + } + } + } + + claudeEvent := e.buildClaudeStreamEvent(pendingText, contentBlockIndex) + sseData := sdktranslator.TranslateStream(ctx, sdktranslator.FromString("kiro"), targetFormat, model, originalReq, claudeBody, claudeEvent, &translatorParam) + for _, chunk := range sseData { + if chunk != "" { + out <- cliproxyexecutor.StreamChunk{Payload: []byte(chunk + "\n\n")} + } + } + } + } break } if err != nil { @@ -2035,15 +2197,16 @@ func (e *KiroExecutor) streamToChannel(ctx context.Context, body io.Reader, out } } - // Handle text content with duplicate detection and thinking mode support + // Handle text content with thinking mode support if contentDelta != "" { - // Check for duplicate content - skip if identical to last content event - // Based on AIClient-2-API implementation for Kiro - if contentDelta == lastContentEvent { - log.Debugf("kiro: skipping duplicate content event (len: %d)", len(contentDelta)) - continue + // DIAGNOSTIC: Check for thinking tags in response + if strings.Contains(contentDelta, "") || strings.Contains(contentDelta, "") { + log.Infof("kiro: DIAGNOSTIC - Found thinking tag in response (len: %d)", len(contentDelta)) } - lastContentEvent = contentDelta + + // NOTE: Duplicate content filtering was removed because it incorrectly + // filtered out legitimate repeated content (like consecutive newlines "\n\n"). + // Streaming naturally can have identical chunks that are valid content. outputLen += len(contentDelta) // Accumulate content for streaming token calculation diff --git a/sdk/api/handlers/claude/code_handlers.go b/sdk/api/handlers/claude/code_handlers.go index 8a57a0cc..be2028e1 100644 --- a/sdk/api/handlers/claude/code_handlers.go +++ b/sdk/api/handlers/claude/code_handlers.go @@ -219,12 +219,12 @@ func (h *ClaudeCodeAPIHandler) handleStreamingResponse(c *gin.Context, rawJSON [ } func (h *ClaudeCodeAPIHandler) forwardClaudeStream(c *gin.Context, flusher http.Flusher, cancel func(error), data <-chan []byte, errs <-chan *interfaces.ErrorMessage) { - // v6.1: Intelligent Buffered Streamer strategy - // Enhanced buffering with larger buffer size (16KB) and longer flush interval (120ms). - // Smart flush only when buffer is sufficiently filled (≥50%), dramatically reducing - // flush frequency from ~12.5Hz to ~5-8Hz while maintaining low latency. - writer := bufio.NewWriterSize(c.Writer, 16*1024) // 4KB → 16KB - ticker := time.NewTicker(120 * time.Millisecond) // 80ms → 120ms + // v6.2: Immediate flush strategy for SSE streams + // SSE requires immediate data delivery to prevent client timeouts. + // Previous buffering strategy (16KB buffer, 8KB threshold) caused delays + // because SSE events are typically small (< 1KB), leading to client retries. + writer := bufio.NewWriterSize(c.Writer, 4*1024) // 4KB buffer (smaller for faster flush) + ticker := time.NewTicker(50 * time.Millisecond) // 50ms interval for responsive streaming defer ticker.Stop() var chunkIdx int @@ -238,10 +238,9 @@ func (h *ClaudeCodeAPIHandler) forwardClaudeStream(c *gin.Context, flusher http. return case <-ticker.C: - // Smart flush: only flush when buffer has sufficient data (≥50% full) - // This reduces flush frequency while ensuring data flows naturally - buffered := writer.Buffered() - if buffered >= 8*1024 { // At least 8KB (50% of 16KB buffer) + // Flush any buffered data on timer to ensure responsiveness + // For SSE, we flush whenever there's any data to prevent client timeouts + if writer.Buffered() > 0 { if err := writer.Flush(); err != nil { // Error flushing, cancel and return cancel(err) @@ -254,6 +253,7 @@ func (h *ClaudeCodeAPIHandler) forwardClaudeStream(c *gin.Context, flusher http. if !ok { // Stream ended, flush remaining data _ = writer.Flush() + flusher.Flush() cancel(nil) return } @@ -263,6 +263,12 @@ func (h *ClaudeCodeAPIHandler) forwardClaudeStream(c *gin.Context, flusher http. // The handler just needs to forward it without reassembly. if len(chunk) > 0 { _, _ = writer.Write(chunk) + // Immediately flush for first few chunks to establish connection quickly + // This prevents client timeout/retry on slow backends like Kiro + if chunkIdx < 3 { + _ = writer.Flush() + flusher.Flush() + } } chunkIdx++ From 58866b21cb7c3910b342f5c75d6f4b75538c84e8 Mon Sep 17 00:00:00 2001 From: Ravens2121 Date: Sat, 13 Dec 2025 10:19:53 +0800 Subject: [PATCH 31/38] feat: optimize connection pooling and improve Kiro executor reliability MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ## 中文说明 ### 连接池优化 - 为 AMP 代理、SOCKS5 代理和 HTTP 代理配置优化的连接池参数 - MaxIdleConnsPerHost 从默认的 2 增加到 20,支持更多并发用户 - MaxConnsPerHost 设为 0(无限制),避免连接瓶颈 - 添加 IdleConnTimeout (90s) 和其他超时配置 ### Kiro 执行器增强 - 添加 Event Stream 消息解析的边界保护,防止越界访问 - 实现实时使用量估算(每 5000 字符或 15 秒发送 ping 事件) - 正确从上游事件中提取并传递 stop_reason - 改进输入 token 计算,优先使用 Claude 格式解析 - 添加 max_tokens 截断警告日志 ### Token 计算改进 - 添加 tokenizer 缓存(sync.Map)避免重复创建 - 为 Claude/Kiro/AmazonQ 模型添加 1.1 调整因子 - 新增 countClaudeChatTokens 函数支持 Claude API 格式 - 支持图像 token 估算(基于尺寸计算) ### 认证刷新优化 - RefreshLead 从 30 分钟改为 5 分钟,与 Antigravity 保持一致 - 修复 NextRefreshAfter 设置,防止频繁刷新检查 - refreshFailureBackoff 从 5 分钟改为 1 分钟,加快失败恢复 --- ## English Description ### Connection Pool Optimization - Configure optimized connection pool parameters for AMP proxy, SOCKS5 proxy, and HTTP proxy - Increase MaxIdleConnsPerHost from default 2 to 20 to support more concurrent users - Set MaxConnsPerHost to 0 (unlimited) to avoid connection bottlenecks - Add IdleConnTimeout (90s) and other timeout configurations ### Kiro Executor Enhancements - Add boundary protection for Event Stream message parsing to prevent out-of-bounds access - Implement real-time usage estimation (send ping events every 5000 chars or 15 seconds) - Correctly extract and pass stop_reason from upstream events - Improve input token calculation, prioritize Claude format parsing - Add max_tokens truncation warning logs ### Token Calculation Improvements - Add tokenizer cache (sync.Map) to avoid repeated creation - Add 1.1 adjustment factor for Claude/Kiro/AmazonQ models - Add countClaudeChatTokens function to support Claude API format - Support image token estimation (calculated based on dimensions) ### Authentication Refresh Optimization - Change RefreshLead from 30 minutes to 5 minutes, consistent with Antigravity - Fix NextRefreshAfter setting to prevent frequent refresh checks - Change refreshFailureBackoff from 5 minutes to 1 minute for faster failure recovery --- internal/api/modules/amp/proxy.go | 50 +- internal/runtime/executor/kiro_executor.go | 610 ++++++++++++++++----- internal/runtime/executor/proxy_helpers.go | 16 +- internal/runtime/executor/token_helpers.go | 287 +++++++++- internal/util/proxy.go | 17 +- sdk/auth/kiro.go | 18 +- sdk/cliproxy/auth/manager.go | 6 +- 7 files changed, 840 insertions(+), 164 deletions(-) diff --git a/internal/api/modules/amp/proxy.go b/internal/api/modules/amp/proxy.go index 6a6b1b54..5a3f2081 100644 --- a/internal/api/modules/amp/proxy.go +++ b/internal/api/modules/amp/proxy.go @@ -7,11 +7,13 @@ import ( "errors" "fmt" "io" + "net" "net/http" "net/http/httputil" "net/url" "strconv" "strings" + "time" "github.com/gin-gonic/gin" log "github.com/sirupsen/logrus" @@ -36,6 +38,22 @@ func createReverseProxy(upstreamURL string, secretSource SecretSource) (*httputi } proxy := httputil.NewSingleHostReverseProxy(parsed) + + // Configure custom Transport with optimized connection pooling for high concurrency + proxy.Transport = &http.Transport{ + MaxIdleConns: 100, + MaxIdleConnsPerHost: 20, // Increased from default 2 to support more concurrent users + MaxConnsPerHost: 0, // No limit on max concurrent connections per host + IdleConnTimeout: 90 * time.Second, + DialContext: (&net.Dialer{ + Timeout: 30 * time.Second, + KeepAlive: 30 * time.Second, + }).DialContext, + TLSHandshakeTimeout: 10 * time.Second, + ResponseHeaderTimeout: 60 * time.Second, + ExpectContinueTimeout: 1 * time.Second, + } + originalDirector := proxy.Director // Modify outgoing requests to inject API key and fix routing @@ -64,7 +82,15 @@ func createReverseProxy(upstreamURL string, secretSource SecretSource) (*httputi // Modify incoming responses to handle gzip without Content-Encoding // This addresses the same issue as inline handler gzip handling, but at the proxy level proxy.ModifyResponse = func(resp *http.Response) error { - // Only process successful responses + // Log upstream error responses for diagnostics (502, 503, etc.) + // These are NOT proxy connection errors - the upstream responded with an error status + if resp.StatusCode >= 500 { + log.Errorf("amp upstream responded with error [%d] for %s %s", resp.StatusCode, resp.Request.Method, resp.Request.URL.Path) + } else if resp.StatusCode >= 400 { + log.Warnf("amp upstream responded with client error [%d] for %s %s", resp.StatusCode, resp.Request.Method, resp.Request.URL.Path) + } + + // Only process successful responses for gzip decompression if resp.StatusCode < 200 || resp.StatusCode >= 300 { return nil } @@ -148,15 +174,29 @@ func createReverseProxy(upstreamURL string, secretSource SecretSource) (*httputi return nil } - // Error handler for proxy failures + // Error handler for proxy failures with detailed error classification for diagnostics proxy.ErrorHandler = func(rw http.ResponseWriter, req *http.Request, err error) { - // Check if this is a client-side cancellation (normal behavior) + // Classify the error type for better diagnostics + var errType string + if errors.Is(err, context.DeadlineExceeded) { + errType = "timeout" + } else if errors.Is(err, context.Canceled) { + errType = "canceled" + } else if netErr, ok := err.(net.Error); ok && netErr.Timeout() { + errType = "dial_timeout" + } else if _, ok := err.(net.Error); ok { + errType = "network_error" + } else { + errType = "connection_error" + } + // Don't log as error for context canceled - it's usually client closing connection if errors.Is(err, context.Canceled) { - log.Debugf("amp upstream proxy: client canceled request for %s %s", req.Method, req.URL.Path) + log.Debugf("amp upstream proxy [%s]: client canceled request for %s %s", errType, req.Method, req.URL.Path) } else { - log.Errorf("amp upstream proxy error for %s %s: %v", req.Method, req.URL.Path, err) + log.Errorf("amp upstream proxy error [%s] for %s %s: %v", errType, req.Method, req.URL.Path, err) } + rw.Header().Set("Content-Type", "application/json") rw.WriteHeader(http.StatusBadGateway) _, _ = rw.Write([]byte(`{"error":"amp_upstream_proxy_error","message":"Failed to reach Amp upstream"}`)) diff --git a/internal/runtime/executor/kiro_executor.go b/internal/runtime/executor/kiro_executor.go index 60148829..cbc5443b 100644 --- a/internal/runtime/executor/kiro_executor.go +++ b/internal/runtime/executor/kiro_executor.go @@ -36,6 +36,16 @@ const ( kiroAcceptStream = "*/*" kiroMaxMessageSize = 10 * 1024 * 1024 // 10MB max message size for event stream kiroMaxToolDescLen = 10237 // Kiro API limit is 10240 bytes, leave room for "..." + + // Event Stream frame size constants for boundary protection + // AWS Event Stream binary format: prelude (12 bytes) + headers + payload + message_crc (4 bytes) + // Prelude consists of: total_length (4) + headers_length (4) + prelude_crc (4) + minEventStreamFrameSize = 16 // Minimum: 4(total_len) + 4(headers_len) + 4(prelude_crc) + 4(message_crc) + maxEventStreamMsgSize = 10 << 20 // Maximum message length: 10MB + + // Event Stream error type constants + ErrStreamFatal = "fatal" // Connection/authentication errors, not recoverable + ErrStreamMalformed = "malformed" // Format errors, data cannot be parsed // kiroUserAgent matches amq2api format for User-Agent header kiroUserAgent = "aws-sdk-rust/1.3.9 os/macos lang/rust/1.87.0" // kiroFullUserAgent is the complete x-amz-user-agent header matching amq2api @@ -102,6 +112,13 @@ You MUST follow these rules for ALL file operations. Violation causes server tim REMEMBER: When in doubt, write LESS per operation. Multiple small operations > one large operation.` ) +// Real-time usage estimation configuration +// These control how often usage updates are sent during streaming +var ( + usageUpdateCharThreshold = 5000 // Send usage update every 5000 characters + usageUpdateTimeInterval = 15 * time.Second // Or every 15 seconds, whichever comes first +) + // kiroEndpointConfig bundles endpoint URL with its compatible Origin and AmzTarget values. // This solves the "triple mismatch" problem where different endpoints require matching // Origin and X-Amz-Target header values. @@ -495,7 +512,7 @@ func (e *KiroExecutor) executeWithRetry(ctx context.Context, auth *cliproxyauth. } }() - content, toolUses, usageInfo, err := e.parseEventStream(httpResp.Body) + content, toolUses, usageInfo, stopReason, err := e.parseEventStream(httpResp.Body) if err != nil { recordAPIResponseError(ctx, e.cfg, err) return resp, err @@ -503,14 +520,14 @@ func (e *KiroExecutor) executeWithRetry(ctx context.Context, auth *cliproxyauth. // Fallback for usage if missing from upstream if usageInfo.TotalTokens == 0 { - if enc, encErr := tokenizerForModel(req.Model); encErr == nil { + if enc, encErr := getTokenizer(req.Model); encErr == nil { if inp, countErr := countOpenAIChatTokens(enc, opts.OriginalRequest); countErr == nil { usageInfo.InputTokens = inp } } if len(content) > 0 { // Use tiktoken for more accurate output token calculation - if enc, encErr := tokenizerForModel(req.Model); encErr == nil { + if enc, encErr := getTokenizer(req.Model); encErr == nil { if tokenCount, countErr := enc.Count(content); countErr == nil { usageInfo.OutputTokens = int64(tokenCount) } @@ -530,7 +547,8 @@ func (e *KiroExecutor) executeWithRetry(ctx context.Context, auth *cliproxyauth. reporter.publish(ctx, usageInfo) // Build response in Claude format for Kiro translator - kiroResponse := e.buildClaudeResponse(content, toolUses, req.Model, usageInfo) + // stopReason is extracted from upstream response by parseEventStream + kiroResponse := e.buildClaudeResponse(content, toolUses, req.Model, usageInfo, stopReason) out := sdktranslator.TranslateNonStream(ctx, to, from, req.Model, bytes.Clone(opts.OriginalRequest), body, kiroResponse, nil) resp = cliproxyexecutor.Response{Payload: []byte(out)} return resp, nil @@ -970,11 +988,40 @@ func (e *KiroExecutor) mapModelToKiro(model string) string { return "claude-sonnet-4.5" } +// EventStreamError represents an Event Stream processing error +type EventStreamError struct { + Type string // "fatal", "malformed" + Message string + Cause error +} + +func (e *EventStreamError) Error() string { + if e.Cause != nil { + return fmt.Sprintf("event stream %s: %s: %v", e.Type, e.Message, e.Cause) + } + return fmt.Sprintf("event stream %s: %s", e.Type, e.Message) +} + +// eventStreamMessage represents a parsed AWS Event Stream message +type eventStreamMessage struct { + EventType string // Event type from headers (e.g., "assistantResponseEvent") + Payload []byte // JSON payload of the message +} + // Kiro API request structs - field order determines JSON key order type kiroPayload struct { ConversationState kiroConversationState `json:"conversationState"` ProfileArn string `json:"profileArn,omitempty"` + InferenceConfig *kiroInferenceConfig `json:"inferenceConfig,omitempty"` +} + +// kiroInferenceConfig contains inference parameters for the Kiro API. +// NOTE: This is an experimental addition - Kiro/Amazon Q API may not support these parameters. +// If the API ignores or rejects these fields, response length is controlled internally by the model. +type kiroInferenceConfig struct { + MaxTokens int `json:"maxTokens,omitempty"` // Maximum output tokens (may be ignored by API) + Temperature float64 `json:"temperature,omitempty"` // Sampling temperature (may be ignored by API) } type kiroConversationState struct { @@ -1058,7 +1105,25 @@ type kiroToolUse struct { // isAgentic parameter enables chunked write optimization prompt for -agentic model variants. // isChatOnly parameter disables tool calling for -chat model variants (pure conversation mode). // Supports thinking mode - when Claude API thinking parameter is present, injects thinkingHint. +// +// max_tokens support: Kiro/Amazon Q API may not officially support max_tokens parameter. +// We attempt to pass it via inferenceConfig.maxTokens, but the API may ignore it. +// Response truncation can be detected via stop_reason == "max_tokens" in the response. func (e *KiroExecutor) buildKiroPayload(claudeBody []byte, modelID, profileArn, origin string, isAgentic, isChatOnly bool) []byte { + // Extract max_tokens for potential use in inferenceConfig + var maxTokens int64 + if mt := gjson.GetBytes(claudeBody, "max_tokens"); mt.Exists() { + maxTokens = mt.Int() + } + + // Extract temperature if specified + var temperature float64 + var hasTemperature bool + if temp := gjson.GetBytes(claudeBody, "temperature"); temp.Exists() { + temperature = temp.Float() + hasTemperature = true + } + // Normalize origin value for Kiro API compatibility // Kiro API only accepts "CLI" or "AI_EDITOR" as valid origin values switch origin { @@ -1325,6 +1390,18 @@ func (e *KiroExecutor) buildKiroPayload(claudeBody []byte, modelID, profileArn, }} } + // Build inferenceConfig if we have any inference parameters + var inferenceConfig *kiroInferenceConfig + if maxTokens > 0 || hasTemperature { + inferenceConfig = &kiroInferenceConfig{} + if maxTokens > 0 { + inferenceConfig.MaxTokens = int(maxTokens) + } + if hasTemperature { + inferenceConfig.Temperature = temperature + } + } + // Build payload with correct field order (matches struct definition) payload := kiroPayload{ ConversationState: kiroConversationState{ @@ -1333,7 +1410,8 @@ func (e *KiroExecutor) buildKiroPayload(claudeBody []byte, modelID, profileArn, CurrentMessage: currentMessage, History: history, // Now always included (non-nil slice) }, - ProfileArn: profileArn, + ProfileArn: profileArn, + InferenceConfig: inferenceConfig, } result, err := json.Marshal(payload) @@ -1493,12 +1571,14 @@ func (e *KiroExecutor) buildAssistantMessageStruct(msg gjson.Result) kiroAssista // NOTE: Tool calling is now supported via userInputMessageContext.tools and toolResults // parseEventStream parses AWS Event Stream binary format. -// Extracts text content and tool uses from the response. +// Extracts text content, tool uses, and stop_reason from the response. // Supports embedded [Called ...] tool calls and input buffering for toolUseEvent. -func (e *KiroExecutor) parseEventStream(body io.Reader) (string, []kiroToolUse, usage.Detail, error) { +// Returns: content, toolUses, usageInfo, stopReason, error +func (e *KiroExecutor) parseEventStream(body io.Reader) (string, []kiroToolUse, usage.Detail, string, error) { var content strings.Builder var toolUses []kiroToolUse var usageInfo usage.Detail + var stopReason string // Extracted from upstream response reader := bufio.NewReader(body) // Tool use state tracking for input buffering and deduplication @@ -1506,59 +1586,28 @@ func (e *KiroExecutor) parseEventStream(body io.Reader) (string, []kiroToolUse, var currentToolUse *toolUseState for { - prelude := make([]byte, 8) - _, err := io.ReadFull(reader, prelude) - if err == io.EOF { + msg, eventErr := e.readEventStreamMessage(reader) + if eventErr != nil { + log.Errorf("kiro: parseEventStream error: %v", eventErr) + return content.String(), toolUses, usageInfo, stopReason, eventErr + } + if msg == nil { + // Normal end of stream (EOF) break } - if err != nil { - return content.String(), toolUses, usageInfo, fmt.Errorf("failed to read prelude: %w", err) - } - totalLen := binary.BigEndian.Uint32(prelude[0:4]) - if totalLen < 8 { - return content.String(), toolUses, usageInfo, fmt.Errorf("invalid message length: %d", totalLen) - } - if totalLen > kiroMaxMessageSize { - return content.String(), toolUses, usageInfo, fmt.Errorf("message too large: %d bytes", totalLen) - } - headersLen := binary.BigEndian.Uint32(prelude[4:8]) - - remaining := make([]byte, totalLen-8) - _, err = io.ReadFull(reader, remaining) - if err != nil { - return content.String(), toolUses, usageInfo, fmt.Errorf("failed to read message: %w", err) - } - - // Validate headersLen to prevent slice out of bounds - if headersLen+4 > uint32(len(remaining)) { - log.Warnf("kiro: invalid headersLen %d exceeds remaining buffer %d", headersLen, len(remaining)) + eventType := msg.EventType + payload := msg.Payload + if len(payload) == 0 { continue } - // Extract event type from headers - eventType := e.extractEventType(remaining[:headersLen+4]) - - payloadStart := 4 + headersLen - payloadEnd := uint32(len(remaining)) - 4 - if payloadStart >= payloadEnd { - continue - } - - payload := remaining[payloadStart:payloadEnd] - var event map[string]interface{} if err := json.Unmarshal(payload, &event); err != nil { log.Debugf("kiro: skipping malformed event: %v", err) continue } - // DIAGNOSTIC: Log all received event types for debugging - log.Debugf("kiro: parseEventStream received event type: %s", eventType) - if log.IsLevelEnabled(log.TraceLevel) { - log.Tracef("kiro: parseEventStream event payload: %s", string(payload)) - } - // Check for error/exception events in the payload (Kiro API may return errors with HTTP 200) // These can appear as top-level fields or nested within the event if errType, hasErrType := event["_type"].(string); hasErrType { @@ -1568,7 +1617,7 @@ func (e *KiroExecutor) parseEventStream(body io.Reader) (string, []kiroToolUse, errMsg = msg } log.Errorf("kiro: received AWS error in event stream: type=%s, message=%s", errType, errMsg) - return "", nil, usageInfo, fmt.Errorf("kiro API error: %s - %s", errType, errMsg) + return "", nil, usageInfo, stopReason, fmt.Errorf("kiro API error: %s - %s", errType, errMsg) } if errType, hasErrType := event["type"].(string); hasErrType && (errType == "error" || errType == "exception") { // Generic error event @@ -1581,7 +1630,18 @@ func (e *KiroExecutor) parseEventStream(body io.Reader) (string, []kiroToolUse, } } log.Errorf("kiro: received error event in stream: type=%s, message=%s", errType, errMsg) - return "", nil, usageInfo, fmt.Errorf("kiro API error: %s", errMsg) + return "", nil, usageInfo, stopReason, fmt.Errorf("kiro API error: %s", errMsg) + } + + // Extract stop_reason from various event formats + // Kiro/Amazon Q API may include stop_reason in different locations + if sr := getString(event, "stop_reason"); sr != "" { + stopReason = sr + log.Debugf("kiro: parseEventStream found stop_reason (top-level): %s", stopReason) + } + if sr := getString(event, "stopReason"); sr != "" { + stopReason = sr + log.Debugf("kiro: parseEventStream found stopReason (top-level): %s", stopReason) } // Handle different event types @@ -1596,6 +1656,15 @@ func (e *KiroExecutor) parseEventStream(body io.Reader) (string, []kiroToolUse, if contentText, ok := assistantResp["content"].(string); ok { content.WriteString(contentText) } + // Extract stop_reason from assistantResponseEvent + if sr := getString(assistantResp, "stop_reason"); sr != "" { + stopReason = sr + log.Debugf("kiro: parseEventStream found stop_reason in assistantResponseEvent: %s", stopReason) + } + if sr := getString(assistantResp, "stopReason"); sr != "" { + stopReason = sr + log.Debugf("kiro: parseEventStream found stopReason in assistantResponseEvent: %s", stopReason) + } // Extract tool uses from response if toolUsesRaw, ok := assistantResp["toolUses"].([]interface{}); ok { for _, tuRaw := range toolUsesRaw { @@ -1661,6 +1730,17 @@ func (e *KiroExecutor) parseEventStream(body io.Reader) (string, []kiroToolUse, if outputTokens, ok := event["outputTokens"].(float64); ok { usageInfo.OutputTokens = int64(outputTokens) } + + case "messageStopEvent", "message_stop": + // Handle message stop events which may contain stop_reason + if sr := getString(event, "stop_reason"); sr != "" { + stopReason = sr + log.Debugf("kiro: parseEventStream found stop_reason in messageStopEvent: %s", stopReason) + } + if sr := getString(event, "stopReason"); sr != "" { + stopReason = sr + log.Debugf("kiro: parseEventStream found stopReason in messageStopEvent: %s", stopReason) + } } // Also check nested supplementaryWebLinksEvent @@ -1682,10 +1762,166 @@ func (e *KiroExecutor) parseEventStream(body io.Reader) (string, []kiroToolUse, // Deduplicate all tool uses toolUses = deduplicateToolUses(toolUses) - return cleanedContent, toolUses, usageInfo, nil + // Apply fallback logic for stop_reason if not provided by upstream + // Priority: upstream stopReason > tool_use detection > end_turn default + if stopReason == "" { + if len(toolUses) > 0 { + stopReason = "tool_use" + log.Debugf("kiro: parseEventStream using fallback stop_reason: tool_use (detected %d tool uses)", len(toolUses)) + } else { + stopReason = "end_turn" + log.Debugf("kiro: parseEventStream using fallback stop_reason: end_turn") + } + } + + // Log warning if response was truncated due to max_tokens + if stopReason == "max_tokens" { + log.Warnf("kiro: response truncated due to max_tokens limit") + } + + return cleanedContent, toolUses, usageInfo, stopReason, nil +} + +// readEventStreamMessage reads and validates a single AWS Event Stream message. +// Returns the parsed message or a structured error for different failure modes. +// This function implements boundary protection and detailed error classification. +// +// AWS Event Stream binary format: +// - Prelude (12 bytes): total_length (4) + headers_length (4) + prelude_crc (4) +// - Headers (variable): header entries +// - Payload (variable): JSON data +// - Message CRC (4 bytes): CRC32C of entire message (not validated, just skipped) +func (e *KiroExecutor) readEventStreamMessage(reader *bufio.Reader) (*eventStreamMessage, *EventStreamError) { + // Read prelude (first 12 bytes: total_len + headers_len + prelude_crc) + prelude := make([]byte, 12) + _, err := io.ReadFull(reader, prelude) + if err == io.EOF { + return nil, nil // Normal end of stream + } + if err != nil { + return nil, &EventStreamError{ + Type: ErrStreamFatal, + Message: "failed to read prelude", + Cause: err, + } + } + + totalLength := binary.BigEndian.Uint32(prelude[0:4]) + headersLength := binary.BigEndian.Uint32(prelude[4:8]) + // Note: prelude[8:12] is prelude_crc - we read it but don't validate (no CRC check per requirements) + + // Boundary check: minimum frame size + if totalLength < minEventStreamFrameSize { + return nil, &EventStreamError{ + Type: ErrStreamMalformed, + Message: fmt.Sprintf("invalid message length: %d (minimum is %d)", totalLength, minEventStreamFrameSize), + } + } + + // Boundary check: maximum message size + if totalLength > maxEventStreamMsgSize { + return nil, &EventStreamError{ + Type: ErrStreamMalformed, + Message: fmt.Sprintf("message too large: %d bytes (maximum is %d)", totalLength, maxEventStreamMsgSize), + } + } + + // Boundary check: headers length within message bounds + // Message structure: prelude(12) + headers(headersLength) + payload + message_crc(4) + // So: headersLength must be <= totalLength - 16 (12 for prelude + 4 for message_crc) + if headersLength > totalLength-16 { + return nil, &EventStreamError{ + Type: ErrStreamMalformed, + Message: fmt.Sprintf("headers length %d exceeds message bounds (total: %d)", headersLength, totalLength), + } + } + + // Read the rest of the message (total - 12 bytes already read) + remaining := make([]byte, totalLength-12) + _, err = io.ReadFull(reader, remaining) + if err != nil { + return nil, &EventStreamError{ + Type: ErrStreamFatal, + Message: "failed to read message body", + Cause: err, + } + } + + // Extract event type from headers + // Headers start at beginning of 'remaining', length is headersLength + var eventType string + if headersLength > 0 && headersLength <= uint32(len(remaining)) { + eventType = e.extractEventTypeFromBytes(remaining[:headersLength]) + } + + // Calculate payload boundaries + // Payload starts after headers, ends before message_crc (last 4 bytes) + payloadStart := headersLength + payloadEnd := uint32(len(remaining)) - 4 // Skip message_crc at end + + // Validate payload boundaries + if payloadStart >= payloadEnd { + // No payload, return empty message + return &eventStreamMessage{ + EventType: eventType, + Payload: nil, + }, nil + } + + payload := remaining[payloadStart:payloadEnd] + + return &eventStreamMessage{ + EventType: eventType, + Payload: payload, + }, nil +} + +// extractEventTypeFromBytes extracts the event type from raw header bytes (without prelude CRC prefix) +func (e *KiroExecutor) extractEventTypeFromBytes(headers []byte) string { + offset := 0 + for offset < len(headers) { + if offset >= len(headers) { + break + } + nameLen := int(headers[offset]) + offset++ + if offset+nameLen > len(headers) { + break + } + name := string(headers[offset : offset+nameLen]) + offset += nameLen + + if offset >= len(headers) { + break + } + valueType := headers[offset] + offset++ + + if valueType == 7 { // String type + if offset+2 > len(headers) { + break + } + valueLen := int(binary.BigEndian.Uint16(headers[offset : offset+2])) + offset += 2 + if offset+valueLen > len(headers) { + break + } + value := string(headers[offset : offset+valueLen]) + offset += valueLen + + if name == ":event-type" { + return value + } + } else { + // Skip other types + break + } + } + return "" } // extractEventType extracts the event type from AWS Event Stream headers +// Note: This is the legacy version that expects headerBytes to include prelude CRC prefix func (e *KiroExecutor) extractEventType(headerBytes []byte) string { // Skip prelude CRC (4 bytes) if len(headerBytes) < 4 { @@ -1746,7 +1982,8 @@ func getString(m map[string]interface{}, key string) string { // buildClaudeResponse constructs a Claude-compatible response. // Supports tool_use blocks when tools are present in the response. // Supports thinking blocks - parses tags and converts to Claude thinking content blocks. -func (e *KiroExecutor) buildClaudeResponse(content string, toolUses []kiroToolUse, model string, usageInfo usage.Detail) []byte { +// stopReason is passed from upstream; fallback logic applied if empty. +func (e *KiroExecutor) buildClaudeResponse(content string, toolUses []kiroToolUse, model string, usageInfo usage.Detail, stopReason string) []byte { var contentBlocks []map[string]interface{} // Extract thinking blocks and text from content @@ -1782,10 +2019,18 @@ func (e *KiroExecutor) buildClaudeResponse(content string, toolUses []kiroToolUs }) } - // Determine stop reason - stopReason := "end_turn" - if len(toolUses) > 0 { - stopReason = "tool_use" + // Use upstream stopReason; apply fallback logic if not provided + if stopReason == "" { + stopReason = "end_turn" + if len(toolUses) > 0 { + stopReason = "tool_use" + } + log.Debugf("kiro: buildClaudeResponse using fallback stop_reason: %s", stopReason) + } + + // Log warning if response was truncated due to max_tokens + if stopReason == "max_tokens" { + log.Warnf("kiro: response truncated due to max_tokens limit (buildClaudeResponse)") } response := map[string]interface{}{ @@ -1906,10 +2151,12 @@ func (e *KiroExecutor) extractThinkingFromContent(content string) []map[string]i // Supports tool calling - emits tool_use content blocks when tools are used. // Includes embedded [Called ...] tool call parsing and input buffering for toolUseEvent. // Implements duplicate content filtering using lastContentEvent detection (based on AIClient-2-API). +// Extracts stop_reason from upstream events when available. func (e *KiroExecutor) streamToChannel(ctx context.Context, body io.Reader, out chan<- cliproxyexecutor.StreamChunk, targetFormat sdktranslator.Format, model string, originalReq, claudeBody []byte, reporter *usageReporter) { reader := bufio.NewReaderSize(body, 20*1024*1024) // 20MB buffer to match other providers var totalUsage usage.Detail - var hasToolUses bool // Track if any tool uses were emitted + var hasToolUses bool // Track if any tool uses were emitted + var upstreamStopReason string // Track stop_reason from upstream events // Tool use state tracking for input buffering and deduplication processedIDs := make(map[string]bool) @@ -1925,6 +2172,12 @@ func (e *KiroExecutor) streamToChannel(ctx context.Context, body io.Reader, out var accumulatedContent strings.Builder accumulatedContent.Grow(4096) // Pre-allocate 4KB capacity to reduce reallocations + // Real-time usage estimation state + // These track when to send periodic usage updates during streaming + var lastUsageUpdateLen int // Last accumulated content length when usage was sent + var lastUsageUpdateTime = time.Now() // Last time usage update was sent + var lastReportedOutputTokens int64 // Last reported output token count + // Translator param for maintaining tool call state across streaming events // IMPORTANT: This must persist across all TranslateStream calls var translatorParam any @@ -1932,24 +2185,37 @@ func (e *KiroExecutor) streamToChannel(ctx context.Context, body io.Reader, out // Thinking mode state tracking - based on amq2api implementation // Tracks whether we're inside a block and handles partial tags inThinkBlock := false - pendingStartTagChars := 0 // Number of chars that might be start of - pendingEndTagChars := 0 // Number of chars that might be start of - isThinkingBlockOpen := false // Track if thinking content block is open - thinkingBlockIndex := -1 // Index of the thinking content block + pendingStartTagChars := 0 // Number of chars that might be start of + pendingEndTagChars := 0 // Number of chars that might be start of + isThinkingBlockOpen := false // Track if thinking content block is open + thinkingBlockIndex := -1 // Index of the thinking content block // Pre-calculate input tokens from request if possible - if enc, err := tokenizerForModel(model); err == nil { - // Try OpenAI format first, then fall back to raw byte count estimation - if inp, err := countOpenAIChatTokens(enc, originalReq); err == nil && inp > 0 { - totalUsage.InputTokens = inp + // Kiro uses Claude format, so try Claude format first, then OpenAI format, then fallback + if enc, err := getTokenizer(model); err == nil { + var inputTokens int64 + var countMethod string + + // Try Claude format first (Kiro uses Claude API format) + if inp, err := countClaudeChatTokens(enc, claudeBody); err == nil && inp > 0 { + inputTokens = inp + countMethod = "claude" + } else if inp, err := countOpenAIChatTokens(enc, originalReq); err == nil && inp > 0 { + // Fallback to OpenAI format (for OpenAI-compatible requests) + inputTokens = inp + countMethod = "openai" } else { - // Fallback: estimate from raw request size (roughly 4 chars per token) - totalUsage.InputTokens = int64(len(originalReq) / 4) - if totalUsage.InputTokens == 0 && len(originalReq) > 0 { - totalUsage.InputTokens = 1 + // Final fallback: estimate from raw request size (roughly 4 chars per token) + inputTokens = int64(len(claudeBody) / 4) + if inputTokens == 0 && len(claudeBody) > 0 { + inputTokens = 1 } + countMethod = "estimate" } - log.Debugf("kiro: streamToChannel pre-calculated input tokens: %d (request size: %d bytes)", totalUsage.InputTokens, len(originalReq)) + + totalUsage.InputTokens = inputTokens + log.Debugf("kiro: streamToChannel pre-calculated input tokens: %d (method: %s, claude body: %d bytes, original req: %d bytes)", + totalUsage.InputTokens, countMethod, len(claudeBody), len(originalReq)) } contentBlockIndex := -1 @@ -1969,9 +2235,17 @@ func (e *KiroExecutor) streamToChannel(ctx context.Context, body io.Reader, out default: } - prelude := make([]byte, 8) - _, err := io.ReadFull(reader, prelude) - if err == io.EOF { + msg, eventErr := e.readEventStreamMessage(reader) + if eventErr != nil { + // Log the error + log.Errorf("kiro: streamToChannel error: %v", eventErr) + + // Send error to channel for client notification + out <- cliproxyexecutor.StreamChunk{Err: eventErr} + return + } + if msg == nil { + // Normal end of stream (EOF) // Flush any incomplete tool use before ending stream if currentToolUse != nil && !processedIDs[currentToolUse.toolUseID] { log.Warnf("kiro: flushing incomplete tool use at EOF: %s (ID: %s)", currentToolUse.name, currentToolUse.toolUseID) @@ -2069,44 +2343,12 @@ func (e *KiroExecutor) streamToChannel(ctx context.Context, body io.Reader, out } break } - if err != nil { - out <- cliproxyexecutor.StreamChunk{Err: fmt.Errorf("failed to read prelude: %w", err)} - return - } - totalLen := binary.BigEndian.Uint32(prelude[0:4]) - if totalLen < 8 { - out <- cliproxyexecutor.StreamChunk{Err: fmt.Errorf("invalid message length: %d", totalLen)} - return - } - if totalLen > kiroMaxMessageSize { - out <- cliproxyexecutor.StreamChunk{Err: fmt.Errorf("message too large: %d bytes", totalLen)} - return - } - headersLen := binary.BigEndian.Uint32(prelude[4:8]) - - remaining := make([]byte, totalLen-8) - _, err = io.ReadFull(reader, remaining) - if err != nil { - out <- cliproxyexecutor.StreamChunk{Err: fmt.Errorf("failed to read message: %w", err)} - return - } - - // Validate headersLen to prevent slice out of bounds - if headersLen+4 > uint32(len(remaining)) { - log.Warnf("kiro: invalid headersLen %d exceeds remaining buffer %d", headersLen, len(remaining)) + eventType := msg.EventType + payload := msg.Payload + if len(payload) == 0 { continue } - - eventType := e.extractEventType(remaining[:headersLen+4]) - - payloadStart := 4 + headersLen - payloadEnd := uint32(len(remaining)) - 4 - if payloadStart >= payloadEnd { - continue - } - - payload := remaining[payloadStart:payloadEnd] appendAPIResponseChunk(ctx, e.cfg, payload) var event map[string]interface{} @@ -2115,12 +2357,6 @@ func (e *KiroExecutor) streamToChannel(ctx context.Context, body io.Reader, out continue } - // DIAGNOSTIC: Log all received event types for debugging - log.Debugf("kiro: streamToChannel received event type: %s", eventType) - if log.IsLevelEnabled(log.TraceLevel) { - log.Tracef("kiro: streamToChannel event payload: %s", string(payload)) - } - // Check for error/exception events in the payload (Kiro API may return errors with HTTP 200) // These can appear as top-level fields or nested within the event if errType, hasErrType := event["_type"].(string); hasErrType { @@ -2148,6 +2384,17 @@ func (e *KiroExecutor) streamToChannel(ctx context.Context, body io.Reader, out return } + // Extract stop_reason from various event formats (streaming) + // Kiro/Amazon Q API may include stop_reason in different locations + if sr := getString(event, "stop_reason"); sr != "" { + upstreamStopReason = sr + log.Debugf("kiro: streamToChannel found stop_reason (top-level): %s", upstreamStopReason) + } + if sr := getString(event, "stopReason"); sr != "" { + upstreamStopReason = sr + log.Debugf("kiro: streamToChannel found stopReason (top-level): %s", upstreamStopReason) + } + // Send message_start on first event if !messageStartSent { msgStart := e.buildClaudeMessageStartEvent(model, totalUsage.InputTokens) @@ -2166,6 +2413,17 @@ func (e *KiroExecutor) streamToChannel(ctx context.Context, body io.Reader, out log.Debugf("kiro: streamToChannel ignoring followupPrompt event") continue + case "messageStopEvent", "message_stop": + // Handle message stop events which may contain stop_reason + if sr := getString(event, "stop_reason"); sr != "" { + upstreamStopReason = sr + log.Debugf("kiro: streamToChannel found stop_reason in messageStopEvent: %s", upstreamStopReason) + } + if sr := getString(event, "stopReason"); sr != "" { + upstreamStopReason = sr + log.Debugf("kiro: streamToChannel found stopReason in messageStopEvent: %s", upstreamStopReason) + } + case "assistantResponseEvent": var contentDelta string var toolUses []map[string]interface{} @@ -2174,6 +2432,15 @@ func (e *KiroExecutor) streamToChannel(ctx context.Context, body io.Reader, out if c, ok := assistantResp["content"].(string); ok { contentDelta = c } + // Extract stop_reason from assistantResponseEvent + if sr := getString(assistantResp, "stop_reason"); sr != "" { + upstreamStopReason = sr + log.Debugf("kiro: streamToChannel found stop_reason in assistantResponseEvent: %s", upstreamStopReason) + } + if sr := getString(assistantResp, "stopReason"); sr != "" { + upstreamStopReason = sr + log.Debugf("kiro: streamToChannel found stopReason in assistantResponseEvent: %s", upstreamStopReason) + } // Extract tool uses from response if tus, ok := assistantResp["toolUses"].([]interface{}); ok { for _, tuRaw := range tus { @@ -2199,11 +2466,6 @@ func (e *KiroExecutor) streamToChannel(ctx context.Context, body io.Reader, out // Handle text content with thinking mode support if contentDelta != "" { - // DIAGNOSTIC: Check for thinking tags in response - if strings.Contains(contentDelta, "") || strings.Contains(contentDelta, "") { - log.Infof("kiro: DIAGNOSTIC - Found thinking tag in response (len: %d)", len(contentDelta)) - } - // NOTE: Duplicate content filtering was removed because it incorrectly // filtered out legitimate repeated content (like consecutive newlines "\n\n"). // Streaming naturally can have identical chunks that are valid content. @@ -2211,6 +2473,52 @@ func (e *KiroExecutor) streamToChannel(ctx context.Context, body io.Reader, out outputLen += len(contentDelta) // Accumulate content for streaming token calculation accumulatedContent.WriteString(contentDelta) + + // Real-time usage estimation: Check if we should send a usage update + // This helps clients track context usage during long thinking sessions + shouldSendUsageUpdate := false + if accumulatedContent.Len()-lastUsageUpdateLen >= usageUpdateCharThreshold { + shouldSendUsageUpdate = true + } else if time.Since(lastUsageUpdateTime) >= usageUpdateTimeInterval && accumulatedContent.Len() > lastUsageUpdateLen { + shouldSendUsageUpdate = true + } + + if shouldSendUsageUpdate { + // Calculate current output tokens using tiktoken + var currentOutputTokens int64 + if enc, encErr := getTokenizer(model); encErr == nil { + if tokenCount, countErr := enc.Count(accumulatedContent.String()); countErr == nil { + currentOutputTokens = int64(tokenCount) + } + } + // Fallback to character estimation if tiktoken fails + if currentOutputTokens == 0 { + currentOutputTokens = int64(accumulatedContent.Len() / 4) + if currentOutputTokens == 0 { + currentOutputTokens = 1 + } + } + + // Only send update if token count has changed significantly (at least 10 tokens) + if currentOutputTokens > lastReportedOutputTokens+10 { + // Send ping event with usage information + // This is a non-blocking update that clients can optionally process + pingEvent := e.buildClaudePingEventWithUsage(totalUsage.InputTokens, currentOutputTokens) + sseData := sdktranslator.TranslateStream(ctx, sdktranslator.FromString("kiro"), targetFormat, model, originalReq, claudeBody, pingEvent, &translatorParam) + for _, chunk := range sseData { + if chunk != "" { + out <- cliproxyexecutor.StreamChunk{Payload: []byte(chunk + "\n\n")} + } + } + + lastReportedOutputTokens = currentOutputTokens + log.Debugf("kiro: sent real-time usage update - input: %d, output: %d (accumulated: %d chars)", + totalUsage.InputTokens, currentOutputTokens, accumulatedContent.Len()) + } + + lastUsageUpdateLen = accumulatedContent.Len() + lastUsageUpdateTime = time.Now() + } // Process content with thinking tag detection - based on amq2api implementation // This handles and tags that may span across chunks @@ -2577,10 +2885,10 @@ func (e *KiroExecutor) streamToChannel(ctx context.Context, body io.Reader, out } // Streaming token calculation - calculate output tokens from accumulated content - // This provides more accurate token counting than simple character division + // Only use local estimation if server didn't provide usage (server-side usage takes priority) if totalUsage.OutputTokens == 0 && accumulatedContent.Len() > 0 { // Try to use tiktoken for accurate counting - if enc, err := tokenizerForModel(model); err == nil { + if enc, err := getTokenizer(model); err == nil { if tokenCount, countErr := enc.Count(accumulatedContent.String()); countErr == nil { totalUsage.OutputTokens = int64(tokenCount) log.Debugf("kiro: streamToChannel calculated output tokens using tiktoken: %d", totalUsage.OutputTokens) @@ -2609,10 +2917,21 @@ func (e *KiroExecutor) streamToChannel(ctx context.Context, body io.Reader, out } totalUsage.TotalTokens = totalUsage.InputTokens + totalUsage.OutputTokens - // Determine stop reason based on whether tool uses were emitted - stopReason := "end_turn" - if hasToolUses { - stopReason = "tool_use" + // Determine stop reason: prefer upstream, then detect tool_use, default to end_turn + stopReason := upstreamStopReason + if stopReason == "" { + if hasToolUses { + stopReason = "tool_use" + log.Debugf("kiro: streamToChannel using fallback stop_reason: tool_use") + } else { + stopReason = "end_turn" + log.Debugf("kiro: streamToChannel using fallback stop_reason: end_turn") + } + } + + // Log warning if response was truncated due to max_tokens + if stopReason == "max_tokens" { + log.Warnf("kiro: response truncated due to max_tokens limit (streamToChannel)") } // Send message_delta event @@ -2758,6 +3077,24 @@ func (e *KiroExecutor) buildClaudeFinalEvent() []byte { return []byte("event: message_stop\ndata: " + string(result)) } +// buildClaudePingEventWithUsage creates a ping event with embedded usage information. +// This is used for real-time usage estimation during streaming. +// The usage field is a non-standard extension that clients can optionally process. +// Clients that don't recognize the usage field will simply ignore it. +func (e *KiroExecutor) buildClaudePingEventWithUsage(inputTokens, outputTokens int64) []byte { + event := map[string]interface{}{ + "type": "ping", + "usage": map[string]interface{}{ + "input_tokens": inputTokens, + "output_tokens": outputTokens, + "total_tokens": inputTokens + outputTokens, + "estimated": true, // Flag to indicate this is an estimate, not final + }, + } + result, _ := json.Marshal(event) + return []byte("event: ping\ndata: " + string(result)) +} + // buildClaudeThinkingDeltaEvent creates a thinking_delta event for Claude API compatibility. // This is used when streaming thinking content wrapped in tags. func (e *KiroExecutor) buildClaudeThinkingDeltaEvent(thinkingDelta string, index int) []byte { @@ -2837,10 +3174,21 @@ func (e *KiroExecutor) Refresh(ctx context.Context, auth *cliproxyauth.Auth) (*c // Also check if expires_at is now in the future with sufficient buffer if expiresAt, ok := auth.Metadata["expires_at"].(string); ok { if expTime, err := time.Parse(time.RFC3339, expiresAt); err == nil { - // If token expires more than 2 minutes from now, it's still valid - if time.Until(expTime) > 2*time.Minute { + // If token expires more than 5 minutes from now, it's still valid + if time.Until(expTime) > 5*time.Minute { log.Debugf("kiro executor: token is still valid (expires in %v), skipping refresh", time.Until(expTime)) - return auth, nil + // CRITICAL FIX: Set NextRefreshAfter to prevent frequent refresh checks + // Without this, shouldRefresh() will return true again in 5 seconds + updated := auth.Clone() + // Set next refresh to 5 minutes before expiry, or at least 30 seconds from now + nextRefresh := expTime.Add(-5 * time.Minute) + minNextRefresh := time.Now().Add(30 * time.Second) + if nextRefresh.Before(minNextRefresh) { + nextRefresh = minNextRefresh + } + updated.NextRefreshAfter = nextRefresh + log.Debugf("kiro executor: setting NextRefreshAfter to %v (in %v)", nextRefresh.Format(time.RFC3339), time.Until(nextRefresh)) + return updated, nil } } } @@ -2924,9 +3272,9 @@ func (e *KiroExecutor) Refresh(ctx context.Context, auth *cliproxyauth.Auth) (*c updated.Attributes["profile_arn"] = tokenData.ProfileArn } - // Set next refresh time to 30 minutes before expiry + // NextRefreshAfter is aligned with RefreshLead (5min) if expiresAt, parseErr := time.Parse(time.RFC3339, tokenData.ExpiresAt); parseErr == nil { - updated.NextRefreshAfter = expiresAt.Add(-30 * time.Minute) + updated.NextRefreshAfter = expiresAt.Add(-5 * time.Minute) } log.Infof("kiro executor: token refreshed successfully, expires at %s", tokenData.ExpiresAt) @@ -2943,7 +3291,7 @@ func (e *KiroExecutor) streamEventStream(ctx context.Context, body io.Reader, c var translatorParam any // Pre-calculate input tokens from request if possible - if enc, err := tokenizerForModel(model); err == nil { + if enc, err := getTokenizer(model); err == nil { // Try OpenAI format first, then fall back to raw byte count estimation if inp, err := countOpenAIChatTokens(enc, originalReq); err == nil && inp > 0 { totalUsage.InputTokens = inp diff --git a/internal/runtime/executor/proxy_helpers.go b/internal/runtime/executor/proxy_helpers.go index 8998eb23..8ac91e03 100644 --- a/internal/runtime/executor/proxy_helpers.go +++ b/internal/runtime/executor/proxy_helpers.go @@ -137,15 +137,25 @@ func buildProxyTransport(proxyURL string) *http.Transport { log.Errorf("create SOCKS5 dialer failed: %v", errSOCKS5) return nil } - // Set up a custom transport using the SOCKS5 dialer + // Set up a custom transport using the SOCKS5 dialer with optimized connection pooling transport = &http.Transport{ DialContext: func(ctx context.Context, network, addr string) (net.Conn, error) { return dialer.Dial(network, addr) }, + MaxIdleConns: 100, + MaxIdleConnsPerHost: 20, // Increased from default 2 to support more concurrent users + MaxConnsPerHost: 0, // No limit on max concurrent connections per host + IdleConnTimeout: 90 * time.Second, } } else if parsedURL.Scheme == "http" || parsedURL.Scheme == "https" { - // Configure HTTP or HTTPS proxy - transport = &http.Transport{Proxy: http.ProxyURL(parsedURL)} + // Configure HTTP or HTTPS proxy with optimized connection pooling + transport = &http.Transport{ + Proxy: http.ProxyURL(parsedURL), + MaxIdleConns: 100, + MaxIdleConnsPerHost: 20, // Increased from default 2 to support more concurrent users + MaxConnsPerHost: 0, // No limit on max concurrent connections per host + IdleConnTimeout: 90 * time.Second, + } } else { log.Errorf("unsupported proxy scheme: %s", parsedURL.Scheme) return nil diff --git a/internal/runtime/executor/token_helpers.go b/internal/runtime/executor/token_helpers.go index f4236f9b..3dd2a2b5 100644 --- a/internal/runtime/executor/token_helpers.go +++ b/internal/runtime/executor/token_helpers.go @@ -2,43 +2,107 @@ package executor import ( "fmt" + "regexp" + "strconv" "strings" + "sync" "github.com/tidwall/gjson" "github.com/tiktoken-go/tokenizer" ) +// tokenizerCache stores tokenizer instances to avoid repeated creation +var tokenizerCache sync.Map + +// TokenizerWrapper wraps a tokenizer codec with an adjustment factor for models +// where tiktoken may not accurately estimate token counts (e.g., Claude models) +type TokenizerWrapper struct { + Codec tokenizer.Codec + AdjustmentFactor float64 // 1.0 means no adjustment, >1.0 means tiktoken underestimates +} + +// Count returns the token count with adjustment factor applied +func (tw *TokenizerWrapper) Count(text string) (int, error) { + count, err := tw.Codec.Count(text) + if err != nil { + return 0, err + } + if tw.AdjustmentFactor != 1.0 && tw.AdjustmentFactor > 0 { + return int(float64(count) * tw.AdjustmentFactor), nil + } + return count, nil +} + +// getTokenizer returns a cached tokenizer for the given model. +// This improves performance by avoiding repeated tokenizer creation. +func getTokenizer(model string) (*TokenizerWrapper, error) { + // Check cache first + if cached, ok := tokenizerCache.Load(model); ok { + return cached.(*TokenizerWrapper), nil + } + + // Cache miss, create new tokenizer + wrapper, err := tokenizerForModel(model) + if err != nil { + return nil, err + } + + // Store in cache (use LoadOrStore to handle race conditions) + actual, _ := tokenizerCache.LoadOrStore(model, wrapper) + return actual.(*TokenizerWrapper), nil +} + // tokenizerForModel returns a tokenizer codec suitable for an OpenAI-style model id. -func tokenizerForModel(model string) (tokenizer.Codec, error) { +// For Claude models, applies a 1.1 adjustment factor since tiktoken may underestimate. +func tokenizerForModel(model string) (*TokenizerWrapper, error) { sanitized := strings.ToLower(strings.TrimSpace(model)) + + // Claude models use cl100k_base with 1.1 adjustment factor + // because tiktoken may underestimate Claude's actual token count + if strings.Contains(sanitized, "claude") || strings.HasPrefix(sanitized, "kiro-") || strings.HasPrefix(sanitized, "amazonq-") { + enc, err := tokenizer.Get(tokenizer.Cl100kBase) + if err != nil { + return nil, err + } + return &TokenizerWrapper{Codec: enc, AdjustmentFactor: 1.1}, nil + } + + var enc tokenizer.Codec + var err error + switch { case sanitized == "": - return tokenizer.Get(tokenizer.Cl100kBase) + enc, err = tokenizer.Get(tokenizer.Cl100kBase) case strings.HasPrefix(sanitized, "gpt-5"): - return tokenizer.ForModel(tokenizer.GPT5) + enc, err = tokenizer.ForModel(tokenizer.GPT5) case strings.HasPrefix(sanitized, "gpt-5.1"): - return tokenizer.ForModel(tokenizer.GPT5) + enc, err = tokenizer.ForModel(tokenizer.GPT5) case strings.HasPrefix(sanitized, "gpt-4.1"): - return tokenizer.ForModel(tokenizer.GPT41) + enc, err = tokenizer.ForModel(tokenizer.GPT41) case strings.HasPrefix(sanitized, "gpt-4o"): - return tokenizer.ForModel(tokenizer.GPT4o) + enc, err = tokenizer.ForModel(tokenizer.GPT4o) case strings.HasPrefix(sanitized, "gpt-4"): - return tokenizer.ForModel(tokenizer.GPT4) + enc, err = tokenizer.ForModel(tokenizer.GPT4) case strings.HasPrefix(sanitized, "gpt-3.5"), strings.HasPrefix(sanitized, "gpt-3"): - return tokenizer.ForModel(tokenizer.GPT35Turbo) + enc, err = tokenizer.ForModel(tokenizer.GPT35Turbo) case strings.HasPrefix(sanitized, "o1"): - return tokenizer.ForModel(tokenizer.O1) + enc, err = tokenizer.ForModel(tokenizer.O1) case strings.HasPrefix(sanitized, "o3"): - return tokenizer.ForModel(tokenizer.O3) + enc, err = tokenizer.ForModel(tokenizer.O3) case strings.HasPrefix(sanitized, "o4"): - return tokenizer.ForModel(tokenizer.O4Mini) + enc, err = tokenizer.ForModel(tokenizer.O4Mini) default: - return tokenizer.Get(tokenizer.O200kBase) + enc, err = tokenizer.Get(tokenizer.O200kBase) } + + if err != nil { + return nil, err + } + return &TokenizerWrapper{Codec: enc, AdjustmentFactor: 1.0}, nil } // countOpenAIChatTokens approximates prompt tokens for OpenAI chat completions payloads. -func countOpenAIChatTokens(enc tokenizer.Codec, payload []byte) (int64, error) { +func countOpenAIChatTokens(enc *TokenizerWrapper, payload []byte) (int64, error) { if enc == nil { return 0, fmt.Errorf("encoder is nil") } @@ -62,11 +126,206 @@ func countOpenAIChatTokens(enc tokenizer.Codec, payload []byte) (int64, error) { return 0, nil } + // Count text tokens count, err := enc.Count(joined) if err != nil { return 0, err } - return int64(count), nil + + // Extract and add image tokens from placeholders + imageTokens := extractImageTokens(joined) + + return int64(count) + int64(imageTokens), nil +} + +// countClaudeChatTokens approximates prompt tokens for Claude API chat completions payloads. +// This handles Claude's message format with system, messages, and tools. +// Image tokens are estimated based on image dimensions when available. +func countClaudeChatTokens(enc *TokenizerWrapper, payload []byte) (int64, error) { + if enc == nil { + return 0, fmt.Errorf("encoder is nil") + } + if len(payload) == 0 { + return 0, nil + } + + root := gjson.ParseBytes(payload) + segments := make([]string, 0, 32) + + // Collect system prompt (can be string or array of content blocks) + collectClaudeSystem(root.Get("system"), &segments) + + // Collect messages + collectClaudeMessages(root.Get("messages"), &segments) + + // Collect tools + collectClaudeTools(root.Get("tools"), &segments) + + joined := strings.TrimSpace(strings.Join(segments, "\n")) + if joined == "" { + return 0, nil + } + + // Count text tokens + count, err := enc.Count(joined) + if err != nil { + return 0, err + } + + // Extract and add image tokens from placeholders + imageTokens := extractImageTokens(joined) + + return int64(count) + int64(imageTokens), nil +} + +// imageTokenPattern matches [IMAGE:xxx tokens] format for extracting estimated image tokens +var imageTokenPattern = regexp.MustCompile(`\[IMAGE:(\d+) tokens\]`) + +// extractImageTokens extracts image token estimates from placeholder text. +// Placeholders are in the format [IMAGE:xxx tokens] where xxx is the estimated token count. +func extractImageTokens(text string) int { + matches := imageTokenPattern.FindAllStringSubmatch(text, -1) + total := 0 + for _, match := range matches { + if len(match) > 1 { + if tokens, err := strconv.Atoi(match[1]); err == nil { + total += tokens + } + } + } + return total +} + +// estimateImageTokens calculates estimated tokens for an image based on dimensions. +// Based on Claude's image token calculation: tokens ≈ (width * height) / 750 +// Minimum 85 tokens, maximum 1590 tokens (for 1568x1568 images). +func estimateImageTokens(width, height float64) int { + if width <= 0 || height <= 0 { + // No valid dimensions, use default estimate (medium-sized image) + return 1000 + } + + tokens := int(width * height / 750) + + // Apply bounds + if tokens < 85 { + tokens = 85 + } + if tokens > 1590 { + tokens = 1590 + } + + return tokens +} + +// collectClaudeSystem extracts text from Claude's system field. +// System can be a string or an array of content blocks. +func collectClaudeSystem(system gjson.Result, segments *[]string) { + if !system.Exists() { + return + } + if system.Type == gjson.String { + addIfNotEmpty(segments, system.String()) + return + } + if system.IsArray() { + system.ForEach(func(_, block gjson.Result) bool { + blockType := block.Get("type").String() + if blockType == "text" || blockType == "" { + addIfNotEmpty(segments, block.Get("text").String()) + } + // Also handle plain string blocks + if block.Type == gjson.String { + addIfNotEmpty(segments, block.String()) + } + return true + }) + } +} + +// collectClaudeMessages extracts text from Claude's messages array. +func collectClaudeMessages(messages gjson.Result, segments *[]string) { + if !messages.Exists() || !messages.IsArray() { + return + } + messages.ForEach(func(_, message gjson.Result) bool { + addIfNotEmpty(segments, message.Get("role").String()) + collectClaudeContent(message.Get("content"), segments) + return true + }) +} + +// collectClaudeContent extracts text from Claude's content field. +// Content can be a string or an array of content blocks. +// For images, estimates token count based on dimensions when available. +func collectClaudeContent(content gjson.Result, segments *[]string) { + if !content.Exists() { + return + } + if content.Type == gjson.String { + addIfNotEmpty(segments, content.String()) + return + } + if content.IsArray() { + content.ForEach(func(_, part gjson.Result) bool { + partType := part.Get("type").String() + switch partType { + case "text": + addIfNotEmpty(segments, part.Get("text").String()) + case "image": + // Estimate image tokens based on dimensions if available + source := part.Get("source") + if source.Exists() { + width := source.Get("width").Float() + height := source.Get("height").Float() + if width > 0 && height > 0 { + tokens := estimateImageTokens(width, height) + addIfNotEmpty(segments, fmt.Sprintf("[IMAGE:%d tokens]", tokens)) + } else { + // No dimensions available, use default estimate + addIfNotEmpty(segments, "[IMAGE:1000 tokens]") + } + } else { + // No source info, use default estimate + addIfNotEmpty(segments, "[IMAGE:1000 tokens]") + } + case "tool_use": + addIfNotEmpty(segments, part.Get("id").String()) + addIfNotEmpty(segments, part.Get("name").String()) + if input := part.Get("input"); input.Exists() { + addIfNotEmpty(segments, input.Raw) + } + case "tool_result": + addIfNotEmpty(segments, part.Get("tool_use_id").String()) + collectClaudeContent(part.Get("content"), segments) + case "thinking": + addIfNotEmpty(segments, part.Get("thinking").String()) + default: + // For unknown types, try to extract any text content + if part.Type == gjson.String { + addIfNotEmpty(segments, part.String()) + } else if part.Type == gjson.JSON { + addIfNotEmpty(segments, part.Raw) + } + } + return true + }) + } +} + +// collectClaudeTools extracts text from Claude's tools array. +func collectClaudeTools(tools gjson.Result, segments *[]string) { + if !tools.Exists() || !tools.IsArray() { + return + } + tools.ForEach(func(_, tool gjson.Result) bool { + addIfNotEmpty(segments, tool.Get("name").String()) + addIfNotEmpty(segments, tool.Get("description").String()) + if inputSchema := tool.Get("input_schema"); inputSchema.Exists() { + addIfNotEmpty(segments, inputSchema.Raw) + } + return true + }) } // buildOpenAIUsageJSON returns a minimal usage structure understood by downstream translators. diff --git a/internal/util/proxy.go b/internal/util/proxy.go index aea52ba8..e5ac7cd6 100644 --- a/internal/util/proxy.go +++ b/internal/util/proxy.go @@ -8,6 +8,7 @@ import ( "net" "net/http" "net/url" + "time" "github.com/router-for-me/CLIProxyAPI/v6/sdk/config" log "github.com/sirupsen/logrus" @@ -36,15 +37,25 @@ func SetProxy(cfg *config.SDKConfig, httpClient *http.Client) *http.Client { log.Errorf("create SOCKS5 dialer failed: %v", errSOCKS5) return httpClient } - // Set up a custom transport using the SOCKS5 dialer. + // Set up a custom transport using the SOCKS5 dialer with optimized connection pooling transport = &http.Transport{ DialContext: func(ctx context.Context, network, addr string) (net.Conn, error) { return dialer.Dial(network, addr) }, + MaxIdleConns: 100, + MaxIdleConnsPerHost: 20, // Increased from default 2 to support more concurrent users + MaxConnsPerHost: 0, // No limit on max concurrent connections per host + IdleConnTimeout: 90 * time.Second, } } else if proxyURL.Scheme == "http" || proxyURL.Scheme == "https" { - // Configure HTTP or HTTPS proxy. - transport = &http.Transport{Proxy: http.ProxyURL(proxyURL)} + // Configure HTTP or HTTPS proxy with optimized connection pooling + transport = &http.Transport{ + Proxy: http.ProxyURL(proxyURL), + MaxIdleConns: 100, + MaxIdleConnsPerHost: 20, // Increased from default 2 to support more concurrent users + MaxConnsPerHost: 0, // No limit on max concurrent connections per host + IdleConnTimeout: 90 * time.Second, + } } } // If a new transport was created, apply it to the HTTP client. diff --git a/sdk/auth/kiro.go b/sdk/auth/kiro.go index b95d103b..1eed4b94 100644 --- a/sdk/auth/kiro.go +++ b/sdk/auth/kiro.go @@ -47,8 +47,9 @@ func (a *KiroAuthenticator) Provider() string { } // RefreshLead indicates how soon before expiry a refresh should be attempted. +// Set to 5 minutes to match Antigravity and avoid frequent refresh checks while still ensuring timely token refresh. func (a *KiroAuthenticator) RefreshLead() *time.Duration { - d := 30 * time.Minute + d := 5 * time.Minute return &d } @@ -103,7 +104,8 @@ func (a *KiroAuthenticator) Login(ctx context.Context, cfg *config.Config, opts "source": "aws-builder-id", "email": tokenData.Email, }, - NextRefreshAfter: expiresAt.Add(-30 * time.Minute), + // NextRefreshAfter is aligned with RefreshLead (5min) + NextRefreshAfter: expiresAt.Add(-5 * time.Minute), } if tokenData.Email != "" { @@ -165,7 +167,8 @@ func (a *KiroAuthenticator) LoginWithGoogle(ctx context.Context, cfg *config.Con "source": "google-oauth", "email": tokenData.Email, }, - NextRefreshAfter: expiresAt.Add(-30 * time.Minute), + // NextRefreshAfter is aligned with RefreshLead (5min) + NextRefreshAfter: expiresAt.Add(-5 * time.Minute), } if tokenData.Email != "" { @@ -227,7 +230,8 @@ func (a *KiroAuthenticator) LoginWithGitHub(ctx context.Context, cfg *config.Con "source": "github-oauth", "email": tokenData.Email, }, - NextRefreshAfter: expiresAt.Add(-30 * time.Minute), + // NextRefreshAfter is aligned with RefreshLead (5min) + NextRefreshAfter: expiresAt.Add(-5 * time.Minute), } if tokenData.Email != "" { @@ -291,7 +295,8 @@ func (a *KiroAuthenticator) ImportFromKiroIDE(ctx context.Context, cfg *config.C "source": "kiro-ide-import", "email": tokenData.Email, }, - NextRefreshAfter: expiresAt.Add(-30 * time.Minute), + // NextRefreshAfter is aligned with RefreshLead (5min) + NextRefreshAfter: expiresAt.Add(-5 * time.Minute), } // Display the email if extracted @@ -351,7 +356,8 @@ func (a *KiroAuthenticator) Refresh(ctx context.Context, cfg *config.Config, aut updated.Metadata["refresh_token"] = tokenData.RefreshToken updated.Metadata["expires_at"] = tokenData.ExpiresAt updated.Metadata["last_refresh"] = now.Format(time.RFC3339) // For double-check optimization - updated.NextRefreshAfter = expiresAt.Add(-30 * time.Minute) + // NextRefreshAfter is aligned with RefreshLead (5min) + updated.NextRefreshAfter = expiresAt.Add(-5 * time.Minute) return updated, nil } diff --git a/sdk/cliproxy/auth/manager.go b/sdk/cliproxy/auth/manager.go index dc7887e7..eba33bb8 100644 --- a/sdk/cliproxy/auth/manager.go +++ b/sdk/cliproxy/auth/manager.go @@ -40,7 +40,7 @@ type RefreshEvaluator interface { const ( refreshCheckInterval = 5 * time.Second refreshPendingBackoff = time.Minute - refreshFailureBackoff = 5 * time.Minute + refreshFailureBackoff = 1 * time.Minute quotaBackoffBase = time.Second quotaBackoffMax = 30 * time.Minute ) @@ -1471,7 +1471,9 @@ func (m *Manager) refreshAuth(ctx context.Context, id string) { updated.Runtime = auth.Runtime } updated.LastRefreshedAt = now - updated.NextRefreshAfter = time.Time{} + // Preserve NextRefreshAfter set by the Authenticator + // If the Authenticator set a reasonable refresh time, it should not be overwritten + // If the Authenticator did not set it (zero value), shouldRefresh will use default logic updated.LastError = nil updated.UpdatedAt = now _, _ = m.Update(ctx, updated) From 75793a18f06c9a539f5995aef86f4b783448f6dc Mon Sep 17 00:00:00 2001 From: Ravens2121 Date: Sat, 13 Dec 2025 11:36:22 +0800 Subject: [PATCH 32/38] feat(kiro): Add Kiro OAuth login entry and auth file filter in Web UI MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit 为Kiro供应商添加WEB UI OAuth登录入口和认证文件过滤器 ## Changes / 更改内容 ### Frontend / 前端 (management.html) - Add Kiro OAuth card UI with support for AWS Builder ID, Google, and GitHub login methods - 添加Kiro OAuth卡片UI,支持AWS Builder ID、Google和GitHub三种登录方式 - Add i18n translations for Kiro OAuth (Chinese and English) - 添加Kiro OAuth的中英文国际化翻译 - Add Kiro filter button in auth files management page - 在认证文件管理页面添加Kiro过滤按钮 - Implement JavaScript methods: startKiroOAuth(), openKiroLink(), copyKiroLink(), copyKiroDeviceCode(), startKiroOAuthPolling(), resetKiroOAuthUI() - 实现JavaScript方法:startKiroOAuth()、openKiroLink()、copyKiroLink()、copyKiroDeviceCode()、startKiroOAuthPolling()、resetKiroOAuthUI() ### Backend / 后端 - Add /kiro-auth-url endpoint for Kiro OAuth authentication (auth_files.go) - 添加/kiro-auth-url端点用于Kiro OAuth认证 (auth_files.go) - Fix GetAuthStatus() to correctly parse device_code and auth_url status - 修复GetAuthStatus()以正确解析device_code和auth_url状态 - Change status delimiter from ':' to '|' to avoid URL parsing issues - 将状态分隔符从':'改为'|'以避免URL解析问题 - Export CreateToken method in social_auth.go - 在social_auth.go中导出CreateToken方法 - Register Kiro OAuth routes in server.go - 在server.go中注册Kiro OAuth路由 ## Files Modified / 修改的文件 - management.html - internal/api/handlers/management/auth_files.go - internal/api/server.go - internal/auth/kiro/social_auth.go --- .../api/handlers/management/auth_files.go | 328 +++++++++++++++++- internal/api/modules/amp/proxy.go | 18 - internal/api/server.go | 13 + internal/auth/kiro/social_auth.go | 6 +- internal/runtime/executor/proxy_helpers.go | 17 +- internal/util/proxy.go | 17 +- 6 files changed, 347 insertions(+), 52 deletions(-) diff --git a/internal/api/handlers/management/auth_files.go b/internal/api/handlers/management/auth_files.go index d35570ce..265b4f8c 100644 --- a/internal/api/handlers/management/auth_files.go +++ b/internal/api/handlers/management/auth_files.go @@ -3,6 +3,9 @@ package management import ( "bytes" "context" + "crypto/rand" + "crypto/sha256" + "encoding/base64" "encoding/json" "errors" "fmt" @@ -23,6 +26,7 @@ import ( "github.com/router-for-me/CLIProxyAPI/v6/internal/auth/codex" geminiAuth "github.com/router-for-me/CLIProxyAPI/v6/internal/auth/gemini" iflowauth "github.com/router-for-me/CLIProxyAPI/v6/internal/auth/iflow" + kiroauth "github.com/router-for-me/CLIProxyAPI/v6/internal/auth/kiro" "github.com/router-for-me/CLIProxyAPI/v6/internal/auth/qwen" "github.com/router-for-me/CLIProxyAPI/v6/internal/interfaces" "github.com/router-for-me/CLIProxyAPI/v6/internal/misc" @@ -2154,9 +2158,35 @@ func checkCloudAPIIsEnabled(ctx context.Context, httpClient *http.Client, projec func (h *Handler) GetAuthStatus(c *gin.Context) { state := c.Query("state") - if err, ok := getOAuthStatus(state); ok { - if err != "" { - c.JSON(200, gin.H{"status": "error", "error": err}) + if statusValue, ok := getOAuthStatus(state); ok { + if statusValue != "" { + // Check for device_code prefix (Kiro AWS Builder ID flow) + // Format: "device_code|verification_url|user_code" + // Using "|" as separator because URLs contain ":" + if strings.HasPrefix(statusValue, "device_code|") { + parts := strings.SplitN(statusValue, "|", 3) + if len(parts) == 3 { + c.JSON(200, gin.H{ + "status": "device_code", + "verification_url": parts[1], + "user_code": parts[2], + }) + return + } + } + // Check for auth_url prefix (Kiro social auth flow) + // Format: "auth_url|url" + // Using "|" as separator because URLs contain ":" + if strings.HasPrefix(statusValue, "auth_url|") { + authURL := strings.TrimPrefix(statusValue, "auth_url|") + c.JSON(200, gin.H{ + "status": "auth_url", + "url": authURL, + }) + return + } + // Otherwise treat as error + c.JSON(200, gin.H{"status": "error", "error": statusValue}) } else { c.JSON(200, gin.H{"status": "wait"}) return @@ -2166,3 +2196,295 @@ func (h *Handler) GetAuthStatus(c *gin.Context) { } deleteOAuthStatus(state) } + +const kiroCallbackPort = 9876 + +func (h *Handler) RequestKiroToken(c *gin.Context) { + ctx := context.Background() + + // Get the login method from query parameter (default: aws for device code flow) + method := strings.ToLower(strings.TrimSpace(c.Query("method"))) + if method == "" { + method = "aws" + } + + fmt.Println("Initializing Kiro authentication...") + + state := fmt.Sprintf("kiro-%d", time.Now().UnixNano()) + + switch method { + case "aws", "builder-id": + // AWS Builder ID uses device code flow (no callback needed) + go func() { + ssoClient := kiroauth.NewSSOOIDCClient(h.cfg) + + // Step 1: Register client + fmt.Println("Registering client...") + regResp, err := ssoClient.RegisterClient(ctx) + if err != nil { + log.Errorf("Failed to register client: %v", err) + setOAuthStatus(state, "Failed to register client") + return + } + + // Step 2: Start device authorization + fmt.Println("Starting device authorization...") + authResp, err := ssoClient.StartDeviceAuthorization(ctx, regResp.ClientID, regResp.ClientSecret) + if err != nil { + log.Errorf("Failed to start device auth: %v", err) + setOAuthStatus(state, "Failed to start device authorization") + return + } + + // Store the verification URL for the frontend to display + // Using "|" as separator because URLs contain ":" + setOAuthStatus(state, "device_code|"+authResp.VerificationURIComplete+"|"+authResp.UserCode) + + // Step 3: Poll for token + fmt.Println("Waiting for authorization...") + interval := 5 * time.Second + if authResp.Interval > 0 { + interval = time.Duration(authResp.Interval) * time.Second + } + deadline := time.Now().Add(time.Duration(authResp.ExpiresIn) * time.Second) + + for time.Now().Before(deadline) { + select { + case <-ctx.Done(): + setOAuthStatus(state, "Authorization cancelled") + return + case <-time.After(interval): + tokenResp, err := ssoClient.CreateToken(ctx, regResp.ClientID, regResp.ClientSecret, authResp.DeviceCode) + if err != nil { + errStr := err.Error() + if strings.Contains(errStr, "authorization_pending") { + continue + } + if strings.Contains(errStr, "slow_down") { + interval += 5 * time.Second + continue + } + log.Errorf("Token creation failed: %v", err) + setOAuthStatus(state, "Token creation failed") + return + } + + // Success! Save the token + expiresAt := time.Now().Add(time.Duration(tokenResp.ExpiresIn) * time.Second) + email := kiroauth.ExtractEmailFromJWT(tokenResp.AccessToken) + + idPart := kiroauth.SanitizeEmailForFilename(email) + if idPart == "" { + idPart = fmt.Sprintf("%d", time.Now().UnixNano()%100000) + } + + now := time.Now() + fileName := fmt.Sprintf("kiro-aws-%s.json", idPart) + + record := &coreauth.Auth{ + ID: fileName, + Provider: "kiro", + FileName: fileName, + Metadata: map[string]any{ + "type": "kiro", + "access_token": tokenResp.AccessToken, + "refresh_token": tokenResp.RefreshToken, + "expires_at": expiresAt.Format(time.RFC3339), + "auth_method": "builder-id", + "provider": "AWS", + "client_id": regResp.ClientID, + "client_secret": regResp.ClientSecret, + "email": email, + "last_refresh": now.Format(time.RFC3339), + }, + } + + savedPath, errSave := h.saveTokenRecord(ctx, record) + if errSave != nil { + log.Errorf("Failed to save authentication tokens: %v", errSave) + setOAuthStatus(state, "Failed to save authentication tokens") + return + } + + fmt.Printf("Authentication successful! Token saved to %s\n", savedPath) + if email != "" { + fmt.Printf("Authenticated as: %s\n", email) + } + deleteOAuthStatus(state) + return + } + } + + setOAuthStatus(state, "Authorization timed out") + }() + + // Return immediately with the state for polling + c.JSON(200, gin.H{"status": "ok", "state": state, "method": "device_code"}) + + case "google", "github": + // Social auth uses protocol handler - for WEB UI we use a callback forwarder + provider := "Google" + if method == "github" { + provider = "Github" + } + + isWebUI := isWebUIRequest(c) + if isWebUI { + targetURL, errTarget := h.managementCallbackURL("/kiro/callback") + if errTarget != nil { + log.WithError(errTarget).Error("failed to compute kiro callback target") + c.JSON(http.StatusInternalServerError, gin.H{"error": "callback server unavailable"}) + return + } + if _, errStart := startCallbackForwarder(kiroCallbackPort, "kiro", targetURL); errStart != nil { + log.WithError(errStart).Error("failed to start kiro callback forwarder") + c.JSON(http.StatusInternalServerError, gin.H{"error": "failed to start callback server"}) + return + } + } + + go func() { + if isWebUI { + defer stopCallbackForwarder(kiroCallbackPort) + } + + socialClient := kiroauth.NewSocialAuthClient(h.cfg) + + // Generate PKCE codes + codeVerifier, codeChallenge, err := generateKiroPKCE() + if err != nil { + log.Errorf("Failed to generate PKCE: %v", err) + setOAuthStatus(state, "Failed to generate PKCE") + return + } + + // Build login URL + authURL := fmt.Sprintf("%s/login?idp=%s&redirect_uri=%s&code_challenge=%s&code_challenge_method=S256&state=%s&prompt=select_account", + "https://prod.us-east-1.auth.desktop.kiro.dev", + provider, + url.QueryEscape(kiroauth.KiroRedirectURI), + codeChallenge, + state, + ) + + // Store auth URL for frontend + // Using "|" as separator because URLs contain ":" + setOAuthStatus(state, "auth_url|"+authURL) + + // Wait for callback file + waitFile := filepath.Join(h.cfg.AuthDir, fmt.Sprintf(".oauth-kiro-%s.oauth", state)) + deadline := time.Now().Add(5 * time.Minute) + + for { + if time.Now().After(deadline) { + log.Error("oauth flow timed out") + setOAuthStatus(state, "OAuth flow timed out") + return + } + if data, errR := os.ReadFile(waitFile); errR == nil { + var m map[string]string + _ = json.Unmarshal(data, &m) + _ = os.Remove(waitFile) + if errStr := m["error"]; errStr != "" { + log.Errorf("Authentication failed: %s", errStr) + setOAuthStatus(state, "Authentication failed") + return + } + if m["state"] != state { + log.Errorf("State mismatch") + setOAuthStatus(state, "State mismatch") + return + } + code := m["code"] + if code == "" { + log.Error("No authorization code received") + setOAuthStatus(state, "No authorization code received") + return + } + + // Exchange code for tokens + tokenReq := &kiroauth.CreateTokenRequest{ + Code: code, + CodeVerifier: codeVerifier, + RedirectURI: kiroauth.KiroRedirectURI, + } + + tokenResp, errToken := socialClient.CreateToken(ctx, tokenReq) + if errToken != nil { + log.Errorf("Failed to exchange code for tokens: %v", errToken) + setOAuthStatus(state, "Failed to exchange code for tokens") + return + } + + // Save the token + expiresIn := tokenResp.ExpiresIn + if expiresIn <= 0 { + expiresIn = 3600 + } + expiresAt := time.Now().Add(time.Duration(expiresIn) * time.Second) + email := kiroauth.ExtractEmailFromJWT(tokenResp.AccessToken) + + idPart := kiroauth.SanitizeEmailForFilename(email) + if idPart == "" { + idPart = fmt.Sprintf("%d", time.Now().UnixNano()%100000) + } + + now := time.Now() + fileName := fmt.Sprintf("kiro-%s-%s.json", strings.ToLower(provider), idPart) + + record := &coreauth.Auth{ + ID: fileName, + Provider: "kiro", + FileName: fileName, + Metadata: map[string]any{ + "type": "kiro", + "access_token": tokenResp.AccessToken, + "refresh_token": tokenResp.RefreshToken, + "profile_arn": tokenResp.ProfileArn, + "expires_at": expiresAt.Format(time.RFC3339), + "auth_method": "social", + "provider": provider, + "email": email, + "last_refresh": now.Format(time.RFC3339), + }, + } + + savedPath, errSave := h.saveTokenRecord(ctx, record) + if errSave != nil { + log.Errorf("Failed to save authentication tokens: %v", errSave) + setOAuthStatus(state, "Failed to save authentication tokens") + return + } + + fmt.Printf("Authentication successful! Token saved to %s\n", savedPath) + if email != "" { + fmt.Printf("Authenticated as: %s\n", email) + } + deleteOAuthStatus(state) + return + } + time.Sleep(500 * time.Millisecond) + } + }() + + setOAuthStatus(state, "") + c.JSON(200, gin.H{"status": "ok", "state": state, "method": "social"}) + + default: + c.JSON(http.StatusBadRequest, gin.H{"error": "invalid method, use 'aws', 'google', or 'github'"}) + } +} + +// generateKiroPKCE generates PKCE code verifier and challenge for Kiro OAuth. +func generateKiroPKCE() (verifier, challenge string, err error) { + b := make([]byte, 32) + if _, err := io.ReadFull(rand.Reader, b); err != nil { + return "", "", fmt.Errorf("failed to generate random bytes: %w", err) + } + verifier = base64.RawURLEncoding.EncodeToString(b) + + h := sha256.Sum256([]byte(verifier)) + challenge = base64.RawURLEncoding.EncodeToString(h[:]) + + return verifier, challenge, nil +} diff --git a/internal/api/modules/amp/proxy.go b/internal/api/modules/amp/proxy.go index 5a3f2081..6ea092c4 100644 --- a/internal/api/modules/amp/proxy.go +++ b/internal/api/modules/amp/proxy.go @@ -7,13 +7,11 @@ import ( "errors" "fmt" "io" - "net" "net/http" "net/http/httputil" "net/url" "strconv" "strings" - "time" "github.com/gin-gonic/gin" log "github.com/sirupsen/logrus" @@ -38,22 +36,6 @@ func createReverseProxy(upstreamURL string, secretSource SecretSource) (*httputi } proxy := httputil.NewSingleHostReverseProxy(parsed) - - // Configure custom Transport with optimized connection pooling for high concurrency - proxy.Transport = &http.Transport{ - MaxIdleConns: 100, - MaxIdleConnsPerHost: 20, // Increased from default 2 to support more concurrent users - MaxConnsPerHost: 0, // No limit on max concurrent connections per host - IdleConnTimeout: 90 * time.Second, - DialContext: (&net.Dialer{ - Timeout: 30 * time.Second, - KeepAlive: 30 * time.Second, - }).DialContext, - TLSHandshakeTimeout: 10 * time.Second, - ResponseHeaderTimeout: 60 * time.Second, - ExpectContinueTimeout: 1 * time.Second, - } - originalDirector := proxy.Director // Modify outgoing requests to inject API key and fix routing diff --git a/internal/api/server.go b/internal/api/server.go index ade08fef..d702551e 100644 --- a/internal/api/server.go +++ b/internal/api/server.go @@ -421,6 +421,18 @@ func (s *Server) setupRoutes() { c.String(http.StatusOK, oauthCallbackSuccessHTML) }) + s.engine.GET("/kiro/callback", func(c *gin.Context) { + code := c.Query("code") + state := c.Query("state") + errStr := c.Query("error") + if state != "" { + file := fmt.Sprintf("%s/.oauth-kiro-%s.oauth", s.cfg.AuthDir, state) + _ = os.WriteFile(file, []byte(fmt.Sprintf(`{"code":"%s","state":"%s","error":"%s"}`, code, state, errStr)), 0o600) + } + c.Header("Content-Type", "text/html; charset=utf-8") + c.String(http.StatusOK, oauthCallbackSuccessHTML) + }) + // Management routes are registered lazily by registerManagementRoutes when a secret is configured. } @@ -586,6 +598,7 @@ func (s *Server) registerManagementRoutes() { mgmt.GET("/qwen-auth-url", s.mgmt.RequestQwenToken) mgmt.GET("/iflow-auth-url", s.mgmt.RequestIFlowToken) mgmt.POST("/iflow-auth-url", s.mgmt.RequestIFlowCookieToken) + mgmt.GET("/kiro-auth-url", s.mgmt.RequestKiroToken) mgmt.GET("/get-auth-status", s.mgmt.GetAuthStatus) } } diff --git a/internal/auth/kiro/social_auth.go b/internal/auth/kiro/social_auth.go index 61c67886..2ac29bf8 100644 --- a/internal/auth/kiro/social_auth.go +++ b/internal/auth/kiro/social_auth.go @@ -126,8 +126,8 @@ func (c *SocialAuthClient) buildLoginURL(provider, redirectURI, codeChallenge, s ) } -// createToken exchanges the authorization code for tokens. -func (c *SocialAuthClient) createToken(ctx context.Context, req *CreateTokenRequest) (*SocialTokenResponse, error) { +// CreateToken exchanges the authorization code for tokens. +func (c *SocialAuthClient) CreateToken(ctx context.Context, req *CreateTokenRequest) (*SocialTokenResponse, error) { body, err := json.Marshal(req) if err != nil { return nil, fmt.Errorf("failed to marshal token request: %w", err) @@ -326,7 +326,7 @@ func (c *SocialAuthClient) LoginWithSocial(ctx context.Context, provider SocialP RedirectURI: KiroRedirectURI, } - tokenResp, err := c.createToken(ctx, tokenReq) + tokenResp, err := c.CreateToken(ctx, tokenReq) if err != nil { return nil, fmt.Errorf("failed to exchange code for tokens: %w", err) } diff --git a/internal/runtime/executor/proxy_helpers.go b/internal/runtime/executor/proxy_helpers.go index 8ac91e03..4cda7b16 100644 --- a/internal/runtime/executor/proxy_helpers.go +++ b/internal/runtime/executor/proxy_helpers.go @@ -7,7 +7,6 @@ import ( "net/url" "strings" "sync" - "time" "github.com/router-for-me/CLIProxyAPI/v6/internal/config" cliproxyauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth" @@ -137,25 +136,15 @@ func buildProxyTransport(proxyURL string) *http.Transport { log.Errorf("create SOCKS5 dialer failed: %v", errSOCKS5) return nil } - // Set up a custom transport using the SOCKS5 dialer with optimized connection pooling + // Set up a custom transport using the SOCKS5 dialer transport = &http.Transport{ DialContext: func(ctx context.Context, network, addr string) (net.Conn, error) { return dialer.Dial(network, addr) }, - MaxIdleConns: 100, - MaxIdleConnsPerHost: 20, // Increased from default 2 to support more concurrent users - MaxConnsPerHost: 0, // No limit on max concurrent connections per host - IdleConnTimeout: 90 * time.Second, } } else if parsedURL.Scheme == "http" || parsedURL.Scheme == "https" { - // Configure HTTP or HTTPS proxy with optimized connection pooling - transport = &http.Transport{ - Proxy: http.ProxyURL(parsedURL), - MaxIdleConns: 100, - MaxIdleConnsPerHost: 20, // Increased from default 2 to support more concurrent users - MaxConnsPerHost: 0, // No limit on max concurrent connections per host - IdleConnTimeout: 90 * time.Second, - } + // Configure HTTP or HTTPS proxy + transport = &http.Transport{Proxy: http.ProxyURL(parsedURL)} } else { log.Errorf("unsupported proxy scheme: %s", parsedURL.Scheme) return nil diff --git a/internal/util/proxy.go b/internal/util/proxy.go index e5ac7cd6..aea52ba8 100644 --- a/internal/util/proxy.go +++ b/internal/util/proxy.go @@ -8,7 +8,6 @@ import ( "net" "net/http" "net/url" - "time" "github.com/router-for-me/CLIProxyAPI/v6/sdk/config" log "github.com/sirupsen/logrus" @@ -37,25 +36,15 @@ func SetProxy(cfg *config.SDKConfig, httpClient *http.Client) *http.Client { log.Errorf("create SOCKS5 dialer failed: %v", errSOCKS5) return httpClient } - // Set up a custom transport using the SOCKS5 dialer with optimized connection pooling + // Set up a custom transport using the SOCKS5 dialer. transport = &http.Transport{ DialContext: func(ctx context.Context, network, addr string) (net.Conn, error) { return dialer.Dial(network, addr) }, - MaxIdleConns: 100, - MaxIdleConnsPerHost: 20, // Increased from default 2 to support more concurrent users - MaxConnsPerHost: 0, // No limit on max concurrent connections per host - IdleConnTimeout: 90 * time.Second, } } else if proxyURL.Scheme == "http" || proxyURL.Scheme == "https" { - // Configure HTTP or HTTPS proxy with optimized connection pooling - transport = &http.Transport{ - Proxy: http.ProxyURL(proxyURL), - MaxIdleConns: 100, - MaxIdleConnsPerHost: 20, // Increased from default 2 to support more concurrent users - MaxConnsPerHost: 0, // No limit on max concurrent connections per host - IdleConnTimeout: 90 * time.Second, - } + // Configure HTTP or HTTPS proxy. + transport = &http.Transport{Proxy: http.ProxyURL(proxyURL)} } } // If a new transport was created, apply it to the HTTP client. From 1ea0cff3a45067f7c4839c728c3e59d92f521112 Mon Sep 17 00:00:00 2001 From: Ravens2121 Date: Sat, 13 Dec 2025 12:57:47 +0800 Subject: [PATCH 33/38] fix: add missing import declarations for net and time packages --- internal/api/modules/amp/proxy.go | 1 + internal/runtime/executor/proxy_helpers.go | 1 + 2 files changed, 2 insertions(+) diff --git a/internal/api/modules/amp/proxy.go b/internal/api/modules/amp/proxy.go index 6ea092c4..91716e36 100644 --- a/internal/api/modules/amp/proxy.go +++ b/internal/api/modules/amp/proxy.go @@ -7,6 +7,7 @@ import ( "errors" "fmt" "io" + "net" "net/http" "net/http/httputil" "net/url" diff --git a/internal/runtime/executor/proxy_helpers.go b/internal/runtime/executor/proxy_helpers.go index 4cda7b16..8998eb23 100644 --- a/internal/runtime/executor/proxy_helpers.go +++ b/internal/runtime/executor/proxy_helpers.go @@ -7,6 +7,7 @@ import ( "net/url" "strings" "sync" + "time" "github.com/router-for-me/CLIProxyAPI/v6/internal/config" cliproxyauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth" From 92ca5078c1aec13b4be77c83bb06a7b3917d7d9d Mon Sep 17 00:00:00 2001 From: Luis Pater Date: Sat, 13 Dec 2025 13:40:39 +0800 Subject: [PATCH 34/38] docs(readme): update contributors for Kiro integration (AWS CodeWhisperer) --- README.md | 2 +- README_CN.md | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/README.md b/README.md index 7a552135..d00e91c9 100644 --- a/README.md +++ b/README.md @@ -11,7 +11,7 @@ The Plus release stays in lockstep with the mainline features. ## Differences from the Mainline - Added GitHub Copilot support (OAuth login), provided by [em4go](https://github.com/em4go/CLIProxyAPI/tree/feature/github-copilot-auth) -- Added Kiro (AWS CodeWhisperer) support (OAuth login), provided by [fuko2935](https://github.com/fuko2935/CLIProxyAPI/tree/feature/kiro-integration) +- Added Kiro (AWS CodeWhisperer) support (OAuth login), provided by [fuko2935](https://github.com/fuko2935/CLIProxyAPI/tree/feature/kiro-integration), [Ravens2121](https://github.com/Ravens2121/CLIProxyAPIPlus/) ## Contributing diff --git a/README_CN.md b/README_CN.md index 163bec07..21132b86 100644 --- a/README_CN.md +++ b/README_CN.md @@ -11,7 +11,7 @@ ## 与主线版本版本差异 - 新增 GitHub Copilot 支持(OAuth 登录),由[em4go](https://github.com/em4go/CLIProxyAPI/tree/feature/github-copilot-auth)提供 -- 新增 Kiro (AWS CodeWhisperer) 支持 (OAuth 登录), 由[fuko2935](https://github.com/fuko2935/CLIProxyAPI/tree/feature/kiro-integration)提供 +- 新增 Kiro (AWS CodeWhisperer) 支持 (OAuth 登录), 由[fuko2935](https://github.com/fuko2935/CLIProxyAPI/tree/feature/kiro-integration)、[Ravens2121](https://github.com/Ravens2121/CLIProxyAPIPlus/)提供 ## 贡献 From 01cf2211671307f297ef9774b0c6dfd31f2f98b4 Mon Sep 17 00:00:00 2001 From: Ravens2121 Date: Sun, 14 Dec 2025 06:58:50 +0800 Subject: [PATCH 35/38] =?UTF-8?q?feat(kiro):=20=E4=BB=A3=E7=A0=81=E4=BC=98?= =?UTF-8?q?=E5=8C=96=E9=87=8D=E6=9E=84=20+=20OpenAI=E7=BF=BB=E8=AF=91?= =?UTF-8?q?=E5=99=A8=E5=AE=9E=E7=8E=B0?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- internal/api/modules/amp/response_rewriter.go | 60 +- internal/runtime/executor/kiro_executor.go | 3206 ++++------------- internal/translator/init.go | 2 +- internal/translator/kiro/claude/init.go | 5 +- .../translator/kiro/claude/kiro_claude.go | 18 +- .../kiro/claude/kiro_claude_request.go | 603 ++++ .../kiro/claude/kiro_claude_response.go | 184 + .../kiro/claude/kiro_claude_stream.go | 176 + .../kiro/claude/kiro_claude_tools.go | 522 +++ internal/translator/kiro/common/constants.go | 66 + .../translator/kiro/common/message_merge.go | 125 + internal/translator/kiro/common/utils.go | 16 + .../chat-completions/kiro_openai_request.go | 348 -- .../chat-completions/kiro_openai_response.go | 404 --- .../openai/{chat-completions => }/init.go | 13 +- .../translator/kiro/openai/kiro_openai.go | 368 ++ .../kiro/openai/kiro_openai_request.go | 604 ++++ .../kiro/openai/kiro_openai_response.go | 264 ++ .../kiro/openai/kiro_openai_stream.go | 207 ++ 19 files changed, 3898 insertions(+), 3293 deletions(-) create mode 100644 internal/translator/kiro/claude/kiro_claude_request.go create mode 100644 internal/translator/kiro/claude/kiro_claude_response.go create mode 100644 internal/translator/kiro/claude/kiro_claude_stream.go create mode 100644 internal/translator/kiro/claude/kiro_claude_tools.go create mode 100644 internal/translator/kiro/common/constants.go create mode 100644 internal/translator/kiro/common/message_merge.go create mode 100644 internal/translator/kiro/common/utils.go delete mode 100644 internal/translator/kiro/openai/chat-completions/kiro_openai_request.go delete mode 100644 internal/translator/kiro/openai/chat-completions/kiro_openai_response.go rename internal/translator/kiro/openai/{chat-completions => }/init.go (56%) create mode 100644 internal/translator/kiro/openai/kiro_openai.go create mode 100644 internal/translator/kiro/openai/kiro_openai_request.go create mode 100644 internal/translator/kiro/openai/kiro_openai_response.go create mode 100644 internal/translator/kiro/openai/kiro_openai_stream.go diff --git a/internal/api/modules/amp/response_rewriter.go b/internal/api/modules/amp/response_rewriter.go index e906f143..d78af9f1 100644 --- a/internal/api/modules/amp/response_rewriter.go +++ b/internal/api/modules/amp/response_rewriter.go @@ -29,15 +29,71 @@ func NewResponseRewriter(w gin.ResponseWriter, originalModel string) *ResponseRe } } +const maxBufferedResponseBytes = 2 * 1024 * 1024 // 2MB safety cap + +func looksLikeSSEChunk(data []byte) bool { + // Fallback detection: some upstreams may omit/lie about Content-Type, causing SSE to be buffered. + // Heuristics are intentionally simple and cheap. + return bytes.Contains(data, []byte("data:")) || + bytes.Contains(data, []byte("event:")) || + bytes.Contains(data, []byte("message_start")) || + bytes.Contains(data, []byte("message_delta")) || + bytes.Contains(data, []byte("content_block_start")) || + bytes.Contains(data, []byte("content_block_delta")) || + bytes.Contains(data, []byte("content_block_stop")) || + bytes.Contains(data, []byte("\n\n")) +} + +func (rw *ResponseRewriter) enableStreaming(reason string) error { + if rw.isStreaming { + return nil + } + rw.isStreaming = true + + // Flush any previously buffered data to avoid reordering or data loss. + if rw.body != nil && rw.body.Len() > 0 { + buf := rw.body.Bytes() + // Copy before Reset() to keep bytes stable. + toFlush := make([]byte, len(buf)) + copy(toFlush, buf) + rw.body.Reset() + + if _, err := rw.ResponseWriter.Write(rw.rewriteStreamChunk(toFlush)); err != nil { + return err + } + if flusher, ok := rw.ResponseWriter.(http.Flusher); ok { + flusher.Flush() + } + } + + log.Debugf("amp response rewriter: switched to streaming (%s)", reason) + return nil +} + // Write intercepts response writes and buffers them for model name replacement func (rw *ResponseRewriter) Write(data []byte) (int, error) { - // Detect streaming on first write - if rw.body.Len() == 0 && !rw.isStreaming { + // Detect streaming on first write (header-based) + if !rw.isStreaming && rw.body.Len() == 0 { contentType := rw.Header().Get("Content-Type") rw.isStreaming = strings.Contains(contentType, "text/event-stream") || strings.Contains(contentType, "stream") } + if !rw.isStreaming { + // Content-based fallback: detect SSE-like chunks even if Content-Type is missing/wrong. + if looksLikeSSEChunk(data) { + if err := rw.enableStreaming("sse heuristic"); err != nil { + return 0, err + } + } else if rw.body.Len()+len(data) > maxBufferedResponseBytes { + // Safety cap: avoid unbounded buffering on large responses. + log.Warnf("amp response rewriter: buffer exceeded %d bytes, switching to streaming", maxBufferedResponseBytes) + if err := rw.enableStreaming("buffer limit"); err != nil { + return 0, err + } + } + } + if rw.isStreaming { return rw.ResponseWriter.Write(rw.rewriteStreamChunk(data)) } diff --git a/internal/runtime/executor/kiro_executor.go b/internal/runtime/executor/kiro_executor.go index cbc5443b..1d4d85a5 100644 --- a/internal/runtime/executor/kiro_executor.go +++ b/internal/runtime/executor/kiro_executor.go @@ -10,38 +10,34 @@ import ( "fmt" "io" "net/http" - "regexp" "strings" "sync" "time" - "unicode/utf8" "github.com/google/uuid" kiroauth "github.com/router-for-me/CLIProxyAPI/v6/internal/auth/kiro" "github.com/router-for-me/CLIProxyAPI/v6/internal/config" + kiroclaude "github.com/router-for-me/CLIProxyAPI/v6/internal/translator/kiro/claude" + kirocommon "github.com/router-for-me/CLIProxyAPI/v6/internal/translator/kiro/common" "github.com/router-for-me/CLIProxyAPI/v6/internal/util" cliproxyauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth" cliproxyexecutor "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/executor" "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/usage" sdktranslator "github.com/router-for-me/CLIProxyAPI/v6/sdk/translator" log "github.com/sirupsen/logrus" - "github.com/tidwall/gjson" - "github.com/gin-gonic/gin" ) const ( // Kiro API common constants - kiroContentType = "application/x-amz-json-1.0" - kiroAcceptStream = "*/*" - kiroMaxMessageSize = 10 * 1024 * 1024 // 10MB max message size for event stream - kiroMaxToolDescLen = 10237 // Kiro API limit is 10240 bytes, leave room for "..." + kiroContentType = "application/x-amz-json-1.0" + kiroAcceptStream = "*/*" // Event Stream frame size constants for boundary protection // AWS Event Stream binary format: prelude (12 bytes) + headers + payload + message_crc (4 bytes) // Prelude consists of: total_length (4) + headers_length (4) + prelude_crc (4) - minEventStreamFrameSize = 16 // Minimum: 4(total_len) + 4(headers_len) + 4(prelude_crc) + 4(message_crc) - maxEventStreamMsgSize = 10 << 20 // Maximum message length: 10MB + minEventStreamFrameSize = 16 // Minimum: 4(total_len) + 4(headers_len) + 4(prelude_crc) + 4(message_crc) + maxEventStreamMsgSize = 10 << 20 // Maximum message length: 10MB // Event Stream error type constants ErrStreamFatal = "fatal" // Connection/authentication errors, not recoverable @@ -50,73 +46,13 @@ const ( kiroUserAgent = "aws-sdk-rust/1.3.9 os/macos lang/rust/1.87.0" // kiroFullUserAgent is the complete x-amz-user-agent header matching amq2api kiroFullUserAgent = "aws-sdk-rust/1.3.9 ua/2.1 api/ssooidc/1.88.0 os/macos lang/rust/1.87.0 m/E app/AmazonQ-For-CLI" - - // Thinking mode support - based on amq2api implementation - // These tags wrap reasoning content in the response stream - thinkingStartTag = "" - thinkingEndTag = "" - // thinkingHint is injected into the request to enable interleaved thinking mode - // This tells the model to use thinking tags and sets the max thinking length - thinkingHint = "interleaved16000" - - // kiroAgenticSystemPrompt is injected only for -agentic models to prevent timeouts on large writes. - // AWS Kiro API has a 2-3 minute timeout for large file write operations. - kiroAgenticSystemPrompt = ` -# CRITICAL: CHUNKED WRITE PROTOCOL (MANDATORY) - -You MUST follow these rules for ALL file operations. Violation causes server timeouts and task failure. - -## ABSOLUTE LIMITS -- **MAXIMUM 350 LINES** per single write/edit operation - NO EXCEPTIONS -- **RECOMMENDED 300 LINES** or less for optimal performance -- **NEVER** write entire files in one operation if >300 lines - -## MANDATORY CHUNKED WRITE STRATEGY - -### For NEW FILES (>300 lines total): -1. FIRST: Write initial chunk (first 250-300 lines) using write_to_file/fsWrite -2. THEN: Append remaining content in 250-300 line chunks using file append operations -3. REPEAT: Continue appending until complete - -### For EDITING EXISTING FILES: -1. Use surgical edits (apply_diff/targeted edits) - change ONLY what's needed -2. NEVER rewrite entire files - use incremental modifications -3. Split large refactors into multiple small, focused edits - -### For LARGE CODE GENERATION: -1. Generate in logical sections (imports, types, functions separately) -2. Write each section as a separate operation -3. Use append operations for subsequent sections - -## EXAMPLES OF CORRECT BEHAVIOR - -✅ CORRECT: Writing a 600-line file -- Operation 1: Write lines 1-300 (initial file creation) -- Operation 2: Append lines 301-600 - -✅ CORRECT: Editing multiple functions -- Operation 1: Edit function A -- Operation 2: Edit function B -- Operation 3: Edit function C - -❌ WRONG: Writing 500 lines in single operation → TIMEOUT -❌ WRONG: Rewriting entire file to change 5 lines → TIMEOUT -❌ WRONG: Generating massive code blocks without chunking → TIMEOUT - -## WHY THIS MATTERS -- Server has 2-3 minute timeout for operations -- Large writes exceed timeout and FAIL completely -- Chunked writes are FASTER and more RELIABLE -- Failed writes waste time and require retry - -REMEMBER: When in doubt, write LESS per operation. Multiple small operations > one large operation.` ) // Real-time usage estimation configuration // These control how often usage updates are sent during streaming var ( - usageUpdateCharThreshold = 5000 // Send usage update every 5000 characters - usageUpdateTimeInterval = 15 * time.Second // Or every 15 seconds, whichever comes first + usageUpdateCharThreshold = 5000 // Send usage update every 5000 characters + usageUpdateTimeInterval = 15 * time.Second // Or every 15 seconds, whichever comes first ) // kiroEndpointConfig bundles endpoint URL with its compatible Origin and AmzTarget values. @@ -186,7 +122,7 @@ func getKiroEndpointConfigs(auth *cliproxyauth.Auth) []kiroEndpointConfig { } preference = strings.ToLower(strings.TrimSpace(preference)) - + // Create new slice to avoid modifying global state var sorted []kiroEndpointConfig var remaining []kiroEndpointConfig @@ -221,8 +157,8 @@ func getKiroEndpointConfigs(auth *cliproxyauth.Auth) []kiroEndpointConfig { // KiroExecutor handles requests to AWS CodeWhisperer (Kiro) API. type KiroExecutor struct { - cfg *config.Config - refreshMu sync.Mutex // Serializes token refresh operations to prevent race conditions + cfg *config.Config + refreshMu sync.Mutex // Serializes token refresh operations to prevent race conditions } // NewKiroExecutor creates a new Kiro executor instance. @@ -236,7 +172,6 @@ func (e *KiroExecutor) Identifier() string { return "kiro" } // PrepareRequest prepares the HTTP request before execution. func (e *KiroExecutor) PrepareRequest(_ *http.Request, _ *cliproxyauth.Auth) error { return nil } - // Execute sends the request to Kiro API and returns the response. // Supports automatic token refresh on 401/403 errors. func (e *KiroExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (resp cliproxyexecutor.Response, err error) { @@ -244,14 +179,6 @@ func (e *KiroExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, req if accessToken == "" { return resp, fmt.Errorf("kiro: access token not found in auth") } - if profileArn == "" { - // Only warn if not using builder-id auth (which doesn't need profileArn) - if auth == nil || auth.Metadata == nil { - log.Debugf("kiro: profile ARN not found in auth (may be normal for builder-id)") - } else if authMethod, ok := auth.Metadata["auth_method"].(string); !ok || authMethod != "builder-id" { - log.Warnf("kiro: profile ARN not found in auth, API calls may fail") - } - } reporter := newUsageReporter(ctx, e.Identifier(), req.Model, auth) defer reporter.trackFailure(ctx, &err) @@ -274,31 +201,14 @@ func (e *KiroExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, req body := sdktranslator.TranslateRequest(from, to, req.Model, bytes.Clone(req.Payload), true) kiroModelID := e.mapModelToKiro(req.Model) - - // Check if this is an agentic model variant - isAgentic := strings.HasSuffix(req.Model, "-agentic") - - // Check if this is a chat-only model variant (no tool calling) - isChatOnly := strings.HasSuffix(req.Model, "-chat") - - // Determine initial origin - always use AI_EDITOR to match AIClient-2-API behavior - // AIClient-2-API uses AI_EDITOR for all models, which is the Kiro IDE quota - // Note: CLI origin is for Amazon Q quota, but AIClient-2-API doesn't use it - currentOrigin := "AI_EDITOR" - - // Determine if profileArn should be included based on auth method - // profileArn is only needed for social auth (Google OAuth), not for builder-id (AWS SSO) - effectiveProfileArn := profileArn - if auth != nil && auth.Metadata != nil { - if authMethod, ok := auth.Metadata["auth_method"].(string); ok && authMethod == "builder-id" { - effectiveProfileArn = "" // Don't include profileArn for builder-id auth - } - } - - kiroPayload := e.buildKiroPayload(body, kiroModelID, effectiveProfileArn, currentOrigin, isAgentic, isChatOnly) + + // Determine agentic mode and effective profile ARN using helper functions + isAgentic, isChatOnly := determineAgenticMode(req.Model) + effectiveProfileArn := getEffectiveProfileArnWithWarning(auth, profileArn) // Execute with retry on 401/403 and 429 (quota exhausted) - resp, err = e.executeWithRetry(ctx, auth, req, opts, accessToken, effectiveProfileArn, kiroPayload, body, from, to, reporter, currentOrigin, kiroModelID, isAgentic, isChatOnly) + // Note: currentOrigin and kiroPayload are built inside executeWithRetry for each endpoint + resp, err = e.executeWithRetry(ctx, auth, req, opts, accessToken, effectiveProfileArn, nil, body, from, to, reporter, "", kiroModelID, isAgentic, isChatOnly) return resp, err } @@ -311,247 +221,252 @@ func (e *KiroExecutor) executeWithRetry(ctx context.Context, auth *cliproxyauth. var resp cliproxyexecutor.Response maxRetries := 2 // Allow retries for token refresh + endpoint fallback endpointConfigs := getKiroEndpointConfigs(auth) + var last429Err error for endpointIdx := 0; endpointIdx < len(endpointConfigs); endpointIdx++ { endpointConfig := endpointConfigs[endpointIdx] url := endpointConfig.URL // Use this endpoint's compatible Origin (critical for avoiding 403 errors) currentOrigin = endpointConfig.Origin - + // Rebuild payload with the correct origin for this endpoint // Each endpoint requires its matching Origin value in the request body - kiroPayload = e.buildKiroPayload(body, kiroModelID, profileArn, currentOrigin, isAgentic, isChatOnly) - + kiroPayload = kiroclaude.BuildKiroPayload(body, kiroModelID, profileArn, currentOrigin, isAgentic, isChatOnly) + log.Debugf("kiro: trying endpoint %d/%d: %s (Name: %s, Origin: %s)", endpointIdx+1, len(endpointConfigs), url, endpointConfig.Name, currentOrigin) - for attempt := 0; attempt <= maxRetries; attempt++ { - httpReq, err := http.NewRequestWithContext(ctx, http.MethodPost, url, bytes.NewReader(kiroPayload)) - if err != nil { - return resp, err - } - - httpReq.Header.Set("Content-Type", kiroContentType) - httpReq.Header.Set("Authorization", "Bearer "+accessToken) - httpReq.Header.Set("Accept", kiroAcceptStream) - // Use endpoint-specific X-Amz-Target (critical for avoiding 403 errors) - httpReq.Header.Set("X-Amz-Target", endpointConfig.AmzTarget) - httpReq.Header.Set("User-Agent", kiroUserAgent) - httpReq.Header.Set("X-Amz-User-Agent", kiroFullUserAgent) - httpReq.Header.Set("Amz-Sdk-Request", "attempt=1; max=3") - httpReq.Header.Set("Amz-Sdk-Invocation-Id", uuid.New().String()) - - var attrs map[string]string - if auth != nil { - attrs = auth.Attributes - } - util.ApplyCustomHeadersFromAttrs(httpReq, attrs) - - var authID, authLabel, authType, authValue string - if auth != nil { - authID = auth.ID - authLabel = auth.Label - authType, authValue = auth.AccountInfo() - } - recordAPIRequest(ctx, e.cfg, upstreamRequestLog{ - URL: url, - Method: http.MethodPost, - Headers: httpReq.Header.Clone(), - Body: kiroPayload, - Provider: e.Identifier(), - AuthID: authID, - AuthLabel: authLabel, - AuthType: authType, - AuthValue: authValue, - }) - - httpClient := newProxyAwareHTTPClient(ctx, e.cfg, auth, 120*time.Second) - httpResp, err := httpClient.Do(httpReq) - if err != nil { - recordAPIResponseError(ctx, e.cfg, err) - return resp, err - } - recordAPIResponseMetadata(ctx, e.cfg, httpResp.StatusCode, httpResp.Header.Clone()) - - // Handle 429 errors (quota exhausted) - try next endpoint - // Each endpoint has its own quota pool, so we can try different endpoints - if httpResp.StatusCode == 429 { - respBody, _ := io.ReadAll(httpResp.Body) - _ = httpResp.Body.Close() - appendAPIResponseChunk(ctx, e.cfg, respBody) - - log.Warnf("kiro: %s endpoint quota exhausted (429), will try next endpoint", endpointConfig.Name) - - // Break inner retry loop to try next endpoint (which has different quota) - break - } - - // Handle 5xx server errors with exponential backoff retry - if httpResp.StatusCode >= 500 && httpResp.StatusCode < 600 { - respBody, _ := io.ReadAll(httpResp.Body) - _ = httpResp.Body.Close() - appendAPIResponseChunk(ctx, e.cfg, respBody) - - if attempt < maxRetries { - // Exponential backoff: 1s, 2s, 4s... (max 30s) - backoff := time.Duration(1< 30*time.Second { - backoff = 30 * time.Second - } - log.Warnf("kiro: server error %d, retrying in %v (attempt %d/%d)", httpResp.StatusCode, backoff, attempt+1, maxRetries) - time.Sleep(backoff) - continue + for attempt := 0; attempt <= maxRetries; attempt++ { + httpReq, err := http.NewRequestWithContext(ctx, http.MethodPost, url, bytes.NewReader(kiroPayload)) + if err != nil { + return resp, err } - log.Errorf("kiro: server error %d after %d retries", httpResp.StatusCode, maxRetries) - return resp, statusErr{code: httpResp.StatusCode, msg: string(respBody)} - } - // Handle 401 errors with token refresh and retry - // 401 = Unauthorized (token expired/invalid) - refresh token - if httpResp.StatusCode == 401 { - respBody, _ := io.ReadAll(httpResp.Body) - _ = httpResp.Body.Close() - appendAPIResponseChunk(ctx, e.cfg, respBody) + httpReq.Header.Set("Content-Type", kiroContentType) + httpReq.Header.Set("Authorization", "Bearer "+accessToken) + httpReq.Header.Set("Accept", kiroAcceptStream) + // Use endpoint-specific X-Amz-Target (critical for avoiding 403 errors) + httpReq.Header.Set("X-Amz-Target", endpointConfig.AmzTarget) + httpReq.Header.Set("User-Agent", kiroUserAgent) + httpReq.Header.Set("X-Amz-User-Agent", kiroFullUserAgent) + httpReq.Header.Set("Amz-Sdk-Request", "attempt=1; max=3") + httpReq.Header.Set("Amz-Sdk-Invocation-Id", uuid.New().String()) - if attempt < maxRetries { - log.Warnf("kiro: received 401 error, attempting token refresh and retry (attempt %d/%d)", attempt+1, maxRetries+1) + var attrs map[string]string + if auth != nil { + attrs = auth.Attributes + } + util.ApplyCustomHeadersFromAttrs(httpReq, attrs) - refreshedAuth, refreshErr := e.Refresh(ctx, auth) - if refreshErr != nil { - log.Errorf("kiro: token refresh failed: %v", refreshErr) - return resp, statusErr{code: httpResp.StatusCode, msg: string(respBody)} - } + var authID, authLabel, authType, authValue string + if auth != nil { + authID = auth.ID + authLabel = auth.Label + authType, authValue = auth.AccountInfo() + } + recordAPIRequest(ctx, e.cfg, upstreamRequestLog{ + URL: url, + Method: http.MethodPost, + Headers: httpReq.Header.Clone(), + Body: kiroPayload, + Provider: e.Identifier(), + AuthID: authID, + AuthLabel: authLabel, + AuthType: authType, + AuthValue: authValue, + }) - if refreshedAuth != nil { - auth = refreshedAuth - accessToken, profileArn = kiroCredentials(auth) - // Rebuild payload with new profile ARN if changed - kiroPayload = e.buildKiroPayload(body, kiroModelID, profileArn, currentOrigin, isAgentic, isChatOnly) - log.Infof("kiro: token refreshed successfully, retrying request") + httpClient := newProxyAwareHTTPClient(ctx, e.cfg, auth, 120*time.Second) + httpResp, err := httpClient.Do(httpReq) + if err != nil { + recordAPIResponseError(ctx, e.cfg, err) + return resp, err + } + recordAPIResponseMetadata(ctx, e.cfg, httpResp.StatusCode, httpResp.Header.Clone()) + + // Handle 429 errors (quota exhausted) - try next endpoint + // Each endpoint has its own quota pool, so we can try different endpoints + if httpResp.StatusCode == 429 { + respBody, _ := io.ReadAll(httpResp.Body) + _ = httpResp.Body.Close() + appendAPIResponseChunk(ctx, e.cfg, respBody) + + // Preserve last 429 so callers can correctly backoff when all endpoints are exhausted + last429Err = statusErr{code: httpResp.StatusCode, msg: string(respBody)} + + log.Warnf("kiro: %s endpoint quota exhausted (429), will try next endpoint, body: %s", + endpointConfig.Name, summarizeErrorBody(httpResp.Header.Get("Content-Type"), respBody)) + + // Break inner retry loop to try next endpoint (which has different quota) + break + } + + // Handle 5xx server errors with exponential backoff retry + if httpResp.StatusCode >= 500 && httpResp.StatusCode < 600 { + respBody, _ := io.ReadAll(httpResp.Body) + _ = httpResp.Body.Close() + appendAPIResponseChunk(ctx, e.cfg, respBody) + + if attempt < maxRetries { + // Exponential backoff: 1s, 2s, 4s... (max 30s) + backoff := time.Duration(1< 30*time.Second { + backoff = 30 * time.Second + } + log.Warnf("kiro: server error %d, retrying in %v (attempt %d/%d)", httpResp.StatusCode, backoff, attempt+1, maxRetries) + time.Sleep(backoff) continue } + log.Errorf("kiro: server error %d after %d retries", httpResp.StatusCode, maxRetries) + return resp, statusErr{code: httpResp.StatusCode, msg: string(respBody)} } - log.Warnf("kiro request error, status: 401, body: %s", summarizeErrorBody(httpResp.Header.Get("Content-Type"), respBody)) - return resp, statusErr{code: httpResp.StatusCode, msg: string(respBody)} - } + // Handle 401 errors with token refresh and retry + // 401 = Unauthorized (token expired/invalid) - refresh token + if httpResp.StatusCode == 401 { + respBody, _ := io.ReadAll(httpResp.Body) + _ = httpResp.Body.Close() + appendAPIResponseChunk(ctx, e.cfg, respBody) - // Handle 402 errors - Monthly Limit Reached - if httpResp.StatusCode == 402 { - respBody, _ := io.ReadAll(httpResp.Body) - _ = httpResp.Body.Close() - appendAPIResponseChunk(ctx, e.cfg, respBody) + if attempt < maxRetries { + log.Warnf("kiro: received 401 error, attempting token refresh and retry (attempt %d/%d)", attempt+1, maxRetries+1) - log.Warnf("kiro: received 402 (monthly limit). Upstream body: %s", string(respBody)) + refreshedAuth, refreshErr := e.Refresh(ctx, auth) + if refreshErr != nil { + log.Errorf("kiro: token refresh failed: %v", refreshErr) + return resp, statusErr{code: httpResp.StatusCode, msg: string(respBody)} + } - // Return upstream error body directly - return resp, statusErr{code: httpResp.StatusCode, msg: string(respBody)} - } - - // Handle 403 errors - Access Denied / Token Expired - // Do NOT switch endpoints for 403 errors - if httpResp.StatusCode == 403 { - respBody, _ := io.ReadAll(httpResp.Body) - _ = httpResp.Body.Close() - appendAPIResponseChunk(ctx, e.cfg, respBody) - - // Log the 403 error details for debugging - log.Warnf("kiro: received 403 error (attempt %d/%d), body: %s", attempt+1, maxRetries+1, summarizeErrorBody(httpResp.Header.Get("Content-Type"), respBody)) - - respBodyStr := string(respBody) - - // Check for SUSPENDED status - return immediately without retry - if strings.Contains(respBodyStr, "SUSPENDED") || strings.Contains(respBodyStr, "TEMPORARILY_SUSPENDED") { - log.Errorf("kiro: account is suspended, cannot proceed") - return resp, statusErr{code: httpResp.StatusCode, msg: "account suspended: " + string(respBody)} - } - - // Check if this looks like a token-related 403 (some APIs return 403 for expired tokens) - isTokenRelated := strings.Contains(respBodyStr, "token") || - strings.Contains(respBodyStr, "expired") || - strings.Contains(respBodyStr, "invalid") || - strings.Contains(respBodyStr, "unauthorized") - - if isTokenRelated && attempt < maxRetries { - log.Warnf("kiro: 403 appears token-related, attempting token refresh") - refreshedAuth, refreshErr := e.Refresh(ctx, auth) - if refreshErr != nil { - log.Errorf("kiro: token refresh failed: %v", refreshErr) - // Token refresh failed - return error immediately - return resp, statusErr{code: httpResp.StatusCode, msg: string(respBody)} - } - if refreshedAuth != nil { - auth = refreshedAuth - accessToken, profileArn = kiroCredentials(auth) - kiroPayload = e.buildKiroPayload(body, kiroModelID, profileArn, currentOrigin, isAgentic, isChatOnly) - log.Infof("kiro: token refreshed for 403, retrying request") - continue + if refreshedAuth != nil { + auth = refreshedAuth + accessToken, profileArn = kiroCredentials(auth) + // Rebuild payload with new profile ARN if changed + kiroPayload = kiroclaude.BuildKiroPayload(body, kiroModelID, profileArn, currentOrigin, isAgentic, isChatOnly) + log.Infof("kiro: token refreshed successfully, retrying request") + continue + } } + + log.Warnf("kiro request error, status: 401, body: %s", summarizeErrorBody(httpResp.Header.Get("Content-Type"), respBody)) + return resp, statusErr{code: httpResp.StatusCode, msg: string(respBody)} } - // For non-token 403 or after max retries, return error immediately + // Handle 402 errors - Monthly Limit Reached + if httpResp.StatusCode == 402 { + respBody, _ := io.ReadAll(httpResp.Body) + _ = httpResp.Body.Close() + appendAPIResponseChunk(ctx, e.cfg, respBody) + + log.Warnf("kiro: received 402 (monthly limit). Upstream body: %s", string(respBody)) + + // Return upstream error body directly + return resp, statusErr{code: httpResp.StatusCode, msg: string(respBody)} + } + + // Handle 403 errors - Access Denied / Token Expired // Do NOT switch endpoints for 403 errors - log.Warnf("kiro: 403 error, returning immediately (no endpoint switch)") - return resp, statusErr{code: httpResp.StatusCode, msg: string(respBody)} - } + if httpResp.StatusCode == 403 { + respBody, _ := io.ReadAll(httpResp.Body) + _ = httpResp.Body.Close() + appendAPIResponseChunk(ctx, e.cfg, respBody) - if httpResp.StatusCode < 200 || httpResp.StatusCode >= 300 { - b, _ := io.ReadAll(httpResp.Body) - appendAPIResponseChunk(ctx, e.cfg, b) - log.Debugf("kiro request error, status: %d, body: %s", httpResp.StatusCode, summarizeErrorBody(httpResp.Header.Get("Content-Type"), b)) - err = statusErr{code: httpResp.StatusCode, msg: string(b)} - if errClose := httpResp.Body.Close(); errClose != nil { - log.Errorf("response body close error: %v", errClose) - } - return resp, err - } + // Log the 403 error details for debugging + log.Warnf("kiro: received 403 error (attempt %d/%d), body: %s", attempt+1, maxRetries+1, summarizeErrorBody(httpResp.Header.Get("Content-Type"), respBody)) - defer func() { - if errClose := httpResp.Body.Close(); errClose != nil { - log.Errorf("response body close error: %v", errClose) - } - }() + respBodyStr := string(respBody) - content, toolUses, usageInfo, stopReason, err := e.parseEventStream(httpResp.Body) - if err != nil { - recordAPIResponseError(ctx, e.cfg, err) - return resp, err - } - - // Fallback for usage if missing from upstream - if usageInfo.TotalTokens == 0 { - if enc, encErr := getTokenizer(req.Model); encErr == nil { - if inp, countErr := countOpenAIChatTokens(enc, opts.OriginalRequest); countErr == nil { - usageInfo.InputTokens = inp + // Check for SUSPENDED status - return immediately without retry + if strings.Contains(respBodyStr, "SUSPENDED") || strings.Contains(respBodyStr, "TEMPORARILY_SUSPENDED") { + log.Errorf("kiro: account is suspended, cannot proceed") + return resp, statusErr{code: httpResp.StatusCode, msg: "account suspended: " + string(respBody)} } + + // Check if this looks like a token-related 403 (some APIs return 403 for expired tokens) + isTokenRelated := strings.Contains(respBodyStr, "token") || + strings.Contains(respBodyStr, "expired") || + strings.Contains(respBodyStr, "invalid") || + strings.Contains(respBodyStr, "unauthorized") + + if isTokenRelated && attempt < maxRetries { + log.Warnf("kiro: 403 appears token-related, attempting token refresh") + refreshedAuth, refreshErr := e.Refresh(ctx, auth) + if refreshErr != nil { + log.Errorf("kiro: token refresh failed: %v", refreshErr) + // Token refresh failed - return error immediately + return resp, statusErr{code: httpResp.StatusCode, msg: string(respBody)} + } + if refreshedAuth != nil { + auth = refreshedAuth + accessToken, profileArn = kiroCredentials(auth) + kiroPayload = kiroclaude.BuildKiroPayload(body, kiroModelID, profileArn, currentOrigin, isAgentic, isChatOnly) + log.Infof("kiro: token refreshed for 403, retrying request") + continue + } + } + + // For non-token 403 or after max retries, return error immediately + // Do NOT switch endpoints for 403 errors + log.Warnf("kiro: 403 error, returning immediately (no endpoint switch)") + return resp, statusErr{code: httpResp.StatusCode, msg: string(respBody)} } - if len(content) > 0 { - // Use tiktoken for more accurate output token calculation + + if httpResp.StatusCode < 200 || httpResp.StatusCode >= 300 { + b, _ := io.ReadAll(httpResp.Body) + appendAPIResponseChunk(ctx, e.cfg, b) + log.Debugf("kiro request error, status: %d, body: %s", httpResp.StatusCode, summarizeErrorBody(httpResp.Header.Get("Content-Type"), b)) + err = statusErr{code: httpResp.StatusCode, msg: string(b)} + if errClose := httpResp.Body.Close(); errClose != nil { + log.Errorf("response body close error: %v", errClose) + } + return resp, err + } + + defer func() { + if errClose := httpResp.Body.Close(); errClose != nil { + log.Errorf("response body close error: %v", errClose) + } + }() + + content, toolUses, usageInfo, stopReason, err := e.parseEventStream(httpResp.Body) + if err != nil { + recordAPIResponseError(ctx, e.cfg, err) + return resp, err + } + + // Fallback for usage if missing from upstream + if usageInfo.TotalTokens == 0 { if enc, encErr := getTokenizer(req.Model); encErr == nil { - if tokenCount, countErr := enc.Count(content); countErr == nil { - usageInfo.OutputTokens = int64(tokenCount) + if inp, countErr := countOpenAIChatTokens(enc, opts.OriginalRequest); countErr == nil { + usageInfo.InputTokens = inp } } - // Fallback to character count estimation if tiktoken fails - if usageInfo.OutputTokens == 0 { - usageInfo.OutputTokens = int64(len(content) / 4) + if len(content) > 0 { + // Use tiktoken for more accurate output token calculation + if enc, encErr := getTokenizer(req.Model); encErr == nil { + if tokenCount, countErr := enc.Count(content); countErr == nil { + usageInfo.OutputTokens = int64(tokenCount) + } + } + // Fallback to character count estimation if tiktoken fails if usageInfo.OutputTokens == 0 { - usageInfo.OutputTokens = 1 + usageInfo.OutputTokens = int64(len(content) / 4) + if usageInfo.OutputTokens == 0 { + usageInfo.OutputTokens = 1 + } } } + usageInfo.TotalTokens = usageInfo.InputTokens + usageInfo.OutputTokens } - usageInfo.TotalTokens = usageInfo.InputTokens + usageInfo.OutputTokens - } - appendAPIResponseChunk(ctx, e.cfg, []byte(content)) - reporter.publish(ctx, usageInfo) + appendAPIResponseChunk(ctx, e.cfg, []byte(content)) + reporter.publish(ctx, usageInfo) - // Build response in Claude format for Kiro translator - // stopReason is extracted from upstream response by parseEventStream - kiroResponse := e.buildClaudeResponse(content, toolUses, req.Model, usageInfo, stopReason) - out := sdktranslator.TranslateNonStream(ctx, to, from, req.Model, bytes.Clone(opts.OriginalRequest), body, kiroResponse, nil) - resp = cliproxyexecutor.Response{Payload: []byte(out)} - return resp, nil + // Build response in Claude format for Kiro translator + // stopReason is extracted from upstream response by parseEventStream + kiroResponse := kiroclaude.BuildClaudeResponse(content, toolUses, req.Model, usageInfo, stopReason) + out := sdktranslator.TranslateNonStream(ctx, to, from, req.Model, bytes.Clone(opts.OriginalRequest), body, kiroResponse, nil) + resp = cliproxyexecutor.Response{Payload: []byte(out)} + return resp, nil } // Inner retry loop exhausted for this endpoint, try next endpoint // Note: This code is unreachable because all paths in the inner loop @@ -559,6 +474,9 @@ func (e *KiroExecutor) executeWithRetry(ctx context.Context, auth *cliproxyauth. } // All endpoints exhausted + if last429Err != nil { + return resp, last429Err + } return resp, fmt.Errorf("kiro: all endpoints exhausted") } @@ -569,14 +487,6 @@ func (e *KiroExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Aut if accessToken == "" { return nil, fmt.Errorf("kiro: access token not found in auth") } - if profileArn == "" { - // Only warn if not using builder-id auth (which doesn't need profileArn) - if auth == nil || auth.Metadata == nil { - log.Debugf("kiro: profile ARN not found in auth (may be normal for builder-id)") - } else if authMethod, ok := auth.Metadata["auth_method"].(string); !ok || authMethod != "builder-id" { - log.Warnf("kiro: profile ARN not found in auth, API calls may fail") - } - } reporter := newUsageReporter(ctx, e.Identifier(), req.Model, auth) defer reporter.trackFailure(ctx, &err) @@ -599,30 +509,14 @@ func (e *KiroExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Aut body := sdktranslator.TranslateRequest(from, to, req.Model, bytes.Clone(req.Payload), true) kiroModelID := e.mapModelToKiro(req.Model) - - // Check if this is an agentic model variant - isAgentic := strings.HasSuffix(req.Model, "-agentic") - - // Check if this is a chat-only model variant (no tool calling) - isChatOnly := strings.HasSuffix(req.Model, "-chat") - - // Determine initial origin - always use AI_EDITOR to match AIClient-2-API behavior - // AIClient-2-API uses AI_EDITOR for all models, which is the Kiro IDE quota - currentOrigin := "AI_EDITOR" - - // Determine if profileArn should be included based on auth method - // profileArn is only needed for social auth (Google OAuth), not for builder-id (AWS SSO) - effectiveProfileArn := profileArn - if auth != nil && auth.Metadata != nil { - if authMethod, ok := auth.Metadata["auth_method"].(string); ok && authMethod == "builder-id" { - effectiveProfileArn = "" // Don't include profileArn for builder-id auth - } - } - - kiroPayload := e.buildKiroPayload(body, kiroModelID, effectiveProfileArn, currentOrigin, isAgentic, isChatOnly) + + // Determine agentic mode and effective profile ARN using helper functions + isAgentic, isChatOnly := determineAgenticMode(req.Model) + effectiveProfileArn := getEffectiveProfileArnWithWarning(auth, profileArn) // Execute stream with retry on 401/403 and 429 (quota exhausted) - return e.executeStreamWithRetry(ctx, auth, req, opts, accessToken, effectiveProfileArn, kiroPayload, body, from, reporter, currentOrigin, kiroModelID, isAgentic, isChatOnly) + // Note: currentOrigin and kiroPayload are built inside executeStreamWithRetry for each endpoint + return e.executeStreamWithRetry(ctx, auth, req, opts, accessToken, effectiveProfileArn, nil, body, from, reporter, "", kiroModelID, isAgentic, isChatOnly) } // executeStreamWithRetry performs the streaming HTTP request with automatic retry on auth errors. @@ -633,233 +527,238 @@ func (e *KiroExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Aut func (e *KiroExecutor) executeStreamWithRetry(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options, accessToken, profileArn string, kiroPayload, body []byte, from sdktranslator.Format, reporter *usageReporter, currentOrigin, kiroModelID string, isAgentic, isChatOnly bool) (<-chan cliproxyexecutor.StreamChunk, error) { maxRetries := 2 // Allow retries for token refresh + endpoint fallback endpointConfigs := getKiroEndpointConfigs(auth) + var last429Err error for endpointIdx := 0; endpointIdx < len(endpointConfigs); endpointIdx++ { endpointConfig := endpointConfigs[endpointIdx] url := endpointConfig.URL // Use this endpoint's compatible Origin (critical for avoiding 403 errors) currentOrigin = endpointConfig.Origin - + // Rebuild payload with the correct origin for this endpoint // Each endpoint requires its matching Origin value in the request body - kiroPayload = e.buildKiroPayload(body, kiroModelID, profileArn, currentOrigin, isAgentic, isChatOnly) - + kiroPayload = kiroclaude.BuildKiroPayload(body, kiroModelID, profileArn, currentOrigin, isAgentic, isChatOnly) + log.Debugf("kiro: stream trying endpoint %d/%d: %s (Name: %s, Origin: %s)", endpointIdx+1, len(endpointConfigs), url, endpointConfig.Name, currentOrigin) - for attempt := 0; attempt <= maxRetries; attempt++ { - httpReq, err := http.NewRequestWithContext(ctx, http.MethodPost, url, bytes.NewReader(kiroPayload)) - if err != nil { - return nil, err - } - - httpReq.Header.Set("Content-Type", kiroContentType) - httpReq.Header.Set("Authorization", "Bearer "+accessToken) - httpReq.Header.Set("Accept", kiroAcceptStream) - // Use endpoint-specific X-Amz-Target (critical for avoiding 403 errors) - httpReq.Header.Set("X-Amz-Target", endpointConfig.AmzTarget) - httpReq.Header.Set("User-Agent", kiroUserAgent) - httpReq.Header.Set("X-Amz-User-Agent", kiroFullUserAgent) - httpReq.Header.Set("Amz-Sdk-Request", "attempt=1; max=3") - httpReq.Header.Set("Amz-Sdk-Invocation-Id", uuid.New().String()) - - var attrs map[string]string - if auth != nil { - attrs = auth.Attributes - } - util.ApplyCustomHeadersFromAttrs(httpReq, attrs) - - var authID, authLabel, authType, authValue string - if auth != nil { - authID = auth.ID - authLabel = auth.Label - authType, authValue = auth.AccountInfo() - } - recordAPIRequest(ctx, e.cfg, upstreamRequestLog{ - URL: url, - Method: http.MethodPost, - Headers: httpReq.Header.Clone(), - Body: kiroPayload, - Provider: e.Identifier(), - AuthID: authID, - AuthLabel: authLabel, - AuthType: authType, - AuthValue: authValue, - }) - - httpClient := newProxyAwareHTTPClient(ctx, e.cfg, auth, 0) - httpResp, err := httpClient.Do(httpReq) - if err != nil { - recordAPIResponseError(ctx, e.cfg, err) - return nil, err - } - recordAPIResponseMetadata(ctx, e.cfg, httpResp.StatusCode, httpResp.Header.Clone()) - - // Handle 429 errors (quota exhausted) - try next endpoint - // Each endpoint has its own quota pool, so we can try different endpoints - if httpResp.StatusCode == 429 { - respBody, _ := io.ReadAll(httpResp.Body) - _ = httpResp.Body.Close() - appendAPIResponseChunk(ctx, e.cfg, respBody) - - log.Warnf("kiro: stream %s endpoint quota exhausted (429), will try next endpoint", endpointConfig.Name) - - // Break inner retry loop to try next endpoint (which has different quota) - break - } - - // Handle 5xx server errors with exponential backoff retry - if httpResp.StatusCode >= 500 && httpResp.StatusCode < 600 { - respBody, _ := io.ReadAll(httpResp.Body) - _ = httpResp.Body.Close() - appendAPIResponseChunk(ctx, e.cfg, respBody) - - if attempt < maxRetries { - // Exponential backoff: 1s, 2s, 4s... (max 30s) - backoff := time.Duration(1< 30*time.Second { - backoff = 30 * time.Second - } - log.Warnf("kiro: stream server error %d, retrying in %v (attempt %d/%d)", httpResp.StatusCode, backoff, attempt+1, maxRetries) - time.Sleep(backoff) - continue + for attempt := 0; attempt <= maxRetries; attempt++ { + httpReq, err := http.NewRequestWithContext(ctx, http.MethodPost, url, bytes.NewReader(kiroPayload)) + if err != nil { + return nil, err } - log.Errorf("kiro: stream server error %d after %d retries", httpResp.StatusCode, maxRetries) - return nil, statusErr{code: httpResp.StatusCode, msg: string(respBody)} - } - // Handle 400 errors - Credential/Validation issues - // Do NOT switch endpoints - return error immediately - if httpResp.StatusCode == 400 { - respBody, _ := io.ReadAll(httpResp.Body) - _ = httpResp.Body.Close() - appendAPIResponseChunk(ctx, e.cfg, respBody) + httpReq.Header.Set("Content-Type", kiroContentType) + httpReq.Header.Set("Authorization", "Bearer "+accessToken) + httpReq.Header.Set("Accept", kiroAcceptStream) + // Use endpoint-specific X-Amz-Target (critical for avoiding 403 errors) + httpReq.Header.Set("X-Amz-Target", endpointConfig.AmzTarget) + httpReq.Header.Set("User-Agent", kiroUserAgent) + httpReq.Header.Set("X-Amz-User-Agent", kiroFullUserAgent) + httpReq.Header.Set("Amz-Sdk-Request", "attempt=1; max=3") + httpReq.Header.Set("Amz-Sdk-Invocation-Id", uuid.New().String()) - log.Warnf("kiro: received 400 error (attempt %d/%d), body: %s", attempt+1, maxRetries+1, summarizeErrorBody(httpResp.Header.Get("Content-Type"), respBody)) + var attrs map[string]string + if auth != nil { + attrs = auth.Attributes + } + util.ApplyCustomHeadersFromAttrs(httpReq, attrs) - // 400 errors indicate request validation issues - return immediately without retry - return nil, statusErr{code: httpResp.StatusCode, msg: string(respBody)} - } + var authID, authLabel, authType, authValue string + if auth != nil { + authID = auth.ID + authLabel = auth.Label + authType, authValue = auth.AccountInfo() + } + recordAPIRequest(ctx, e.cfg, upstreamRequestLog{ + URL: url, + Method: http.MethodPost, + Headers: httpReq.Header.Clone(), + Body: kiroPayload, + Provider: e.Identifier(), + AuthID: authID, + AuthLabel: authLabel, + AuthType: authType, + AuthValue: authValue, + }) - // Handle 401 errors with token refresh and retry - // 401 = Unauthorized (token expired/invalid) - refresh token - if httpResp.StatusCode == 401 { - respBody, _ := io.ReadAll(httpResp.Body) - _ = httpResp.Body.Close() - appendAPIResponseChunk(ctx, e.cfg, respBody) + httpClient := newProxyAwareHTTPClient(ctx, e.cfg, auth, 0) + httpResp, err := httpClient.Do(httpReq) + if err != nil { + recordAPIResponseError(ctx, e.cfg, err) + return nil, err + } + recordAPIResponseMetadata(ctx, e.cfg, httpResp.StatusCode, httpResp.Header.Clone()) - if attempt < maxRetries { - log.Warnf("kiro: stream received 401 error, attempting token refresh and retry (attempt %d/%d)", attempt+1, maxRetries+1) + // Handle 429 errors (quota exhausted) - try next endpoint + // Each endpoint has its own quota pool, so we can try different endpoints + if httpResp.StatusCode == 429 { + respBody, _ := io.ReadAll(httpResp.Body) + _ = httpResp.Body.Close() + appendAPIResponseChunk(ctx, e.cfg, respBody) - refreshedAuth, refreshErr := e.Refresh(ctx, auth) - if refreshErr != nil { - log.Errorf("kiro: token refresh failed: %v", refreshErr) - return nil, statusErr{code: httpResp.StatusCode, msg: string(respBody)} - } + // Preserve last 429 so callers can correctly backoff when all endpoints are exhausted + last429Err = statusErr{code: httpResp.StatusCode, msg: string(respBody)} - if refreshedAuth != nil { - auth = refreshedAuth - accessToken, profileArn = kiroCredentials(auth) - // Rebuild payload with new profile ARN if changed - kiroPayload = e.buildKiroPayload(body, kiroModelID, profileArn, currentOrigin, isAgentic, isChatOnly) - log.Infof("kiro: token refreshed successfully, retrying stream request") + log.Warnf("kiro: stream %s endpoint quota exhausted (429), will try next endpoint, body: %s", + endpointConfig.Name, summarizeErrorBody(httpResp.Header.Get("Content-Type"), respBody)) + + // Break inner retry loop to try next endpoint (which has different quota) + break + } + + // Handle 5xx server errors with exponential backoff retry + if httpResp.StatusCode >= 500 && httpResp.StatusCode < 600 { + respBody, _ := io.ReadAll(httpResp.Body) + _ = httpResp.Body.Close() + appendAPIResponseChunk(ctx, e.cfg, respBody) + + if attempt < maxRetries { + // Exponential backoff: 1s, 2s, 4s... (max 30s) + backoff := time.Duration(1< 30*time.Second { + backoff = 30 * time.Second + } + log.Warnf("kiro: stream server error %d, retrying in %v (attempt %d/%d)", httpResp.StatusCode, backoff, attempt+1, maxRetries) + time.Sleep(backoff) continue } + log.Errorf("kiro: stream server error %d after %d retries", httpResp.StatusCode, maxRetries) + return nil, statusErr{code: httpResp.StatusCode, msg: string(respBody)} } - log.Warnf("kiro stream error, status: 401, body: %s", string(respBody)) - return nil, statusErr{code: httpResp.StatusCode, msg: string(respBody)} - } + // Handle 400 errors - Credential/Validation issues + // Do NOT switch endpoints - return error immediately + if httpResp.StatusCode == 400 { + respBody, _ := io.ReadAll(httpResp.Body) + _ = httpResp.Body.Close() + appendAPIResponseChunk(ctx, e.cfg, respBody) - // Handle 402 errors - Monthly Limit Reached - if httpResp.StatusCode == 402 { - respBody, _ := io.ReadAll(httpResp.Body) - _ = httpResp.Body.Close() - appendAPIResponseChunk(ctx, e.cfg, respBody) + log.Warnf("kiro: received 400 error (attempt %d/%d), body: %s", attempt+1, maxRetries+1, summarizeErrorBody(httpResp.Header.Get("Content-Type"), respBody)) - log.Warnf("kiro: stream received 402 (monthly limit). Upstream body: %s", string(respBody)) - - // Return upstream error body directly - return nil, statusErr{code: httpResp.StatusCode, msg: string(respBody)} - } - - // Handle 403 errors - Access Denied / Token Expired - // Do NOT switch endpoints for 403 errors - if httpResp.StatusCode == 403 { - respBody, _ := io.ReadAll(httpResp.Body) - _ = httpResp.Body.Close() - appendAPIResponseChunk(ctx, e.cfg, respBody) - - // Log the 403 error details for debugging - log.Warnf("kiro: stream received 403 error (attempt %d/%d), body: %s", attempt+1, maxRetries+1, string(respBody)) - - respBodyStr := string(respBody) - - // Check for SUSPENDED status - return immediately without retry - if strings.Contains(respBodyStr, "SUSPENDED") || strings.Contains(respBodyStr, "TEMPORARILY_SUSPENDED") { - log.Errorf("kiro: account is suspended, cannot proceed") - return nil, statusErr{code: httpResp.StatusCode, msg: "account suspended: " + string(respBody)} + // 400 errors indicate request validation issues - return immediately without retry + return nil, statusErr{code: httpResp.StatusCode, msg: string(respBody)} } - // Check if this looks like a token-related 403 (some APIs return 403 for expired tokens) - isTokenRelated := strings.Contains(respBodyStr, "token") || - strings.Contains(respBodyStr, "expired") || - strings.Contains(respBodyStr, "invalid") || - strings.Contains(respBodyStr, "unauthorized") + // Handle 401 errors with token refresh and retry + // 401 = Unauthorized (token expired/invalid) - refresh token + if httpResp.StatusCode == 401 { + respBody, _ := io.ReadAll(httpResp.Body) + _ = httpResp.Body.Close() + appendAPIResponseChunk(ctx, e.cfg, respBody) - if isTokenRelated && attempt < maxRetries { - log.Warnf("kiro: 403 appears token-related, attempting token refresh") - refreshedAuth, refreshErr := e.Refresh(ctx, auth) - if refreshErr != nil { - log.Errorf("kiro: token refresh failed: %v", refreshErr) - // Token refresh failed - return error immediately - return nil, statusErr{code: httpResp.StatusCode, msg: string(respBody)} - } - if refreshedAuth != nil { - auth = refreshedAuth - accessToken, profileArn = kiroCredentials(auth) - kiroPayload = e.buildKiroPayload(body, kiroModelID, profileArn, currentOrigin, isAgentic, isChatOnly) - log.Infof("kiro: token refreshed for 403, retrying stream request") - continue + if attempt < maxRetries { + log.Warnf("kiro: stream received 401 error, attempting token refresh and retry (attempt %d/%d)", attempt+1, maxRetries+1) + + refreshedAuth, refreshErr := e.Refresh(ctx, auth) + if refreshErr != nil { + log.Errorf("kiro: token refresh failed: %v", refreshErr) + return nil, statusErr{code: httpResp.StatusCode, msg: string(respBody)} + } + + if refreshedAuth != nil { + auth = refreshedAuth + accessToken, profileArn = kiroCredentials(auth) + // Rebuild payload with new profile ARN if changed + kiroPayload = kiroclaude.BuildKiroPayload(body, kiroModelID, profileArn, currentOrigin, isAgentic, isChatOnly) + log.Infof("kiro: token refreshed successfully, retrying stream request") + continue + } } + + log.Warnf("kiro stream error, status: 401, body: %s", string(respBody)) + return nil, statusErr{code: httpResp.StatusCode, msg: string(respBody)} } - // For non-token 403 or after max retries, return error immediately + // Handle 402 errors - Monthly Limit Reached + if httpResp.StatusCode == 402 { + respBody, _ := io.ReadAll(httpResp.Body) + _ = httpResp.Body.Close() + appendAPIResponseChunk(ctx, e.cfg, respBody) + + log.Warnf("kiro: stream received 402 (monthly limit). Upstream body: %s", string(respBody)) + + // Return upstream error body directly + return nil, statusErr{code: httpResp.StatusCode, msg: string(respBody)} + } + + // Handle 403 errors - Access Denied / Token Expired // Do NOT switch endpoints for 403 errors - log.Warnf("kiro: 403 error, returning immediately (no endpoint switch)") - return nil, statusErr{code: httpResp.StatusCode, msg: string(respBody)} - } + if httpResp.StatusCode == 403 { + respBody, _ := io.ReadAll(httpResp.Body) + _ = httpResp.Body.Close() + appendAPIResponseChunk(ctx, e.cfg, respBody) - if httpResp.StatusCode < 200 || httpResp.StatusCode >= 300 { - b, _ := io.ReadAll(httpResp.Body) - appendAPIResponseChunk(ctx, e.cfg, b) - log.Debugf("kiro stream error, status: %d, body: %s", httpResp.StatusCode, string(b)) - if errClose := httpResp.Body.Close(); errClose != nil { - log.Errorf("response body close error: %v", errClose) - } - return nil, statusErr{code: httpResp.StatusCode, msg: string(b)} - } + // Log the 403 error details for debugging + log.Warnf("kiro: stream received 403 error (attempt %d/%d), body: %s", attempt+1, maxRetries+1, string(respBody)) - out := make(chan cliproxyexecutor.StreamChunk) + respBodyStr := string(respBody) - go func(resp *http.Response) { - defer close(out) - defer func() { - if r := recover(); r != nil { - log.Errorf("kiro: panic in stream handler: %v", r) - out <- cliproxyexecutor.StreamChunk{Err: fmt.Errorf("internal error: %v", r)} + // Check for SUSPENDED status - return immediately without retry + if strings.Contains(respBodyStr, "SUSPENDED") || strings.Contains(respBodyStr, "TEMPORARILY_SUSPENDED") { + log.Errorf("kiro: account is suspended, cannot proceed") + return nil, statusErr{code: httpResp.StatusCode, msg: "account suspended: " + string(respBody)} } - }() - defer func() { - if errClose := resp.Body.Close(); errClose != nil { + + // Check if this looks like a token-related 403 (some APIs return 403 for expired tokens) + isTokenRelated := strings.Contains(respBodyStr, "token") || + strings.Contains(respBodyStr, "expired") || + strings.Contains(respBodyStr, "invalid") || + strings.Contains(respBodyStr, "unauthorized") + + if isTokenRelated && attempt < maxRetries { + log.Warnf("kiro: 403 appears token-related, attempting token refresh") + refreshedAuth, refreshErr := e.Refresh(ctx, auth) + if refreshErr != nil { + log.Errorf("kiro: token refresh failed: %v", refreshErr) + // Token refresh failed - return error immediately + return nil, statusErr{code: httpResp.StatusCode, msg: string(respBody)} + } + if refreshedAuth != nil { + auth = refreshedAuth + accessToken, profileArn = kiroCredentials(auth) + kiroPayload = kiroclaude.BuildKiroPayload(body, kiroModelID, profileArn, currentOrigin, isAgentic, isChatOnly) + log.Infof("kiro: token refreshed for 403, retrying stream request") + continue + } + } + + // For non-token 403 or after max retries, return error immediately + // Do NOT switch endpoints for 403 errors + log.Warnf("kiro: 403 error, returning immediately (no endpoint switch)") + return nil, statusErr{code: httpResp.StatusCode, msg: string(respBody)} + } + + if httpResp.StatusCode < 200 || httpResp.StatusCode >= 300 { + b, _ := io.ReadAll(httpResp.Body) + appendAPIResponseChunk(ctx, e.cfg, b) + log.Debugf("kiro stream error, status: %d, body: %s", httpResp.StatusCode, string(b)) + if errClose := httpResp.Body.Close(); errClose != nil { log.Errorf("response body close error: %v", errClose) } - }() + return nil, statusErr{code: httpResp.StatusCode, msg: string(b)} + } - e.streamToChannel(ctx, resp.Body, out, from, req.Model, opts.OriginalRequest, body, reporter) - }(httpResp) + out := make(chan cliproxyexecutor.StreamChunk) - return out, nil + go func(resp *http.Response) { + defer close(out) + defer func() { + if r := recover(); r != nil { + log.Errorf("kiro: panic in stream handler: %v", r) + out <- cliproxyexecutor.StreamChunk{Err: fmt.Errorf("internal error: %v", r)} + } + }() + defer func() { + if errClose := resp.Body.Close(); errClose != nil { + log.Errorf("response body close error: %v", errClose) + } + }() + + e.streamToChannel(ctx, resp.Body, out, from, req.Model, opts.OriginalRequest, body, reporter) + }(httpResp) + + return out, nil } // Inner retry loop exhausted for this endpoint, try next endpoint // Note: This code is unreachable because all paths in the inner loop @@ -867,16 +766,18 @@ func (e *KiroExecutor) executeStreamWithRetry(ctx context.Context, auth *cliprox } // All endpoints exhausted + if last429Err != nil { + return nil, last429Err + } return nil, fmt.Errorf("kiro: stream all endpoints exhausted") } - // kiroCredentials extracts access token and profile ARN from auth. func kiroCredentials(auth *cliproxyauth.Auth) (accessToken, profileArn string) { if auth == nil { return "", "" } - + // Try Metadata first (wrapper format) if auth.Metadata != nil { if token, ok := auth.Metadata["access_token"].(string); ok { @@ -886,13 +787,13 @@ func kiroCredentials(auth *cliproxyauth.Auth) (accessToken, profileArn string) { profileArn = arn } } - + // Try Attributes if accessToken == "" && auth.Attributes != nil { accessToken = auth.Attributes["access_token"] profileArn = auth.Attributes["profile_arn"] } - + // Try direct fields from flat JSON format (new AWS Builder ID format) if accessToken == "" && auth.Metadata != nil { if token, ok := auth.Metadata["accessToken"].(string); ok { @@ -902,10 +803,46 @@ func kiroCredentials(auth *cliproxyauth.Auth) (accessToken, profileArn string) { profileArn = arn } } - + return accessToken, profileArn } +// determineAgenticMode determines if the model is an agentic or chat-only variant. +// Returns (isAgentic, isChatOnly) based on model name suffixes. +func determineAgenticMode(model string) (isAgentic, isChatOnly bool) { + isAgentic = strings.HasSuffix(model, "-agentic") + isChatOnly = strings.HasSuffix(model, "-chat") + return isAgentic, isChatOnly +} + +// getEffectiveProfileArn determines if profileArn should be included based on auth method. +// profileArn is only needed for social auth (Google OAuth), not for builder-id (AWS SSO). +func getEffectiveProfileArn(auth *cliproxyauth.Auth, profileArn string) string { + if auth != nil && auth.Metadata != nil { + if authMethod, ok := auth.Metadata["auth_method"].(string); ok && authMethod == "builder-id" { + return "" // Don't include profileArn for builder-id auth + } + } + return profileArn +} + +// getEffectiveProfileArnWithWarning determines if profileArn should be included based on auth method, +// and logs a warning if profileArn is missing for non-builder-id auth. +// This consolidates the auth_method check that was previously done separately. +func getEffectiveProfileArnWithWarning(auth *cliproxyauth.Auth, profileArn string) string { + if auth != nil && auth.Metadata != nil { + if authMethod, ok := auth.Metadata["auth_method"].(string); ok && authMethod == "builder-id" { + // builder-id auth doesn't need profileArn + return "" + } + } + // For non-builder-id auth (social auth), profileArn is required + if profileArn == "" { + log.Warnf("kiro: profile ARN not found in auth, API calls may fail") + } + return profileArn +} + // mapModelToKiro maps external model names to Kiro model IDs. // Supports both Kiro and Amazon Q prefixes since they use the same API. // Agentic variants (-agentic suffix) map to the same backend model IDs. @@ -939,28 +876,28 @@ func (e *KiroExecutor) mapModelToKiro(model string) string { "claude-sonnet-4-20250514": "claude-sonnet-4", "auto": "auto", // Agentic variants (same backend model IDs, but with special system prompt) - "claude-opus-4.5-agentic": "claude-opus-4.5", - "claude-sonnet-4.5-agentic": "claude-sonnet-4.5", - "claude-sonnet-4-agentic": "claude-sonnet-4", - "claude-haiku-4.5-agentic": "claude-haiku-4.5", - "kiro-claude-opus-4-5-agentic": "claude-opus-4.5", - "kiro-claude-sonnet-4-5-agentic": "claude-sonnet-4.5", - "kiro-claude-sonnet-4-agentic": "claude-sonnet-4", - "kiro-claude-haiku-4-5-agentic": "claude-haiku-4.5", + "claude-opus-4.5-agentic": "claude-opus-4.5", + "claude-sonnet-4.5-agentic": "claude-sonnet-4.5", + "claude-sonnet-4-agentic": "claude-sonnet-4", + "claude-haiku-4.5-agentic": "claude-haiku-4.5", + "kiro-claude-opus-4-5-agentic": "claude-opus-4.5", + "kiro-claude-sonnet-4-5-agentic": "claude-sonnet-4.5", + "kiro-claude-sonnet-4-agentic": "claude-sonnet-4", + "kiro-claude-haiku-4-5-agentic": "claude-haiku-4.5", } if kiroID, ok := modelMap[model]; ok { return kiroID } - + // Smart fallback: try to infer model type from name patterns modelLower := strings.ToLower(model) - + // Check for Haiku variants if strings.Contains(modelLower, "haiku") { log.Debugf("kiro: unknown Haiku model '%s', mapping to claude-haiku-4.5", model) return "claude-haiku-4.5" } - + // Check for Sonnet variants if strings.Contains(modelLower, "sonnet") { // Check for specific version patterns @@ -976,13 +913,13 @@ func (e *KiroExecutor) mapModelToKiro(model string) string { log.Debugf("kiro: unknown Sonnet model '%s', mapping to claude-sonnet-4", model) return "claude-sonnet-4" } - + // Check for Opus variants if strings.Contains(modelLower, "opus") { log.Debugf("kiro: unknown Opus model '%s', mapping to claude-opus-4.5", model) return "claude-opus-4.5" } - + // Final fallback to Sonnet 4.5 (most commonly used model) log.Warnf("kiro: unknown model '%s', falling back to claude-sonnet-4.5", model) return "claude-sonnet-4.5" @@ -1008,582 +945,23 @@ type eventStreamMessage struct { Payload []byte // JSON payload of the message } -// Kiro API request structs - field order determines JSON key order - -type kiroPayload struct { - ConversationState kiroConversationState `json:"conversationState"` - ProfileArn string `json:"profileArn,omitempty"` - InferenceConfig *kiroInferenceConfig `json:"inferenceConfig,omitempty"` -} - -// kiroInferenceConfig contains inference parameters for the Kiro API. -// NOTE: This is an experimental addition - Kiro/Amazon Q API may not support these parameters. -// If the API ignores or rejects these fields, response length is controlled internally by the model. -type kiroInferenceConfig struct { - MaxTokens int `json:"maxTokens,omitempty"` // Maximum output tokens (may be ignored by API) - Temperature float64 `json:"temperature,omitempty"` // Sampling temperature (may be ignored by API) -} - -type kiroConversationState struct { - ChatTriggerType string `json:"chatTriggerType"` // Required: "MANUAL" - must be first field - ConversationID string `json:"conversationId"` - CurrentMessage kiroCurrentMessage `json:"currentMessage"` - History []kiroHistoryMessage `json:"history,omitempty"` -} - -type kiroCurrentMessage struct { - UserInputMessage kiroUserInputMessage `json:"userInputMessage"` -} - -type kiroHistoryMessage struct { - UserInputMessage *kiroUserInputMessage `json:"userInputMessage,omitempty"` - AssistantResponseMessage *kiroAssistantResponseMessage `json:"assistantResponseMessage,omitempty"` -} - -// kiroImage represents an image in Kiro API format -type kiroImage struct { - Format string `json:"format"` - Source kiroImageSource `json:"source"` -} - -// kiroImageSource contains the image data -type kiroImageSource struct { - Bytes string `json:"bytes"` // base64 encoded image data -} - -type kiroUserInputMessage struct { - Content string `json:"content"` - ModelID string `json:"modelId"` - Origin string `json:"origin"` - Images []kiroImage `json:"images,omitempty"` - UserInputMessageContext *kiroUserInputMessageContext `json:"userInputMessageContext,omitempty"` -} - -type kiroUserInputMessageContext struct { - ToolResults []kiroToolResult `json:"toolResults,omitempty"` - Tools []kiroToolWrapper `json:"tools,omitempty"` -} - -type kiroToolResult struct { - Content []kiroTextContent `json:"content"` - Status string `json:"status"` - ToolUseID string `json:"toolUseId"` -} - -type kiroTextContent struct { - Text string `json:"text"` -} - -type kiroToolWrapper struct { - ToolSpecification kiroToolSpecification `json:"toolSpecification"` -} - -type kiroToolSpecification struct { - Name string `json:"name"` - Description string `json:"description"` - InputSchema kiroInputSchema `json:"inputSchema"` -} - -type kiroInputSchema struct { - JSON interface{} `json:"json"` -} - -type kiroAssistantResponseMessage struct { - Content string `json:"content"` - ToolUses []kiroToolUse `json:"toolUses,omitempty"` -} - -type kiroToolUse struct { - ToolUseID string `json:"toolUseId"` - Name string `json:"name"` - Input map[string]interface{} `json:"input"` -} - -// buildKiroPayload constructs the Kiro API request payload. -// Supports tool calling - tools are passed via userInputMessageContext. -// origin parameter determines which quota to use: "CLI" for Amazon Q, "AI_EDITOR" for Kiro IDE. -// isAgentic parameter enables chunked write optimization prompt for -agentic model variants. -// isChatOnly parameter disables tool calling for -chat model variants (pure conversation mode). -// Supports thinking mode - when Claude API thinking parameter is present, injects thinkingHint. -// -// max_tokens support: Kiro/Amazon Q API may not officially support max_tokens parameter. -// We attempt to pass it via inferenceConfig.maxTokens, but the API may ignore it. -// Response truncation can be detected via stop_reason == "max_tokens" in the response. -func (e *KiroExecutor) buildKiroPayload(claudeBody []byte, modelID, profileArn, origin string, isAgentic, isChatOnly bool) []byte { - // Extract max_tokens for potential use in inferenceConfig - var maxTokens int64 - if mt := gjson.GetBytes(claudeBody, "max_tokens"); mt.Exists() { - maxTokens = mt.Int() - } - - // Extract temperature if specified - var temperature float64 - var hasTemperature bool - if temp := gjson.GetBytes(claudeBody, "temperature"); temp.Exists() { - temperature = temp.Float() - hasTemperature = true - } - - // Normalize origin value for Kiro API compatibility - // Kiro API only accepts "CLI" or "AI_EDITOR" as valid origin values - switch origin { - case "KIRO_CLI": - origin = "CLI" - case "KIRO_AI_EDITOR": - origin = "AI_EDITOR" - case "AMAZON_Q": - origin = "CLI" - case "KIRO_IDE": - origin = "AI_EDITOR" - // Add any other non-standard origin values that need normalization - default: - // Keep the original value if it's already standard - // Valid values: "CLI", "AI_EDITOR" - } - log.Debugf("kiro: normalized origin value: %s", origin) - - messages := gjson.GetBytes(claudeBody, "messages") - - // For chat-only mode, don't include tools - var tools gjson.Result - if !isChatOnly { - tools = gjson.GetBytes(claudeBody, "tools") - } - - // Extract system prompt - can be string or array of content blocks - systemField := gjson.GetBytes(claudeBody, "system") - var systemPrompt string - if systemField.IsArray() { - // System is array of content blocks, extract text - var sb strings.Builder - for _, block := range systemField.Array() { - if block.Get("type").String() == "text" { - sb.WriteString(block.Get("text").String()) - } else if block.Type == gjson.String { - sb.WriteString(block.String()) - } - } - systemPrompt = sb.String() - } else { - systemPrompt = systemField.String() - } - - // Check for thinking parameter in Claude API request - // Claude API format: {"thinking": {"type": "enabled", "budget_tokens": 16000}} - // When thinking is enabled, inject dynamic thinkingHint based on budget_tokens - // This allows reasoning_effort (low/medium/high) to control actual thinking length - thinkingEnabled := false - var budgetTokens int64 = 16000 // Default value (same as OpenAI reasoning_effort "medium") - thinkingField := gjson.GetBytes(claudeBody, "thinking") - if thinkingField.Exists() { - // Check if thinking.type is "enabled" - thinkingType := thinkingField.Get("type").String() - if thinkingType == "enabled" { - thinkingEnabled = true - // Read budget_tokens if specified - this value comes from: - // - Claude API: thinking.budget_tokens directly - // - OpenAI API: reasoning_effort -> budget_tokens (low:4000, medium:16000, high:32000) - if bt := thinkingField.Get("budget_tokens"); bt.Exists() { - budgetTokens = bt.Int() - // If budget_tokens <= 0, disable thinking explicitly - // This allows users to disable thinking by setting budget_tokens to 0 - if budgetTokens <= 0 { - thinkingEnabled = false - log.Debugf("kiro: thinking mode disabled via budget_tokens <= 0") - } - } - if thinkingEnabled { - log.Debugf("kiro: thinking mode enabled via Claude API parameter, budget_tokens: %d", budgetTokens) - } - } - } - - // Inject timestamp context for better temporal awareness - // Based on amq2api implementation - helps model understand current time context - timestamp := time.Now().Format("2006-01-02 15:04:05 MST") - timestampContext := fmt.Sprintf("[Context: Current time is %s]", timestamp) - if systemPrompt != "" { - systemPrompt = timestampContext + "\n\n" + systemPrompt - } else { - systemPrompt = timestampContext - } - log.Debugf("kiro: injected timestamp context: %s", timestamp) - - // Inject agentic optimization prompt for -agentic model variants - // This prevents AWS Kiro API timeouts during large file write operations - if isAgentic { - if systemPrompt != "" { - systemPrompt += "\n" - } - systemPrompt += kiroAgenticSystemPrompt - } - - // Inject thinking hint when thinking mode is enabled - // This tells the model to use tags in its response - // DYNAMICALLY set max_thinking_length based on budget_tokens from request - // This respects the reasoning_effort setting: low(4000), medium(16000), high(32000) - if thinkingEnabled { - if systemPrompt != "" { - systemPrompt += "\n" - } - // Build dynamic thinking hint with the actual budget_tokens value - dynamicThinkingHint := fmt.Sprintf("interleaved%d", budgetTokens) - systemPrompt += dynamicThinkingHint - log.Debugf("kiro: injected dynamic thinking hint into system prompt, max_thinking_length: %d", budgetTokens) - } - - // Convert Claude tools to Kiro format - var kiroTools []kiroToolWrapper - if tools.IsArray() { - for _, tool := range tools.Array() { - name := tool.Get("name").String() - description := tool.Get("description").String() - inputSchema := tool.Get("input_schema").Value() - - // Truncate long descriptions (Kiro API limit is in bytes) - // Truncate at valid UTF-8 boundary to avoid breaking multi-byte chars - // Add truncation notice to help model understand the description is incomplete - if len(description) > kiroMaxToolDescLen { - // Find a valid UTF-8 boundary before the limit - // Reserve space for truncation notice (about 30 bytes) - truncLen := kiroMaxToolDescLen - 30 - for truncLen > 0 && !utf8.RuneStart(description[truncLen]) { - truncLen-- - } - description = description[:truncLen] + "... (description truncated)" - } - - kiroTools = append(kiroTools, kiroToolWrapper{ - ToolSpecification: kiroToolSpecification{ - Name: name, - Description: description, - InputSchema: kiroInputSchema{JSON: inputSchema}, - }, - }) - } - } - - var history []kiroHistoryMessage - var currentUserMsg *kiroUserInputMessage - var currentToolResults []kiroToolResult - - // Merge adjacent messages with the same role before processing - // This reduces API call complexity and improves compatibility - messagesArray := mergeAdjacentMessages(messages.Array()) - for i, msg := range messagesArray { - role := msg.Get("role").String() - isLastMessage := i == len(messagesArray)-1 - - if role == "user" { - userMsg, toolResults := e.buildUserMessageStruct(msg, modelID, origin) - if isLastMessage { - currentUserMsg = &userMsg - currentToolResults = toolResults - } else { - // CRITICAL: Kiro API requires content to be non-empty for history messages too - if strings.TrimSpace(userMsg.Content) == "" { - if len(toolResults) > 0 { - userMsg.Content = "Tool results provided." - } else { - userMsg.Content = "Continue" - } - } - // For history messages, embed tool results in context - if len(toolResults) > 0 { - userMsg.UserInputMessageContext = &kiroUserInputMessageContext{ - ToolResults: toolResults, - } - } - history = append(history, kiroHistoryMessage{ - UserInputMessage: &userMsg, - }) - } - } else if role == "assistant" { - assistantMsg := e.buildAssistantMessageStruct(msg) - // If this is the last message and it's an assistant message, - // we need to add it to history and create a "Continue" user message - // because Kiro API requires currentMessage to be userInputMessage type - if isLastMessage { - history = append(history, kiroHistoryMessage{ - AssistantResponseMessage: &assistantMsg, - }) - // Create a "Continue" user message as currentMessage - currentUserMsg = &kiroUserInputMessage{ - Content: "Continue", - ModelID: modelID, - Origin: origin, - } - } else { - history = append(history, kiroHistoryMessage{ - AssistantResponseMessage: &assistantMsg, - }) - } - } - } - - // Build content with system prompt - if currentUserMsg != nil { - var contentBuilder strings.Builder - - // Add system prompt if present - if systemPrompt != "" { - contentBuilder.WriteString("--- SYSTEM PROMPT ---\n") - contentBuilder.WriteString(systemPrompt) - contentBuilder.WriteString("\n--- END SYSTEM PROMPT ---\n\n") - } - - // Add the actual user message - contentBuilder.WriteString(currentUserMsg.Content) - finalContent := contentBuilder.String() - - // CRITICAL: Kiro API requires content to be non-empty, even when toolResults are present - // If content is empty or only whitespace, provide a default message - if strings.TrimSpace(finalContent) == "" { - if len(currentToolResults) > 0 { - finalContent = "Tool results provided." - } else { - finalContent = "Continue" - } - log.Debugf("kiro: content was empty, using default: %s", finalContent) - } - currentUserMsg.Content = finalContent - - // Deduplicate currentToolResults before adding to context - // Kiro API does not accept duplicate toolUseIds - if len(currentToolResults) > 0 { - seenIDs := make(map[string]bool) - uniqueToolResults := make([]kiroToolResult, 0, len(currentToolResults)) - for _, tr := range currentToolResults { - if !seenIDs[tr.ToolUseID] { - seenIDs[tr.ToolUseID] = true - uniqueToolResults = append(uniqueToolResults, tr) - } else { - log.Debugf("kiro: skipping duplicate toolResult in currentMessage: %s", tr.ToolUseID) - } - } - currentToolResults = uniqueToolResults - } - - // Build userInputMessageContext with tools and tool results - if len(kiroTools) > 0 || len(currentToolResults) > 0 { - currentUserMsg.UserInputMessageContext = &kiroUserInputMessageContext{ - Tools: kiroTools, - ToolResults: currentToolResults, - } - } - } - - // Build payload using structs (preserves key order) - var currentMessage kiroCurrentMessage - if currentUserMsg != nil { - currentMessage = kiroCurrentMessage{UserInputMessage: *currentUserMsg} - } else { - // Fallback when no user messages - still include system prompt if present - fallbackContent := "" - if systemPrompt != "" { - fallbackContent = "--- SYSTEM PROMPT ---\n" + systemPrompt + "\n--- END SYSTEM PROMPT ---\n" - } - currentMessage = kiroCurrentMessage{UserInputMessage: kiroUserInputMessage{ - Content: fallbackContent, - ModelID: modelID, - Origin: origin, - }} - } - - // Build inferenceConfig if we have any inference parameters - var inferenceConfig *kiroInferenceConfig - if maxTokens > 0 || hasTemperature { - inferenceConfig = &kiroInferenceConfig{} - if maxTokens > 0 { - inferenceConfig.MaxTokens = int(maxTokens) - } - if hasTemperature { - inferenceConfig.Temperature = temperature - } - } - - // Build payload with correct field order (matches struct definition) - payload := kiroPayload{ - ConversationState: kiroConversationState{ - ChatTriggerType: "MANUAL", // Required by Kiro API - must be first - ConversationID: uuid.New().String(), - CurrentMessage: currentMessage, - History: history, // Now always included (non-nil slice) - }, - ProfileArn: profileArn, - InferenceConfig: inferenceConfig, - } - - result, err := json.Marshal(payload) - if err != nil { - log.Debugf("kiro: failed to marshal payload: %v", err) - return nil - } - - return result -} - -// buildUserMessageStruct builds a user message and extracts tool results -// origin parameter determines which quota to use: "CLI" for Amazon Q, "AI_EDITOR" for Kiro IDE. -// IMPORTANT: Kiro API does not accept duplicate toolUseIds, so we deduplicate here. -func (e *KiroExecutor) buildUserMessageStruct(msg gjson.Result, modelID, origin string) (kiroUserInputMessage, []kiroToolResult) { - content := msg.Get("content") - var contentBuilder strings.Builder - var toolResults []kiroToolResult - var images []kiroImage - - // Track seen toolUseIds to deduplicate - Kiro API rejects duplicate toolUseIds - seenToolUseIDs := make(map[string]bool) - - if content.IsArray() { - for _, part := range content.Array() { - partType := part.Get("type").String() - switch partType { - case "text": - contentBuilder.WriteString(part.Get("text").String()) - case "image": - // Extract image data from Claude API format - mediaType := part.Get("source.media_type").String() - data := part.Get("source.data").String() - - // Extract format from media_type (e.g., "image/png" -> "png") - format := "" - if idx := strings.LastIndex(mediaType, "/"); idx != -1 { - format = mediaType[idx+1:] - } - - if format != "" && data != "" { - images = append(images, kiroImage{ - Format: format, - Source: kiroImageSource{ - Bytes: data, - }, - }) - } - case "tool_result": - // Extract tool result for API - toolUseID := part.Get("tool_use_id").String() - - // Skip duplicate toolUseIds - Kiro API does not accept duplicates - if seenToolUseIDs[toolUseID] { - log.Debugf("kiro: skipping duplicate tool_result with toolUseId: %s", toolUseID) - continue - } - seenToolUseIDs[toolUseID] = true - - isError := part.Get("is_error").Bool() - resultContent := part.Get("content") - - // Convert content to Kiro format: [{text: "..."}] - var textContents []kiroTextContent - if resultContent.IsArray() { - for _, item := range resultContent.Array() { - if item.Get("type").String() == "text" { - textContents = append(textContents, kiroTextContent{Text: item.Get("text").String()}) - } else if item.Type == gjson.String { - textContents = append(textContents, kiroTextContent{Text: item.String()}) - } - } - } else if resultContent.Type == gjson.String { - textContents = append(textContents, kiroTextContent{Text: resultContent.String()}) - } - - // If no content, add default message - if len(textContents) == 0 { - textContents = append(textContents, kiroTextContent{Text: "Tool use was cancelled by the user"}) - } - - status := "success" - if isError { - status = "error" - } - - toolResults = append(toolResults, kiroToolResult{ - ToolUseID: toolUseID, - Content: textContents, - Status: status, - }) - } - } - } else { - contentBuilder.WriteString(content.String()) - } - - userMsg := kiroUserInputMessage{ - Content: contentBuilder.String(), - ModelID: modelID, - Origin: origin, - } - - // Add images to message if present - if len(images) > 0 { - userMsg.Images = images - } - - return userMsg, toolResults -} - -// buildAssistantMessageStruct builds an assistant message with tool uses -func (e *KiroExecutor) buildAssistantMessageStruct(msg gjson.Result) kiroAssistantResponseMessage { - content := msg.Get("content") - var contentBuilder strings.Builder - var toolUses []kiroToolUse - - if content.IsArray() { - for _, part := range content.Array() { - partType := part.Get("type").String() - switch partType { - case "text": - contentBuilder.WriteString(part.Get("text").String()) - case "tool_use": - // Extract tool use for API - toolUseID := part.Get("id").String() - toolName := part.Get("name").String() - toolInput := part.Get("input") - - // Convert input to map - var inputMap map[string]interface{} - if toolInput.IsObject() { - inputMap = make(map[string]interface{}) - toolInput.ForEach(func(key, value gjson.Result) bool { - inputMap[key.String()] = value.Value() - return true - }) - } - - toolUses = append(toolUses, kiroToolUse{ - ToolUseID: toolUseID, - Name: toolName, - Input: inputMap, - }) - } - } - } else { - contentBuilder.WriteString(content.String()) - } - - return kiroAssistantResponseMessage{ - Content: contentBuilder.String(), - ToolUses: toolUses, - } -} - -// NOTE: Tool calling is now supported via userInputMessageContext.tools and toolResults +// NOTE: Request building functions moved to internal/translator/kiro/claude/kiro_claude_request.go +// The executor now uses kiroclaude.BuildKiroPayload() instead // parseEventStream parses AWS Event Stream binary format. // Extracts text content, tool uses, and stop_reason from the response. // Supports embedded [Called ...] tool calls and input buffering for toolUseEvent. // Returns: content, toolUses, usageInfo, stopReason, error -func (e *KiroExecutor) parseEventStream(body io.Reader) (string, []kiroToolUse, usage.Detail, string, error) { +func (e *KiroExecutor) parseEventStream(body io.Reader) (string, []kiroclaude.KiroToolUse, usage.Detail, string, error) { var content strings.Builder - var toolUses []kiroToolUse + var toolUses []kiroclaude.KiroToolUse var usageInfo usage.Detail var stopReason string // Extracted from upstream response reader := bufio.NewReader(body) // Tool use state tracking for input buffering and deduplication processedIDs := make(map[string]bool) - var currentToolUse *toolUseState + var currentToolUse *kiroclaude.ToolUseState for { msg, eventErr := e.readEventStreamMessage(reader) @@ -1635,11 +1013,11 @@ func (e *KiroExecutor) parseEventStream(body io.Reader) (string, []kiroToolUse, // Extract stop_reason from various event formats // Kiro/Amazon Q API may include stop_reason in different locations - if sr := getString(event, "stop_reason"); sr != "" { + if sr := kirocommon.GetString(event, "stop_reason"); sr != "" { stopReason = sr log.Debugf("kiro: parseEventStream found stop_reason (top-level): %s", stopReason) } - if sr := getString(event, "stopReason"); sr != "" { + if sr := kirocommon.GetString(event, "stopReason"); sr != "" { stopReason = sr log.Debugf("kiro: parseEventStream found stopReason (top-level): %s", stopReason) } @@ -1657,11 +1035,11 @@ func (e *KiroExecutor) parseEventStream(body io.Reader) (string, []kiroToolUse, content.WriteString(contentText) } // Extract stop_reason from assistantResponseEvent - if sr := getString(assistantResp, "stop_reason"); sr != "" { + if sr := kirocommon.GetString(assistantResp, "stop_reason"); sr != "" { stopReason = sr log.Debugf("kiro: parseEventStream found stop_reason in assistantResponseEvent: %s", stopReason) } - if sr := getString(assistantResp, "stopReason"); sr != "" { + if sr := kirocommon.GetString(assistantResp, "stopReason"); sr != "" { stopReason = sr log.Debugf("kiro: parseEventStream found stopReason in assistantResponseEvent: %s", stopReason) } @@ -1669,17 +1047,17 @@ func (e *KiroExecutor) parseEventStream(body io.Reader) (string, []kiroToolUse, if toolUsesRaw, ok := assistantResp["toolUses"].([]interface{}); ok { for _, tuRaw := range toolUsesRaw { if tu, ok := tuRaw.(map[string]interface{}); ok { - toolUseID := getString(tu, "toolUseId") + toolUseID := kirocommon.GetStringValue(tu, "toolUseId") // Check for duplicate if processedIDs[toolUseID] { log.Debugf("kiro: skipping duplicate tool use from assistantResponse: %s", toolUseID) continue } processedIDs[toolUseID] = true - - toolUse := kiroToolUse{ + + toolUse := kiroclaude.KiroToolUse{ ToolUseID: toolUseID, - Name: getString(tu, "name"), + Name: kirocommon.GetStringValue(tu, "name"), } if input, ok := tu["input"].(map[string]interface{}); ok { toolUse.Input = input @@ -1697,17 +1075,17 @@ func (e *KiroExecutor) parseEventStream(body io.Reader) (string, []kiroToolUse, if toolUsesRaw, ok := event["toolUses"].([]interface{}); ok { for _, tuRaw := range toolUsesRaw { if tu, ok := tuRaw.(map[string]interface{}); ok { - toolUseID := getString(tu, "toolUseId") + toolUseID := kirocommon.GetStringValue(tu, "toolUseId") // Check for duplicate if processedIDs[toolUseID] { log.Debugf("kiro: skipping duplicate direct tool use: %s", toolUseID) continue } processedIDs[toolUseID] = true - - toolUse := kiroToolUse{ + + toolUse := kiroclaude.KiroToolUse{ ToolUseID: toolUseID, - Name: getString(tu, "name"), + Name: kirocommon.GetStringValue(tu, "name"), } if input, ok := tu["input"].(map[string]interface{}); ok { toolUse.Input = input @@ -1719,7 +1097,7 @@ func (e *KiroExecutor) parseEventStream(body io.Reader) (string, []kiroToolUse, case "toolUseEvent": // Handle dedicated tool use events with input buffering - completedToolUses, newState := e.processToolUseEvent(event, currentToolUse, processedIDs) + completedToolUses, newState := kiroclaude.ProcessToolUseEvent(event, currentToolUse, processedIDs) currentToolUse = newState toolUses = append(toolUses, completedToolUses...) @@ -1733,11 +1111,11 @@ func (e *KiroExecutor) parseEventStream(body io.Reader) (string, []kiroToolUse, case "messageStopEvent", "message_stop": // Handle message stop events which may contain stop_reason - if sr := getString(event, "stop_reason"); sr != "" { + if sr := kirocommon.GetString(event, "stop_reason"); sr != "" { stopReason = sr log.Debugf("kiro: parseEventStream found stop_reason in messageStopEvent: %s", stopReason) } - if sr := getString(event, "stopReason"); sr != "" { + if sr := kirocommon.GetString(event, "stopReason"); sr != "" { stopReason = sr log.Debugf("kiro: parseEventStream found stopReason in messageStopEvent: %s", stopReason) } @@ -1756,11 +1134,11 @@ func (e *KiroExecutor) parseEventStream(body io.Reader) (string, []kiroToolUse, // Parse embedded tool calls from content (e.g., [Called tool_name with args: {...}]) contentStr := content.String() - cleanedContent, embeddedToolUses := e.parseEmbeddedToolCalls(contentStr, processedIDs) + cleanedContent, embeddedToolUses := kiroclaude.ParseEmbeddedToolCalls(contentStr, processedIDs) toolUses = append(toolUses, embeddedToolUses...) // Deduplicate all tool uses - toolUses = deduplicateToolUses(toolUses) + toolUses = kiroclaude.DeduplicateToolUses(toolUses) // Apply fallback logic for stop_reason if not provided by upstream // Priority: upstream stopReason > tool_use detection > end_turn default @@ -1876,13 +1254,59 @@ func (e *KiroExecutor) readEventStreamMessage(reader *bufio.Reader) (*eventStrea }, nil } +func skipEventStreamHeaderValue(headers []byte, offset int, valueType byte) (int, bool) { + switch valueType { + case 0, 1: // bool true / bool false + return offset, true + case 2: // byte + if offset+1 > len(headers) { + return offset, false + } + return offset + 1, true + case 3: // short + if offset+2 > len(headers) { + return offset, false + } + return offset + 2, true + case 4: // int + if offset+4 > len(headers) { + return offset, false + } + return offset + 4, true + case 5: // long + if offset+8 > len(headers) { + return offset, false + } + return offset + 8, true + case 6: // byte array (2-byte length + data) + if offset+2 > len(headers) { + return offset, false + } + valueLen := int(binary.BigEndian.Uint16(headers[offset : offset+2])) + offset += 2 + if offset+valueLen > len(headers) { + return offset, false + } + return offset + valueLen, true + case 8: // timestamp + if offset+8 > len(headers) { + return offset, false + } + return offset + 8, true + case 9: // uuid + if offset+16 > len(headers) { + return offset, false + } + return offset + 16, true + default: + return offset, false + } +} + // extractEventTypeFromBytes extracts the event type from raw header bytes (without prelude CRC prefix) func (e *KiroExecutor) extractEventTypeFromBytes(headers []byte) string { offset := 0 for offset < len(headers) { - if offset >= len(headers) { - break - } nameLen := int(headers[offset]) offset++ if offset+nameLen > len(headers) { @@ -1912,240 +1336,21 @@ func (e *KiroExecutor) extractEventTypeFromBytes(headers []byte) string { if name == ":event-type" { return value } - } else { - // Skip other types + continue + } + + nextOffset, ok := skipEventStreamHeaderValue(headers, offset, valueType) + if !ok { break } + offset = nextOffset } return "" } -// extractEventType extracts the event type from AWS Event Stream headers -// Note: This is the legacy version that expects headerBytes to include prelude CRC prefix -func (e *KiroExecutor) extractEventType(headerBytes []byte) string { - // Skip prelude CRC (4 bytes) - if len(headerBytes) < 4 { - return "" - } - headers := headerBytes[4:] - - offset := 0 - for offset < len(headers) { - if offset >= len(headers) { - break - } - nameLen := int(headers[offset]) - offset++ - if offset+nameLen > len(headers) { - break - } - name := string(headers[offset : offset+nameLen]) - offset += nameLen - - if offset >= len(headers) { - break - } - valueType := headers[offset] - offset++ - - if valueType == 7 { // String type - if offset+2 > len(headers) { - break - } - valueLen := int(binary.BigEndian.Uint16(headers[offset : offset+2])) - offset += 2 - if offset+valueLen > len(headers) { - break - } - value := string(headers[offset : offset+valueLen]) - offset += valueLen - - if name == ":event-type" { - return value - } - } else { - // Skip other types - break - } - } - return "" -} - -// getString safely extracts a string from a map -func getString(m map[string]interface{}, key string) string { - if v, ok := m[key].(string); ok { - return v - } - return "" -} - -// buildClaudeResponse constructs a Claude-compatible response. -// Supports tool_use blocks when tools are present in the response. -// Supports thinking blocks - parses tags and converts to Claude thinking content blocks. -// stopReason is passed from upstream; fallback logic applied if empty. -func (e *KiroExecutor) buildClaudeResponse(content string, toolUses []kiroToolUse, model string, usageInfo usage.Detail, stopReason string) []byte { - var contentBlocks []map[string]interface{} - - // Extract thinking blocks and text from content - // This handles ... tags from Kiro's response - if content != "" { - blocks := e.extractThinkingFromContent(content) - contentBlocks = append(contentBlocks, blocks...) - - // DIAGNOSTIC: Log if thinking blocks were extracted - for _, block := range blocks { - if block["type"] == "thinking" { - thinkingContent := block["thinking"].(string) - log.Infof("kiro: buildClaudeResponse extracted thinking block (len: %d)", len(thinkingContent)) - } - } - } - - // Add tool_use blocks - for _, toolUse := range toolUses { - contentBlocks = append(contentBlocks, map[string]interface{}{ - "type": "tool_use", - "id": toolUse.ToolUseID, - "name": toolUse.Name, - "input": toolUse.Input, - }) - } - - // Ensure at least one content block (Claude API requires non-empty content) - if len(contentBlocks) == 0 { - contentBlocks = append(contentBlocks, map[string]interface{}{ - "type": "text", - "text": "", - }) - } - - // Use upstream stopReason; apply fallback logic if not provided - if stopReason == "" { - stopReason = "end_turn" - if len(toolUses) > 0 { - stopReason = "tool_use" - } - log.Debugf("kiro: buildClaudeResponse using fallback stop_reason: %s", stopReason) - } - - // Log warning if response was truncated due to max_tokens - if stopReason == "max_tokens" { - log.Warnf("kiro: response truncated due to max_tokens limit (buildClaudeResponse)") - } - - response := map[string]interface{}{ - "id": "msg_" + uuid.New().String()[:24], - "type": "message", - "role": "assistant", - "model": model, - "content": contentBlocks, - "stop_reason": stopReason, - "usage": map[string]interface{}{ - "input_tokens": usageInfo.InputTokens, - "output_tokens": usageInfo.OutputTokens, - }, - } - result, _ := json.Marshal(response) - return result -} - -// extractThinkingFromContent parses content to extract thinking blocks and text. -// Returns a list of content blocks in the order they appear in the content. -// Handles interleaved thinking and text blocks correctly. -// Based on the streaming implementation's thinking tag handling. -func (e *KiroExecutor) extractThinkingFromContent(content string) []map[string]interface{} { - var blocks []map[string]interface{} - - if content == "" { - return blocks - } - - // Check if content contains thinking tags at all - if !strings.Contains(content, thinkingStartTag) { - // No thinking tags, return as plain text - return []map[string]interface{}{ - { - "type": "text", - "text": content, - }, - } - } - - log.Debugf("kiro: extractThinkingFromContent - found thinking tags in content (len: %d)", len(content)) - - remaining := content - - for len(remaining) > 0 { - // Look for tag - startIdx := strings.Index(remaining, thinkingStartTag) - - if startIdx == -1 { - // No more thinking tags, add remaining as text - if strings.TrimSpace(remaining) != "" { - blocks = append(blocks, map[string]interface{}{ - "type": "text", - "text": remaining, - }) - } - break - } - - // Add text before thinking tag (if any meaningful content) - if startIdx > 0 { - textBefore := remaining[:startIdx] - if strings.TrimSpace(textBefore) != "" { - blocks = append(blocks, map[string]interface{}{ - "type": "text", - "text": textBefore, - }) - } - } - - // Move past the opening tag - remaining = remaining[startIdx+len(thinkingStartTag):] - - // Find closing tag - endIdx := strings.Index(remaining, thinkingEndTag) - - if endIdx == -1 { - // No closing tag found, treat rest as thinking content (incomplete response) - if strings.TrimSpace(remaining) != "" { - blocks = append(blocks, map[string]interface{}{ - "type": "thinking", - "thinking": remaining, - }) - log.Warnf("kiro: extractThinkingFromContent - missing closing tag") - } - break - } - - // Extract thinking content between tags - thinkContent := remaining[:endIdx] - if strings.TrimSpace(thinkContent) != "" { - blocks = append(blocks, map[string]interface{}{ - "type": "thinking", - "thinking": thinkContent, - }) - log.Debugf("kiro: extractThinkingFromContent - extracted thinking block (len: %d)", len(thinkContent)) - } - - // Move past the closing tag - remaining = remaining[endIdx+len(thinkingEndTag):] - } - - // If no blocks were created (all whitespace), return empty text block - if len(blocks) == 0 { - blocks = append(blocks, map[string]interface{}{ - "type": "text", - "text": "", - }) - } - - return blocks -} - -// NOTE: Tool uses are now extracted from API response, not parsed from text +// NOTE: Response building functions moved to internal/translator/kiro/claude/kiro_claude_response.go +// The executor now uses kiroclaude.BuildClaudeResponse() and kiroclaude.ExtractThinkingFromContent() instead // streamToChannel converts AWS Event Stream to channel-based streaming. // Supports tool calling - emits tool_use content blocks when tools are used. @@ -2155,12 +1360,12 @@ func (e *KiroExecutor) extractThinkingFromContent(content string) []map[string]i func (e *KiroExecutor) streamToChannel(ctx context.Context, body io.Reader, out chan<- cliproxyexecutor.StreamChunk, targetFormat sdktranslator.Format, model string, originalReq, claudeBody []byte, reporter *usageReporter) { reader := bufio.NewReaderSize(body, 20*1024*1024) // 20MB buffer to match other providers var totalUsage usage.Detail - var hasToolUses bool // Track if any tool uses were emitted - var upstreamStopReason string // Track stop_reason from upstream events + var hasToolUses bool // Track if any tool uses were emitted + var upstreamStopReason string // Track stop_reason from upstream events // Tool use state tracking for input buffering and deduplication processedIDs := make(map[string]bool) - var currentToolUse *toolUseState + var currentToolUse *kiroclaude.ToolUseState // NOTE: Duplicate content filtering removed - it was causing legitimate repeated // content (like consecutive newlines) to be incorrectly filtered out. @@ -2185,17 +1390,17 @@ func (e *KiroExecutor) streamToChannel(ctx context.Context, body io.Reader, out // Thinking mode state tracking - based on amq2api implementation // Tracks whether we're inside a block and handles partial tags inThinkBlock := false - pendingStartTagChars := 0 // Number of chars that might be start of - pendingEndTagChars := 0 // Number of chars that might be start of - isThinkingBlockOpen := false // Track if thinking content block is open - thinkingBlockIndex := -1 // Index of the thinking content block + pendingStartTagChars := 0 // Number of chars that might be start of + pendingEndTagChars := 0 // Number of chars that might be start of + isThinkingBlockOpen := false // Track if thinking content block is open + thinkingBlockIndex := -1 // Index of the thinking content block // Pre-calculate input tokens from request if possible // Kiro uses Claude format, so try Claude format first, then OpenAI format, then fallback if enc, err := getTokenizer(model); err == nil { var inputTokens int64 var countMethod string - + // Try Claude format first (Kiro uses Claude API format) if inp, err := countClaudeChatTokens(enc, claudeBody); err == nil && inp > 0 { inputTokens = inp @@ -2212,7 +1417,7 @@ func (e *KiroExecutor) streamToChannel(ctx context.Context, body io.Reader, out } countMethod = "estimate" } - + totalUsage.InputTokens = inputTokens log.Debugf("kiro: streamToChannel pre-calculated input tokens: %d (method: %s, claude body: %d bytes, original req: %d bytes)", totalUsage.InputTokens, countMethod, len(claudeBody), len(originalReq)) @@ -2239,7 +1444,7 @@ func (e *KiroExecutor) streamToChannel(ctx context.Context, body io.Reader, out if eventErr != nil { // Log the error log.Errorf("kiro: streamToChannel error: %v", eventErr) - + // Send error to channel for client notification out <- cliproxyexecutor.StreamChunk{Err: eventErr} return @@ -2247,71 +1452,71 @@ func (e *KiroExecutor) streamToChannel(ctx context.Context, body io.Reader, out if msg == nil { // Normal end of stream (EOF) // Flush any incomplete tool use before ending stream - if currentToolUse != nil && !processedIDs[currentToolUse.toolUseID] { - log.Warnf("kiro: flushing incomplete tool use at EOF: %s (ID: %s)", currentToolUse.name, currentToolUse.toolUseID) - fullInput := currentToolUse.inputBuffer.String() - repairedJSON := repairJSON(fullInput) + if currentToolUse != nil && !processedIDs[currentToolUse.ToolUseID] { + log.Warnf("kiro: flushing incomplete tool use at EOF: %s (ID: %s)", currentToolUse.Name, currentToolUse.ToolUseID) + fullInput := currentToolUse.InputBuffer.String() + repairedJSON := kiroclaude.RepairJSON(fullInput) var finalInput map[string]interface{} if err := json.Unmarshal([]byte(repairedJSON), &finalInput); err != nil { log.Warnf("kiro: failed to parse incomplete tool input at EOF: %v", err) finalInput = make(map[string]interface{}) } - - processedIDs[currentToolUse.toolUseID] = true + + processedIDs[currentToolUse.ToolUseID] = true contentBlockIndex++ - + // Send tool_use content block - blockStart := e.buildClaudeContentBlockStartEvent(contentBlockIndex, "tool_use", currentToolUse.toolUseID, currentToolUse.name) + blockStart := kiroclaude.BuildClaudeContentBlockStartEvent(contentBlockIndex, "tool_use", currentToolUse.ToolUseID, currentToolUse.Name) sseData := sdktranslator.TranslateStream(ctx, sdktranslator.FromString("kiro"), targetFormat, model, originalReq, claudeBody, blockStart, &translatorParam) for _, chunk := range sseData { if chunk != "" { out <- cliproxyexecutor.StreamChunk{Payload: []byte(chunk + "\n\n")} } } - + // Send tool input as delta inputBytes, _ := json.Marshal(finalInput) - inputDelta := e.buildClaudeInputJsonDeltaEvent(string(inputBytes), contentBlockIndex) + inputDelta := kiroclaude.BuildClaudeInputJsonDeltaEvent(string(inputBytes), contentBlockIndex) sseData = sdktranslator.TranslateStream(ctx, sdktranslator.FromString("kiro"), targetFormat, model, originalReq, claudeBody, inputDelta, &translatorParam) for _, chunk := range sseData { if chunk != "" { out <- cliproxyexecutor.StreamChunk{Payload: []byte(chunk + "\n\n")} } } - + // Close block - blockStop := e.buildClaudeContentBlockStopEvent(contentBlockIndex) + blockStop := kiroclaude.BuildClaudeContentBlockStopEvent(contentBlockIndex) sseData = sdktranslator.TranslateStream(ctx, sdktranslator.FromString("kiro"), targetFormat, model, originalReq, claudeBody, blockStop, &translatorParam) for _, chunk := range sseData { if chunk != "" { out <- cliproxyexecutor.StreamChunk{Payload: []byte(chunk + "\n\n")} } } - + hasToolUses = true currentToolUse = nil } - + // Flush any pending tag characters at EOF // These are partial tag prefixes that were held back waiting for more data // Since no more data is coming, output them as regular text var pendingText string if pendingStartTagChars > 0 { - pendingText = thinkingStartTag[:pendingStartTagChars] + pendingText = kirocommon.ThinkingStartTag[:pendingStartTagChars] log.Debugf("kiro: flushing pending start tag chars at EOF: %q", pendingText) pendingStartTagChars = 0 } if pendingEndTagChars > 0 { - pendingText += thinkingEndTag[:pendingEndTagChars] + pendingText += kirocommon.ThinkingEndTag[:pendingEndTagChars] log.Debugf("kiro: flushing pending end tag chars at EOF: %q", pendingText) pendingEndTagChars = 0 } - + // Output pending text if any if pendingText != "" { // If we're in a thinking block, output as thinking content if inThinkBlock && isThinkingBlockOpen { - thinkingEvent := e.buildClaudeThinkingDeltaEvent(pendingText, thinkingBlockIndex) + thinkingEvent := kiroclaude.BuildClaudeThinkingDeltaEvent(pendingText, thinkingBlockIndex) sseData := sdktranslator.TranslateStream(ctx, sdktranslator.FromString("kiro"), targetFormat, model, originalReq, claudeBody, thinkingEvent, &translatorParam) for _, chunk := range sseData { if chunk != "" { @@ -2323,7 +1528,7 @@ func (e *KiroExecutor) streamToChannel(ctx context.Context, body io.Reader, out if !isTextBlockOpen { contentBlockIndex++ isTextBlockOpen = true - blockStart := e.buildClaudeContentBlockStartEvent(contentBlockIndex, "text", "", "") + blockStart := kiroclaude.BuildClaudeContentBlockStartEvent(contentBlockIndex, "text", "", "") sseData := sdktranslator.TranslateStream(ctx, sdktranslator.FromString("kiro"), targetFormat, model, originalReq, claudeBody, blockStart, &translatorParam) for _, chunk := range sseData { if chunk != "" { @@ -2331,8 +1536,8 @@ func (e *KiroExecutor) streamToChannel(ctx context.Context, body io.Reader, out } } } - - claudeEvent := e.buildClaudeStreamEvent(pendingText, contentBlockIndex) + + claudeEvent := kiroclaude.BuildClaudeStreamEvent(pendingText, contentBlockIndex) sseData := sdktranslator.TranslateStream(ctx, sdktranslator.FromString("kiro"), targetFormat, model, originalReq, claudeBody, claudeEvent, &translatorParam) for _, chunk := range sseData { if chunk != "" { @@ -2386,18 +1591,18 @@ func (e *KiroExecutor) streamToChannel(ctx context.Context, body io.Reader, out // Extract stop_reason from various event formats (streaming) // Kiro/Amazon Q API may include stop_reason in different locations - if sr := getString(event, "stop_reason"); sr != "" { + if sr := kirocommon.GetString(event, "stop_reason"); sr != "" { upstreamStopReason = sr log.Debugf("kiro: streamToChannel found stop_reason (top-level): %s", upstreamStopReason) } - if sr := getString(event, "stopReason"); sr != "" { + if sr := kirocommon.GetString(event, "stopReason"); sr != "" { upstreamStopReason = sr log.Debugf("kiro: streamToChannel found stopReason (top-level): %s", upstreamStopReason) } // Send message_start on first event if !messageStartSent { - msgStart := e.buildClaudeMessageStartEvent(model, totalUsage.InputTokens) + msgStart := kiroclaude.BuildClaudeMessageStartEvent(model, totalUsage.InputTokens) sseData := sdktranslator.TranslateStream(ctx, sdktranslator.FromString("kiro"), targetFormat, model, originalReq, claudeBody, msgStart, &translatorParam) for _, chunk := range sseData { if chunk != "" { @@ -2415,11 +1620,11 @@ func (e *KiroExecutor) streamToChannel(ctx context.Context, body io.Reader, out case "messageStopEvent", "message_stop": // Handle message stop events which may contain stop_reason - if sr := getString(event, "stop_reason"); sr != "" { + if sr := kirocommon.GetString(event, "stop_reason"); sr != "" { upstreamStopReason = sr log.Debugf("kiro: streamToChannel found stop_reason in messageStopEvent: %s", upstreamStopReason) } - if sr := getString(event, "stopReason"); sr != "" { + if sr := kirocommon.GetString(event, "stopReason"); sr != "" { upstreamStopReason = sr log.Debugf("kiro: streamToChannel found stopReason in messageStopEvent: %s", upstreamStopReason) } @@ -2427,17 +1632,17 @@ func (e *KiroExecutor) streamToChannel(ctx context.Context, body io.Reader, out case "assistantResponseEvent": var contentDelta string var toolUses []map[string]interface{} - + if assistantResp, ok := event["assistantResponseEvent"].(map[string]interface{}); ok { if c, ok := assistantResp["content"].(string); ok { contentDelta = c } // Extract stop_reason from assistantResponseEvent - if sr := getString(assistantResp, "stop_reason"); sr != "" { + if sr := kirocommon.GetString(assistantResp, "stop_reason"); sr != "" { upstreamStopReason = sr log.Debugf("kiro: streamToChannel found stop_reason in assistantResponseEvent: %s", upstreamStopReason) } - if sr := getString(assistantResp, "stopReason"); sr != "" { + if sr := kirocommon.GetString(assistantResp, "stopReason"); sr != "" { upstreamStopReason = sr log.Debugf("kiro: streamToChannel found stopReason in assistantResponseEvent: %s", upstreamStopReason) } @@ -2473,7 +1678,7 @@ func (e *KiroExecutor) streamToChannel(ctx context.Context, body io.Reader, out outputLen += len(contentDelta) // Accumulate content for streaming token calculation accumulatedContent.WriteString(contentDelta) - + // Real-time usage estimation: Check if we should send a usage update // This helps clients track context usage during long thinking sessions shouldSendUsageUpdate := false @@ -2482,7 +1687,7 @@ func (e *KiroExecutor) streamToChannel(ctx context.Context, body io.Reader, out } else if time.Since(lastUsageUpdateTime) >= usageUpdateTimeInterval && accumulatedContent.Len() > lastUsageUpdateLen { shouldSendUsageUpdate = true } - + if shouldSendUsageUpdate { // Calculate current output tokens using tiktoken var currentOutputTokens int64 @@ -2498,24 +1703,24 @@ func (e *KiroExecutor) streamToChannel(ctx context.Context, body io.Reader, out currentOutputTokens = 1 } } - + // Only send update if token count has changed significantly (at least 10 tokens) if currentOutputTokens > lastReportedOutputTokens+10 { // Send ping event with usage information // This is a non-blocking update that clients can optionally process - pingEvent := e.buildClaudePingEventWithUsage(totalUsage.InputTokens, currentOutputTokens) + pingEvent := kiroclaude.BuildClaudePingEventWithUsage(totalUsage.InputTokens, currentOutputTokens) sseData := sdktranslator.TranslateStream(ctx, sdktranslator.FromString("kiro"), targetFormat, model, originalReq, claudeBody, pingEvent, &translatorParam) for _, chunk := range sseData { if chunk != "" { out <- cliproxyexecutor.StreamChunk{Payload: []byte(chunk + "\n\n")} } } - + lastReportedOutputTokens = currentOutputTokens log.Debugf("kiro: sent real-time usage update - input: %d, output: %d (accumulated: %d chars)", totalUsage.InputTokens, currentOutputTokens, accumulatedContent.Len()) } - + lastUsageUpdateLen = accumulatedContent.Len() lastUsageUpdateTime = time.Now() } @@ -2526,20 +1731,20 @@ func (e *KiroExecutor) streamToChannel(ctx context.Context, body io.Reader, out // If we have pending start tag chars from previous chunk, prepend them if pendingStartTagChars > 0 { - remaining = thinkingStartTag[:pendingStartTagChars] + remaining + remaining = kirocommon.ThinkingStartTag[:pendingStartTagChars] + remaining pendingStartTagChars = 0 } - + // If we have pending end tag chars from previous chunk, prepend them if pendingEndTagChars > 0 { - remaining = thinkingEndTag[:pendingEndTagChars] + remaining + remaining = kirocommon.ThinkingEndTag[:pendingEndTagChars] + remaining pendingEndTagChars = 0 } for len(remaining) > 0 { if inThinkBlock { // Inside thinking block - look for end tag - endIdx := strings.Index(remaining, thinkingEndTag) + endIdx := strings.Index(remaining, kirocommon.ThinkingEndTag) if endIdx >= 0 { // Found end tag - emit any content before end tag, then close block thinkContent := remaining[:endIdx] @@ -2550,7 +1755,7 @@ func (e *KiroExecutor) streamToChannel(ctx context.Context, body io.Reader, out contentBlockIndex++ thinkingBlockIndex = contentBlockIndex isThinkingBlockOpen = true - blockStart := e.buildClaudeContentBlockStartEvent(thinkingBlockIndex, "thinking", "", "") + blockStart := kiroclaude.BuildClaudeContentBlockStartEvent(thinkingBlockIndex, "thinking", "", "") sseData := sdktranslator.TranslateStream(ctx, sdktranslator.FromString("kiro"), targetFormat, model, originalReq, claudeBody, blockStart, &translatorParam) for _, chunk := range sseData { if chunk != "" { @@ -2558,9 +1763,9 @@ func (e *KiroExecutor) streamToChannel(ctx context.Context, body io.Reader, out } } } - + // Send thinking delta immediately - thinkingEvent := e.buildClaudeThinkingDeltaEvent(thinkContent, thinkingBlockIndex) + thinkingEvent := kiroclaude.BuildClaudeThinkingDeltaEvent(thinkContent, thinkingBlockIndex) sseData := sdktranslator.TranslateStream(ctx, sdktranslator.FromString("kiro"), targetFormat, model, originalReq, claudeBody, thinkingEvent, &translatorParam) for _, chunk := range sseData { if chunk != "" { @@ -2574,7 +1779,7 @@ func (e *KiroExecutor) streamToChannel(ctx context.Context, body io.Reader, out // Close thinking block if isThinkingBlockOpen { - blockStop := e.buildClaudeContentBlockStopEvent(thinkingBlockIndex) + blockStop := kiroclaude.BuildClaudeContentBlockStopEvent(thinkingBlockIndex) sseData := sdktranslator.TranslateStream(ctx, sdktranslator.FromString("kiro"), targetFormat, model, originalReq, claudeBody, blockStop, &translatorParam) for _, chunk := range sseData { if chunk != "" { @@ -2585,13 +1790,13 @@ func (e *KiroExecutor) streamToChannel(ctx context.Context, body io.Reader, out } inThinkBlock = false - remaining = remaining[endIdx+len(thinkingEndTag):] + remaining = remaining[endIdx+len(kirocommon.ThinkingEndTag):] log.Debugf("kiro: exited thinking block") } else { // No end tag found - TRUE STREAMING: emit content immediately // Only save potential partial tag length for next iteration - pendingEnd := pendingTagSuffix(remaining, thinkingEndTag) - + pendingEnd := kiroclaude.PendingTagSuffix(remaining, kirocommon.ThinkingEndTag) + // Calculate content to emit immediately (excluding potential partial tag) var contentToEmit string if pendingEnd > 0 { @@ -2601,7 +1806,7 @@ func (e *KiroExecutor) streamToChannel(ctx context.Context, body io.Reader, out } else { contentToEmit = remaining } - + // TRUE STREAMING: Emit thinking content immediately if contentToEmit != "" { // Start thinking block if not open @@ -2609,39 +1814,7 @@ func (e *KiroExecutor) streamToChannel(ctx context.Context, body io.Reader, out contentBlockIndex++ thinkingBlockIndex = contentBlockIndex isThinkingBlockOpen = true - blockStart := e.buildClaudeContentBlockStartEvent(thinkingBlockIndex, "thinking", "", "") - sseData := sdktranslator.TranslateStream(ctx, sdktranslator.FromString("kiro"), targetFormat, model, originalReq, claudeBody, blockStart, &translatorParam) - for _, chunk := range sseData { - if chunk != "" { - out <- cliproxyexecutor.StreamChunk{Payload: []byte(chunk + "\n\n")} - } - } - } - - // Send thinking delta immediately - TRUE STREAMING! - thinkingEvent := e.buildClaudeThinkingDeltaEvent(contentToEmit, thinkingBlockIndex) - sseData := sdktranslator.TranslateStream(ctx, sdktranslator.FromString("kiro"), targetFormat, model, originalReq, claudeBody, thinkingEvent, &translatorParam) - for _, chunk := range sseData { - if chunk != "" { - out <- cliproxyexecutor.StreamChunk{Payload: []byte(chunk + "\n\n")} - } - } - } - - remaining = "" - } - } else { - // Outside thinking block - look for start tag - startIdx := strings.Index(remaining, thinkingStartTag) - if startIdx >= 0 { - // Found start tag - emit text before it and switch to thinking mode - textBefore := remaining[:startIdx] - if textBefore != "" { - // Start text content block if needed - if !isTextBlockOpen { - contentBlockIndex++ - isTextBlockOpen = true - blockStart := e.buildClaudeContentBlockStartEvent(contentBlockIndex, "text", "", "") + blockStart := kiroclaude.BuildClaudeContentBlockStartEvent(thinkingBlockIndex, "thinking", "", "") sseData := sdktranslator.TranslateStream(ctx, sdktranslator.FromString("kiro"), targetFormat, model, originalReq, claudeBody, blockStart, &translatorParam) for _, chunk := range sseData { if chunk != "" { @@ -2650,7 +1823,39 @@ func (e *KiroExecutor) streamToChannel(ctx context.Context, body io.Reader, out } } - claudeEvent := e.buildClaudeStreamEvent(textBefore, contentBlockIndex) + // Send thinking delta immediately - TRUE STREAMING! + thinkingEvent := kiroclaude.BuildClaudeThinkingDeltaEvent(contentToEmit, thinkingBlockIndex) + sseData := sdktranslator.TranslateStream(ctx, sdktranslator.FromString("kiro"), targetFormat, model, originalReq, claudeBody, thinkingEvent, &translatorParam) + for _, chunk := range sseData { + if chunk != "" { + out <- cliproxyexecutor.StreamChunk{Payload: []byte(chunk + "\n\n")} + } + } + } + + remaining = "" + } + } else { + // Outside thinking block - look for start tag + startIdx := strings.Index(remaining, kirocommon.ThinkingStartTag) + if startIdx >= 0 { + // Found start tag - emit text before it and switch to thinking mode + textBefore := remaining[:startIdx] + if textBefore != "" { + // Start text content block if needed + if !isTextBlockOpen { + contentBlockIndex++ + isTextBlockOpen = true + blockStart := kiroclaude.BuildClaudeContentBlockStartEvent(contentBlockIndex, "text", "", "") + sseData := sdktranslator.TranslateStream(ctx, sdktranslator.FromString("kiro"), targetFormat, model, originalReq, claudeBody, blockStart, &translatorParam) + for _, chunk := range sseData { + if chunk != "" { + out <- cliproxyexecutor.StreamChunk{Payload: []byte(chunk + "\n\n")} + } + } + } + + claudeEvent := kiroclaude.BuildClaudeStreamEvent(textBefore, contentBlockIndex) sseData := sdktranslator.TranslateStream(ctx, sdktranslator.FromString("kiro"), targetFormat, model, originalReq, claudeBody, claudeEvent, &translatorParam) for _, chunk := range sseData { if chunk != "" { @@ -2661,7 +1866,7 @@ func (e *KiroExecutor) streamToChannel(ctx context.Context, body io.Reader, out // Close text block before starting thinking block if isTextBlockOpen { - blockStop := e.buildClaudeContentBlockStopEvent(contentBlockIndex) + blockStop := kiroclaude.BuildClaudeContentBlockStopEvent(contentBlockIndex) sseData := sdktranslator.TranslateStream(ctx, sdktranslator.FromString("kiro"), targetFormat, model, originalReq, claudeBody, blockStop, &translatorParam) for _, chunk := range sseData { if chunk != "" { @@ -2672,11 +1877,11 @@ func (e *KiroExecutor) streamToChannel(ctx context.Context, body io.Reader, out } inThinkBlock = true - remaining = remaining[startIdx+len(thinkingStartTag):] + remaining = remaining[startIdx+len(kirocommon.ThinkingStartTag):] log.Debugf("kiro: entered thinking block") } else { // No start tag found - check for partial start tag at buffer end - pendingStart := pendingTagSuffix(remaining, thinkingStartTag) + pendingStart := kiroclaude.PendingTagSuffix(remaining, kirocommon.ThinkingStartTag) if pendingStart > 0 { // Emit text except potential partial tag textToEmit := remaining[:len(remaining)-pendingStart] @@ -2685,7 +1890,7 @@ func (e *KiroExecutor) streamToChannel(ctx context.Context, body io.Reader, out if !isTextBlockOpen { contentBlockIndex++ isTextBlockOpen = true - blockStart := e.buildClaudeContentBlockStartEvent(contentBlockIndex, "text", "", "") + blockStart := kiroclaude.BuildClaudeContentBlockStartEvent(contentBlockIndex, "text", "", "") sseData := sdktranslator.TranslateStream(ctx, sdktranslator.FromString("kiro"), targetFormat, model, originalReq, claudeBody, blockStart, &translatorParam) for _, chunk := range sseData { if chunk != "" { @@ -2694,7 +1899,7 @@ func (e *KiroExecutor) streamToChannel(ctx context.Context, body io.Reader, out } } - claudeEvent := e.buildClaudeStreamEvent(textToEmit, contentBlockIndex) + claudeEvent := kiroclaude.BuildClaudeStreamEvent(textToEmit, contentBlockIndex) sseData := sdktranslator.TranslateStream(ctx, sdktranslator.FromString("kiro"), targetFormat, model, originalReq, claudeBody, claudeEvent, &translatorParam) for _, chunk := range sseData { if chunk != "" { @@ -2711,7 +1916,7 @@ func (e *KiroExecutor) streamToChannel(ctx context.Context, body io.Reader, out if !isTextBlockOpen { contentBlockIndex++ isTextBlockOpen = true - blockStart := e.buildClaudeContentBlockStartEvent(contentBlockIndex, "text", "", "") + blockStart := kiroclaude.BuildClaudeContentBlockStartEvent(contentBlockIndex, "text", "", "") sseData := sdktranslator.TranslateStream(ctx, sdktranslator.FromString("kiro"), targetFormat, model, originalReq, claudeBody, blockStart, &translatorParam) for _, chunk := range sseData { if chunk != "" { @@ -2720,7 +1925,7 @@ func (e *KiroExecutor) streamToChannel(ctx context.Context, body io.Reader, out } } - claudeEvent := e.buildClaudeStreamEvent(remaining, contentBlockIndex) + claudeEvent := kiroclaude.BuildClaudeStreamEvent(remaining, contentBlockIndex) sseData := sdktranslator.TranslateStream(ctx, sdktranslator.FromString("kiro"), targetFormat, model, originalReq, claudeBody, claudeEvent, &translatorParam) for _, chunk := range sseData { if chunk != "" { @@ -2734,22 +1939,22 @@ func (e *KiroExecutor) streamToChannel(ctx context.Context, body io.Reader, out } } } - + // Handle tool uses in response (with deduplication) for _, tu := range toolUses { - toolUseID := getString(tu, "toolUseId") - + toolUseID := kirocommon.GetString(tu, "toolUseId") + // Check for duplicate if processedIDs[toolUseID] { log.Debugf("kiro: skipping duplicate tool use in stream: %s", toolUseID) continue } processedIDs[toolUseID] = true - + hasToolUses = true // Close text block if open before starting tool_use block if isTextBlockOpen && contentBlockIndex >= 0 { - blockStop := e.buildClaudeContentBlockStopEvent(contentBlockIndex) + blockStop := kiroclaude.BuildClaudeContentBlockStopEvent(contentBlockIndex) sseData := sdktranslator.TranslateStream(ctx, sdktranslator.FromString("kiro"), targetFormat, model, originalReq, claudeBody, blockStop, &translatorParam) for _, chunk := range sseData { if chunk != "" { @@ -2758,19 +1963,19 @@ func (e *KiroExecutor) streamToChannel(ctx context.Context, body io.Reader, out } isTextBlockOpen = false } - + // Emit tool_use content block contentBlockIndex++ - toolName := getString(tu, "name") - - blockStart := e.buildClaudeContentBlockStartEvent(contentBlockIndex, "tool_use", toolUseID, toolName) + toolName := kirocommon.GetString(tu, "name") + + blockStart := kiroclaude.BuildClaudeContentBlockStartEvent(contentBlockIndex, "tool_use", toolUseID, toolName) sseData := sdktranslator.TranslateStream(ctx, sdktranslator.FromString("kiro"), targetFormat, model, originalReq, claudeBody, blockStart, &translatorParam) for _, chunk := range sseData { if chunk != "" { out <- cliproxyexecutor.StreamChunk{Payload: []byte(chunk + "\n\n")} } } - + // Send input_json_delta with the tool input if input, ok := tu["input"].(map[string]interface{}); ok { inputJSON, err := json.Marshal(input) @@ -2778,7 +1983,7 @@ func (e *KiroExecutor) streamToChannel(ctx context.Context, body io.Reader, out log.Debugf("kiro: failed to marshal tool input: %v", err) // Don't continue - still need to close the block } else { - inputDelta := e.buildClaudeInputJsonDeltaEvent(string(inputJSON), contentBlockIndex) + inputDelta := kiroclaude.BuildClaudeInputJsonDeltaEvent(string(inputJSON), contentBlockIndex) sseData = sdktranslator.TranslateStream(ctx, sdktranslator.FromString("kiro"), targetFormat, model, originalReq, claudeBody, inputDelta, &translatorParam) for _, chunk := range sseData { if chunk != "" { @@ -2787,9 +1992,9 @@ func (e *KiroExecutor) streamToChannel(ctx context.Context, body io.Reader, out } } } - + // Close tool_use block (always close even if input marshal failed) - blockStop := e.buildClaudeContentBlockStopEvent(contentBlockIndex) + blockStop := kiroclaude.BuildClaudeContentBlockStopEvent(contentBlockIndex) sseData = sdktranslator.TranslateStream(ctx, sdktranslator.FromString("kiro"), targetFormat, model, originalReq, claudeBody, blockStop, &translatorParam) for _, chunk := range sseData { if chunk != "" { @@ -2800,16 +2005,16 @@ func (e *KiroExecutor) streamToChannel(ctx context.Context, body io.Reader, out case "toolUseEvent": // Handle dedicated tool use events with input buffering - completedToolUses, newState := e.processToolUseEvent(event, currentToolUse, processedIDs) + completedToolUses, newState := kiroclaude.ProcessToolUseEvent(event, currentToolUse, processedIDs) currentToolUse = newState - + // Emit completed tool uses for _, tu := range completedToolUses { hasToolUses = true - + // Close text block if open if isTextBlockOpen && contentBlockIndex >= 0 { - blockStop := e.buildClaudeContentBlockStopEvent(contentBlockIndex) + blockStop := kiroclaude.BuildClaudeContentBlockStopEvent(contentBlockIndex) sseData := sdktranslator.TranslateStream(ctx, sdktranslator.FromString("kiro"), targetFormat, model, originalReq, claudeBody, blockStop, &translatorParam) for _, chunk := range sseData { if chunk != "" { @@ -2818,23 +2023,23 @@ func (e *KiroExecutor) streamToChannel(ctx context.Context, body io.Reader, out } isTextBlockOpen = false } - + contentBlockIndex++ - - blockStart := e.buildClaudeContentBlockStartEvent(contentBlockIndex, "tool_use", tu.ToolUseID, tu.Name) + + blockStart := kiroclaude.BuildClaudeContentBlockStartEvent(contentBlockIndex, "tool_use", tu.ToolUseID, tu.Name) sseData := sdktranslator.TranslateStream(ctx, sdktranslator.FromString("kiro"), targetFormat, model, originalReq, claudeBody, blockStart, &translatorParam) for _, chunk := range sseData { if chunk != "" { out <- cliproxyexecutor.StreamChunk{Payload: []byte(chunk + "\n\n")} } } - + if tu.Input != nil { inputJSON, err := json.Marshal(tu.Input) if err != nil { log.Debugf("kiro: failed to marshal tool input in toolUseEvent: %v", err) } else { - inputDelta := e.buildClaudeInputJsonDeltaEvent(string(inputJSON), contentBlockIndex) + inputDelta := kiroclaude.BuildClaudeInputJsonDeltaEvent(string(inputJSON), contentBlockIndex) sseData = sdktranslator.TranslateStream(ctx, sdktranslator.FromString("kiro"), targetFormat, model, originalReq, claudeBody, inputDelta, &translatorParam) for _, chunk := range sseData { if chunk != "" { @@ -2843,8 +2048,8 @@ func (e *KiroExecutor) streamToChannel(ctx context.Context, body io.Reader, out } } } - - blockStop := e.buildClaudeContentBlockStopEvent(contentBlockIndex) + + blockStop := kiroclaude.BuildClaudeContentBlockStopEvent(contentBlockIndex) sseData = sdktranslator.TranslateStream(ctx, sdktranslator.FromString("kiro"), targetFormat, model, originalReq, claudeBody, blockStop, &translatorParam) for _, chunk := range sseData { if chunk != "" { @@ -2875,7 +2080,7 @@ func (e *KiroExecutor) streamToChannel(ctx context.Context, body io.Reader, out // Close content block if open if isTextBlockOpen && contentBlockIndex >= 0 { - blockStop := e.buildClaudeContentBlockStopEvent(contentBlockIndex) + blockStop := kiroclaude.BuildClaudeContentBlockStopEvent(contentBlockIndex) sseData := sdktranslator.TranslateStream(ctx, sdktranslator.FromString("kiro"), targetFormat, model, originalReq, claudeBody, blockStop, &translatorParam) for _, chunk := range sseData { if chunk != "" { @@ -2935,7 +2140,7 @@ func (e *KiroExecutor) streamToChannel(ctx context.Context, body io.Reader, out } // Send message_delta event - msgDelta := e.buildClaudeMessageDeltaEvent(stopReason, totalUsage) + msgDelta := kiroclaude.BuildClaudeMessageDeltaEvent(stopReason, totalUsage) sseData := sdktranslator.TranslateStream(ctx, sdktranslator.FromString("kiro"), targetFormat, model, originalReq, claudeBody, msgDelta, &translatorParam) for _, chunk := range sseData { if chunk != "" { @@ -2944,7 +2149,7 @@ func (e *KiroExecutor) streamToChannel(ctx context.Context, body io.Reader, out } // Send message_stop event separately - msgStop := e.buildClaudeMessageStopOnlyEvent() + msgStop := kiroclaude.BuildClaudeMessageStopOnlyEvent() sseData = sdktranslator.TranslateStream(ctx, sdktranslator.FromString("kiro"), targetFormat, model, originalReq, claudeBody, msgStop, &translatorParam) for _, chunk := range sseData { if chunk != "" { @@ -2954,180 +2159,8 @@ func (e *KiroExecutor) streamToChannel(ctx context.Context, body io.Reader, out // reporter.publish is called via defer } - -// Claude SSE event builders -// All builders return complete SSE format with "event:" line for Claude client compatibility. -func (e *KiroExecutor) buildClaudeMessageStartEvent(model string, inputTokens int64) []byte { - event := map[string]interface{}{ - "type": "message_start", - "message": map[string]interface{}{ - "id": "msg_" + uuid.New().String()[:24], - "type": "message", - "role": "assistant", - "content": []interface{}{}, - "model": model, - "stop_reason": nil, - "stop_sequence": nil, - "usage": map[string]interface{}{"input_tokens": inputTokens, "output_tokens": 0}, - }, - } - result, _ := json.Marshal(event) - return []byte("event: message_start\ndata: " + string(result)) -} - -func (e *KiroExecutor) buildClaudeContentBlockStartEvent(index int, blockType, toolUseID, toolName string) []byte { - var contentBlock map[string]interface{} - switch blockType { - case "tool_use": - contentBlock = map[string]interface{}{ - "type": "tool_use", - "id": toolUseID, - "name": toolName, - "input": map[string]interface{}{}, - } - case "thinking": - contentBlock = map[string]interface{}{ - "type": "thinking", - "thinking": "", - } - default: - contentBlock = map[string]interface{}{ - "type": "text", - "text": "", - } - } - - event := map[string]interface{}{ - "type": "content_block_start", - "index": index, - "content_block": contentBlock, - } - result, _ := json.Marshal(event) - return []byte("event: content_block_start\ndata: " + string(result)) -} - -func (e *KiroExecutor) buildClaudeStreamEvent(contentDelta string, index int) []byte { - event := map[string]interface{}{ - "type": "content_block_delta", - "index": index, - "delta": map[string]interface{}{ - "type": "text_delta", - "text": contentDelta, - }, - } - result, _ := json.Marshal(event) - return []byte("event: content_block_delta\ndata: " + string(result)) -} - -// buildClaudeInputJsonDeltaEvent creates an input_json_delta event for tool use streaming -func (e *KiroExecutor) buildClaudeInputJsonDeltaEvent(partialJSON string, index int) []byte { - event := map[string]interface{}{ - "type": "content_block_delta", - "index": index, - "delta": map[string]interface{}{ - "type": "input_json_delta", - "partial_json": partialJSON, - }, - } - result, _ := json.Marshal(event) - return []byte("event: content_block_delta\ndata: " + string(result)) -} - -func (e *KiroExecutor) buildClaudeContentBlockStopEvent(index int) []byte { - event := map[string]interface{}{ - "type": "content_block_stop", - "index": index, - } - result, _ := json.Marshal(event) - return []byte("event: content_block_stop\ndata: " + string(result)) -} - -// buildClaudeMessageDeltaEvent creates the message_delta event with stop_reason and usage. -func (e *KiroExecutor) buildClaudeMessageDeltaEvent(stopReason string, usageInfo usage.Detail) []byte { - deltaEvent := map[string]interface{}{ - "type": "message_delta", - "delta": map[string]interface{}{ - "stop_reason": stopReason, - "stop_sequence": nil, - }, - "usage": map[string]interface{}{ - "input_tokens": usageInfo.InputTokens, - "output_tokens": usageInfo.OutputTokens, - }, - } - deltaResult, _ := json.Marshal(deltaEvent) - return []byte("event: message_delta\ndata: " + string(deltaResult)) -} - -// buildClaudeMessageStopOnlyEvent creates only the message_stop event. -func (e *KiroExecutor) buildClaudeMessageStopOnlyEvent() []byte { - stopEvent := map[string]interface{}{ - "type": "message_stop", - } - stopResult, _ := json.Marshal(stopEvent) - return []byte("event: message_stop\ndata: " + string(stopResult)) -} - -// buildClaudeFinalEvent constructs the final Claude-style event. -func (e *KiroExecutor) buildClaudeFinalEvent() []byte { - event := map[string]interface{}{ - "type": "message_stop", - } - result, _ := json.Marshal(event) - return []byte("event: message_stop\ndata: " + string(result)) -} - -// buildClaudePingEventWithUsage creates a ping event with embedded usage information. -// This is used for real-time usage estimation during streaming. -// The usage field is a non-standard extension that clients can optionally process. -// Clients that don't recognize the usage field will simply ignore it. -func (e *KiroExecutor) buildClaudePingEventWithUsage(inputTokens, outputTokens int64) []byte { - event := map[string]interface{}{ - "type": "ping", - "usage": map[string]interface{}{ - "input_tokens": inputTokens, - "output_tokens": outputTokens, - "total_tokens": inputTokens + outputTokens, - "estimated": true, // Flag to indicate this is an estimate, not final - }, - } - result, _ := json.Marshal(event) - return []byte("event: ping\ndata: " + string(result)) -} - -// buildClaudeThinkingDeltaEvent creates a thinking_delta event for Claude API compatibility. -// This is used when streaming thinking content wrapped in tags. -func (e *KiroExecutor) buildClaudeThinkingDeltaEvent(thinkingDelta string, index int) []byte { - event := map[string]interface{}{ - "type": "content_block_delta", - "index": index, - "delta": map[string]interface{}{ - "type": "thinking_delta", - "thinking": thinkingDelta, - }, - } - result, _ := json.Marshal(event) - return []byte("event: content_block_delta\ndata: " + string(result)) -} - -// pendingTagSuffix detects if the buffer ends with a partial prefix of the given tag. -// Returns the length of the partial match (0 if no match). -// Based on amq2api implementation for handling cross-chunk tag boundaries. -func pendingTagSuffix(buffer, tag string) int { - if buffer == "" || tag == "" { - return 0 - } - maxLen := len(buffer) - if maxLen > len(tag)-1 { - maxLen = len(tag) - 1 - } - for length := maxLen; length > 0; length-- { - if len(buffer) >= length && buffer[len(buffer)-length:] == tag[:length] { - return length - } - } - return 0 -} +// NOTE: Claude SSE event builders moved to internal/translator/kiro/claude/kiro_claude_stream.go +// The executor now uses kiroclaude.BuildClaude*Event() functions instead // CountTokens is not supported for Kiro provider. // Kiro/Amazon Q backend doesn't expose a token counting API. @@ -3281,209 +2314,6 @@ func (e *KiroExecutor) Refresh(ctx context.Context, auth *cliproxyauth.Auth) (*c return updated, nil } -// streamEventStream converts AWS Event Stream to SSE (legacy method for gin.Context). -// Note: For full tool calling support, use streamToChannel instead. -func (e *KiroExecutor) streamEventStream(ctx context.Context, body io.Reader, c *gin.Context, targetFormat sdktranslator.Format, model string, originalReq, claudeBody []byte, reporter *usageReporter) error { - reader := bufio.NewReader(body) - var totalUsage usage.Detail - - // Translator param for maintaining tool call state across streaming events - var translatorParam any - - // Pre-calculate input tokens from request if possible - if enc, err := getTokenizer(model); err == nil { - // Try OpenAI format first, then fall back to raw byte count estimation - if inp, err := countOpenAIChatTokens(enc, originalReq); err == nil && inp > 0 { - totalUsage.InputTokens = inp - } else { - // Fallback: estimate from raw request size (roughly 4 chars per token) - totalUsage.InputTokens = int64(len(originalReq) / 4) - if totalUsage.InputTokens == 0 && len(originalReq) > 0 { - totalUsage.InputTokens = 1 - } - } - log.Debugf("kiro: streamEventStream pre-calculated input tokens: %d (request size: %d bytes)", totalUsage.InputTokens, len(originalReq)) - } - - contentBlockIndex := -1 - messageStartSent := false - isBlockOpen := false - var outputLen int - - for { - select { - case <-ctx.Done(): - return ctx.Err() - default: - } - - prelude := make([]byte, 8) - _, err := io.ReadFull(reader, prelude) - if err == io.EOF { - break - } - if err != nil { - return fmt.Errorf("failed to read prelude: %w", err) - } - - totalLen := binary.BigEndian.Uint32(prelude[0:4]) - if totalLen < 8 { - return fmt.Errorf("invalid message length: %d", totalLen) - } - if totalLen > kiroMaxMessageSize { - return fmt.Errorf("message too large: %d bytes", totalLen) - } - headersLen := binary.BigEndian.Uint32(prelude[4:8]) - - remaining := make([]byte, totalLen-8) - _, err = io.ReadFull(reader, remaining) - if err != nil { - return fmt.Errorf("failed to read message: %w", err) - } - - // Validate headersLen to prevent slice out of bounds - if headersLen+4 > uint32(len(remaining)) { - log.Warnf("kiro: invalid headersLen %d exceeds remaining buffer %d", headersLen, len(remaining)) - continue - } - - eventType := e.extractEventType(remaining[:headersLen+4]) - - payloadStart := 4 + headersLen - payloadEnd := uint32(len(remaining)) - 4 - if payloadStart >= payloadEnd { - continue - } - - payload := remaining[payloadStart:payloadEnd] - appendAPIResponseChunk(ctx, e.cfg, payload) - - var event map[string]interface{} - if err := json.Unmarshal(payload, &event); err != nil { - log.Warnf("kiro: failed to unmarshal event payload: %v, raw: %s", err, string(payload)) - continue - } - - if !messageStartSent { - msgStart := e.buildClaudeMessageStartEvent(model, totalUsage.InputTokens) - sseData := sdktranslator.TranslateStream(ctx, sdktranslator.FromString("kiro"), targetFormat, model, originalReq, claudeBody, msgStart, &translatorParam) - for _, chunk := range sseData { - if chunk != "" { - c.Writer.Write([]byte(chunk + "\n\n")) - } - } - c.Writer.Flush() - messageStartSent = true - } - - switch eventType { - case "assistantResponseEvent": - var contentDelta string - if assistantResp, ok := event["assistantResponseEvent"].(map[string]interface{}); ok { - if ct, ok := assistantResp["content"].(string); ok { - contentDelta = ct - } - } - if contentDelta == "" { - if ct, ok := event["content"].(string); ok { - contentDelta = ct - } - } - - if contentDelta != "" { - outputLen += len(contentDelta) - // Start text content block if needed - if !isBlockOpen { - contentBlockIndex++ - isBlockOpen = true - blockStart := e.buildClaudeContentBlockStartEvent(contentBlockIndex, "text", "", "") - sseData := sdktranslator.TranslateStream(ctx, sdktranslator.FromString("kiro"), targetFormat, model, originalReq, claudeBody, blockStart, &translatorParam) - for _, chunk := range sseData { - if chunk != "" { - c.Writer.Write([]byte(chunk + "\n\n")) - } - } - c.Writer.Flush() - } - - claudeEvent := e.buildClaudeStreamEvent(contentDelta, contentBlockIndex) - sseData := sdktranslator.TranslateStream(ctx, sdktranslator.FromString("kiro"), targetFormat, model, originalReq, claudeBody, claudeEvent, &translatorParam) - for _, chunk := range sseData { - if chunk != "" { - c.Writer.Write([]byte(chunk + "\n\n")) - } - } - c.Writer.Flush() - } - - // Note: For full toolUseEvent support, use streamToChannel - - case "supplementaryWebLinksEvent": - if inputTokens, ok := event["inputTokens"].(float64); ok { - totalUsage.InputTokens = int64(inputTokens) - } - if outputTokens, ok := event["outputTokens"].(float64); ok { - totalUsage.OutputTokens = int64(outputTokens) - } - } - - if usageEvent, ok := event["supplementaryWebLinksEvent"].(map[string]interface{}); ok { - if inputTokens, ok := usageEvent["inputTokens"].(float64); ok { - totalUsage.InputTokens = int64(inputTokens) - } - if outputTokens, ok := usageEvent["outputTokens"].(float64); ok { - totalUsage.OutputTokens = int64(outputTokens) - } - } - } - - // Close content block if open - if isBlockOpen && contentBlockIndex >= 0 { - blockStop := e.buildClaudeContentBlockStopEvent(contentBlockIndex) - sseData := sdktranslator.TranslateStream(ctx, sdktranslator.FromString("kiro"), targetFormat, model, originalReq, claudeBody, blockStop, &translatorParam) - for _, chunk := range sseData { - if chunk != "" { - c.Writer.Write([]byte(chunk + "\n\n")) - } - } - c.Writer.Flush() - } - - // Fallback for output tokens if not received from upstream - if totalUsage.OutputTokens == 0 && outputLen > 0 { - totalUsage.OutputTokens = int64(outputLen / 4) - if totalUsage.OutputTokens == 0 { - totalUsage.OutputTokens = 1 - } - } - totalUsage.TotalTokens = totalUsage.InputTokens + totalUsage.OutputTokens - - // Send message_delta event - msgDelta := e.buildClaudeMessageDeltaEvent("end_turn", totalUsage) - sseData := sdktranslator.TranslateStream(ctx, sdktranslator.FromString("kiro"), targetFormat, model, originalReq, claudeBody, msgDelta, &translatorParam) - for _, chunk := range sseData { - if chunk != "" { - c.Writer.Write([]byte(chunk + "\n\n")) - } - } - c.Writer.Flush() - - // Send message_stop event separately - msgStop := e.buildClaudeMessageStopOnlyEvent() - sseData = sdktranslator.TranslateStream(ctx, sdktranslator.FromString("kiro"), targetFormat, model, originalReq, claudeBody, msgStop, &translatorParam) - for _, chunk := range sseData { - if chunk != "" { - c.Writer.Write([]byte(chunk + "\n\n")) - } - } - - c.Writer.Write([]byte("data: [DONE]\n\n")) - c.Writer.Flush() - - reporter.publish(ctx, totalUsage) - return nil -} - // isTokenExpired checks if a JWT access token has expired. // Returns true if the token is expired or cannot be parsed. func (e *KiroExecutor) isTokenExpired(accessToken string) bool { @@ -3542,666 +2372,6 @@ func (e *KiroExecutor) isTokenExpired(accessToken string) bool { return isExpired } -// ============================================================================ -// Message Merging Support - Merge adjacent messages with the same role -// Based on AIClient-2-API implementation -// ============================================================================ - -// mergeAdjacentMessages merges adjacent messages with the same role. -// This reduces API call complexity and improves compatibility. -// Based on AIClient-2-API implementation. -func mergeAdjacentMessages(messages []gjson.Result) []gjson.Result { - if len(messages) <= 1 { - return messages - } - - var merged []gjson.Result - for _, msg := range messages { - if len(merged) == 0 { - merged = append(merged, msg) - continue - } - - lastMsg := merged[len(merged)-1] - currentRole := msg.Get("role").String() - lastRole := lastMsg.Get("role").String() - - if currentRole == lastRole { - // Merge content from current message into last message - mergedContent := mergeMessageContent(lastMsg, msg) - // Create a new merged message JSON - mergedMsg := createMergedMessage(lastRole, mergedContent) - merged[len(merged)-1] = gjson.Parse(mergedMsg) - } else { - merged = append(merged, msg) - } - } - - return merged -} - -// mergeMessageContent merges the content of two messages with the same role. -// Handles both string content and array content (with text, tool_use, tool_result blocks). -func mergeMessageContent(msg1, msg2 gjson.Result) string { - content1 := msg1.Get("content") - content2 := msg2.Get("content") - - // Extract content blocks from both messages - var blocks1, blocks2 []map[string]interface{} - - if content1.IsArray() { - for _, block := range content1.Array() { - blocks1 = append(blocks1, blockToMap(block)) - } - } else if content1.Type == gjson.String { - blocks1 = append(blocks1, map[string]interface{}{ - "type": "text", - "text": content1.String(), - }) - } - - if content2.IsArray() { - for _, block := range content2.Array() { - blocks2 = append(blocks2, blockToMap(block)) - } - } else if content2.Type == gjson.String { - blocks2 = append(blocks2, map[string]interface{}{ - "type": "text", - "text": content2.String(), - }) - } - - // Merge text blocks if both end/start with text - if len(blocks1) > 0 && len(blocks2) > 0 { - if blocks1[len(blocks1)-1]["type"] == "text" && blocks2[0]["type"] == "text" { - // Merge the last text block of msg1 with the first text block of msg2 - text1 := blocks1[len(blocks1)-1]["text"].(string) - text2 := blocks2[0]["text"].(string) - blocks1[len(blocks1)-1]["text"] = text1 + "\n" + text2 - blocks2 = blocks2[1:] // Remove the merged block from blocks2 - } - } - - // Combine all blocks - allBlocks := append(blocks1, blocks2...) - - // Convert to JSON - result, _ := json.Marshal(allBlocks) - return string(result) -} - -// blockToMap converts a gjson.Result block to a map[string]interface{} -func blockToMap(block gjson.Result) map[string]interface{} { - result := make(map[string]interface{}) - block.ForEach(func(key, value gjson.Result) bool { - if value.IsObject() { - result[key.String()] = blockToMap(value) - } else if value.IsArray() { - var arr []interface{} - for _, item := range value.Array() { - if item.IsObject() { - arr = append(arr, blockToMap(item)) - } else { - arr = append(arr, item.Value()) - } - } - result[key.String()] = arr - } else { - result[key.String()] = value.Value() - } - return true - }) - return result -} - -// createMergedMessage creates a JSON string for a merged message -func createMergedMessage(role string, content string) string { - msg := map[string]interface{}{ - "role": role, - "content": json.RawMessage(content), - } - result, _ := json.Marshal(msg) - return string(result) -} - -// ============================================================================ -// Tool Calling Support - Embedded tool call parsing and input buffering -// Based on amq2api and AIClient-2-API implementations -// ============================================================================ - -// toolUseState tracks the state of an in-progress tool use during streaming. -type toolUseState struct { - toolUseID string - name string - inputBuffer strings.Builder - isComplete bool -} - -// Pre-compiled regex patterns for performance (avoid recompilation on each call) -var ( - // embeddedToolCallPattern matches [Called tool_name with args: {...}] format - // This pattern is used by Kiro when it embeds tool calls in text content - embeddedToolCallPattern = regexp.MustCompile(`\[Called\s+(\w+)\s+with\s+args:\s*`) - // whitespaceCollapsePattern collapses multiple whitespace characters into single space - whitespaceCollapsePattern = regexp.MustCompile(`\s+`) - // trailingCommaPattern matches trailing commas before closing braces/brackets - trailingCommaPattern = regexp.MustCompile(`,\s*([}\]])`) -) - -// parseEmbeddedToolCalls extracts [Called tool_name with args: {...}] format from text. -// Kiro sometimes embeds tool calls in text content instead of using toolUseEvent. -// Returns the cleaned text (with tool calls removed) and extracted tool uses. -func (e *KiroExecutor) parseEmbeddedToolCalls(text string, processedIDs map[string]bool) (string, []kiroToolUse) { - if !strings.Contains(text, "[Called") { - return text, nil - } - - var toolUses []kiroToolUse - cleanText := text - - // Find all [Called markers - matches := embeddedToolCallPattern.FindAllStringSubmatchIndex(text, -1) - if len(matches) == 0 { - return text, nil - } - - // Process matches in reverse order to maintain correct indices - for i := len(matches) - 1; i >= 0; i-- { - matchStart := matches[i][0] - toolNameStart := matches[i][2] - toolNameEnd := matches[i][3] - - if toolNameStart < 0 || toolNameEnd < 0 { - continue - } - - toolName := text[toolNameStart:toolNameEnd] - - // Find the JSON object start (after "with args:") - jsonStart := matches[i][1] - if jsonStart >= len(text) { - continue - } - - // Skip whitespace to find the opening brace - for jsonStart < len(text) && (text[jsonStart] == ' ' || text[jsonStart] == '\t') { - jsonStart++ - } - - if jsonStart >= len(text) || text[jsonStart] != '{' { - continue - } - - // Find matching closing bracket - jsonEnd := findMatchingBracket(text, jsonStart) - if jsonEnd < 0 { - continue - } - - // Extract JSON and find the closing bracket of [Called ...] - jsonStr := text[jsonStart : jsonEnd+1] - - // Find the closing ] after the JSON - closingBracket := jsonEnd + 1 - for closingBracket < len(text) && text[closingBracket] != ']' { - closingBracket++ - } - if closingBracket >= len(text) { - continue - } - - // Extract and repair the full tool call text - fullMatch := text[matchStart : closingBracket+1] - - // Repair and parse JSON - repairedJSON := repairJSON(jsonStr) - var inputMap map[string]interface{} - if err := json.Unmarshal([]byte(repairedJSON), &inputMap); err != nil { - log.Debugf("kiro: failed to parse embedded tool call JSON: %v, raw: %s", err, jsonStr) - continue - } - - // Generate unique tool ID - toolUseID := "toolu_" + uuid.New().String()[:12] - - // Check for duplicates using name+input as key - dedupeKey := toolName + ":" + repairedJSON - if processedIDs != nil { - if processedIDs[dedupeKey] { - log.Debugf("kiro: skipping duplicate embedded tool call: %s", toolName) - // Still remove from text even if duplicate - cleanText = strings.Replace(cleanText, fullMatch, "", 1) - continue - } - processedIDs[dedupeKey] = true - } - - toolUses = append(toolUses, kiroToolUse{ - ToolUseID: toolUseID, - Name: toolName, - Input: inputMap, - }) - - log.Infof("kiro: extracted embedded tool call: %s (ID: %s)", toolName, toolUseID) - - // Remove from clean text - cleanText = strings.Replace(cleanText, fullMatch, "", 1) - } - - // Clean up extra whitespace - cleanText = strings.TrimSpace(cleanText) - cleanText = whitespaceCollapsePattern.ReplaceAllString(cleanText, " ") - - return cleanText, toolUses -} - -// findMatchingBracket finds the index of the closing brace/bracket that matches -// the opening one at startPos. Handles nested objects and strings correctly. -func findMatchingBracket(text string, startPos int) int { - if startPos >= len(text) { - return -1 - } - - openChar := text[startPos] - var closeChar byte - switch openChar { - case '{': - closeChar = '}' - case '[': - closeChar = ']' - default: - return -1 - } - - depth := 1 - inString := false - escapeNext := false - - for i := startPos + 1; i < len(text); i++ { - char := text[i] - - if escapeNext { - escapeNext = false - continue - } - - if char == '\\' && inString { - escapeNext = true - continue - } - - if char == '"' { - inString = !inString - continue - } - - if !inString { - if char == openChar { - depth++ - } else if char == closeChar { - depth-- - if depth == 0 { - return i - } - } - } - } - - return -1 -} - -// repairJSON attempts to fix common JSON issues that may occur in tool call arguments. -// Based on AIClient-2-API's JSON repair implementation with a more conservative strategy. -// -// Conservative repair strategy: -// 1. First try to parse JSON directly - if valid, return as-is -// 2. Only attempt repair if parsing fails -// 3. After repair, validate the result - if still invalid, return original -// -// Handles incomplete JSON by balancing brackets and removing trailing incomplete content. -// Uses pre-compiled regex patterns for performance. -func repairJSON(jsonString string) string { - // Handle empty or invalid input - if jsonString == "" { - return "{}" - } - - str := strings.TrimSpace(jsonString) - if str == "" { - return "{}" - } - - // CONSERVATIVE STRATEGY: First try to parse directly - // If the JSON is already valid, return it unchanged - var testParse interface{} - if err := json.Unmarshal([]byte(str), &testParse); err == nil { - log.Debugf("kiro: repairJSON - JSON is already valid, returning unchanged") - return str - } - - log.Debugf("kiro: repairJSON - JSON parse failed, attempting repair") - originalStr := str // Keep original for fallback - - // First, escape unescaped newlines/tabs within JSON string values - str = escapeNewlinesInStrings(str) - // Remove trailing commas before closing braces/brackets - str = trailingCommaPattern.ReplaceAllString(str, "$1") - - // Calculate bracket balance to detect incomplete JSON - braceCount := 0 // {} balance - bracketCount := 0 // [] balance - inString := false - escape := false - lastValidIndex := -1 - - for i := 0; i < len(str); i++ { - char := str[i] - - // Handle escape sequences - if escape { - escape = false - continue - } - - if char == '\\' { - escape = true - continue - } - - // Handle string boundaries - if char == '"' { - inString = !inString - continue - } - - // Skip characters inside strings (they don't affect bracket balance) - if inString { - continue - } - - // Track bracket balance - switch char { - case '{': - braceCount++ - case '}': - braceCount-- - case '[': - bracketCount++ - case ']': - bracketCount-- - } - - // Record last valid position (where brackets are balanced or positive) - if braceCount >= 0 && bracketCount >= 0 { - lastValidIndex = i - } - } - - // If brackets are unbalanced, try to repair - if braceCount > 0 || bracketCount > 0 { - // Truncate to last valid position if we have incomplete content - if lastValidIndex > 0 && lastValidIndex < len(str)-1 { - // Check if truncation would help (only truncate if there's trailing garbage) - truncated := str[:lastValidIndex+1] - // Recount brackets after truncation - braceCount = 0 - bracketCount = 0 - inString = false - escape = false - for i := 0; i < len(truncated); i++ { - char := truncated[i] - if escape { - escape = false - continue - } - if char == '\\' { - escape = true - continue - } - if char == '"' { - inString = !inString - continue - } - if inString { - continue - } - switch char { - case '{': - braceCount++ - case '}': - braceCount-- - case '[': - bracketCount++ - case ']': - bracketCount-- - } - } - str = truncated - } - - // Add missing closing brackets - for braceCount > 0 { - str += "}" - braceCount-- - } - for bracketCount > 0 { - str += "]" - bracketCount-- - } - } - - // CONSERVATIVE STRATEGY: Validate repaired JSON - // If repair didn't produce valid JSON, return original string - if err := json.Unmarshal([]byte(str), &testParse); err != nil { - log.Warnf("kiro: repairJSON - repair failed to produce valid JSON, returning original") - return originalStr - } - - log.Debugf("kiro: repairJSON - successfully repaired JSON") - return str -} - -// escapeNewlinesInStrings escapes literal newlines, tabs, and other control characters -// that appear inside JSON string values. This handles cases where streaming fragments -// contain unescaped control characters within string content. -func escapeNewlinesInStrings(raw string) string { - var result strings.Builder - result.Grow(len(raw) + 100) // Pre-allocate with some extra space - - inString := false - escaped := false - - for i := 0; i < len(raw); i++ { - c := raw[i] - - if escaped { - // Previous character was backslash, this is an escape sequence - result.WriteByte(c) - escaped = false - continue - } - - if c == '\\' && inString { - // Start of escape sequence - result.WriteByte(c) - escaped = true - continue - } - - if c == '"' { - // Toggle string state - inString = !inString - result.WriteByte(c) - continue - } - - if inString { - // Inside a string, escape control characters - switch c { - case '\n': - result.WriteString("\\n") - case '\r': - result.WriteString("\\r") - case '\t': - result.WriteString("\\t") - default: - result.WriteByte(c) - } - } else { - result.WriteByte(c) - } - } - - return result.String() -} - -// processToolUseEvent handles a toolUseEvent from the Kiro stream. -// It accumulates input fragments and emits tool_use blocks when complete. -// Returns events to emit and updated state. -func (e *KiroExecutor) processToolUseEvent(event map[string]interface{}, currentToolUse *toolUseState, processedIDs map[string]bool) ([]kiroToolUse, *toolUseState) { - var toolUses []kiroToolUse - - // Extract from nested toolUseEvent or direct format - tu := event - if nested, ok := event["toolUseEvent"].(map[string]interface{}); ok { - tu = nested - } - - toolUseID := getString(tu, "toolUseId") - toolName := getString(tu, "name") - isStop := false - if stop, ok := tu["stop"].(bool); ok { - isStop = stop - } - - // Get input - can be string (fragment) or object (complete) - var inputFragment string - var inputMap map[string]interface{} - - if inputRaw, ok := tu["input"]; ok { - switch v := inputRaw.(type) { - case string: - inputFragment = v - case map[string]interface{}: - inputMap = v - } - } - - // New tool use starting - if toolUseID != "" && toolName != "" { - if currentToolUse != nil && currentToolUse.toolUseID != toolUseID { - // New tool use arrived while another is in progress (interleaved events) - // This is unusual - log warning and complete the previous one - log.Warnf("kiro: interleaved tool use detected - new ID %s arrived while %s in progress, completing previous", - toolUseID, currentToolUse.toolUseID) - // Emit incomplete previous tool use - if !processedIDs[currentToolUse.toolUseID] { - incomplete := kiroToolUse{ - ToolUseID: currentToolUse.toolUseID, - Name: currentToolUse.name, - } - if currentToolUse.inputBuffer.Len() > 0 { - var input map[string]interface{} - if err := json.Unmarshal([]byte(currentToolUse.inputBuffer.String()), &input); err == nil { - incomplete.Input = input - } - } - toolUses = append(toolUses, incomplete) - processedIDs[currentToolUse.toolUseID] = true - } - currentToolUse = nil - } - - if currentToolUse == nil { - // Check for duplicate - if processedIDs != nil && processedIDs[toolUseID] { - log.Debugf("kiro: skipping duplicate toolUseEvent: %s", toolUseID) - return nil, nil - } - - currentToolUse = &toolUseState{ - toolUseID: toolUseID, - name: toolName, - } - log.Infof("kiro: starting new tool use: %s (ID: %s)", toolName, toolUseID) - } - } - - // Accumulate input fragments - if currentToolUse != nil && inputFragment != "" { - // Accumulate fragments directly - they form valid JSON when combined - // The fragments are already decoded from JSON, so we just concatenate them - currentToolUse.inputBuffer.WriteString(inputFragment) - log.Debugf("kiro: accumulated input fragment, total length: %d", currentToolUse.inputBuffer.Len()) - } - - // If complete input object provided directly - if currentToolUse != nil && inputMap != nil { - inputBytes, _ := json.Marshal(inputMap) - currentToolUse.inputBuffer.Reset() - currentToolUse.inputBuffer.Write(inputBytes) - } - - // Tool use complete - if isStop && currentToolUse != nil { - fullInput := currentToolUse.inputBuffer.String() - - // Repair and parse the accumulated JSON - repairedJSON := repairJSON(fullInput) - var finalInput map[string]interface{} - if err := json.Unmarshal([]byte(repairedJSON), &finalInput); err != nil { - log.Warnf("kiro: failed to parse accumulated tool input: %v, raw: %s", err, fullInput) - // Use empty input as fallback - finalInput = make(map[string]interface{}) - } - - toolUse := kiroToolUse{ - ToolUseID: currentToolUse.toolUseID, - Name: currentToolUse.name, - Input: finalInput, - } - toolUses = append(toolUses, toolUse) - - // Mark as processed - if processedIDs != nil { - processedIDs[currentToolUse.toolUseID] = true - } - - log.Infof("kiro: completed tool use: %s (ID: %s)", currentToolUse.name, currentToolUse.toolUseID) - return toolUses, nil // Reset state - } - - return toolUses, currentToolUse -} - -// deduplicateToolUses removes duplicate tool uses based on toolUseId and content (name+arguments). -// This prevents both ID-based duplicates and content-based duplicates (same tool call with different IDs). -func deduplicateToolUses(toolUses []kiroToolUse) []kiroToolUse { - seenIDs := make(map[string]bool) - seenContent := make(map[string]bool) // Content-based deduplication (name + arguments) - var unique []kiroToolUse - - for _, tu := range toolUses { - // Skip if we've already seen this ID - if seenIDs[tu.ToolUseID] { - log.Debugf("kiro: removing ID-duplicate tool use: %s (name: %s)", tu.ToolUseID, tu.Name) - continue - } - - // Build content key for content-based deduplication - inputJSON, _ := json.Marshal(tu.Input) - contentKey := tu.Name + ":" + string(inputJSON) - - // Skip if we've already seen this content (same name + arguments) - if seenContent[contentKey] { - log.Debugf("kiro: removing content-duplicate tool use: %s (id: %s)", tu.Name, tu.ToolUseID) - continue - } - - seenIDs[tu.ToolUseID] = true - seenContent[contentKey] = true - unique = append(unique, tu) - } - - return unique -} +// NOTE: Message merging functions moved to internal/translator/kiro/common/message_merge.go +// NOTE: Tool calling support functions moved to internal/translator/kiro/claude/kiro_claude_tools.go +// The executor now uses kiroclaude.* and kirocommon.* functions instead diff --git a/internal/translator/init.go b/internal/translator/init.go index d19d9b34..0754db03 100644 --- a/internal/translator/init.go +++ b/internal/translator/init.go @@ -35,5 +35,5 @@ import ( _ "github.com/router-for-me/CLIProxyAPI/v6/internal/translator/antigravity/openai/responses" _ "github.com/router-for-me/CLIProxyAPI/v6/internal/translator/kiro/claude" - _ "github.com/router-for-me/CLIProxyAPI/v6/internal/translator/kiro/openai/chat-completions" + _ "github.com/router-for-me/CLIProxyAPI/v6/internal/translator/kiro/openai" ) diff --git a/internal/translator/kiro/claude/init.go b/internal/translator/kiro/claude/init.go index 9e3a2ba3..1685d195 100644 --- a/internal/translator/kiro/claude/init.go +++ b/internal/translator/kiro/claude/init.go @@ -1,3 +1,4 @@ +// Package claude provides translation between Kiro and Claude formats. package claude import ( @@ -12,8 +13,8 @@ func init() { Kiro, ConvertClaudeRequestToKiro, interfaces.TranslateResponse{ - Stream: ConvertKiroResponseToClaude, - NonStream: ConvertKiroResponseToClaudeNonStream, + Stream: ConvertKiroStreamToClaude, + NonStream: ConvertKiroNonStreamToClaude, }, ) } diff --git a/internal/translator/kiro/claude/kiro_claude.go b/internal/translator/kiro/claude/kiro_claude.go index 554dbf21..752a00d9 100644 --- a/internal/translator/kiro/claude/kiro_claude.go +++ b/internal/translator/kiro/claude/kiro_claude.go @@ -1,27 +1,21 @@ // Package claude provides translation between Kiro and Claude formats. // Since Kiro executor generates Claude-compatible SSE format internally (with event: prefix), -// translations are pass-through. +// translations are pass-through for streaming, but responses need proper formatting. package claude import ( - "bytes" "context" ) -// ConvertClaudeRequestToKiro converts Claude request to Kiro format. -// Since Kiro uses Claude format internally, this is mostly a pass-through. -func ConvertClaudeRequestToKiro(modelName string, inputRawJSON []byte, stream bool) []byte { - return bytes.Clone(inputRawJSON) -} - -// ConvertKiroResponseToClaude converts Kiro streaming response to Claude format. +// ConvertKiroStreamToClaude converts Kiro streaming response to Claude format. // Kiro executor already generates complete SSE format with "event:" prefix, // so this is a simple pass-through. -func ConvertKiroResponseToClaude(ctx context.Context, model string, originalRequest, request, rawResponse []byte, param *any) []string { +func ConvertKiroStreamToClaude(ctx context.Context, model string, originalRequest, request, rawResponse []byte, param *any) []string { return []string{string(rawResponse)} } -// ConvertKiroResponseToClaudeNonStream converts Kiro non-streaming response to Claude format. -func ConvertKiroResponseToClaudeNonStream(ctx context.Context, model string, originalRequest, request, rawResponse []byte, param *any) string { +// ConvertKiroNonStreamToClaude converts Kiro non-streaming response to Claude format. +// The response is already in Claude format, so this is a pass-through. +func ConvertKiroNonStreamToClaude(ctx context.Context, model string, originalRequest, request, rawResponse []byte, param *any) string { return string(rawResponse) } diff --git a/internal/translator/kiro/claude/kiro_claude_request.go b/internal/translator/kiro/claude/kiro_claude_request.go new file mode 100644 index 00000000..07472be4 --- /dev/null +++ b/internal/translator/kiro/claude/kiro_claude_request.go @@ -0,0 +1,603 @@ +// Package claude provides request translation functionality for Claude API to Kiro format. +// It handles parsing and transforming Claude API requests into the Kiro/Amazon Q API format, +// extracting model information, system instructions, message contents, and tool declarations. +package claude + +import ( + "encoding/json" + "fmt" + "strings" + "time" + "unicode/utf8" + + "github.com/google/uuid" + kirocommon "github.com/router-for-me/CLIProxyAPI/v6/internal/translator/kiro/common" + log "github.com/sirupsen/logrus" + "github.com/tidwall/gjson" +) + + +// Kiro API request structs - field order determines JSON key order + +// KiroPayload is the top-level request structure for Kiro API +type KiroPayload struct { + ConversationState KiroConversationState `json:"conversationState"` + ProfileArn string `json:"profileArn,omitempty"` + InferenceConfig *KiroInferenceConfig `json:"inferenceConfig,omitempty"` +} + +// KiroInferenceConfig contains inference parameters for the Kiro API. +type KiroInferenceConfig struct { + MaxTokens int `json:"maxTokens,omitempty"` + Temperature float64 `json:"temperature,omitempty"` +} + +// KiroConversationState holds the conversation context +type KiroConversationState struct { + ChatTriggerType string `json:"chatTriggerType"` // Required: "MANUAL" - must be first field + ConversationID string `json:"conversationId"` + CurrentMessage KiroCurrentMessage `json:"currentMessage"` + History []KiroHistoryMessage `json:"history,omitempty"` +} + +// KiroCurrentMessage wraps the current user message +type KiroCurrentMessage struct { + UserInputMessage KiroUserInputMessage `json:"userInputMessage"` +} + +// KiroHistoryMessage represents a message in the conversation history +type KiroHistoryMessage struct { + UserInputMessage *KiroUserInputMessage `json:"userInputMessage,omitempty"` + AssistantResponseMessage *KiroAssistantResponseMessage `json:"assistantResponseMessage,omitempty"` +} + +// KiroImage represents an image in Kiro API format +type KiroImage struct { + Format string `json:"format"` + Source KiroImageSource `json:"source"` +} + +// KiroImageSource contains the image data +type KiroImageSource struct { + Bytes string `json:"bytes"` // base64 encoded image data +} + +// KiroUserInputMessage represents a user message +type KiroUserInputMessage struct { + Content string `json:"content"` + ModelID string `json:"modelId"` + Origin string `json:"origin"` + Images []KiroImage `json:"images,omitempty"` + UserInputMessageContext *KiroUserInputMessageContext `json:"userInputMessageContext,omitempty"` +} + +// KiroUserInputMessageContext contains tool-related context +type KiroUserInputMessageContext struct { + ToolResults []KiroToolResult `json:"toolResults,omitempty"` + Tools []KiroToolWrapper `json:"tools,omitempty"` +} + +// KiroToolResult represents a tool execution result +type KiroToolResult struct { + Content []KiroTextContent `json:"content"` + Status string `json:"status"` + ToolUseID string `json:"toolUseId"` +} + +// KiroTextContent represents text content +type KiroTextContent struct { + Text string `json:"text"` +} + +// KiroToolWrapper wraps a tool specification +type KiroToolWrapper struct { + ToolSpecification KiroToolSpecification `json:"toolSpecification"` +} + +// KiroToolSpecification defines a tool's schema +type KiroToolSpecification struct { + Name string `json:"name"` + Description string `json:"description"` + InputSchema KiroInputSchema `json:"inputSchema"` +} + +// KiroInputSchema wraps the JSON schema for tool input +type KiroInputSchema struct { + JSON interface{} `json:"json"` +} + +// KiroAssistantResponseMessage represents an assistant message +type KiroAssistantResponseMessage struct { + Content string `json:"content"` + ToolUses []KiroToolUse `json:"toolUses,omitempty"` +} + +// KiroToolUse represents a tool invocation by the assistant +type KiroToolUse struct { + ToolUseID string `json:"toolUseId"` + Name string `json:"name"` + Input map[string]interface{} `json:"input"` +} + +// ConvertClaudeRequestToKiro converts a Claude API request to Kiro format. +// This is the main entry point for request translation. +func ConvertClaudeRequestToKiro(modelName string, inputRawJSON []byte, stream bool) []byte { + // For Kiro, we pass through the Claude format since buildKiroPayload + // expects Claude format and does the conversion internally. + // The actual conversion happens in the executor when building the HTTP request. + return inputRawJSON +} + +// BuildKiroPayload constructs the Kiro API request payload from Claude format. +// Supports tool calling - tools are passed via userInputMessageContext. +// origin parameter determines which quota to use: "CLI" for Amazon Q, "AI_EDITOR" for Kiro IDE. +// isAgentic parameter enables chunked write optimization prompt for -agentic model variants. +// isChatOnly parameter disables tool calling for -chat model variants (pure conversation mode). +// Supports thinking mode - when Claude API thinking parameter is present, injects thinkingHint. +func BuildKiroPayload(claudeBody []byte, modelID, profileArn, origin string, isAgentic, isChatOnly bool) []byte { + // Extract max_tokens for potential use in inferenceConfig + var maxTokens int64 + if mt := gjson.GetBytes(claudeBody, "max_tokens"); mt.Exists() { + maxTokens = mt.Int() + } + + // Extract temperature if specified + var temperature float64 + var hasTemperature bool + if temp := gjson.GetBytes(claudeBody, "temperature"); temp.Exists() { + temperature = temp.Float() + hasTemperature = true + } + + // Normalize origin value for Kiro API compatibility + origin = normalizeOrigin(origin) + log.Debugf("kiro: normalized origin value: %s", origin) + + messages := gjson.GetBytes(claudeBody, "messages") + + // For chat-only mode, don't include tools + var tools gjson.Result + if !isChatOnly { + tools = gjson.GetBytes(claudeBody, "tools") + } + + // Extract system prompt + systemPrompt := extractSystemPrompt(claudeBody) + + // Check for thinking mode + thinkingEnabled, budgetTokens := checkThinkingMode(claudeBody) + + // Inject timestamp context + timestamp := time.Now().Format("2006-01-02 15:04:05 MST") + timestampContext := fmt.Sprintf("[Context: Current time is %s]", timestamp) + if systemPrompt != "" { + systemPrompt = timestampContext + "\n\n" + systemPrompt + } else { + systemPrompt = timestampContext + } + log.Debugf("kiro: injected timestamp context: %s", timestamp) + + // Inject agentic optimization prompt for -agentic model variants + if isAgentic { + if systemPrompt != "" { + systemPrompt += "\n" + } + systemPrompt += kirocommon.KiroAgenticSystemPrompt + } + + // Inject thinking hint when thinking mode is enabled + if thinkingEnabled { + if systemPrompt != "" { + systemPrompt += "\n" + } + dynamicThinkingHint := fmt.Sprintf("interleaved%d", budgetTokens) + systemPrompt += dynamicThinkingHint + log.Debugf("kiro: injected dynamic thinking hint into system prompt, max_thinking_length: %d", budgetTokens) + } + + // Convert Claude tools to Kiro format + kiroTools := convertClaudeToolsToKiro(tools) + + // Process messages and build history + history, currentUserMsg, currentToolResults := processMessages(messages, modelID, origin) + + // Build content with system prompt + if currentUserMsg != nil { + currentUserMsg.Content = buildFinalContent(currentUserMsg.Content, systemPrompt, currentToolResults) + + // Deduplicate currentToolResults + currentToolResults = deduplicateToolResults(currentToolResults) + + // Build userInputMessageContext with tools and tool results + if len(kiroTools) > 0 || len(currentToolResults) > 0 { + currentUserMsg.UserInputMessageContext = &KiroUserInputMessageContext{ + Tools: kiroTools, + ToolResults: currentToolResults, + } + } + } + + // Build payload + var currentMessage KiroCurrentMessage + if currentUserMsg != nil { + currentMessage = KiroCurrentMessage{UserInputMessage: *currentUserMsg} + } else { + fallbackContent := "" + if systemPrompt != "" { + fallbackContent = "--- SYSTEM PROMPT ---\n" + systemPrompt + "\n--- END SYSTEM PROMPT ---\n" + } + currentMessage = KiroCurrentMessage{UserInputMessage: KiroUserInputMessage{ + Content: fallbackContent, + ModelID: modelID, + Origin: origin, + }} + } + + // Build inferenceConfig if we have any inference parameters + var inferenceConfig *KiroInferenceConfig + if maxTokens > 0 || hasTemperature { + inferenceConfig = &KiroInferenceConfig{} + if maxTokens > 0 { + inferenceConfig.MaxTokens = int(maxTokens) + } + if hasTemperature { + inferenceConfig.Temperature = temperature + } + } + + payload := KiroPayload{ + ConversationState: KiroConversationState{ + ChatTriggerType: "MANUAL", + ConversationID: uuid.New().String(), + CurrentMessage: currentMessage, + History: history, + }, + ProfileArn: profileArn, + InferenceConfig: inferenceConfig, + } + + result, err := json.Marshal(payload) + if err != nil { + log.Debugf("kiro: failed to marshal payload: %v", err) + return nil + } + + return result +} + +// normalizeOrigin normalizes origin value for Kiro API compatibility +func normalizeOrigin(origin string) string { + switch origin { + case "KIRO_CLI": + return "CLI" + case "KIRO_AI_EDITOR": + return "AI_EDITOR" + case "AMAZON_Q": + return "CLI" + case "KIRO_IDE": + return "AI_EDITOR" + default: + return origin + } +} + +// extractSystemPrompt extracts system prompt from Claude request +func extractSystemPrompt(claudeBody []byte) string { + systemField := gjson.GetBytes(claudeBody, "system") + if systemField.IsArray() { + var sb strings.Builder + for _, block := range systemField.Array() { + if block.Get("type").String() == "text" { + sb.WriteString(block.Get("text").String()) + } else if block.Type == gjson.String { + sb.WriteString(block.String()) + } + } + return sb.String() + } + return systemField.String() +} + +// checkThinkingMode checks if thinking mode is enabled in the Claude request +func checkThinkingMode(claudeBody []byte) (bool, int64) { + thinkingEnabled := false + var budgetTokens int64 = 16000 + + thinkingField := gjson.GetBytes(claudeBody, "thinking") + if thinkingField.Exists() { + thinkingType := thinkingField.Get("type").String() + if thinkingType == "enabled" { + thinkingEnabled = true + if bt := thinkingField.Get("budget_tokens"); bt.Exists() { + budgetTokens = bt.Int() + if budgetTokens <= 0 { + thinkingEnabled = false + log.Debugf("kiro: thinking mode disabled via budget_tokens <= 0") + } + } + if thinkingEnabled { + log.Debugf("kiro: thinking mode enabled via Claude API parameter, budget_tokens: %d", budgetTokens) + } + } + } + + return thinkingEnabled, budgetTokens +} + +// convertClaudeToolsToKiro converts Claude tools to Kiro format +func convertClaudeToolsToKiro(tools gjson.Result) []KiroToolWrapper { + var kiroTools []KiroToolWrapper + if !tools.IsArray() { + return kiroTools + } + + for _, tool := range tools.Array() { + name := tool.Get("name").String() + description := tool.Get("description").String() + inputSchema := tool.Get("input_schema").Value() + + // CRITICAL FIX: Kiro API requires non-empty description + if strings.TrimSpace(description) == "" { + description = fmt.Sprintf("Tool: %s", name) + log.Debugf("kiro: tool '%s' has empty description, using default: %s", name, description) + } + + // Truncate long descriptions + if len(description) > kirocommon.KiroMaxToolDescLen { + truncLen := kirocommon.KiroMaxToolDescLen - 30 + for truncLen > 0 && !utf8.RuneStart(description[truncLen]) { + truncLen-- + } + description = description[:truncLen] + "... (description truncated)" + } + + kiroTools = append(kiroTools, KiroToolWrapper{ + ToolSpecification: KiroToolSpecification{ + Name: name, + Description: description, + InputSchema: KiroInputSchema{JSON: inputSchema}, + }, + }) + } + + return kiroTools +} + +// processMessages processes Claude messages and builds Kiro history +func processMessages(messages gjson.Result, modelID, origin string) ([]KiroHistoryMessage, *KiroUserInputMessage, []KiroToolResult) { + var history []KiroHistoryMessage + var currentUserMsg *KiroUserInputMessage + var currentToolResults []KiroToolResult + + // Merge adjacent messages with the same role + messagesArray := kirocommon.MergeAdjacentMessages(messages.Array()) + for i, msg := range messagesArray { + role := msg.Get("role").String() + isLastMessage := i == len(messagesArray)-1 + + if role == "user" { + userMsg, toolResults := BuildUserMessageStruct(msg, modelID, origin) + if isLastMessage { + currentUserMsg = &userMsg + currentToolResults = toolResults + } else { + // CRITICAL: Kiro API requires content to be non-empty for history messages too + if strings.TrimSpace(userMsg.Content) == "" { + if len(toolResults) > 0 { + userMsg.Content = "Tool results provided." + } else { + userMsg.Content = "Continue" + } + } + // For history messages, embed tool results in context + if len(toolResults) > 0 { + userMsg.UserInputMessageContext = &KiroUserInputMessageContext{ + ToolResults: toolResults, + } + } + history = append(history, KiroHistoryMessage{ + UserInputMessage: &userMsg, + }) + } + } else if role == "assistant" { + assistantMsg := BuildAssistantMessageStruct(msg) + if isLastMessage { + history = append(history, KiroHistoryMessage{ + AssistantResponseMessage: &assistantMsg, + }) + // Create a "Continue" user message as currentMessage + currentUserMsg = &KiroUserInputMessage{ + Content: "Continue", + ModelID: modelID, + Origin: origin, + } + } else { + history = append(history, KiroHistoryMessage{ + AssistantResponseMessage: &assistantMsg, + }) + } + } + } + + return history, currentUserMsg, currentToolResults +} + +// buildFinalContent builds the final content with system prompt +func buildFinalContent(content, systemPrompt string, toolResults []KiroToolResult) string { + var contentBuilder strings.Builder + + if systemPrompt != "" { + contentBuilder.WriteString("--- SYSTEM PROMPT ---\n") + contentBuilder.WriteString(systemPrompt) + contentBuilder.WriteString("\n--- END SYSTEM PROMPT ---\n\n") + } + + contentBuilder.WriteString(content) + finalContent := contentBuilder.String() + + // CRITICAL: Kiro API requires content to be non-empty + if strings.TrimSpace(finalContent) == "" { + if len(toolResults) > 0 { + finalContent = "Tool results provided." + } else { + finalContent = "Continue" + } + log.Debugf("kiro: content was empty, using default: %s", finalContent) + } + + return finalContent +} + +// deduplicateToolResults removes duplicate tool results +func deduplicateToolResults(toolResults []KiroToolResult) []KiroToolResult { + if len(toolResults) == 0 { + return toolResults + } + + seenIDs := make(map[string]bool) + unique := make([]KiroToolResult, 0, len(toolResults)) + for _, tr := range toolResults { + if !seenIDs[tr.ToolUseID] { + seenIDs[tr.ToolUseID] = true + unique = append(unique, tr) + } else { + log.Debugf("kiro: skipping duplicate toolResult in currentMessage: %s", tr.ToolUseID) + } + } + return unique +} + +// BuildUserMessageStruct builds a user message and extracts tool results +func BuildUserMessageStruct(msg gjson.Result, modelID, origin string) (KiroUserInputMessage, []KiroToolResult) { + content := msg.Get("content") + var contentBuilder strings.Builder + var toolResults []KiroToolResult + var images []KiroImage + + // Track seen toolUseIds to deduplicate + seenToolUseIDs := make(map[string]bool) + + if content.IsArray() { + for _, part := range content.Array() { + partType := part.Get("type").String() + switch partType { + case "text": + contentBuilder.WriteString(part.Get("text").String()) + case "image": + mediaType := part.Get("source.media_type").String() + data := part.Get("source.data").String() + + format := "" + if idx := strings.LastIndex(mediaType, "/"); idx != -1 { + format = mediaType[idx+1:] + } + + if format != "" && data != "" { + images = append(images, KiroImage{ + Format: format, + Source: KiroImageSource{ + Bytes: data, + }, + }) + } + case "tool_result": + toolUseID := part.Get("tool_use_id").String() + + // Skip duplicate toolUseIds + if seenToolUseIDs[toolUseID] { + log.Debugf("kiro: skipping duplicate tool_result with toolUseId: %s", toolUseID) + continue + } + seenToolUseIDs[toolUseID] = true + + isError := part.Get("is_error").Bool() + resultContent := part.Get("content") + + var textContents []KiroTextContent + if resultContent.IsArray() { + for _, item := range resultContent.Array() { + if item.Get("type").String() == "text" { + textContents = append(textContents, KiroTextContent{Text: item.Get("text").String()}) + } else if item.Type == gjson.String { + textContents = append(textContents, KiroTextContent{Text: item.String()}) + } + } + } else if resultContent.Type == gjson.String { + textContents = append(textContents, KiroTextContent{Text: resultContent.String()}) + } + + if len(textContents) == 0 { + textContents = append(textContents, KiroTextContent{Text: "Tool use was cancelled by the user"}) + } + + status := "success" + if isError { + status = "error" + } + + toolResults = append(toolResults, KiroToolResult{ + ToolUseID: toolUseID, + Content: textContents, + Status: status, + }) + } + } + } else { + contentBuilder.WriteString(content.String()) + } + + userMsg := KiroUserInputMessage{ + Content: contentBuilder.String(), + ModelID: modelID, + Origin: origin, + } + + if len(images) > 0 { + userMsg.Images = images + } + + return userMsg, toolResults +} + +// BuildAssistantMessageStruct builds an assistant message with tool uses +func BuildAssistantMessageStruct(msg gjson.Result) KiroAssistantResponseMessage { + content := msg.Get("content") + var contentBuilder strings.Builder + var toolUses []KiroToolUse + + if content.IsArray() { + for _, part := range content.Array() { + partType := part.Get("type").String() + switch partType { + case "text": + contentBuilder.WriteString(part.Get("text").String()) + case "tool_use": + toolUseID := part.Get("id").String() + toolName := part.Get("name").String() + toolInput := part.Get("input") + + var inputMap map[string]interface{} + if toolInput.IsObject() { + inputMap = make(map[string]interface{}) + toolInput.ForEach(func(key, value gjson.Result) bool { + inputMap[key.String()] = value.Value() + return true + }) + } + + toolUses = append(toolUses, KiroToolUse{ + ToolUseID: toolUseID, + Name: toolName, + Input: inputMap, + }) + } + } + } else { + contentBuilder.WriteString(content.String()) + } + + return KiroAssistantResponseMessage{ + Content: contentBuilder.String(), + ToolUses: toolUses, + } +} \ No newline at end of file diff --git a/internal/translator/kiro/claude/kiro_claude_response.go b/internal/translator/kiro/claude/kiro_claude_response.go new file mode 100644 index 00000000..49ebf79e --- /dev/null +++ b/internal/translator/kiro/claude/kiro_claude_response.go @@ -0,0 +1,184 @@ +// Package claude provides response translation functionality for Kiro API to Claude format. +// This package handles the conversion of Kiro API responses into Claude-compatible format, +// including support for thinking blocks and tool use. +package claude + +import ( + "encoding/json" + "strings" + + "github.com/google/uuid" + "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/usage" + log "github.com/sirupsen/logrus" + + kirocommon "github.com/router-for-me/CLIProxyAPI/v6/internal/translator/kiro/common" +) + +// Local references to kirocommon constants for thinking block parsing +var ( + thinkingStartTag = kirocommon.ThinkingStartTag + thinkingEndTag = kirocommon.ThinkingEndTag +) + +// BuildClaudeResponse constructs a Claude-compatible response. +// Supports tool_use blocks when tools are present in the response. +// Supports thinking blocks - parses tags and converts to Claude thinking content blocks. +// stopReason is passed from upstream; fallback logic applied if empty. +func BuildClaudeResponse(content string, toolUses []KiroToolUse, model string, usageInfo usage.Detail, stopReason string) []byte { + var contentBlocks []map[string]interface{} + + // Extract thinking blocks and text from content + if content != "" { + blocks := ExtractThinkingFromContent(content) + contentBlocks = append(contentBlocks, blocks...) + + // Log if thinking blocks were extracted + for _, block := range blocks { + if block["type"] == "thinking" { + thinkingContent := block["thinking"].(string) + log.Infof("kiro: buildClaudeResponse extracted thinking block (len: %d)", len(thinkingContent)) + } + } + } + + // Add tool_use blocks + for _, toolUse := range toolUses { + contentBlocks = append(contentBlocks, map[string]interface{}{ + "type": "tool_use", + "id": toolUse.ToolUseID, + "name": toolUse.Name, + "input": toolUse.Input, + }) + } + + // Ensure at least one content block (Claude API requires non-empty content) + if len(contentBlocks) == 0 { + contentBlocks = append(contentBlocks, map[string]interface{}{ + "type": "text", + "text": "", + }) + } + + // Use upstream stopReason; apply fallback logic if not provided + if stopReason == "" { + stopReason = "end_turn" + if len(toolUses) > 0 { + stopReason = "tool_use" + } + log.Debugf("kiro: buildClaudeResponse using fallback stop_reason: %s", stopReason) + } + + // Log warning if response was truncated due to max_tokens + if stopReason == "max_tokens" { + log.Warnf("kiro: response truncated due to max_tokens limit (buildClaudeResponse)") + } + + response := map[string]interface{}{ + "id": "msg_" + uuid.New().String()[:24], + "type": "message", + "role": "assistant", + "model": model, + "content": contentBlocks, + "stop_reason": stopReason, + "usage": map[string]interface{}{ + "input_tokens": usageInfo.InputTokens, + "output_tokens": usageInfo.OutputTokens, + }, + } + result, _ := json.Marshal(response) + return result +} + +// ExtractThinkingFromContent parses content to extract thinking blocks and text. +// Returns a list of content blocks in the order they appear in the content. +// Handles interleaved thinking and text blocks correctly. +func ExtractThinkingFromContent(content string) []map[string]interface{} { + var blocks []map[string]interface{} + + if content == "" { + return blocks + } + + // Check if content contains thinking tags at all + if !strings.Contains(content, thinkingStartTag) { + // No thinking tags, return as plain text + return []map[string]interface{}{ + { + "type": "text", + "text": content, + }, + } + } + + log.Debugf("kiro: extractThinkingFromContent - found thinking tags in content (len: %d)", len(content)) + + remaining := content + + for len(remaining) > 0 { + // Look for tag + startIdx := strings.Index(remaining, thinkingStartTag) + + if startIdx == -1 { + // No more thinking tags, add remaining as text + if strings.TrimSpace(remaining) != "" { + blocks = append(blocks, map[string]interface{}{ + "type": "text", + "text": remaining, + }) + } + break + } + + // Add text before thinking tag (if any meaningful content) + if startIdx > 0 { + textBefore := remaining[:startIdx] + if strings.TrimSpace(textBefore) != "" { + blocks = append(blocks, map[string]interface{}{ + "type": "text", + "text": textBefore, + }) + } + } + + // Move past the opening tag + remaining = remaining[startIdx+len(thinkingStartTag):] + + // Find closing tag + endIdx := strings.Index(remaining, thinkingEndTag) + + if endIdx == -1 { + // No closing tag found, treat rest as thinking content (incomplete response) + if strings.TrimSpace(remaining) != "" { + blocks = append(blocks, map[string]interface{}{ + "type": "thinking", + "thinking": remaining, + }) + log.Warnf("kiro: extractThinkingFromContent - missing closing tag") + } + break + } + + // Extract thinking content between tags + thinkContent := remaining[:endIdx] + if strings.TrimSpace(thinkContent) != "" { + blocks = append(blocks, map[string]interface{}{ + "type": "thinking", + "thinking": thinkContent, + }) + log.Debugf("kiro: extractThinkingFromContent - extracted thinking block (len: %d)", len(thinkContent)) + } + + // Move past the closing tag + remaining = remaining[endIdx+len(thinkingEndTag):] + } + + // If no blocks were created (all whitespace), return empty text block + if len(blocks) == 0 { + blocks = append(blocks, map[string]interface{}{ + "type": "text", + "text": "", + }) + } + + return blocks +} \ No newline at end of file diff --git a/internal/translator/kiro/claude/kiro_claude_stream.go b/internal/translator/kiro/claude/kiro_claude_stream.go new file mode 100644 index 00000000..6ea6e4cd --- /dev/null +++ b/internal/translator/kiro/claude/kiro_claude_stream.go @@ -0,0 +1,176 @@ +// Package claude provides streaming SSE event building for Claude format. +// This package handles the construction of Claude-compatible Server-Sent Events (SSE) +// for streaming responses from Kiro API. +package claude + +import ( + "encoding/json" + + "github.com/google/uuid" + "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/usage" +) + +// BuildClaudeMessageStartEvent creates the message_start SSE event +func BuildClaudeMessageStartEvent(model string, inputTokens int64) []byte { + event := map[string]interface{}{ + "type": "message_start", + "message": map[string]interface{}{ + "id": "msg_" + uuid.New().String()[:24], + "type": "message", + "role": "assistant", + "content": []interface{}{}, + "model": model, + "stop_reason": nil, + "stop_sequence": nil, + "usage": map[string]interface{}{"input_tokens": inputTokens, "output_tokens": 0}, + }, + } + result, _ := json.Marshal(event) + return []byte("event: message_start\ndata: " + string(result)) +} + +// BuildClaudeContentBlockStartEvent creates a content_block_start SSE event +func BuildClaudeContentBlockStartEvent(index int, blockType, toolUseID, toolName string) []byte { + var contentBlock map[string]interface{} + switch blockType { + case "tool_use": + contentBlock = map[string]interface{}{ + "type": "tool_use", + "id": toolUseID, + "name": toolName, + "input": map[string]interface{}{}, + } + case "thinking": + contentBlock = map[string]interface{}{ + "type": "thinking", + "thinking": "", + } + default: + contentBlock = map[string]interface{}{ + "type": "text", + "text": "", + } + } + + event := map[string]interface{}{ + "type": "content_block_start", + "index": index, + "content_block": contentBlock, + } + result, _ := json.Marshal(event) + return []byte("event: content_block_start\ndata: " + string(result)) +} + +// BuildClaudeStreamEvent creates a text_delta content_block_delta SSE event +func BuildClaudeStreamEvent(contentDelta string, index int) []byte { + event := map[string]interface{}{ + "type": "content_block_delta", + "index": index, + "delta": map[string]interface{}{ + "type": "text_delta", + "text": contentDelta, + }, + } + result, _ := json.Marshal(event) + return []byte("event: content_block_delta\ndata: " + string(result)) +} + +// BuildClaudeInputJsonDeltaEvent creates an input_json_delta event for tool use streaming +func BuildClaudeInputJsonDeltaEvent(partialJSON string, index int) []byte { + event := map[string]interface{}{ + "type": "content_block_delta", + "index": index, + "delta": map[string]interface{}{ + "type": "input_json_delta", + "partial_json": partialJSON, + }, + } + result, _ := json.Marshal(event) + return []byte("event: content_block_delta\ndata: " + string(result)) +} + +// BuildClaudeContentBlockStopEvent creates a content_block_stop SSE event +func BuildClaudeContentBlockStopEvent(index int) []byte { + event := map[string]interface{}{ + "type": "content_block_stop", + "index": index, + } + result, _ := json.Marshal(event) + return []byte("event: content_block_stop\ndata: " + string(result)) +} + +// BuildClaudeMessageDeltaEvent creates the message_delta event with stop_reason and usage +func BuildClaudeMessageDeltaEvent(stopReason string, usageInfo usage.Detail) []byte { + deltaEvent := map[string]interface{}{ + "type": "message_delta", + "delta": map[string]interface{}{ + "stop_reason": stopReason, + "stop_sequence": nil, + }, + "usage": map[string]interface{}{ + "input_tokens": usageInfo.InputTokens, + "output_tokens": usageInfo.OutputTokens, + }, + } + deltaResult, _ := json.Marshal(deltaEvent) + return []byte("event: message_delta\ndata: " + string(deltaResult)) +} + +// BuildClaudeMessageStopOnlyEvent creates only the message_stop event +func BuildClaudeMessageStopOnlyEvent() []byte { + stopEvent := map[string]interface{}{ + "type": "message_stop", + } + stopResult, _ := json.Marshal(stopEvent) + return []byte("event: message_stop\ndata: " + string(stopResult)) +} + +// BuildClaudePingEventWithUsage creates a ping event with embedded usage information. +// This is used for real-time usage estimation during streaming. +func BuildClaudePingEventWithUsage(inputTokens, outputTokens int64) []byte { + event := map[string]interface{}{ + "type": "ping", + "usage": map[string]interface{}{ + "input_tokens": inputTokens, + "output_tokens": outputTokens, + "total_tokens": inputTokens + outputTokens, + "estimated": true, + }, + } + result, _ := json.Marshal(event) + return []byte("event: ping\ndata: " + string(result)) +} + +// BuildClaudeThinkingDeltaEvent creates a thinking_delta event for Claude API compatibility. +// This is used when streaming thinking content wrapped in tags. +func BuildClaudeThinkingDeltaEvent(thinkingDelta string, index int) []byte { + event := map[string]interface{}{ + "type": "content_block_delta", + "index": index, + "delta": map[string]interface{}{ + "type": "thinking_delta", + "thinking": thinkingDelta, + }, + } + result, _ := json.Marshal(event) + return []byte("event: content_block_delta\ndata: " + string(result)) +} + +// PendingTagSuffix detects if the buffer ends with a partial prefix of the given tag. +// Returns the length of the partial match (0 if no match). +// Based on amq2api implementation for handling cross-chunk tag boundaries. +func PendingTagSuffix(buffer, tag string) int { + if buffer == "" || tag == "" { + return 0 + } + maxLen := len(buffer) + if maxLen > len(tag)-1 { + maxLen = len(tag) - 1 + } + for length := maxLen; length > 0; length-- { + if len(buffer) >= length && buffer[len(buffer)-length:] == tag[:length] { + return length + } + } + return 0 +} \ No newline at end of file diff --git a/internal/translator/kiro/claude/kiro_claude_tools.go b/internal/translator/kiro/claude/kiro_claude_tools.go new file mode 100644 index 00000000..93ede875 --- /dev/null +++ b/internal/translator/kiro/claude/kiro_claude_tools.go @@ -0,0 +1,522 @@ +// Package claude provides tool calling support for Kiro to Claude translation. +// This package handles parsing embedded tool calls, JSON repair, and deduplication. +package claude + +import ( + "encoding/json" + "regexp" + "strings" + + "github.com/google/uuid" + kirocommon "github.com/router-for-me/CLIProxyAPI/v6/internal/translator/kiro/common" + log "github.com/sirupsen/logrus" +) + +// ToolUseState tracks the state of an in-progress tool use during streaming. +type ToolUseState struct { + ToolUseID string + Name string + InputBuffer strings.Builder + IsComplete bool +} + +// Pre-compiled regex patterns for performance +var ( + // embeddedToolCallPattern matches [Called tool_name with args: {...}] format + embeddedToolCallPattern = regexp.MustCompile(`\[Called\s+([A-Za-z0-9_.-]+)\s+with\s+args:\s*`) + // trailingCommaPattern matches trailing commas before closing braces/brackets + trailingCommaPattern = regexp.MustCompile(`,\s*([}\]])`) +) + +// ParseEmbeddedToolCalls extracts [Called tool_name with args: {...}] format from text. +// Kiro sometimes embeds tool calls in text content instead of using toolUseEvent. +// Returns the cleaned text (with tool calls removed) and extracted tool uses. +func ParseEmbeddedToolCalls(text string, processedIDs map[string]bool) (string, []KiroToolUse) { + if !strings.Contains(text, "[Called") { + return text, nil + } + + var toolUses []KiroToolUse + cleanText := text + + // Find all [Called markers + matches := embeddedToolCallPattern.FindAllStringSubmatchIndex(text, -1) + if len(matches) == 0 { + return text, nil + } + + // Process matches in reverse order to maintain correct indices + for i := len(matches) - 1; i >= 0; i-- { + matchStart := matches[i][0] + toolNameStart := matches[i][2] + toolNameEnd := matches[i][3] + + if toolNameStart < 0 || toolNameEnd < 0 { + continue + } + + toolName := text[toolNameStart:toolNameEnd] + + // Find the JSON object start (after "with args:") + jsonStart := matches[i][1] + if jsonStart >= len(text) { + continue + } + + // Skip whitespace to find the opening brace + for jsonStart < len(text) && (text[jsonStart] == ' ' || text[jsonStart] == '\t') { + jsonStart++ + } + + if jsonStart >= len(text) || text[jsonStart] != '{' { + continue + } + + // Find matching closing bracket + jsonEnd := findMatchingBracket(text, jsonStart) + if jsonEnd < 0 { + continue + } + + // Extract JSON and find the closing bracket of [Called ...] + jsonStr := text[jsonStart : jsonEnd+1] + + // Find the closing ] after the JSON + closingBracket := jsonEnd + 1 + for closingBracket < len(text) && text[closingBracket] != ']' { + closingBracket++ + } + if closingBracket >= len(text) { + continue + } + + // End index of the full tool call (closing ']' inclusive) + matchEnd := closingBracket + 1 + + // Repair and parse JSON + repairedJSON := RepairJSON(jsonStr) + var inputMap map[string]interface{} + if err := json.Unmarshal([]byte(repairedJSON), &inputMap); err != nil { + log.Debugf("kiro: failed to parse embedded tool call JSON: %v, raw: %s", err, jsonStr) + continue + } + + // Generate unique tool ID + toolUseID := "toolu_" + uuid.New().String()[:12] + + // Check for duplicates using name+input as key + dedupeKey := toolName + ":" + repairedJSON + if processedIDs != nil { + if processedIDs[dedupeKey] { + log.Debugf("kiro: skipping duplicate embedded tool call: %s", toolName) + // Still remove from text even if duplicate + if matchStart >= 0 && matchEnd <= len(cleanText) && matchStart <= matchEnd { + cleanText = cleanText[:matchStart] + cleanText[matchEnd:] + } + continue + } + processedIDs[dedupeKey] = true + } + + toolUses = append(toolUses, KiroToolUse{ + ToolUseID: toolUseID, + Name: toolName, + Input: inputMap, + }) + + log.Infof("kiro: extracted embedded tool call: %s (ID: %s)", toolName, toolUseID) + + // Remove from clean text (index-based removal to avoid deleting the wrong occurrence) + if matchStart >= 0 && matchEnd <= len(cleanText) && matchStart <= matchEnd { + cleanText = cleanText[:matchStart] + cleanText[matchEnd:] + } + } + + return cleanText, toolUses +} + +// findMatchingBracket finds the index of the closing brace/bracket that matches +// the opening one at startPos. Handles nested objects and strings correctly. +func findMatchingBracket(text string, startPos int) int { + if startPos >= len(text) { + return -1 + } + + openChar := text[startPos] + var closeChar byte + switch openChar { + case '{': + closeChar = '}' + case '[': + closeChar = ']' + default: + return -1 + } + + depth := 1 + inString := false + escapeNext := false + + for i := startPos + 1; i < len(text); i++ { + char := text[i] + + if escapeNext { + escapeNext = false + continue + } + + if char == '\\' && inString { + escapeNext = true + continue + } + + if char == '"' { + inString = !inString + continue + } + + if !inString { + if char == openChar { + depth++ + } else if char == closeChar { + depth-- + if depth == 0 { + return i + } + } + } + } + + return -1 +} + +// RepairJSON attempts to fix common JSON issues that may occur in tool call arguments. +// Conservative repair strategy: +// 1. First try to parse JSON directly - if valid, return as-is +// 2. Only attempt repair if parsing fails +// 3. After repair, validate the result - if still invalid, return original +func RepairJSON(jsonString string) string { + // Handle empty or invalid input + if jsonString == "" { + return "{}" + } + + str := strings.TrimSpace(jsonString) + if str == "" { + return "{}" + } + + // CONSERVATIVE STRATEGY: First try to parse directly + var testParse interface{} + if err := json.Unmarshal([]byte(str), &testParse); err == nil { + log.Debugf("kiro: repairJSON - JSON is already valid, returning unchanged") + return str + } + + log.Debugf("kiro: repairJSON - JSON parse failed, attempting repair") + originalStr := str + + // First, escape unescaped newlines/tabs within JSON string values + str = escapeNewlinesInStrings(str) + // Remove trailing commas before closing braces/brackets + str = trailingCommaPattern.ReplaceAllString(str, "$1") + + // Calculate bracket balance + braceCount := 0 + bracketCount := 0 + inString := false + escape := false + lastValidIndex := -1 + + for i := 0; i < len(str); i++ { + char := str[i] + + if escape { + escape = false + continue + } + + if char == '\\' { + escape = true + continue + } + + if char == '"' { + inString = !inString + continue + } + + if inString { + continue + } + + switch char { + case '{': + braceCount++ + case '}': + braceCount-- + case '[': + bracketCount++ + case ']': + bracketCount-- + } + + if braceCount >= 0 && bracketCount >= 0 { + lastValidIndex = i + } + } + + // If brackets are unbalanced, try to repair + if braceCount > 0 || bracketCount > 0 { + if lastValidIndex > 0 && lastValidIndex < len(str)-1 { + truncated := str[:lastValidIndex+1] + // Recount brackets after truncation + braceCount = 0 + bracketCount = 0 + inString = false + escape = false + for i := 0; i < len(truncated); i++ { + char := truncated[i] + if escape { + escape = false + continue + } + if char == '\\' { + escape = true + continue + } + if char == '"' { + inString = !inString + continue + } + if inString { + continue + } + switch char { + case '{': + braceCount++ + case '}': + braceCount-- + case '[': + bracketCount++ + case ']': + bracketCount-- + } + } + str = truncated + } + + // Add missing closing brackets + for braceCount > 0 { + str += "}" + braceCount-- + } + for bracketCount > 0 { + str += "]" + bracketCount-- + } + } + + // Validate repaired JSON + if err := json.Unmarshal([]byte(str), &testParse); err != nil { + log.Warnf("kiro: repairJSON - repair failed to produce valid JSON, returning original") + return originalStr + } + + log.Debugf("kiro: repairJSON - successfully repaired JSON") + return str +} + +// escapeNewlinesInStrings escapes literal newlines, tabs, and other control characters +// that appear inside JSON string values. +func escapeNewlinesInStrings(raw string) string { + var result strings.Builder + result.Grow(len(raw) + 100) + + inString := false + escaped := false + + for i := 0; i < len(raw); i++ { + c := raw[i] + + if escaped { + result.WriteByte(c) + escaped = false + continue + } + + if c == '\\' && inString { + result.WriteByte(c) + escaped = true + continue + } + + if c == '"' { + inString = !inString + result.WriteByte(c) + continue + } + + if inString { + switch c { + case '\n': + result.WriteString("\\n") + case '\r': + result.WriteString("\\r") + case '\t': + result.WriteString("\\t") + default: + result.WriteByte(c) + } + } else { + result.WriteByte(c) + } + } + + return result.String() +} + +// ProcessToolUseEvent handles a toolUseEvent from the Kiro stream. +// It accumulates input fragments and emits tool_use blocks when complete. +// Returns events to emit and updated state. +func ProcessToolUseEvent(event map[string]interface{}, currentToolUse *ToolUseState, processedIDs map[string]bool) ([]KiroToolUse, *ToolUseState) { + var toolUses []KiroToolUse + + // Extract from nested toolUseEvent or direct format + tu := event + if nested, ok := event["toolUseEvent"].(map[string]interface{}); ok { + tu = nested + } + + toolUseID := kirocommon.GetString(tu, "toolUseId") + toolName := kirocommon.GetString(tu, "name") + isStop := false + if stop, ok := tu["stop"].(bool); ok { + isStop = stop + } + + // Get input - can be string (fragment) or object (complete) + var inputFragment string + var inputMap map[string]interface{} + + if inputRaw, ok := tu["input"]; ok { + switch v := inputRaw.(type) { + case string: + inputFragment = v + case map[string]interface{}: + inputMap = v + } + } + + // New tool use starting + if toolUseID != "" && toolName != "" { + if currentToolUse != nil && currentToolUse.ToolUseID != toolUseID { + log.Warnf("kiro: interleaved tool use detected - new ID %s arrived while %s in progress, completing previous", + toolUseID, currentToolUse.ToolUseID) + if !processedIDs[currentToolUse.ToolUseID] { + incomplete := KiroToolUse{ + ToolUseID: currentToolUse.ToolUseID, + Name: currentToolUse.Name, + } + if currentToolUse.InputBuffer.Len() > 0 { + raw := currentToolUse.InputBuffer.String() + repaired := RepairJSON(raw) + + var input map[string]interface{} + if err := json.Unmarshal([]byte(repaired), &input); err != nil { + log.Warnf("kiro: failed to parse interleaved tool input: %v, raw: %s", err, raw) + input = make(map[string]interface{}) + } + incomplete.Input = input + } + toolUses = append(toolUses, incomplete) + processedIDs[currentToolUse.ToolUseID] = true + } + currentToolUse = nil + } + + if currentToolUse == nil { + if processedIDs != nil && processedIDs[toolUseID] { + log.Debugf("kiro: skipping duplicate toolUseEvent: %s", toolUseID) + return nil, nil + } + + currentToolUse = &ToolUseState{ + ToolUseID: toolUseID, + Name: toolName, + } + log.Infof("kiro: starting new tool use: %s (ID: %s)", toolName, toolUseID) + } + } + + // Accumulate input fragments + if currentToolUse != nil && inputFragment != "" { + currentToolUse.InputBuffer.WriteString(inputFragment) + log.Debugf("kiro: accumulated input fragment, total length: %d", currentToolUse.InputBuffer.Len()) + } + + // If complete input object provided directly + if currentToolUse != nil && inputMap != nil { + inputBytes, _ := json.Marshal(inputMap) + currentToolUse.InputBuffer.Reset() + currentToolUse.InputBuffer.Write(inputBytes) + } + + // Tool use complete + if isStop && currentToolUse != nil { + fullInput := currentToolUse.InputBuffer.String() + + // Repair and parse the accumulated JSON + repairedJSON := RepairJSON(fullInput) + var finalInput map[string]interface{} + if err := json.Unmarshal([]byte(repairedJSON), &finalInput); err != nil { + log.Warnf("kiro: failed to parse accumulated tool input: %v, raw: %s", err, fullInput) + finalInput = make(map[string]interface{}) + } + + toolUse := KiroToolUse{ + ToolUseID: currentToolUse.ToolUseID, + Name: currentToolUse.Name, + Input: finalInput, + } + toolUses = append(toolUses, toolUse) + + if processedIDs != nil { + processedIDs[currentToolUse.ToolUseID] = true + } + + log.Infof("kiro: completed tool use: %s (ID: %s)", currentToolUse.Name, currentToolUse.ToolUseID) + return toolUses, nil + } + + return toolUses, currentToolUse +} + +// DeduplicateToolUses removes duplicate tool uses based on toolUseId and content. +func DeduplicateToolUses(toolUses []KiroToolUse) []KiroToolUse { + seenIDs := make(map[string]bool) + seenContent := make(map[string]bool) + var unique []KiroToolUse + + for _, tu := range toolUses { + if seenIDs[tu.ToolUseID] { + log.Debugf("kiro: removing ID-duplicate tool use: %s (name: %s)", tu.ToolUseID, tu.Name) + continue + } + + inputJSON, _ := json.Marshal(tu.Input) + contentKey := tu.Name + ":" + string(inputJSON) + + if seenContent[contentKey] { + log.Debugf("kiro: removing content-duplicate tool use: %s (id: %s)", tu.Name, tu.ToolUseID) + continue + } + + seenIDs[tu.ToolUseID] = true + seenContent[contentKey] = true + unique = append(unique, tu) + } + + return unique +} + diff --git a/internal/translator/kiro/common/constants.go b/internal/translator/kiro/common/constants.go new file mode 100644 index 00000000..1d4b0330 --- /dev/null +++ b/internal/translator/kiro/common/constants.go @@ -0,0 +1,66 @@ +// Package common provides shared constants and utilities for Kiro translator. +package common + +const ( + // KiroMaxToolDescLen is the maximum description length for Kiro API tools. + // Kiro API limit is 10240 bytes, leave room for "..." + KiroMaxToolDescLen = 10237 + + // ThinkingStartTag is the start tag for thinking blocks in responses. + ThinkingStartTag = "" + + // ThinkingEndTag is the end tag for thinking blocks in responses. + ThinkingEndTag = "" + + // KiroAgenticSystemPrompt is injected only for -agentic models to prevent timeouts on large writes. + // AWS Kiro API has a 2-3 minute timeout for large file write operations. + KiroAgenticSystemPrompt = ` +# CRITICAL: CHUNKED WRITE PROTOCOL (MANDATORY) + +You MUST follow these rules for ALL file operations. Violation causes server timeouts and task failure. + +## ABSOLUTE LIMITS +- **MAXIMUM 350 LINES** per single write/edit operation - NO EXCEPTIONS +- **RECOMMENDED 300 LINES** or less for optimal performance +- **NEVER** write entire files in one operation if >300 lines + +## MANDATORY CHUNKED WRITE STRATEGY + +### For NEW FILES (>300 lines total): +1. FIRST: Write initial chunk (first 250-300 lines) using write_to_file/fsWrite +2. THEN: Append remaining content in 250-300 line chunks using file append operations +3. REPEAT: Continue appending until complete + +### For EDITING EXISTING FILES: +1. Use surgical edits (apply_diff/targeted edits) - change ONLY what's needed +2. NEVER rewrite entire files - use incremental modifications +3. Split large refactors into multiple small, focused edits + +### For LARGE CODE GENERATION: +1. Generate in logical sections (imports, types, functions separately) +2. Write each section as a separate operation +3. Use append operations for subsequent sections + +## EXAMPLES OF CORRECT BEHAVIOR + +✅ CORRECT: Writing a 600-line file +- Operation 1: Write lines 1-300 (initial file creation) +- Operation 2: Append lines 301-600 + +✅ CORRECT: Editing multiple functions +- Operation 1: Edit function A +- Operation 2: Edit function B +- Operation 3: Edit function C + +❌ WRONG: Writing 500 lines in single operation → TIMEOUT +❌ WRONG: Rewriting entire file to change 5 lines → TIMEOUT +❌ WRONG: Generating massive code blocks without chunking → TIMEOUT + +## WHY THIS MATTERS +- Server has 2-3 minute timeout for operations +- Large writes exceed timeout and FAIL completely +- Chunked writes are FASTER and more RELIABLE +- Failed writes waste time and require retry + +REMEMBER: When in doubt, write LESS per operation. Multiple small operations > one large operation.` +) \ No newline at end of file diff --git a/internal/translator/kiro/common/message_merge.go b/internal/translator/kiro/common/message_merge.go new file mode 100644 index 00000000..93f17f28 --- /dev/null +++ b/internal/translator/kiro/common/message_merge.go @@ -0,0 +1,125 @@ +// Package common provides shared utilities for Kiro translators. +package common + +import ( + "encoding/json" + + "github.com/tidwall/gjson" +) + +// MergeAdjacentMessages merges adjacent messages with the same role. +// This reduces API call complexity and improves compatibility. +// Based on AIClient-2-API implementation. +func MergeAdjacentMessages(messages []gjson.Result) []gjson.Result { + if len(messages) <= 1 { + return messages + } + + var merged []gjson.Result + for _, msg := range messages { + if len(merged) == 0 { + merged = append(merged, msg) + continue + } + + lastMsg := merged[len(merged)-1] + currentRole := msg.Get("role").String() + lastRole := lastMsg.Get("role").String() + + if currentRole == lastRole { + // Merge content from current message into last message + mergedContent := mergeMessageContent(lastMsg, msg) + // Create a new merged message JSON + mergedMsg := createMergedMessage(lastRole, mergedContent) + merged[len(merged)-1] = gjson.Parse(mergedMsg) + } else { + merged = append(merged, msg) + } + } + + return merged +} + +// mergeMessageContent merges the content of two messages with the same role. +// Handles both string content and array content (with text, tool_use, tool_result blocks). +func mergeMessageContent(msg1, msg2 gjson.Result) string { + content1 := msg1.Get("content") + content2 := msg2.Get("content") + + // Extract content blocks from both messages + var blocks1, blocks2 []map[string]interface{} + + if content1.IsArray() { + for _, block := range content1.Array() { + blocks1 = append(blocks1, blockToMap(block)) + } + } else if content1.Type == gjson.String { + blocks1 = append(blocks1, map[string]interface{}{ + "type": "text", + "text": content1.String(), + }) + } + + if content2.IsArray() { + for _, block := range content2.Array() { + blocks2 = append(blocks2, blockToMap(block)) + } + } else if content2.Type == gjson.String { + blocks2 = append(blocks2, map[string]interface{}{ + "type": "text", + "text": content2.String(), + }) + } + + // Merge text blocks if both end/start with text + if len(blocks1) > 0 && len(blocks2) > 0 { + if blocks1[len(blocks1)-1]["type"] == "text" && blocks2[0]["type"] == "text" { + // Merge the last text block of msg1 with the first text block of msg2 + text1 := blocks1[len(blocks1)-1]["text"].(string) + text2 := blocks2[0]["text"].(string) + blocks1[len(blocks1)-1]["text"] = text1 + "\n" + text2 + blocks2 = blocks2[1:] // Remove the merged block from blocks2 + } + } + + // Combine all blocks + allBlocks := append(blocks1, blocks2...) + + // Convert to JSON + result, _ := json.Marshal(allBlocks) + return string(result) +} + +// blockToMap converts a gjson.Result block to a map[string]interface{} +func blockToMap(block gjson.Result) map[string]interface{} { + result := make(map[string]interface{}) + block.ForEach(func(key, value gjson.Result) bool { + if value.IsObject() { + result[key.String()] = blockToMap(value) + } else if value.IsArray() { + var arr []interface{} + for _, item := range value.Array() { + if item.IsObject() { + arr = append(arr, blockToMap(item)) + } else { + arr = append(arr, item.Value()) + } + } + result[key.String()] = arr + } else { + result[key.String()] = value.Value() + } + return true + }) + return result +} + +// createMergedMessage creates a JSON string for a merged message +func createMergedMessage(role string, content string) string { + msg := map[string]interface{}{ + "role": role, + "content": json.RawMessage(content), + } + result, _ := json.Marshal(msg) + return string(result) +} \ No newline at end of file diff --git a/internal/translator/kiro/common/utils.go b/internal/translator/kiro/common/utils.go new file mode 100644 index 00000000..f5f5788a --- /dev/null +++ b/internal/translator/kiro/common/utils.go @@ -0,0 +1,16 @@ +// Package common provides shared constants and utilities for Kiro translator. +package common + +// GetString safely extracts a string from a map. +// Returns empty string if the key doesn't exist or the value is not a string. +func GetString(m map[string]interface{}, key string) string { + if v, ok := m[key].(string); ok { + return v + } + return "" +} + +// GetStringValue is an alias for GetString for backward compatibility. +func GetStringValue(m map[string]interface{}, key string) string { + return GetString(m, key) +} \ No newline at end of file diff --git a/internal/translator/kiro/openai/chat-completions/kiro_openai_request.go b/internal/translator/kiro/openai/chat-completions/kiro_openai_request.go deleted file mode 100644 index d1094c1c..00000000 --- a/internal/translator/kiro/openai/chat-completions/kiro_openai_request.go +++ /dev/null @@ -1,348 +0,0 @@ -// Package chat_completions provides request translation from OpenAI to Kiro format. -package chat_completions - -import ( - "bytes" - "encoding/json" - "strings" - - "github.com/tidwall/gjson" - "github.com/tidwall/sjson" -) - -// reasoningEffortToBudget maps OpenAI reasoning_effort values to Claude thinking budget_tokens. -// OpenAI uses "low", "medium", "high" while Claude uses numeric budget_tokens. -var reasoningEffortToBudget = map[string]int{ - "low": 4000, - "medium": 16000, - "high": 32000, -} - -// ConvertOpenAIRequestToKiro transforms an OpenAI Chat Completions API request into Kiro (Claude) format. -// Kiro uses Claude-compatible format internally, so we primarily pass through to Claude format. -// Supports tool calling: OpenAI tools -> Claude tools, tool_calls -> tool_use, tool messages -> tool_result. -// Supports reasoning/thinking: OpenAI reasoning_effort -> Claude thinking parameter. -func ConvertOpenAIRequestToKiro(modelName string, inputRawJSON []byte, stream bool) []byte { - rawJSON := bytes.Clone(inputRawJSON) - root := gjson.ParseBytes(rawJSON) - - // Build Claude-compatible request - out := `{"model":"","max_tokens":32000,"messages":[]}` - - // Set model - out, _ = sjson.Set(out, "model", modelName) - - // Copy max_tokens if present - if v := root.Get("max_tokens"); v.Exists() { - out, _ = sjson.Set(out, "max_tokens", v.Int()) - } - - // Copy temperature if present - if v := root.Get("temperature"); v.Exists() { - out, _ = sjson.Set(out, "temperature", v.Float()) - } - - // Copy top_p if present - if v := root.Get("top_p"); v.Exists() { - out, _ = sjson.Set(out, "top_p", v.Float()) - } - - // Handle OpenAI reasoning_effort parameter -> Claude thinking parameter - // OpenAI format: {"reasoning_effort": "low"|"medium"|"high"} - // Claude format: {"thinking": {"type": "enabled", "budget_tokens": N}} - if v := root.Get("reasoning_effort"); v.Exists() { - effort := v.String() - if budget, ok := reasoningEffortToBudget[effort]; ok { - thinking := map[string]interface{}{ - "type": "enabled", - "budget_tokens": budget, - } - out, _ = sjson.Set(out, "thinking", thinking) - } - } - - // Also support direct thinking parameter passthrough (for Claude API compatibility) - // Claude format: {"thinking": {"type": "enabled", "budget_tokens": N}} - if v := root.Get("thinking"); v.Exists() && v.IsObject() { - out, _ = sjson.Set(out, "thinking", v.Value()) - } - - // Convert OpenAI tools to Claude tools format - if tools := root.Get("tools"); tools.Exists() && tools.IsArray() { - claudeTools := make([]interface{}, 0) - for _, tool := range tools.Array() { - if tool.Get("type").String() == "function" { - fn := tool.Get("function") - claudeTool := map[string]interface{}{ - "name": fn.Get("name").String(), - "description": fn.Get("description").String(), - } - // Convert parameters to input_schema - if params := fn.Get("parameters"); params.Exists() { - claudeTool["input_schema"] = params.Value() - } else { - claudeTool["input_schema"] = map[string]interface{}{ - "type": "object", - "properties": map[string]interface{}{}, - } - } - claudeTools = append(claudeTools, claudeTool) - } - } - if len(claudeTools) > 0 { - out, _ = sjson.Set(out, "tools", claudeTools) - } - } - - // Process messages - messages := root.Get("messages") - if messages.Exists() && messages.IsArray() { - claudeMessages := make([]interface{}, 0) - var systemPrompt string - - // Track pending tool results to merge with next user message - var pendingToolResults []map[string]interface{} - - for _, msg := range messages.Array() { - role := msg.Get("role").String() - content := msg.Get("content") - - if role == "system" { - // Extract system message - if content.IsArray() { - for _, part := range content.Array() { - if part.Get("type").String() == "text" { - systemPrompt += part.Get("text").String() + "\n" - } - } - } else { - systemPrompt = content.String() - } - continue - } - - if role == "tool" { - // OpenAI tool message -> Claude tool_result content block - toolCallID := msg.Get("tool_call_id").String() - toolContent := content.String() - - toolResult := map[string]interface{}{ - "type": "tool_result", - "tool_use_id": toolCallID, - } - - // Handle content - can be string or structured - if content.IsArray() { - contentParts := make([]interface{}, 0) - for _, part := range content.Array() { - if part.Get("type").String() == "text" { - contentParts = append(contentParts, map[string]interface{}{ - "type": "text", - "text": part.Get("text").String(), - }) - } - } - toolResult["content"] = contentParts - } else { - toolResult["content"] = toolContent - } - - pendingToolResults = append(pendingToolResults, toolResult) - continue - } - - claudeMsg := map[string]interface{}{ - "role": role, - } - - // Handle assistant messages with tool_calls - if role == "assistant" && msg.Get("tool_calls").Exists() { - contentParts := make([]interface{}, 0) - - // Add text content if present - if content.Exists() && content.String() != "" { - contentParts = append(contentParts, map[string]interface{}{ - "type": "text", - "text": content.String(), - }) - } - - // Convert tool_calls to tool_use blocks - for _, toolCall := range msg.Get("tool_calls").Array() { - toolUseID := toolCall.Get("id").String() - fnName := toolCall.Get("function.name").String() - fnArgs := toolCall.Get("function.arguments").String() - - // Parse arguments JSON - var argsMap map[string]interface{} - if err := json.Unmarshal([]byte(fnArgs), &argsMap); err != nil { - argsMap = map[string]interface{}{"raw": fnArgs} - } - - contentParts = append(contentParts, map[string]interface{}{ - "type": "tool_use", - "id": toolUseID, - "name": fnName, - "input": argsMap, - }) - } - - claudeMsg["content"] = contentParts - claudeMessages = append(claudeMessages, claudeMsg) - continue - } - - // Handle user messages - may need to include pending tool results - if role == "user" && len(pendingToolResults) > 0 { - contentParts := make([]interface{}, 0) - - // Add pending tool results first - for _, tr := range pendingToolResults { - contentParts = append(contentParts, tr) - } - pendingToolResults = nil - - // Add user content - if content.IsArray() { - for _, part := range content.Array() { - partType := part.Get("type").String() - if partType == "text" { - contentParts = append(contentParts, map[string]interface{}{ - "type": "text", - "text": part.Get("text").String(), - }) - } else if partType == "image_url" { - imageURL := part.Get("image_url.url").String() - - // Check if it's base64 format (data:image/png;base64,xxxxx) - if strings.HasPrefix(imageURL, "data:") { - // Parse data URL format - // Format: data:image/png;base64,xxxxx - commaIdx := strings.Index(imageURL, ",") - if commaIdx != -1 { - // Extract media_type (e.g., "image/png") - header := imageURL[5:commaIdx] // Remove "data:" prefix - mediaType := header - if semiIdx := strings.Index(header, ";"); semiIdx != -1 { - mediaType = header[:semiIdx] - } - - // Extract base64 data - base64Data := imageURL[commaIdx+1:] - - contentParts = append(contentParts, map[string]interface{}{ - "type": "image", - "source": map[string]interface{}{ - "type": "base64", - "media_type": mediaType, - "data": base64Data, - }, - }) - } - } else { - // Regular URL format - keep original logic - contentParts = append(contentParts, map[string]interface{}{ - "type": "image", - "source": map[string]interface{}{ - "type": "url", - "url": imageURL, - }, - }) - } - } - } - } else if content.String() != "" { - contentParts = append(contentParts, map[string]interface{}{ - "type": "text", - "text": content.String(), - }) - } - - claudeMsg["content"] = contentParts - claudeMessages = append(claudeMessages, claudeMsg) - continue - } - - // Handle regular content - if content.IsArray() { - contentParts := make([]interface{}, 0) - for _, part := range content.Array() { - partType := part.Get("type").String() - if partType == "text" { - contentParts = append(contentParts, map[string]interface{}{ - "type": "text", - "text": part.Get("text").String(), - }) - } else if partType == "image_url" { - imageURL := part.Get("image_url.url").String() - - // Check if it's base64 format (data:image/png;base64,xxxxx) - if strings.HasPrefix(imageURL, "data:") { - // Parse data URL format - // Format: data:image/png;base64,xxxxx - commaIdx := strings.Index(imageURL, ",") - if commaIdx != -1 { - // Extract media_type (e.g., "image/png") - header := imageURL[5:commaIdx] // Remove "data:" prefix - mediaType := header - if semiIdx := strings.Index(header, ";"); semiIdx != -1 { - mediaType = header[:semiIdx] - } - - // Extract base64 data - base64Data := imageURL[commaIdx+1:] - - contentParts = append(contentParts, map[string]interface{}{ - "type": "image", - "source": map[string]interface{}{ - "type": "base64", - "media_type": mediaType, - "data": base64Data, - }, - }) - } - } else { - // Regular URL format - keep original logic - contentParts = append(contentParts, map[string]interface{}{ - "type": "image", - "source": map[string]interface{}{ - "type": "url", - "url": imageURL, - }, - }) - } - } - } - claudeMsg["content"] = contentParts - } else { - claudeMsg["content"] = content.String() - } - - claudeMessages = append(claudeMessages, claudeMsg) - } - - // If there are pending tool results without a following user message, - // create a user message with just the tool results - if len(pendingToolResults) > 0 { - contentParts := make([]interface{}, 0) - for _, tr := range pendingToolResults { - contentParts = append(contentParts, tr) - } - claudeMessages = append(claudeMessages, map[string]interface{}{ - "role": "user", - "content": contentParts, - }) - } - - out, _ = sjson.Set(out, "messages", claudeMessages) - - if systemPrompt != "" { - out, _ = sjson.Set(out, "system", systemPrompt) - } - } - - // Set stream - out, _ = sjson.Set(out, "stream", stream) - - return []byte(out) -} diff --git a/internal/translator/kiro/openai/chat-completions/kiro_openai_response.go b/internal/translator/kiro/openai/chat-completions/kiro_openai_response.go deleted file mode 100644 index 2fab2a4d..00000000 --- a/internal/translator/kiro/openai/chat-completions/kiro_openai_response.go +++ /dev/null @@ -1,404 +0,0 @@ -// Package chat_completions provides response translation from Kiro to OpenAI format. -package chat_completions - -import ( - "context" - "encoding/json" - "strings" - "time" - - "github.com/google/uuid" - "github.com/tidwall/gjson" -) - -// ConvertKiroResponseToOpenAI converts Kiro streaming response to OpenAI SSE format. -// Handles Claude SSE events: content_block_start, content_block_delta, input_json_delta, -// content_block_stop, message_delta, and message_stop. -// Input may be in SSE format: "event: xxx\ndata: {...}" or raw JSON. -func ConvertKiroResponseToOpenAI(ctx context.Context, model string, originalRequest, request, rawResponse []byte, param *any) []string { - raw := string(rawResponse) - var results []string - - // Handle SSE format: extract JSON from "data: " lines - // Input format: "event: message_start\ndata: {...}" - lines := strings.Split(raw, "\n") - for _, line := range lines { - line = strings.TrimSpace(line) - if strings.HasPrefix(line, "data: ") { - jsonPart := strings.TrimPrefix(line, "data: ") - chunks := convertClaudeEventToOpenAI(jsonPart, model) - results = append(results, chunks...) - } else if strings.HasPrefix(line, "{") { - // Raw JSON (backward compatibility) - chunks := convertClaudeEventToOpenAI(line, model) - results = append(results, chunks...) - } - } - - return results -} - -// convertClaudeEventToOpenAI converts a single Claude JSON event to OpenAI format -func convertClaudeEventToOpenAI(jsonStr string, model string) []string { - root := gjson.Parse(jsonStr) - var results []string - - eventType := root.Get("type").String() - - switch eventType { - case "message_start": - // Initial message event - emit initial chunk with role - response := map[string]interface{}{ - "id": "chatcmpl-" + uuid.New().String()[:24], - "object": "chat.completion.chunk", - "created": time.Now().Unix(), - "model": model, - "choices": []map[string]interface{}{ - { - "index": 0, - "delta": map[string]interface{}{ - "role": "assistant", - "content": "", - }, - "finish_reason": nil, - }, - }, - } - result, _ := json.Marshal(response) - results = append(results, string(result)) - return results - - case "content_block_start": - // Start of a content block (text or tool_use) - blockType := root.Get("content_block.type").String() - index := int(root.Get("index").Int()) - - if blockType == "tool_use" { - // Start of tool_use block - toolUseID := root.Get("content_block.id").String() - toolName := root.Get("content_block.name").String() - - toolCall := map[string]interface{}{ - "index": index, - "id": toolUseID, - "type": "function", - "function": map[string]interface{}{ - "name": toolName, - "arguments": "", - }, - } - - response := map[string]interface{}{ - "id": "chatcmpl-" + uuid.New().String()[:24], - "object": "chat.completion.chunk", - "created": time.Now().Unix(), - "model": model, - "choices": []map[string]interface{}{ - { - "index": 0, - "delta": map[string]interface{}{ - "tool_calls": []map[string]interface{}{toolCall}, - }, - "finish_reason": nil, - }, - }, - } - result, _ := json.Marshal(response) - results = append(results, string(result)) - } - return results - - case "content_block_delta": - index := int(root.Get("index").Int()) - deltaType := root.Get("delta.type").String() - - if deltaType == "text_delta" { - // Text content delta - contentDelta := root.Get("delta.text").String() - if contentDelta != "" { - response := map[string]interface{}{ - "id": "chatcmpl-" + uuid.New().String()[:24], - "object": "chat.completion.chunk", - "created": time.Now().Unix(), - "model": model, - "choices": []map[string]interface{}{ - { - "index": 0, - "delta": map[string]interface{}{ - "content": contentDelta, - }, - "finish_reason": nil, - }, - }, - } - result, _ := json.Marshal(response) - results = append(results, string(result)) - } - } else if deltaType == "thinking_delta" { - // Thinking/reasoning content delta - convert to OpenAI reasoning_content format - thinkingDelta := root.Get("delta.thinking").String() - if thinkingDelta != "" { - response := map[string]interface{}{ - "id": "chatcmpl-" + uuid.New().String()[:24], - "object": "chat.completion.chunk", - "created": time.Now().Unix(), - "model": model, - "choices": []map[string]interface{}{ - { - "index": 0, - "delta": map[string]interface{}{ - "reasoning_content": thinkingDelta, - }, - "finish_reason": nil, - }, - }, - } - result, _ := json.Marshal(response) - results = append(results, string(result)) - } - } else if deltaType == "input_json_delta" { - // Tool input delta (streaming arguments) - partialJSON := root.Get("delta.partial_json").String() - if partialJSON != "" { - toolCall := map[string]interface{}{ - "index": index, - "function": map[string]interface{}{ - "arguments": partialJSON, - }, - } - - response := map[string]interface{}{ - "id": "chatcmpl-" + uuid.New().String()[:24], - "object": "chat.completion.chunk", - "created": time.Now().Unix(), - "model": model, - "choices": []map[string]interface{}{ - { - "index": 0, - "delta": map[string]interface{}{ - "tool_calls": []map[string]interface{}{toolCall}, - }, - "finish_reason": nil, - }, - }, - } - result, _ := json.Marshal(response) - results = append(results, string(result)) - } - } - return results - - case "content_block_stop": - // End of content block - no output needed for OpenAI format - return results - - case "message_delta": - // Final message delta with stop_reason and usage - stopReason := root.Get("delta.stop_reason").String() - if stopReason != "" { - finishReason := "stop" - if stopReason == "tool_use" { - finishReason = "tool_calls" - } else if stopReason == "end_turn" { - finishReason = "stop" - } else if stopReason == "max_tokens" { - finishReason = "length" - } - - response := map[string]interface{}{ - "id": "chatcmpl-" + uuid.New().String()[:24], - "object": "chat.completion.chunk", - "created": time.Now().Unix(), - "model": model, - "choices": []map[string]interface{}{ - { - "index": 0, - "delta": map[string]interface{}{}, - "finish_reason": finishReason, - }, - }, - } - - // Extract and include usage information from message_delta event - usage := root.Get("usage") - if usage.Exists() { - inputTokens := usage.Get("input_tokens").Int() - outputTokens := usage.Get("output_tokens").Int() - response["usage"] = map[string]interface{}{ - "prompt_tokens": inputTokens, - "completion_tokens": outputTokens, - "total_tokens": inputTokens + outputTokens, - } - } - - result, _ := json.Marshal(response) - results = append(results, string(result)) - } - return results - - case "message_stop": - // End of message - could emit [DONE] marker - return results - } - - // Fallback: handle raw content for backward compatibility - var contentDelta string - if delta := root.Get("delta.text"); delta.Exists() { - contentDelta = delta.String() - } else if content := root.Get("content"); content.Exists() && root.Get("type").String() == "" { - contentDelta = content.String() - } - - if contentDelta != "" { - response := map[string]interface{}{ - "id": "chatcmpl-" + uuid.New().String()[:24], - "object": "chat.completion.chunk", - "created": time.Now().Unix(), - "model": model, - "choices": []map[string]interface{}{ - { - "index": 0, - "delta": map[string]interface{}{ - "content": contentDelta, - }, - "finish_reason": nil, - }, - }, - } - result, _ := json.Marshal(response) - results = append(results, string(result)) - } - - // Handle tool_use content blocks (Claude format) - fallback - toolUses := root.Get("delta.tool_use") - if !toolUses.Exists() { - toolUses = root.Get("tool_use") - } - if toolUses.Exists() && toolUses.IsObject() { - inputJSON := toolUses.Get("input").String() - if inputJSON == "" { - if inputObj := toolUses.Get("input"); inputObj.Exists() { - inputBytes, _ := json.Marshal(inputObj.Value()) - inputJSON = string(inputBytes) - } - } - - toolCall := map[string]interface{}{ - "index": 0, - "id": toolUses.Get("id").String(), - "type": "function", - "function": map[string]interface{}{ - "name": toolUses.Get("name").String(), - "arguments": inputJSON, - }, - } - - response := map[string]interface{}{ - "id": "chatcmpl-" + uuid.New().String()[:24], - "object": "chat.completion.chunk", - "created": time.Now().Unix(), - "model": model, - "choices": []map[string]interface{}{ - { - "index": 0, - "delta": map[string]interface{}{ - "tool_calls": []map[string]interface{}{toolCall}, - }, - "finish_reason": nil, - }, - }, - } - result, _ := json.Marshal(response) - results = append(results, string(result)) - } - - return results -} - -// ConvertKiroResponseToOpenAINonStream converts Kiro non-streaming response to OpenAI format. -func ConvertKiroResponseToOpenAINonStream(ctx context.Context, model string, originalRequest, request, rawResponse []byte, param *any) string { - root := gjson.ParseBytes(rawResponse) - - var content string - var reasoningContent string - var toolCalls []map[string]interface{} - - contentArray := root.Get("content") - if contentArray.IsArray() { - for _, item := range contentArray.Array() { - itemType := item.Get("type").String() - if itemType == "text" { - content += item.Get("text").String() - } else if itemType == "thinking" { - // Extract thinking/reasoning content - reasoningContent += item.Get("thinking").String() - } else if itemType == "tool_use" { - // Convert Claude tool_use to OpenAI tool_calls format - inputJSON := item.Get("input").String() - if inputJSON == "" { - // If input is an object, marshal it - if inputObj := item.Get("input"); inputObj.Exists() { - inputBytes, _ := json.Marshal(inputObj.Value()) - inputJSON = string(inputBytes) - } - } - toolCall := map[string]interface{}{ - "id": item.Get("id").String(), - "type": "function", - "function": map[string]interface{}{ - "name": item.Get("name").String(), - "arguments": inputJSON, - }, - } - toolCalls = append(toolCalls, toolCall) - } - } - } else { - content = root.Get("content").String() - } - - inputTokens := root.Get("usage.input_tokens").Int() - outputTokens := root.Get("usage.output_tokens").Int() - - message := map[string]interface{}{ - "role": "assistant", - "content": content, - } - - // Add reasoning_content if present (OpenAI reasoning format) - if reasoningContent != "" { - message["reasoning_content"] = reasoningContent - } - - // Add tool_calls if present - if len(toolCalls) > 0 { - message["tool_calls"] = toolCalls - } - - finishReason := "stop" - if len(toolCalls) > 0 { - finishReason = "tool_calls" - } - - response := map[string]interface{}{ - "id": "chatcmpl-" + uuid.New().String()[:24], - "object": "chat.completion", - "created": time.Now().Unix(), - "model": model, - "choices": []map[string]interface{}{ - { - "index": 0, - "message": message, - "finish_reason": finishReason, - }, - }, - "usage": map[string]interface{}{ - "prompt_tokens": inputTokens, - "completion_tokens": outputTokens, - "total_tokens": inputTokens + outputTokens, - }, - } - - result, _ := json.Marshal(response) - return string(result) -} diff --git a/internal/translator/kiro/openai/chat-completions/init.go b/internal/translator/kiro/openai/init.go similarity index 56% rename from internal/translator/kiro/openai/chat-completions/init.go rename to internal/translator/kiro/openai/init.go index 2a99d0e0..653eed45 100644 --- a/internal/translator/kiro/openai/chat-completions/init.go +++ b/internal/translator/kiro/openai/init.go @@ -1,4 +1,5 @@ -package chat_completions +// Package openai provides translation between OpenAI Chat Completions and Kiro formats. +package openai import ( . "github.com/router-for-me/CLIProxyAPI/v6/internal/constant" @@ -8,12 +9,12 @@ import ( func init() { translator.Register( - OpenAI, - Kiro, + OpenAI, // source format + Kiro, // target format ConvertOpenAIRequestToKiro, interfaces.TranslateResponse{ - Stream: ConvertKiroResponseToOpenAI, - NonStream: ConvertKiroResponseToOpenAINonStream, + Stream: ConvertKiroStreamToOpenAI, + NonStream: ConvertKiroNonStreamToOpenAI, }, ) -} +} \ No newline at end of file diff --git a/internal/translator/kiro/openai/kiro_openai.go b/internal/translator/kiro/openai/kiro_openai.go new file mode 100644 index 00000000..35cd0424 --- /dev/null +++ b/internal/translator/kiro/openai/kiro_openai.go @@ -0,0 +1,368 @@ +// Package openai provides translation between OpenAI Chat Completions and Kiro formats. +// This package enables direct OpenAI → Kiro translation, bypassing the Claude intermediate layer. +// +// The Kiro executor generates Claude-compatible SSE format internally, so the streaming response +// translation converts from Claude SSE format to OpenAI SSE format. +package openai + +import ( + "bytes" + "context" + "encoding/json" + "strings" + + kirocommon "github.com/router-for-me/CLIProxyAPI/v6/internal/translator/kiro/common" + "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/usage" + log "github.com/sirupsen/logrus" + "github.com/tidwall/gjson" +) + +// ConvertKiroStreamToOpenAI converts Kiro streaming response to OpenAI format. +// The Kiro executor emits Claude-compatible SSE events, so this function translates +// from Claude SSE format to OpenAI SSE format. +// +// Claude SSE format: +// - event: message_start\ndata: {...} +// - event: content_block_start\ndata: {...} +// - event: content_block_delta\ndata: {...} +// - event: content_block_stop\ndata: {...} +// - event: message_delta\ndata: {...} +// - event: message_stop\ndata: {...} +// +// OpenAI SSE format: +// - data: {"id":"...","object":"chat.completion.chunk",...} +// - data: [DONE] +func ConvertKiroStreamToOpenAI(ctx context.Context, model string, originalRequest, request, rawResponse []byte, param *any) []string { + // Initialize state if needed + if *param == nil { + *param = NewOpenAIStreamState(model) + } + state := (*param).(*OpenAIStreamState) + + // Parse the Claude SSE event + responseStr := string(rawResponse) + + // Handle raw event format (event: xxx\ndata: {...}) + var eventType string + var eventData string + + if strings.HasPrefix(responseStr, "event:") { + // Parse event type and data + lines := strings.SplitN(responseStr, "\n", 2) + if len(lines) >= 1 { + eventType = strings.TrimSpace(strings.TrimPrefix(lines[0], "event:")) + } + if len(lines) >= 2 && strings.HasPrefix(lines[1], "data:") { + eventData = strings.TrimSpace(strings.TrimPrefix(lines[1], "data:")) + } + } else if strings.HasPrefix(responseStr, "data:") { + // Just data line + eventData = strings.TrimSpace(strings.TrimPrefix(responseStr, "data:")) + } else { + // Try to parse as raw JSON + eventData = strings.TrimSpace(responseStr) + } + + if eventData == "" { + return []string{} + } + + // Parse the event data as JSON + eventJSON := gjson.Parse(eventData) + if !eventJSON.Exists() { + return []string{} + } + + // Determine event type from JSON if not already set + if eventType == "" { + eventType = eventJSON.Get("type").String() + } + + var results []string + + switch eventType { + case "message_start": + // Send first chunk with role + firstChunk := BuildOpenAISSEFirstChunk(state) + results = append(results, firstChunk) + + case "content_block_start": + // Check block type + blockType := eventJSON.Get("content_block.type").String() + switch blockType { + case "text": + // Text block starting - nothing to emit yet + case "thinking": + // Thinking block starting - nothing to emit yet for OpenAI + case "tool_use": + // Tool use block starting + toolUseID := eventJSON.Get("content_block.id").String() + toolName := eventJSON.Get("content_block.name").String() + chunk := BuildOpenAISSEToolCallStart(state, toolUseID, toolName) + results = append(results, chunk) + state.ToolCallIndex++ + } + + case "content_block_delta": + deltaType := eventJSON.Get("delta.type").String() + switch deltaType { + case "text_delta": + textDelta := eventJSON.Get("delta.text").String() + if textDelta != "" { + chunk := BuildOpenAISSETextDelta(state, textDelta) + results = append(results, chunk) + } + case "thinking_delta": + // Convert thinking to reasoning_content for o1-style compatibility + thinkingDelta := eventJSON.Get("delta.thinking").String() + if thinkingDelta != "" { + chunk := BuildOpenAISSEReasoningDelta(state, thinkingDelta) + results = append(results, chunk) + } + case "input_json_delta": + // Tool call arguments delta + partialJSON := eventJSON.Get("delta.partial_json").String() + if partialJSON != "" { + // Get the tool index from content block index + blockIndex := int(eventJSON.Get("index").Int()) + chunk := BuildOpenAISSEToolCallArgumentsDelta(state, partialJSON, blockIndex-1) // Adjust for 0-based tool index + results = append(results, chunk) + } + } + + case "content_block_stop": + // Content block ended - nothing to emit for OpenAI + + case "message_delta": + // Message delta with stop_reason + stopReason := eventJSON.Get("delta.stop_reason").String() + finishReason := mapKiroStopReasonToOpenAI(stopReason) + if finishReason != "" { + chunk := BuildOpenAISSEFinish(state, finishReason) + results = append(results, chunk) + } + + // Extract usage if present + if eventJSON.Get("usage").Exists() { + inputTokens := eventJSON.Get("usage.input_tokens").Int() + outputTokens := eventJSON.Get("usage.output_tokens").Int() + usageInfo := usage.Detail{ + InputTokens: inputTokens, + OutputTokens: outputTokens, + TotalTokens: inputTokens + outputTokens, + } + chunk := BuildOpenAISSEUsage(state, usageInfo) + results = append(results, chunk) + } + + case "message_stop": + // Final event - emit [DONE] + results = append(results, BuildOpenAISSEDone()) + + case "ping": + // Ping event with usage - optionally emit usage chunk + if eventJSON.Get("usage").Exists() { + inputTokens := eventJSON.Get("usage.input_tokens").Int() + outputTokens := eventJSON.Get("usage.output_tokens").Int() + usageInfo := usage.Detail{ + InputTokens: inputTokens, + OutputTokens: outputTokens, + TotalTokens: inputTokens + outputTokens, + } + chunk := BuildOpenAISSEUsage(state, usageInfo) + results = append(results, chunk) + } + } + + return results +} + +// ConvertKiroNonStreamToOpenAI converts Kiro non-streaming response to OpenAI format. +// The Kiro executor returns Claude-compatible JSON responses, so this function translates +// from Claude format to OpenAI format. +func ConvertKiroNonStreamToOpenAI(ctx context.Context, model string, originalRequest, request, rawResponse []byte, param *any) string { + // Parse the Claude-format response + response := gjson.ParseBytes(rawResponse) + + // Extract content + var content string + var toolUses []KiroToolUse + var stopReason string + + // Get stop_reason + stopReason = response.Get("stop_reason").String() + + // Process content blocks + contentBlocks := response.Get("content") + if contentBlocks.IsArray() { + for _, block := range contentBlocks.Array() { + blockType := block.Get("type").String() + switch blockType { + case "text": + content += block.Get("text").String() + case "thinking": + // Skip thinking blocks for OpenAI format (or convert to reasoning_content if needed) + case "tool_use": + toolUseID := block.Get("id").String() + toolName := block.Get("name").String() + toolInput := block.Get("input") + + var inputMap map[string]interface{} + if toolInput.IsObject() { + inputMap = make(map[string]interface{}) + toolInput.ForEach(func(key, value gjson.Result) bool { + inputMap[key.String()] = value.Value() + return true + }) + } + + toolUses = append(toolUses, KiroToolUse{ + ToolUseID: toolUseID, + Name: toolName, + Input: inputMap, + }) + } + } + } + + // Extract usage + usageInfo := usage.Detail{ + InputTokens: response.Get("usage.input_tokens").Int(), + OutputTokens: response.Get("usage.output_tokens").Int(), + } + usageInfo.TotalTokens = usageInfo.InputTokens + usageInfo.OutputTokens + + // Build OpenAI response + openaiResponse := BuildOpenAIResponse(content, toolUses, model, usageInfo, stopReason) + return string(openaiResponse) +} + +// ParseClaudeEvent parses a Claude SSE event and returns the event type and data +func ParseClaudeEvent(rawEvent []byte) (eventType string, eventData []byte) { + lines := bytes.Split(rawEvent, []byte("\n")) + for _, line := range lines { + line = bytes.TrimSpace(line) + if bytes.HasPrefix(line, []byte("event:")) { + eventType = string(bytes.TrimSpace(bytes.TrimPrefix(line, []byte("event:")))) + } else if bytes.HasPrefix(line, []byte("data:")) { + eventData = bytes.TrimSpace(bytes.TrimPrefix(line, []byte("data:"))) + } + } + return eventType, eventData +} + +// ExtractThinkingFromContent parses content to extract thinking blocks. +// Returns cleaned content (without thinking tags) and whether thinking was found. +func ExtractThinkingFromContent(content string) (string, string, bool) { + if !strings.Contains(content, kirocommon.ThinkingStartTag) { + return content, "", false + } + + var cleanedContent strings.Builder + var thinkingContent strings.Builder + hasThinking := false + remaining := content + + for len(remaining) > 0 { + startIdx := strings.Index(remaining, kirocommon.ThinkingStartTag) + if startIdx == -1 { + cleanedContent.WriteString(remaining) + break + } + + // Add content before thinking tag + cleanedContent.WriteString(remaining[:startIdx]) + + // Move past opening tag + remaining = remaining[startIdx+len(kirocommon.ThinkingStartTag):] + + // Find closing tag + endIdx := strings.Index(remaining, kirocommon.ThinkingEndTag) + if endIdx == -1 { + // No closing tag - treat rest as thinking + thinkingContent.WriteString(remaining) + hasThinking = true + break + } + + // Extract thinking content + thinkingContent.WriteString(remaining[:endIdx]) + hasThinking = true + remaining = remaining[endIdx+len(kirocommon.ThinkingEndTag):] + } + + return strings.TrimSpace(cleanedContent.String()), strings.TrimSpace(thinkingContent.String()), hasThinking +} + +// ConvertOpenAIToolsToKiroFormat is a helper that converts OpenAI tools format to Kiro format +func ConvertOpenAIToolsToKiroFormat(tools []map[string]interface{}) []KiroToolWrapper { + var kiroTools []KiroToolWrapper + + for _, tool := range tools { + toolType, _ := tool["type"].(string) + if toolType != "function" { + continue + } + + fn, ok := tool["function"].(map[string]interface{}) + if !ok { + continue + } + + name := kirocommon.GetString(fn, "name") + description := kirocommon.GetString(fn, "description") + parameters := fn["parameters"] + + if name == "" { + continue + } + + if description == "" { + description = "Tool: " + name + } + + kiroTools = append(kiroTools, KiroToolWrapper{ + ToolSpecification: KiroToolSpecification{ + Name: name, + Description: description, + InputSchema: KiroInputSchema{JSON: parameters}, + }, + }) + } + + return kiroTools +} + +// OpenAIStreamParams holds parameters for OpenAI streaming conversion +type OpenAIStreamParams struct { + State *OpenAIStreamState + ThinkingState *ThinkingTagState + ToolCallsEmitted map[string]bool +} + +// NewOpenAIStreamParams creates new streaming parameters +func NewOpenAIStreamParams(model string) *OpenAIStreamParams { + return &OpenAIStreamParams{ + State: NewOpenAIStreamState(model), + ThinkingState: NewThinkingTagState(), + ToolCallsEmitted: make(map[string]bool), + } +} + +// ConvertClaudeToolUseToOpenAI converts a Claude tool_use block to OpenAI tool_calls format +func ConvertClaudeToolUseToOpenAI(toolUseID, toolName string, input map[string]interface{}) map[string]interface{} { + inputJSON, _ := json.Marshal(input) + return map[string]interface{}{ + "id": toolUseID, + "type": "function", + "function": map[string]interface{}{ + "name": toolName, + "arguments": string(inputJSON), + }, + } +} + +// LogStreamEvent logs a streaming event for debugging +func LogStreamEvent(eventType, data string) { + log.Debugf("kiro-openai: stream event type=%s, data_len=%d", eventType, len(data)) +} \ No newline at end of file diff --git a/internal/translator/kiro/openai/kiro_openai_request.go b/internal/translator/kiro/openai/kiro_openai_request.go new file mode 100644 index 00000000..4aaa8b4e --- /dev/null +++ b/internal/translator/kiro/openai/kiro_openai_request.go @@ -0,0 +1,604 @@ +// Package openai provides request translation from OpenAI Chat Completions to Kiro format. +// It handles parsing and transforming OpenAI API requests into the Kiro/Amazon Q API format, +// extracting model information, system instructions, message contents, and tool declarations. +package openai + +import ( + "encoding/json" + "fmt" + "strings" + "time" + "unicode/utf8" + + "github.com/google/uuid" + kirocommon "github.com/router-for-me/CLIProxyAPI/v6/internal/translator/kiro/common" + log "github.com/sirupsen/logrus" + "github.com/tidwall/gjson" +) + +// Kiro API request structs - reuse from kiroclaude package structure + +// KiroPayload is the top-level request structure for Kiro API +type KiroPayload struct { + ConversationState KiroConversationState `json:"conversationState"` + ProfileArn string `json:"profileArn,omitempty"` + InferenceConfig *KiroInferenceConfig `json:"inferenceConfig,omitempty"` +} + +// KiroInferenceConfig contains inference parameters for the Kiro API. +type KiroInferenceConfig struct { + MaxTokens int `json:"maxTokens,omitempty"` + Temperature float64 `json:"temperature,omitempty"` +} + +// KiroConversationState holds the conversation context +type KiroConversationState struct { + ChatTriggerType string `json:"chatTriggerType"` // Required: "MANUAL" + ConversationID string `json:"conversationId"` + CurrentMessage KiroCurrentMessage `json:"currentMessage"` + History []KiroHistoryMessage `json:"history,omitempty"` +} + +// KiroCurrentMessage wraps the current user message +type KiroCurrentMessage struct { + UserInputMessage KiroUserInputMessage `json:"userInputMessage"` +} + +// KiroHistoryMessage represents a message in the conversation history +type KiroHistoryMessage struct { + UserInputMessage *KiroUserInputMessage `json:"userInputMessage,omitempty"` + AssistantResponseMessage *KiroAssistantResponseMessage `json:"assistantResponseMessage,omitempty"` +} + +// KiroImage represents an image in Kiro API format +type KiroImage struct { + Format string `json:"format"` + Source KiroImageSource `json:"source"` +} + +// KiroImageSource contains the image data +type KiroImageSource struct { + Bytes string `json:"bytes"` // base64 encoded image data +} + +// KiroUserInputMessage represents a user message +type KiroUserInputMessage struct { + Content string `json:"content"` + ModelID string `json:"modelId"` + Origin string `json:"origin"` + Images []KiroImage `json:"images,omitempty"` + UserInputMessageContext *KiroUserInputMessageContext `json:"userInputMessageContext,omitempty"` +} + +// KiroUserInputMessageContext contains tool-related context +type KiroUserInputMessageContext struct { + ToolResults []KiroToolResult `json:"toolResults,omitempty"` + Tools []KiroToolWrapper `json:"tools,omitempty"` +} + +// KiroToolResult represents a tool execution result +type KiroToolResult struct { + Content []KiroTextContent `json:"content"` + Status string `json:"status"` + ToolUseID string `json:"toolUseId"` +} + +// KiroTextContent represents text content +type KiroTextContent struct { + Text string `json:"text"` +} + +// KiroToolWrapper wraps a tool specification +type KiroToolWrapper struct { + ToolSpecification KiroToolSpecification `json:"toolSpecification"` +} + +// KiroToolSpecification defines a tool's schema +type KiroToolSpecification struct { + Name string `json:"name"` + Description string `json:"description"` + InputSchema KiroInputSchema `json:"inputSchema"` +} + +// KiroInputSchema wraps the JSON schema for tool input +type KiroInputSchema struct { + JSON interface{} `json:"json"` +} + +// KiroAssistantResponseMessage represents an assistant message +type KiroAssistantResponseMessage struct { + Content string `json:"content"` + ToolUses []KiroToolUse `json:"toolUses,omitempty"` +} + +// KiroToolUse represents a tool invocation by the assistant +type KiroToolUse struct { + ToolUseID string `json:"toolUseId"` + Name string `json:"name"` + Input map[string]interface{} `json:"input"` +} + +// ConvertOpenAIRequestToKiro converts an OpenAI Chat Completions request to Kiro format. +// This is the main entry point for request translation. +// Note: The actual payload building happens in the executor, this just passes through +// the OpenAI format which will be converted by BuildKiroPayloadFromOpenAI. +func ConvertOpenAIRequestToKiro(modelName string, inputRawJSON []byte, stream bool) []byte { + // Pass through the OpenAI format - actual conversion happens in BuildKiroPayloadFromOpenAI + return inputRawJSON +} + +// BuildKiroPayloadFromOpenAI constructs the Kiro API request payload from OpenAI format. +// Supports tool calling - tools are passed via userInputMessageContext. +// origin parameter determines which quota to use: "CLI" for Amazon Q, "AI_EDITOR" for Kiro IDE. +// isAgentic parameter enables chunked write optimization prompt for -agentic model variants. +// isChatOnly parameter disables tool calling for -chat model variants (pure conversation mode). +func BuildKiroPayloadFromOpenAI(openaiBody []byte, modelID, profileArn, origin string, isAgentic, isChatOnly bool) []byte { + // Extract max_tokens for potential use in inferenceConfig + var maxTokens int64 + if mt := gjson.GetBytes(openaiBody, "max_tokens"); mt.Exists() { + maxTokens = mt.Int() + } + + // Extract temperature if specified + var temperature float64 + var hasTemperature bool + if temp := gjson.GetBytes(openaiBody, "temperature"); temp.Exists() { + temperature = temp.Float() + hasTemperature = true + } + + // Normalize origin value for Kiro API compatibility + origin = normalizeOrigin(origin) + log.Debugf("kiro-openai: normalized origin value: %s", origin) + + messages := gjson.GetBytes(openaiBody, "messages") + + // For chat-only mode, don't include tools + var tools gjson.Result + if !isChatOnly { + tools = gjson.GetBytes(openaiBody, "tools") + } + + // Extract system prompt from messages + systemPrompt := extractSystemPromptFromOpenAI(messages) + + // Inject timestamp context + timestamp := time.Now().Format("2006-01-02 15:04:05 MST") + timestampContext := fmt.Sprintf("[Context: Current time is %s]", timestamp) + if systemPrompt != "" { + systemPrompt = timestampContext + "\n\n" + systemPrompt + } else { + systemPrompt = timestampContext + } + log.Debugf("kiro-openai: injected timestamp context: %s", timestamp) + + // Inject agentic optimization prompt for -agentic model variants + if isAgentic { + if systemPrompt != "" { + systemPrompt += "\n" + } + systemPrompt += kirocommon.KiroAgenticSystemPrompt + } + + // Convert OpenAI tools to Kiro format + kiroTools := convertOpenAIToolsToKiro(tools) + + // Process messages and build history + history, currentUserMsg, currentToolResults := processOpenAIMessages(messages, modelID, origin) + + // Build content with system prompt + if currentUserMsg != nil { + currentUserMsg.Content = buildFinalContent(currentUserMsg.Content, systemPrompt, currentToolResults) + + // Deduplicate currentToolResults + currentToolResults = deduplicateToolResults(currentToolResults) + + // Build userInputMessageContext with tools and tool results + if len(kiroTools) > 0 || len(currentToolResults) > 0 { + currentUserMsg.UserInputMessageContext = &KiroUserInputMessageContext{ + Tools: kiroTools, + ToolResults: currentToolResults, + } + } + } + + // Build payload + var currentMessage KiroCurrentMessage + if currentUserMsg != nil { + currentMessage = KiroCurrentMessage{UserInputMessage: *currentUserMsg} + } else { + fallbackContent := "" + if systemPrompt != "" { + fallbackContent = "--- SYSTEM PROMPT ---\n" + systemPrompt + "\n--- END SYSTEM PROMPT ---\n" + } + currentMessage = KiroCurrentMessage{UserInputMessage: KiroUserInputMessage{ + Content: fallbackContent, + ModelID: modelID, + Origin: origin, + }} + } + + // Build inferenceConfig if we have any inference parameters + var inferenceConfig *KiroInferenceConfig + if maxTokens > 0 || hasTemperature { + inferenceConfig = &KiroInferenceConfig{} + if maxTokens > 0 { + inferenceConfig.MaxTokens = int(maxTokens) + } + if hasTemperature { + inferenceConfig.Temperature = temperature + } + } + + payload := KiroPayload{ + ConversationState: KiroConversationState{ + ChatTriggerType: "MANUAL", + ConversationID: uuid.New().String(), + CurrentMessage: currentMessage, + History: history, + }, + ProfileArn: profileArn, + InferenceConfig: inferenceConfig, + } + + result, err := json.Marshal(payload) + if err != nil { + log.Debugf("kiro-openai: failed to marshal payload: %v", err) + return nil + } + + return result +} + +// normalizeOrigin normalizes origin value for Kiro API compatibility +func normalizeOrigin(origin string) string { + switch origin { + case "KIRO_CLI": + return "CLI" + case "KIRO_AI_EDITOR": + return "AI_EDITOR" + case "AMAZON_Q": + return "CLI" + case "KIRO_IDE": + return "AI_EDITOR" + default: + return origin + } +} + +// extractSystemPromptFromOpenAI extracts system prompt from OpenAI messages +func extractSystemPromptFromOpenAI(messages gjson.Result) string { + if !messages.IsArray() { + return "" + } + + var systemParts []string + for _, msg := range messages.Array() { + if msg.Get("role").String() == "system" { + content := msg.Get("content") + if content.Type == gjson.String { + systemParts = append(systemParts, content.String()) + } else if content.IsArray() { + // Handle array content format + for _, part := range content.Array() { + if part.Get("type").String() == "text" { + systemParts = append(systemParts, part.Get("text").String()) + } + } + } + } + } + + return strings.Join(systemParts, "\n") +} + +// convertOpenAIToolsToKiro converts OpenAI tools to Kiro format +func convertOpenAIToolsToKiro(tools gjson.Result) []KiroToolWrapper { + var kiroTools []KiroToolWrapper + if !tools.IsArray() { + return kiroTools + } + + for _, tool := range tools.Array() { + // OpenAI tools have type "function" with function definition inside + if tool.Get("type").String() != "function" { + continue + } + + fn := tool.Get("function") + if !fn.Exists() { + continue + } + + name := fn.Get("name").String() + description := fn.Get("description").String() + parameters := fn.Get("parameters").Value() + + // CRITICAL FIX: Kiro API requires non-empty description + if strings.TrimSpace(description) == "" { + description = fmt.Sprintf("Tool: %s", name) + log.Debugf("kiro-openai: tool '%s' has empty description, using default: %s", name, description) + } + + // Truncate long descriptions + if len(description) > kirocommon.KiroMaxToolDescLen { + truncLen := kirocommon.KiroMaxToolDescLen - 30 + for truncLen > 0 && !utf8.RuneStart(description[truncLen]) { + truncLen-- + } + description = description[:truncLen] + "... (description truncated)" + } + + kiroTools = append(kiroTools, KiroToolWrapper{ + ToolSpecification: KiroToolSpecification{ + Name: name, + Description: description, + InputSchema: KiroInputSchema{JSON: parameters}, + }, + }) + } + + return kiroTools +} + +// processOpenAIMessages processes OpenAI messages and builds Kiro history +func processOpenAIMessages(messages gjson.Result, modelID, origin string) ([]KiroHistoryMessage, *KiroUserInputMessage, []KiroToolResult) { + var history []KiroHistoryMessage + var currentUserMsg *KiroUserInputMessage + var currentToolResults []KiroToolResult + + if !messages.IsArray() { + return history, currentUserMsg, currentToolResults + } + + // Merge adjacent messages with the same role + messagesArray := kirocommon.MergeAdjacentMessages(messages.Array()) + + // Build tool_call_id to name mapping from assistant messages + toolCallIDToName := make(map[string]string) + for _, msg := range messagesArray { + if msg.Get("role").String() == "assistant" { + toolCalls := msg.Get("tool_calls") + if toolCalls.IsArray() { + for _, tc := range toolCalls.Array() { + if tc.Get("type").String() == "function" { + id := tc.Get("id").String() + name := tc.Get("function.name").String() + if id != "" && name != "" { + toolCallIDToName[id] = name + } + } + } + } + } + } + + for i, msg := range messagesArray { + role := msg.Get("role").String() + isLastMessage := i == len(messagesArray)-1 + + switch role { + case "system": + // System messages are handled separately via extractSystemPromptFromOpenAI + continue + + case "user": + userMsg, toolResults := buildUserMessageFromOpenAI(msg, modelID, origin) + if isLastMessage { + currentUserMsg = &userMsg + currentToolResults = toolResults + } else { + // CRITICAL: Kiro API requires content to be non-empty for history messages + if strings.TrimSpace(userMsg.Content) == "" { + if len(toolResults) > 0 { + userMsg.Content = "Tool results provided." + } else { + userMsg.Content = "Continue" + } + } + // For history messages, embed tool results in context + if len(toolResults) > 0 { + userMsg.UserInputMessageContext = &KiroUserInputMessageContext{ + ToolResults: toolResults, + } + } + history = append(history, KiroHistoryMessage{ + UserInputMessage: &userMsg, + }) + } + + case "assistant": + assistantMsg := buildAssistantMessageFromOpenAI(msg) + if isLastMessage { + history = append(history, KiroHistoryMessage{ + AssistantResponseMessage: &assistantMsg, + }) + // Create a "Continue" user message as currentMessage + currentUserMsg = &KiroUserInputMessage{ + Content: "Continue", + ModelID: modelID, + Origin: origin, + } + } else { + history = append(history, KiroHistoryMessage{ + AssistantResponseMessage: &assistantMsg, + }) + } + + case "tool": + // Tool messages in OpenAI format provide results for tool_calls + // These are typically followed by user or assistant messages + // Process them and merge into the next user message's tool results + toolCallID := msg.Get("tool_call_id").String() + content := msg.Get("content").String() + + if toolCallID != "" { + toolResult := KiroToolResult{ + ToolUseID: toolCallID, + Content: []KiroTextContent{{Text: content}}, + Status: "success", + } + // Tool results should be included in the next user message + // For now, collect them and they'll be handled when we build the current message + currentToolResults = append(currentToolResults, toolResult) + } + } + } + + return history, currentUserMsg, currentToolResults +} + +// buildUserMessageFromOpenAI builds a user message from OpenAI format and extracts tool results +func buildUserMessageFromOpenAI(msg gjson.Result, modelID, origin string) (KiroUserInputMessage, []KiroToolResult) { + content := msg.Get("content") + var contentBuilder strings.Builder + var toolResults []KiroToolResult + var images []KiroImage + + // Track seen toolCallIds to deduplicate + seenToolCallIDs := make(map[string]bool) + + if content.IsArray() { + for _, part := range content.Array() { + partType := part.Get("type").String() + switch partType { + case "text": + contentBuilder.WriteString(part.Get("text").String()) + case "image_url": + imageURL := part.Get("image_url.url").String() + if strings.HasPrefix(imageURL, "data:") { + // Parse data URL: data:image/png;base64,xxxxx + if idx := strings.Index(imageURL, ";base64,"); idx != -1 { + mediaType := imageURL[5:idx] // Skip "data:" + data := imageURL[idx+8:] // Skip ";base64," + + format := "" + if lastSlash := strings.LastIndex(mediaType, "/"); lastSlash != -1 { + format = mediaType[lastSlash+1:] + } + + if format != "" && data != "" { + images = append(images, KiroImage{ + Format: format, + Source: KiroImageSource{ + Bytes: data, + }, + }) + } + } + } + } + } + } else if content.Type == gjson.String { + contentBuilder.WriteString(content.String()) + } + + // Check for tool_calls in the message (shouldn't be in user messages, but handle edge cases) + _ = seenToolCallIDs // Used for deduplication if needed + + userMsg := KiroUserInputMessage{ + Content: contentBuilder.String(), + ModelID: modelID, + Origin: origin, + } + + if len(images) > 0 { + userMsg.Images = images + } + + return userMsg, toolResults +} + +// buildAssistantMessageFromOpenAI builds an assistant message from OpenAI format +func buildAssistantMessageFromOpenAI(msg gjson.Result) KiroAssistantResponseMessage { + content := msg.Get("content") + var contentBuilder strings.Builder + var toolUses []KiroToolUse + + // Handle content + if content.Type == gjson.String { + contentBuilder.WriteString(content.String()) + } else if content.IsArray() { + for _, part := range content.Array() { + if part.Get("type").String() == "text" { + contentBuilder.WriteString(part.Get("text").String()) + } + } + } + + // Handle tool_calls + toolCalls := msg.Get("tool_calls") + if toolCalls.IsArray() { + for _, tc := range toolCalls.Array() { + if tc.Get("type").String() != "function" { + continue + } + + toolUseID := tc.Get("id").String() + toolName := tc.Get("function.name").String() + toolArgs := tc.Get("function.arguments").String() + + var inputMap map[string]interface{} + if err := json.Unmarshal([]byte(toolArgs), &inputMap); err != nil { + log.Debugf("kiro-openai: failed to parse tool arguments: %v", err) + inputMap = make(map[string]interface{}) + } + + toolUses = append(toolUses, KiroToolUse{ + ToolUseID: toolUseID, + Name: toolName, + Input: inputMap, + }) + } + } + + return KiroAssistantResponseMessage{ + Content: contentBuilder.String(), + ToolUses: toolUses, + } +} + +// buildFinalContent builds the final content with system prompt +func buildFinalContent(content, systemPrompt string, toolResults []KiroToolResult) string { + var contentBuilder strings.Builder + + if systemPrompt != "" { + contentBuilder.WriteString("--- SYSTEM PROMPT ---\n") + contentBuilder.WriteString(systemPrompt) + contentBuilder.WriteString("\n--- END SYSTEM PROMPT ---\n\n") + } + + contentBuilder.WriteString(content) + finalContent := contentBuilder.String() + + // CRITICAL: Kiro API requires content to be non-empty + if strings.TrimSpace(finalContent) == "" { + if len(toolResults) > 0 { + finalContent = "Tool results provided." + } else { + finalContent = "Continue" + } + log.Debugf("kiro-openai: content was empty, using default: %s", finalContent) + } + + return finalContent +} + +// deduplicateToolResults removes duplicate tool results +func deduplicateToolResults(toolResults []KiroToolResult) []KiroToolResult { + if len(toolResults) == 0 { + return toolResults + } + + seenIDs := make(map[string]bool) + unique := make([]KiroToolResult, 0, len(toolResults)) + for _, tr := range toolResults { + if !seenIDs[tr.ToolUseID] { + seenIDs[tr.ToolUseID] = true + unique = append(unique, tr) + } else { + log.Debugf("kiro-openai: skipping duplicate toolResult: %s", tr.ToolUseID) + } + } + return unique +} \ No newline at end of file diff --git a/internal/translator/kiro/openai/kiro_openai_response.go b/internal/translator/kiro/openai/kiro_openai_response.go new file mode 100644 index 00000000..b7da1373 --- /dev/null +++ b/internal/translator/kiro/openai/kiro_openai_response.go @@ -0,0 +1,264 @@ +// Package openai provides response translation from Kiro to OpenAI format. +// This package handles the conversion of Kiro API responses into OpenAI Chat Completions-compatible +// JSON format, transforming streaming events and non-streaming responses. +package openai + +import ( + "encoding/json" + "fmt" + "sync/atomic" + "time" + + "github.com/google/uuid" + "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/usage" + log "github.com/sirupsen/logrus" +) + +// functionCallIDCounter provides a process-wide unique counter for function call identifiers. +var functionCallIDCounter uint64 + +// BuildOpenAIResponse constructs an OpenAI Chat Completions-compatible response. +// Supports tool_calls when tools are present in the response. +// stopReason is passed from upstream; fallback logic applied if empty. +func BuildOpenAIResponse(content string, toolUses []KiroToolUse, model string, usageInfo usage.Detail, stopReason string) []byte { + // Build the message object + message := map[string]interface{}{ + "role": "assistant", + "content": content, + } + + // Add tool_calls if present + if len(toolUses) > 0 { + var toolCalls []map[string]interface{} + for i, tu := range toolUses { + inputJSON, _ := json.Marshal(tu.Input) + toolCalls = append(toolCalls, map[string]interface{}{ + "id": tu.ToolUseID, + "type": "function", + "index": i, + "function": map[string]interface{}{ + "name": tu.Name, + "arguments": string(inputJSON), + }, + }) + } + message["tool_calls"] = toolCalls + // When tool_calls are present, content should be null according to OpenAI spec + if content == "" { + message["content"] = nil + } + } + + // Use upstream stopReason; apply fallback logic if not provided + finishReason := mapKiroStopReasonToOpenAI(stopReason) + if finishReason == "" { + finishReason = "stop" + if len(toolUses) > 0 { + finishReason = "tool_calls" + } + log.Debugf("kiro-openai: buildOpenAIResponse using fallback finish_reason: %s", finishReason) + } + + response := map[string]interface{}{ + "id": "chatcmpl-" + uuid.New().String()[:24], + "object": "chat.completion", + "created": time.Now().Unix(), + "model": model, + "choices": []map[string]interface{}{ + { + "index": 0, + "message": message, + "finish_reason": finishReason, + }, + }, + "usage": map[string]interface{}{ + "prompt_tokens": usageInfo.InputTokens, + "completion_tokens": usageInfo.OutputTokens, + "total_tokens": usageInfo.InputTokens + usageInfo.OutputTokens, + }, + } + + result, _ := json.Marshal(response) + return result +} + +// mapKiroStopReasonToOpenAI converts Kiro/Claude stop_reason to OpenAI finish_reason +func mapKiroStopReasonToOpenAI(stopReason string) string { + switch stopReason { + case "end_turn": + return "stop" + case "stop_sequence": + return "stop" + case "tool_use": + return "tool_calls" + case "max_tokens": + return "length" + case "content_filtered": + return "content_filter" + default: + return stopReason + } +} + +// BuildOpenAIStreamChunk constructs an OpenAI Chat Completions streaming chunk. +// This is the delta format used in streaming responses. +func BuildOpenAIStreamChunk(model string, deltaContent string, deltaToolCalls []map[string]interface{}, finishReason string, index int) []byte { + delta := map[string]interface{}{} + + // First chunk should include role + if index == 0 && deltaContent == "" && len(deltaToolCalls) == 0 { + delta["role"] = "assistant" + delta["content"] = "" + } else if deltaContent != "" { + delta["content"] = deltaContent + } + + // Add tool_calls delta if present + if len(deltaToolCalls) > 0 { + delta["tool_calls"] = deltaToolCalls + } + + choice := map[string]interface{}{ + "index": 0, + "delta": delta, + } + + if finishReason != "" { + choice["finish_reason"] = finishReason + } else { + choice["finish_reason"] = nil + } + + chunk := map[string]interface{}{ + "id": "chatcmpl-" + uuid.New().String()[:12], + "object": "chat.completion.chunk", + "created": time.Now().Unix(), + "model": model, + "choices": []map[string]interface{}{choice}, + } + + result, _ := json.Marshal(chunk) + return result +} + +// BuildOpenAIStreamChunkWithToolCallStart creates a stream chunk for tool call start +func BuildOpenAIStreamChunkWithToolCallStart(model string, toolUseID, toolName string, toolIndex int) []byte { + toolCall := map[string]interface{}{ + "index": toolIndex, + "id": toolUseID, + "type": "function", + "function": map[string]interface{}{ + "name": toolName, + "arguments": "", + }, + } + + delta := map[string]interface{}{ + "tool_calls": []map[string]interface{}{toolCall}, + } + + choice := map[string]interface{}{ + "index": 0, + "delta": delta, + "finish_reason": nil, + } + + chunk := map[string]interface{}{ + "id": "chatcmpl-" + uuid.New().String()[:12], + "object": "chat.completion.chunk", + "created": time.Now().Unix(), + "model": model, + "choices": []map[string]interface{}{choice}, + } + + result, _ := json.Marshal(chunk) + return result +} + +// BuildOpenAIStreamChunkWithToolCallDelta creates a stream chunk for tool call arguments delta +func BuildOpenAIStreamChunkWithToolCallDelta(model string, argumentsDelta string, toolIndex int) []byte { + toolCall := map[string]interface{}{ + "index": toolIndex, + "function": map[string]interface{}{ + "arguments": argumentsDelta, + }, + } + + delta := map[string]interface{}{ + "tool_calls": []map[string]interface{}{toolCall}, + } + + choice := map[string]interface{}{ + "index": 0, + "delta": delta, + "finish_reason": nil, + } + + chunk := map[string]interface{}{ + "id": "chatcmpl-" + uuid.New().String()[:12], + "object": "chat.completion.chunk", + "created": time.Now().Unix(), + "model": model, + "choices": []map[string]interface{}{choice}, + } + + result, _ := json.Marshal(chunk) + return result +} + +// BuildOpenAIStreamDoneChunk creates the final [DONE] stream event +func BuildOpenAIStreamDoneChunk() []byte { + return []byte("data: [DONE]") +} + +// BuildOpenAIStreamFinishChunk creates the final chunk with finish_reason +func BuildOpenAIStreamFinishChunk(model string, finishReason string) []byte { + choice := map[string]interface{}{ + "index": 0, + "delta": map[string]interface{}{}, + "finish_reason": finishReason, + } + + chunk := map[string]interface{}{ + "id": "chatcmpl-" + uuid.New().String()[:12], + "object": "chat.completion.chunk", + "created": time.Now().Unix(), + "model": model, + "choices": []map[string]interface{}{choice}, + } + + result, _ := json.Marshal(chunk) + return result +} + +// BuildOpenAIStreamUsageChunk creates a chunk with usage information (optional, for stream_options.include_usage) +func BuildOpenAIStreamUsageChunk(model string, usageInfo usage.Detail) []byte { + chunk := map[string]interface{}{ + "id": "chatcmpl-" + uuid.New().String()[:12], + "object": "chat.completion.chunk", + "created": time.Now().Unix(), + "model": model, + "choices": []map[string]interface{}{}, + "usage": map[string]interface{}{ + "prompt_tokens": usageInfo.InputTokens, + "completion_tokens": usageInfo.OutputTokens, + "total_tokens": usageInfo.InputTokens + usageInfo.OutputTokens, + }, + } + + result, _ := json.Marshal(chunk) + return result +} + +// GenerateToolCallID generates a unique tool call ID in OpenAI format +func GenerateToolCallID(toolName string) string { + return fmt.Sprintf("call_%s_%d_%d", toolName[:min(8, len(toolName))], time.Now().UnixNano(), atomic.AddUint64(&functionCallIDCounter, 1)) +} + +// min returns the minimum of two integers +func min(a, b int) int { + if a < b { + return a + } + return b +} \ No newline at end of file diff --git a/internal/translator/kiro/openai/kiro_openai_stream.go b/internal/translator/kiro/openai/kiro_openai_stream.go new file mode 100644 index 00000000..d550a8d8 --- /dev/null +++ b/internal/translator/kiro/openai/kiro_openai_stream.go @@ -0,0 +1,207 @@ +// Package openai provides streaming SSE event building for OpenAI format. +// This package handles the construction of OpenAI-compatible Server-Sent Events (SSE) +// for streaming responses from Kiro API. +package openai + +import ( + "encoding/json" + "fmt" + "time" + + "github.com/google/uuid" + "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/usage" +) + +// OpenAIStreamState tracks the state of streaming response conversion +type OpenAIStreamState struct { + ChunkIndex int + ToolCallIndex int + HasSentFirstChunk bool + Model string + ResponseID string + Created int64 +} + +// NewOpenAIStreamState creates a new stream state for tracking +func NewOpenAIStreamState(model string) *OpenAIStreamState { + return &OpenAIStreamState{ + ChunkIndex: 0, + ToolCallIndex: 0, + HasSentFirstChunk: false, + Model: model, + ResponseID: "chatcmpl-" + uuid.New().String()[:24], + Created: time.Now().Unix(), + } +} + +// FormatSSEEvent formats a JSON payload as an SSE event +func FormatSSEEvent(data []byte) string { + return fmt.Sprintf("data: %s", string(data)) +} + +// BuildOpenAISSETextDelta creates an SSE event for text content delta +func BuildOpenAISSETextDelta(state *OpenAIStreamState, textDelta string) string { + delta := map[string]interface{}{ + "content": textDelta, + } + + // Include role in first chunk + if !state.HasSentFirstChunk { + delta["role"] = "assistant" + state.HasSentFirstChunk = true + } + + chunk := buildBaseChunk(state, delta, nil) + result, _ := json.Marshal(chunk) + state.ChunkIndex++ + return FormatSSEEvent(result) +} + +// BuildOpenAISSEToolCallStart creates an SSE event for tool call start +func BuildOpenAISSEToolCallStart(state *OpenAIStreamState, toolUseID, toolName string) string { + toolCall := map[string]interface{}{ + "index": state.ToolCallIndex, + "id": toolUseID, + "type": "function", + "function": map[string]interface{}{ + "name": toolName, + "arguments": "", + }, + } + + delta := map[string]interface{}{ + "tool_calls": []map[string]interface{}{toolCall}, + } + + // Include role in first chunk if not sent yet + if !state.HasSentFirstChunk { + delta["role"] = "assistant" + state.HasSentFirstChunk = true + } + + chunk := buildBaseChunk(state, delta, nil) + result, _ := json.Marshal(chunk) + state.ChunkIndex++ + return FormatSSEEvent(result) +} + +// BuildOpenAISSEToolCallArgumentsDelta creates an SSE event for tool call arguments delta +func BuildOpenAISSEToolCallArgumentsDelta(state *OpenAIStreamState, argumentsDelta string, toolIndex int) string { + toolCall := map[string]interface{}{ + "index": toolIndex, + "function": map[string]interface{}{ + "arguments": argumentsDelta, + }, + } + + delta := map[string]interface{}{ + "tool_calls": []map[string]interface{}{toolCall}, + } + + chunk := buildBaseChunk(state, delta, nil) + result, _ := json.Marshal(chunk) + state.ChunkIndex++ + return FormatSSEEvent(result) +} + +// BuildOpenAISSEFinish creates an SSE event with finish_reason +func BuildOpenAISSEFinish(state *OpenAIStreamState, finishReason string) string { + chunk := buildBaseChunk(state, map[string]interface{}{}, &finishReason) + result, _ := json.Marshal(chunk) + state.ChunkIndex++ + return FormatSSEEvent(result) +} + +// BuildOpenAISSEUsage creates an SSE event with usage information +func BuildOpenAISSEUsage(state *OpenAIStreamState, usageInfo usage.Detail) string { + chunk := map[string]interface{}{ + "id": state.ResponseID, + "object": "chat.completion.chunk", + "created": state.Created, + "model": state.Model, + "choices": []map[string]interface{}{}, + "usage": map[string]interface{}{ + "prompt_tokens": usageInfo.InputTokens, + "completion_tokens": usageInfo.OutputTokens, + "total_tokens": usageInfo.InputTokens + usageInfo.OutputTokens, + }, + } + result, _ := json.Marshal(chunk) + return FormatSSEEvent(result) +} + +// BuildOpenAISSEDone creates the final [DONE] SSE event +func BuildOpenAISSEDone() string { + return "data: [DONE]" +} + +// buildBaseChunk creates a base chunk structure for streaming +func buildBaseChunk(state *OpenAIStreamState, delta map[string]interface{}, finishReason *string) map[string]interface{} { + choice := map[string]interface{}{ + "index": 0, + "delta": delta, + } + + if finishReason != nil { + choice["finish_reason"] = *finishReason + } else { + choice["finish_reason"] = nil + } + + return map[string]interface{}{ + "id": state.ResponseID, + "object": "chat.completion.chunk", + "created": state.Created, + "model": state.Model, + "choices": []map[string]interface{}{choice}, + } +} + +// BuildOpenAISSEReasoningDelta creates an SSE event for reasoning content delta +// This is used for o1/o3 style models that expose reasoning tokens +func BuildOpenAISSEReasoningDelta(state *OpenAIStreamState, reasoningDelta string) string { + delta := map[string]interface{}{ + "reasoning_content": reasoningDelta, + } + + // Include role in first chunk + if !state.HasSentFirstChunk { + delta["role"] = "assistant" + state.HasSentFirstChunk = true + } + + chunk := buildBaseChunk(state, delta, nil) + result, _ := json.Marshal(chunk) + state.ChunkIndex++ + return FormatSSEEvent(result) +} + +// BuildOpenAISSEFirstChunk creates the first chunk with role only +func BuildOpenAISSEFirstChunk(state *OpenAIStreamState) string { + delta := map[string]interface{}{ + "role": "assistant", + "content": "", + } + + state.HasSentFirstChunk = true + chunk := buildBaseChunk(state, delta, nil) + result, _ := json.Marshal(chunk) + state.ChunkIndex++ + return FormatSSEEvent(result) +} + +// ThinkingTagState tracks state for thinking tag detection in streaming +type ThinkingTagState struct { + InThinkingBlock bool + PendingStartChars int + PendingEndChars int +} + +// NewThinkingTagState creates a new thinking tag state +func NewThinkingTagState() *ThinkingTagState { + return &ThinkingTagState{ + InThinkingBlock: false, + PendingStartChars: 0, + PendingEndChars: 0, + } +} \ No newline at end of file From 9c04c18c0484d3162b62ec9dcc6edca415fbe988 Mon Sep 17 00:00:00 2001 From: Ravens2121 Date: Sun, 14 Dec 2025 11:54:57 +0800 Subject: [PATCH 36/38] feat(kiro): enhance request translation and fix streaming issues MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit English: - Fix tag parsing: only parse at response start, avoid misinterpreting discussion text - Add token counting support using tiktoken for local estimation - Support top_p parameter in inference config - Handle max_tokens=-1 as maximum (32000 tokens) - Add tool_choice and response_format parameter handling via system prompt hints - Support multiple thinking mode detection formats (Claude API, OpenAI reasoning_effort, AMP/Cursor) - Shorten MCP tool names exceeding 64 characters - Fix duplicate [DONE] marker in OpenAI SSE streaming - Enhance token usage statistics with multiple event format support - Add code fence markers to constants 中文: - 修复 标签解析:仅在响应开头解析,避免误解析讨论文本中的标签 - 使用 tiktoken 实现本地 token 计数功能 - 支持 top_p 推理配置参数 - 处理 max_tokens=-1 转换为最大值(32000 tokens) - 通过系统提示词注入实现 tool_choice 和 response_format 参数支持 - 支持多种思考模式检测格式(Claude API、OpenAI reasoning_effort、AMP/Cursor) - 截断超过64字符的 MCP 工具名称 - 修复 OpenAI SSE 流中重复的 [DONE] 标记 - 增强 token 使用量统计,支持多种事件格式 - 添加代码围栏标记常量 --- internal/runtime/executor/kiro_executor.go | 855 +++++++++++++++++- .../kiro/claude/kiro_claude_request.go | 176 +++- internal/translator/kiro/common/constants.go | 9 + .../translator/kiro/openai/kiro_openai.go | 5 +- .../kiro/openai/kiro_openai_request.go | 245 ++++- .../kiro/openai/kiro_openai_stream.go | 15 +- 6 files changed, 1278 insertions(+), 27 deletions(-) diff --git a/internal/runtime/executor/kiro_executor.go b/internal/runtime/executor/kiro_executor.go index 1d4d85a5..b9a38272 100644 --- a/internal/runtime/executor/kiro_executor.go +++ b/internal/runtime/executor/kiro_executor.go @@ -19,6 +19,7 @@ import ( "github.com/router-for-me/CLIProxyAPI/v6/internal/config" kiroclaude "github.com/router-for-me/CLIProxyAPI/v6/internal/translator/kiro/claude" kirocommon "github.com/router-for-me/CLIProxyAPI/v6/internal/translator/kiro/common" + kiroopenai "github.com/router-for-me/CLIProxyAPI/v6/internal/translator/kiro/openai" "github.com/router-for-me/CLIProxyAPI/v6/internal/util" cliproxyauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth" cliproxyexecutor "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/executor" @@ -161,6 +162,22 @@ type KiroExecutor struct { refreshMu sync.Mutex // Serializes token refresh operations to prevent race conditions } +// buildKiroPayloadForFormat builds the Kiro API payload based on the source format. +// This is critical because OpenAI and Claude formats have different tool structures: +// - OpenAI: tools[].function.name, tools[].function.description +// - Claude: tools[].name, tools[].description +func buildKiroPayloadForFormat(body []byte, modelID, profileArn, origin string, isAgentic, isChatOnly bool, sourceFormat sdktranslator.Format) []byte { + switch sourceFormat.String() { + case "openai": + log.Debugf("kiro: using OpenAI payload builder for source format: %s", sourceFormat.String()) + return kiroopenai.BuildKiroPayloadFromOpenAI(body, modelID, profileArn, origin, isAgentic, isChatOnly) + default: + // Default to Claude format (also handles "claude", "kiro", etc.) + log.Debugf("kiro: using Claude payload builder for source format: %s", sourceFormat.String()) + return kiroclaude.BuildKiroPayload(body, modelID, profileArn, origin, isAgentic, isChatOnly) + } +} + // NewKiroExecutor creates a new Kiro executor instance. func NewKiroExecutor(cfg *config.Config) *KiroExecutor { return &KiroExecutor{cfg: cfg} @@ -231,7 +248,7 @@ func (e *KiroExecutor) executeWithRetry(ctx context.Context, auth *cliproxyauth. // Rebuild payload with the correct origin for this endpoint // Each endpoint requires its matching Origin value in the request body - kiroPayload = kiroclaude.BuildKiroPayload(body, kiroModelID, profileArn, currentOrigin, isAgentic, isChatOnly) + kiroPayload = buildKiroPayloadForFormat(body, kiroModelID, profileArn, currentOrigin, isAgentic, isChatOnly, from) log.Debugf("kiro: trying endpoint %d/%d: %s (Name: %s, Origin: %s)", endpointIdx+1, len(endpointConfigs), url, endpointConfig.Name, currentOrigin) @@ -341,7 +358,7 @@ func (e *KiroExecutor) executeWithRetry(ctx context.Context, auth *cliproxyauth. auth = refreshedAuth accessToken, profileArn = kiroCredentials(auth) // Rebuild payload with new profile ARN if changed - kiroPayload = kiroclaude.BuildKiroPayload(body, kiroModelID, profileArn, currentOrigin, isAgentic, isChatOnly) + kiroPayload = buildKiroPayloadForFormat(body, kiroModelID, profileArn, currentOrigin, isAgentic, isChatOnly, from) log.Infof("kiro: token refreshed successfully, retrying request") continue } @@ -398,7 +415,7 @@ func (e *KiroExecutor) executeWithRetry(ctx context.Context, auth *cliproxyauth. if refreshedAuth != nil { auth = refreshedAuth accessToken, profileArn = kiroCredentials(auth) - kiroPayload = kiroclaude.BuildKiroPayload(body, kiroModelID, profileArn, currentOrigin, isAgentic, isChatOnly) + kiroPayload = buildKiroPayloadForFormat(body, kiroModelID, profileArn, currentOrigin, isAgentic, isChatOnly, from) log.Infof("kiro: token refreshed for 403, retrying request") continue } @@ -537,7 +554,7 @@ func (e *KiroExecutor) executeStreamWithRetry(ctx context.Context, auth *cliprox // Rebuild payload with the correct origin for this endpoint // Each endpoint requires its matching Origin value in the request body - kiroPayload = kiroclaude.BuildKiroPayload(body, kiroModelID, profileArn, currentOrigin, isAgentic, isChatOnly) + kiroPayload = buildKiroPayloadForFormat(body, kiroModelID, profileArn, currentOrigin, isAgentic, isChatOnly, from) log.Debugf("kiro: stream trying endpoint %d/%d: %s (Name: %s, Origin: %s)", endpointIdx+1, len(endpointConfigs), url, endpointConfig.Name, currentOrigin) @@ -660,7 +677,7 @@ func (e *KiroExecutor) executeStreamWithRetry(ctx context.Context, auth *cliprox auth = refreshedAuth accessToken, profileArn = kiroCredentials(auth) // Rebuild payload with new profile ARN if changed - kiroPayload = kiroclaude.BuildKiroPayload(body, kiroModelID, profileArn, currentOrigin, isAgentic, isChatOnly) + kiroPayload = buildKiroPayloadForFormat(body, kiroModelID, profileArn, currentOrigin, isAgentic, isChatOnly, from) log.Infof("kiro: token refreshed successfully, retrying stream request") continue } @@ -717,7 +734,7 @@ func (e *KiroExecutor) executeStreamWithRetry(ctx context.Context, auth *cliprox if refreshedAuth != nil { auth = refreshedAuth accessToken, profileArn = kiroCredentials(auth) - kiroPayload = kiroclaude.BuildKiroPayload(body, kiroModelID, profileArn, currentOrigin, isAgentic, isChatOnly) + kiroPayload = buildKiroPayloadForFormat(body, kiroModelID, profileArn, currentOrigin, isAgentic, isChatOnly, from) log.Infof("kiro: token refreshed for 403, retrying stream request") continue } @@ -755,7 +772,20 @@ func (e *KiroExecutor) executeStreamWithRetry(ctx context.Context, auth *cliprox } }() - e.streamToChannel(ctx, resp.Body, out, from, req.Model, opts.OriginalRequest, body, reporter) + // Check if thinking mode was enabled in the original request + // Only parse tags when thinking was explicitly requested + // Check multiple sources: original request, pre-translation payload, and translated body + // This handles different client formats (Claude API, OpenAI, AMP/Cursor) + thinkingEnabled := kiroclaude.IsThinkingEnabled(opts.OriginalRequest) + if !thinkingEnabled { + thinkingEnabled = kiroclaude.IsThinkingEnabled(req.Payload) + } + if !thinkingEnabled { + thinkingEnabled = kiroclaude.IsThinkingEnabled(body) + } + log.Debugf("kiro: stream thinkingEnabled = %v", thinkingEnabled) + + e.streamToChannel(ctx, resp.Body, out, from, req.Model, opts.OriginalRequest, body, reporter, thinkingEnabled) }(httpResp) return out, nil @@ -807,6 +837,153 @@ func kiroCredentials(auth *cliproxyauth.Auth) (accessToken, profileArn string) { return accessToken, profileArn } +// findRealThinkingEndTag finds the real end tag, skipping false positives. +// Returns -1 if no real end tag is found. +// +// Real tags from Kiro API have specific characteristics: +// - Usually preceded by newline (.\n) +// - Usually followed by newline (\n\n) +// - Not inside code blocks or inline code +// +// False positives (discussion text) have characteristics: +// - In the middle of a sentence +// - Preceded by discussion words like "标签", "tag", "returns" +// - Inside code blocks or inline code +// +// Parameters: +// - content: the content to search in +// - alreadyInCodeBlock: whether we're already inside a code block from previous chunks +// - alreadyInInlineCode: whether we're already inside inline code from previous chunks +func findRealThinkingEndTag(content string, alreadyInCodeBlock, alreadyInInlineCode bool) int { + searchStart := 0 + for { + endIdx := strings.Index(content[searchStart:], kirocommon.ThinkingEndTag) + if endIdx < 0 { + return -1 + } + endIdx += searchStart // Adjust to absolute position + + textBeforeEnd := content[:endIdx] + textAfterEnd := content[endIdx+len(kirocommon.ThinkingEndTag):] + + // Check 1: Is it inside inline code? + // Count backticks in current content and add state from previous chunks + backtickCount := strings.Count(textBeforeEnd, "`") + effectiveInInlineCode := alreadyInInlineCode + if backtickCount%2 == 1 { + effectiveInInlineCode = !effectiveInInlineCode + } + if effectiveInInlineCode { + log.Debugf("kiro: found inside inline code at pos %d, skipping", endIdx) + searchStart = endIdx + len(kirocommon.ThinkingEndTag) + continue + } + + // Check 2: Is it inside a code block? + // Count fences in current content and add state from previous chunks + fenceCount := strings.Count(textBeforeEnd, "```") + altFenceCount := strings.Count(textBeforeEnd, "~~~") + effectiveInCodeBlock := alreadyInCodeBlock + if fenceCount%2 == 1 || altFenceCount%2 == 1 { + effectiveInCodeBlock = !effectiveInCodeBlock + } + if effectiveInCodeBlock { + log.Debugf("kiro: found inside code block at pos %d, skipping", endIdx) + searchStart = endIdx + len(kirocommon.ThinkingEndTag) + continue + } + + // Check 3: Real tags are usually preceded by newline or at start + // and followed by newline or at end. Check the format. + charBeforeTag := byte(0) + if endIdx > 0 { + charBeforeTag = content[endIdx-1] + } + charAfterTag := byte(0) + if len(textAfterEnd) > 0 { + charAfterTag = textAfterEnd[0] + } + + // Real end tag format: preceded by newline OR end of sentence (. ! ?) + // and followed by newline OR end of content + isPrecededByNewlineOrSentenceEnd := charBeforeTag == '\n' || charBeforeTag == '.' || + charBeforeTag == '!' || charBeforeTag == '?' || charBeforeTag == 0 + isFollowedByNewlineOrEnd := charAfterTag == '\n' || charAfterTag == 0 + + // If the tag has proper formatting (newline before/after), it's likely real + if isPrecededByNewlineOrSentenceEnd && isFollowedByNewlineOrEnd { + log.Debugf("kiro: found properly formatted at pos %d", endIdx) + return endIdx + } + + // Check 4: Is the tag preceded by discussion keywords on the same line? + lastNewlineIdx := strings.LastIndex(textBeforeEnd, "\n") + lineBeforeTag := textBeforeEnd + if lastNewlineIdx >= 0 { + lineBeforeTag = textBeforeEnd[lastNewlineIdx+1:] + } + lineBeforeTagLower := strings.ToLower(lineBeforeTag) + + // Discussion patterns - if found, this is likely discussion text + discussionPatterns := []string{ + "标签", "返回", "输出", "包含", "使用", "解析", "转换", "生成", // Chinese + "tag", "return", "output", "contain", "use", "parse", "emit", "convert", "generate", // English + "", // discussing both tags together + "``", // explicitly in inline code + } + isDiscussion := false + for _, pattern := range discussionPatterns { + if strings.Contains(lineBeforeTagLower, pattern) { + isDiscussion = true + break + } + } + if isDiscussion { + log.Debugf("kiro: found after discussion text at pos %d, skipping", endIdx) + searchStart = endIdx + len(kirocommon.ThinkingEndTag) + continue + } + + // Check 5: Is there text immediately after on the same line? + // Real end tags don't have text immediately after on the same line + if len(textAfterEnd) > 0 && charAfterTag != '\n' && charAfterTag != 0 { + // Find the next newline + nextNewline := strings.Index(textAfterEnd, "\n") + var textOnSameLine string + if nextNewline >= 0 { + textOnSameLine = textAfterEnd[:nextNewline] + } else { + textOnSameLine = textAfterEnd + } + // If there's non-whitespace text on the same line after the tag, it's discussion + if strings.TrimSpace(textOnSameLine) != "" { + log.Debugf("kiro: found with text after on same line at pos %d, skipping", endIdx) + searchStart = endIdx + len(kirocommon.ThinkingEndTag) + continue + } + } + + // Check 6: Is there another tag after this ? + if strings.Contains(textAfterEnd, kirocommon.ThinkingStartTag) { + nextStartIdx := strings.Index(textAfterEnd, kirocommon.ThinkingStartTag) + textBeforeNextStart := textAfterEnd[:nextStartIdx] + nextBacktickCount := strings.Count(textBeforeNextStart, "`") + nextFenceCount := strings.Count(textBeforeNextStart, "```") + nextAltFenceCount := strings.Count(textBeforeNextStart, "~~~") + + // If the next is NOT in code, then this is discussion text + if nextBacktickCount%2 == 0 && nextFenceCount%2 == 0 && nextAltFenceCount%2 == 0 { + log.Debugf("kiro: found followed by at pos %d, likely discussion text, skipping", endIdx) + searchStart = endIdx + len(kirocommon.ThinkingEndTag) + continue + } + } + + // This looks like a real end tag + return endIdx + } +} + // determineAgenticMode determines if the model is an agentic or chat-only variant. // Returns (isAgentic, isChatOnly) based on model name suffixes. func determineAgenticMode(model string) (isAgentic, isChatOnly bool) { @@ -963,6 +1140,9 @@ func (e *KiroExecutor) parseEventStream(body io.Reader) (string, []kiroclaude.Ki processedIDs := make(map[string]bool) var currentToolUse *kiroclaude.ToolUseState + // Upstream usage tracking - Kiro API returns credit usage and context percentage + var upstreamContextPercentage float64 // Context usage percentage from upstream (e.g., 78.56) + for { msg, eventErr := e.readEventStreamMessage(reader) if eventErr != nil { @@ -1119,6 +1299,119 @@ func (e *KiroExecutor) parseEventStream(body io.Reader) (string, []kiroclaude.Ki stopReason = sr log.Debugf("kiro: parseEventStream found stopReason in messageStopEvent: %s", stopReason) } + + case "messageMetadataEvent": + // Handle message metadata events which may contain token counts + if metadata, ok := event["messageMetadataEvent"].(map[string]interface{}); ok { + if inputTokens, ok := metadata["inputTokens"].(float64); ok { + usageInfo.InputTokens = int64(inputTokens) + log.Debugf("kiro: parseEventStream found inputTokens in messageMetadataEvent: %d", usageInfo.InputTokens) + } + if outputTokens, ok := metadata["outputTokens"].(float64); ok { + usageInfo.OutputTokens = int64(outputTokens) + log.Debugf("kiro: parseEventStream found outputTokens in messageMetadataEvent: %d", usageInfo.OutputTokens) + } + if totalTokens, ok := metadata["totalTokens"].(float64); ok { + usageInfo.TotalTokens = int64(totalTokens) + log.Debugf("kiro: parseEventStream found totalTokens in messageMetadataEvent: %d", usageInfo.TotalTokens) + } + } + + case "usageEvent", "usage": + // Handle dedicated usage events + if inputTokens, ok := event["inputTokens"].(float64); ok { + usageInfo.InputTokens = int64(inputTokens) + log.Debugf("kiro: parseEventStream found inputTokens in usageEvent: %d", usageInfo.InputTokens) + } + if outputTokens, ok := event["outputTokens"].(float64); ok { + usageInfo.OutputTokens = int64(outputTokens) + log.Debugf("kiro: parseEventStream found outputTokens in usageEvent: %d", usageInfo.OutputTokens) + } + if totalTokens, ok := event["totalTokens"].(float64); ok { + usageInfo.TotalTokens = int64(totalTokens) + log.Debugf("kiro: parseEventStream found totalTokens in usageEvent: %d", usageInfo.TotalTokens) + } + // Also check nested usage object + if usageObj, ok := event["usage"].(map[string]interface{}); ok { + if inputTokens, ok := usageObj["input_tokens"].(float64); ok { + usageInfo.InputTokens = int64(inputTokens) + } else if inputTokens, ok := usageObj["prompt_tokens"].(float64); ok { + usageInfo.InputTokens = int64(inputTokens) + } + if outputTokens, ok := usageObj["output_tokens"].(float64); ok { + usageInfo.OutputTokens = int64(outputTokens) + } else if outputTokens, ok := usageObj["completion_tokens"].(float64); ok { + usageInfo.OutputTokens = int64(outputTokens) + } + if totalTokens, ok := usageObj["total_tokens"].(float64); ok { + usageInfo.TotalTokens = int64(totalTokens) + } + log.Debugf("kiro: parseEventStream found usage object: input=%d, output=%d, total=%d", + usageInfo.InputTokens, usageInfo.OutputTokens, usageInfo.TotalTokens) + } + + case "metricsEvent": + // Handle metrics events which may contain usage data + if metrics, ok := event["metricsEvent"].(map[string]interface{}); ok { + if inputTokens, ok := metrics["inputTokens"].(float64); ok { + usageInfo.InputTokens = int64(inputTokens) + } + if outputTokens, ok := metrics["outputTokens"].(float64); ok { + usageInfo.OutputTokens = int64(outputTokens) + } + log.Debugf("kiro: parseEventStream found metricsEvent: input=%d, output=%d", + usageInfo.InputTokens, usageInfo.OutputTokens) + } + + default: + // Check for contextUsagePercentage in any event + if ctxPct, ok := event["contextUsagePercentage"].(float64); ok { + upstreamContextPercentage = ctxPct + log.Debugf("kiro: parseEventStream received context usage: %.2f%%", upstreamContextPercentage) + } + // Log unknown event types for debugging (to discover new event formats) + log.Debugf("kiro: parseEventStream unknown event type: %s, payload: %s", eventType, string(payload)) + } + + // Check for direct token fields in any event (fallback) + if usageInfo.InputTokens == 0 { + if inputTokens, ok := event["inputTokens"].(float64); ok { + usageInfo.InputTokens = int64(inputTokens) + log.Debugf("kiro: parseEventStream found direct inputTokens: %d", usageInfo.InputTokens) + } + } + if usageInfo.OutputTokens == 0 { + if outputTokens, ok := event["outputTokens"].(float64); ok { + usageInfo.OutputTokens = int64(outputTokens) + log.Debugf("kiro: parseEventStream found direct outputTokens: %d", usageInfo.OutputTokens) + } + } + + // Check for usage object in any event (OpenAI format) + if usageInfo.InputTokens == 0 || usageInfo.OutputTokens == 0 { + if usageObj, ok := event["usage"].(map[string]interface{}); ok { + if usageInfo.InputTokens == 0 { + if inputTokens, ok := usageObj["input_tokens"].(float64); ok { + usageInfo.InputTokens = int64(inputTokens) + } else if inputTokens, ok := usageObj["prompt_tokens"].(float64); ok { + usageInfo.InputTokens = int64(inputTokens) + } + } + if usageInfo.OutputTokens == 0 { + if outputTokens, ok := usageObj["output_tokens"].(float64); ok { + usageInfo.OutputTokens = int64(outputTokens) + } else if outputTokens, ok := usageObj["completion_tokens"].(float64); ok { + usageInfo.OutputTokens = int64(outputTokens) + } + } + if usageInfo.TotalTokens == 0 { + if totalTokens, ok := usageObj["total_tokens"].(float64); ok { + usageInfo.TotalTokens = int64(totalTokens) + } + } + log.Debugf("kiro: parseEventStream found usage object (fallback): input=%d, output=%d, total=%d", + usageInfo.InputTokens, usageInfo.OutputTokens, usageInfo.TotalTokens) + } } // Also check nested supplementaryWebLinksEvent @@ -1157,6 +1450,20 @@ func (e *KiroExecutor) parseEventStream(body io.Reader) (string, []kiroclaude.Ki log.Warnf("kiro: response truncated due to max_tokens limit") } + // Use contextUsagePercentage to calculate more accurate input tokens + // Kiro model has 200k max context, contextUsagePercentage represents the percentage used + // Formula: input_tokens = contextUsagePercentage * 200000 / 100 + if upstreamContextPercentage > 0 { + calculatedInputTokens := int64(upstreamContextPercentage * 200000 / 100) + if calculatedInputTokens > 0 { + localEstimate := usageInfo.InputTokens + usageInfo.InputTokens = calculatedInputTokens + usageInfo.TotalTokens = usageInfo.InputTokens + usageInfo.OutputTokens + log.Infof("kiro: parseEventStream using contextUsagePercentage (%.2f%%) to calculate input tokens: %d (local estimate was: %d)", + upstreamContextPercentage, calculatedInputTokens, localEstimate) + } + } + return cleanedContent, toolUses, usageInfo, stopReason, nil } @@ -1357,7 +1664,8 @@ func (e *KiroExecutor) extractEventTypeFromBytes(headers []byte) string { // Includes embedded [Called ...] tool call parsing and input buffering for toolUseEvent. // Implements duplicate content filtering using lastContentEvent detection (based on AIClient-2-API). // Extracts stop_reason from upstream events when available. -func (e *KiroExecutor) streamToChannel(ctx context.Context, body io.Reader, out chan<- cliproxyexecutor.StreamChunk, targetFormat sdktranslator.Format, model string, originalReq, claudeBody []byte, reporter *usageReporter) { +// thinkingEnabled controls whether tags are parsed - only parse when request enabled thinking. +func (e *KiroExecutor) streamToChannel(ctx context.Context, body io.Reader, out chan<- cliproxyexecutor.StreamChunk, targetFormat sdktranslator.Format, model string, originalReq, claudeBody []byte, reporter *usageReporter, thinkingEnabled bool) { reader := bufio.NewReaderSize(body, 20*1024*1024) // 20MB buffer to match other providers var totalUsage usage.Detail var hasToolUses bool // Track if any tool uses were emitted @@ -1383,6 +1691,11 @@ func (e *KiroExecutor) streamToChannel(ctx context.Context, body io.Reader, out var lastUsageUpdateTime = time.Now() // Last time usage update was sent var lastReportedOutputTokens int64 // Last reported output token count + // Upstream usage tracking - Kiro API returns credit usage and context percentage + var upstreamCreditUsage float64 // Credit usage from upstream (e.g., 1.458) + var upstreamContextPercentage float64 // Context usage percentage from upstream (e.g., 78.56) + var hasUpstreamUsage bool // Whether we received usage from upstream + // Translator param for maintaining tool call state across streaming events // IMPORTANT: This must persist across all TranslateStream calls var translatorParam any @@ -1395,6 +1708,22 @@ func (e *KiroExecutor) streamToChannel(ctx context.Context, body io.Reader, out isThinkingBlockOpen := false // Track if thinking content block is open thinkingBlockIndex := -1 // Index of the thinking content block + // Code block state tracking for heuristic thinking tag parsing + // When inside a markdown code block, tags should NOT be parsed + // This prevents false positives when the model outputs code examples containing these tags + inCodeBlock := false + codeFenceType := "" // Track which fence type opened the block ("```" or "~~~") + + // Inline code state tracking - when inside backticks, don't parse thinking tags + // This handles cases like `` being discussed in text + inInlineCode := false + + // Track if we've seen any non-whitespace content before a thinking tag + // Real thinking blocks from Kiro always start at the very beginning of the response + // If we see content before , subsequent tags are likely discussion text + hasSeenNonThinkingContent := false + thinkingBlockCompleted := false // Track if we've already completed a thinking block + // Pre-calculate input tokens from request if possible // Kiro uses Claude format, so try Claude format first, then OpenAI format, then fallback if enc, err := getTokenizer(model); err == nil { @@ -1629,6 +1958,66 @@ func (e *KiroExecutor) streamToChannel(ctx context.Context, body io.Reader, out log.Debugf("kiro: streamToChannel found stopReason in messageStopEvent: %s", upstreamStopReason) } + default: + // Check for upstream usage events from Kiro API + // Format: {"unit":"credit","unitPlural":"credits","usage":1.458} + if unit, ok := event["unit"].(string); ok && unit == "credit" { + if usage, ok := event["usage"].(float64); ok { + upstreamCreditUsage = usage + hasUpstreamUsage = true + log.Debugf("kiro: received upstream credit usage: %.4f", upstreamCreditUsage) + } + } + // Format: {"contextUsagePercentage":78.56} + if ctxPct, ok := event["contextUsagePercentage"].(float64); ok { + upstreamContextPercentage = ctxPct + log.Debugf("kiro: received upstream context usage: %.2f%%", upstreamContextPercentage) + } + + // Check for token counts in unknown events + if inputTokens, ok := event["inputTokens"].(float64); ok { + totalUsage.InputTokens = int64(inputTokens) + hasUpstreamUsage = true + log.Debugf("kiro: streamToChannel found inputTokens in event %s: %d", eventType, totalUsage.InputTokens) + } + if outputTokens, ok := event["outputTokens"].(float64); ok { + totalUsage.OutputTokens = int64(outputTokens) + hasUpstreamUsage = true + log.Debugf("kiro: streamToChannel found outputTokens in event %s: %d", eventType, totalUsage.OutputTokens) + } + if totalTokens, ok := event["totalTokens"].(float64); ok { + totalUsage.TotalTokens = int64(totalTokens) + log.Debugf("kiro: streamToChannel found totalTokens in event %s: %d", eventType, totalUsage.TotalTokens) + } + + // Check for usage object in unknown events (OpenAI/Claude format) + if usageObj, ok := event["usage"].(map[string]interface{}); ok { + if inputTokens, ok := usageObj["input_tokens"].(float64); ok { + totalUsage.InputTokens = int64(inputTokens) + hasUpstreamUsage = true + } else if inputTokens, ok := usageObj["prompt_tokens"].(float64); ok { + totalUsage.InputTokens = int64(inputTokens) + hasUpstreamUsage = true + } + if outputTokens, ok := usageObj["output_tokens"].(float64); ok { + totalUsage.OutputTokens = int64(outputTokens) + hasUpstreamUsage = true + } else if outputTokens, ok := usageObj["completion_tokens"].(float64); ok { + totalUsage.OutputTokens = int64(outputTokens) + hasUpstreamUsage = true + } + if totalTokens, ok := usageObj["total_tokens"].(float64); ok { + totalUsage.TotalTokens = int64(totalTokens) + } + log.Debugf("kiro: streamToChannel found usage object in event %s: input=%d, output=%d, total=%d", + eventType, totalUsage.InputTokens, totalUsage.OutputTokens, totalUsage.TotalTokens) + } + + // Log unknown event types for debugging (to discover new event formats) + if eventType != "" { + log.Debugf("kiro: streamToChannel unknown event type: %s, payload: %s", eventType, string(payload)) + } + case "assistantResponseEvent": var contentDelta string var toolUses []map[string]interface{} @@ -1742,9 +2131,243 @@ func (e *KiroExecutor) streamToChannel(ctx context.Context, body io.Reader, out } for len(remaining) > 0 { + // CRITICAL FIX: Only parse tags when thinking mode was enabled in the request. + // When thinking is NOT enabled, tags in responses should be treated as + // regular text content, not as thinking blocks. This prevents normal text content + // from being incorrectly parsed as thinking when the model outputs tags + // without the user requesting thinking mode. + if !thinkingEnabled { + // Thinking not enabled - emit all content as regular text without parsing tags + if remaining != "" { + if !isTextBlockOpen { + contentBlockIndex++ + isTextBlockOpen = true + blockStart := kiroclaude.BuildClaudeContentBlockStartEvent(contentBlockIndex, "text", "", "") + sseData := sdktranslator.TranslateStream(ctx, sdktranslator.FromString("kiro"), targetFormat, model, originalReq, claudeBody, blockStart, &translatorParam) + for _, chunk := range sseData { + if chunk != "" { + out <- cliproxyexecutor.StreamChunk{Payload: []byte(chunk + "\n\n")} + } + } + } + + claudeEvent := kiroclaude.BuildClaudeStreamEvent(remaining, contentBlockIndex) + sseData := sdktranslator.TranslateStream(ctx, sdktranslator.FromString("kiro"), targetFormat, model, originalReq, claudeBody, claudeEvent, &translatorParam) + for _, chunk := range sseData { + if chunk != "" { + out <- cliproxyexecutor.StreamChunk{Payload: []byte(chunk + "\n\n")} + } + } + } + break // Exit the for loop - all content processed as text + } + + // HEURISTIC FIX: Track code block and inline code state to avoid parsing tags + // inside code contexts. When the model outputs code examples containing these tags, + // they should be treated as text. + if !inThinkBlock { + // Check for inline code backticks first (higher priority than code fences) + // This handles cases like `` being discussed in text + backtickIdx := strings.Index(remaining, kirocommon.InlineCodeMarker) + thinkingIdx := strings.Index(remaining, kirocommon.ThinkingStartTag) + + // If backtick comes before thinking tag, handle inline code + if backtickIdx >= 0 && (thinkingIdx < 0 || backtickIdx < thinkingIdx) { + if inInlineCode { + // Closing backtick - emit content up to and including backtick, exit inline code + textToEmit := remaining[:backtickIdx+1] + if textToEmit != "" { + if !isTextBlockOpen { + contentBlockIndex++ + isTextBlockOpen = true + blockStart := kiroclaude.BuildClaudeContentBlockStartEvent(contentBlockIndex, "text", "", "") + sseData := sdktranslator.TranslateStream(ctx, sdktranslator.FromString("kiro"), targetFormat, model, originalReq, claudeBody, blockStart, &translatorParam) + for _, chunk := range sseData { + if chunk != "" { + out <- cliproxyexecutor.StreamChunk{Payload: []byte(chunk + "\n\n")} + } + } + } + claudeEvent := kiroclaude.BuildClaudeStreamEvent(textToEmit, contentBlockIndex) + sseData := sdktranslator.TranslateStream(ctx, sdktranslator.FromString("kiro"), targetFormat, model, originalReq, claudeBody, claudeEvent, &translatorParam) + for _, chunk := range sseData { + if chunk != "" { + out <- cliproxyexecutor.StreamChunk{Payload: []byte(chunk + "\n\n")} + } + } + } + remaining = remaining[backtickIdx+1:] + inInlineCode = false + continue + } else { + // Opening backtick - emit content before backtick, enter inline code + textToEmit := remaining[:backtickIdx+1] + if textToEmit != "" { + if !isTextBlockOpen { + contentBlockIndex++ + isTextBlockOpen = true + blockStart := kiroclaude.BuildClaudeContentBlockStartEvent(contentBlockIndex, "text", "", "") + sseData := sdktranslator.TranslateStream(ctx, sdktranslator.FromString("kiro"), targetFormat, model, originalReq, claudeBody, blockStart, &translatorParam) + for _, chunk := range sseData { + if chunk != "" { + out <- cliproxyexecutor.StreamChunk{Payload: []byte(chunk + "\n\n")} + } + } + } + claudeEvent := kiroclaude.BuildClaudeStreamEvent(textToEmit, contentBlockIndex) + sseData := sdktranslator.TranslateStream(ctx, sdktranslator.FromString("kiro"), targetFormat, model, originalReq, claudeBody, claudeEvent, &translatorParam) + for _, chunk := range sseData { + if chunk != "" { + out <- cliproxyexecutor.StreamChunk{Payload: []byte(chunk + "\n\n")} + } + } + } + remaining = remaining[backtickIdx+1:] + inInlineCode = true + continue + } + } + + // If inside inline code, emit all content as text (don't parse thinking tags) + if inInlineCode { + if remaining != "" { + if !isTextBlockOpen { + contentBlockIndex++ + isTextBlockOpen = true + blockStart := kiroclaude.BuildClaudeContentBlockStartEvent(contentBlockIndex, "text", "", "") + sseData := sdktranslator.TranslateStream(ctx, sdktranslator.FromString("kiro"), targetFormat, model, originalReq, claudeBody, blockStart, &translatorParam) + for _, chunk := range sseData { + if chunk != "" { + out <- cliproxyexecutor.StreamChunk{Payload: []byte(chunk + "\n\n")} + } + } + } + claudeEvent := kiroclaude.BuildClaudeStreamEvent(remaining, contentBlockIndex) + sseData := sdktranslator.TranslateStream(ctx, sdktranslator.FromString("kiro"), targetFormat, model, originalReq, claudeBody, claudeEvent, &translatorParam) + for _, chunk := range sseData { + if chunk != "" { + out <- cliproxyexecutor.StreamChunk{Payload: []byte(chunk + "\n\n")} + } + } + } + break // Exit loop - remaining content is inside inline code + } + + // Check for code fence markers (``` or ~~~) to toggle code block state + fenceIdx := strings.Index(remaining, kirocommon.CodeFenceMarker) + altFenceIdx := strings.Index(remaining, kirocommon.AltCodeFenceMarker) + + // Find the earliest fence marker + earliestFenceIdx := -1 + earliestFenceType := "" + if fenceIdx >= 0 && (altFenceIdx < 0 || fenceIdx < altFenceIdx) { + earliestFenceIdx = fenceIdx + earliestFenceType = kirocommon.CodeFenceMarker + } else if altFenceIdx >= 0 { + earliestFenceIdx = altFenceIdx + earliestFenceType = kirocommon.AltCodeFenceMarker + } + + if earliestFenceIdx >= 0 { + // Check if this fence comes before any thinking tag + thinkingIdx := strings.Index(remaining, kirocommon.ThinkingStartTag) + if inCodeBlock { + // Inside code block - check if this fence closes it + if earliestFenceType == codeFenceType { + // This fence closes the code block + // Emit content up to and including the fence as text + textToEmit := remaining[:earliestFenceIdx+len(earliestFenceType)] + if textToEmit != "" { + if !isTextBlockOpen { + contentBlockIndex++ + isTextBlockOpen = true + blockStart := kiroclaude.BuildClaudeContentBlockStartEvent(contentBlockIndex, "text", "", "") + sseData := sdktranslator.TranslateStream(ctx, sdktranslator.FromString("kiro"), targetFormat, model, originalReq, claudeBody, blockStart, &translatorParam) + for _, chunk := range sseData { + if chunk != "" { + out <- cliproxyexecutor.StreamChunk{Payload: []byte(chunk + "\n\n")} + } + } + } + claudeEvent := kiroclaude.BuildClaudeStreamEvent(textToEmit, contentBlockIndex) + sseData := sdktranslator.TranslateStream(ctx, sdktranslator.FromString("kiro"), targetFormat, model, originalReq, claudeBody, claudeEvent, &translatorParam) + for _, chunk := range sseData { + if chunk != "" { + out <- cliproxyexecutor.StreamChunk{Payload: []byte(chunk + "\n\n")} + } + } + } + remaining = remaining[earliestFenceIdx+len(earliestFenceType):] + inCodeBlock = false + codeFenceType = "" + log.Debugf("kiro: exited code block") + continue + } + } else if thinkingIdx < 0 || earliestFenceIdx < thinkingIdx { + // Not in code block, and fence comes before thinking tag (or no thinking tag) + // Emit content up to and including the fence as text, then enter code block + textToEmit := remaining[:earliestFenceIdx+len(earliestFenceType)] + if textToEmit != "" { + if !isTextBlockOpen { + contentBlockIndex++ + isTextBlockOpen = true + blockStart := kiroclaude.BuildClaudeContentBlockStartEvent(contentBlockIndex, "text", "", "") + sseData := sdktranslator.TranslateStream(ctx, sdktranslator.FromString("kiro"), targetFormat, model, originalReq, claudeBody, blockStart, &translatorParam) + for _, chunk := range sseData { + if chunk != "" { + out <- cliproxyexecutor.StreamChunk{Payload: []byte(chunk + "\n\n")} + } + } + } + claudeEvent := kiroclaude.BuildClaudeStreamEvent(textToEmit, contentBlockIndex) + sseData := sdktranslator.TranslateStream(ctx, sdktranslator.FromString("kiro"), targetFormat, model, originalReq, claudeBody, claudeEvent, &translatorParam) + for _, chunk := range sseData { + if chunk != "" { + out <- cliproxyexecutor.StreamChunk{Payload: []byte(chunk + "\n\n")} + } + } + } + remaining = remaining[earliestFenceIdx+len(earliestFenceType):] + inCodeBlock = true + codeFenceType = earliestFenceType + log.Debugf("kiro: entered code block with fence: %s", earliestFenceType) + continue + } + } + + // If inside code block, emit all content as text (don't parse thinking tags) + if inCodeBlock { + if remaining != "" { + if !isTextBlockOpen { + contentBlockIndex++ + isTextBlockOpen = true + blockStart := kiroclaude.BuildClaudeContentBlockStartEvent(contentBlockIndex, "text", "", "") + sseData := sdktranslator.TranslateStream(ctx, sdktranslator.FromString("kiro"), targetFormat, model, originalReq, claudeBody, blockStart, &translatorParam) + for _, chunk := range sseData { + if chunk != "" { + out <- cliproxyexecutor.StreamChunk{Payload: []byte(chunk + "\n\n")} + } + } + } + claudeEvent := kiroclaude.BuildClaudeStreamEvent(remaining, contentBlockIndex) + sseData := sdktranslator.TranslateStream(ctx, sdktranslator.FromString("kiro"), targetFormat, model, originalReq, claudeBody, claudeEvent, &translatorParam) + for _, chunk := range sseData { + if chunk != "" { + out <- cliproxyexecutor.StreamChunk{Payload: []byte(chunk + "\n\n")} + } + } + } + break // Exit loop - all remaining content is inside code block + } + } + if inThinkBlock { // Inside thinking block - look for end tag - endIdx := strings.Index(remaining, kirocommon.ThinkingEndTag) + // CRITICAL FIX: Skip tags that are not the real end tag + // This prevents false positives when thinking content discusses these tags + // Pass current code block/inline code state for accurate detection + endIdx := findRealThinkingEndTag(remaining, inCodeBlock, inInlineCode) + if endIdx >= 0 { // Found end tag - emit any content before end tag, then close block thinkContent := remaining[:endIdx] @@ -1790,8 +2413,9 @@ func (e *KiroExecutor) streamToChannel(ctx context.Context, body io.Reader, out } inThinkBlock = false + thinkingBlockCompleted = true // Mark that we've completed a thinking block remaining = remaining[endIdx+len(kirocommon.ThinkingEndTag):] - log.Debugf("kiro: exited thinking block") + log.Debugf("kiro: exited thinking block, subsequent tags will be treated as text") } else { // No end tag found - TRUE STREAMING: emit content immediately // Only save potential partial tag length for next iteration @@ -1837,11 +2461,29 @@ func (e *KiroExecutor) streamToChannel(ctx context.Context, body io.Reader, out } } else { // Outside thinking block - look for start tag - startIdx := strings.Index(remaining, kirocommon.ThinkingStartTag) + // CRITICAL FIX: Only parse tags at the very beginning of the response + // or if we haven't completed a thinking block yet. + // After a thinking block is completed, subsequent tags are likely + // discussion text (e.g., "Kiro returns `` tags") and should NOT be parsed. + startIdx := -1 + if !thinkingBlockCompleted && !hasSeenNonThinkingContent { + startIdx = strings.Index(remaining, kirocommon.ThinkingStartTag) + // If there's non-whitespace content before the tag, it's not a real thinking block + if startIdx > 0 { + textBefore := remaining[:startIdx] + if strings.TrimSpace(textBefore) != "" { + // There's real content before the tag - this is discussion text, not thinking + hasSeenNonThinkingContent = true + startIdx = -1 + log.Debugf("kiro: found tag after non-whitespace content, treating as text") + } + } + } if startIdx >= 0 { // Found start tag - emit text before it and switch to thinking mode textBefore := remaining[:startIdx] if textBefore != "" { + // Only whitespace before thinking tag is allowed // Start text content block if needed if !isTextBlockOpen { contentBlockIndex++ @@ -1881,11 +2523,19 @@ func (e *KiroExecutor) streamToChannel(ctx context.Context, body io.Reader, out log.Debugf("kiro: entered thinking block") } else { // No start tag found - check for partial start tag at buffer end - pendingStart := kiroclaude.PendingTagSuffix(remaining, kirocommon.ThinkingStartTag) + // Only check for partial tags if we haven't completed a thinking block yet + pendingStart := 0 + if !thinkingBlockCompleted && !hasSeenNonThinkingContent { + pendingStart = kiroclaude.PendingTagSuffix(remaining, kirocommon.ThinkingStartTag) + } if pendingStart > 0 { // Emit text except potential partial tag textToEmit := remaining[:len(remaining)-pendingStart] if textToEmit != "" { + // Mark that we've seen non-thinking content + if strings.TrimSpace(textToEmit) != "" { + hasSeenNonThinkingContent = true + } // Start text content block if needed if !isTextBlockOpen { contentBlockIndex++ @@ -1912,6 +2562,10 @@ func (e *KiroExecutor) streamToChannel(ctx context.Context, body io.Reader, out } else { // No partial tag - emit all as text if remaining != "" { + // Mark that we've seen non-thinking content + if strings.TrimSpace(remaining) != "" { + hasSeenNonThinkingContent = true + } // Start text content block if needed if !isTextBlockOpen { contentBlockIndex++ @@ -2065,6 +2719,69 @@ func (e *KiroExecutor) streamToChannel(ctx context.Context, body io.Reader, out if outputTokens, ok := event["outputTokens"].(float64); ok { totalUsage.OutputTokens = int64(outputTokens) } + + case "messageMetadataEvent": + // Handle message metadata events which may contain token counts + if metadata, ok := event["messageMetadataEvent"].(map[string]interface{}); ok { + if inputTokens, ok := metadata["inputTokens"].(float64); ok { + totalUsage.InputTokens = int64(inputTokens) + log.Debugf("kiro: streamToChannel found inputTokens in messageMetadataEvent: %d", totalUsage.InputTokens) + } + if outputTokens, ok := metadata["outputTokens"].(float64); ok { + totalUsage.OutputTokens = int64(outputTokens) + log.Debugf("kiro: streamToChannel found outputTokens in messageMetadataEvent: %d", totalUsage.OutputTokens) + } + if totalTokens, ok := metadata["totalTokens"].(float64); ok { + totalUsage.TotalTokens = int64(totalTokens) + log.Debugf("kiro: streamToChannel found totalTokens in messageMetadataEvent: %d", totalUsage.TotalTokens) + } + } + + case "usageEvent", "usage": + // Handle dedicated usage events + if inputTokens, ok := event["inputTokens"].(float64); ok { + totalUsage.InputTokens = int64(inputTokens) + log.Debugf("kiro: streamToChannel found inputTokens in usageEvent: %d", totalUsage.InputTokens) + } + if outputTokens, ok := event["outputTokens"].(float64); ok { + totalUsage.OutputTokens = int64(outputTokens) + log.Debugf("kiro: streamToChannel found outputTokens in usageEvent: %d", totalUsage.OutputTokens) + } + if totalTokens, ok := event["totalTokens"].(float64); ok { + totalUsage.TotalTokens = int64(totalTokens) + log.Debugf("kiro: streamToChannel found totalTokens in usageEvent: %d", totalUsage.TotalTokens) + } + // Also check nested usage object + if usageObj, ok := event["usage"].(map[string]interface{}); ok { + if inputTokens, ok := usageObj["input_tokens"].(float64); ok { + totalUsage.InputTokens = int64(inputTokens) + } else if inputTokens, ok := usageObj["prompt_tokens"].(float64); ok { + totalUsage.InputTokens = int64(inputTokens) + } + if outputTokens, ok := usageObj["output_tokens"].(float64); ok { + totalUsage.OutputTokens = int64(outputTokens) + } else if outputTokens, ok := usageObj["completion_tokens"].(float64); ok { + totalUsage.OutputTokens = int64(outputTokens) + } + if totalTokens, ok := usageObj["total_tokens"].(float64); ok { + totalUsage.TotalTokens = int64(totalTokens) + } + log.Debugf("kiro: streamToChannel found usage object: input=%d, output=%d, total=%d", + totalUsage.InputTokens, totalUsage.OutputTokens, totalUsage.TotalTokens) + } + + case "metricsEvent": + // Handle metrics events which may contain usage data + if metrics, ok := event["metricsEvent"].(map[string]interface{}); ok { + if inputTokens, ok := metrics["inputTokens"].(float64); ok { + totalUsage.InputTokens = int64(inputTokens) + } + if outputTokens, ok := metrics["outputTokens"].(float64); ok { + totalUsage.OutputTokens = int64(outputTokens) + } + log.Debugf("kiro: streamToChannel found metricsEvent: input=%d, output=%d", + totalUsage.InputTokens, totalUsage.OutputTokens) + } } // Check nested usage event @@ -2076,6 +2793,47 @@ func (e *KiroExecutor) streamToChannel(ctx context.Context, body io.Reader, out totalUsage.OutputTokens = int64(outputTokens) } } + + // Check for direct token fields in any event (fallback) + if totalUsage.InputTokens == 0 { + if inputTokens, ok := event["inputTokens"].(float64); ok { + totalUsage.InputTokens = int64(inputTokens) + log.Debugf("kiro: streamToChannel found direct inputTokens: %d", totalUsage.InputTokens) + } + } + if totalUsage.OutputTokens == 0 { + if outputTokens, ok := event["outputTokens"].(float64); ok { + totalUsage.OutputTokens = int64(outputTokens) + log.Debugf("kiro: streamToChannel found direct outputTokens: %d", totalUsage.OutputTokens) + } + } + + // Check for usage object in any event (OpenAI format) + if totalUsage.InputTokens == 0 || totalUsage.OutputTokens == 0 { + if usageObj, ok := event["usage"].(map[string]interface{}); ok { + if totalUsage.InputTokens == 0 { + if inputTokens, ok := usageObj["input_tokens"].(float64); ok { + totalUsage.InputTokens = int64(inputTokens) + } else if inputTokens, ok := usageObj["prompt_tokens"].(float64); ok { + totalUsage.InputTokens = int64(inputTokens) + } + } + if totalUsage.OutputTokens == 0 { + if outputTokens, ok := usageObj["output_tokens"].(float64); ok { + totalUsage.OutputTokens = int64(outputTokens) + } else if outputTokens, ok := usageObj["completion_tokens"].(float64); ok { + totalUsage.OutputTokens = int64(outputTokens) + } + } + if totalUsage.TotalTokens == 0 { + if totalTokens, ok := usageObj["total_tokens"].(float64); ok { + totalUsage.TotalTokens = int64(totalTokens) + } + } + log.Debugf("kiro: streamToChannel found usage object (fallback): input=%d, output=%d, total=%d", + totalUsage.InputTokens, totalUsage.OutputTokens, totalUsage.TotalTokens) + } + } } // Close content block if open @@ -2120,8 +2878,35 @@ func (e *KiroExecutor) streamToChannel(ctx context.Context, body io.Reader, out totalUsage.OutputTokens = 1 } } + + // Use contextUsagePercentage to calculate more accurate input tokens + // Kiro model has 200k max context, contextUsagePercentage represents the percentage used + // Formula: input_tokens = contextUsagePercentage * 200000 / 100 + // Note: The effective input context is ~170k (200k - 30k reserved for output) + if upstreamContextPercentage > 0 { + // Calculate input tokens from context percentage + // Using 200k as the base since that's what Kiro reports against + calculatedInputTokens := int64(upstreamContextPercentage * 200000 / 100) + + // Only use calculated value if it's significantly different from local estimate + // This provides more accurate token counts based on upstream data + if calculatedInputTokens > 0 { + localEstimate := totalUsage.InputTokens + totalUsage.InputTokens = calculatedInputTokens + log.Infof("kiro: using contextUsagePercentage (%.2f%%) to calculate input tokens: %d (local estimate was: %d)", + upstreamContextPercentage, calculatedInputTokens, localEstimate) + } + } + totalUsage.TotalTokens = totalUsage.InputTokens + totalUsage.OutputTokens + // Log upstream usage information if received + if hasUpstreamUsage { + log.Infof("kiro: upstream usage - credits: %.4f, context: %.2f%%, final tokens - input: %d, output: %d, total: %d", + upstreamCreditUsage, upstreamContextPercentage, + totalUsage.InputTokens, totalUsage.OutputTokens, totalUsage.TotalTokens) + } + // Determine stop reason: prefer upstream, then detect tool_use, default to end_turn stopReason := upstreamStopReason if stopReason == "" { @@ -2162,10 +2947,48 @@ func (e *KiroExecutor) streamToChannel(ctx context.Context, body io.Reader, out // NOTE: Claude SSE event builders moved to internal/translator/kiro/claude/kiro_claude_stream.go // The executor now uses kiroclaude.BuildClaude*Event() functions instead -// CountTokens is not supported for Kiro provider. -// Kiro/Amazon Q backend doesn't expose a token counting API. -func (e *KiroExecutor) CountTokens(context.Context, *cliproxyauth.Auth, cliproxyexecutor.Request, cliproxyexecutor.Options) (cliproxyexecutor.Response, error) { - return cliproxyexecutor.Response{}, statusErr{code: http.StatusNotImplemented, msg: "count tokens not supported for kiro"} +// CountTokens counts tokens locally using tiktoken since Kiro API doesn't expose a token counting endpoint. +// This provides approximate token counts for client requests. +func (e *KiroExecutor) CountTokens(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (cliproxyexecutor.Response, error) { + // Use tiktoken for local token counting + enc, err := getTokenizer(req.Model) + if err != nil { + log.Warnf("kiro: CountTokens failed to get tokenizer: %v, falling back to estimate", err) + // Fallback: estimate from payload size (roughly 4 chars per token) + estimatedTokens := len(req.Payload) / 4 + if estimatedTokens == 0 && len(req.Payload) > 0 { + estimatedTokens = 1 + } + return cliproxyexecutor.Response{ + Payload: []byte(fmt.Sprintf(`{"count":%d}`, estimatedTokens)), + }, nil + } + + // Try to count tokens from the request payload + var totalTokens int64 + + // Try OpenAI chat format first + if tokens, countErr := countOpenAIChatTokens(enc, req.Payload); countErr == nil && tokens > 0 { + totalTokens = tokens + log.Debugf("kiro: CountTokens counted %d tokens using OpenAI chat format", totalTokens) + } else { + // Fallback: count raw payload tokens + if tokenCount, countErr := enc.Count(string(req.Payload)); countErr == nil { + totalTokens = int64(tokenCount) + log.Debugf("kiro: CountTokens counted %d tokens from raw payload", totalTokens) + } else { + // Final fallback: estimate from payload size + totalTokens = int64(len(req.Payload) / 4) + if totalTokens == 0 && len(req.Payload) > 0 { + totalTokens = 1 + } + log.Debugf("kiro: CountTokens estimated %d tokens from payload size", totalTokens) + } + } + + return cliproxyexecutor.Response{ + Payload: []byte(fmt.Sprintf(`{"count":%d}`, totalTokens)), + }, nil } // Refresh refreshes the Kiro OAuth token. diff --git a/internal/translator/kiro/claude/kiro_claude_request.go b/internal/translator/kiro/claude/kiro_claude_request.go index 07472be4..ae42b186 100644 --- a/internal/translator/kiro/claude/kiro_claude_request.go +++ b/internal/translator/kiro/claude/kiro_claude_request.go @@ -30,6 +30,7 @@ type KiroPayload struct { type KiroInferenceConfig struct { MaxTokens int `json:"maxTokens,omitempty"` Temperature float64 `json:"temperature,omitempty"` + TopP float64 `json:"topP,omitempty"` } // KiroConversationState holds the conversation context @@ -136,9 +137,15 @@ func ConvertClaudeRequestToKiro(modelName string, inputRawJSON []byte, stream bo // Supports thinking mode - when Claude API thinking parameter is present, injects thinkingHint. func BuildKiroPayload(claudeBody []byte, modelID, profileArn, origin string, isAgentic, isChatOnly bool) []byte { // Extract max_tokens for potential use in inferenceConfig + // Handle -1 as "use maximum" (Kiro max output is ~32000 tokens) + const kiroMaxOutputTokens = 32000 var maxTokens int64 if mt := gjson.GetBytes(claudeBody, "max_tokens"); mt.Exists() { maxTokens = mt.Int() + if maxTokens == -1 { + maxTokens = kiroMaxOutputTokens + log.Debugf("kiro: max_tokens=-1 converted to %d", kiroMaxOutputTokens) + } } // Extract temperature if specified @@ -149,6 +156,15 @@ func BuildKiroPayload(claudeBody []byte, modelID, profileArn, origin string, isA hasTemperature = true } + // Extract top_p if specified + var topP float64 + var hasTopP bool + if tp := gjson.GetBytes(claudeBody, "top_p"); tp.Exists() { + topP = tp.Float() + hasTopP = true + log.Debugf("kiro: extracted top_p: %.2f", topP) + } + // Normalize origin value for Kiro API compatibility origin = normalizeOrigin(origin) log.Debugf("kiro: normalized origin value: %s", origin) @@ -164,8 +180,26 @@ func BuildKiroPayload(claudeBody []byte, modelID, profileArn, origin string, isA // Extract system prompt systemPrompt := extractSystemPrompt(claudeBody) - // Check for thinking mode - thinkingEnabled, budgetTokens := checkThinkingMode(claudeBody) + // Check for thinking mode using the comprehensive IsThinkingEnabled function + // This supports Claude API format, OpenAI reasoning_effort, and AMP/Cursor format + thinkingEnabled := IsThinkingEnabled(claudeBody) + _, budgetTokens := checkThinkingMode(claudeBody) // Get budget tokens from Claude format if available + if budgetTokens <= 0 { + // Calculate budgetTokens based on max_tokens if available + // Use 50% of max_tokens for thinking, with min 8000 and max 24000 + if maxTokens > 0 { + budgetTokens = maxTokens / 2 + if budgetTokens < 8000 { + budgetTokens = 8000 + } + if budgetTokens > 24000 { + budgetTokens = 24000 + } + log.Debugf("kiro: budgetTokens calculated from max_tokens: %d (max_tokens=%d)", budgetTokens, maxTokens) + } else { + budgetTokens = 16000 // Default budget tokens + } + } // Inject timestamp context timestamp := time.Now().Format("2006-01-02 15:04:05 MST") @@ -185,6 +219,17 @@ func BuildKiroPayload(claudeBody []byte, modelID, profileArn, origin string, isA systemPrompt += kirocommon.KiroAgenticSystemPrompt } + // Handle tool_choice parameter - Kiro doesn't support it natively, so we inject system prompt hints + // Claude tool_choice values: {"type": "auto/any/tool", "name": "..."} + toolChoiceHint := extractClaudeToolChoiceHint(claudeBody) + if toolChoiceHint != "" { + if systemPrompt != "" { + systemPrompt += "\n" + } + systemPrompt += toolChoiceHint + log.Debugf("kiro: injected tool_choice hint into system prompt") + } + // Inject thinking hint when thinking mode is enabled if thinkingEnabled { if systemPrompt != "" { @@ -235,7 +280,7 @@ func BuildKiroPayload(claudeBody []byte, modelID, profileArn, origin string, isA // Build inferenceConfig if we have any inference parameters var inferenceConfig *KiroInferenceConfig - if maxTokens > 0 || hasTemperature { + if maxTokens > 0 || hasTemperature || hasTopP { inferenceConfig = &KiroInferenceConfig{} if maxTokens > 0 { inferenceConfig.MaxTokens = int(maxTokens) @@ -243,6 +288,9 @@ func BuildKiroPayload(claudeBody []byte, modelID, profileArn, origin string, isA if hasTemperature { inferenceConfig.Temperature = temperature } + if hasTopP { + inferenceConfig.TopP = topP + } } payload := KiroPayload{ @@ -324,6 +372,93 @@ func checkThinkingMode(claudeBody []byte) (bool, int64) { return thinkingEnabled, budgetTokens } +// IsThinkingEnabled is a public wrapper to check if thinking mode is enabled. +// This is used by the executor to determine whether to parse tags in responses. +// When thinking is NOT enabled in the request, tags in responses should be +// treated as regular text content, not as thinking blocks. +// +// Supports multiple formats: +// - Claude API format: thinking.type = "enabled" +// - OpenAI format: reasoning_effort parameter +// - AMP/Cursor format: interleaved in system prompt +func IsThinkingEnabled(body []byte) bool { + // Check Claude API format first (thinking.type = "enabled") + enabled, _ := checkThinkingMode(body) + if enabled { + log.Debugf("kiro: IsThinkingEnabled returning true (Claude API format)") + return true + } + + // Check OpenAI format: reasoning_effort parameter + // Valid values: "low", "medium", "high", "auto" (not "none") + reasoningEffort := gjson.GetBytes(body, "reasoning_effort") + if reasoningEffort.Exists() { + effort := reasoningEffort.String() + if effort != "" && effort != "none" { + log.Debugf("kiro: thinking mode enabled via OpenAI reasoning_effort: %s", effort) + return true + } + } + + // Check AMP/Cursor format: interleaved in system prompt + // This is how AMP client passes thinking configuration + bodyStr := string(body) + if strings.Contains(bodyStr, "") && strings.Contains(bodyStr, "") { + // Extract thinking mode value + startTag := "" + endTag := "" + startIdx := strings.Index(bodyStr, startTag) + if startIdx >= 0 { + startIdx += len(startTag) + endIdx := strings.Index(bodyStr[startIdx:], endTag) + if endIdx >= 0 { + thinkingMode := bodyStr[startIdx : startIdx+endIdx] + if thinkingMode == "interleaved" || thinkingMode == "enabled" { + log.Debugf("kiro: thinking mode enabled via AMP/Cursor format: %s", thinkingMode) + return true + } + } + } + } + + // Check OpenAI format: max_completion_tokens with reasoning (o1-style) + // Some clients use this to indicate reasoning mode + if gjson.GetBytes(body, "max_completion_tokens").Exists() { + // If max_completion_tokens is set, check if model name suggests reasoning + model := gjson.GetBytes(body, "model").String() + if strings.Contains(strings.ToLower(model), "thinking") || + strings.Contains(strings.ToLower(model), "reason") { + log.Debugf("kiro: thinking mode enabled via model name hint: %s", model) + return true + } + } + + log.Debugf("kiro: IsThinkingEnabled returning false (no thinking mode detected)") + return false +} + +// shortenToolNameIfNeeded shortens tool names that exceed 64 characters. +// MCP tools often have long names like "mcp__server-name__tool-name". +// This preserves the "mcp__" prefix and last segment when possible. +func shortenToolNameIfNeeded(name string) string { + const limit = 64 + if len(name) <= limit { + return name + } + // For MCP tools, try to preserve prefix and last segment + if strings.HasPrefix(name, "mcp__") { + idx := strings.LastIndex(name, "__") + if idx > 0 { + cand := "mcp__" + name[idx+2:] + if len(cand) > limit { + return cand[:limit] + } + return cand + } + } + return name[:limit] +} + // convertClaudeToolsToKiro converts Claude tools to Kiro format func convertClaudeToolsToKiro(tools gjson.Result) []KiroToolWrapper { var kiroTools []KiroToolWrapper @@ -336,6 +471,13 @@ func convertClaudeToolsToKiro(tools gjson.Result) []KiroToolWrapper { description := tool.Get("description").String() inputSchema := tool.Get("input_schema").Value() + // Shorten tool name if it exceeds 64 characters (common with MCP tools) + originalName := name + name = shortenToolNameIfNeeded(name) + if name != originalName { + log.Debugf("kiro: shortened tool name from '%s' to '%s'", originalName, name) + } + // CRITICAL FIX: Kiro API requires non-empty description if strings.TrimSpace(description) == "" { description = fmt.Sprintf("Tool: %s", name) @@ -467,6 +609,34 @@ func deduplicateToolResults(toolResults []KiroToolResult) []KiroToolResult { return unique } +// extractClaudeToolChoiceHint extracts tool_choice from Claude request and returns a system prompt hint. +// Claude tool_choice values: +// - {"type": "auto"}: Model decides (default, no hint needed) +// - {"type": "any"}: Must use at least one tool +// - {"type": "tool", "name": "..."}: Must use specific tool +func extractClaudeToolChoiceHint(claudeBody []byte) string { + toolChoice := gjson.GetBytes(claudeBody, "tool_choice") + if !toolChoice.Exists() { + return "" + } + + toolChoiceType := toolChoice.Get("type").String() + switch toolChoiceType { + case "any": + return "[INSTRUCTION: You MUST use at least one of the available tools to respond. Do not respond with text only - always make a tool call.]" + case "tool": + toolName := toolChoice.Get("name").String() + if toolName != "" { + return fmt.Sprintf("[INSTRUCTION: You MUST use the tool named '%s' to respond. Do not use any other tool or respond with text only.]", toolName) + } + case "auto": + // Default behavior, no hint needed + return "" + } + + return "" +} + // BuildUserMessageStruct builds a user message and extracts tool results func BuildUserMessageStruct(msg gjson.Result, modelID, origin string) (KiroUserInputMessage, []KiroToolResult) { content := msg.Get("content") diff --git a/internal/translator/kiro/common/constants.go b/internal/translator/kiro/common/constants.go index 1d4b0330..96174b8c 100644 --- a/internal/translator/kiro/common/constants.go +++ b/internal/translator/kiro/common/constants.go @@ -12,6 +12,15 @@ const ( // ThinkingEndTag is the end tag for thinking blocks in responses. ThinkingEndTag = "" + // CodeFenceMarker is the markdown code fence marker. + CodeFenceMarker = "```" + + // AltCodeFenceMarker is the alternative markdown code fence marker. + AltCodeFenceMarker = "~~~" + + // InlineCodeMarker is the markdown inline code marker (backtick). + InlineCodeMarker = "`" + // KiroAgenticSystemPrompt is injected only for -agentic models to prevent timeouts on large writes. // AWS Kiro API has a 2-3 minute timeout for large file write operations. KiroAgenticSystemPrompt = ` diff --git a/internal/translator/kiro/openai/kiro_openai.go b/internal/translator/kiro/openai/kiro_openai.go index 35cd0424..d5822998 100644 --- a/internal/translator/kiro/openai/kiro_openai.go +++ b/internal/translator/kiro/openai/kiro_openai.go @@ -156,8 +156,9 @@ func ConvertKiroStreamToOpenAI(ctx context.Context, model string, originalReques } case "message_stop": - // Final event - emit [DONE] - results = append(results, BuildOpenAISSEDone()) + // Final event - do NOT emit [DONE] here + // The handler layer (openai_handlers.go) will send [DONE] when the stream closes + // Emitting [DONE] here would cause duplicate [DONE] markers case "ping": // Ping event with usage - optionally emit usage chunk diff --git a/internal/translator/kiro/openai/kiro_openai_request.go b/internal/translator/kiro/openai/kiro_openai_request.go index 4aaa8b4e..21b15aa0 100644 --- a/internal/translator/kiro/openai/kiro_openai_request.go +++ b/internal/translator/kiro/openai/kiro_openai_request.go @@ -29,6 +29,7 @@ type KiroPayload struct { type KiroInferenceConfig struct { MaxTokens int `json:"maxTokens,omitempty"` Temperature float64 `json:"temperature,omitempty"` + TopP float64 `json:"topP,omitempty"` } // KiroConversationState holds the conversation context @@ -134,9 +135,15 @@ func ConvertOpenAIRequestToKiro(modelName string, inputRawJSON []byte, stream bo // isChatOnly parameter disables tool calling for -chat model variants (pure conversation mode). func BuildKiroPayloadFromOpenAI(openaiBody []byte, modelID, profileArn, origin string, isAgentic, isChatOnly bool) []byte { // Extract max_tokens for potential use in inferenceConfig + // Handle -1 as "use maximum" (Kiro max output is ~32000 tokens) + const kiroMaxOutputTokens = 32000 var maxTokens int64 if mt := gjson.GetBytes(openaiBody, "max_tokens"); mt.Exists() { maxTokens = mt.Int() + if maxTokens == -1 { + maxTokens = kiroMaxOutputTokens + log.Debugf("kiro-openai: max_tokens=-1 converted to %d", kiroMaxOutputTokens) + } } // Extract temperature if specified @@ -147,6 +154,15 @@ func BuildKiroPayloadFromOpenAI(openaiBody []byte, modelID, profileArn, origin s hasTemperature = true } + // Extract top_p if specified + var topP float64 + var hasTopP bool + if tp := gjson.GetBytes(openaiBody, "top_p"); tp.Exists() { + topP = tp.Float() + hasTopP = true + log.Debugf("kiro-openai: extracted top_p: %.2f", topP) + } + // Normalize origin value for Kiro API compatibility origin = normalizeOrigin(origin) log.Debugf("kiro-openai: normalized origin value: %s", origin) @@ -180,6 +196,54 @@ func BuildKiroPayloadFromOpenAI(openaiBody []byte, modelID, profileArn, origin s systemPrompt += kirocommon.KiroAgenticSystemPrompt } + // Handle tool_choice parameter - Kiro doesn't support it natively, so we inject system prompt hints + // OpenAI tool_choice values: "none", "auto", "required", or {"type":"function","function":{"name":"..."}} + toolChoiceHint := extractToolChoiceHint(openaiBody) + if toolChoiceHint != "" { + if systemPrompt != "" { + systemPrompt += "\n" + } + systemPrompt += toolChoiceHint + log.Debugf("kiro-openai: injected tool_choice hint into system prompt") + } + + // Handle response_format parameter - Kiro doesn't support it natively, so we inject system prompt hints + // OpenAI response_format: {"type": "json_object"} or {"type": "json_schema", "json_schema": {...}} + responseFormatHint := extractResponseFormatHint(openaiBody) + if responseFormatHint != "" { + if systemPrompt != "" { + systemPrompt += "\n" + } + systemPrompt += responseFormatHint + log.Debugf("kiro-openai: injected response_format hint into system prompt") + } + + // Check for thinking mode and inject thinking hint + // Supports OpenAI reasoning_effort parameter and model name hints + thinkingEnabled, budgetTokens := checkThinkingModeFromOpenAI(openaiBody) + if thinkingEnabled { + // Adjust budgetTokens based on max_tokens if not explicitly set by reasoning_effort + // Use 50% of max_tokens for thinking, with min 8000 and max 24000 + if maxTokens > 0 && budgetTokens == 16000 { // 16000 is the default, meaning not explicitly set + calculatedBudget := maxTokens / 2 + if calculatedBudget < 8000 { + calculatedBudget = 8000 + } + if calculatedBudget > 24000 { + calculatedBudget = 24000 + } + budgetTokens = calculatedBudget + log.Debugf("kiro-openai: budgetTokens calculated from max_tokens: %d (max_tokens=%d)", budgetTokens, maxTokens) + } + + if systemPrompt != "" { + systemPrompt += "\n" + } + dynamicThinkingHint := fmt.Sprintf("interleaved%d", budgetTokens) + systemPrompt += dynamicThinkingHint + log.Debugf("kiro-openai: injected dynamic thinking hint into system prompt, max_thinking_length: %d", budgetTokens) + } + // Convert OpenAI tools to Kiro format kiroTools := convertOpenAIToolsToKiro(tools) @@ -220,7 +284,7 @@ func BuildKiroPayloadFromOpenAI(openaiBody []byte, modelID, profileArn, origin s // Build inferenceConfig if we have any inference parameters var inferenceConfig *KiroInferenceConfig - if maxTokens > 0 || hasTemperature { + if maxTokens > 0 || hasTemperature || hasTopP { inferenceConfig = &KiroInferenceConfig{} if maxTokens > 0 { inferenceConfig.MaxTokens = int(maxTokens) @@ -228,6 +292,9 @@ func BuildKiroPayloadFromOpenAI(openaiBody []byte, modelID, profileArn, origin s if hasTemperature { inferenceConfig.Temperature = temperature } + if hasTopP { + inferenceConfig.TopP = topP + } } payload := KiroPayload{ @@ -292,6 +359,28 @@ func extractSystemPromptFromOpenAI(messages gjson.Result) string { return strings.Join(systemParts, "\n") } +// shortenToolNameIfNeeded shortens tool names that exceed 64 characters. +// MCP tools often have long names like "mcp__server-name__tool-name". +// This preserves the "mcp__" prefix and last segment when possible. +func shortenToolNameIfNeeded(name string) string { + const limit = 64 + if len(name) <= limit { + return name + } + // For MCP tools, try to preserve prefix and last segment + if strings.HasPrefix(name, "mcp__") { + idx := strings.LastIndex(name, "__") + if idx > 0 { + cand := "mcp__" + name[idx+2:] + if len(cand) > limit { + return cand[:limit] + } + return cand + } + } + return name[:limit] +} + // convertOpenAIToolsToKiro converts OpenAI tools to Kiro format func convertOpenAIToolsToKiro(tools gjson.Result) []KiroToolWrapper { var kiroTools []KiroToolWrapper @@ -314,6 +403,13 @@ func convertOpenAIToolsToKiro(tools gjson.Result) []KiroToolWrapper { description := fn.Get("description").String() parameters := fn.Get("parameters").Value() + // Shorten tool name if it exceeds 64 characters (common with MCP tools) + originalName := name + name = shortenToolNameIfNeeded(name) + if name != originalName { + log.Debugf("kiro-openai: shortened tool name from '%s' to '%s'", originalName, name) + } + // CRITICAL FIX: Kiro API requires non-empty description if strings.TrimSpace(description) == "" { description = fmt.Sprintf("Tool: %s", name) @@ -584,6 +680,153 @@ func buildFinalContent(content, systemPrompt string, toolResults []KiroToolResul return finalContent } +// checkThinkingModeFromOpenAI checks if thinking mode is enabled in the OpenAI request. +// Returns (thinkingEnabled, budgetTokens). +// Supports: +// - reasoning_effort parameter (low/medium/high/auto) +// - Model name containing "thinking" or "reason" +// - tag in system prompt (AMP/Cursor format) +func checkThinkingModeFromOpenAI(openaiBody []byte) (bool, int64) { + var budgetTokens int64 = 16000 // Default budget + + // Check OpenAI format: reasoning_effort parameter + // Valid values: "low", "medium", "high", "auto" (not "none") + reasoningEffort := gjson.GetBytes(openaiBody, "reasoning_effort") + if reasoningEffort.Exists() { + effort := reasoningEffort.String() + if effort != "" && effort != "none" { + log.Debugf("kiro-openai: thinking mode enabled via reasoning_effort: %s", effort) + // Adjust budget based on effort level + switch effort { + case "low": + budgetTokens = 8000 + case "medium": + budgetTokens = 16000 + case "high": + budgetTokens = 32000 + case "auto": + budgetTokens = 16000 + } + return true, budgetTokens + } + } + + // Check AMP/Cursor format: interleaved in system prompt + bodyStr := string(openaiBody) + if strings.Contains(bodyStr, "") && strings.Contains(bodyStr, "") { + startTag := "" + endTag := "" + startIdx := strings.Index(bodyStr, startTag) + if startIdx >= 0 { + startIdx += len(startTag) + endIdx := strings.Index(bodyStr[startIdx:], endTag) + if endIdx >= 0 { + thinkingMode := bodyStr[startIdx : startIdx+endIdx] + if thinkingMode == "interleaved" || thinkingMode == "enabled" { + log.Debugf("kiro-openai: thinking mode enabled via AMP/Cursor format: %s", thinkingMode) + // Try to extract max_thinking_length if present + if maxLenStart := strings.Index(bodyStr, ""); maxLenStart >= 0 { + maxLenStart += len("") + if maxLenEnd := strings.Index(bodyStr[maxLenStart:], ""); maxLenEnd >= 0 { + maxLenStr := bodyStr[maxLenStart : maxLenStart+maxLenEnd] + if parsed, err := fmt.Sscanf(maxLenStr, "%d", &budgetTokens); err == nil && parsed == 1 { + log.Debugf("kiro-openai: extracted max_thinking_length: %d", budgetTokens) + } + } + } + return true, budgetTokens + } + } + } + } + + // Check model name for thinking hints + model := gjson.GetBytes(openaiBody, "model").String() + modelLower := strings.ToLower(model) + if strings.Contains(modelLower, "thinking") || strings.Contains(modelLower, "-reason") { + log.Debugf("kiro-openai: thinking mode enabled via model name hint: %s", model) + return true, budgetTokens + } + + log.Debugf("kiro-openai: no thinking mode detected in OpenAI request") + return false, budgetTokens +} + +// extractToolChoiceHint extracts tool_choice from OpenAI request and returns a system prompt hint. +// OpenAI tool_choice values: +// - "none": Don't use any tools +// - "auto": Model decides (default, no hint needed) +// - "required": Must use at least one tool +// - {"type":"function","function":{"name":"..."}} : Must use specific tool +func extractToolChoiceHint(openaiBody []byte) string { + toolChoice := gjson.GetBytes(openaiBody, "tool_choice") + if !toolChoice.Exists() { + return "" + } + + // Handle string values + if toolChoice.Type == gjson.String { + switch toolChoice.String() { + case "none": + // Note: When tool_choice is "none", we should ideally not pass tools at all + // But since we can't modify tool passing here, we add a strong hint + return "[INSTRUCTION: Do NOT use any tools. Respond with text only.]" + case "required": + return "[INSTRUCTION: You MUST use at least one of the available tools to respond. Do not respond with text only - always make a tool call.]" + case "auto": + // Default behavior, no hint needed + return "" + } + } + + // Handle object value: {"type":"function","function":{"name":"..."}} + if toolChoice.IsObject() { + if toolChoice.Get("type").String() == "function" { + toolName := toolChoice.Get("function.name").String() + if toolName != "" { + return fmt.Sprintf("[INSTRUCTION: You MUST use the tool named '%s' to respond. Do not use any other tool or respond with text only.]", toolName) + } + } + } + + return "" +} + +// extractResponseFormatHint extracts response_format from OpenAI request and returns a system prompt hint. +// OpenAI response_format values: +// - {"type": "text"}: Default, no hint needed +// - {"type": "json_object"}: Must respond with valid JSON +// - {"type": "json_schema", "json_schema": {...}}: Must respond with JSON matching schema +func extractResponseFormatHint(openaiBody []byte) string { + responseFormat := gjson.GetBytes(openaiBody, "response_format") + if !responseFormat.Exists() { + return "" + } + + formatType := responseFormat.Get("type").String() + switch formatType { + case "json_object": + return "[INSTRUCTION: You MUST respond with valid JSON only. Do not include any text before or after the JSON. Do not wrap the JSON in markdown code blocks. Output raw JSON directly.]" + case "json_schema": + // Extract schema if provided + schema := responseFormat.Get("json_schema.schema") + if schema.Exists() { + schemaStr := schema.Raw + // Truncate if too long + if len(schemaStr) > 500 { + schemaStr = schemaStr[:500] + "..." + } + return fmt.Sprintf("[INSTRUCTION: You MUST respond with valid JSON that matches this schema: %s. Do not include any text before or after the JSON. Do not wrap the JSON in markdown code blocks. Output raw JSON directly.]", schemaStr) + } + return "[INSTRUCTION: You MUST respond with valid JSON only. Do not include any text before or after the JSON. Do not wrap the JSON in markdown code blocks. Output raw JSON directly.]" + case "text": + // Default behavior, no hint needed + return "" + } + + return "" +} + // deduplicateToolResults removes duplicate tool results func deduplicateToolResults(toolResults []KiroToolResult) []KiroToolResult { if len(toolResults) == 0 { diff --git a/internal/translator/kiro/openai/kiro_openai_stream.go b/internal/translator/kiro/openai/kiro_openai_stream.go index d550a8d8..e72d970e 100644 --- a/internal/translator/kiro/openai/kiro_openai_stream.go +++ b/internal/translator/kiro/openai/kiro_openai_stream.go @@ -5,7 +5,6 @@ package openai import ( "encoding/json" - "fmt" "time" "github.com/google/uuid" @@ -34,9 +33,12 @@ func NewOpenAIStreamState(model string) *OpenAIStreamState { } } -// FormatSSEEvent formats a JSON payload as an SSE event +// FormatSSEEvent formats a JSON payload for SSE streaming. +// Note: This returns raw JSON data without "data:" prefix. +// The SSE "data:" prefix is added by the Handler layer (e.g., openai_handlers.go) +// to maintain architectural consistency and avoid double-prefix issues. func FormatSSEEvent(data []byte) string { - return fmt.Sprintf("data: %s", string(data)) + return string(data) } // BuildOpenAISSETextDelta creates an SSE event for text content delta @@ -130,9 +132,12 @@ func BuildOpenAISSEUsage(state *OpenAIStreamState, usageInfo usage.Detail) strin return FormatSSEEvent(result) } -// BuildOpenAISSEDone creates the final [DONE] SSE event +// BuildOpenAISSEDone creates the final [DONE] SSE event. +// Note: This returns raw "[DONE]" without "data:" prefix. +// The SSE "data:" prefix is added by the Handler layer (e.g., openai_handlers.go) +// to maintain architectural consistency and avoid double-prefix issues. func BuildOpenAISSEDone() string { - return "data: [DONE]" + return "[DONE]" } // buildBaseChunk creates a base chunk structure for streaming From c3ed3b40ea99a442a52b47bc2f68947370c5f0fe Mon Sep 17 00:00:00 2001 From: Ravens2121 Date: Sun, 14 Dec 2025 16:37:08 +0800 Subject: [PATCH 37/38] feat(kiro): Add token usage cross-validation and simplify thinking mode handling --- internal/runtime/executor/kiro_executor.go | 91 ++++++++++++++----- .../kiro/claude/kiro_claude_request.go | 7 +- .../kiro/openai/kiro_openai_request.go | 7 +- 3 files changed, 78 insertions(+), 27 deletions(-) diff --git a/internal/runtime/executor/kiro_executor.go b/internal/runtime/executor/kiro_executor.go index b9a38272..e1d752a8 100644 --- a/internal/runtime/executor/kiro_executor.go +++ b/internal/runtime/executor/kiro_executor.go @@ -56,6 +56,34 @@ var ( usageUpdateTimeInterval = 15 * time.Second // Or every 15 seconds, whichever comes first ) +// kiroRateMultipliers maps Kiro model IDs to their credit multipliers. +// Used for cross-validation of token estimates against upstream credit usage. +// Source: Kiro API model definitions +var kiroRateMultipliers = map[string]float64{ + "claude-haiku-4.5": 0.4, + "claude-sonnet-4": 1.3, + "claude-sonnet-4.5": 1.3, + "claude-opus-4.5": 2.2, + "auto": 1.0, // Default multiplier for auto model selection +} + +// getKiroRateMultiplier returns the credit multiplier for a given model ID. +// Returns 1.0 as default if model is not found. +func getKiroRateMultiplier(modelID string) float64 { + if multiplier, ok := kiroRateMultipliers[modelID]; ok { + return multiplier + } + return 1.0 +} + +// abs64 returns the absolute value of an int64. +func abs64(x int64) int64 { + if x < 0 { + return -x + } + return x +} + // kiroEndpointConfig bundles endpoint URL with its compatible Origin and AmzTarget values. // This solves the "triple mismatch" problem where different endpoints require matching // Origin and X-Amz-Target header values. @@ -166,7 +194,8 @@ type KiroExecutor struct { // This is critical because OpenAI and Claude formats have different tool structures: // - OpenAI: tools[].function.name, tools[].function.description // - Claude: tools[].name, tools[].description -func buildKiroPayloadForFormat(body []byte, modelID, profileArn, origin string, isAgentic, isChatOnly bool, sourceFormat sdktranslator.Format) []byte { +// Returns the serialized JSON payload and a boolean indicating whether thinking mode was injected. +func buildKiroPayloadForFormat(body []byte, modelID, profileArn, origin string, isAgentic, isChatOnly bool, sourceFormat sdktranslator.Format) ([]byte, bool) { switch sourceFormat.String() { case "openai": log.Debugf("kiro: using OpenAI payload builder for source format: %s", sourceFormat.String()) @@ -248,7 +277,7 @@ func (e *KiroExecutor) executeWithRetry(ctx context.Context, auth *cliproxyauth. // Rebuild payload with the correct origin for this endpoint // Each endpoint requires its matching Origin value in the request body - kiroPayload = buildKiroPayloadForFormat(body, kiroModelID, profileArn, currentOrigin, isAgentic, isChatOnly, from) + kiroPayload, _ = buildKiroPayloadForFormat(body, kiroModelID, profileArn, currentOrigin, isAgentic, isChatOnly, from) log.Debugf("kiro: trying endpoint %d/%d: %s (Name: %s, Origin: %s)", endpointIdx+1, len(endpointConfigs), url, endpointConfig.Name, currentOrigin) @@ -358,7 +387,7 @@ func (e *KiroExecutor) executeWithRetry(ctx context.Context, auth *cliproxyauth. auth = refreshedAuth accessToken, profileArn = kiroCredentials(auth) // Rebuild payload with new profile ARN if changed - kiroPayload = buildKiroPayloadForFormat(body, kiroModelID, profileArn, currentOrigin, isAgentic, isChatOnly, from) + kiroPayload, _ = buildKiroPayloadForFormat(body, kiroModelID, profileArn, currentOrigin, isAgentic, isChatOnly, from) log.Infof("kiro: token refreshed successfully, retrying request") continue } @@ -415,7 +444,7 @@ func (e *KiroExecutor) executeWithRetry(ctx context.Context, auth *cliproxyauth. if refreshedAuth != nil { auth = refreshedAuth accessToken, profileArn = kiroCredentials(auth) - kiroPayload = buildKiroPayloadForFormat(body, kiroModelID, profileArn, currentOrigin, isAgentic, isChatOnly, from) + kiroPayload, _ = buildKiroPayloadForFormat(body, kiroModelID, profileArn, currentOrigin, isAgentic, isChatOnly, from) log.Infof("kiro: token refreshed for 403, retrying request") continue } @@ -554,7 +583,10 @@ func (e *KiroExecutor) executeStreamWithRetry(ctx context.Context, auth *cliprox // Rebuild payload with the correct origin for this endpoint // Each endpoint requires its matching Origin value in the request body - kiroPayload = buildKiroPayloadForFormat(body, kiroModelID, profileArn, currentOrigin, isAgentic, isChatOnly, from) + kiroPayload, _ = buildKiroPayloadForFormat(body, kiroModelID, profileArn, currentOrigin, isAgentic, isChatOnly, from) + // Kiro API always returns tags regardless of whether thinking mode was requested + // So we always enable thinking parsing for Kiro responses + thinkingEnabled := true log.Debugf("kiro: stream trying endpoint %d/%d: %s (Name: %s, Origin: %s)", endpointIdx+1, len(endpointConfigs), url, endpointConfig.Name, currentOrigin) @@ -677,7 +709,7 @@ func (e *KiroExecutor) executeStreamWithRetry(ctx context.Context, auth *cliprox auth = refreshedAuth accessToken, profileArn = kiroCredentials(auth) // Rebuild payload with new profile ARN if changed - kiroPayload = buildKiroPayloadForFormat(body, kiroModelID, profileArn, currentOrigin, isAgentic, isChatOnly, from) + kiroPayload, _ = buildKiroPayloadForFormat(body, kiroModelID, profileArn, currentOrigin, isAgentic, isChatOnly, from) log.Infof("kiro: token refreshed successfully, retrying stream request") continue } @@ -734,7 +766,7 @@ func (e *KiroExecutor) executeStreamWithRetry(ctx context.Context, auth *cliprox if refreshedAuth != nil { auth = refreshedAuth accessToken, profileArn = kiroCredentials(auth) - kiroPayload = buildKiroPayloadForFormat(body, kiroModelID, profileArn, currentOrigin, isAgentic, isChatOnly, from) + kiroPayload, _ = buildKiroPayloadForFormat(body, kiroModelID, profileArn, currentOrigin, isAgentic, isChatOnly, from) log.Infof("kiro: token refreshed for 403, retrying stream request") continue } @@ -758,7 +790,7 @@ func (e *KiroExecutor) executeStreamWithRetry(ctx context.Context, auth *cliprox out := make(chan cliproxyexecutor.StreamChunk) - go func(resp *http.Response) { + go func(resp *http.Response, thinkingEnabled bool) { defer close(out) defer func() { if r := recover(); r != nil { @@ -772,21 +804,12 @@ func (e *KiroExecutor) executeStreamWithRetry(ctx context.Context, auth *cliprox } }() - // Check if thinking mode was enabled in the original request - // Only parse tags when thinking was explicitly requested - // Check multiple sources: original request, pre-translation payload, and translated body - // This handles different client formats (Claude API, OpenAI, AMP/Cursor) - thinkingEnabled := kiroclaude.IsThinkingEnabled(opts.OriginalRequest) - if !thinkingEnabled { - thinkingEnabled = kiroclaude.IsThinkingEnabled(req.Payload) - } - if !thinkingEnabled { - thinkingEnabled = kiroclaude.IsThinkingEnabled(body) - } - log.Debugf("kiro: stream thinkingEnabled = %v", thinkingEnabled) + // Kiro API always returns tags regardless of request parameters + // So we always enable thinking parsing for Kiro responses + log.Debugf("kiro: stream thinkingEnabled = %v (always true for Kiro)", thinkingEnabled) e.streamToChannel(ctx, resp.Body, out, from, req.Model, opts.OriginalRequest, body, reporter, thinkingEnabled) - }(httpResp) + }(httpResp, thinkingEnabled) return out, nil } @@ -2907,6 +2930,32 @@ func (e *KiroExecutor) streamToChannel(ctx context.Context, body io.Reader, out totalUsage.InputTokens, totalUsage.OutputTokens, totalUsage.TotalTokens) } + // Cross-validate token estimates using creditUsage if available + // This helps detect estimation errors and provides insight into actual usage + if upstreamCreditUsage > 0 { + rateMultiplier := getKiroRateMultiplier(model) + // Credit usage represents total tokens (input + output) * multiplier / 1000 + // Formula: total_tokens = creditUsage * 1000 / rateMultiplier + creditBasedTotal := int64(upstreamCreditUsage * 1000 / rateMultiplier) + localTotal := totalUsage.InputTokens + totalUsage.OutputTokens + + // Calculate difference percentage + var diffPercent float64 + if localTotal > 0 { + diffPercent = float64(abs64(creditBasedTotal-localTotal)) / float64(localTotal) * 100 + } + + // Log cross-validation result + if diffPercent > 20 { + // Significant mismatch - may indicate thinking tokens or estimation error + log.Warnf("kiro: token estimation mismatch > 20%% - local: %d, credit-based: %d (credits: %.4f, multiplier: %.1f, diff: %.1f%%)", + localTotal, creditBasedTotal, upstreamCreditUsage, rateMultiplier, diffPercent) + } else { + log.Debugf("kiro: token cross-validation - local: %d, credit-based: %d (credits: %.4f, multiplier: %.1f, diff: %.1f%%)", + localTotal, creditBasedTotal, upstreamCreditUsage, rateMultiplier, diffPercent) + } + } + // Determine stop reason: prefer upstream, then detect tool_use, default to end_turn stopReason := upstreamStopReason if stopReason == "" { diff --git a/internal/translator/kiro/claude/kiro_claude_request.go b/internal/translator/kiro/claude/kiro_claude_request.go index ae42b186..052d671c 100644 --- a/internal/translator/kiro/claude/kiro_claude_request.go +++ b/internal/translator/kiro/claude/kiro_claude_request.go @@ -135,7 +135,8 @@ func ConvertClaudeRequestToKiro(modelName string, inputRawJSON []byte, stream bo // isAgentic parameter enables chunked write optimization prompt for -agentic model variants. // isChatOnly parameter disables tool calling for -chat model variants (pure conversation mode). // Supports thinking mode - when Claude API thinking parameter is present, injects thinkingHint. -func BuildKiroPayload(claudeBody []byte, modelID, profileArn, origin string, isAgentic, isChatOnly bool) []byte { +// Returns the payload and a boolean indicating whether thinking mode was injected. +func BuildKiroPayload(claudeBody []byte, modelID, profileArn, origin string, isAgentic, isChatOnly bool) ([]byte, bool) { // Extract max_tokens for potential use in inferenceConfig // Handle -1 as "use maximum" (Kiro max output is ~32000 tokens) const kiroMaxOutputTokens = 32000 @@ -307,10 +308,10 @@ func BuildKiroPayload(claudeBody []byte, modelID, profileArn, origin string, isA result, err := json.Marshal(payload) if err != nil { log.Debugf("kiro: failed to marshal payload: %v", err) - return nil + return nil, false } - return result + return result, thinkingEnabled } // normalizeOrigin normalizes origin value for Kiro API compatibility diff --git a/internal/translator/kiro/openai/kiro_openai_request.go b/internal/translator/kiro/openai/kiro_openai_request.go index 21b15aa0..cb97a340 100644 --- a/internal/translator/kiro/openai/kiro_openai_request.go +++ b/internal/translator/kiro/openai/kiro_openai_request.go @@ -133,7 +133,8 @@ func ConvertOpenAIRequestToKiro(modelName string, inputRawJSON []byte, stream bo // origin parameter determines which quota to use: "CLI" for Amazon Q, "AI_EDITOR" for Kiro IDE. // isAgentic parameter enables chunked write optimization prompt for -agentic model variants. // isChatOnly parameter disables tool calling for -chat model variants (pure conversation mode). -func BuildKiroPayloadFromOpenAI(openaiBody []byte, modelID, profileArn, origin string, isAgentic, isChatOnly bool) []byte { +// Returns the payload and a boolean indicating whether thinking mode was injected. +func BuildKiroPayloadFromOpenAI(openaiBody []byte, modelID, profileArn, origin string, isAgentic, isChatOnly bool) ([]byte, bool) { // Extract max_tokens for potential use in inferenceConfig // Handle -1 as "use maximum" (Kiro max output is ~32000 tokens) const kiroMaxOutputTokens = 32000 @@ -311,10 +312,10 @@ func BuildKiroPayloadFromOpenAI(openaiBody []byte, modelID, profileArn, origin s result, err := json.Marshal(payload) if err != nil { log.Debugf("kiro-openai: failed to marshal payload: %v", err) - return nil + return nil, false } - return result + return result, thinkingEnabled } // normalizeOrigin normalizes origin value for Kiro API compatibility From de0ea3ac49a0a9a9ef677342a2aa21cec2e8c68a Mon Sep 17 00:00:00 2001 From: Ravens2121 Date: Sun, 14 Dec 2025 16:46:02 +0800 Subject: [PATCH 38/38] fix(kiro): Always parse thinking tags from Kiro API responses Amp-Thread-ID: https://ampcode.com/threads/T-019b1c00-17b4-713d-a8cc-813b71181934 Co-authored-by: Amp --- internal/runtime/executor/kiro_executor.go | 54 ---------------------- 1 file changed, 54 deletions(-) diff --git a/internal/runtime/executor/kiro_executor.go b/internal/runtime/executor/kiro_executor.go index e1d752a8..be6be1ed 100644 --- a/internal/runtime/executor/kiro_executor.go +++ b/internal/runtime/executor/kiro_executor.go @@ -56,34 +56,6 @@ var ( usageUpdateTimeInterval = 15 * time.Second // Or every 15 seconds, whichever comes first ) -// kiroRateMultipliers maps Kiro model IDs to their credit multipliers. -// Used for cross-validation of token estimates against upstream credit usage. -// Source: Kiro API model definitions -var kiroRateMultipliers = map[string]float64{ - "claude-haiku-4.5": 0.4, - "claude-sonnet-4": 1.3, - "claude-sonnet-4.5": 1.3, - "claude-opus-4.5": 2.2, - "auto": 1.0, // Default multiplier for auto model selection -} - -// getKiroRateMultiplier returns the credit multiplier for a given model ID. -// Returns 1.0 as default if model is not found. -func getKiroRateMultiplier(modelID string) float64 { - if multiplier, ok := kiroRateMultipliers[modelID]; ok { - return multiplier - } - return 1.0 -} - -// abs64 returns the absolute value of an int64. -func abs64(x int64) int64 { - if x < 0 { - return -x - } - return x -} - // kiroEndpointConfig bundles endpoint URL with its compatible Origin and AmzTarget values. // This solves the "triple mismatch" problem where different endpoints require matching // Origin and X-Amz-Target header values. @@ -2930,32 +2902,6 @@ func (e *KiroExecutor) streamToChannel(ctx context.Context, body io.Reader, out totalUsage.InputTokens, totalUsage.OutputTokens, totalUsage.TotalTokens) } - // Cross-validate token estimates using creditUsage if available - // This helps detect estimation errors and provides insight into actual usage - if upstreamCreditUsage > 0 { - rateMultiplier := getKiroRateMultiplier(model) - // Credit usage represents total tokens (input + output) * multiplier / 1000 - // Formula: total_tokens = creditUsage * 1000 / rateMultiplier - creditBasedTotal := int64(upstreamCreditUsage * 1000 / rateMultiplier) - localTotal := totalUsage.InputTokens + totalUsage.OutputTokens - - // Calculate difference percentage - var diffPercent float64 - if localTotal > 0 { - diffPercent = float64(abs64(creditBasedTotal-localTotal)) / float64(localTotal) * 100 - } - - // Log cross-validation result - if diffPercent > 20 { - // Significant mismatch - may indicate thinking tokens or estimation error - log.Warnf("kiro: token estimation mismatch > 20%% - local: %d, credit-based: %d (credits: %.4f, multiplier: %.1f, diff: %.1f%%)", - localTotal, creditBasedTotal, upstreamCreditUsage, rateMultiplier, diffPercent) - } else { - log.Debugf("kiro: token cross-validation - local: %d, credit-based: %d (credits: %.4f, multiplier: %.1f, diff: %.1f%%)", - localTotal, creditBasedTotal, upstreamCreditUsage, rateMultiplier, diffPercent) - } - } - // Determine stop reason: prefer upstream, then detect tool_use, default to end_turn stopReason := upstreamStopReason if stopReason == "" {