mirror of
https://github.com/router-for-me/CLIProxyAPIPlus.git
synced 2026-04-26 04:16:09 +00:00
Compare commits
13 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
92ca5078c1 | ||
|
|
aca8523060 | ||
|
|
1ea0cff3a4 | ||
|
|
75793a18f0 | ||
|
|
58866b21cb | ||
|
|
db80b20bc2 | ||
|
|
05b499fb83 | ||
|
|
7621ec609e | ||
|
|
9f511f0024 | ||
|
|
ba6aa5fbbe | ||
|
|
1c52a89535 | ||
|
|
e7cedbee6e | ||
|
|
15c3cc3a50 |
@@ -11,7 +11,7 @@ The Plus release stays in lockstep with the mainline features.
|
|||||||
## Differences from the Mainline
|
## Differences from the Mainline
|
||||||
|
|
||||||
- Added GitHub Copilot support (OAuth login), provided by [em4go](https://github.com/em4go/CLIProxyAPI/tree/feature/github-copilot-auth)
|
- Added GitHub Copilot support (OAuth login), provided by [em4go](https://github.com/em4go/CLIProxyAPI/tree/feature/github-copilot-auth)
|
||||||
- Added Kiro (AWS CodeWhisperer) support (OAuth login), provided by [fuko2935](https://github.com/fuko2935/CLIProxyAPI/tree/feature/kiro-integration)
|
- Added Kiro (AWS CodeWhisperer) support (OAuth login), provided by [fuko2935](https://github.com/fuko2935/CLIProxyAPI/tree/feature/kiro-integration), [Ravens2121](https://github.com/Ravens2121/CLIProxyAPIPlus/)
|
||||||
|
|
||||||
## Contributing
|
## Contributing
|
||||||
|
|
||||||
|
|||||||
@@ -11,7 +11,7 @@
|
|||||||
## 与主线版本版本差异
|
## 与主线版本版本差异
|
||||||
|
|
||||||
- 新增 GitHub Copilot 支持(OAuth 登录),由[em4go](https://github.com/em4go/CLIProxyAPI/tree/feature/github-copilot-auth)提供
|
- 新增 GitHub Copilot 支持(OAuth 登录),由[em4go](https://github.com/em4go/CLIProxyAPI/tree/feature/github-copilot-auth)提供
|
||||||
- 新增 Kiro (AWS CodeWhisperer) 支持 (OAuth 登录), 由[fuko2935](https://github.com/fuko2935/CLIProxyAPI/tree/feature/kiro-integration)提供
|
- 新增 Kiro (AWS CodeWhisperer) 支持 (OAuth 登录), 由[fuko2935](https://github.com/fuko2935/CLIProxyAPI/tree/feature/kiro-integration)、[Ravens2121](https://github.com/Ravens2121/CLIProxyAPIPlus/)提供
|
||||||
|
|
||||||
## 贡献
|
## 贡献
|
||||||
|
|
||||||
|
|||||||
@@ -3,6 +3,9 @@ package management
|
|||||||
import (
|
import (
|
||||||
"bytes"
|
"bytes"
|
||||||
"context"
|
"context"
|
||||||
|
"crypto/rand"
|
||||||
|
"crypto/sha256"
|
||||||
|
"encoding/base64"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
@@ -23,6 +26,7 @@ import (
|
|||||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/auth/codex"
|
"github.com/router-for-me/CLIProxyAPI/v6/internal/auth/codex"
|
||||||
geminiAuth "github.com/router-for-me/CLIProxyAPI/v6/internal/auth/gemini"
|
geminiAuth "github.com/router-for-me/CLIProxyAPI/v6/internal/auth/gemini"
|
||||||
iflowauth "github.com/router-for-me/CLIProxyAPI/v6/internal/auth/iflow"
|
iflowauth "github.com/router-for-me/CLIProxyAPI/v6/internal/auth/iflow"
|
||||||
|
kiroauth "github.com/router-for-me/CLIProxyAPI/v6/internal/auth/kiro"
|
||||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/auth/qwen"
|
"github.com/router-for-me/CLIProxyAPI/v6/internal/auth/qwen"
|
||||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/interfaces"
|
"github.com/router-for-me/CLIProxyAPI/v6/internal/interfaces"
|
||||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/misc"
|
"github.com/router-for-me/CLIProxyAPI/v6/internal/misc"
|
||||||
@@ -1745,6 +1749,17 @@ func (h *Handler) RequestIFlowCookieToken(c *gin.Context) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Check for duplicate BXAuth before authentication
|
||||||
|
bxAuth := iflowauth.ExtractBXAuth(cookieValue)
|
||||||
|
if existingFile, err := iflowauth.CheckDuplicateBXAuth(h.cfg.AuthDir, bxAuth); err != nil {
|
||||||
|
c.JSON(http.StatusInternalServerError, gin.H{"status": "error", "error": "failed to check duplicate"})
|
||||||
|
return
|
||||||
|
} else if existingFile != "" {
|
||||||
|
existingFileName := filepath.Base(existingFile)
|
||||||
|
c.JSON(http.StatusConflict, gin.H{"status": "error", "error": "duplicate BXAuth found", "existing_file": existingFileName})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
authSvc := iflowauth.NewIFlowAuth(h.cfg)
|
authSvc := iflowauth.NewIFlowAuth(h.cfg)
|
||||||
tokenData, errAuth := authSvc.AuthenticateWithCookie(ctx, cookieValue)
|
tokenData, errAuth := authSvc.AuthenticateWithCookie(ctx, cookieValue)
|
||||||
if errAuth != nil {
|
if errAuth != nil {
|
||||||
@@ -1767,11 +1782,12 @@ func (h *Handler) RequestIFlowCookieToken(c *gin.Context) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
tokenStorage.Email = email
|
tokenStorage.Email = email
|
||||||
|
timestamp := time.Now().Unix()
|
||||||
|
|
||||||
record := &coreauth.Auth{
|
record := &coreauth.Auth{
|
||||||
ID: fmt.Sprintf("iflow-%s.json", fileName),
|
ID: fmt.Sprintf("iflow-%s-%d.json", fileName, timestamp),
|
||||||
Provider: "iflow",
|
Provider: "iflow",
|
||||||
FileName: fmt.Sprintf("iflow-%s.json", fileName),
|
FileName: fmt.Sprintf("iflow-%s-%d.json", fileName, timestamp),
|
||||||
Storage: tokenStorage,
|
Storage: tokenStorage,
|
||||||
Metadata: map[string]any{
|
Metadata: map[string]any{
|
||||||
"email": email,
|
"email": email,
|
||||||
@@ -2142,9 +2158,35 @@ func checkCloudAPIIsEnabled(ctx context.Context, httpClient *http.Client, projec
|
|||||||
|
|
||||||
func (h *Handler) GetAuthStatus(c *gin.Context) {
|
func (h *Handler) GetAuthStatus(c *gin.Context) {
|
||||||
state := c.Query("state")
|
state := c.Query("state")
|
||||||
if err, ok := getOAuthStatus(state); ok {
|
if statusValue, ok := getOAuthStatus(state); ok {
|
||||||
if err != "" {
|
if statusValue != "" {
|
||||||
c.JSON(200, gin.H{"status": "error", "error": err})
|
// Check for device_code prefix (Kiro AWS Builder ID flow)
|
||||||
|
// Format: "device_code|verification_url|user_code"
|
||||||
|
// Using "|" as separator because URLs contain ":"
|
||||||
|
if strings.HasPrefix(statusValue, "device_code|") {
|
||||||
|
parts := strings.SplitN(statusValue, "|", 3)
|
||||||
|
if len(parts) == 3 {
|
||||||
|
c.JSON(200, gin.H{
|
||||||
|
"status": "device_code",
|
||||||
|
"verification_url": parts[1],
|
||||||
|
"user_code": parts[2],
|
||||||
|
})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// Check for auth_url prefix (Kiro social auth flow)
|
||||||
|
// Format: "auth_url|url"
|
||||||
|
// Using "|" as separator because URLs contain ":"
|
||||||
|
if strings.HasPrefix(statusValue, "auth_url|") {
|
||||||
|
authURL := strings.TrimPrefix(statusValue, "auth_url|")
|
||||||
|
c.JSON(200, gin.H{
|
||||||
|
"status": "auth_url",
|
||||||
|
"url": authURL,
|
||||||
|
})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
// Otherwise treat as error
|
||||||
|
c.JSON(200, gin.H{"status": "error", "error": statusValue})
|
||||||
} else {
|
} else {
|
||||||
c.JSON(200, gin.H{"status": "wait"})
|
c.JSON(200, gin.H{"status": "wait"})
|
||||||
return
|
return
|
||||||
@@ -2154,3 +2196,295 @@ func (h *Handler) GetAuthStatus(c *gin.Context) {
|
|||||||
}
|
}
|
||||||
deleteOAuthStatus(state)
|
deleteOAuthStatus(state)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
const kiroCallbackPort = 9876
|
||||||
|
|
||||||
|
func (h *Handler) RequestKiroToken(c *gin.Context) {
|
||||||
|
ctx := context.Background()
|
||||||
|
|
||||||
|
// Get the login method from query parameter (default: aws for device code flow)
|
||||||
|
method := strings.ToLower(strings.TrimSpace(c.Query("method")))
|
||||||
|
if method == "" {
|
||||||
|
method = "aws"
|
||||||
|
}
|
||||||
|
|
||||||
|
fmt.Println("Initializing Kiro authentication...")
|
||||||
|
|
||||||
|
state := fmt.Sprintf("kiro-%d", time.Now().UnixNano())
|
||||||
|
|
||||||
|
switch method {
|
||||||
|
case "aws", "builder-id":
|
||||||
|
// AWS Builder ID uses device code flow (no callback needed)
|
||||||
|
go func() {
|
||||||
|
ssoClient := kiroauth.NewSSOOIDCClient(h.cfg)
|
||||||
|
|
||||||
|
// Step 1: Register client
|
||||||
|
fmt.Println("Registering client...")
|
||||||
|
regResp, err := ssoClient.RegisterClient(ctx)
|
||||||
|
if err != nil {
|
||||||
|
log.Errorf("Failed to register client: %v", err)
|
||||||
|
setOAuthStatus(state, "Failed to register client")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Step 2: Start device authorization
|
||||||
|
fmt.Println("Starting device authorization...")
|
||||||
|
authResp, err := ssoClient.StartDeviceAuthorization(ctx, regResp.ClientID, regResp.ClientSecret)
|
||||||
|
if err != nil {
|
||||||
|
log.Errorf("Failed to start device auth: %v", err)
|
||||||
|
setOAuthStatus(state, "Failed to start device authorization")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Store the verification URL for the frontend to display
|
||||||
|
// Using "|" as separator because URLs contain ":"
|
||||||
|
setOAuthStatus(state, "device_code|"+authResp.VerificationURIComplete+"|"+authResp.UserCode)
|
||||||
|
|
||||||
|
// Step 3: Poll for token
|
||||||
|
fmt.Println("Waiting for authorization...")
|
||||||
|
interval := 5 * time.Second
|
||||||
|
if authResp.Interval > 0 {
|
||||||
|
interval = time.Duration(authResp.Interval) * time.Second
|
||||||
|
}
|
||||||
|
deadline := time.Now().Add(time.Duration(authResp.ExpiresIn) * time.Second)
|
||||||
|
|
||||||
|
for time.Now().Before(deadline) {
|
||||||
|
select {
|
||||||
|
case <-ctx.Done():
|
||||||
|
setOAuthStatus(state, "Authorization cancelled")
|
||||||
|
return
|
||||||
|
case <-time.After(interval):
|
||||||
|
tokenResp, err := ssoClient.CreateToken(ctx, regResp.ClientID, regResp.ClientSecret, authResp.DeviceCode)
|
||||||
|
if err != nil {
|
||||||
|
errStr := err.Error()
|
||||||
|
if strings.Contains(errStr, "authorization_pending") {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if strings.Contains(errStr, "slow_down") {
|
||||||
|
interval += 5 * time.Second
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
log.Errorf("Token creation failed: %v", err)
|
||||||
|
setOAuthStatus(state, "Token creation failed")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Success! Save the token
|
||||||
|
expiresAt := time.Now().Add(time.Duration(tokenResp.ExpiresIn) * time.Second)
|
||||||
|
email := kiroauth.ExtractEmailFromJWT(tokenResp.AccessToken)
|
||||||
|
|
||||||
|
idPart := kiroauth.SanitizeEmailForFilename(email)
|
||||||
|
if idPart == "" {
|
||||||
|
idPart = fmt.Sprintf("%d", time.Now().UnixNano()%100000)
|
||||||
|
}
|
||||||
|
|
||||||
|
now := time.Now()
|
||||||
|
fileName := fmt.Sprintf("kiro-aws-%s.json", idPart)
|
||||||
|
|
||||||
|
record := &coreauth.Auth{
|
||||||
|
ID: fileName,
|
||||||
|
Provider: "kiro",
|
||||||
|
FileName: fileName,
|
||||||
|
Metadata: map[string]any{
|
||||||
|
"type": "kiro",
|
||||||
|
"access_token": tokenResp.AccessToken,
|
||||||
|
"refresh_token": tokenResp.RefreshToken,
|
||||||
|
"expires_at": expiresAt.Format(time.RFC3339),
|
||||||
|
"auth_method": "builder-id",
|
||||||
|
"provider": "AWS",
|
||||||
|
"client_id": regResp.ClientID,
|
||||||
|
"client_secret": regResp.ClientSecret,
|
||||||
|
"email": email,
|
||||||
|
"last_refresh": now.Format(time.RFC3339),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
savedPath, errSave := h.saveTokenRecord(ctx, record)
|
||||||
|
if errSave != nil {
|
||||||
|
log.Errorf("Failed to save authentication tokens: %v", errSave)
|
||||||
|
setOAuthStatus(state, "Failed to save authentication tokens")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
fmt.Printf("Authentication successful! Token saved to %s\n", savedPath)
|
||||||
|
if email != "" {
|
||||||
|
fmt.Printf("Authenticated as: %s\n", email)
|
||||||
|
}
|
||||||
|
deleteOAuthStatus(state)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
setOAuthStatus(state, "Authorization timed out")
|
||||||
|
}()
|
||||||
|
|
||||||
|
// Return immediately with the state for polling
|
||||||
|
c.JSON(200, gin.H{"status": "ok", "state": state, "method": "device_code"})
|
||||||
|
|
||||||
|
case "google", "github":
|
||||||
|
// Social auth uses protocol handler - for WEB UI we use a callback forwarder
|
||||||
|
provider := "Google"
|
||||||
|
if method == "github" {
|
||||||
|
provider = "Github"
|
||||||
|
}
|
||||||
|
|
||||||
|
isWebUI := isWebUIRequest(c)
|
||||||
|
if isWebUI {
|
||||||
|
targetURL, errTarget := h.managementCallbackURL("/kiro/callback")
|
||||||
|
if errTarget != nil {
|
||||||
|
log.WithError(errTarget).Error("failed to compute kiro callback target")
|
||||||
|
c.JSON(http.StatusInternalServerError, gin.H{"error": "callback server unavailable"})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if _, errStart := startCallbackForwarder(kiroCallbackPort, "kiro", targetURL); errStart != nil {
|
||||||
|
log.WithError(errStart).Error("failed to start kiro callback forwarder")
|
||||||
|
c.JSON(http.StatusInternalServerError, gin.H{"error": "failed to start callback server"})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
go func() {
|
||||||
|
if isWebUI {
|
||||||
|
defer stopCallbackForwarder(kiroCallbackPort)
|
||||||
|
}
|
||||||
|
|
||||||
|
socialClient := kiroauth.NewSocialAuthClient(h.cfg)
|
||||||
|
|
||||||
|
// Generate PKCE codes
|
||||||
|
codeVerifier, codeChallenge, err := generateKiroPKCE()
|
||||||
|
if err != nil {
|
||||||
|
log.Errorf("Failed to generate PKCE: %v", err)
|
||||||
|
setOAuthStatus(state, "Failed to generate PKCE")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Build login URL
|
||||||
|
authURL := fmt.Sprintf("%s/login?idp=%s&redirect_uri=%s&code_challenge=%s&code_challenge_method=S256&state=%s&prompt=select_account",
|
||||||
|
"https://prod.us-east-1.auth.desktop.kiro.dev",
|
||||||
|
provider,
|
||||||
|
url.QueryEscape(kiroauth.KiroRedirectURI),
|
||||||
|
codeChallenge,
|
||||||
|
state,
|
||||||
|
)
|
||||||
|
|
||||||
|
// Store auth URL for frontend
|
||||||
|
// Using "|" as separator because URLs contain ":"
|
||||||
|
setOAuthStatus(state, "auth_url|"+authURL)
|
||||||
|
|
||||||
|
// Wait for callback file
|
||||||
|
waitFile := filepath.Join(h.cfg.AuthDir, fmt.Sprintf(".oauth-kiro-%s.oauth", state))
|
||||||
|
deadline := time.Now().Add(5 * time.Minute)
|
||||||
|
|
||||||
|
for {
|
||||||
|
if time.Now().After(deadline) {
|
||||||
|
log.Error("oauth flow timed out")
|
||||||
|
setOAuthStatus(state, "OAuth flow timed out")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if data, errR := os.ReadFile(waitFile); errR == nil {
|
||||||
|
var m map[string]string
|
||||||
|
_ = json.Unmarshal(data, &m)
|
||||||
|
_ = os.Remove(waitFile)
|
||||||
|
if errStr := m["error"]; errStr != "" {
|
||||||
|
log.Errorf("Authentication failed: %s", errStr)
|
||||||
|
setOAuthStatus(state, "Authentication failed")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if m["state"] != state {
|
||||||
|
log.Errorf("State mismatch")
|
||||||
|
setOAuthStatus(state, "State mismatch")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
code := m["code"]
|
||||||
|
if code == "" {
|
||||||
|
log.Error("No authorization code received")
|
||||||
|
setOAuthStatus(state, "No authorization code received")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Exchange code for tokens
|
||||||
|
tokenReq := &kiroauth.CreateTokenRequest{
|
||||||
|
Code: code,
|
||||||
|
CodeVerifier: codeVerifier,
|
||||||
|
RedirectURI: kiroauth.KiroRedirectURI,
|
||||||
|
}
|
||||||
|
|
||||||
|
tokenResp, errToken := socialClient.CreateToken(ctx, tokenReq)
|
||||||
|
if errToken != nil {
|
||||||
|
log.Errorf("Failed to exchange code for tokens: %v", errToken)
|
||||||
|
setOAuthStatus(state, "Failed to exchange code for tokens")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Save the token
|
||||||
|
expiresIn := tokenResp.ExpiresIn
|
||||||
|
if expiresIn <= 0 {
|
||||||
|
expiresIn = 3600
|
||||||
|
}
|
||||||
|
expiresAt := time.Now().Add(time.Duration(expiresIn) * time.Second)
|
||||||
|
email := kiroauth.ExtractEmailFromJWT(tokenResp.AccessToken)
|
||||||
|
|
||||||
|
idPart := kiroauth.SanitizeEmailForFilename(email)
|
||||||
|
if idPart == "" {
|
||||||
|
idPart = fmt.Sprintf("%d", time.Now().UnixNano()%100000)
|
||||||
|
}
|
||||||
|
|
||||||
|
now := time.Now()
|
||||||
|
fileName := fmt.Sprintf("kiro-%s-%s.json", strings.ToLower(provider), idPart)
|
||||||
|
|
||||||
|
record := &coreauth.Auth{
|
||||||
|
ID: fileName,
|
||||||
|
Provider: "kiro",
|
||||||
|
FileName: fileName,
|
||||||
|
Metadata: map[string]any{
|
||||||
|
"type": "kiro",
|
||||||
|
"access_token": tokenResp.AccessToken,
|
||||||
|
"refresh_token": tokenResp.RefreshToken,
|
||||||
|
"profile_arn": tokenResp.ProfileArn,
|
||||||
|
"expires_at": expiresAt.Format(time.RFC3339),
|
||||||
|
"auth_method": "social",
|
||||||
|
"provider": provider,
|
||||||
|
"email": email,
|
||||||
|
"last_refresh": now.Format(time.RFC3339),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
savedPath, errSave := h.saveTokenRecord(ctx, record)
|
||||||
|
if errSave != nil {
|
||||||
|
log.Errorf("Failed to save authentication tokens: %v", errSave)
|
||||||
|
setOAuthStatus(state, "Failed to save authentication tokens")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
fmt.Printf("Authentication successful! Token saved to %s\n", savedPath)
|
||||||
|
if email != "" {
|
||||||
|
fmt.Printf("Authenticated as: %s\n", email)
|
||||||
|
}
|
||||||
|
deleteOAuthStatus(state)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
time.Sleep(500 * time.Millisecond)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
setOAuthStatus(state, "")
|
||||||
|
c.JSON(200, gin.H{"status": "ok", "state": state, "method": "social"})
|
||||||
|
|
||||||
|
default:
|
||||||
|
c.JSON(http.StatusBadRequest, gin.H{"error": "invalid method, use 'aws', 'google', or 'github'"})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// generateKiroPKCE generates PKCE code verifier and challenge for Kiro OAuth.
|
||||||
|
func generateKiroPKCE() (verifier, challenge string, err error) {
|
||||||
|
b := make([]byte, 32)
|
||||||
|
if _, err := io.ReadFull(rand.Reader, b); err != nil {
|
||||||
|
return "", "", fmt.Errorf("failed to generate random bytes: %w", err)
|
||||||
|
}
|
||||||
|
verifier = base64.RawURLEncoding.EncodeToString(b)
|
||||||
|
|
||||||
|
h := sha256.Sum256([]byte(verifier))
|
||||||
|
challenge = base64.RawURLEncoding.EncodeToString(h[:])
|
||||||
|
|
||||||
|
return verifier, challenge, nil
|
||||||
|
}
|
||||||
|
|||||||
@@ -7,6 +7,7 @@ import (
|
|||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
|
"net"
|
||||||
"net/http"
|
"net/http"
|
||||||
"net/http/httputil"
|
"net/http/httputil"
|
||||||
"net/url"
|
"net/url"
|
||||||
@@ -64,7 +65,15 @@ func createReverseProxy(upstreamURL string, secretSource SecretSource) (*httputi
|
|||||||
// Modify incoming responses to handle gzip without Content-Encoding
|
// Modify incoming responses to handle gzip without Content-Encoding
|
||||||
// This addresses the same issue as inline handler gzip handling, but at the proxy level
|
// This addresses the same issue as inline handler gzip handling, but at the proxy level
|
||||||
proxy.ModifyResponse = func(resp *http.Response) error {
|
proxy.ModifyResponse = func(resp *http.Response) error {
|
||||||
// Only process successful responses
|
// Log upstream error responses for diagnostics (502, 503, etc.)
|
||||||
|
// These are NOT proxy connection errors - the upstream responded with an error status
|
||||||
|
if resp.StatusCode >= 500 {
|
||||||
|
log.Errorf("amp upstream responded with error [%d] for %s %s", resp.StatusCode, resp.Request.Method, resp.Request.URL.Path)
|
||||||
|
} else if resp.StatusCode >= 400 {
|
||||||
|
log.Warnf("amp upstream responded with client error [%d] for %s %s", resp.StatusCode, resp.Request.Method, resp.Request.URL.Path)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Only process successful responses for gzip decompression
|
||||||
if resp.StatusCode < 200 || resp.StatusCode >= 300 {
|
if resp.StatusCode < 200 || resp.StatusCode >= 300 {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
@@ -148,15 +157,29 @@ func createReverseProxy(upstreamURL string, secretSource SecretSource) (*httputi
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// Error handler for proxy failures
|
// Error handler for proxy failures with detailed error classification for diagnostics
|
||||||
proxy.ErrorHandler = func(rw http.ResponseWriter, req *http.Request, err error) {
|
proxy.ErrorHandler = func(rw http.ResponseWriter, req *http.Request, err error) {
|
||||||
// Check if this is a client-side cancellation (normal behavior)
|
// Classify the error type for better diagnostics
|
||||||
|
var errType string
|
||||||
|
if errors.Is(err, context.DeadlineExceeded) {
|
||||||
|
errType = "timeout"
|
||||||
|
} else if errors.Is(err, context.Canceled) {
|
||||||
|
errType = "canceled"
|
||||||
|
} else if netErr, ok := err.(net.Error); ok && netErr.Timeout() {
|
||||||
|
errType = "dial_timeout"
|
||||||
|
} else if _, ok := err.(net.Error); ok {
|
||||||
|
errType = "network_error"
|
||||||
|
} else {
|
||||||
|
errType = "connection_error"
|
||||||
|
}
|
||||||
|
|
||||||
// Don't log as error for context canceled - it's usually client closing connection
|
// Don't log as error for context canceled - it's usually client closing connection
|
||||||
if errors.Is(err, context.Canceled) {
|
if errors.Is(err, context.Canceled) {
|
||||||
log.Debugf("amp upstream proxy: client canceled request for %s %s", req.Method, req.URL.Path)
|
log.Debugf("amp upstream proxy [%s]: client canceled request for %s %s", errType, req.Method, req.URL.Path)
|
||||||
} else {
|
} else {
|
||||||
log.Errorf("amp upstream proxy error for %s %s: %v", req.Method, req.URL.Path, err)
|
log.Errorf("amp upstream proxy error [%s] for %s %s: %v", errType, req.Method, req.URL.Path, err)
|
||||||
}
|
}
|
||||||
|
|
||||||
rw.Header().Set("Content-Type", "application/json")
|
rw.Header().Set("Content-Type", "application/json")
|
||||||
rw.WriteHeader(http.StatusBadGateway)
|
rw.WriteHeader(http.StatusBadGateway)
|
||||||
_, _ = rw.Write([]byte(`{"error":"amp_upstream_proxy_error","message":"Failed to reach Amp upstream"}`))
|
_, _ = rw.Write([]byte(`{"error":"amp_upstream_proxy_error","message":"Failed to reach Amp upstream"}`))
|
||||||
|
|||||||
@@ -421,6 +421,18 @@ func (s *Server) setupRoutes() {
|
|||||||
c.String(http.StatusOK, oauthCallbackSuccessHTML)
|
c.String(http.StatusOK, oauthCallbackSuccessHTML)
|
||||||
})
|
})
|
||||||
|
|
||||||
|
s.engine.GET("/kiro/callback", func(c *gin.Context) {
|
||||||
|
code := c.Query("code")
|
||||||
|
state := c.Query("state")
|
||||||
|
errStr := c.Query("error")
|
||||||
|
if state != "" {
|
||||||
|
file := fmt.Sprintf("%s/.oauth-kiro-%s.oauth", s.cfg.AuthDir, state)
|
||||||
|
_ = os.WriteFile(file, []byte(fmt.Sprintf(`{"code":"%s","state":"%s","error":"%s"}`, code, state, errStr)), 0o600)
|
||||||
|
}
|
||||||
|
c.Header("Content-Type", "text/html; charset=utf-8")
|
||||||
|
c.String(http.StatusOK, oauthCallbackSuccessHTML)
|
||||||
|
})
|
||||||
|
|
||||||
// Management routes are registered lazily by registerManagementRoutes when a secret is configured.
|
// Management routes are registered lazily by registerManagementRoutes when a secret is configured.
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -586,6 +598,7 @@ func (s *Server) registerManagementRoutes() {
|
|||||||
mgmt.GET("/qwen-auth-url", s.mgmt.RequestQwenToken)
|
mgmt.GET("/qwen-auth-url", s.mgmt.RequestQwenToken)
|
||||||
mgmt.GET("/iflow-auth-url", s.mgmt.RequestIFlowToken)
|
mgmt.GET("/iflow-auth-url", s.mgmt.RequestIFlowToken)
|
||||||
mgmt.POST("/iflow-auth-url", s.mgmt.RequestIFlowCookieToken)
|
mgmt.POST("/iflow-auth-url", s.mgmt.RequestIFlowCookieToken)
|
||||||
|
mgmt.GET("/kiro-auth-url", s.mgmt.RequestKiroToken)
|
||||||
mgmt.GET("/get-auth-status", s.mgmt.GetAuthStatus)
|
mgmt.GET("/get-auth-status", s.mgmt.GetAuthStatus)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,7 +1,10 @@
|
|||||||
package iflow
|
package iflow
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"encoding/json"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"os"
|
||||||
|
"path/filepath"
|
||||||
"strings"
|
"strings"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -36,3 +39,61 @@ func SanitizeIFlowFileName(raw string) string {
|
|||||||
}
|
}
|
||||||
return strings.TrimSpace(result.String())
|
return strings.TrimSpace(result.String())
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// ExtractBXAuth extracts the BXAuth value from a cookie string.
|
||||||
|
func ExtractBXAuth(cookie string) string {
|
||||||
|
parts := strings.Split(cookie, ";")
|
||||||
|
for _, part := range parts {
|
||||||
|
part = strings.TrimSpace(part)
|
||||||
|
if strings.HasPrefix(part, "BXAuth=") {
|
||||||
|
return strings.TrimPrefix(part, "BXAuth=")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
|
||||||
|
// CheckDuplicateBXAuth checks if the given BXAuth value already exists in any iflow auth file.
|
||||||
|
// Returns the path of the existing file if found, empty string otherwise.
|
||||||
|
func CheckDuplicateBXAuth(authDir, bxAuth string) (string, error) {
|
||||||
|
if bxAuth == "" {
|
||||||
|
return "", nil
|
||||||
|
}
|
||||||
|
|
||||||
|
entries, err := os.ReadDir(authDir)
|
||||||
|
if err != nil {
|
||||||
|
if os.IsNotExist(err) {
|
||||||
|
return "", nil
|
||||||
|
}
|
||||||
|
return "", fmt.Errorf("read auth dir failed: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, entry := range entries {
|
||||||
|
if entry.IsDir() {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
name := entry.Name()
|
||||||
|
if !strings.HasPrefix(name, "iflow-") || !strings.HasSuffix(name, ".json") {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
filePath := filepath.Join(authDir, name)
|
||||||
|
data, err := os.ReadFile(filePath)
|
||||||
|
if err != nil {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
var tokenData struct {
|
||||||
|
Cookie string `json:"cookie"`
|
||||||
|
}
|
||||||
|
if err := json.Unmarshal(data, &tokenData); err != nil {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
existingBXAuth := ExtractBXAuth(tokenData.Cookie)
|
||||||
|
if existingBXAuth != "" && existingBXAuth == bxAuth {
|
||||||
|
return filePath, nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return "", nil
|
||||||
|
}
|
||||||
|
|||||||
@@ -506,11 +506,18 @@ func (ia *IFlowAuth) CreateCookieTokenStorage(data *IFlowTokenData) *IFlowTokenS
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Only save the BXAuth field from the cookie
|
||||||
|
bxAuth := ExtractBXAuth(data.Cookie)
|
||||||
|
cookieToSave := ""
|
||||||
|
if bxAuth != "" {
|
||||||
|
cookieToSave = "BXAuth=" + bxAuth + ";"
|
||||||
|
}
|
||||||
|
|
||||||
return &IFlowTokenStorage{
|
return &IFlowTokenStorage{
|
||||||
APIKey: data.APIKey,
|
APIKey: data.APIKey,
|
||||||
Email: data.Email,
|
Email: data.Email,
|
||||||
Expire: data.Expire,
|
Expire: data.Expire,
|
||||||
Cookie: data.Cookie,
|
Cookie: cookieToSave,
|
||||||
LastRefresh: time.Now().Format(time.RFC3339),
|
LastRefresh: time.Now().Format(time.RFC3339),
|
||||||
Type: "iflow",
|
Type: "iflow",
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -126,8 +126,8 @@ func (c *SocialAuthClient) buildLoginURL(provider, redirectURI, codeChallenge, s
|
|||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
// createToken exchanges the authorization code for tokens.
|
// CreateToken exchanges the authorization code for tokens.
|
||||||
func (c *SocialAuthClient) createToken(ctx context.Context, req *CreateTokenRequest) (*SocialTokenResponse, error) {
|
func (c *SocialAuthClient) CreateToken(ctx context.Context, req *CreateTokenRequest) (*SocialTokenResponse, error) {
|
||||||
body, err := json.Marshal(req)
|
body, err := json.Marshal(req)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("failed to marshal token request: %w", err)
|
return nil, fmt.Errorf("failed to marshal token request: %w", err)
|
||||||
@@ -326,7 +326,7 @@ func (c *SocialAuthClient) LoginWithSocial(ctx context.Context, provider SocialP
|
|||||||
RedirectURI: KiroRedirectURI,
|
RedirectURI: KiroRedirectURI,
|
||||||
}
|
}
|
||||||
|
|
||||||
tokenResp, err := c.createToken(ctx, tokenReq)
|
tokenResp, err := c.CreateToken(ctx, tokenReq)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("failed to exchange code for tokens: %w", err)
|
return nil, fmt.Errorf("failed to exchange code for tokens: %w", err)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -5,7 +5,9 @@ import (
|
|||||||
"context"
|
"context"
|
||||||
"fmt"
|
"fmt"
|
||||||
"os"
|
"os"
|
||||||
|
"path/filepath"
|
||||||
"strings"
|
"strings"
|
||||||
|
"time"
|
||||||
|
|
||||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/auth/iflow"
|
"github.com/router-for-me/CLIProxyAPI/v6/internal/auth/iflow"
|
||||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
|
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
|
||||||
@@ -37,6 +39,16 @@ func DoIFlowCookieAuth(cfg *config.Config, options *LoginOptions) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Check for duplicate BXAuth before authentication
|
||||||
|
bxAuth := iflow.ExtractBXAuth(cookie)
|
||||||
|
if existingFile, err := iflow.CheckDuplicateBXAuth(cfg.AuthDir, bxAuth); err != nil {
|
||||||
|
fmt.Printf("Failed to check duplicate: %v\n", err)
|
||||||
|
return
|
||||||
|
} else if existingFile != "" {
|
||||||
|
fmt.Printf("Duplicate BXAuth found, authentication already exists: %s\n", filepath.Base(existingFile))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
// Authenticate with cookie
|
// Authenticate with cookie
|
||||||
auth := iflow.NewIFlowAuth(cfg)
|
auth := iflow.NewIFlowAuth(cfg)
|
||||||
ctx := context.Background()
|
ctx := context.Background()
|
||||||
@@ -82,5 +94,5 @@ func promptForCookie(promptFn func(string) (string, error)) (string, error) {
|
|||||||
// getAuthFilePath returns the auth file path for the given provider and email
|
// getAuthFilePath returns the auth file path for the given provider and email
|
||||||
func getAuthFilePath(cfg *config.Config, provider, email string) string {
|
func getAuthFilePath(cfg *config.Config, provider, email string) string {
|
||||||
fileName := iflow.SanitizeIFlowFileName(email)
|
fileName := iflow.SanitizeIFlowFileName(email)
|
||||||
return fmt.Sprintf("%s/%s-%s.json", cfg.AuthDir, provider, fileName)
|
return fmt.Sprintf("%s/%s-%s-%d.json", cfg.AuthDir, provider, fileName, time.Now().Unix())
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -895,6 +895,7 @@ func GetKiroModels() []*ModelInfo {
|
|||||||
Description: "Claude Opus 4.5 via Kiro (2.2x credit)",
|
Description: "Claude Opus 4.5 via Kiro (2.2x credit)",
|
||||||
ContextLength: 200000,
|
ContextLength: 200000,
|
||||||
MaxCompletionTokens: 64000,
|
MaxCompletionTokens: 64000,
|
||||||
|
Thinking: &ThinkingSupport{Min: 1024, Max: 32000, ZeroAllowed: true, DynamicAllowed: true},
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
ID: "kiro-claude-sonnet-4-5",
|
ID: "kiro-claude-sonnet-4-5",
|
||||||
@@ -906,6 +907,7 @@ func GetKiroModels() []*ModelInfo {
|
|||||||
Description: "Claude Sonnet 4.5 via Kiro (1.3x credit)",
|
Description: "Claude Sonnet 4.5 via Kiro (1.3x credit)",
|
||||||
ContextLength: 200000,
|
ContextLength: 200000,
|
||||||
MaxCompletionTokens: 64000,
|
MaxCompletionTokens: 64000,
|
||||||
|
Thinking: &ThinkingSupport{Min: 1024, Max: 32000, ZeroAllowed: true, DynamicAllowed: true},
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
ID: "kiro-claude-sonnet-4",
|
ID: "kiro-claude-sonnet-4",
|
||||||
@@ -917,6 +919,7 @@ func GetKiroModels() []*ModelInfo {
|
|||||||
Description: "Claude Sonnet 4 via Kiro (1.3x credit)",
|
Description: "Claude Sonnet 4 via Kiro (1.3x credit)",
|
||||||
ContextLength: 200000,
|
ContextLength: 200000,
|
||||||
MaxCompletionTokens: 64000,
|
MaxCompletionTokens: 64000,
|
||||||
|
Thinking: &ThinkingSupport{Min: 1024, Max: 32000, ZeroAllowed: true, DynamicAllowed: true},
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
ID: "kiro-claude-haiku-4-5",
|
ID: "kiro-claude-haiku-4-5",
|
||||||
@@ -928,6 +931,7 @@ func GetKiroModels() []*ModelInfo {
|
|||||||
Description: "Claude Haiku 4.5 via Kiro (0.4x credit)",
|
Description: "Claude Haiku 4.5 via Kiro (0.4x credit)",
|
||||||
ContextLength: 200000,
|
ContextLength: 200000,
|
||||||
MaxCompletionTokens: 64000,
|
MaxCompletionTokens: 64000,
|
||||||
|
Thinking: &ThinkingSupport{Min: 1024, Max: 32000, ZeroAllowed: true, DynamicAllowed: true},
|
||||||
},
|
},
|
||||||
// --- Agentic Variants (Optimized for coding agents with chunked writes) ---
|
// --- Agentic Variants (Optimized for coding agents with chunked writes) ---
|
||||||
{
|
{
|
||||||
@@ -940,6 +944,7 @@ func GetKiroModels() []*ModelInfo {
|
|||||||
Description: "Claude Opus 4.5 optimized for coding agents (chunked writes)",
|
Description: "Claude Opus 4.5 optimized for coding agents (chunked writes)",
|
||||||
ContextLength: 200000,
|
ContextLength: 200000,
|
||||||
MaxCompletionTokens: 64000,
|
MaxCompletionTokens: 64000,
|
||||||
|
Thinking: &ThinkingSupport{Min: 1024, Max: 32000, ZeroAllowed: true, DynamicAllowed: true},
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
ID: "kiro-claude-sonnet-4-5-agentic",
|
ID: "kiro-claude-sonnet-4-5-agentic",
|
||||||
@@ -951,6 +956,7 @@ func GetKiroModels() []*ModelInfo {
|
|||||||
Description: "Claude Sonnet 4.5 optimized for coding agents (chunked writes)",
|
Description: "Claude Sonnet 4.5 optimized for coding agents (chunked writes)",
|
||||||
ContextLength: 200000,
|
ContextLength: 200000,
|
||||||
MaxCompletionTokens: 64000,
|
MaxCompletionTokens: 64000,
|
||||||
|
Thinking: &ThinkingSupport{Min: 1024, Max: 32000, ZeroAllowed: true, DynamicAllowed: true},
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
ID: "kiro-claude-sonnet-4-agentic",
|
ID: "kiro-claude-sonnet-4-agentic",
|
||||||
@@ -962,6 +968,7 @@ func GetKiroModels() []*ModelInfo {
|
|||||||
Description: "Claude Sonnet 4 optimized for coding agents (chunked writes)",
|
Description: "Claude Sonnet 4 optimized for coding agents (chunked writes)",
|
||||||
ContextLength: 200000,
|
ContextLength: 200000,
|
||||||
MaxCompletionTokens: 64000,
|
MaxCompletionTokens: 64000,
|
||||||
|
Thinking: &ThinkingSupport{Min: 1024, Max: 32000, ZeroAllowed: true, DynamicAllowed: true},
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
ID: "kiro-claude-haiku-4-5-agentic",
|
ID: "kiro-claude-haiku-4-5-agentic",
|
||||||
@@ -973,6 +980,7 @@ func GetKiroModels() []*ModelInfo {
|
|||||||
Description: "Claude Haiku 4.5 optimized for coding agents (chunked writes)",
|
Description: "Claude Haiku 4.5 optimized for coding agents (chunked writes)",
|
||||||
ContextLength: 200000,
|
ContextLength: 200000,
|
||||||
MaxCompletionTokens: 64000,
|
MaxCompletionTokens: 64000,
|
||||||
|
Thinking: &ThinkingSupport{Min: 1024, Max: 32000, ZeroAllowed: true, DynamicAllowed: true},
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -748,7 +748,8 @@ func (r *ModelRegistry) convertModelToMap(model *ModelInfo, handlerType string)
|
|||||||
}
|
}
|
||||||
return result
|
return result
|
||||||
|
|
||||||
case "claude":
|
case "claude", "kiro", "antigravity":
|
||||||
|
// Claude, Kiro, and Antigravity all use Claude-compatible format for Claude Code client
|
||||||
result := map[string]any{
|
result := map[string]any{
|
||||||
"id": model.ID,
|
"id": model.ID,
|
||||||
"object": "model",
|
"object": "model",
|
||||||
@@ -763,6 +764,19 @@ func (r *ModelRegistry) convertModelToMap(model *ModelInfo, handlerType string)
|
|||||||
if model.DisplayName != "" {
|
if model.DisplayName != "" {
|
||||||
result["display_name"] = model.DisplayName
|
result["display_name"] = model.DisplayName
|
||||||
}
|
}
|
||||||
|
// Add thinking support for Claude Code client
|
||||||
|
// Claude Code checks for "thinking" field (simple boolean) to enable tab toggle
|
||||||
|
// Also add "extended_thinking" for detailed budget info
|
||||||
|
if model.Thinking != nil {
|
||||||
|
result["thinking"] = true
|
||||||
|
result["extended_thinking"] = map[string]any{
|
||||||
|
"supported": true,
|
||||||
|
"min": model.Thinking.Min,
|
||||||
|
"max": model.Thinking.Max,
|
||||||
|
"zero_allowed": model.Thinking.ZeroAllowed,
|
||||||
|
"dynamic_allowed": model.Thinking.DynamicAllowed,
|
||||||
|
}
|
||||||
|
}
|
||||||
return result
|
return result
|
||||||
|
|
||||||
case "gemini":
|
case "gemini":
|
||||||
|
|||||||
File diff suppressed because it is too large
Load Diff
@@ -54,13 +54,14 @@ func (e *OpenAICompatExecutor) Execute(ctx context.Context, auth *cliproxyauth.A
|
|||||||
from := opts.SourceFormat
|
from := opts.SourceFormat
|
||||||
to := sdktranslator.FromString("openai")
|
to := sdktranslator.FromString("openai")
|
||||||
translated := sdktranslator.TranslateRequest(from, to, req.Model, bytes.Clone(req.Payload), opts.Stream)
|
translated := sdktranslator.TranslateRequest(from, to, req.Model, bytes.Clone(req.Payload), opts.Stream)
|
||||||
if modelOverride := e.resolveUpstreamModel(req.Model, auth); modelOverride != "" {
|
modelOverride := e.resolveUpstreamModel(req.Model, auth)
|
||||||
|
if modelOverride != "" {
|
||||||
translated = e.overrideModel(translated, modelOverride)
|
translated = e.overrideModel(translated, modelOverride)
|
||||||
}
|
}
|
||||||
translated = applyPayloadConfigWithRoot(e.cfg, req.Model, to.String(), "", translated)
|
translated = applyPayloadConfigWithRoot(e.cfg, req.Model, to.String(), "", translated)
|
||||||
translated = applyReasoningEffortMetadata(translated, req.Metadata, req.Model, "reasoning_effort")
|
translated = applyReasoningEffortMetadata(translated, req.Metadata, req.Model, "reasoning_effort")
|
||||||
upstreamModel := util.ResolveOriginalModel(req.Model, req.Metadata)
|
upstreamModel := util.ResolveOriginalModel(req.Model, req.Metadata)
|
||||||
if upstreamModel != "" {
|
if upstreamModel != "" && modelOverride == "" {
|
||||||
translated, _ = sjson.SetBytes(translated, "model", upstreamModel)
|
translated, _ = sjson.SetBytes(translated, "model", upstreamModel)
|
||||||
}
|
}
|
||||||
translated = normalizeThinkingConfig(translated, upstreamModel)
|
translated = normalizeThinkingConfig(translated, upstreamModel)
|
||||||
@@ -148,13 +149,14 @@ func (e *OpenAICompatExecutor) ExecuteStream(ctx context.Context, auth *cliproxy
|
|||||||
from := opts.SourceFormat
|
from := opts.SourceFormat
|
||||||
to := sdktranslator.FromString("openai")
|
to := sdktranslator.FromString("openai")
|
||||||
translated := sdktranslator.TranslateRequest(from, to, req.Model, bytes.Clone(req.Payload), true)
|
translated := sdktranslator.TranslateRequest(from, to, req.Model, bytes.Clone(req.Payload), true)
|
||||||
if modelOverride := e.resolveUpstreamModel(req.Model, auth); modelOverride != "" {
|
modelOverride := e.resolveUpstreamModel(req.Model, auth)
|
||||||
|
if modelOverride != "" {
|
||||||
translated = e.overrideModel(translated, modelOverride)
|
translated = e.overrideModel(translated, modelOverride)
|
||||||
}
|
}
|
||||||
translated = applyPayloadConfigWithRoot(e.cfg, req.Model, to.String(), "", translated)
|
translated = applyPayloadConfigWithRoot(e.cfg, req.Model, to.String(), "", translated)
|
||||||
translated = applyReasoningEffortMetadata(translated, req.Metadata, req.Model, "reasoning_effort")
|
translated = applyReasoningEffortMetadata(translated, req.Metadata, req.Model, "reasoning_effort")
|
||||||
upstreamModel := util.ResolveOriginalModel(req.Model, req.Metadata)
|
upstreamModel := util.ResolveOriginalModel(req.Model, req.Metadata)
|
||||||
if upstreamModel != "" {
|
if upstreamModel != "" && modelOverride == "" {
|
||||||
translated, _ = sjson.SetBytes(translated, "model", upstreamModel)
|
translated, _ = sjson.SetBytes(translated, "model", upstreamModel)
|
||||||
}
|
}
|
||||||
translated = normalizeThinkingConfig(translated, upstreamModel)
|
translated = normalizeThinkingConfig(translated, upstreamModel)
|
||||||
|
|||||||
@@ -52,10 +52,14 @@ func applyReasoningEffortMetadata(payload []byte, metadata map[string]any, model
|
|||||||
if len(metadata) == 0 {
|
if len(metadata) == 0 {
|
||||||
return payload
|
return payload
|
||||||
}
|
}
|
||||||
if !util.ModelSupportsThinking(model) {
|
if field == "" {
|
||||||
return payload
|
return payload
|
||||||
}
|
}
|
||||||
if field == "" {
|
baseModel := util.ResolveOriginalModel(model, metadata)
|
||||||
|
if baseModel == "" {
|
||||||
|
baseModel = model
|
||||||
|
}
|
||||||
|
if !util.ModelSupportsThinking(baseModel) && !util.IsOpenAICompatibilityModel(baseModel) {
|
||||||
return payload
|
return payload
|
||||||
}
|
}
|
||||||
if effort, ok := util.ReasoningEffortFromMetadata(metadata); ok && effort != "" {
|
if effort, ok := util.ReasoningEffortFromMetadata(metadata); ok && effort != "" {
|
||||||
@@ -226,6 +230,9 @@ func normalizeThinkingConfig(payload []byte, model string) []byte {
|
|||||||
}
|
}
|
||||||
|
|
||||||
if !util.ModelSupportsThinking(model) {
|
if !util.ModelSupportsThinking(model) {
|
||||||
|
if util.IsOpenAICompatibilityModel(model) {
|
||||||
|
return payload
|
||||||
|
}
|
||||||
return stripThinkingFields(payload)
|
return stripThinkingFields(payload)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -2,43 +2,107 @@ package executor
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"regexp"
|
||||||
|
"strconv"
|
||||||
"strings"
|
"strings"
|
||||||
|
"sync"
|
||||||
|
|
||||||
"github.com/tidwall/gjson"
|
"github.com/tidwall/gjson"
|
||||||
"github.com/tiktoken-go/tokenizer"
|
"github.com/tiktoken-go/tokenizer"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
// tokenizerCache stores tokenizer instances to avoid repeated creation
|
||||||
|
var tokenizerCache sync.Map
|
||||||
|
|
||||||
|
// TokenizerWrapper wraps a tokenizer codec with an adjustment factor for models
|
||||||
|
// where tiktoken may not accurately estimate token counts (e.g., Claude models)
|
||||||
|
type TokenizerWrapper struct {
|
||||||
|
Codec tokenizer.Codec
|
||||||
|
AdjustmentFactor float64 // 1.0 means no adjustment, >1.0 means tiktoken underestimates
|
||||||
|
}
|
||||||
|
|
||||||
|
// Count returns the token count with adjustment factor applied
|
||||||
|
func (tw *TokenizerWrapper) Count(text string) (int, error) {
|
||||||
|
count, err := tw.Codec.Count(text)
|
||||||
|
if err != nil {
|
||||||
|
return 0, err
|
||||||
|
}
|
||||||
|
if tw.AdjustmentFactor != 1.0 && tw.AdjustmentFactor > 0 {
|
||||||
|
return int(float64(count) * tw.AdjustmentFactor), nil
|
||||||
|
}
|
||||||
|
return count, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// getTokenizer returns a cached tokenizer for the given model.
|
||||||
|
// This improves performance by avoiding repeated tokenizer creation.
|
||||||
|
func getTokenizer(model string) (*TokenizerWrapper, error) {
|
||||||
|
// Check cache first
|
||||||
|
if cached, ok := tokenizerCache.Load(model); ok {
|
||||||
|
return cached.(*TokenizerWrapper), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Cache miss, create new tokenizer
|
||||||
|
wrapper, err := tokenizerForModel(model)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
// Store in cache (use LoadOrStore to handle race conditions)
|
||||||
|
actual, _ := tokenizerCache.LoadOrStore(model, wrapper)
|
||||||
|
return actual.(*TokenizerWrapper), nil
|
||||||
|
}
|
||||||
|
|
||||||
// tokenizerForModel returns a tokenizer codec suitable for an OpenAI-style model id.
|
// tokenizerForModel returns a tokenizer codec suitable for an OpenAI-style model id.
|
||||||
func tokenizerForModel(model string) (tokenizer.Codec, error) {
|
// For Claude models, applies a 1.1 adjustment factor since tiktoken may underestimate.
|
||||||
|
func tokenizerForModel(model string) (*TokenizerWrapper, error) {
|
||||||
sanitized := strings.ToLower(strings.TrimSpace(model))
|
sanitized := strings.ToLower(strings.TrimSpace(model))
|
||||||
|
|
||||||
|
// Claude models use cl100k_base with 1.1 adjustment factor
|
||||||
|
// because tiktoken may underestimate Claude's actual token count
|
||||||
|
if strings.Contains(sanitized, "claude") || strings.HasPrefix(sanitized, "kiro-") || strings.HasPrefix(sanitized, "amazonq-") {
|
||||||
|
enc, err := tokenizer.Get(tokenizer.Cl100kBase)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return &TokenizerWrapper{Codec: enc, AdjustmentFactor: 1.1}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
var enc tokenizer.Codec
|
||||||
|
var err error
|
||||||
|
|
||||||
switch {
|
switch {
|
||||||
case sanitized == "":
|
case sanitized == "":
|
||||||
return tokenizer.Get(tokenizer.Cl100kBase)
|
enc, err = tokenizer.Get(tokenizer.Cl100kBase)
|
||||||
case strings.HasPrefix(sanitized, "gpt-5"):
|
case strings.HasPrefix(sanitized, "gpt-5"):
|
||||||
return tokenizer.ForModel(tokenizer.GPT5)
|
enc, err = tokenizer.ForModel(tokenizer.GPT5)
|
||||||
case strings.HasPrefix(sanitized, "gpt-5.1"):
|
case strings.HasPrefix(sanitized, "gpt-5.1"):
|
||||||
return tokenizer.ForModel(tokenizer.GPT5)
|
enc, err = tokenizer.ForModel(tokenizer.GPT5)
|
||||||
case strings.HasPrefix(sanitized, "gpt-4.1"):
|
case strings.HasPrefix(sanitized, "gpt-4.1"):
|
||||||
return tokenizer.ForModel(tokenizer.GPT41)
|
enc, err = tokenizer.ForModel(tokenizer.GPT41)
|
||||||
case strings.HasPrefix(sanitized, "gpt-4o"):
|
case strings.HasPrefix(sanitized, "gpt-4o"):
|
||||||
return tokenizer.ForModel(tokenizer.GPT4o)
|
enc, err = tokenizer.ForModel(tokenizer.GPT4o)
|
||||||
case strings.HasPrefix(sanitized, "gpt-4"):
|
case strings.HasPrefix(sanitized, "gpt-4"):
|
||||||
return tokenizer.ForModel(tokenizer.GPT4)
|
enc, err = tokenizer.ForModel(tokenizer.GPT4)
|
||||||
case strings.HasPrefix(sanitized, "gpt-3.5"), strings.HasPrefix(sanitized, "gpt-3"):
|
case strings.HasPrefix(sanitized, "gpt-3.5"), strings.HasPrefix(sanitized, "gpt-3"):
|
||||||
return tokenizer.ForModel(tokenizer.GPT35Turbo)
|
enc, err = tokenizer.ForModel(tokenizer.GPT35Turbo)
|
||||||
case strings.HasPrefix(sanitized, "o1"):
|
case strings.HasPrefix(sanitized, "o1"):
|
||||||
return tokenizer.ForModel(tokenizer.O1)
|
enc, err = tokenizer.ForModel(tokenizer.O1)
|
||||||
case strings.HasPrefix(sanitized, "o3"):
|
case strings.HasPrefix(sanitized, "o3"):
|
||||||
return tokenizer.ForModel(tokenizer.O3)
|
enc, err = tokenizer.ForModel(tokenizer.O3)
|
||||||
case strings.HasPrefix(sanitized, "o4"):
|
case strings.HasPrefix(sanitized, "o4"):
|
||||||
return tokenizer.ForModel(tokenizer.O4Mini)
|
enc, err = tokenizer.ForModel(tokenizer.O4Mini)
|
||||||
default:
|
default:
|
||||||
return tokenizer.Get(tokenizer.O200kBase)
|
enc, err = tokenizer.Get(tokenizer.O200kBase)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return &TokenizerWrapper{Codec: enc, AdjustmentFactor: 1.0}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// countOpenAIChatTokens approximates prompt tokens for OpenAI chat completions payloads.
|
// countOpenAIChatTokens approximates prompt tokens for OpenAI chat completions payloads.
|
||||||
func countOpenAIChatTokens(enc tokenizer.Codec, payload []byte) (int64, error) {
|
func countOpenAIChatTokens(enc *TokenizerWrapper, payload []byte) (int64, error) {
|
||||||
if enc == nil {
|
if enc == nil {
|
||||||
return 0, fmt.Errorf("encoder is nil")
|
return 0, fmt.Errorf("encoder is nil")
|
||||||
}
|
}
|
||||||
@@ -62,11 +126,206 @@ func countOpenAIChatTokens(enc tokenizer.Codec, payload []byte) (int64, error) {
|
|||||||
return 0, nil
|
return 0, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Count text tokens
|
||||||
count, err := enc.Count(joined)
|
count, err := enc.Count(joined)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return 0, err
|
return 0, err
|
||||||
}
|
}
|
||||||
return int64(count), nil
|
|
||||||
|
// Extract and add image tokens from placeholders
|
||||||
|
imageTokens := extractImageTokens(joined)
|
||||||
|
|
||||||
|
return int64(count) + int64(imageTokens), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// countClaudeChatTokens approximates prompt tokens for Claude API chat completions payloads.
|
||||||
|
// This handles Claude's message format with system, messages, and tools.
|
||||||
|
// Image tokens are estimated based on image dimensions when available.
|
||||||
|
func countClaudeChatTokens(enc *TokenizerWrapper, payload []byte) (int64, error) {
|
||||||
|
if enc == nil {
|
||||||
|
return 0, fmt.Errorf("encoder is nil")
|
||||||
|
}
|
||||||
|
if len(payload) == 0 {
|
||||||
|
return 0, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
root := gjson.ParseBytes(payload)
|
||||||
|
segments := make([]string, 0, 32)
|
||||||
|
|
||||||
|
// Collect system prompt (can be string or array of content blocks)
|
||||||
|
collectClaudeSystem(root.Get("system"), &segments)
|
||||||
|
|
||||||
|
// Collect messages
|
||||||
|
collectClaudeMessages(root.Get("messages"), &segments)
|
||||||
|
|
||||||
|
// Collect tools
|
||||||
|
collectClaudeTools(root.Get("tools"), &segments)
|
||||||
|
|
||||||
|
joined := strings.TrimSpace(strings.Join(segments, "\n"))
|
||||||
|
if joined == "" {
|
||||||
|
return 0, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Count text tokens
|
||||||
|
count, err := enc.Count(joined)
|
||||||
|
if err != nil {
|
||||||
|
return 0, err
|
||||||
|
}
|
||||||
|
|
||||||
|
// Extract and add image tokens from placeholders
|
||||||
|
imageTokens := extractImageTokens(joined)
|
||||||
|
|
||||||
|
return int64(count) + int64(imageTokens), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// imageTokenPattern matches [IMAGE:xxx tokens] format for extracting estimated image tokens
|
||||||
|
var imageTokenPattern = regexp.MustCompile(`\[IMAGE:(\d+) tokens\]`)
|
||||||
|
|
||||||
|
// extractImageTokens extracts image token estimates from placeholder text.
|
||||||
|
// Placeholders are in the format [IMAGE:xxx tokens] where xxx is the estimated token count.
|
||||||
|
func extractImageTokens(text string) int {
|
||||||
|
matches := imageTokenPattern.FindAllStringSubmatch(text, -1)
|
||||||
|
total := 0
|
||||||
|
for _, match := range matches {
|
||||||
|
if len(match) > 1 {
|
||||||
|
if tokens, err := strconv.Atoi(match[1]); err == nil {
|
||||||
|
total += tokens
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return total
|
||||||
|
}
|
||||||
|
|
||||||
|
// estimateImageTokens calculates estimated tokens for an image based on dimensions.
|
||||||
|
// Based on Claude's image token calculation: tokens ≈ (width * height) / 750
|
||||||
|
// Minimum 85 tokens, maximum 1590 tokens (for 1568x1568 images).
|
||||||
|
func estimateImageTokens(width, height float64) int {
|
||||||
|
if width <= 0 || height <= 0 {
|
||||||
|
// No valid dimensions, use default estimate (medium-sized image)
|
||||||
|
return 1000
|
||||||
|
}
|
||||||
|
|
||||||
|
tokens := int(width * height / 750)
|
||||||
|
|
||||||
|
// Apply bounds
|
||||||
|
if tokens < 85 {
|
||||||
|
tokens = 85
|
||||||
|
}
|
||||||
|
if tokens > 1590 {
|
||||||
|
tokens = 1590
|
||||||
|
}
|
||||||
|
|
||||||
|
return tokens
|
||||||
|
}
|
||||||
|
|
||||||
|
// collectClaudeSystem extracts text from Claude's system field.
|
||||||
|
// System can be a string or an array of content blocks.
|
||||||
|
func collectClaudeSystem(system gjson.Result, segments *[]string) {
|
||||||
|
if !system.Exists() {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if system.Type == gjson.String {
|
||||||
|
addIfNotEmpty(segments, system.String())
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if system.IsArray() {
|
||||||
|
system.ForEach(func(_, block gjson.Result) bool {
|
||||||
|
blockType := block.Get("type").String()
|
||||||
|
if blockType == "text" || blockType == "" {
|
||||||
|
addIfNotEmpty(segments, block.Get("text").String())
|
||||||
|
}
|
||||||
|
// Also handle plain string blocks
|
||||||
|
if block.Type == gjson.String {
|
||||||
|
addIfNotEmpty(segments, block.String())
|
||||||
|
}
|
||||||
|
return true
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// collectClaudeMessages extracts text from Claude's messages array.
|
||||||
|
func collectClaudeMessages(messages gjson.Result, segments *[]string) {
|
||||||
|
if !messages.Exists() || !messages.IsArray() {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
messages.ForEach(func(_, message gjson.Result) bool {
|
||||||
|
addIfNotEmpty(segments, message.Get("role").String())
|
||||||
|
collectClaudeContent(message.Get("content"), segments)
|
||||||
|
return true
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// collectClaudeContent extracts text from Claude's content field.
|
||||||
|
// Content can be a string or an array of content blocks.
|
||||||
|
// For images, estimates token count based on dimensions when available.
|
||||||
|
func collectClaudeContent(content gjson.Result, segments *[]string) {
|
||||||
|
if !content.Exists() {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if content.Type == gjson.String {
|
||||||
|
addIfNotEmpty(segments, content.String())
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if content.IsArray() {
|
||||||
|
content.ForEach(func(_, part gjson.Result) bool {
|
||||||
|
partType := part.Get("type").String()
|
||||||
|
switch partType {
|
||||||
|
case "text":
|
||||||
|
addIfNotEmpty(segments, part.Get("text").String())
|
||||||
|
case "image":
|
||||||
|
// Estimate image tokens based on dimensions if available
|
||||||
|
source := part.Get("source")
|
||||||
|
if source.Exists() {
|
||||||
|
width := source.Get("width").Float()
|
||||||
|
height := source.Get("height").Float()
|
||||||
|
if width > 0 && height > 0 {
|
||||||
|
tokens := estimateImageTokens(width, height)
|
||||||
|
addIfNotEmpty(segments, fmt.Sprintf("[IMAGE:%d tokens]", tokens))
|
||||||
|
} else {
|
||||||
|
// No dimensions available, use default estimate
|
||||||
|
addIfNotEmpty(segments, "[IMAGE:1000 tokens]")
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
// No source info, use default estimate
|
||||||
|
addIfNotEmpty(segments, "[IMAGE:1000 tokens]")
|
||||||
|
}
|
||||||
|
case "tool_use":
|
||||||
|
addIfNotEmpty(segments, part.Get("id").String())
|
||||||
|
addIfNotEmpty(segments, part.Get("name").String())
|
||||||
|
if input := part.Get("input"); input.Exists() {
|
||||||
|
addIfNotEmpty(segments, input.Raw)
|
||||||
|
}
|
||||||
|
case "tool_result":
|
||||||
|
addIfNotEmpty(segments, part.Get("tool_use_id").String())
|
||||||
|
collectClaudeContent(part.Get("content"), segments)
|
||||||
|
case "thinking":
|
||||||
|
addIfNotEmpty(segments, part.Get("thinking").String())
|
||||||
|
default:
|
||||||
|
// For unknown types, try to extract any text content
|
||||||
|
if part.Type == gjson.String {
|
||||||
|
addIfNotEmpty(segments, part.String())
|
||||||
|
} else if part.Type == gjson.JSON {
|
||||||
|
addIfNotEmpty(segments, part.Raw)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return true
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// collectClaudeTools extracts text from Claude's tools array.
|
||||||
|
func collectClaudeTools(tools gjson.Result, segments *[]string) {
|
||||||
|
if !tools.Exists() || !tools.IsArray() {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
tools.ForEach(func(_, tool gjson.Result) bool {
|
||||||
|
addIfNotEmpty(segments, tool.Get("name").String())
|
||||||
|
addIfNotEmpty(segments, tool.Get("description").String())
|
||||||
|
if inputSchema := tool.Get("input_schema"); inputSchema.Exists() {
|
||||||
|
addIfNotEmpty(segments, inputSchema.Raw)
|
||||||
|
}
|
||||||
|
return true
|
||||||
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
// buildOpenAIUsageJSON returns a minimal usage structure understood by downstream translators.
|
// buildOpenAIUsageJSON returns a minimal usage structure understood by downstream translators.
|
||||||
|
|||||||
@@ -25,33 +25,33 @@ func ModelSupportsThinking(model string) bool {
|
|||||||
// or min (0 if zero is allowed and mid <= 0).
|
// or min (0 if zero is allowed and mid <= 0).
|
||||||
func NormalizeThinkingBudget(model string, budget int) int {
|
func NormalizeThinkingBudget(model string, budget int) int {
|
||||||
if budget == -1 { // dynamic
|
if budget == -1 { // dynamic
|
||||||
if found, min, max, zeroAllowed, dynamicAllowed := thinkingRangeFromRegistry(model); found {
|
if found, minBudget, maxBudget, zeroAllowed, dynamicAllowed := thinkingRangeFromRegistry(model); found {
|
||||||
if dynamicAllowed {
|
if dynamicAllowed {
|
||||||
return -1
|
return -1
|
||||||
}
|
}
|
||||||
mid := (min + max) / 2
|
mid := (minBudget + maxBudget) / 2
|
||||||
if mid <= 0 && zeroAllowed {
|
if mid <= 0 && zeroAllowed {
|
||||||
return 0
|
return 0
|
||||||
}
|
}
|
||||||
if mid <= 0 {
|
if mid <= 0 {
|
||||||
return min
|
return minBudget
|
||||||
}
|
}
|
||||||
return mid
|
return mid
|
||||||
}
|
}
|
||||||
return -1
|
return -1
|
||||||
}
|
}
|
||||||
if found, min, max, zeroAllowed, _ := thinkingRangeFromRegistry(model); found {
|
if found, minBudget, maxBudget, zeroAllowed, _ := thinkingRangeFromRegistry(model); found {
|
||||||
if budget == 0 {
|
if budget == 0 {
|
||||||
if zeroAllowed {
|
if zeroAllowed {
|
||||||
return 0
|
return 0
|
||||||
}
|
}
|
||||||
return min
|
return minBudget
|
||||||
}
|
}
|
||||||
if budget < min {
|
if budget < minBudget {
|
||||||
return min
|
return minBudget
|
||||||
}
|
}
|
||||||
if budget > max {
|
if budget > maxBudget {
|
||||||
return max
|
return maxBudget
|
||||||
}
|
}
|
||||||
return budget
|
return budget
|
||||||
}
|
}
|
||||||
@@ -105,3 +105,16 @@ func NormalizeReasoningEffortLevel(model, effort string) (string, bool) {
|
|||||||
}
|
}
|
||||||
return "", false
|
return "", false
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// IsOpenAICompatibilityModel reports whether the model is registered as an OpenAI-compatibility model.
|
||||||
|
// These models may not advertise Thinking metadata in the registry.
|
||||||
|
func IsOpenAICompatibilityModel(model string) bool {
|
||||||
|
if model == "" {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
info := registry.GetGlobalRegistry().GetModelInfo(model)
|
||||||
|
if info == nil {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
return strings.EqualFold(strings.TrimSpace(info.Type), "openai-compatibility")
|
||||||
|
}
|
||||||
|
|||||||
@@ -219,12 +219,12 @@ func (h *ClaudeCodeAPIHandler) handleStreamingResponse(c *gin.Context, rawJSON [
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (h *ClaudeCodeAPIHandler) forwardClaudeStream(c *gin.Context, flusher http.Flusher, cancel func(error), data <-chan []byte, errs <-chan *interfaces.ErrorMessage) {
|
func (h *ClaudeCodeAPIHandler) forwardClaudeStream(c *gin.Context, flusher http.Flusher, cancel func(error), data <-chan []byte, errs <-chan *interfaces.ErrorMessage) {
|
||||||
// v6.1: Intelligent Buffered Streamer strategy
|
// v6.2: Immediate flush strategy for SSE streams
|
||||||
// Enhanced buffering with larger buffer size (16KB) and longer flush interval (120ms).
|
// SSE requires immediate data delivery to prevent client timeouts.
|
||||||
// Smart flush only when buffer is sufficiently filled (≥50%), dramatically reducing
|
// Previous buffering strategy (16KB buffer, 8KB threshold) caused delays
|
||||||
// flush frequency from ~12.5Hz to ~5-8Hz while maintaining low latency.
|
// because SSE events are typically small (< 1KB), leading to client retries.
|
||||||
writer := bufio.NewWriterSize(c.Writer, 16*1024) // 4KB → 16KB
|
writer := bufio.NewWriterSize(c.Writer, 4*1024) // 4KB buffer (smaller for faster flush)
|
||||||
ticker := time.NewTicker(120 * time.Millisecond) // 80ms → 120ms
|
ticker := time.NewTicker(50 * time.Millisecond) // 50ms interval for responsive streaming
|
||||||
defer ticker.Stop()
|
defer ticker.Stop()
|
||||||
|
|
||||||
var chunkIdx int
|
var chunkIdx int
|
||||||
@@ -238,10 +238,9 @@ func (h *ClaudeCodeAPIHandler) forwardClaudeStream(c *gin.Context, flusher http.
|
|||||||
return
|
return
|
||||||
|
|
||||||
case <-ticker.C:
|
case <-ticker.C:
|
||||||
// Smart flush: only flush when buffer has sufficient data (≥50% full)
|
// Flush any buffered data on timer to ensure responsiveness
|
||||||
// This reduces flush frequency while ensuring data flows naturally
|
// For SSE, we flush whenever there's any data to prevent client timeouts
|
||||||
buffered := writer.Buffered()
|
if writer.Buffered() > 0 {
|
||||||
if buffered >= 8*1024 { // At least 8KB (50% of 16KB buffer)
|
|
||||||
if err := writer.Flush(); err != nil {
|
if err := writer.Flush(); err != nil {
|
||||||
// Error flushing, cancel and return
|
// Error flushing, cancel and return
|
||||||
cancel(err)
|
cancel(err)
|
||||||
@@ -254,6 +253,7 @@ func (h *ClaudeCodeAPIHandler) forwardClaudeStream(c *gin.Context, flusher http.
|
|||||||
if !ok {
|
if !ok {
|
||||||
// Stream ended, flush remaining data
|
// Stream ended, flush remaining data
|
||||||
_ = writer.Flush()
|
_ = writer.Flush()
|
||||||
|
flusher.Flush()
|
||||||
cancel(nil)
|
cancel(nil)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
@@ -263,6 +263,12 @@ func (h *ClaudeCodeAPIHandler) forwardClaudeStream(c *gin.Context, flusher http.
|
|||||||
// The handler just needs to forward it without reassembly.
|
// The handler just needs to forward it without reassembly.
|
||||||
if len(chunk) > 0 {
|
if len(chunk) > 0 {
|
||||||
_, _ = writer.Write(chunk)
|
_, _ = writer.Write(chunk)
|
||||||
|
// Immediately flush for first few chunks to establish connection quickly
|
||||||
|
// This prevents client timeout/retry on slow backends like Kiro
|
||||||
|
if chunkIdx < 3 {
|
||||||
|
_ = writer.Flush()
|
||||||
|
flusher.Flush()
|
||||||
|
}
|
||||||
}
|
}
|
||||||
chunkIdx++
|
chunkIdx++
|
||||||
|
|
||||||
|
|||||||
@@ -107,7 +107,7 @@ func (a *IFlowAuthenticator) Login(ctx context.Context, cfg *config.Config, opts
|
|||||||
return nil, fmt.Errorf("iflow authentication failed: missing account identifier")
|
return nil, fmt.Errorf("iflow authentication failed: missing account identifier")
|
||||||
}
|
}
|
||||||
|
|
||||||
fileName := fmt.Sprintf("iflow-%s.json", email)
|
fileName := fmt.Sprintf("iflow-%s-%d.json", email, time.Now().Unix())
|
||||||
metadata := map[string]any{
|
metadata := map[string]any{
|
||||||
"email": email,
|
"email": email,
|
||||||
"api_key": tokenStorage.APIKey,
|
"api_key": tokenStorage.APIKey,
|
||||||
|
|||||||
@@ -47,8 +47,9 @@ func (a *KiroAuthenticator) Provider() string {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// RefreshLead indicates how soon before expiry a refresh should be attempted.
|
// RefreshLead indicates how soon before expiry a refresh should be attempted.
|
||||||
|
// Set to 5 minutes to match Antigravity and avoid frequent refresh checks while still ensuring timely token refresh.
|
||||||
func (a *KiroAuthenticator) RefreshLead() *time.Duration {
|
func (a *KiroAuthenticator) RefreshLead() *time.Duration {
|
||||||
d := 30 * time.Minute
|
d := 5 * time.Minute
|
||||||
return &d
|
return &d
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -103,7 +104,8 @@ func (a *KiroAuthenticator) Login(ctx context.Context, cfg *config.Config, opts
|
|||||||
"source": "aws-builder-id",
|
"source": "aws-builder-id",
|
||||||
"email": tokenData.Email,
|
"email": tokenData.Email,
|
||||||
},
|
},
|
||||||
NextRefreshAfter: expiresAt.Add(-30 * time.Minute),
|
// NextRefreshAfter is aligned with RefreshLead (5min)
|
||||||
|
NextRefreshAfter: expiresAt.Add(-5 * time.Minute),
|
||||||
}
|
}
|
||||||
|
|
||||||
if tokenData.Email != "" {
|
if tokenData.Email != "" {
|
||||||
@@ -165,7 +167,8 @@ func (a *KiroAuthenticator) LoginWithGoogle(ctx context.Context, cfg *config.Con
|
|||||||
"source": "google-oauth",
|
"source": "google-oauth",
|
||||||
"email": tokenData.Email,
|
"email": tokenData.Email,
|
||||||
},
|
},
|
||||||
NextRefreshAfter: expiresAt.Add(-30 * time.Minute),
|
// NextRefreshAfter is aligned with RefreshLead (5min)
|
||||||
|
NextRefreshAfter: expiresAt.Add(-5 * time.Minute),
|
||||||
}
|
}
|
||||||
|
|
||||||
if tokenData.Email != "" {
|
if tokenData.Email != "" {
|
||||||
@@ -227,7 +230,8 @@ func (a *KiroAuthenticator) LoginWithGitHub(ctx context.Context, cfg *config.Con
|
|||||||
"source": "github-oauth",
|
"source": "github-oauth",
|
||||||
"email": tokenData.Email,
|
"email": tokenData.Email,
|
||||||
},
|
},
|
||||||
NextRefreshAfter: expiresAt.Add(-30 * time.Minute),
|
// NextRefreshAfter is aligned with RefreshLead (5min)
|
||||||
|
NextRefreshAfter: expiresAt.Add(-5 * time.Minute),
|
||||||
}
|
}
|
||||||
|
|
||||||
if tokenData.Email != "" {
|
if tokenData.Email != "" {
|
||||||
@@ -291,7 +295,8 @@ func (a *KiroAuthenticator) ImportFromKiroIDE(ctx context.Context, cfg *config.C
|
|||||||
"source": "kiro-ide-import",
|
"source": "kiro-ide-import",
|
||||||
"email": tokenData.Email,
|
"email": tokenData.Email,
|
||||||
},
|
},
|
||||||
NextRefreshAfter: expiresAt.Add(-30 * time.Minute),
|
// NextRefreshAfter is aligned with RefreshLead (5min)
|
||||||
|
NextRefreshAfter: expiresAt.Add(-5 * time.Minute),
|
||||||
}
|
}
|
||||||
|
|
||||||
// Display the email if extracted
|
// Display the email if extracted
|
||||||
@@ -351,7 +356,8 @@ func (a *KiroAuthenticator) Refresh(ctx context.Context, cfg *config.Config, aut
|
|||||||
updated.Metadata["refresh_token"] = tokenData.RefreshToken
|
updated.Metadata["refresh_token"] = tokenData.RefreshToken
|
||||||
updated.Metadata["expires_at"] = tokenData.ExpiresAt
|
updated.Metadata["expires_at"] = tokenData.ExpiresAt
|
||||||
updated.Metadata["last_refresh"] = now.Format(time.RFC3339) // For double-check optimization
|
updated.Metadata["last_refresh"] = now.Format(time.RFC3339) // For double-check optimization
|
||||||
updated.NextRefreshAfter = expiresAt.Add(-30 * time.Minute)
|
// NextRefreshAfter is aligned with RefreshLead (5min)
|
||||||
|
updated.NextRefreshAfter = expiresAt.Add(-5 * time.Minute)
|
||||||
|
|
||||||
return updated, nil
|
return updated, nil
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -40,7 +40,7 @@ type RefreshEvaluator interface {
|
|||||||
const (
|
const (
|
||||||
refreshCheckInterval = 5 * time.Second
|
refreshCheckInterval = 5 * time.Second
|
||||||
refreshPendingBackoff = time.Minute
|
refreshPendingBackoff = time.Minute
|
||||||
refreshFailureBackoff = 5 * time.Minute
|
refreshFailureBackoff = 1 * time.Minute
|
||||||
quotaBackoffBase = time.Second
|
quotaBackoffBase = time.Second
|
||||||
quotaBackoffMax = 30 * time.Minute
|
quotaBackoffMax = 30 * time.Minute
|
||||||
)
|
)
|
||||||
@@ -1471,7 +1471,9 @@ func (m *Manager) refreshAuth(ctx context.Context, id string) {
|
|||||||
updated.Runtime = auth.Runtime
|
updated.Runtime = auth.Runtime
|
||||||
}
|
}
|
||||||
updated.LastRefreshedAt = now
|
updated.LastRefreshedAt = now
|
||||||
updated.NextRefreshAfter = time.Time{}
|
// Preserve NextRefreshAfter set by the Authenticator
|
||||||
|
// If the Authenticator set a reasonable refresh time, it should not be overwritten
|
||||||
|
// If the Authenticator did not set it (zero value), shouldRefresh will use default logic
|
||||||
updated.LastError = nil
|
updated.LastError = nil
|
||||||
updated.UpdatedAt = now
|
updated.UpdatedAt = now
|
||||||
_, _ = m.Update(ctx, updated)
|
_, _ = m.Update(ctx, updated)
|
||||||
|
|||||||
Reference in New Issue
Block a user