mirror of
https://github.com/router-for-me/CLIProxyAPIPlus.git
synced 2026-03-09 23:33:24 +00:00
Compare commits
82 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
10e77fcf24 | ||
|
|
bbb21d7c2b | ||
|
|
f6720f8dfa | ||
|
|
e19ab3a066 | ||
|
|
8f1dd69e72 | ||
|
|
f26da24a2f | ||
|
|
407020de0c | ||
|
|
8e4fbcaa7d | ||
|
|
09c339953d | ||
|
|
367a05bdf6 | ||
|
|
d20b71deb9 | ||
|
|
712ce9f781 | ||
|
|
a4a3274a55 | ||
|
|
716aa71f6e | ||
|
|
e8976f9898 | ||
|
|
8496cc2444 | ||
|
|
5ef2d59e05 | ||
|
|
07bb89ae80 | ||
|
|
27a5ad8ec2 | ||
|
|
707b07c5f5 | ||
|
|
4a764afd76 | ||
|
|
ecf49d574b | ||
|
|
188de4ff2a | ||
|
|
5a75ef8ffd | ||
|
|
07279f8746 | ||
|
|
71f788b13a | ||
|
|
59c62dc580 | ||
|
|
8fb1f114bc | ||
|
|
6a4cff6699 | ||
|
|
d5310a3300 | ||
|
|
de0ea3ac49 | ||
|
|
12116b018d | ||
|
|
c3ed3b40ea | ||
|
|
b80c2aabb0 | ||
|
|
f0a3eb574e | ||
|
|
bb15855443 | ||
|
|
14ce6aebd1 | ||
|
|
2fe83723f2 | ||
|
|
e73b9e10a6 | ||
|
|
9c04c18c04 | ||
|
|
81ae09d0ec | ||
|
|
01cf221167 | ||
|
|
cd8c86c6fb | ||
|
|
52d5fd1a67 | ||
|
|
7ecc7aabda | ||
|
|
79033aee34 | ||
|
|
b6ad243e9e | ||
|
|
92ca5078c1 | ||
|
|
aca8523060 | ||
|
|
1ea0cff3a4 | ||
|
|
75793a18f0 | ||
|
|
58866b21cb | ||
|
|
660aabc437 | ||
|
|
db80b20bc2 | ||
|
|
566120e8d5 | ||
|
|
f3f0f1717d | ||
|
|
05b499fb83 | ||
|
|
7621ec609e | ||
|
|
9f511f0024 | ||
|
|
374faa2640 | ||
|
|
ba6aa5fbbe | ||
|
|
1c52a89535 | ||
|
|
e7cedbee6e | ||
|
|
75ce0919a0 | ||
|
|
7f4f6bc9ca | ||
|
|
b8194e717c | ||
|
|
15c3cc3a50 | ||
|
|
d131435e25 | ||
|
|
6e43669498 | ||
|
|
5ab3032335 | ||
|
|
1215c635a0 | ||
|
|
54d4fd7f84 | ||
|
|
8dc690a638 | ||
|
|
fdeb84db2b | ||
|
|
84920cb670 | ||
|
|
204bba9dea | ||
|
|
35fdd7bc05 | ||
|
|
fc054db51a | ||
|
|
6e2306a5f2 | ||
|
|
b09e2115d1 | ||
|
|
40e7f066e4 | ||
|
|
07d21463ca |
@@ -11,7 +11,7 @@ The Plus release stays in lockstep with the mainline features.
|
||||
## Differences from the Mainline
|
||||
|
||||
- Added GitHub Copilot support (OAuth login), provided by [em4go](https://github.com/em4go/CLIProxyAPI/tree/feature/github-copilot-auth)
|
||||
- Added Kiro (AWS CodeWhisperer) support (OAuth login), provided by [fuko2935](https://github.com/fuko2935/CLIProxyAPI/tree/feature/kiro-integration)
|
||||
- Added Kiro (AWS CodeWhisperer) support (OAuth login), provided by [fuko2935](https://github.com/fuko2935/CLIProxyAPI/tree/feature/kiro-integration), [Ravens2121](https://github.com/Ravens2121/CLIProxyAPIPlus/)
|
||||
|
||||
## Contributing
|
||||
|
||||
|
||||
@@ -11,7 +11,7 @@
|
||||
## 与主线版本版本差异
|
||||
|
||||
- 新增 GitHub Copilot 支持(OAuth 登录),由[em4go](https://github.com/em4go/CLIProxyAPI/tree/feature/github-copilot-auth)提供
|
||||
- 新增 Kiro (AWS CodeWhisperer) 支持 (OAuth 登录), 由[fuko2935](https://github.com/fuko2935/CLIProxyAPI/tree/feature/kiro-integration)提供
|
||||
- 新增 Kiro (AWS CodeWhisperer) 支持 (OAuth 登录), 由[fuko2935](https://github.com/fuko2935/CLIProxyAPI/tree/feature/kiro-integration)、[Ravens2121](https://github.com/Ravens2121/CLIProxyAPIPlus/)提供
|
||||
|
||||
## 贡献
|
||||
|
||||
|
||||
@@ -151,8 +151,8 @@ ws-auth: false
|
||||
# upstream-url: "https://ampcode.com"
|
||||
# # Optional: Override API key for Amp upstream (otherwise uses env or file)
|
||||
# upstream-api-key: ""
|
||||
# # Restrict Amp management routes (/api/auth, /api/user, etc.) to localhost only (recommended)
|
||||
# restrict-management-to-localhost: true
|
||||
# # Restrict Amp management routes (/api/auth, /api/user, etc.) to localhost only (default: false)
|
||||
# restrict-management-to-localhost: false
|
||||
# # Force model mappings to run before checking local API keys (default: false)
|
||||
# force-model-mappings: false
|
||||
# # Amp Model Mappings
|
||||
|
||||
@@ -3,6 +3,9 @@ package management
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"crypto/rand"
|
||||
"crypto/sha256"
|
||||
"encoding/base64"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
@@ -23,9 +26,11 @@ import (
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/auth/codex"
|
||||
geminiAuth "github.com/router-for-me/CLIProxyAPI/v6/internal/auth/gemini"
|
||||
iflowauth "github.com/router-for-me/CLIProxyAPI/v6/internal/auth/iflow"
|
||||
kiroauth "github.com/router-for-me/CLIProxyAPI/v6/internal/auth/kiro"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/auth/qwen"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/interfaces"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/misc"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/registry"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/util"
|
||||
sdkAuth "github.com/router-for-me/CLIProxyAPI/v6/sdk/auth"
|
||||
coreauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth"
|
||||
@@ -289,6 +294,54 @@ func (h *Handler) ListAuthFiles(c *gin.Context) {
|
||||
c.JSON(200, gin.H{"files": files})
|
||||
}
|
||||
|
||||
// GetAuthFileModels returns the models supported by a specific auth file
|
||||
func (h *Handler) GetAuthFileModels(c *gin.Context) {
|
||||
name := c.Query("name")
|
||||
if name == "" {
|
||||
c.JSON(400, gin.H{"error": "name is required"})
|
||||
return
|
||||
}
|
||||
|
||||
// Try to find auth ID via authManager
|
||||
var authID string
|
||||
if h.authManager != nil {
|
||||
auths := h.authManager.List()
|
||||
for _, auth := range auths {
|
||||
if auth.FileName == name || auth.ID == name {
|
||||
authID = auth.ID
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if authID == "" {
|
||||
authID = name // fallback to filename as ID
|
||||
}
|
||||
|
||||
// Get models from registry
|
||||
reg := registry.GetGlobalRegistry()
|
||||
models := reg.GetModelsForClient(authID)
|
||||
|
||||
result := make([]gin.H, 0, len(models))
|
||||
for _, m := range models {
|
||||
entry := gin.H{
|
||||
"id": m.ID,
|
||||
}
|
||||
if m.DisplayName != "" {
|
||||
entry["display_name"] = m.DisplayName
|
||||
}
|
||||
if m.Type != "" {
|
||||
entry["type"] = m.Type
|
||||
}
|
||||
if m.OwnedBy != "" {
|
||||
entry["owned_by"] = m.OwnedBy
|
||||
}
|
||||
result = append(result, entry)
|
||||
}
|
||||
|
||||
c.JSON(200, gin.H{"models": result})
|
||||
}
|
||||
|
||||
// List auth files from disk when the auth manager is unavailable.
|
||||
func (h *Handler) listAuthFilesFromDisk(c *gin.Context) {
|
||||
entries, err := os.ReadDir(h.cfg.AuthDir)
|
||||
@@ -1745,6 +1798,17 @@ func (h *Handler) RequestIFlowCookieToken(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
// Check for duplicate BXAuth before authentication
|
||||
bxAuth := iflowauth.ExtractBXAuth(cookieValue)
|
||||
if existingFile, err := iflowauth.CheckDuplicateBXAuth(h.cfg.AuthDir, bxAuth); err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"status": "error", "error": "failed to check duplicate"})
|
||||
return
|
||||
} else if existingFile != "" {
|
||||
existingFileName := filepath.Base(existingFile)
|
||||
c.JSON(http.StatusConflict, gin.H{"status": "error", "error": "duplicate BXAuth found", "existing_file": existingFileName})
|
||||
return
|
||||
}
|
||||
|
||||
authSvc := iflowauth.NewIFlowAuth(h.cfg)
|
||||
tokenData, errAuth := authSvc.AuthenticateWithCookie(ctx, cookieValue)
|
||||
if errAuth != nil {
|
||||
@@ -1767,11 +1831,12 @@ func (h *Handler) RequestIFlowCookieToken(c *gin.Context) {
|
||||
}
|
||||
|
||||
tokenStorage.Email = email
|
||||
timestamp := time.Now().Unix()
|
||||
|
||||
record := &coreauth.Auth{
|
||||
ID: fmt.Sprintf("iflow-%s.json", fileName),
|
||||
ID: fmt.Sprintf("iflow-%s-%d.json", fileName, timestamp),
|
||||
Provider: "iflow",
|
||||
FileName: fmt.Sprintf("iflow-%s.json", fileName),
|
||||
FileName: fmt.Sprintf("iflow-%s-%d.json", fileName, timestamp),
|
||||
Storage: tokenStorage,
|
||||
Metadata: map[string]any{
|
||||
"email": email,
|
||||
@@ -2142,9 +2207,35 @@ func checkCloudAPIIsEnabled(ctx context.Context, httpClient *http.Client, projec
|
||||
|
||||
func (h *Handler) GetAuthStatus(c *gin.Context) {
|
||||
state := c.Query("state")
|
||||
if err, ok := getOAuthStatus(state); ok {
|
||||
if err != "" {
|
||||
c.JSON(200, gin.H{"status": "error", "error": err})
|
||||
if statusValue, ok := getOAuthStatus(state); ok {
|
||||
if statusValue != "" {
|
||||
// Check for device_code prefix (Kiro AWS Builder ID flow)
|
||||
// Format: "device_code|verification_url|user_code"
|
||||
// Using "|" as separator because URLs contain ":"
|
||||
if strings.HasPrefix(statusValue, "device_code|") {
|
||||
parts := strings.SplitN(statusValue, "|", 3)
|
||||
if len(parts) == 3 {
|
||||
c.JSON(200, gin.H{
|
||||
"status": "device_code",
|
||||
"verification_url": parts[1],
|
||||
"user_code": parts[2],
|
||||
})
|
||||
return
|
||||
}
|
||||
}
|
||||
// Check for auth_url prefix (Kiro social auth flow)
|
||||
// Format: "auth_url|url"
|
||||
// Using "|" as separator because URLs contain ":"
|
||||
if strings.HasPrefix(statusValue, "auth_url|") {
|
||||
authURL := strings.TrimPrefix(statusValue, "auth_url|")
|
||||
c.JSON(200, gin.H{
|
||||
"status": "auth_url",
|
||||
"url": authURL,
|
||||
})
|
||||
return
|
||||
}
|
||||
// Otherwise treat as error
|
||||
c.JSON(200, gin.H{"status": "error", "error": statusValue})
|
||||
} else {
|
||||
c.JSON(200, gin.H{"status": "wait"})
|
||||
return
|
||||
@@ -2154,3 +2245,295 @@ func (h *Handler) GetAuthStatus(c *gin.Context) {
|
||||
}
|
||||
deleteOAuthStatus(state)
|
||||
}
|
||||
|
||||
const kiroCallbackPort = 9876
|
||||
|
||||
func (h *Handler) RequestKiroToken(c *gin.Context) {
|
||||
ctx := context.Background()
|
||||
|
||||
// Get the login method from query parameter (default: aws for device code flow)
|
||||
method := strings.ToLower(strings.TrimSpace(c.Query("method")))
|
||||
if method == "" {
|
||||
method = "aws"
|
||||
}
|
||||
|
||||
fmt.Println("Initializing Kiro authentication...")
|
||||
|
||||
state := fmt.Sprintf("kiro-%d", time.Now().UnixNano())
|
||||
|
||||
switch method {
|
||||
case "aws", "builder-id":
|
||||
// AWS Builder ID uses device code flow (no callback needed)
|
||||
go func() {
|
||||
ssoClient := kiroauth.NewSSOOIDCClient(h.cfg)
|
||||
|
||||
// Step 1: Register client
|
||||
fmt.Println("Registering client...")
|
||||
regResp, err := ssoClient.RegisterClient(ctx)
|
||||
if err != nil {
|
||||
log.Errorf("Failed to register client: %v", err)
|
||||
setOAuthStatus(state, "Failed to register client")
|
||||
return
|
||||
}
|
||||
|
||||
// Step 2: Start device authorization
|
||||
fmt.Println("Starting device authorization...")
|
||||
authResp, err := ssoClient.StartDeviceAuthorization(ctx, regResp.ClientID, regResp.ClientSecret)
|
||||
if err != nil {
|
||||
log.Errorf("Failed to start device auth: %v", err)
|
||||
setOAuthStatus(state, "Failed to start device authorization")
|
||||
return
|
||||
}
|
||||
|
||||
// Store the verification URL for the frontend to display
|
||||
// Using "|" as separator because URLs contain ":"
|
||||
setOAuthStatus(state, "device_code|"+authResp.VerificationURIComplete+"|"+authResp.UserCode)
|
||||
|
||||
// Step 3: Poll for token
|
||||
fmt.Println("Waiting for authorization...")
|
||||
interval := 5 * time.Second
|
||||
if authResp.Interval > 0 {
|
||||
interval = time.Duration(authResp.Interval) * time.Second
|
||||
}
|
||||
deadline := time.Now().Add(time.Duration(authResp.ExpiresIn) * time.Second)
|
||||
|
||||
for time.Now().Before(deadline) {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
setOAuthStatus(state, "Authorization cancelled")
|
||||
return
|
||||
case <-time.After(interval):
|
||||
tokenResp, err := ssoClient.CreateToken(ctx, regResp.ClientID, regResp.ClientSecret, authResp.DeviceCode)
|
||||
if err != nil {
|
||||
errStr := err.Error()
|
||||
if strings.Contains(errStr, "authorization_pending") {
|
||||
continue
|
||||
}
|
||||
if strings.Contains(errStr, "slow_down") {
|
||||
interval += 5 * time.Second
|
||||
continue
|
||||
}
|
||||
log.Errorf("Token creation failed: %v", err)
|
||||
setOAuthStatus(state, "Token creation failed")
|
||||
return
|
||||
}
|
||||
|
||||
// Success! Save the token
|
||||
expiresAt := time.Now().Add(time.Duration(tokenResp.ExpiresIn) * time.Second)
|
||||
email := kiroauth.ExtractEmailFromJWT(tokenResp.AccessToken)
|
||||
|
||||
idPart := kiroauth.SanitizeEmailForFilename(email)
|
||||
if idPart == "" {
|
||||
idPart = fmt.Sprintf("%d", time.Now().UnixNano()%100000)
|
||||
}
|
||||
|
||||
now := time.Now()
|
||||
fileName := fmt.Sprintf("kiro-aws-%s.json", idPart)
|
||||
|
||||
record := &coreauth.Auth{
|
||||
ID: fileName,
|
||||
Provider: "kiro",
|
||||
FileName: fileName,
|
||||
Metadata: map[string]any{
|
||||
"type": "kiro",
|
||||
"access_token": tokenResp.AccessToken,
|
||||
"refresh_token": tokenResp.RefreshToken,
|
||||
"expires_at": expiresAt.Format(time.RFC3339),
|
||||
"auth_method": "builder-id",
|
||||
"provider": "AWS",
|
||||
"client_id": regResp.ClientID,
|
||||
"client_secret": regResp.ClientSecret,
|
||||
"email": email,
|
||||
"last_refresh": now.Format(time.RFC3339),
|
||||
},
|
||||
}
|
||||
|
||||
savedPath, errSave := h.saveTokenRecord(ctx, record)
|
||||
if errSave != nil {
|
||||
log.Errorf("Failed to save authentication tokens: %v", errSave)
|
||||
setOAuthStatus(state, "Failed to save authentication tokens")
|
||||
return
|
||||
}
|
||||
|
||||
fmt.Printf("Authentication successful! Token saved to %s\n", savedPath)
|
||||
if email != "" {
|
||||
fmt.Printf("Authenticated as: %s\n", email)
|
||||
}
|
||||
deleteOAuthStatus(state)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
setOAuthStatus(state, "Authorization timed out")
|
||||
}()
|
||||
|
||||
// Return immediately with the state for polling
|
||||
c.JSON(200, gin.H{"status": "ok", "state": state, "method": "device_code"})
|
||||
|
||||
case "google", "github":
|
||||
// Social auth uses protocol handler - for WEB UI we use a callback forwarder
|
||||
provider := "Google"
|
||||
if method == "github" {
|
||||
provider = "Github"
|
||||
}
|
||||
|
||||
isWebUI := isWebUIRequest(c)
|
||||
if isWebUI {
|
||||
targetURL, errTarget := h.managementCallbackURL("/kiro/callback")
|
||||
if errTarget != nil {
|
||||
log.WithError(errTarget).Error("failed to compute kiro callback target")
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "callback server unavailable"})
|
||||
return
|
||||
}
|
||||
if _, errStart := startCallbackForwarder(kiroCallbackPort, "kiro", targetURL); errStart != nil {
|
||||
log.WithError(errStart).Error("failed to start kiro callback forwarder")
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "failed to start callback server"})
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
go func() {
|
||||
if isWebUI {
|
||||
defer stopCallbackForwarder(kiroCallbackPort)
|
||||
}
|
||||
|
||||
socialClient := kiroauth.NewSocialAuthClient(h.cfg)
|
||||
|
||||
// Generate PKCE codes
|
||||
codeVerifier, codeChallenge, err := generateKiroPKCE()
|
||||
if err != nil {
|
||||
log.Errorf("Failed to generate PKCE: %v", err)
|
||||
setOAuthStatus(state, "Failed to generate PKCE")
|
||||
return
|
||||
}
|
||||
|
||||
// Build login URL
|
||||
authURL := fmt.Sprintf("%s/login?idp=%s&redirect_uri=%s&code_challenge=%s&code_challenge_method=S256&state=%s&prompt=select_account",
|
||||
"https://prod.us-east-1.auth.desktop.kiro.dev",
|
||||
provider,
|
||||
url.QueryEscape(kiroauth.KiroRedirectURI),
|
||||
codeChallenge,
|
||||
state,
|
||||
)
|
||||
|
||||
// Store auth URL for frontend
|
||||
// Using "|" as separator because URLs contain ":"
|
||||
setOAuthStatus(state, "auth_url|"+authURL)
|
||||
|
||||
// Wait for callback file
|
||||
waitFile := filepath.Join(h.cfg.AuthDir, fmt.Sprintf(".oauth-kiro-%s.oauth", state))
|
||||
deadline := time.Now().Add(5 * time.Minute)
|
||||
|
||||
for {
|
||||
if time.Now().After(deadline) {
|
||||
log.Error("oauth flow timed out")
|
||||
setOAuthStatus(state, "OAuth flow timed out")
|
||||
return
|
||||
}
|
||||
if data, errR := os.ReadFile(waitFile); errR == nil {
|
||||
var m map[string]string
|
||||
_ = json.Unmarshal(data, &m)
|
||||
_ = os.Remove(waitFile)
|
||||
if errStr := m["error"]; errStr != "" {
|
||||
log.Errorf("Authentication failed: %s", errStr)
|
||||
setOAuthStatus(state, "Authentication failed")
|
||||
return
|
||||
}
|
||||
if m["state"] != state {
|
||||
log.Errorf("State mismatch")
|
||||
setOAuthStatus(state, "State mismatch")
|
||||
return
|
||||
}
|
||||
code := m["code"]
|
||||
if code == "" {
|
||||
log.Error("No authorization code received")
|
||||
setOAuthStatus(state, "No authorization code received")
|
||||
return
|
||||
}
|
||||
|
||||
// Exchange code for tokens
|
||||
tokenReq := &kiroauth.CreateTokenRequest{
|
||||
Code: code,
|
||||
CodeVerifier: codeVerifier,
|
||||
RedirectURI: kiroauth.KiroRedirectURI,
|
||||
}
|
||||
|
||||
tokenResp, errToken := socialClient.CreateToken(ctx, tokenReq)
|
||||
if errToken != nil {
|
||||
log.Errorf("Failed to exchange code for tokens: %v", errToken)
|
||||
setOAuthStatus(state, "Failed to exchange code for tokens")
|
||||
return
|
||||
}
|
||||
|
||||
// Save the token
|
||||
expiresIn := tokenResp.ExpiresIn
|
||||
if expiresIn <= 0 {
|
||||
expiresIn = 3600
|
||||
}
|
||||
expiresAt := time.Now().Add(time.Duration(expiresIn) * time.Second)
|
||||
email := kiroauth.ExtractEmailFromJWT(tokenResp.AccessToken)
|
||||
|
||||
idPart := kiroauth.SanitizeEmailForFilename(email)
|
||||
if idPart == "" {
|
||||
idPart = fmt.Sprintf("%d", time.Now().UnixNano()%100000)
|
||||
}
|
||||
|
||||
now := time.Now()
|
||||
fileName := fmt.Sprintf("kiro-%s-%s.json", strings.ToLower(provider), idPart)
|
||||
|
||||
record := &coreauth.Auth{
|
||||
ID: fileName,
|
||||
Provider: "kiro",
|
||||
FileName: fileName,
|
||||
Metadata: map[string]any{
|
||||
"type": "kiro",
|
||||
"access_token": tokenResp.AccessToken,
|
||||
"refresh_token": tokenResp.RefreshToken,
|
||||
"profile_arn": tokenResp.ProfileArn,
|
||||
"expires_at": expiresAt.Format(time.RFC3339),
|
||||
"auth_method": "social",
|
||||
"provider": provider,
|
||||
"email": email,
|
||||
"last_refresh": now.Format(time.RFC3339),
|
||||
},
|
||||
}
|
||||
|
||||
savedPath, errSave := h.saveTokenRecord(ctx, record)
|
||||
if errSave != nil {
|
||||
log.Errorf("Failed to save authentication tokens: %v", errSave)
|
||||
setOAuthStatus(state, "Failed to save authentication tokens")
|
||||
return
|
||||
}
|
||||
|
||||
fmt.Printf("Authentication successful! Token saved to %s\n", savedPath)
|
||||
if email != "" {
|
||||
fmt.Printf("Authenticated as: %s\n", email)
|
||||
}
|
||||
deleteOAuthStatus(state)
|
||||
return
|
||||
}
|
||||
time.Sleep(500 * time.Millisecond)
|
||||
}
|
||||
}()
|
||||
|
||||
setOAuthStatus(state, "")
|
||||
c.JSON(200, gin.H{"status": "ok", "state": state, "method": "social"})
|
||||
|
||||
default:
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "invalid method, use 'aws', 'google', or 'github'"})
|
||||
}
|
||||
}
|
||||
|
||||
// generateKiroPKCE generates PKCE code verifier and challenge for Kiro OAuth.
|
||||
func generateKiroPKCE() (verifier, challenge string, err error) {
|
||||
b := make([]byte, 32)
|
||||
if _, err := io.ReadFull(rand.Reader, b); err != nil {
|
||||
return "", "", fmt.Errorf("failed to generate random bytes: %w", err)
|
||||
}
|
||||
verifier = base64.RawURLEncoding.EncodeToString(b)
|
||||
|
||||
h := sha256.Sum256([]byte(verifier))
|
||||
challenge = base64.RawURLEncoding.EncodeToString(h[:])
|
||||
|
||||
return verifier, challenge, nil
|
||||
}
|
||||
|
||||
@@ -137,7 +137,8 @@ func (m *AmpModule) Register(ctx modules.Context) error {
|
||||
m.registerProviderAliases(ctx.Engine, ctx.BaseHandler, auth)
|
||||
|
||||
// Register management proxy routes once; middleware will gate access when upstream is unavailable.
|
||||
m.registerManagementRoutes(ctx.Engine, ctx.BaseHandler)
|
||||
// Pass auth middleware to require valid API key for all management routes.
|
||||
m.registerManagementRoutes(ctx.Engine, ctx.BaseHandler, auth)
|
||||
|
||||
// If no upstream URL, skip proxy routes but provider aliases are still available
|
||||
if upstreamURL == "" {
|
||||
@@ -187,9 +188,6 @@ func (m *AmpModule) OnConfigUpdated(cfg *config.Config) error {
|
||||
|
||||
if oldSettings != nil && oldSettings.RestrictManagementToLocalhost != newSettings.RestrictManagementToLocalhost {
|
||||
m.setRestrictToLocalhost(newSettings.RestrictManagementToLocalhost)
|
||||
if !newSettings.RestrictManagementToLocalhost {
|
||||
log.Warnf("amp management routes now accessible from any IP - this is insecure!")
|
||||
}
|
||||
}
|
||||
|
||||
newUpstreamURL := strings.TrimSpace(newSettings.UpstreamURL)
|
||||
|
||||
@@ -64,7 +64,7 @@ func logAmpRouting(routeType AmpRouteType, requestedModel, resolvedModel, provid
|
||||
fields["cost"] = "amp_credits"
|
||||
fields["source"] = "ampcode.com"
|
||||
fields["model_id"] = requestedModel // Explicit model_id for easy config reference
|
||||
log.WithFields(fields).Warnf("forwarding to ampcode.com (uses amp credits) - model_id: %s | To use local proxy, add to config: amp-model-mappings: [{from: \"%s\", to: \"<your-local-model>\"}]", requestedModel, requestedModel)
|
||||
log.WithFields(fields).Warnf("forwarding to ampcode.com (uses amp credits) - model_id: %s | To use local provider, add to config: ampcode.model-mappings: [{from: \"%s\", to: \"<your-local-model>\"}]", requestedModel, requestedModel)
|
||||
|
||||
case RouteTypeNoProvider:
|
||||
fields["cost"] = "none"
|
||||
|
||||
@@ -3,8 +3,11 @@ package amp
|
||||
import (
|
||||
"bytes"
|
||||
"compress/gzip"
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/http/httputil"
|
||||
"net/url"
|
||||
@@ -41,6 +44,11 @@ func createReverseProxy(upstreamURL string, secretSource SecretSource) (*httputi
|
||||
originalDirector(req)
|
||||
req.Host = parsed.Host
|
||||
|
||||
// Remove client's Authorization header - it was only used for CLI Proxy API authentication
|
||||
// We will set our own Authorization using the configured upstream-api-key
|
||||
req.Header.Del("Authorization")
|
||||
req.Header.Del("X-Api-Key")
|
||||
|
||||
// Preserve correlation headers for debugging
|
||||
if req.Header.Get("X-Request-ID") == "" {
|
||||
// Could generate one here if needed
|
||||
@@ -50,7 +58,7 @@ func createReverseProxy(upstreamURL string, secretSource SecretSource) (*httputi
|
||||
// Users going through ampcode.com proxy are paying for the service and should get all features
|
||||
// including 1M context window (context-1m-2025-08-07)
|
||||
|
||||
// Inject API key from secret source (precedence: config > env > file)
|
||||
// Inject API key from secret source (only uses upstream-api-key from config)
|
||||
if key, err := secretSource.Get(req.Context()); err == nil && key != "" {
|
||||
req.Header.Set("X-Api-Key", key)
|
||||
req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", key))
|
||||
@@ -62,7 +70,15 @@ func createReverseProxy(upstreamURL string, secretSource SecretSource) (*httputi
|
||||
// Modify incoming responses to handle gzip without Content-Encoding
|
||||
// This addresses the same issue as inline handler gzip handling, but at the proxy level
|
||||
proxy.ModifyResponse = func(resp *http.Response) error {
|
||||
// Only process successful responses
|
||||
// Log upstream error responses for diagnostics (502, 503, etc.)
|
||||
// These are NOT proxy connection errors - the upstream responded with an error status
|
||||
if resp.StatusCode >= 500 {
|
||||
log.Errorf("amp upstream responded with error [%d] for %s %s", resp.StatusCode, resp.Request.Method, resp.Request.URL.Path)
|
||||
} else if resp.StatusCode >= 400 {
|
||||
log.Warnf("amp upstream responded with client error [%d] for %s %s", resp.StatusCode, resp.Request.Method, resp.Request.URL.Path)
|
||||
}
|
||||
|
||||
// Only process successful responses for gzip decompression
|
||||
if resp.StatusCode < 200 || resp.StatusCode >= 300 {
|
||||
return nil
|
||||
}
|
||||
@@ -146,9 +162,29 @@ func createReverseProxy(upstreamURL string, secretSource SecretSource) (*httputi
|
||||
return nil
|
||||
}
|
||||
|
||||
// Error handler for proxy failures
|
||||
// Error handler for proxy failures with detailed error classification for diagnostics
|
||||
proxy.ErrorHandler = func(rw http.ResponseWriter, req *http.Request, err error) {
|
||||
log.Errorf("amp upstream proxy error for %s %s: %v", req.Method, req.URL.Path, err)
|
||||
// Classify the error type for better diagnostics
|
||||
var errType string
|
||||
if errors.Is(err, context.DeadlineExceeded) {
|
||||
errType = "timeout"
|
||||
} else if errors.Is(err, context.Canceled) {
|
||||
errType = "canceled"
|
||||
} else if netErr, ok := err.(net.Error); ok && netErr.Timeout() {
|
||||
errType = "dial_timeout"
|
||||
} else if _, ok := err.(net.Error); ok {
|
||||
errType = "network_error"
|
||||
} else {
|
||||
errType = "connection_error"
|
||||
}
|
||||
|
||||
// Don't log as error for context canceled - it's usually client closing connection
|
||||
if errors.Is(err, context.Canceled) {
|
||||
log.Debugf("amp upstream proxy [%s]: client canceled request for %s %s", errType, req.Method, req.URL.Path)
|
||||
} else {
|
||||
log.Errorf("amp upstream proxy error [%s] for %s %s: %v", errType, req.Method, req.URL.Path, err)
|
||||
}
|
||||
|
||||
rw.Header().Set("Content-Type", "application/json")
|
||||
rw.WriteHeader(http.StatusBadGateway)
|
||||
_, _ = rw.Write([]byte(`{"error":"amp_upstream_proxy_error","message":"Failed to reach Amp upstream"}`))
|
||||
|
||||
@@ -29,17 +29,79 @@ func NewResponseRewriter(w gin.ResponseWriter, originalModel string) *ResponseRe
|
||||
}
|
||||
}
|
||||
|
||||
const maxBufferedResponseBytes = 2 * 1024 * 1024 // 2MB safety cap
|
||||
|
||||
func looksLikeSSEChunk(data []byte) bool {
|
||||
// Fallback detection: some upstreams may omit/lie about Content-Type, causing SSE to be buffered.
|
||||
// Heuristics are intentionally simple and cheap.
|
||||
return bytes.Contains(data, []byte("data:")) ||
|
||||
bytes.Contains(data, []byte("event:")) ||
|
||||
bytes.Contains(data, []byte("message_start")) ||
|
||||
bytes.Contains(data, []byte("message_delta")) ||
|
||||
bytes.Contains(data, []byte("content_block_start")) ||
|
||||
bytes.Contains(data, []byte("content_block_delta")) ||
|
||||
bytes.Contains(data, []byte("content_block_stop")) ||
|
||||
bytes.Contains(data, []byte("\n\n"))
|
||||
}
|
||||
|
||||
func (rw *ResponseRewriter) enableStreaming(reason string) error {
|
||||
if rw.isStreaming {
|
||||
return nil
|
||||
}
|
||||
rw.isStreaming = true
|
||||
|
||||
// Flush any previously buffered data to avoid reordering or data loss.
|
||||
if rw.body != nil && rw.body.Len() > 0 {
|
||||
buf := rw.body.Bytes()
|
||||
// Copy before Reset() to keep bytes stable.
|
||||
toFlush := make([]byte, len(buf))
|
||||
copy(toFlush, buf)
|
||||
rw.body.Reset()
|
||||
|
||||
if _, err := rw.ResponseWriter.Write(rw.rewriteStreamChunk(toFlush)); err != nil {
|
||||
return err
|
||||
}
|
||||
if flusher, ok := rw.ResponseWriter.(http.Flusher); ok {
|
||||
flusher.Flush()
|
||||
}
|
||||
}
|
||||
|
||||
log.Debugf("amp response rewriter: switched to streaming (%s)", reason)
|
||||
return nil
|
||||
}
|
||||
|
||||
// Write intercepts response writes and buffers them for model name replacement
|
||||
func (rw *ResponseRewriter) Write(data []byte) (int, error) {
|
||||
// Detect streaming on first write
|
||||
if rw.body.Len() == 0 && !rw.isStreaming {
|
||||
// Detect streaming on first write (header-based)
|
||||
if !rw.isStreaming && rw.body.Len() == 0 {
|
||||
contentType := rw.Header().Get("Content-Type")
|
||||
rw.isStreaming = strings.Contains(contentType, "text/event-stream") ||
|
||||
strings.Contains(contentType, "stream")
|
||||
}
|
||||
|
||||
if !rw.isStreaming {
|
||||
// Content-based fallback: detect SSE-like chunks even if Content-Type is missing/wrong.
|
||||
if looksLikeSSEChunk(data) {
|
||||
if err := rw.enableStreaming("sse heuristic"); err != nil {
|
||||
return 0, err
|
||||
}
|
||||
} else if rw.body.Len()+len(data) > maxBufferedResponseBytes {
|
||||
// Safety cap: avoid unbounded buffering on large responses.
|
||||
log.Warnf("amp response rewriter: buffer exceeded %d bytes, switching to streaming", maxBufferedResponseBytes)
|
||||
if err := rw.enableStreaming("buffer limit"); err != nil {
|
||||
return 0, err
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if rw.isStreaming {
|
||||
return rw.ResponseWriter.Write(rw.rewriteStreamChunk(data))
|
||||
n, err := rw.ResponseWriter.Write(rw.rewriteStreamChunk(data))
|
||||
if err == nil {
|
||||
if flusher, ok := rw.ResponseWriter.(http.Flusher); ok {
|
||||
flusher.Flush()
|
||||
}
|
||||
}
|
||||
return n, err
|
||||
}
|
||||
return rw.body.Write(data)
|
||||
}
|
||||
|
||||
@@ -98,7 +98,8 @@ func (m *AmpModule) managementAvailabilityMiddleware() gin.HandlerFunc {
|
||||
// registerManagementRoutes registers Amp management proxy routes
|
||||
// These routes proxy through to the Amp control plane for OAuth, user management, etc.
|
||||
// Uses dynamic middleware and proxy getter for hot-reload support.
|
||||
func (m *AmpModule) registerManagementRoutes(engine *gin.Engine, baseHandler *handlers.BaseAPIHandler) {
|
||||
// The auth middleware validates Authorization header against configured API keys.
|
||||
func (m *AmpModule) registerManagementRoutes(engine *gin.Engine, baseHandler *handlers.BaseAPIHandler, auth gin.HandlerFunc) {
|
||||
ampAPI := engine.Group("/api")
|
||||
|
||||
// Always disable CORS for management routes to prevent browser-based attacks
|
||||
@@ -107,8 +108,9 @@ func (m *AmpModule) registerManagementRoutes(engine *gin.Engine, baseHandler *ha
|
||||
// Apply dynamic localhost-only restriction (hot-reloadable via m.IsRestrictedToLocalhost())
|
||||
ampAPI.Use(m.localhostOnlyMiddleware())
|
||||
|
||||
if !m.IsRestrictedToLocalhost() {
|
||||
log.Warn("amp management routes are NOT restricted to localhost - this is insecure!")
|
||||
// Apply authentication middleware - requires valid API key in Authorization header
|
||||
if auth != nil {
|
||||
ampAPI.Use(auth)
|
||||
}
|
||||
|
||||
// Dynamic proxy handler that uses m.getProxy() for hot-reload support
|
||||
@@ -154,6 +156,9 @@ func (m *AmpModule) registerManagementRoutes(engine *gin.Engine, baseHandler *ha
|
||||
// Root-level routes that AMP CLI expects without /api prefix
|
||||
// These need the same security middleware as the /api/* routes (dynamic for hot-reload)
|
||||
rootMiddleware := []gin.HandlerFunc{m.managementAvailabilityMiddleware(), noCORSMiddleware(), m.localhostOnlyMiddleware()}
|
||||
if auth != nil {
|
||||
rootMiddleware = append(rootMiddleware, auth)
|
||||
}
|
||||
engine.GET("/threads/*path", append(rootMiddleware, proxyHandler)...)
|
||||
engine.GET("/threads.rss", append(rootMiddleware, proxyHandler)...)
|
||||
engine.GET("/news.rss", append(rootMiddleware, proxyHandler)...)
|
||||
|
||||
@@ -349,6 +349,12 @@ func (s *Server) setupRoutes() {
|
||||
},
|
||||
})
|
||||
})
|
||||
|
||||
// Event logging endpoint - handles Claude Code telemetry requests
|
||||
// Returns 200 OK to prevent 404 errors in logs
|
||||
s.engine.POST("/api/event_logging/batch", func(c *gin.Context) {
|
||||
c.JSON(http.StatusOK, gin.H{"status": "ok"})
|
||||
})
|
||||
s.engine.POST("/v1internal:method", geminiCLIHandlers.CLIHandler)
|
||||
|
||||
// OAuth callback endpoints (reuse main server port)
|
||||
@@ -415,6 +421,18 @@ func (s *Server) setupRoutes() {
|
||||
c.String(http.StatusOK, oauthCallbackSuccessHTML)
|
||||
})
|
||||
|
||||
s.engine.GET("/kiro/callback", func(c *gin.Context) {
|
||||
code := c.Query("code")
|
||||
state := c.Query("state")
|
||||
errStr := c.Query("error")
|
||||
if state != "" {
|
||||
file := fmt.Sprintf("%s/.oauth-kiro-%s.oauth", s.cfg.AuthDir, state)
|
||||
_ = os.WriteFile(file, []byte(fmt.Sprintf(`{"code":"%s","state":"%s","error":"%s"}`, code, state, errStr)), 0o600)
|
||||
}
|
||||
c.Header("Content-Type", "text/html; charset=utf-8")
|
||||
c.String(http.StatusOK, oauthCallbackSuccessHTML)
|
||||
})
|
||||
|
||||
// Management routes are registered lazily by registerManagementRoutes when a secret is configured.
|
||||
}
|
||||
|
||||
@@ -568,6 +586,7 @@ func (s *Server) registerManagementRoutes() {
|
||||
mgmt.DELETE("/oauth-excluded-models", s.mgmt.DeleteOAuthExcludedModels)
|
||||
|
||||
mgmt.GET("/auth-files", s.mgmt.ListAuthFiles)
|
||||
mgmt.GET("/auth-files/models", s.mgmt.GetAuthFileModels)
|
||||
mgmt.GET("/auth-files/download", s.mgmt.DownloadAuthFile)
|
||||
mgmt.POST("/auth-files", s.mgmt.UploadAuthFile)
|
||||
mgmt.DELETE("/auth-files", s.mgmt.DeleteAuthFile)
|
||||
@@ -580,6 +599,7 @@ func (s *Server) registerManagementRoutes() {
|
||||
mgmt.GET("/qwen-auth-url", s.mgmt.RequestQwenToken)
|
||||
mgmt.GET("/iflow-auth-url", s.mgmt.RequestIFlowToken)
|
||||
mgmt.POST("/iflow-auth-url", s.mgmt.RequestIFlowCookieToken)
|
||||
mgmt.GET("/kiro-auth-url", s.mgmt.RequestKiroToken)
|
||||
mgmt.GET("/get-auth-status", s.mgmt.GetAuthStatus)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,7 +1,10 @@
|
||||
package iflow
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
)
|
||||
|
||||
@@ -36,3 +39,61 @@ func SanitizeIFlowFileName(raw string) string {
|
||||
}
|
||||
return strings.TrimSpace(result.String())
|
||||
}
|
||||
|
||||
// ExtractBXAuth extracts the BXAuth value from a cookie string.
|
||||
func ExtractBXAuth(cookie string) string {
|
||||
parts := strings.Split(cookie, ";")
|
||||
for _, part := range parts {
|
||||
part = strings.TrimSpace(part)
|
||||
if strings.HasPrefix(part, "BXAuth=") {
|
||||
return strings.TrimPrefix(part, "BXAuth=")
|
||||
}
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
// CheckDuplicateBXAuth checks if the given BXAuth value already exists in any iflow auth file.
|
||||
// Returns the path of the existing file if found, empty string otherwise.
|
||||
func CheckDuplicateBXAuth(authDir, bxAuth string) (string, error) {
|
||||
if bxAuth == "" {
|
||||
return "", nil
|
||||
}
|
||||
|
||||
entries, err := os.ReadDir(authDir)
|
||||
if err != nil {
|
||||
if os.IsNotExist(err) {
|
||||
return "", nil
|
||||
}
|
||||
return "", fmt.Errorf("read auth dir failed: %w", err)
|
||||
}
|
||||
|
||||
for _, entry := range entries {
|
||||
if entry.IsDir() {
|
||||
continue
|
||||
}
|
||||
name := entry.Name()
|
||||
if !strings.HasPrefix(name, "iflow-") || !strings.HasSuffix(name, ".json") {
|
||||
continue
|
||||
}
|
||||
|
||||
filePath := filepath.Join(authDir, name)
|
||||
data, err := os.ReadFile(filePath)
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
|
||||
var tokenData struct {
|
||||
Cookie string `json:"cookie"`
|
||||
}
|
||||
if err := json.Unmarshal(data, &tokenData); err != nil {
|
||||
continue
|
||||
}
|
||||
|
||||
existingBXAuth := ExtractBXAuth(tokenData.Cookie)
|
||||
if existingBXAuth != "" && existingBXAuth == bxAuth {
|
||||
return filePath, nil
|
||||
}
|
||||
}
|
||||
|
||||
return "", nil
|
||||
}
|
||||
|
||||
@@ -506,11 +506,18 @@ func (ia *IFlowAuth) CreateCookieTokenStorage(data *IFlowTokenData) *IFlowTokenS
|
||||
return nil
|
||||
}
|
||||
|
||||
// Only save the BXAuth field from the cookie
|
||||
bxAuth := ExtractBXAuth(data.Cookie)
|
||||
cookieToSave := ""
|
||||
if bxAuth != "" {
|
||||
cookieToSave = "BXAuth=" + bxAuth + ";"
|
||||
}
|
||||
|
||||
return &IFlowTokenStorage{
|
||||
APIKey: data.APIKey,
|
||||
Email: data.Email,
|
||||
Expire: data.Expire,
|
||||
Cookie: data.Cookie,
|
||||
Cookie: cookieToSave,
|
||||
LastRefresh: time.Now().Format(time.RFC3339),
|
||||
Type: "iflow",
|
||||
}
|
||||
|
||||
@@ -126,8 +126,8 @@ func (c *SocialAuthClient) buildLoginURL(provider, redirectURI, codeChallenge, s
|
||||
)
|
||||
}
|
||||
|
||||
// createToken exchanges the authorization code for tokens.
|
||||
func (c *SocialAuthClient) createToken(ctx context.Context, req *CreateTokenRequest) (*SocialTokenResponse, error) {
|
||||
// CreateToken exchanges the authorization code for tokens.
|
||||
func (c *SocialAuthClient) CreateToken(ctx context.Context, req *CreateTokenRequest) (*SocialTokenResponse, error) {
|
||||
body, err := json.Marshal(req)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to marshal token request: %w", err)
|
||||
@@ -326,7 +326,7 @@ func (c *SocialAuthClient) LoginWithSocial(ctx context.Context, provider SocialP
|
||||
RedirectURI: KiroRedirectURI,
|
||||
}
|
||||
|
||||
tokenResp, err := c.createToken(ctx, tokenReq)
|
||||
tokenResp, err := c.CreateToken(ctx, tokenReq)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to exchange code for tokens: %w", err)
|
||||
}
|
||||
|
||||
@@ -5,7 +5,9 @@ import (
|
||||
"context"
|
||||
"fmt"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/auth/iflow"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
|
||||
@@ -37,6 +39,16 @@ func DoIFlowCookieAuth(cfg *config.Config, options *LoginOptions) {
|
||||
return
|
||||
}
|
||||
|
||||
// Check for duplicate BXAuth before authentication
|
||||
bxAuth := iflow.ExtractBXAuth(cookie)
|
||||
if existingFile, err := iflow.CheckDuplicateBXAuth(cfg.AuthDir, bxAuth); err != nil {
|
||||
fmt.Printf("Failed to check duplicate: %v\n", err)
|
||||
return
|
||||
} else if existingFile != "" {
|
||||
fmt.Printf("Duplicate BXAuth found, authentication already exists: %s\n", filepath.Base(existingFile))
|
||||
return
|
||||
}
|
||||
|
||||
// Authenticate with cookie
|
||||
auth := iflow.NewIFlowAuth(cfg)
|
||||
ctx := context.Background()
|
||||
@@ -82,5 +94,5 @@ func promptForCookie(promptFn func(string) (string, error)) (string, error) {
|
||||
// getAuthFilePath returns the auth file path for the given provider and email
|
||||
func getAuthFilePath(cfg *config.Config, provider, email string) string {
|
||||
fileName := iflow.SanitizeIFlowFileName(email)
|
||||
return fmt.Sprintf("%s/%s-%s.json", cfg.AuthDir, provider, fileName)
|
||||
return fmt.Sprintf("%s/%s-%s-%d.json", cfg.AuthDir, provider, fileName, time.Now().Unix())
|
||||
}
|
||||
|
||||
@@ -64,6 +64,10 @@ type Config struct {
|
||||
// KiroKey defines a list of Kiro (AWS CodeWhisperer) configurations.
|
||||
KiroKey []KiroKey `yaml:"kiro" json:"kiro"`
|
||||
|
||||
// KiroPreferredEndpoint sets the global default preferred endpoint for all Kiro providers.
|
||||
// Values: "ide" (default, CodeWhisperer) or "cli" (Amazon Q).
|
||||
KiroPreferredEndpoint string `yaml:"kiro-preferred-endpoint" json:"kiro-preferred-endpoint"`
|
||||
|
||||
// Codex defines a list of Codex API key configurations as specified in the YAML configuration file.
|
||||
CodexKey []CodexKey `yaml:"codex-api-key" json:"codex-api-key"`
|
||||
|
||||
@@ -147,7 +151,7 @@ type AmpCode struct {
|
||||
|
||||
// RestrictManagementToLocalhost restricts Amp management routes (/api/user, /api/threads, etc.)
|
||||
// to only accept connections from localhost (127.0.0.1, ::1). When true, prevents drive-by
|
||||
// browser attacks and remote access to management endpoints. Default: true (recommended).
|
||||
// browser attacks and remote access to management endpoints. Default: false (API key auth is sufficient).
|
||||
RestrictManagementToLocalhost bool `yaml:"restrict-management-to-localhost" json:"restrict-management-to-localhost"`
|
||||
|
||||
// ModelMappings defines model name mappings for Amp CLI requests.
|
||||
@@ -278,6 +282,10 @@ type KiroKey struct {
|
||||
// AgentTaskType sets the Kiro API task type. Known values: "vibe", "dev", "chat".
|
||||
// Leave empty to let API use defaults. Different values may inject different system prompts.
|
||||
AgentTaskType string `yaml:"agent-task-type,omitempty" json:"agent-task-type,omitempty"`
|
||||
|
||||
// PreferredEndpoint sets the preferred Kiro API endpoint/quota.
|
||||
// Values: "codewhisperer" (default, IDE quota) or "amazonq" (CLI quota).
|
||||
PreferredEndpoint string `yaml:"preferred-endpoint,omitempty" json:"preferred-endpoint,omitempty"`
|
||||
}
|
||||
|
||||
// OpenAICompatibility represents the configuration for OpenAI API compatibility
|
||||
@@ -360,7 +368,7 @@ func LoadConfigOptional(configFile string, optional bool) (*Config, error) {
|
||||
cfg.LoggingToFile = false
|
||||
cfg.UsageStatisticsEnabled = false
|
||||
cfg.DisableCooling = false
|
||||
cfg.AmpCode.RestrictManagementToLocalhost = true // Default to secure: only localhost access
|
||||
cfg.AmpCode.RestrictManagementToLocalhost = false // Default to false: API key auth is sufficient
|
||||
cfg.IncognitoBrowser = false // Default to normal browser (AWS uses incognito by force)
|
||||
if err = yaml.Unmarshal(data, &cfg); err != nil {
|
||||
if optional {
|
||||
@@ -504,6 +512,7 @@ func (cfg *Config) SanitizeKiroKeys() {
|
||||
entry.ProfileArn = strings.TrimSpace(entry.ProfileArn)
|
||||
entry.Region = strings.TrimSpace(entry.Region)
|
||||
entry.ProxyURL = strings.TrimSpace(entry.ProxyURL)
|
||||
entry.PreferredEndpoint = strings.TrimSpace(entry.PreferredEndpoint)
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -580,7 +580,7 @@ func GetOpenAIModels() []*ModelInfo {
|
||||
ContextLength: 400000,
|
||||
MaxCompletionTokens: 128000,
|
||||
SupportedParameters: []string{"tools"},
|
||||
Thinking: &ThinkingSupport{Levels: []string{"low", "medium", "high", "xhigh"}},
|
||||
Thinking: &ThinkingSupport{Levels: []string{"none", "low", "medium", "high", "xhigh"}},
|
||||
},
|
||||
}
|
||||
}
|
||||
@@ -648,10 +648,11 @@ func GetIFlowModels() []*ModelInfo {
|
||||
{ID: "glm-4.6", DisplayName: "GLM-4.6", Description: "Zhipu GLM 4.6 general model", Created: 1759190400},
|
||||
{ID: "kimi-k2", DisplayName: "Kimi-K2", Description: "Moonshot Kimi K2 general model", Created: 1752192000},
|
||||
{ID: "kimi-k2-thinking", DisplayName: "Kimi-K2-Thinking", Description: "Moonshot Kimi K2 thinking model", Created: 1762387200, Thinking: &ThinkingSupport{Levels: []string{"low", "medium", "high"}}},
|
||||
{ID: "deepseek-v3.2-chat", DisplayName: "DeepSeek-V3.2", Description: "DeepSeek V3.2", Created: 1764576000},
|
||||
{ID: "deepseek-v3.2-chat", DisplayName: "DeepSeek-V3.2", Description: "DeepSeek V3.2 Chat", Created: 1764576000},
|
||||
{ID: "deepseek-v3.2-reasoner", DisplayName: "DeepSeek-V3.2", Description: "DeepSeek V3.2 Reasoner", Created: 1764576000},
|
||||
{ID: "deepseek-v3.2", DisplayName: "DeepSeek-V3.2-Exp", Description: "DeepSeek V3.2 experimental", Created: 1759104000},
|
||||
{ID: "deepseek-v3.1", DisplayName: "DeepSeek-V3.1-Terminus", Description: "DeepSeek V3.1 Terminus", Created: 1756339200},
|
||||
{ID: "deepseek-r1", DisplayName: "DeepSeek-R1", Description: "DeepSeek reasoning model R1", Created: 1737331200, Thinking: &ThinkingSupport{Levels: []string{"low", "medium", "high"}}},
|
||||
{ID: "deepseek-r1", DisplayName: "DeepSeek-R1", Description: "DeepSeek reasoning model R1", Created: 1737331200},
|
||||
{ID: "deepseek-v3", DisplayName: "DeepSeek-V3-671B", Description: "DeepSeek V3 671B", Created: 1734307200},
|
||||
{ID: "qwen3-32b", DisplayName: "Qwen3-32B", Description: "Qwen3 32B", Created: 1747094400},
|
||||
{ID: "qwen3-235b-a22b-thinking-2507", DisplayName: "Qwen3-235B-A22B-Thinking", Description: "Qwen3 235B A22B Thinking (2507)", Created: 1753401600, Thinking: &ThinkingSupport{Levels: []string{"low", "medium", "high"}}},
|
||||
@@ -884,8 +885,9 @@ func GetGitHubCopilotModels() []*ModelInfo {
|
||||
// GetKiroModels returns the Kiro (AWS CodeWhisperer) model definitions
|
||||
func GetKiroModels() []*ModelInfo {
|
||||
return []*ModelInfo{
|
||||
// --- Base Models ---
|
||||
{
|
||||
ID: "kiro-claude-opus-4.5",
|
||||
ID: "kiro-claude-opus-4-5",
|
||||
Object: "model",
|
||||
Created: 1732752000,
|
||||
OwnedBy: "aws",
|
||||
@@ -894,9 +896,10 @@ func GetKiroModels() []*ModelInfo {
|
||||
Description: "Claude Opus 4.5 via Kiro (2.2x credit)",
|
||||
ContextLength: 200000,
|
||||
MaxCompletionTokens: 64000,
|
||||
Thinking: &ThinkingSupport{Min: 1024, Max: 32000, ZeroAllowed: true, DynamicAllowed: true},
|
||||
},
|
||||
{
|
||||
ID: "kiro-claude-sonnet-4.5",
|
||||
ID: "kiro-claude-sonnet-4-5",
|
||||
Object: "model",
|
||||
Created: 1732752000,
|
||||
OwnedBy: "aws",
|
||||
@@ -905,6 +908,7 @@ func GetKiroModels() []*ModelInfo {
|
||||
Description: "Claude Sonnet 4.5 via Kiro (1.3x credit)",
|
||||
ContextLength: 200000,
|
||||
MaxCompletionTokens: 64000,
|
||||
Thinking: &ThinkingSupport{Min: 1024, Max: 32000, ZeroAllowed: true, DynamicAllowed: true},
|
||||
},
|
||||
{
|
||||
ID: "kiro-claude-sonnet-4",
|
||||
@@ -916,9 +920,10 @@ func GetKiroModels() []*ModelInfo {
|
||||
Description: "Claude Sonnet 4 via Kiro (1.3x credit)",
|
||||
ContextLength: 200000,
|
||||
MaxCompletionTokens: 64000,
|
||||
Thinking: &ThinkingSupport{Min: 1024, Max: 32000, ZeroAllowed: true, DynamicAllowed: true},
|
||||
},
|
||||
{
|
||||
ID: "kiro-claude-haiku-4.5",
|
||||
ID: "kiro-claude-haiku-4-5",
|
||||
Object: "model",
|
||||
Created: 1732752000,
|
||||
OwnedBy: "aws",
|
||||
@@ -927,22 +932,11 @@ func GetKiroModels() []*ModelInfo {
|
||||
Description: "Claude Haiku 4.5 via Kiro (0.4x credit)",
|
||||
ContextLength: 200000,
|
||||
MaxCompletionTokens: 64000,
|
||||
},
|
||||
// --- Chat Variant (No tool calling, for pure conversation) ---
|
||||
{
|
||||
ID: "kiro-claude-opus-4.5-chat",
|
||||
Object: "model",
|
||||
Created: 1732752000,
|
||||
OwnedBy: "aws",
|
||||
Type: "kiro",
|
||||
DisplayName: "Kiro Claude Opus 4.5 (Chat)",
|
||||
Description: "Claude Opus 4.5 for chat only (no tool calling)",
|
||||
ContextLength: 200000,
|
||||
MaxCompletionTokens: 64000,
|
||||
Thinking: &ThinkingSupport{Min: 1024, Max: 32000, ZeroAllowed: true, DynamicAllowed: true},
|
||||
},
|
||||
// --- Agentic Variants (Optimized for coding agents with chunked writes) ---
|
||||
{
|
||||
ID: "kiro-claude-opus-4.5-agentic",
|
||||
ID: "kiro-claude-opus-4-5-agentic",
|
||||
Object: "model",
|
||||
Created: 1732752000,
|
||||
OwnedBy: "aws",
|
||||
@@ -951,9 +945,10 @@ func GetKiroModels() []*ModelInfo {
|
||||
Description: "Claude Opus 4.5 optimized for coding agents (chunked writes)",
|
||||
ContextLength: 200000,
|
||||
MaxCompletionTokens: 64000,
|
||||
Thinking: &ThinkingSupport{Min: 1024, Max: 32000, ZeroAllowed: true, DynamicAllowed: true},
|
||||
},
|
||||
{
|
||||
ID: "kiro-claude-sonnet-4.5-agentic",
|
||||
ID: "kiro-claude-sonnet-4-5-agentic",
|
||||
Object: "model",
|
||||
Created: 1732752000,
|
||||
OwnedBy: "aws",
|
||||
@@ -962,6 +957,31 @@ func GetKiroModels() []*ModelInfo {
|
||||
Description: "Claude Sonnet 4.5 optimized for coding agents (chunked writes)",
|
||||
ContextLength: 200000,
|
||||
MaxCompletionTokens: 64000,
|
||||
Thinking: &ThinkingSupport{Min: 1024, Max: 32000, ZeroAllowed: true, DynamicAllowed: true},
|
||||
},
|
||||
{
|
||||
ID: "kiro-claude-sonnet-4-agentic",
|
||||
Object: "model",
|
||||
Created: 1732752000,
|
||||
OwnedBy: "aws",
|
||||
Type: "kiro",
|
||||
DisplayName: "Kiro Claude Sonnet 4 (Agentic)",
|
||||
Description: "Claude Sonnet 4 optimized for coding agents (chunked writes)",
|
||||
ContextLength: 200000,
|
||||
MaxCompletionTokens: 64000,
|
||||
Thinking: &ThinkingSupport{Min: 1024, Max: 32000, ZeroAllowed: true, DynamicAllowed: true},
|
||||
},
|
||||
{
|
||||
ID: "kiro-claude-haiku-4-5-agentic",
|
||||
Object: "model",
|
||||
Created: 1732752000,
|
||||
OwnedBy: "aws",
|
||||
Type: "kiro",
|
||||
DisplayName: "Kiro Claude Haiku 4.5 (Agentic)",
|
||||
Description: "Claude Haiku 4.5 optimized for coding agents (chunked writes)",
|
||||
ContextLength: 200000,
|
||||
MaxCompletionTokens: 64000,
|
||||
Thinking: &ThinkingSupport{Min: 1024, Max: 32000, ZeroAllowed: true, DynamicAllowed: true},
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
@@ -90,6 +90,9 @@ type ModelRegistry struct {
|
||||
models map[string]*ModelRegistration
|
||||
// clientModels maps client ID to the models it provides
|
||||
clientModels map[string][]string
|
||||
// clientModelInfos maps client ID to a map of model ID -> ModelInfo
|
||||
// This preserves the original model info provided by each client
|
||||
clientModelInfos map[string]map[string]*ModelInfo
|
||||
// clientProviders maps client ID to its provider identifier
|
||||
clientProviders map[string]string
|
||||
// mutex ensures thread-safe access to the registry
|
||||
@@ -104,10 +107,11 @@ var registryOnce sync.Once
|
||||
func GetGlobalRegistry() *ModelRegistry {
|
||||
registryOnce.Do(func() {
|
||||
globalRegistry = &ModelRegistry{
|
||||
models: make(map[string]*ModelRegistration),
|
||||
clientModels: make(map[string][]string),
|
||||
clientProviders: make(map[string]string),
|
||||
mutex: &sync.RWMutex{},
|
||||
models: make(map[string]*ModelRegistration),
|
||||
clientModels: make(map[string][]string),
|
||||
clientModelInfos: make(map[string]map[string]*ModelInfo),
|
||||
clientProviders: make(map[string]string),
|
||||
mutex: &sync.RWMutex{},
|
||||
}
|
||||
})
|
||||
return globalRegistry
|
||||
@@ -144,6 +148,7 @@ func (r *ModelRegistry) RegisterClient(clientID, clientProvider string, models [
|
||||
// No models supplied; unregister existing client state if present.
|
||||
r.unregisterClientInternal(clientID)
|
||||
delete(r.clientModels, clientID)
|
||||
delete(r.clientModelInfos, clientID)
|
||||
delete(r.clientProviders, clientID)
|
||||
misc.LogCredentialSeparator()
|
||||
return
|
||||
@@ -152,7 +157,7 @@ func (r *ModelRegistry) RegisterClient(clientID, clientProvider string, models [
|
||||
now := time.Now()
|
||||
|
||||
oldModels, hadExisting := r.clientModels[clientID]
|
||||
oldProvider, _ := r.clientProviders[clientID]
|
||||
oldProvider := r.clientProviders[clientID]
|
||||
providerChanged := oldProvider != provider
|
||||
if !hadExisting {
|
||||
// Pure addition path.
|
||||
@@ -161,6 +166,12 @@ func (r *ModelRegistry) RegisterClient(clientID, clientProvider string, models [
|
||||
r.addModelRegistration(modelID, provider, model, now)
|
||||
}
|
||||
r.clientModels[clientID] = append([]string(nil), rawModelIDs...)
|
||||
// Store client's own model infos
|
||||
clientInfos := make(map[string]*ModelInfo, len(newModels))
|
||||
for id, m := range newModels {
|
||||
clientInfos[id] = cloneModelInfo(m)
|
||||
}
|
||||
r.clientModelInfos[clientID] = clientInfos
|
||||
if provider != "" {
|
||||
r.clientProviders[clientID] = provider
|
||||
} else {
|
||||
@@ -287,6 +298,12 @@ func (r *ModelRegistry) RegisterClient(clientID, clientProvider string, models [
|
||||
if len(rawModelIDs) > 0 {
|
||||
r.clientModels[clientID] = append([]string(nil), rawModelIDs...)
|
||||
}
|
||||
// Update client's own model infos
|
||||
clientInfos := make(map[string]*ModelInfo, len(newModels))
|
||||
for id, m := range newModels {
|
||||
clientInfos[id] = cloneModelInfo(m)
|
||||
}
|
||||
r.clientModelInfos[clientID] = clientInfos
|
||||
if provider != "" {
|
||||
r.clientProviders[clientID] = provider
|
||||
} else {
|
||||
@@ -436,6 +453,7 @@ func (r *ModelRegistry) unregisterClientInternal(clientID string) {
|
||||
}
|
||||
|
||||
delete(r.clientModels, clientID)
|
||||
delete(r.clientModelInfos, clientID)
|
||||
if hasProvider {
|
||||
delete(r.clientProviders, clientID)
|
||||
}
|
||||
@@ -748,7 +766,8 @@ func (r *ModelRegistry) convertModelToMap(model *ModelInfo, handlerType string)
|
||||
}
|
||||
return result
|
||||
|
||||
case "claude":
|
||||
case "claude", "kiro", "antigravity":
|
||||
// Claude, Kiro, and Antigravity all use Claude-compatible format for Claude Code client
|
||||
result := map[string]any{
|
||||
"id": model.ID,
|
||||
"object": "model",
|
||||
@@ -763,6 +782,19 @@ func (r *ModelRegistry) convertModelToMap(model *ModelInfo, handlerType string)
|
||||
if model.DisplayName != "" {
|
||||
result["display_name"] = model.DisplayName
|
||||
}
|
||||
// Add thinking support for Claude Code client
|
||||
// Claude Code checks for "thinking" field (simple boolean) to enable tab toggle
|
||||
// Also add "extended_thinking" for detailed budget info
|
||||
if model.Thinking != nil {
|
||||
result["thinking"] = true
|
||||
result["extended_thinking"] = map[string]any{
|
||||
"supported": true,
|
||||
"min": model.Thinking.Min,
|
||||
"max": model.Thinking.Max,
|
||||
"zero_allowed": model.Thinking.ZeroAllowed,
|
||||
"dynamic_allowed": model.Thinking.DynamicAllowed,
|
||||
}
|
||||
}
|
||||
return result
|
||||
|
||||
case "gemini":
|
||||
@@ -871,3 +903,44 @@ func (r *ModelRegistry) GetFirstAvailableModel(handlerType string) (string, erro
|
||||
|
||||
return "", fmt.Errorf("no available clients for any model in handler type: %s", handlerType)
|
||||
}
|
||||
|
||||
// GetModelsForClient returns the models registered for a specific client.
|
||||
// Parameters:
|
||||
// - clientID: The client identifier (typically auth file name or auth ID)
|
||||
//
|
||||
// Returns:
|
||||
// - []*ModelInfo: List of models registered for this client, nil if client not found
|
||||
func (r *ModelRegistry) GetModelsForClient(clientID string) []*ModelInfo {
|
||||
r.mutex.RLock()
|
||||
defer r.mutex.RUnlock()
|
||||
|
||||
modelIDs, exists := r.clientModels[clientID]
|
||||
if !exists || len(modelIDs) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Try to use client-specific model infos first
|
||||
clientInfos := r.clientModelInfos[clientID]
|
||||
|
||||
seen := make(map[string]struct{})
|
||||
result := make([]*ModelInfo, 0, len(modelIDs))
|
||||
for _, modelID := range modelIDs {
|
||||
if _, dup := seen[modelID]; dup {
|
||||
continue
|
||||
}
|
||||
seen[modelID] = struct{}{}
|
||||
|
||||
// Prefer client's own model info to preserve original type/owned_by
|
||||
if clientInfos != nil {
|
||||
if info, ok := clientInfos[modelID]; ok && info != nil {
|
||||
result = append(result, info)
|
||||
continue
|
||||
}
|
||||
}
|
||||
// Fallback to global registry (for backwards compatibility)
|
||||
if reg, ok := r.models[modelID]; ok && reg.Info != nil {
|
||||
result = append(result, reg.Info)
|
||||
}
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
@@ -322,7 +322,7 @@ func (e *AIStudioExecutor) translateRequest(req cliproxyexecutor.Request, opts c
|
||||
from := opts.SourceFormat
|
||||
to := sdktranslator.FromString("gemini")
|
||||
payload := sdktranslator.TranslateRequest(from, to, req.Model, bytes.Clone(req.Payload), stream)
|
||||
payload = applyThinkingMetadata(payload, req.Metadata, req.Model)
|
||||
payload = ApplyThinkingMetadata(payload, req.Metadata, req.Model)
|
||||
payload = util.ApplyDefaultThinkingIfNeeded(req.Model, payload)
|
||||
payload = util.ConvertThinkingLevelToBudget(payload)
|
||||
payload = util.NormalizeGeminiThinkingBudget(req.Model, payload)
|
||||
@@ -384,8 +384,16 @@ func ensureColonSpacedJSON(payload []byte) []byte {
|
||||
|
||||
for i := 0; i < len(indented); i++ {
|
||||
ch := indented[i]
|
||||
if ch == '"' && (i == 0 || indented[i-1] != '\\') {
|
||||
inString = !inString
|
||||
if ch == '"' {
|
||||
// A quote is escaped only when preceded by an odd number of consecutive backslashes.
|
||||
// For example: "\\\"" keeps the quote inside the string, but "\\\\" closes the string.
|
||||
backslashes := 0
|
||||
for j := i - 1; j >= 0 && indented[j] == '\\'; j-- {
|
||||
backslashes++
|
||||
}
|
||||
if backslashes%2 == 0 {
|
||||
inString = !inString
|
||||
}
|
||||
}
|
||||
|
||||
if !inString {
|
||||
|
||||
@@ -54,9 +54,9 @@ func (e *CodexExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, re
|
||||
from := opts.SourceFormat
|
||||
to := sdktranslator.FromString("codex")
|
||||
body := sdktranslator.TranslateRequest(from, to, req.Model, bytes.Clone(req.Payload), false)
|
||||
body = applyReasoningEffortMetadata(body, req.Metadata, req.Model, "reasoning.effort")
|
||||
body = normalizeThinkingConfig(body, upstreamModel)
|
||||
if errValidate := validateThinkingConfig(body, upstreamModel); errValidate != nil {
|
||||
body = ApplyReasoningEffortMetadata(body, req.Metadata, req.Model, "reasoning.effort", false)
|
||||
body = NormalizeThinkingConfig(body, upstreamModel, false)
|
||||
if errValidate := ValidateThinkingConfig(body, upstreamModel); errValidate != nil {
|
||||
return resp, errValidate
|
||||
}
|
||||
body = applyPayloadConfig(e.cfg, req.Model, body)
|
||||
@@ -152,9 +152,9 @@ func (e *CodexExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Au
|
||||
to := sdktranslator.FromString("codex")
|
||||
body := sdktranslator.TranslateRequest(from, to, req.Model, bytes.Clone(req.Payload), true)
|
||||
|
||||
body = applyReasoningEffortMetadata(body, req.Metadata, req.Model, "reasoning.effort")
|
||||
body = normalizeThinkingConfig(body, upstreamModel)
|
||||
if errValidate := validateThinkingConfig(body, upstreamModel); errValidate != nil {
|
||||
body = ApplyReasoningEffortMetadata(body, req.Metadata, req.Model, "reasoning.effort", false)
|
||||
body = NormalizeThinkingConfig(body, upstreamModel, false)
|
||||
if errValidate := ValidateThinkingConfig(body, upstreamModel); errValidate != nil {
|
||||
return nil, errValidate
|
||||
}
|
||||
body = applyPayloadConfig(e.cfg, req.Model, body)
|
||||
@@ -254,7 +254,7 @@ func (e *CodexExecutor) CountTokens(ctx context.Context, auth *cliproxyauth.Auth
|
||||
|
||||
modelForCounting := req.Model
|
||||
|
||||
body = applyReasoningEffortMetadata(body, req.Metadata, req.Model, "reasoning.effort")
|
||||
body = ApplyReasoningEffortMetadata(body, req.Metadata, req.Model, "reasoning.effort", false)
|
||||
body, _ = sjson.SetBytes(body, "model", upstreamModel)
|
||||
body, _ = sjson.DeleteBytes(body, "previous_response_id")
|
||||
body, _ = sjson.SetBytes(body, "stream", false)
|
||||
|
||||
@@ -11,6 +11,8 @@ import (
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"regexp"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
@@ -784,20 +786,45 @@ func parseRetryDelay(errorBody []byte) (*time.Duration, error) {
|
||||
// Try to parse the retryDelay from the error response
|
||||
// Format: error.details[].retryDelay where @type == "type.googleapis.com/google.rpc.RetryInfo"
|
||||
details := gjson.GetBytes(errorBody, "error.details")
|
||||
if !details.Exists() || !details.IsArray() {
|
||||
return nil, fmt.Errorf("no error.details found")
|
||||
if details.Exists() && details.IsArray() {
|
||||
for _, detail := range details.Array() {
|
||||
typeVal := detail.Get("@type").String()
|
||||
if typeVal == "type.googleapis.com/google.rpc.RetryInfo" {
|
||||
retryDelay := detail.Get("retryDelay").String()
|
||||
if retryDelay != "" {
|
||||
// Parse duration string like "0.847655010s"
|
||||
duration, err := time.ParseDuration(retryDelay)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to parse duration")
|
||||
}
|
||||
return &duration, nil
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Fallback: try ErrorInfo.metadata.quotaResetDelay (e.g., "373.801628ms")
|
||||
for _, detail := range details.Array() {
|
||||
typeVal := detail.Get("@type").String()
|
||||
if typeVal == "type.googleapis.com/google.rpc.ErrorInfo" {
|
||||
quotaResetDelay := detail.Get("metadata.quotaResetDelay").String()
|
||||
if quotaResetDelay != "" {
|
||||
duration, err := time.ParseDuration(quotaResetDelay)
|
||||
if err == nil {
|
||||
return &duration, nil
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
for _, detail := range details.Array() {
|
||||
typeVal := detail.Get("@type").String()
|
||||
if typeVal == "type.googleapis.com/google.rpc.RetryInfo" {
|
||||
retryDelay := detail.Get("retryDelay").String()
|
||||
if retryDelay != "" {
|
||||
// Parse duration string like "0.847655010s"
|
||||
duration, err := time.ParseDuration(retryDelay)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to parse duration")
|
||||
}
|
||||
// Fallback: parse from error.message "Your quota will reset after Xs."
|
||||
message := gjson.GetBytes(errorBody, "error.message").String()
|
||||
if message != "" {
|
||||
re := regexp.MustCompile(`after\s+(\d+)s\.?`)
|
||||
if matches := re.FindStringSubmatch(message); len(matches) > 1 {
|
||||
seconds, err := strconv.Atoi(matches[1])
|
||||
if err == nil {
|
||||
duration := time.Duration(seconds) * time.Second
|
||||
return &duration, nil
|
||||
}
|
||||
}
|
||||
|
||||
@@ -83,7 +83,7 @@ func (e *GeminiExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, r
|
||||
from := opts.SourceFormat
|
||||
to := sdktranslator.FromString("gemini")
|
||||
body := sdktranslator.TranslateRequest(from, to, req.Model, bytes.Clone(req.Payload), false)
|
||||
body = applyThinkingMetadata(body, req.Metadata, req.Model)
|
||||
body = ApplyThinkingMetadata(body, req.Metadata, req.Model)
|
||||
body = util.ApplyDefaultThinkingIfNeeded(req.Model, body)
|
||||
body = util.NormalizeGeminiThinkingBudget(req.Model, body)
|
||||
body = util.StripThinkingConfigIfUnsupported(req.Model, body)
|
||||
@@ -178,7 +178,7 @@ func (e *GeminiExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.A
|
||||
from := opts.SourceFormat
|
||||
to := sdktranslator.FromString("gemini")
|
||||
body := sdktranslator.TranslateRequest(from, to, req.Model, bytes.Clone(req.Payload), true)
|
||||
body = applyThinkingMetadata(body, req.Metadata, req.Model)
|
||||
body = ApplyThinkingMetadata(body, req.Metadata, req.Model)
|
||||
body = util.ApplyDefaultThinkingIfNeeded(req.Model, body)
|
||||
body = util.NormalizeGeminiThinkingBudget(req.Model, body)
|
||||
body = util.StripThinkingConfigIfUnsupported(req.Model, body)
|
||||
@@ -290,7 +290,7 @@ func (e *GeminiExecutor) CountTokens(ctx context.Context, auth *cliproxyauth.Aut
|
||||
from := opts.SourceFormat
|
||||
to := sdktranslator.FromString("gemini")
|
||||
translatedReq := sdktranslator.TranslateRequest(from, to, req.Model, bytes.Clone(req.Payload), false)
|
||||
translatedReq = applyThinkingMetadata(translatedReq, req.Metadata, req.Model)
|
||||
translatedReq = ApplyThinkingMetadata(translatedReq, req.Metadata, req.Model)
|
||||
translatedReq = util.StripThinkingConfigIfUnsupported(req.Model, translatedReq)
|
||||
translatedReq = fixGeminiImageAspectRatio(req.Model, translatedReq)
|
||||
respCtx := context.WithValue(ctx, "alt", opts.Alt)
|
||||
|
||||
@@ -57,13 +57,13 @@ func (e *IFlowExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, re
|
||||
from := opts.SourceFormat
|
||||
to := sdktranslator.FromString("openai")
|
||||
body := sdktranslator.TranslateRequest(from, to, req.Model, bytes.Clone(req.Payload), false)
|
||||
body = applyReasoningEffortMetadata(body, req.Metadata, req.Model, "reasoning_effort")
|
||||
body = ApplyReasoningEffortMetadata(body, req.Metadata, req.Model, "reasoning_effort", false)
|
||||
upstreamModel := util.ResolveOriginalModel(req.Model, req.Metadata)
|
||||
if upstreamModel != "" {
|
||||
body, _ = sjson.SetBytes(body, "model", upstreamModel)
|
||||
}
|
||||
body = normalizeThinkingConfig(body, upstreamModel)
|
||||
if errValidate := validateThinkingConfig(body, upstreamModel); errValidate != nil {
|
||||
body = NormalizeThinkingConfig(body, upstreamModel, false)
|
||||
if errValidate := ValidateThinkingConfig(body, upstreamModel); errValidate != nil {
|
||||
return resp, errValidate
|
||||
}
|
||||
body = applyPayloadConfig(e.cfg, req.Model, body)
|
||||
@@ -148,13 +148,13 @@ func (e *IFlowExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Au
|
||||
to := sdktranslator.FromString("openai")
|
||||
body := sdktranslator.TranslateRequest(from, to, req.Model, bytes.Clone(req.Payload), true)
|
||||
|
||||
body = applyReasoningEffortMetadata(body, req.Metadata, req.Model, "reasoning_effort")
|
||||
body = ApplyReasoningEffortMetadata(body, req.Metadata, req.Model, "reasoning_effort", false)
|
||||
upstreamModel := util.ResolveOriginalModel(req.Model, req.Metadata)
|
||||
if upstreamModel != "" {
|
||||
body, _ = sjson.SetBytes(body, "model", upstreamModel)
|
||||
}
|
||||
body = normalizeThinkingConfig(body, upstreamModel)
|
||||
if errValidate := validateThinkingConfig(body, upstreamModel); errValidate != nil {
|
||||
body = NormalizeThinkingConfig(body, upstreamModel, false)
|
||||
if errValidate := ValidateThinkingConfig(body, upstreamModel); errValidate != nil {
|
||||
return nil, errValidate
|
||||
}
|
||||
// Ensure tools array exists to avoid provider quirks similar to Qwen's behaviour.
|
||||
@@ -219,7 +219,7 @@ func (e *IFlowExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Au
|
||||
}()
|
||||
|
||||
scanner := bufio.NewScanner(httpResp.Body)
|
||||
scanner.Buffer(nil, 52_428_800) // 50MB
|
||||
scanner.Buffer(nil, 52_428_800) // 50MB
|
||||
var param any
|
||||
for scanner.Scan() {
|
||||
line := scanner.Bytes()
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -54,17 +54,19 @@ func (e *OpenAICompatExecutor) Execute(ctx context.Context, auth *cliproxyauth.A
|
||||
from := opts.SourceFormat
|
||||
to := sdktranslator.FromString("openai")
|
||||
translated := sdktranslator.TranslateRequest(from, to, req.Model, bytes.Clone(req.Payload), opts.Stream)
|
||||
if modelOverride := e.resolveUpstreamModel(req.Model, auth); modelOverride != "" {
|
||||
modelOverride := e.resolveUpstreamModel(req.Model, auth)
|
||||
if modelOverride != "" {
|
||||
translated = e.overrideModel(translated, modelOverride)
|
||||
}
|
||||
translated = applyPayloadConfigWithRoot(e.cfg, req.Model, to.String(), "", translated)
|
||||
translated = applyReasoningEffortMetadata(translated, req.Metadata, req.Model, "reasoning_effort")
|
||||
allowCompat := e.allowCompatReasoningEffort(req.Model, auth)
|
||||
translated = ApplyReasoningEffortMetadata(translated, req.Metadata, req.Model, "reasoning_effort", allowCompat)
|
||||
upstreamModel := util.ResolveOriginalModel(req.Model, req.Metadata)
|
||||
if upstreamModel != "" {
|
||||
if upstreamModel != "" && modelOverride == "" {
|
||||
translated, _ = sjson.SetBytes(translated, "model", upstreamModel)
|
||||
}
|
||||
translated = normalizeThinkingConfig(translated, upstreamModel)
|
||||
if errValidate := validateThinkingConfig(translated, upstreamModel); errValidate != nil {
|
||||
translated = NormalizeThinkingConfig(translated, upstreamModel, allowCompat)
|
||||
if errValidate := ValidateThinkingConfig(translated, upstreamModel); errValidate != nil {
|
||||
return resp, errValidate
|
||||
}
|
||||
|
||||
@@ -148,17 +150,19 @@ func (e *OpenAICompatExecutor) ExecuteStream(ctx context.Context, auth *cliproxy
|
||||
from := opts.SourceFormat
|
||||
to := sdktranslator.FromString("openai")
|
||||
translated := sdktranslator.TranslateRequest(from, to, req.Model, bytes.Clone(req.Payload), true)
|
||||
if modelOverride := e.resolveUpstreamModel(req.Model, auth); modelOverride != "" {
|
||||
modelOverride := e.resolveUpstreamModel(req.Model, auth)
|
||||
if modelOverride != "" {
|
||||
translated = e.overrideModel(translated, modelOverride)
|
||||
}
|
||||
translated = applyPayloadConfigWithRoot(e.cfg, req.Model, to.String(), "", translated)
|
||||
translated = applyReasoningEffortMetadata(translated, req.Metadata, req.Model, "reasoning_effort")
|
||||
allowCompat := e.allowCompatReasoningEffort(req.Model, auth)
|
||||
translated = ApplyReasoningEffortMetadata(translated, req.Metadata, req.Model, "reasoning_effort", allowCompat)
|
||||
upstreamModel := util.ResolveOriginalModel(req.Model, req.Metadata)
|
||||
if upstreamModel != "" {
|
||||
if upstreamModel != "" && modelOverride == "" {
|
||||
translated, _ = sjson.SetBytes(translated, "model", upstreamModel)
|
||||
}
|
||||
translated = normalizeThinkingConfig(translated, upstreamModel)
|
||||
if errValidate := validateThinkingConfig(translated, upstreamModel); errValidate != nil {
|
||||
translated = NormalizeThinkingConfig(translated, upstreamModel, allowCompat)
|
||||
if errValidate := ValidateThinkingConfig(translated, upstreamModel); errValidate != nil {
|
||||
return nil, errValidate
|
||||
}
|
||||
|
||||
@@ -323,6 +327,27 @@ func (e *OpenAICompatExecutor) resolveUpstreamModel(alias string, auth *cliproxy
|
||||
return ""
|
||||
}
|
||||
|
||||
func (e *OpenAICompatExecutor) allowCompatReasoningEffort(model string, auth *cliproxyauth.Auth) bool {
|
||||
trimmed := strings.TrimSpace(model)
|
||||
if trimmed == "" || e == nil || e.cfg == nil {
|
||||
return false
|
||||
}
|
||||
compat := e.resolveCompatConfig(auth)
|
||||
if compat == nil || len(compat.Models) == 0 {
|
||||
return false
|
||||
}
|
||||
for i := range compat.Models {
|
||||
entry := compat.Models[i]
|
||||
if strings.EqualFold(strings.TrimSpace(entry.Alias), trimmed) {
|
||||
return true
|
||||
}
|
||||
if strings.EqualFold(strings.TrimSpace(entry.Name), trimmed) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func (e *OpenAICompatExecutor) resolveCompatConfig(auth *cliproxyauth.Auth) *config.OpenAICompatibility {
|
||||
if auth == nil || e.cfg == nil {
|
||||
return nil
|
||||
|
||||
@@ -11,9 +11,9 @@ import (
|
||||
"github.com/tidwall/sjson"
|
||||
)
|
||||
|
||||
// applyThinkingMetadata applies thinking config from model suffix metadata (e.g., (high), (8192))
|
||||
// ApplyThinkingMetadata applies thinking config from model suffix metadata (e.g., (high), (8192))
|
||||
// for standard Gemini format payloads. It normalizes the budget when the model supports thinking.
|
||||
func applyThinkingMetadata(payload []byte, metadata map[string]any, model string) []byte {
|
||||
func ApplyThinkingMetadata(payload []byte, metadata map[string]any, model string) []byte {
|
||||
budgetOverride, includeOverride, ok := util.ResolveThinkingConfigFromMetadata(model, metadata)
|
||||
if !ok || (budgetOverride == nil && includeOverride == nil) {
|
||||
return payload
|
||||
@@ -45,22 +45,44 @@ func applyThinkingMetadataCLI(payload []byte, metadata map[string]any, model str
|
||||
return util.ApplyGeminiCLIThinkingConfig(payload, budgetOverride, includeOverride)
|
||||
}
|
||||
|
||||
// applyReasoningEffortMetadata applies reasoning effort overrides from metadata to the given JSON path.
|
||||
// ApplyReasoningEffortMetadata applies reasoning effort overrides from metadata to the given JSON path.
|
||||
// Metadata values take precedence over any existing field when the model supports thinking, intentionally
|
||||
// overwriting caller-provided values to honor suffix/default metadata priority.
|
||||
func applyReasoningEffortMetadata(payload []byte, metadata map[string]any, model, field string) []byte {
|
||||
func ApplyReasoningEffortMetadata(payload []byte, metadata map[string]any, model, field string, allowCompat bool) []byte {
|
||||
if len(metadata) == 0 {
|
||||
return payload
|
||||
}
|
||||
if !util.ModelSupportsThinking(model) {
|
||||
return payload
|
||||
}
|
||||
if field == "" {
|
||||
return payload
|
||||
}
|
||||
baseModel := util.ResolveOriginalModel(model, metadata)
|
||||
if baseModel == "" {
|
||||
baseModel = model
|
||||
}
|
||||
if !util.ModelSupportsThinking(baseModel) && !allowCompat {
|
||||
return payload
|
||||
}
|
||||
if effort, ok := util.ReasoningEffortFromMetadata(metadata); ok && effort != "" {
|
||||
if updated, err := sjson.SetBytes(payload, field, effort); err == nil {
|
||||
return updated
|
||||
if util.ModelUsesThinkingLevels(baseModel) || allowCompat {
|
||||
if updated, err := sjson.SetBytes(payload, field, effort); err == nil {
|
||||
return updated
|
||||
}
|
||||
}
|
||||
}
|
||||
// Fallback: numeric thinking_budget suffix for level-based (OpenAI-style) models.
|
||||
if util.ModelUsesThinkingLevels(baseModel) || allowCompat {
|
||||
if budget, _, _, matched := util.ThinkingFromMetadata(metadata); matched && budget != nil {
|
||||
if effort, ok := util.OpenAIThinkingBudgetToEffort(baseModel, *budget); ok && effort != "" {
|
||||
if *budget == 0 && effort == "none" && util.ModelUsesThinkingLevels(baseModel) {
|
||||
if _, supported := util.NormalizeReasoningEffortLevel(baseModel, effort); !supported {
|
||||
return StripThinkingFields(payload, false)
|
||||
}
|
||||
}
|
||||
|
||||
if updated, err := sjson.SetBytes(payload, field, effort); err == nil {
|
||||
return updated
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
return payload
|
||||
@@ -216,34 +238,43 @@ func matchModelPattern(pattern, model string) bool {
|
||||
return pi == len(pattern)
|
||||
}
|
||||
|
||||
// normalizeThinkingConfig normalizes thinking-related fields in the payload
|
||||
// NormalizeThinkingConfig normalizes thinking-related fields in the payload
|
||||
// based on model capabilities. For models without thinking support, it strips
|
||||
// reasoning fields. For models with level-based thinking, it validates and
|
||||
// normalizes the reasoning effort level.
|
||||
func normalizeThinkingConfig(payload []byte, model string) []byte {
|
||||
// normalizes the reasoning effort level. For models with numeric budget thinking,
|
||||
// it strips the effort string fields.
|
||||
func NormalizeThinkingConfig(payload []byte, model string, allowCompat bool) []byte {
|
||||
if len(payload) == 0 || model == "" {
|
||||
return payload
|
||||
}
|
||||
|
||||
if !util.ModelSupportsThinking(model) {
|
||||
return stripThinkingFields(payload)
|
||||
if allowCompat {
|
||||
return payload
|
||||
}
|
||||
return StripThinkingFields(payload, false)
|
||||
}
|
||||
|
||||
if util.ModelUsesThinkingLevels(model) {
|
||||
return normalizeReasoningEffortLevel(payload, model)
|
||||
return NormalizeReasoningEffortLevel(payload, model)
|
||||
}
|
||||
|
||||
return payload
|
||||
// Model supports thinking but uses numeric budgets, not levels.
|
||||
// Strip effort string fields since they are not applicable.
|
||||
return StripThinkingFields(payload, true)
|
||||
}
|
||||
|
||||
// stripThinkingFields removes thinking-related fields from the payload for
|
||||
// models that do not support thinking.
|
||||
func stripThinkingFields(payload []byte) []byte {
|
||||
// StripThinkingFields removes thinking-related fields from the payload for
|
||||
// models that do not support thinking. If effortOnly is true, only removes
|
||||
// effort string fields (for models using numeric budgets).
|
||||
func StripThinkingFields(payload []byte, effortOnly bool) []byte {
|
||||
fieldsToRemove := []string{
|
||||
"reasoning",
|
||||
"reasoning_effort",
|
||||
"reasoning.effort",
|
||||
}
|
||||
if !effortOnly {
|
||||
fieldsToRemove = append([]string{"reasoning"}, fieldsToRemove...)
|
||||
}
|
||||
out := payload
|
||||
for _, field := range fieldsToRemove {
|
||||
if gjson.GetBytes(out, field).Exists() {
|
||||
@@ -253,9 +284,9 @@ func stripThinkingFields(payload []byte) []byte {
|
||||
return out
|
||||
}
|
||||
|
||||
// normalizeReasoningEffortLevel validates and normalizes the reasoning_effort
|
||||
// NormalizeReasoningEffortLevel validates and normalizes the reasoning_effort
|
||||
// or reasoning.effort field for level-based thinking models.
|
||||
func normalizeReasoningEffortLevel(payload []byte, model string) []byte {
|
||||
func NormalizeReasoningEffortLevel(payload []byte, model string) []byte {
|
||||
out := payload
|
||||
|
||||
if effort := gjson.GetBytes(out, "reasoning_effort"); effort.Exists() {
|
||||
@@ -273,10 +304,10 @@ func normalizeReasoningEffortLevel(payload []byte, model string) []byte {
|
||||
return out
|
||||
}
|
||||
|
||||
// validateThinkingConfig checks for unsupported reasoning levels on level-based models.
|
||||
// ValidateThinkingConfig checks for unsupported reasoning levels on level-based models.
|
||||
// Returns a statusErr with 400 when an unsupported level is supplied to avoid silently
|
||||
// downgrading requests.
|
||||
func validateThinkingConfig(payload []byte, model string) error {
|
||||
func ValidateThinkingConfig(payload []byte, model string) error {
|
||||
if len(payload) == 0 || model == "" {
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -51,13 +51,13 @@ func (e *QwenExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, req
|
||||
from := opts.SourceFormat
|
||||
to := sdktranslator.FromString("openai")
|
||||
body := sdktranslator.TranslateRequest(from, to, req.Model, bytes.Clone(req.Payload), false)
|
||||
body = applyReasoningEffortMetadata(body, req.Metadata, req.Model, "reasoning_effort")
|
||||
body = ApplyReasoningEffortMetadata(body, req.Metadata, req.Model, "reasoning_effort", false)
|
||||
upstreamModel := util.ResolveOriginalModel(req.Model, req.Metadata)
|
||||
if upstreamModel != "" {
|
||||
body, _ = sjson.SetBytes(body, "model", upstreamModel)
|
||||
}
|
||||
body = normalizeThinkingConfig(body, upstreamModel)
|
||||
if errValidate := validateThinkingConfig(body, upstreamModel); errValidate != nil {
|
||||
body = NormalizeThinkingConfig(body, upstreamModel, false)
|
||||
if errValidate := ValidateThinkingConfig(body, upstreamModel); errValidate != nil {
|
||||
return resp, errValidate
|
||||
}
|
||||
body = applyPayloadConfig(e.cfg, req.Model, body)
|
||||
@@ -131,13 +131,13 @@ func (e *QwenExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Aut
|
||||
to := sdktranslator.FromString("openai")
|
||||
body := sdktranslator.TranslateRequest(from, to, req.Model, bytes.Clone(req.Payload), true)
|
||||
|
||||
body = applyReasoningEffortMetadata(body, req.Metadata, req.Model, "reasoning_effort")
|
||||
body = ApplyReasoningEffortMetadata(body, req.Metadata, req.Model, "reasoning_effort", false)
|
||||
upstreamModel := util.ResolveOriginalModel(req.Model, req.Metadata)
|
||||
if upstreamModel != "" {
|
||||
body, _ = sjson.SetBytes(body, "model", upstreamModel)
|
||||
}
|
||||
body = normalizeThinkingConfig(body, upstreamModel)
|
||||
if errValidate := validateThinkingConfig(body, upstreamModel); errValidate != nil {
|
||||
body = NormalizeThinkingConfig(body, upstreamModel, false)
|
||||
if errValidate := ValidateThinkingConfig(body, upstreamModel); errValidate != nil {
|
||||
return nil, errValidate
|
||||
}
|
||||
toolsResult := gjson.GetBytes(body, "tools")
|
||||
|
||||
@@ -2,43 +2,107 @@ package executor
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"regexp"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
|
||||
"github.com/tidwall/gjson"
|
||||
"github.com/tiktoken-go/tokenizer"
|
||||
)
|
||||
|
||||
// tokenizerCache stores tokenizer instances to avoid repeated creation
|
||||
var tokenizerCache sync.Map
|
||||
|
||||
// TokenizerWrapper wraps a tokenizer codec with an adjustment factor for models
|
||||
// where tiktoken may not accurately estimate token counts (e.g., Claude models)
|
||||
type TokenizerWrapper struct {
|
||||
Codec tokenizer.Codec
|
||||
AdjustmentFactor float64 // 1.0 means no adjustment, >1.0 means tiktoken underestimates
|
||||
}
|
||||
|
||||
// Count returns the token count with adjustment factor applied
|
||||
func (tw *TokenizerWrapper) Count(text string) (int, error) {
|
||||
count, err := tw.Codec.Count(text)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
if tw.AdjustmentFactor != 1.0 && tw.AdjustmentFactor > 0 {
|
||||
return int(float64(count) * tw.AdjustmentFactor), nil
|
||||
}
|
||||
return count, nil
|
||||
}
|
||||
|
||||
// getTokenizer returns a cached tokenizer for the given model.
|
||||
// This improves performance by avoiding repeated tokenizer creation.
|
||||
func getTokenizer(model string) (*TokenizerWrapper, error) {
|
||||
// Check cache first
|
||||
if cached, ok := tokenizerCache.Load(model); ok {
|
||||
return cached.(*TokenizerWrapper), nil
|
||||
}
|
||||
|
||||
// Cache miss, create new tokenizer
|
||||
wrapper, err := tokenizerForModel(model)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Store in cache (use LoadOrStore to handle race conditions)
|
||||
actual, _ := tokenizerCache.LoadOrStore(model, wrapper)
|
||||
return actual.(*TokenizerWrapper), nil
|
||||
}
|
||||
|
||||
// tokenizerForModel returns a tokenizer codec suitable for an OpenAI-style model id.
|
||||
func tokenizerForModel(model string) (tokenizer.Codec, error) {
|
||||
// For Claude models, applies a 1.1 adjustment factor since tiktoken may underestimate.
|
||||
func tokenizerForModel(model string) (*TokenizerWrapper, error) {
|
||||
sanitized := strings.ToLower(strings.TrimSpace(model))
|
||||
|
||||
// Claude models use cl100k_base with 1.1 adjustment factor
|
||||
// because tiktoken may underestimate Claude's actual token count
|
||||
if strings.Contains(sanitized, "claude") || strings.HasPrefix(sanitized, "kiro-") || strings.HasPrefix(sanitized, "amazonq-") {
|
||||
enc, err := tokenizer.Get(tokenizer.Cl100kBase)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &TokenizerWrapper{Codec: enc, AdjustmentFactor: 1.1}, nil
|
||||
}
|
||||
|
||||
var enc tokenizer.Codec
|
||||
var err error
|
||||
|
||||
switch {
|
||||
case sanitized == "":
|
||||
return tokenizer.Get(tokenizer.Cl100kBase)
|
||||
enc, err = tokenizer.Get(tokenizer.Cl100kBase)
|
||||
case strings.HasPrefix(sanitized, "gpt-5"):
|
||||
return tokenizer.ForModel(tokenizer.GPT5)
|
||||
enc, err = tokenizer.ForModel(tokenizer.GPT5)
|
||||
case strings.HasPrefix(sanitized, "gpt-5.1"):
|
||||
return tokenizer.ForModel(tokenizer.GPT5)
|
||||
enc, err = tokenizer.ForModel(tokenizer.GPT5)
|
||||
case strings.HasPrefix(sanitized, "gpt-4.1"):
|
||||
return tokenizer.ForModel(tokenizer.GPT41)
|
||||
enc, err = tokenizer.ForModel(tokenizer.GPT41)
|
||||
case strings.HasPrefix(sanitized, "gpt-4o"):
|
||||
return tokenizer.ForModel(tokenizer.GPT4o)
|
||||
enc, err = tokenizer.ForModel(tokenizer.GPT4o)
|
||||
case strings.HasPrefix(sanitized, "gpt-4"):
|
||||
return tokenizer.ForModel(tokenizer.GPT4)
|
||||
enc, err = tokenizer.ForModel(tokenizer.GPT4)
|
||||
case strings.HasPrefix(sanitized, "gpt-3.5"), strings.HasPrefix(sanitized, "gpt-3"):
|
||||
return tokenizer.ForModel(tokenizer.GPT35Turbo)
|
||||
enc, err = tokenizer.ForModel(tokenizer.GPT35Turbo)
|
||||
case strings.HasPrefix(sanitized, "o1"):
|
||||
return tokenizer.ForModel(tokenizer.O1)
|
||||
enc, err = tokenizer.ForModel(tokenizer.O1)
|
||||
case strings.HasPrefix(sanitized, "o3"):
|
||||
return tokenizer.ForModel(tokenizer.O3)
|
||||
enc, err = tokenizer.ForModel(tokenizer.O3)
|
||||
case strings.HasPrefix(sanitized, "o4"):
|
||||
return tokenizer.ForModel(tokenizer.O4Mini)
|
||||
enc, err = tokenizer.ForModel(tokenizer.O4Mini)
|
||||
default:
|
||||
return tokenizer.Get(tokenizer.O200kBase)
|
||||
enc, err = tokenizer.Get(tokenizer.O200kBase)
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &TokenizerWrapper{Codec: enc, AdjustmentFactor: 1.0}, nil
|
||||
}
|
||||
|
||||
// countOpenAIChatTokens approximates prompt tokens for OpenAI chat completions payloads.
|
||||
func countOpenAIChatTokens(enc tokenizer.Codec, payload []byte) (int64, error) {
|
||||
func countOpenAIChatTokens(enc *TokenizerWrapper, payload []byte) (int64, error) {
|
||||
if enc == nil {
|
||||
return 0, fmt.Errorf("encoder is nil")
|
||||
}
|
||||
@@ -62,11 +126,206 @@ func countOpenAIChatTokens(enc tokenizer.Codec, payload []byte) (int64, error) {
|
||||
return 0, nil
|
||||
}
|
||||
|
||||
// Count text tokens
|
||||
count, err := enc.Count(joined)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
return int64(count), nil
|
||||
|
||||
// Extract and add image tokens from placeholders
|
||||
imageTokens := extractImageTokens(joined)
|
||||
|
||||
return int64(count) + int64(imageTokens), nil
|
||||
}
|
||||
|
||||
// countClaudeChatTokens approximates prompt tokens for Claude API chat completions payloads.
|
||||
// This handles Claude's message format with system, messages, and tools.
|
||||
// Image tokens are estimated based on image dimensions when available.
|
||||
func countClaudeChatTokens(enc *TokenizerWrapper, payload []byte) (int64, error) {
|
||||
if enc == nil {
|
||||
return 0, fmt.Errorf("encoder is nil")
|
||||
}
|
||||
if len(payload) == 0 {
|
||||
return 0, nil
|
||||
}
|
||||
|
||||
root := gjson.ParseBytes(payload)
|
||||
segments := make([]string, 0, 32)
|
||||
|
||||
// Collect system prompt (can be string or array of content blocks)
|
||||
collectClaudeSystem(root.Get("system"), &segments)
|
||||
|
||||
// Collect messages
|
||||
collectClaudeMessages(root.Get("messages"), &segments)
|
||||
|
||||
// Collect tools
|
||||
collectClaudeTools(root.Get("tools"), &segments)
|
||||
|
||||
joined := strings.TrimSpace(strings.Join(segments, "\n"))
|
||||
if joined == "" {
|
||||
return 0, nil
|
||||
}
|
||||
|
||||
// Count text tokens
|
||||
count, err := enc.Count(joined)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
// Extract and add image tokens from placeholders
|
||||
imageTokens := extractImageTokens(joined)
|
||||
|
||||
return int64(count) + int64(imageTokens), nil
|
||||
}
|
||||
|
||||
// imageTokenPattern matches [IMAGE:xxx tokens] format for extracting estimated image tokens
|
||||
var imageTokenPattern = regexp.MustCompile(`\[IMAGE:(\d+) tokens\]`)
|
||||
|
||||
// extractImageTokens extracts image token estimates from placeholder text.
|
||||
// Placeholders are in the format [IMAGE:xxx tokens] where xxx is the estimated token count.
|
||||
func extractImageTokens(text string) int {
|
||||
matches := imageTokenPattern.FindAllStringSubmatch(text, -1)
|
||||
total := 0
|
||||
for _, match := range matches {
|
||||
if len(match) > 1 {
|
||||
if tokens, err := strconv.Atoi(match[1]); err == nil {
|
||||
total += tokens
|
||||
}
|
||||
}
|
||||
}
|
||||
return total
|
||||
}
|
||||
|
||||
// estimateImageTokens calculates estimated tokens for an image based on dimensions.
|
||||
// Based on Claude's image token calculation: tokens ≈ (width * height) / 750
|
||||
// Minimum 85 tokens, maximum 1590 tokens (for 1568x1568 images).
|
||||
func estimateImageTokens(width, height float64) int {
|
||||
if width <= 0 || height <= 0 {
|
||||
// No valid dimensions, use default estimate (medium-sized image)
|
||||
return 1000
|
||||
}
|
||||
|
||||
tokens := int(width * height / 750)
|
||||
|
||||
// Apply bounds
|
||||
if tokens < 85 {
|
||||
tokens = 85
|
||||
}
|
||||
if tokens > 1590 {
|
||||
tokens = 1590
|
||||
}
|
||||
|
||||
return tokens
|
||||
}
|
||||
|
||||
// collectClaudeSystem extracts text from Claude's system field.
|
||||
// System can be a string or an array of content blocks.
|
||||
func collectClaudeSystem(system gjson.Result, segments *[]string) {
|
||||
if !system.Exists() {
|
||||
return
|
||||
}
|
||||
if system.Type == gjson.String {
|
||||
addIfNotEmpty(segments, system.String())
|
||||
return
|
||||
}
|
||||
if system.IsArray() {
|
||||
system.ForEach(func(_, block gjson.Result) bool {
|
||||
blockType := block.Get("type").String()
|
||||
if blockType == "text" || blockType == "" {
|
||||
addIfNotEmpty(segments, block.Get("text").String())
|
||||
}
|
||||
// Also handle plain string blocks
|
||||
if block.Type == gjson.String {
|
||||
addIfNotEmpty(segments, block.String())
|
||||
}
|
||||
return true
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// collectClaudeMessages extracts text from Claude's messages array.
|
||||
func collectClaudeMessages(messages gjson.Result, segments *[]string) {
|
||||
if !messages.Exists() || !messages.IsArray() {
|
||||
return
|
||||
}
|
||||
messages.ForEach(func(_, message gjson.Result) bool {
|
||||
addIfNotEmpty(segments, message.Get("role").String())
|
||||
collectClaudeContent(message.Get("content"), segments)
|
||||
return true
|
||||
})
|
||||
}
|
||||
|
||||
// collectClaudeContent extracts text from Claude's content field.
|
||||
// Content can be a string or an array of content blocks.
|
||||
// For images, estimates token count based on dimensions when available.
|
||||
func collectClaudeContent(content gjson.Result, segments *[]string) {
|
||||
if !content.Exists() {
|
||||
return
|
||||
}
|
||||
if content.Type == gjson.String {
|
||||
addIfNotEmpty(segments, content.String())
|
||||
return
|
||||
}
|
||||
if content.IsArray() {
|
||||
content.ForEach(func(_, part gjson.Result) bool {
|
||||
partType := part.Get("type").String()
|
||||
switch partType {
|
||||
case "text":
|
||||
addIfNotEmpty(segments, part.Get("text").String())
|
||||
case "image":
|
||||
// Estimate image tokens based on dimensions if available
|
||||
source := part.Get("source")
|
||||
if source.Exists() {
|
||||
width := source.Get("width").Float()
|
||||
height := source.Get("height").Float()
|
||||
if width > 0 && height > 0 {
|
||||
tokens := estimateImageTokens(width, height)
|
||||
addIfNotEmpty(segments, fmt.Sprintf("[IMAGE:%d tokens]", tokens))
|
||||
} else {
|
||||
// No dimensions available, use default estimate
|
||||
addIfNotEmpty(segments, "[IMAGE:1000 tokens]")
|
||||
}
|
||||
} else {
|
||||
// No source info, use default estimate
|
||||
addIfNotEmpty(segments, "[IMAGE:1000 tokens]")
|
||||
}
|
||||
case "tool_use":
|
||||
addIfNotEmpty(segments, part.Get("id").String())
|
||||
addIfNotEmpty(segments, part.Get("name").String())
|
||||
if input := part.Get("input"); input.Exists() {
|
||||
addIfNotEmpty(segments, input.Raw)
|
||||
}
|
||||
case "tool_result":
|
||||
addIfNotEmpty(segments, part.Get("tool_use_id").String())
|
||||
collectClaudeContent(part.Get("content"), segments)
|
||||
case "thinking":
|
||||
addIfNotEmpty(segments, part.Get("thinking").String())
|
||||
default:
|
||||
// For unknown types, try to extract any text content
|
||||
if part.Type == gjson.String {
|
||||
addIfNotEmpty(segments, part.String())
|
||||
} else if part.Type == gjson.JSON {
|
||||
addIfNotEmpty(segments, part.Raw)
|
||||
}
|
||||
}
|
||||
return true
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// collectClaudeTools extracts text from Claude's tools array.
|
||||
func collectClaudeTools(tools gjson.Result, segments *[]string) {
|
||||
if !tools.Exists() || !tools.IsArray() {
|
||||
return
|
||||
}
|
||||
tools.ForEach(func(_, tool gjson.Result) bool {
|
||||
addIfNotEmpty(segments, tool.Get("name").String())
|
||||
addIfNotEmpty(segments, tool.Get("description").String())
|
||||
if inputSchema := tool.Get("input_schema"); inputSchema.Exists() {
|
||||
addIfNotEmpty(segments, inputSchema.Raw)
|
||||
}
|
||||
return true
|
||||
})
|
||||
}
|
||||
|
||||
// buildOpenAIUsageJSON returns a minimal usage structure understood by downstream translators.
|
||||
|
||||
@@ -122,6 +122,38 @@ type FunctionCallGroup struct {
|
||||
ResponsesNeeded int
|
||||
}
|
||||
|
||||
// parseFunctionResponse attempts to unmarshal a function response part.
|
||||
// Falls back to gjson extraction if standard json.Unmarshal fails.
|
||||
func parseFunctionResponse(response gjson.Result) map[string]interface{} {
|
||||
var responseMap map[string]interface{}
|
||||
err := json.Unmarshal([]byte(response.Raw), &responseMap)
|
||||
if err == nil {
|
||||
return responseMap
|
||||
}
|
||||
|
||||
log.Debugf("unmarshal function response failed, using fallback: %v", err)
|
||||
funcResp := response.Get("functionResponse")
|
||||
if funcResp.Exists() {
|
||||
fr := map[string]interface{}{
|
||||
"name": funcResp.Get("name").String(),
|
||||
"response": map[string]interface{}{
|
||||
"result": funcResp.Get("response").String(),
|
||||
},
|
||||
}
|
||||
if id := funcResp.Get("id").String(); id != "" {
|
||||
fr["id"] = id
|
||||
}
|
||||
return map[string]interface{}{"functionResponse": fr}
|
||||
}
|
||||
|
||||
return map[string]interface{}{
|
||||
"functionResponse": map[string]interface{}{
|
||||
"name": "unknown",
|
||||
"response": map[string]interface{}{"result": response.String()},
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
// fixCLIToolResponse performs sophisticated tool response format conversion and grouping.
|
||||
// This function transforms the CLI tool response format by intelligently grouping function calls
|
||||
// with their corresponding responses, ensuring proper conversation flow and API compatibility.
|
||||
@@ -180,13 +212,7 @@ func fixCLIToolResponse(input string) (string, error) {
|
||||
// Create merged function response content
|
||||
var responseParts []interface{}
|
||||
for _, response := range groupResponses {
|
||||
var responseMap map[string]interface{}
|
||||
errUnmarshal := json.Unmarshal([]byte(response.Raw), &responseMap)
|
||||
if errUnmarshal != nil {
|
||||
log.Warnf("failed to unmarshal function response: %v\n", errUnmarshal)
|
||||
continue
|
||||
}
|
||||
responseParts = append(responseParts, responseMap)
|
||||
responseParts = append(responseParts, parseFunctionResponse(response))
|
||||
}
|
||||
|
||||
if len(responseParts) > 0 {
|
||||
@@ -265,13 +291,7 @@ func fixCLIToolResponse(input string) (string, error) {
|
||||
|
||||
var responseParts []interface{}
|
||||
for _, response := range groupResponses {
|
||||
var responseMap map[string]interface{}
|
||||
errUnmarshal := json.Unmarshal([]byte(response.Raw), &responseMap)
|
||||
if errUnmarshal != nil {
|
||||
log.Warnf("failed to unmarshal function response: %v\n", errUnmarshal)
|
||||
continue
|
||||
}
|
||||
responseParts = append(responseParts, responseMap)
|
||||
responseParts = append(responseParts, parseFunctionResponse(response))
|
||||
}
|
||||
|
||||
if len(responseParts) > 0 {
|
||||
|
||||
@@ -39,31 +39,13 @@ func ConvertOpenAIRequestToAntigravity(modelName string, inputRawJSON []byte, _
|
||||
// Note: OpenAI official fields take precedence over extra_body.google.thinking_config
|
||||
re := gjson.GetBytes(rawJSON, "reasoning_effort")
|
||||
hasOfficialThinking := re.Exists()
|
||||
if hasOfficialThinking && util.ModelSupportsThinking(modelName) {
|
||||
switch re.String() {
|
||||
case "none":
|
||||
out, _ = sjson.DeleteBytes(out, "request.generationConfig.thinkingConfig.include_thoughts")
|
||||
out, _ = sjson.SetBytes(out, "request.generationConfig.thinkingConfig.thinkingBudget", 0)
|
||||
case "auto":
|
||||
out, _ = sjson.SetBytes(out, "request.generationConfig.thinkingConfig.thinkingBudget", -1)
|
||||
out, _ = sjson.SetBytes(out, "request.generationConfig.thinkingConfig.include_thoughts", true)
|
||||
case "low":
|
||||
out, _ = sjson.SetBytes(out, "request.generationConfig.thinkingConfig.thinkingBudget", 1024)
|
||||
out, _ = sjson.SetBytes(out, "request.generationConfig.thinkingConfig.include_thoughts", true)
|
||||
case "medium":
|
||||
out, _ = sjson.SetBytes(out, "request.generationConfig.thinkingConfig.thinkingBudget", 8192)
|
||||
out, _ = sjson.SetBytes(out, "request.generationConfig.thinkingConfig.include_thoughts", true)
|
||||
case "high":
|
||||
out, _ = sjson.SetBytes(out, "request.generationConfig.thinkingConfig.thinkingBudget", 32768)
|
||||
out, _ = sjson.SetBytes(out, "request.generationConfig.thinkingConfig.include_thoughts", true)
|
||||
default:
|
||||
out, _ = sjson.SetBytes(out, "request.generationConfig.thinkingConfig.thinkingBudget", -1)
|
||||
out, _ = sjson.SetBytes(out, "request.generationConfig.thinkingConfig.include_thoughts", true)
|
||||
}
|
||||
if hasOfficialThinking && util.ModelSupportsThinking(modelName) && !util.ModelUsesThinkingLevels(modelName) {
|
||||
out = util.ApplyReasoningEffortToGeminiCLI(out, re.String())
|
||||
}
|
||||
|
||||
// Cherry Studio extension extra_body.google.thinking_config (effective only when official fields are absent)
|
||||
if !hasOfficialThinking && util.ModelSupportsThinking(modelName) {
|
||||
// Only apply for models that use numeric budgets, not discrete levels.
|
||||
if !hasOfficialThinking && util.ModelSupportsThinking(modelName) && !util.ModelUsesThinkingLevels(modelName) {
|
||||
if tc := gjson.GetBytes(rawJSON, "extra_body.google.thinking_config"); tc.Exists() && tc.IsObject() {
|
||||
var setBudget bool
|
||||
var budget int
|
||||
|
||||
@@ -114,14 +114,16 @@ func ConvertGeminiRequestToClaude(modelName string, inputRawJSON []byte, stream
|
||||
}
|
||||
}
|
||||
// Include thoughts configuration for reasoning process visibility
|
||||
if thinkingConfig := genConfig.Get("thinkingConfig"); thinkingConfig.Exists() && thinkingConfig.IsObject() {
|
||||
if includeThoughts := thinkingConfig.Get("include_thoughts"); includeThoughts.Exists() {
|
||||
if includeThoughts.Type == gjson.True {
|
||||
out, _ = sjson.Set(out, "thinking.type", "enabled")
|
||||
if thinkingBudget := thinkingConfig.Get("thinkingBudget"); thinkingBudget.Exists() {
|
||||
out, _ = sjson.Set(out, "thinking.budget_tokens", thinkingBudget.Int())
|
||||
}
|
||||
}
|
||||
// Only apply for models that support thinking and use numeric budgets, not discrete levels.
|
||||
if thinkingConfig := genConfig.Get("thinkingConfig"); thinkingConfig.Exists() && thinkingConfig.IsObject() && util.ModelSupportsThinking(modelName) && !util.ModelUsesThinkingLevels(modelName) {
|
||||
// Check for thinkingBudget first - if present, enable thinking with budget
|
||||
if thinkingBudget := thinkingConfig.Get("thinkingBudget"); thinkingBudget.Exists() && thinkingBudget.Int() > 0 {
|
||||
out, _ = sjson.Set(out, "thinking.type", "enabled")
|
||||
normalizedBudget := util.NormalizeThinkingBudget(modelName, int(thinkingBudget.Int()))
|
||||
out, _ = sjson.Set(out, "thinking.budget_tokens", normalizedBudget)
|
||||
} else if includeThoughts := thinkingConfig.Get("include_thoughts"); includeThoughts.Exists() && includeThoughts.Type == gjson.True {
|
||||
// Fallback to include_thoughts if no budget specified
|
||||
out, _ = sjson.Set(out, "thinking.type", "enabled")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -16,6 +16,7 @@ import (
|
||||
"strings"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/util"
|
||||
"github.com/tidwall/gjson"
|
||||
"github.com/tidwall/sjson"
|
||||
)
|
||||
@@ -65,18 +66,23 @@ func ConvertOpenAIRequestToClaude(modelName string, inputRawJSON []byte, stream
|
||||
|
||||
root := gjson.ParseBytes(rawJSON)
|
||||
|
||||
if v := root.Get("reasoning_effort"); v.Exists() {
|
||||
out, _ = sjson.Set(out, "thinking.type", "enabled")
|
||||
|
||||
switch v.String() {
|
||||
case "none":
|
||||
out, _ = sjson.Set(out, "thinking.type", "disabled")
|
||||
case "low":
|
||||
out, _ = sjson.Set(out, "thinking.budget_tokens", 1024)
|
||||
case "medium":
|
||||
out, _ = sjson.Set(out, "thinking.budget_tokens", 8192)
|
||||
case "high":
|
||||
out, _ = sjson.Set(out, "thinking.budget_tokens", 24576)
|
||||
if v := root.Get("reasoning_effort"); v.Exists() && util.ModelSupportsThinking(modelName) && !util.ModelUsesThinkingLevels(modelName) {
|
||||
effort := strings.ToLower(strings.TrimSpace(v.String()))
|
||||
if effort != "" {
|
||||
budget, ok := util.ThinkingEffortToBudget(modelName, effort)
|
||||
if ok {
|
||||
switch budget {
|
||||
case 0:
|
||||
out, _ = sjson.Set(out, "thinking.type", "disabled")
|
||||
case -1:
|
||||
out, _ = sjson.Set(out, "thinking.type", "enabled")
|
||||
default:
|
||||
if budget > 0 {
|
||||
out, _ = sjson.Set(out, "thinking.type", "enabled")
|
||||
out, _ = sjson.Set(out, "thinking.budget_tokens", budget)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -10,6 +10,7 @@ import (
|
||||
"strings"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/util"
|
||||
"github.com/tidwall/gjson"
|
||||
"github.com/tidwall/sjson"
|
||||
)
|
||||
@@ -52,20 +53,23 @@ func ConvertOpenAIResponsesRequestToClaude(modelName string, inputRawJSON []byte
|
||||
|
||||
root := gjson.ParseBytes(rawJSON)
|
||||
|
||||
if v := root.Get("reasoning.effort"); v.Exists() {
|
||||
out, _ = sjson.Set(out, "thinking.type", "enabled")
|
||||
|
||||
switch v.String() {
|
||||
case "none":
|
||||
out, _ = sjson.Set(out, "thinking.type", "disabled")
|
||||
case "minimal":
|
||||
out, _ = sjson.Set(out, "thinking.budget_tokens", 1024)
|
||||
case "low":
|
||||
out, _ = sjson.Set(out, "thinking.budget_tokens", 4096)
|
||||
case "medium":
|
||||
out, _ = sjson.Set(out, "thinking.budget_tokens", 8192)
|
||||
case "high":
|
||||
out, _ = sjson.Set(out, "thinking.budget_tokens", 24576)
|
||||
if v := root.Get("reasoning.effort"); v.Exists() && util.ModelSupportsThinking(modelName) && !util.ModelUsesThinkingLevels(modelName) {
|
||||
effort := strings.ToLower(strings.TrimSpace(v.String()))
|
||||
if effort != "" {
|
||||
budget, ok := util.ThinkingEffortToBudget(modelName, effort)
|
||||
if ok {
|
||||
switch budget {
|
||||
case 0:
|
||||
out, _ = sjson.Set(out, "thinking.type", "disabled")
|
||||
case -1:
|
||||
out, _ = sjson.Set(out, "thinking.type", "enabled")
|
||||
default:
|
||||
if budget > 0 {
|
||||
out, _ = sjson.Set(out, "thinking.type", "enabled")
|
||||
out, _ = sjson.Set(out, "thinking.budget_tokens", budget)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -12,6 +12,7 @@ import (
|
||||
"strings"
|
||||
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/misc"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/util"
|
||||
"github.com/tidwall/gjson"
|
||||
"github.com/tidwall/sjson"
|
||||
)
|
||||
@@ -214,7 +215,22 @@ func ConvertClaudeRequestToCodex(modelName string, inputRawJSON []byte, _ bool)
|
||||
|
||||
// Add additional configuration parameters for the Codex API.
|
||||
template, _ = sjson.Set(template, "parallel_tool_calls", true)
|
||||
template, _ = sjson.Set(template, "reasoning.effort", "low")
|
||||
|
||||
// Convert thinking.budget_tokens to reasoning.effort for level-based models
|
||||
reasoningEffort := "medium" // default
|
||||
if thinking := rootResult.Get("thinking"); thinking.Exists() && thinking.IsObject() {
|
||||
if thinking.Get("type").String() == "enabled" {
|
||||
if util.ModelUsesThinkingLevels(modelName) {
|
||||
if budgetTokens := thinking.Get("budget_tokens"); budgetTokens.Exists() {
|
||||
budget := int(budgetTokens.Int())
|
||||
if effort, ok := util.OpenAIThinkingBudgetToEffort(modelName, budget); ok && effort != "" {
|
||||
reasoningEffort = effort
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
template, _ = sjson.Set(template, "reasoning.effort", reasoningEffort)
|
||||
template, _ = sjson.Set(template, "reasoning.summary", "auto")
|
||||
template, _ = sjson.Set(template, "stream", true)
|
||||
template, _ = sjson.Set(template, "store", false)
|
||||
|
||||
@@ -245,7 +245,22 @@ func ConvertGeminiRequestToCodex(modelName string, inputRawJSON []byte, _ bool)
|
||||
|
||||
// Fixed flags aligning with Codex expectations
|
||||
out, _ = sjson.Set(out, "parallel_tool_calls", true)
|
||||
out, _ = sjson.Set(out, "reasoning.effort", "low")
|
||||
|
||||
// Convert thinkingBudget to reasoning.effort for level-based models
|
||||
reasoningEffort := "medium" // default
|
||||
if genConfig := root.Get("generationConfig"); genConfig.Exists() {
|
||||
if thinkingConfig := genConfig.Get("thinkingConfig"); thinkingConfig.Exists() && thinkingConfig.IsObject() {
|
||||
if util.ModelUsesThinkingLevels(modelName) {
|
||||
if thinkingBudget := thinkingConfig.Get("thinkingBudget"); thinkingBudget.Exists() {
|
||||
budget := int(thinkingBudget.Int())
|
||||
if effort, ok := util.OpenAIThinkingBudgetToEffort(modelName, budget); ok && effort != "" {
|
||||
reasoningEffort = effort
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
out, _ = sjson.Set(out, "reasoning.effort", reasoningEffort)
|
||||
out, _ = sjson.Set(out, "reasoning.summary", "auto")
|
||||
out, _ = sjson.Set(out, "stream", true)
|
||||
out, _ = sjson.Set(out, "store", false)
|
||||
|
||||
@@ -60,7 +60,7 @@ func ConvertOpenAIRequestToCodex(modelName string, inputRawJSON []byte, stream b
|
||||
if v := gjson.GetBytes(rawJSON, "reasoning_effort"); v.Exists() {
|
||||
out, _ = sjson.Set(out, "reasoning.effort", v.Value())
|
||||
} else {
|
||||
out, _ = sjson.Set(out, "reasoning.effort", "low")
|
||||
out, _ = sjson.Set(out, "reasoning.effort", "medium")
|
||||
}
|
||||
out, _ = sjson.Set(out, "parallel_tool_calls", true)
|
||||
out, _ = sjson.Set(out, "reasoning.summary", "auto")
|
||||
|
||||
@@ -39,31 +39,13 @@ func ConvertOpenAIRequestToGeminiCLI(modelName string, inputRawJSON []byte, _ bo
|
||||
// Note: OpenAI official fields take precedence over extra_body.google.thinking_config
|
||||
re := gjson.GetBytes(rawJSON, "reasoning_effort")
|
||||
hasOfficialThinking := re.Exists()
|
||||
if hasOfficialThinking && util.ModelSupportsThinking(modelName) {
|
||||
switch re.String() {
|
||||
case "none":
|
||||
out, _ = sjson.DeleteBytes(out, "request.generationConfig.thinkingConfig.include_thoughts")
|
||||
out, _ = sjson.SetBytes(out, "request.generationConfig.thinkingConfig.thinkingBudget", 0)
|
||||
case "auto":
|
||||
out, _ = sjson.SetBytes(out, "request.generationConfig.thinkingConfig.thinkingBudget", -1)
|
||||
out, _ = sjson.SetBytes(out, "request.generationConfig.thinkingConfig.include_thoughts", true)
|
||||
case "low":
|
||||
out, _ = sjson.SetBytes(out, "request.generationConfig.thinkingConfig.thinkingBudget", 1024)
|
||||
out, _ = sjson.SetBytes(out, "request.generationConfig.thinkingConfig.include_thoughts", true)
|
||||
case "medium":
|
||||
out, _ = sjson.SetBytes(out, "request.generationConfig.thinkingConfig.thinkingBudget", 8192)
|
||||
out, _ = sjson.SetBytes(out, "request.generationConfig.thinkingConfig.include_thoughts", true)
|
||||
case "high":
|
||||
out, _ = sjson.SetBytes(out, "request.generationConfig.thinkingConfig.thinkingBudget", 32768)
|
||||
out, _ = sjson.SetBytes(out, "request.generationConfig.thinkingConfig.include_thoughts", true)
|
||||
default:
|
||||
out, _ = sjson.SetBytes(out, "request.generationConfig.thinkingConfig.thinkingBudget", -1)
|
||||
out, _ = sjson.SetBytes(out, "request.generationConfig.thinkingConfig.include_thoughts", true)
|
||||
}
|
||||
if hasOfficialThinking && util.ModelSupportsThinking(modelName) && !util.ModelUsesThinkingLevels(modelName) {
|
||||
out = util.ApplyReasoningEffortToGeminiCLI(out, re.String())
|
||||
}
|
||||
|
||||
// Cherry Studio extension extra_body.google.thinking_config (effective only when official fields are absent)
|
||||
if !hasOfficialThinking && util.ModelSupportsThinking(modelName) {
|
||||
// Only apply for models that use numeric budgets, not discrete levels.
|
||||
if !hasOfficialThinking && util.ModelSupportsThinking(modelName) && !util.ModelUsesThinkingLevels(modelName) {
|
||||
if tc := gjson.GetBytes(rawJSON, "extra_body.google.thinking_config"); tc.Exists() && tc.IsObject() {
|
||||
var setBudget bool
|
||||
var budget int
|
||||
|
||||
@@ -154,7 +154,8 @@ func ConvertClaudeRequestToGemini(modelName string, inputRawJSON []byte, _ bool)
|
||||
}
|
||||
|
||||
// Map Anthropic thinking -> Gemini thinkingBudget/include_thoughts when enabled
|
||||
if t := gjson.GetBytes(rawJSON, "thinking"); t.Exists() && t.IsObject() && util.ModelSupportsThinking(modelName) {
|
||||
// Only apply for models that use numeric budgets, not discrete levels.
|
||||
if t := gjson.GetBytes(rawJSON, "thinking"); t.Exists() && t.IsObject() && util.ModelSupportsThinking(modelName) && !util.ModelUsesThinkingLevels(modelName) {
|
||||
if t.Get("type").String() == "enabled" {
|
||||
if b := t.Get("budget_tokens"); b.Exists() && b.Type == gjson.Number {
|
||||
budget := int(b.Int())
|
||||
|
||||
@@ -37,33 +37,17 @@ func ConvertOpenAIRequestToGemini(modelName string, inputRawJSON []byte, _ bool)
|
||||
|
||||
// Reasoning effort -> thinkingBudget/include_thoughts
|
||||
// Note: OpenAI official fields take precedence over extra_body.google.thinking_config
|
||||
// Only convert for models that use numeric budgets (not discrete levels) to avoid
|
||||
// incorrectly applying thinkingBudget for level-based models like gpt-5.
|
||||
re := gjson.GetBytes(rawJSON, "reasoning_effort")
|
||||
hasOfficialThinking := re.Exists()
|
||||
if hasOfficialThinking && util.ModelSupportsThinking(modelName) {
|
||||
switch re.String() {
|
||||
case "none":
|
||||
out, _ = sjson.DeleteBytes(out, "generationConfig.thinkingConfig.include_thoughts")
|
||||
out, _ = sjson.SetBytes(out, "generationConfig.thinkingConfig.thinkingBudget", 0)
|
||||
case "auto":
|
||||
out, _ = sjson.SetBytes(out, "generationConfig.thinkingConfig.thinkingBudget", -1)
|
||||
out, _ = sjson.SetBytes(out, "generationConfig.thinkingConfig.include_thoughts", true)
|
||||
case "low":
|
||||
out, _ = sjson.SetBytes(out, "generationConfig.thinkingConfig.thinkingBudget", 1024)
|
||||
out, _ = sjson.SetBytes(out, "generationConfig.thinkingConfig.include_thoughts", true)
|
||||
case "medium":
|
||||
out, _ = sjson.SetBytes(out, "generationConfig.thinkingConfig.thinkingBudget", 8192)
|
||||
out, _ = sjson.SetBytes(out, "generationConfig.thinkingConfig.include_thoughts", true)
|
||||
case "high":
|
||||
out, _ = sjson.SetBytes(out, "generationConfig.thinkingConfig.thinkingBudget", 32768)
|
||||
out, _ = sjson.SetBytes(out, "generationConfig.thinkingConfig.include_thoughts", true)
|
||||
default:
|
||||
out, _ = sjson.SetBytes(out, "generationConfig.thinkingConfig.thinkingBudget", -1)
|
||||
out, _ = sjson.SetBytes(out, "generationConfig.thinkingConfig.include_thoughts", true)
|
||||
}
|
||||
if hasOfficialThinking && util.ModelSupportsThinking(modelName) && !util.ModelUsesThinkingLevels(modelName) {
|
||||
out = util.ApplyReasoningEffortToGemini(out, re.String())
|
||||
}
|
||||
|
||||
// Cherry Studio extension extra_body.google.thinking_config (effective only when official fields are absent)
|
||||
if !hasOfficialThinking && util.ModelSupportsThinking(modelName) {
|
||||
// Only apply for models that use numeric budgets, not discrete levels.
|
||||
if !hasOfficialThinking && util.ModelSupportsThinking(modelName) && !util.ModelUsesThinkingLevels(modelName) {
|
||||
if tc := gjson.GetBytes(rawJSON, "extra_body.google.thinking_config"); tc.Exists() && tc.IsObject() {
|
||||
var setBudget bool
|
||||
var budget int
|
||||
|
||||
@@ -389,36 +389,16 @@ func ConvertOpenAIResponsesRequestToGemini(modelName string, inputRawJSON []byte
|
||||
}
|
||||
|
||||
// OpenAI official reasoning fields take precedence
|
||||
// Only convert for models that use numeric budgets (not discrete levels).
|
||||
hasOfficialThinking := root.Get("reasoning.effort").Exists()
|
||||
if hasOfficialThinking && util.ModelSupportsThinking(modelName) {
|
||||
if hasOfficialThinking && util.ModelSupportsThinking(modelName) && !util.ModelUsesThinkingLevels(modelName) {
|
||||
reasoningEffort := root.Get("reasoning.effort")
|
||||
switch reasoningEffort.String() {
|
||||
case "none":
|
||||
out, _ = sjson.Set(out, "generationConfig.thinkingConfig.include_thoughts", false)
|
||||
out, _ = sjson.Set(out, "generationConfig.thinkingConfig.thinkingBudget", 0)
|
||||
case "auto":
|
||||
out, _ = sjson.Set(out, "generationConfig.thinkingConfig.thinkingBudget", -1)
|
||||
out, _ = sjson.Set(out, "generationConfig.thinkingConfig.include_thoughts", true)
|
||||
case "minimal":
|
||||
out, _ = sjson.Set(out, "generationConfig.thinkingConfig.thinkingBudget", 1024)
|
||||
out, _ = sjson.Set(out, "generationConfig.thinkingConfig.include_thoughts", true)
|
||||
case "low":
|
||||
out, _ = sjson.Set(out, "generationConfig.thinkingConfig.thinkingBudget", 4096)
|
||||
out, _ = sjson.Set(out, "generationConfig.thinkingConfig.include_thoughts", true)
|
||||
case "medium":
|
||||
out, _ = sjson.Set(out, "generationConfig.thinkingConfig.thinkingBudget", 8192)
|
||||
out, _ = sjson.Set(out, "generationConfig.thinkingConfig.include_thoughts", true)
|
||||
case "high":
|
||||
out, _ = sjson.Set(out, "generationConfig.thinkingConfig.thinkingBudget", 32768)
|
||||
out, _ = sjson.Set(out, "generationConfig.thinkingConfig.include_thoughts", true)
|
||||
default:
|
||||
out, _ = sjson.Set(out, "generationConfig.thinkingConfig.thinkingBudget", -1)
|
||||
out, _ = sjson.Set(out, "generationConfig.thinkingConfig.include_thoughts", true)
|
||||
}
|
||||
out = string(util.ApplyReasoningEffortToGemini([]byte(out), reasoningEffort.String()))
|
||||
}
|
||||
|
||||
// Cherry Studio extension (applies only when official fields are missing)
|
||||
if !hasOfficialThinking && util.ModelSupportsThinking(modelName) {
|
||||
// Only apply for models that use numeric budgets, not discrete levels.
|
||||
if !hasOfficialThinking && util.ModelSupportsThinking(modelName) && !util.ModelUsesThinkingLevels(modelName) {
|
||||
if tc := root.Get("extra_body.google.thinking_config"); tc.Exists() && tc.IsObject() {
|
||||
var setBudget bool
|
||||
var budget int
|
||||
|
||||
@@ -35,5 +35,5 @@ import (
|
||||
_ "github.com/router-for-me/CLIProxyAPI/v6/internal/translator/antigravity/openai/responses"
|
||||
|
||||
_ "github.com/router-for-me/CLIProxyAPI/v6/internal/translator/kiro/claude"
|
||||
_ "github.com/router-for-me/CLIProxyAPI/v6/internal/translator/kiro/openai/chat-completions"
|
||||
_ "github.com/router-for-me/CLIProxyAPI/v6/internal/translator/kiro/openai"
|
||||
)
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
// Package claude provides translation between Kiro and Claude formats.
|
||||
package claude
|
||||
|
||||
import (
|
||||
@@ -12,8 +13,8 @@ func init() {
|
||||
Kiro,
|
||||
ConvertClaudeRequestToKiro,
|
||||
interfaces.TranslateResponse{
|
||||
Stream: ConvertKiroResponseToClaude,
|
||||
NonStream: ConvertKiroResponseToClaudeNonStream,
|
||||
Stream: ConvertKiroStreamToClaude,
|
||||
NonStream: ConvertKiroNonStreamToClaude,
|
||||
},
|
||||
)
|
||||
}
|
||||
|
||||
@@ -1,27 +1,21 @@
|
||||
// Package claude provides translation between Kiro and Claude formats.
|
||||
// Since Kiro executor generates Claude-compatible SSE format internally (with event: prefix),
|
||||
// translations are pass-through.
|
||||
// translations are pass-through for streaming, but responses need proper formatting.
|
||||
package claude
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
)
|
||||
|
||||
// ConvertClaudeRequestToKiro converts Claude request to Kiro format.
|
||||
// Since Kiro uses Claude format internally, this is mostly a pass-through.
|
||||
func ConvertClaudeRequestToKiro(modelName string, inputRawJSON []byte, stream bool) []byte {
|
||||
return bytes.Clone(inputRawJSON)
|
||||
}
|
||||
|
||||
// ConvertKiroResponseToClaude converts Kiro streaming response to Claude format.
|
||||
// ConvertKiroStreamToClaude converts Kiro streaming response to Claude format.
|
||||
// Kiro executor already generates complete SSE format with "event:" prefix,
|
||||
// so this is a simple pass-through.
|
||||
func ConvertKiroResponseToClaude(ctx context.Context, model string, originalRequest, request, rawResponse []byte, param *any) []string {
|
||||
func ConvertKiroStreamToClaude(ctx context.Context, model string, originalRequest, request, rawResponse []byte, param *any) []string {
|
||||
return []string{string(rawResponse)}
|
||||
}
|
||||
|
||||
// ConvertKiroResponseToClaudeNonStream converts Kiro non-streaming response to Claude format.
|
||||
func ConvertKiroResponseToClaudeNonStream(ctx context.Context, model string, originalRequest, request, rawResponse []byte, param *any) string {
|
||||
// ConvertKiroNonStreamToClaude converts Kiro non-streaming response to Claude format.
|
||||
// The response is already in Claude format, so this is a pass-through.
|
||||
func ConvertKiroNonStreamToClaude(ctx context.Context, model string, originalRequest, request, rawResponse []byte, param *any) string {
|
||||
return string(rawResponse)
|
||||
}
|
||||
|
||||
774
internal/translator/kiro/claude/kiro_claude_request.go
Normal file
774
internal/translator/kiro/claude/kiro_claude_request.go
Normal file
@@ -0,0 +1,774 @@
|
||||
// Package claude provides request translation functionality for Claude API to Kiro format.
|
||||
// It handles parsing and transforming Claude API requests into the Kiro/Amazon Q API format,
|
||||
// extracting model information, system instructions, message contents, and tool declarations.
|
||||
package claude
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"strings"
|
||||
"time"
|
||||
"unicode/utf8"
|
||||
|
||||
"github.com/google/uuid"
|
||||
kirocommon "github.com/router-for-me/CLIProxyAPI/v6/internal/translator/kiro/common"
|
||||
log "github.com/sirupsen/logrus"
|
||||
"github.com/tidwall/gjson"
|
||||
)
|
||||
|
||||
|
||||
// Kiro API request structs - field order determines JSON key order
|
||||
|
||||
// KiroPayload is the top-level request structure for Kiro API
|
||||
type KiroPayload struct {
|
||||
ConversationState KiroConversationState `json:"conversationState"`
|
||||
ProfileArn string `json:"profileArn,omitempty"`
|
||||
InferenceConfig *KiroInferenceConfig `json:"inferenceConfig,omitempty"`
|
||||
}
|
||||
|
||||
// KiroInferenceConfig contains inference parameters for the Kiro API.
|
||||
type KiroInferenceConfig struct {
|
||||
MaxTokens int `json:"maxTokens,omitempty"`
|
||||
Temperature float64 `json:"temperature,omitempty"`
|
||||
TopP float64 `json:"topP,omitempty"`
|
||||
}
|
||||
|
||||
// KiroConversationState holds the conversation context
|
||||
type KiroConversationState struct {
|
||||
ChatTriggerType string `json:"chatTriggerType"` // Required: "MANUAL" - must be first field
|
||||
ConversationID string `json:"conversationId"`
|
||||
CurrentMessage KiroCurrentMessage `json:"currentMessage"`
|
||||
History []KiroHistoryMessage `json:"history,omitempty"`
|
||||
}
|
||||
|
||||
// KiroCurrentMessage wraps the current user message
|
||||
type KiroCurrentMessage struct {
|
||||
UserInputMessage KiroUserInputMessage `json:"userInputMessage"`
|
||||
}
|
||||
|
||||
// KiroHistoryMessage represents a message in the conversation history
|
||||
type KiroHistoryMessage struct {
|
||||
UserInputMessage *KiroUserInputMessage `json:"userInputMessage,omitempty"`
|
||||
AssistantResponseMessage *KiroAssistantResponseMessage `json:"assistantResponseMessage,omitempty"`
|
||||
}
|
||||
|
||||
// KiroImage represents an image in Kiro API format
|
||||
type KiroImage struct {
|
||||
Format string `json:"format"`
|
||||
Source KiroImageSource `json:"source"`
|
||||
}
|
||||
|
||||
// KiroImageSource contains the image data
|
||||
type KiroImageSource struct {
|
||||
Bytes string `json:"bytes"` // base64 encoded image data
|
||||
}
|
||||
|
||||
// KiroUserInputMessage represents a user message
|
||||
type KiroUserInputMessage struct {
|
||||
Content string `json:"content"`
|
||||
ModelID string `json:"modelId"`
|
||||
Origin string `json:"origin"`
|
||||
Images []KiroImage `json:"images,omitempty"`
|
||||
UserInputMessageContext *KiroUserInputMessageContext `json:"userInputMessageContext,omitempty"`
|
||||
}
|
||||
|
||||
// KiroUserInputMessageContext contains tool-related context
|
||||
type KiroUserInputMessageContext struct {
|
||||
ToolResults []KiroToolResult `json:"toolResults,omitempty"`
|
||||
Tools []KiroToolWrapper `json:"tools,omitempty"`
|
||||
}
|
||||
|
||||
// KiroToolResult represents a tool execution result
|
||||
type KiroToolResult struct {
|
||||
Content []KiroTextContent `json:"content"`
|
||||
Status string `json:"status"`
|
||||
ToolUseID string `json:"toolUseId"`
|
||||
}
|
||||
|
||||
// KiroTextContent represents text content
|
||||
type KiroTextContent struct {
|
||||
Text string `json:"text"`
|
||||
}
|
||||
|
||||
// KiroToolWrapper wraps a tool specification
|
||||
type KiroToolWrapper struct {
|
||||
ToolSpecification KiroToolSpecification `json:"toolSpecification"`
|
||||
}
|
||||
|
||||
// KiroToolSpecification defines a tool's schema
|
||||
type KiroToolSpecification struct {
|
||||
Name string `json:"name"`
|
||||
Description string `json:"description"`
|
||||
InputSchema KiroInputSchema `json:"inputSchema"`
|
||||
}
|
||||
|
||||
// KiroInputSchema wraps the JSON schema for tool input
|
||||
type KiroInputSchema struct {
|
||||
JSON interface{} `json:"json"`
|
||||
}
|
||||
|
||||
// KiroAssistantResponseMessage represents an assistant message
|
||||
type KiroAssistantResponseMessage struct {
|
||||
Content string `json:"content"`
|
||||
ToolUses []KiroToolUse `json:"toolUses,omitempty"`
|
||||
}
|
||||
|
||||
// KiroToolUse represents a tool invocation by the assistant
|
||||
type KiroToolUse struct {
|
||||
ToolUseID string `json:"toolUseId"`
|
||||
Name string `json:"name"`
|
||||
Input map[string]interface{} `json:"input"`
|
||||
}
|
||||
|
||||
// ConvertClaudeRequestToKiro converts a Claude API request to Kiro format.
|
||||
// This is the main entry point for request translation.
|
||||
func ConvertClaudeRequestToKiro(modelName string, inputRawJSON []byte, stream bool) []byte {
|
||||
// For Kiro, we pass through the Claude format since buildKiroPayload
|
||||
// expects Claude format and does the conversion internally.
|
||||
// The actual conversion happens in the executor when building the HTTP request.
|
||||
return inputRawJSON
|
||||
}
|
||||
|
||||
// BuildKiroPayload constructs the Kiro API request payload from Claude format.
|
||||
// Supports tool calling - tools are passed via userInputMessageContext.
|
||||
// origin parameter determines which quota to use: "CLI" for Amazon Q, "AI_EDITOR" for Kiro IDE.
|
||||
// isAgentic parameter enables chunked write optimization prompt for -agentic model variants.
|
||||
// isChatOnly parameter disables tool calling for -chat model variants (pure conversation mode).
|
||||
// Supports thinking mode - when Claude API thinking parameter is present, injects thinkingHint.
|
||||
// Returns the payload and a boolean indicating whether thinking mode was injected.
|
||||
func BuildKiroPayload(claudeBody []byte, modelID, profileArn, origin string, isAgentic, isChatOnly bool) ([]byte, bool) {
|
||||
// Extract max_tokens for potential use in inferenceConfig
|
||||
// Handle -1 as "use maximum" (Kiro max output is ~32000 tokens)
|
||||
const kiroMaxOutputTokens = 32000
|
||||
var maxTokens int64
|
||||
if mt := gjson.GetBytes(claudeBody, "max_tokens"); mt.Exists() {
|
||||
maxTokens = mt.Int()
|
||||
if maxTokens == -1 {
|
||||
maxTokens = kiroMaxOutputTokens
|
||||
log.Debugf("kiro: max_tokens=-1 converted to %d", kiroMaxOutputTokens)
|
||||
}
|
||||
}
|
||||
|
||||
// Extract temperature if specified
|
||||
var temperature float64
|
||||
var hasTemperature bool
|
||||
if temp := gjson.GetBytes(claudeBody, "temperature"); temp.Exists() {
|
||||
temperature = temp.Float()
|
||||
hasTemperature = true
|
||||
}
|
||||
|
||||
// Extract top_p if specified
|
||||
var topP float64
|
||||
var hasTopP bool
|
||||
if tp := gjson.GetBytes(claudeBody, "top_p"); tp.Exists() {
|
||||
topP = tp.Float()
|
||||
hasTopP = true
|
||||
log.Debugf("kiro: extracted top_p: %.2f", topP)
|
||||
}
|
||||
|
||||
// Normalize origin value for Kiro API compatibility
|
||||
origin = normalizeOrigin(origin)
|
||||
log.Debugf("kiro: normalized origin value: %s", origin)
|
||||
|
||||
messages := gjson.GetBytes(claudeBody, "messages")
|
||||
|
||||
// For chat-only mode, don't include tools
|
||||
var tools gjson.Result
|
||||
if !isChatOnly {
|
||||
tools = gjson.GetBytes(claudeBody, "tools")
|
||||
}
|
||||
|
||||
// Extract system prompt
|
||||
systemPrompt := extractSystemPrompt(claudeBody)
|
||||
|
||||
// Check for thinking mode using the comprehensive IsThinkingEnabled function
|
||||
// This supports Claude API format, OpenAI reasoning_effort, and AMP/Cursor format
|
||||
thinkingEnabled := IsThinkingEnabled(claudeBody)
|
||||
_, budgetTokens := checkThinkingMode(claudeBody) // Get budget tokens from Claude format if available
|
||||
if budgetTokens <= 0 {
|
||||
// Calculate budgetTokens based on max_tokens if available
|
||||
// Use 50% of max_tokens for thinking, with min 8000 and max 24000
|
||||
if maxTokens > 0 {
|
||||
budgetTokens = maxTokens / 2
|
||||
if budgetTokens < 8000 {
|
||||
budgetTokens = 8000
|
||||
}
|
||||
if budgetTokens > 24000 {
|
||||
budgetTokens = 24000
|
||||
}
|
||||
log.Debugf("kiro: budgetTokens calculated from max_tokens: %d (max_tokens=%d)", budgetTokens, maxTokens)
|
||||
} else {
|
||||
budgetTokens = 16000 // Default budget tokens
|
||||
}
|
||||
}
|
||||
|
||||
// Inject timestamp context
|
||||
timestamp := time.Now().Format("2006-01-02 15:04:05 MST")
|
||||
timestampContext := fmt.Sprintf("[Context: Current time is %s]", timestamp)
|
||||
if systemPrompt != "" {
|
||||
systemPrompt = timestampContext + "\n\n" + systemPrompt
|
||||
} else {
|
||||
systemPrompt = timestampContext
|
||||
}
|
||||
log.Debugf("kiro: injected timestamp context: %s", timestamp)
|
||||
|
||||
// Inject agentic optimization prompt for -agentic model variants
|
||||
if isAgentic {
|
||||
if systemPrompt != "" {
|
||||
systemPrompt += "\n"
|
||||
}
|
||||
systemPrompt += kirocommon.KiroAgenticSystemPrompt
|
||||
}
|
||||
|
||||
// Handle tool_choice parameter - Kiro doesn't support it natively, so we inject system prompt hints
|
||||
// Claude tool_choice values: {"type": "auto/any/tool", "name": "..."}
|
||||
toolChoiceHint := extractClaudeToolChoiceHint(claudeBody)
|
||||
if toolChoiceHint != "" {
|
||||
if systemPrompt != "" {
|
||||
systemPrompt += "\n"
|
||||
}
|
||||
systemPrompt += toolChoiceHint
|
||||
log.Debugf("kiro: injected tool_choice hint into system prompt")
|
||||
}
|
||||
|
||||
// Inject thinking hint when thinking mode is enabled
|
||||
if thinkingEnabled {
|
||||
if systemPrompt != "" {
|
||||
systemPrompt += "\n"
|
||||
}
|
||||
dynamicThinkingHint := fmt.Sprintf("<thinking_mode>interleaved</thinking_mode><max_thinking_length>%d</max_thinking_length>", budgetTokens)
|
||||
systemPrompt += dynamicThinkingHint
|
||||
log.Debugf("kiro: injected dynamic thinking hint into system prompt, max_thinking_length: %d", budgetTokens)
|
||||
}
|
||||
|
||||
// Convert Claude tools to Kiro format
|
||||
kiroTools := convertClaudeToolsToKiro(tools)
|
||||
|
||||
// Process messages and build history
|
||||
history, currentUserMsg, currentToolResults := processMessages(messages, modelID, origin)
|
||||
|
||||
// Build content with system prompt
|
||||
if currentUserMsg != nil {
|
||||
currentUserMsg.Content = buildFinalContent(currentUserMsg.Content, systemPrompt, currentToolResults)
|
||||
|
||||
// Deduplicate currentToolResults
|
||||
currentToolResults = deduplicateToolResults(currentToolResults)
|
||||
|
||||
// Build userInputMessageContext with tools and tool results
|
||||
if len(kiroTools) > 0 || len(currentToolResults) > 0 {
|
||||
currentUserMsg.UserInputMessageContext = &KiroUserInputMessageContext{
|
||||
Tools: kiroTools,
|
||||
ToolResults: currentToolResults,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Build payload
|
||||
var currentMessage KiroCurrentMessage
|
||||
if currentUserMsg != nil {
|
||||
currentMessage = KiroCurrentMessage{UserInputMessage: *currentUserMsg}
|
||||
} else {
|
||||
fallbackContent := ""
|
||||
if systemPrompt != "" {
|
||||
fallbackContent = "--- SYSTEM PROMPT ---\n" + systemPrompt + "\n--- END SYSTEM PROMPT ---\n"
|
||||
}
|
||||
currentMessage = KiroCurrentMessage{UserInputMessage: KiroUserInputMessage{
|
||||
Content: fallbackContent,
|
||||
ModelID: modelID,
|
||||
Origin: origin,
|
||||
}}
|
||||
}
|
||||
|
||||
// Build inferenceConfig if we have any inference parameters
|
||||
var inferenceConfig *KiroInferenceConfig
|
||||
if maxTokens > 0 || hasTemperature || hasTopP {
|
||||
inferenceConfig = &KiroInferenceConfig{}
|
||||
if maxTokens > 0 {
|
||||
inferenceConfig.MaxTokens = int(maxTokens)
|
||||
}
|
||||
if hasTemperature {
|
||||
inferenceConfig.Temperature = temperature
|
||||
}
|
||||
if hasTopP {
|
||||
inferenceConfig.TopP = topP
|
||||
}
|
||||
}
|
||||
|
||||
payload := KiroPayload{
|
||||
ConversationState: KiroConversationState{
|
||||
ChatTriggerType: "MANUAL",
|
||||
ConversationID: uuid.New().String(),
|
||||
CurrentMessage: currentMessage,
|
||||
History: history,
|
||||
},
|
||||
ProfileArn: profileArn,
|
||||
InferenceConfig: inferenceConfig,
|
||||
}
|
||||
|
||||
result, err := json.Marshal(payload)
|
||||
if err != nil {
|
||||
log.Debugf("kiro: failed to marshal payload: %v", err)
|
||||
return nil, false
|
||||
}
|
||||
|
||||
return result, thinkingEnabled
|
||||
}
|
||||
|
||||
// normalizeOrigin normalizes origin value for Kiro API compatibility
|
||||
func normalizeOrigin(origin string) string {
|
||||
switch origin {
|
||||
case "KIRO_CLI":
|
||||
return "CLI"
|
||||
case "KIRO_AI_EDITOR":
|
||||
return "AI_EDITOR"
|
||||
case "AMAZON_Q":
|
||||
return "CLI"
|
||||
case "KIRO_IDE":
|
||||
return "AI_EDITOR"
|
||||
default:
|
||||
return origin
|
||||
}
|
||||
}
|
||||
|
||||
// extractSystemPrompt extracts system prompt from Claude request
|
||||
func extractSystemPrompt(claudeBody []byte) string {
|
||||
systemField := gjson.GetBytes(claudeBody, "system")
|
||||
if systemField.IsArray() {
|
||||
var sb strings.Builder
|
||||
for _, block := range systemField.Array() {
|
||||
if block.Get("type").String() == "text" {
|
||||
sb.WriteString(block.Get("text").String())
|
||||
} else if block.Type == gjson.String {
|
||||
sb.WriteString(block.String())
|
||||
}
|
||||
}
|
||||
return sb.String()
|
||||
}
|
||||
return systemField.String()
|
||||
}
|
||||
|
||||
// checkThinkingMode checks if thinking mode is enabled in the Claude request
|
||||
func checkThinkingMode(claudeBody []byte) (bool, int64) {
|
||||
thinkingEnabled := false
|
||||
var budgetTokens int64 = 16000
|
||||
|
||||
thinkingField := gjson.GetBytes(claudeBody, "thinking")
|
||||
if thinkingField.Exists() {
|
||||
thinkingType := thinkingField.Get("type").String()
|
||||
if thinkingType == "enabled" {
|
||||
thinkingEnabled = true
|
||||
if bt := thinkingField.Get("budget_tokens"); bt.Exists() {
|
||||
budgetTokens = bt.Int()
|
||||
if budgetTokens <= 0 {
|
||||
thinkingEnabled = false
|
||||
log.Debugf("kiro: thinking mode disabled via budget_tokens <= 0")
|
||||
}
|
||||
}
|
||||
if thinkingEnabled {
|
||||
log.Debugf("kiro: thinking mode enabled via Claude API parameter, budget_tokens: %d", budgetTokens)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return thinkingEnabled, budgetTokens
|
||||
}
|
||||
|
||||
// IsThinkingEnabled is a public wrapper to check if thinking mode is enabled.
|
||||
// This is used by the executor to determine whether to parse <thinking> tags in responses.
|
||||
// When thinking is NOT enabled in the request, <thinking> tags in responses should be
|
||||
// treated as regular text content, not as thinking blocks.
|
||||
//
|
||||
// Supports multiple formats:
|
||||
// - Claude API format: thinking.type = "enabled"
|
||||
// - OpenAI format: reasoning_effort parameter
|
||||
// - AMP/Cursor format: <thinking_mode>interleaved</thinking_mode> in system prompt
|
||||
func IsThinkingEnabled(body []byte) bool {
|
||||
// Check Claude API format first (thinking.type = "enabled")
|
||||
enabled, _ := checkThinkingMode(body)
|
||||
if enabled {
|
||||
log.Debugf("kiro: IsThinkingEnabled returning true (Claude API format)")
|
||||
return true
|
||||
}
|
||||
|
||||
// Check OpenAI format: reasoning_effort parameter
|
||||
// Valid values: "low", "medium", "high", "auto" (not "none")
|
||||
reasoningEffort := gjson.GetBytes(body, "reasoning_effort")
|
||||
if reasoningEffort.Exists() {
|
||||
effort := reasoningEffort.String()
|
||||
if effort != "" && effort != "none" {
|
||||
log.Debugf("kiro: thinking mode enabled via OpenAI reasoning_effort: %s", effort)
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
// Check AMP/Cursor format: <thinking_mode>interleaved</thinking_mode> in system prompt
|
||||
// This is how AMP client passes thinking configuration
|
||||
bodyStr := string(body)
|
||||
if strings.Contains(bodyStr, "<thinking_mode>") && strings.Contains(bodyStr, "</thinking_mode>") {
|
||||
// Extract thinking mode value
|
||||
startTag := "<thinking_mode>"
|
||||
endTag := "</thinking_mode>"
|
||||
startIdx := strings.Index(bodyStr, startTag)
|
||||
if startIdx >= 0 {
|
||||
startIdx += len(startTag)
|
||||
endIdx := strings.Index(bodyStr[startIdx:], endTag)
|
||||
if endIdx >= 0 {
|
||||
thinkingMode := bodyStr[startIdx : startIdx+endIdx]
|
||||
if thinkingMode == "interleaved" || thinkingMode == "enabled" {
|
||||
log.Debugf("kiro: thinking mode enabled via AMP/Cursor format: %s", thinkingMode)
|
||||
return true
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Check OpenAI format: max_completion_tokens with reasoning (o1-style)
|
||||
// Some clients use this to indicate reasoning mode
|
||||
if gjson.GetBytes(body, "max_completion_tokens").Exists() {
|
||||
// If max_completion_tokens is set, check if model name suggests reasoning
|
||||
model := gjson.GetBytes(body, "model").String()
|
||||
if strings.Contains(strings.ToLower(model), "thinking") ||
|
||||
strings.Contains(strings.ToLower(model), "reason") {
|
||||
log.Debugf("kiro: thinking mode enabled via model name hint: %s", model)
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
log.Debugf("kiro: IsThinkingEnabled returning false (no thinking mode detected)")
|
||||
return false
|
||||
}
|
||||
|
||||
// shortenToolNameIfNeeded shortens tool names that exceed 64 characters.
|
||||
// MCP tools often have long names like "mcp__server-name__tool-name".
|
||||
// This preserves the "mcp__" prefix and last segment when possible.
|
||||
func shortenToolNameIfNeeded(name string) string {
|
||||
const limit = 64
|
||||
if len(name) <= limit {
|
||||
return name
|
||||
}
|
||||
// For MCP tools, try to preserve prefix and last segment
|
||||
if strings.HasPrefix(name, "mcp__") {
|
||||
idx := strings.LastIndex(name, "__")
|
||||
if idx > 0 {
|
||||
cand := "mcp__" + name[idx+2:]
|
||||
if len(cand) > limit {
|
||||
return cand[:limit]
|
||||
}
|
||||
return cand
|
||||
}
|
||||
}
|
||||
return name[:limit]
|
||||
}
|
||||
|
||||
// convertClaudeToolsToKiro converts Claude tools to Kiro format
|
||||
func convertClaudeToolsToKiro(tools gjson.Result) []KiroToolWrapper {
|
||||
var kiroTools []KiroToolWrapper
|
||||
if !tools.IsArray() {
|
||||
return kiroTools
|
||||
}
|
||||
|
||||
for _, tool := range tools.Array() {
|
||||
name := tool.Get("name").String()
|
||||
description := tool.Get("description").String()
|
||||
inputSchema := tool.Get("input_schema").Value()
|
||||
|
||||
// Shorten tool name if it exceeds 64 characters (common with MCP tools)
|
||||
originalName := name
|
||||
name = shortenToolNameIfNeeded(name)
|
||||
if name != originalName {
|
||||
log.Debugf("kiro: shortened tool name from '%s' to '%s'", originalName, name)
|
||||
}
|
||||
|
||||
// CRITICAL FIX: Kiro API requires non-empty description
|
||||
if strings.TrimSpace(description) == "" {
|
||||
description = fmt.Sprintf("Tool: %s", name)
|
||||
log.Debugf("kiro: tool '%s' has empty description, using default: %s", name, description)
|
||||
}
|
||||
|
||||
// Truncate long descriptions
|
||||
if len(description) > kirocommon.KiroMaxToolDescLen {
|
||||
truncLen := kirocommon.KiroMaxToolDescLen - 30
|
||||
for truncLen > 0 && !utf8.RuneStart(description[truncLen]) {
|
||||
truncLen--
|
||||
}
|
||||
description = description[:truncLen] + "... (description truncated)"
|
||||
}
|
||||
|
||||
kiroTools = append(kiroTools, KiroToolWrapper{
|
||||
ToolSpecification: KiroToolSpecification{
|
||||
Name: name,
|
||||
Description: description,
|
||||
InputSchema: KiroInputSchema{JSON: inputSchema},
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
return kiroTools
|
||||
}
|
||||
|
||||
// processMessages processes Claude messages and builds Kiro history
|
||||
func processMessages(messages gjson.Result, modelID, origin string) ([]KiroHistoryMessage, *KiroUserInputMessage, []KiroToolResult) {
|
||||
var history []KiroHistoryMessage
|
||||
var currentUserMsg *KiroUserInputMessage
|
||||
var currentToolResults []KiroToolResult
|
||||
|
||||
// Merge adjacent messages with the same role
|
||||
messagesArray := kirocommon.MergeAdjacentMessages(messages.Array())
|
||||
for i, msg := range messagesArray {
|
||||
role := msg.Get("role").String()
|
||||
isLastMessage := i == len(messagesArray)-1
|
||||
|
||||
if role == "user" {
|
||||
userMsg, toolResults := BuildUserMessageStruct(msg, modelID, origin)
|
||||
if isLastMessage {
|
||||
currentUserMsg = &userMsg
|
||||
currentToolResults = toolResults
|
||||
} else {
|
||||
// CRITICAL: Kiro API requires content to be non-empty for history messages too
|
||||
if strings.TrimSpace(userMsg.Content) == "" {
|
||||
if len(toolResults) > 0 {
|
||||
userMsg.Content = "Tool results provided."
|
||||
} else {
|
||||
userMsg.Content = "Continue"
|
||||
}
|
||||
}
|
||||
// For history messages, embed tool results in context
|
||||
if len(toolResults) > 0 {
|
||||
userMsg.UserInputMessageContext = &KiroUserInputMessageContext{
|
||||
ToolResults: toolResults,
|
||||
}
|
||||
}
|
||||
history = append(history, KiroHistoryMessage{
|
||||
UserInputMessage: &userMsg,
|
||||
})
|
||||
}
|
||||
} else if role == "assistant" {
|
||||
assistantMsg := BuildAssistantMessageStruct(msg)
|
||||
if isLastMessage {
|
||||
history = append(history, KiroHistoryMessage{
|
||||
AssistantResponseMessage: &assistantMsg,
|
||||
})
|
||||
// Create a "Continue" user message as currentMessage
|
||||
currentUserMsg = &KiroUserInputMessage{
|
||||
Content: "Continue",
|
||||
ModelID: modelID,
|
||||
Origin: origin,
|
||||
}
|
||||
} else {
|
||||
history = append(history, KiroHistoryMessage{
|
||||
AssistantResponseMessage: &assistantMsg,
|
||||
})
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return history, currentUserMsg, currentToolResults
|
||||
}
|
||||
|
||||
// buildFinalContent builds the final content with system prompt
|
||||
func buildFinalContent(content, systemPrompt string, toolResults []KiroToolResult) string {
|
||||
var contentBuilder strings.Builder
|
||||
|
||||
if systemPrompt != "" {
|
||||
contentBuilder.WriteString("--- SYSTEM PROMPT ---\n")
|
||||
contentBuilder.WriteString(systemPrompt)
|
||||
contentBuilder.WriteString("\n--- END SYSTEM PROMPT ---\n\n")
|
||||
}
|
||||
|
||||
contentBuilder.WriteString(content)
|
||||
finalContent := contentBuilder.String()
|
||||
|
||||
// CRITICAL: Kiro API requires content to be non-empty
|
||||
if strings.TrimSpace(finalContent) == "" {
|
||||
if len(toolResults) > 0 {
|
||||
finalContent = "Tool results provided."
|
||||
} else {
|
||||
finalContent = "Continue"
|
||||
}
|
||||
log.Debugf("kiro: content was empty, using default: %s", finalContent)
|
||||
}
|
||||
|
||||
return finalContent
|
||||
}
|
||||
|
||||
// deduplicateToolResults removes duplicate tool results
|
||||
func deduplicateToolResults(toolResults []KiroToolResult) []KiroToolResult {
|
||||
if len(toolResults) == 0 {
|
||||
return toolResults
|
||||
}
|
||||
|
||||
seenIDs := make(map[string]bool)
|
||||
unique := make([]KiroToolResult, 0, len(toolResults))
|
||||
for _, tr := range toolResults {
|
||||
if !seenIDs[tr.ToolUseID] {
|
||||
seenIDs[tr.ToolUseID] = true
|
||||
unique = append(unique, tr)
|
||||
} else {
|
||||
log.Debugf("kiro: skipping duplicate toolResult in currentMessage: %s", tr.ToolUseID)
|
||||
}
|
||||
}
|
||||
return unique
|
||||
}
|
||||
|
||||
// extractClaudeToolChoiceHint extracts tool_choice from Claude request and returns a system prompt hint.
|
||||
// Claude tool_choice values:
|
||||
// - {"type": "auto"}: Model decides (default, no hint needed)
|
||||
// - {"type": "any"}: Must use at least one tool
|
||||
// - {"type": "tool", "name": "..."}: Must use specific tool
|
||||
func extractClaudeToolChoiceHint(claudeBody []byte) string {
|
||||
toolChoice := gjson.GetBytes(claudeBody, "tool_choice")
|
||||
if !toolChoice.Exists() {
|
||||
return ""
|
||||
}
|
||||
|
||||
toolChoiceType := toolChoice.Get("type").String()
|
||||
switch toolChoiceType {
|
||||
case "any":
|
||||
return "[INSTRUCTION: You MUST use at least one of the available tools to respond. Do not respond with text only - always make a tool call.]"
|
||||
case "tool":
|
||||
toolName := toolChoice.Get("name").String()
|
||||
if toolName != "" {
|
||||
return fmt.Sprintf("[INSTRUCTION: You MUST use the tool named '%s' to respond. Do not use any other tool or respond with text only.]", toolName)
|
||||
}
|
||||
case "auto":
|
||||
// Default behavior, no hint needed
|
||||
return ""
|
||||
}
|
||||
|
||||
return ""
|
||||
}
|
||||
|
||||
// BuildUserMessageStruct builds a user message and extracts tool results
|
||||
func BuildUserMessageStruct(msg gjson.Result, modelID, origin string) (KiroUserInputMessage, []KiroToolResult) {
|
||||
content := msg.Get("content")
|
||||
var contentBuilder strings.Builder
|
||||
var toolResults []KiroToolResult
|
||||
var images []KiroImage
|
||||
|
||||
// Track seen toolUseIds to deduplicate
|
||||
seenToolUseIDs := make(map[string]bool)
|
||||
|
||||
if content.IsArray() {
|
||||
for _, part := range content.Array() {
|
||||
partType := part.Get("type").String()
|
||||
switch partType {
|
||||
case "text":
|
||||
contentBuilder.WriteString(part.Get("text").String())
|
||||
case "image":
|
||||
mediaType := part.Get("source.media_type").String()
|
||||
data := part.Get("source.data").String()
|
||||
|
||||
format := ""
|
||||
if idx := strings.LastIndex(mediaType, "/"); idx != -1 {
|
||||
format = mediaType[idx+1:]
|
||||
}
|
||||
|
||||
if format != "" && data != "" {
|
||||
images = append(images, KiroImage{
|
||||
Format: format,
|
||||
Source: KiroImageSource{
|
||||
Bytes: data,
|
||||
},
|
||||
})
|
||||
}
|
||||
case "tool_result":
|
||||
toolUseID := part.Get("tool_use_id").String()
|
||||
|
||||
// Skip duplicate toolUseIds
|
||||
if seenToolUseIDs[toolUseID] {
|
||||
log.Debugf("kiro: skipping duplicate tool_result with toolUseId: %s", toolUseID)
|
||||
continue
|
||||
}
|
||||
seenToolUseIDs[toolUseID] = true
|
||||
|
||||
isError := part.Get("is_error").Bool()
|
||||
resultContent := part.Get("content")
|
||||
|
||||
var textContents []KiroTextContent
|
||||
if resultContent.IsArray() {
|
||||
for _, item := range resultContent.Array() {
|
||||
if item.Get("type").String() == "text" {
|
||||
textContents = append(textContents, KiroTextContent{Text: item.Get("text").String()})
|
||||
} else if item.Type == gjson.String {
|
||||
textContents = append(textContents, KiroTextContent{Text: item.String()})
|
||||
}
|
||||
}
|
||||
} else if resultContent.Type == gjson.String {
|
||||
textContents = append(textContents, KiroTextContent{Text: resultContent.String()})
|
||||
}
|
||||
|
||||
if len(textContents) == 0 {
|
||||
textContents = append(textContents, KiroTextContent{Text: "Tool use was cancelled by the user"})
|
||||
}
|
||||
|
||||
status := "success"
|
||||
if isError {
|
||||
status = "error"
|
||||
}
|
||||
|
||||
toolResults = append(toolResults, KiroToolResult{
|
||||
ToolUseID: toolUseID,
|
||||
Content: textContents,
|
||||
Status: status,
|
||||
})
|
||||
}
|
||||
}
|
||||
} else {
|
||||
contentBuilder.WriteString(content.String())
|
||||
}
|
||||
|
||||
userMsg := KiroUserInputMessage{
|
||||
Content: contentBuilder.String(),
|
||||
ModelID: modelID,
|
||||
Origin: origin,
|
||||
}
|
||||
|
||||
if len(images) > 0 {
|
||||
userMsg.Images = images
|
||||
}
|
||||
|
||||
return userMsg, toolResults
|
||||
}
|
||||
|
||||
// BuildAssistantMessageStruct builds an assistant message with tool uses
|
||||
func BuildAssistantMessageStruct(msg gjson.Result) KiroAssistantResponseMessage {
|
||||
content := msg.Get("content")
|
||||
var contentBuilder strings.Builder
|
||||
var toolUses []KiroToolUse
|
||||
|
||||
if content.IsArray() {
|
||||
for _, part := range content.Array() {
|
||||
partType := part.Get("type").String()
|
||||
switch partType {
|
||||
case "text":
|
||||
contentBuilder.WriteString(part.Get("text").String())
|
||||
case "tool_use":
|
||||
toolUseID := part.Get("id").String()
|
||||
toolName := part.Get("name").String()
|
||||
toolInput := part.Get("input")
|
||||
|
||||
var inputMap map[string]interface{}
|
||||
if toolInput.IsObject() {
|
||||
inputMap = make(map[string]interface{})
|
||||
toolInput.ForEach(func(key, value gjson.Result) bool {
|
||||
inputMap[key.String()] = value.Value()
|
||||
return true
|
||||
})
|
||||
}
|
||||
|
||||
toolUses = append(toolUses, KiroToolUse{
|
||||
ToolUseID: toolUseID,
|
||||
Name: toolName,
|
||||
Input: inputMap,
|
||||
})
|
||||
}
|
||||
}
|
||||
} else {
|
||||
contentBuilder.WriteString(content.String())
|
||||
}
|
||||
|
||||
return KiroAssistantResponseMessage{
|
||||
Content: contentBuilder.String(),
|
||||
ToolUses: toolUses,
|
||||
}
|
||||
}
|
||||
184
internal/translator/kiro/claude/kiro_claude_response.go
Normal file
184
internal/translator/kiro/claude/kiro_claude_response.go
Normal file
@@ -0,0 +1,184 @@
|
||||
// Package claude provides response translation functionality for Kiro API to Claude format.
|
||||
// This package handles the conversion of Kiro API responses into Claude-compatible format,
|
||||
// including support for thinking blocks and tool use.
|
||||
package claude
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"strings"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/usage"
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
kirocommon "github.com/router-for-me/CLIProxyAPI/v6/internal/translator/kiro/common"
|
||||
)
|
||||
|
||||
// Local references to kirocommon constants for thinking block parsing
|
||||
var (
|
||||
thinkingStartTag = kirocommon.ThinkingStartTag
|
||||
thinkingEndTag = kirocommon.ThinkingEndTag
|
||||
)
|
||||
|
||||
// BuildClaudeResponse constructs a Claude-compatible response.
|
||||
// Supports tool_use blocks when tools are present in the response.
|
||||
// Supports thinking blocks - parses <thinking> tags and converts to Claude thinking content blocks.
|
||||
// stopReason is passed from upstream; fallback logic applied if empty.
|
||||
func BuildClaudeResponse(content string, toolUses []KiroToolUse, model string, usageInfo usage.Detail, stopReason string) []byte {
|
||||
var contentBlocks []map[string]interface{}
|
||||
|
||||
// Extract thinking blocks and text from content
|
||||
if content != "" {
|
||||
blocks := ExtractThinkingFromContent(content)
|
||||
contentBlocks = append(contentBlocks, blocks...)
|
||||
|
||||
// Log if thinking blocks were extracted
|
||||
for _, block := range blocks {
|
||||
if block["type"] == "thinking" {
|
||||
thinkingContent := block["thinking"].(string)
|
||||
log.Infof("kiro: buildClaudeResponse extracted thinking block (len: %d)", len(thinkingContent))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Add tool_use blocks
|
||||
for _, toolUse := range toolUses {
|
||||
contentBlocks = append(contentBlocks, map[string]interface{}{
|
||||
"type": "tool_use",
|
||||
"id": toolUse.ToolUseID,
|
||||
"name": toolUse.Name,
|
||||
"input": toolUse.Input,
|
||||
})
|
||||
}
|
||||
|
||||
// Ensure at least one content block (Claude API requires non-empty content)
|
||||
if len(contentBlocks) == 0 {
|
||||
contentBlocks = append(contentBlocks, map[string]interface{}{
|
||||
"type": "text",
|
||||
"text": "",
|
||||
})
|
||||
}
|
||||
|
||||
// Use upstream stopReason; apply fallback logic if not provided
|
||||
if stopReason == "" {
|
||||
stopReason = "end_turn"
|
||||
if len(toolUses) > 0 {
|
||||
stopReason = "tool_use"
|
||||
}
|
||||
log.Debugf("kiro: buildClaudeResponse using fallback stop_reason: %s", stopReason)
|
||||
}
|
||||
|
||||
// Log warning if response was truncated due to max_tokens
|
||||
if stopReason == "max_tokens" {
|
||||
log.Warnf("kiro: response truncated due to max_tokens limit (buildClaudeResponse)")
|
||||
}
|
||||
|
||||
response := map[string]interface{}{
|
||||
"id": "msg_" + uuid.New().String()[:24],
|
||||
"type": "message",
|
||||
"role": "assistant",
|
||||
"model": model,
|
||||
"content": contentBlocks,
|
||||
"stop_reason": stopReason,
|
||||
"usage": map[string]interface{}{
|
||||
"input_tokens": usageInfo.InputTokens,
|
||||
"output_tokens": usageInfo.OutputTokens,
|
||||
},
|
||||
}
|
||||
result, _ := json.Marshal(response)
|
||||
return result
|
||||
}
|
||||
|
||||
// ExtractThinkingFromContent parses content to extract thinking blocks and text.
|
||||
// Returns a list of content blocks in the order they appear in the content.
|
||||
// Handles interleaved thinking and text blocks correctly.
|
||||
func ExtractThinkingFromContent(content string) []map[string]interface{} {
|
||||
var blocks []map[string]interface{}
|
||||
|
||||
if content == "" {
|
||||
return blocks
|
||||
}
|
||||
|
||||
// Check if content contains thinking tags at all
|
||||
if !strings.Contains(content, thinkingStartTag) {
|
||||
// No thinking tags, return as plain text
|
||||
return []map[string]interface{}{
|
||||
{
|
||||
"type": "text",
|
||||
"text": content,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
log.Debugf("kiro: extractThinkingFromContent - found thinking tags in content (len: %d)", len(content))
|
||||
|
||||
remaining := content
|
||||
|
||||
for len(remaining) > 0 {
|
||||
// Look for <thinking> tag
|
||||
startIdx := strings.Index(remaining, thinkingStartTag)
|
||||
|
||||
if startIdx == -1 {
|
||||
// No more thinking tags, add remaining as text
|
||||
if strings.TrimSpace(remaining) != "" {
|
||||
blocks = append(blocks, map[string]interface{}{
|
||||
"type": "text",
|
||||
"text": remaining,
|
||||
})
|
||||
}
|
||||
break
|
||||
}
|
||||
|
||||
// Add text before thinking tag (if any meaningful content)
|
||||
if startIdx > 0 {
|
||||
textBefore := remaining[:startIdx]
|
||||
if strings.TrimSpace(textBefore) != "" {
|
||||
blocks = append(blocks, map[string]interface{}{
|
||||
"type": "text",
|
||||
"text": textBefore,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// Move past the opening tag
|
||||
remaining = remaining[startIdx+len(thinkingStartTag):]
|
||||
|
||||
// Find closing tag
|
||||
endIdx := strings.Index(remaining, thinkingEndTag)
|
||||
|
||||
if endIdx == -1 {
|
||||
// No closing tag found, treat rest as thinking content (incomplete response)
|
||||
if strings.TrimSpace(remaining) != "" {
|
||||
blocks = append(blocks, map[string]interface{}{
|
||||
"type": "thinking",
|
||||
"thinking": remaining,
|
||||
})
|
||||
log.Warnf("kiro: extractThinkingFromContent - missing closing </thinking> tag")
|
||||
}
|
||||
break
|
||||
}
|
||||
|
||||
// Extract thinking content between tags
|
||||
thinkContent := remaining[:endIdx]
|
||||
if strings.TrimSpace(thinkContent) != "" {
|
||||
blocks = append(blocks, map[string]interface{}{
|
||||
"type": "thinking",
|
||||
"thinking": thinkContent,
|
||||
})
|
||||
log.Debugf("kiro: extractThinkingFromContent - extracted thinking block (len: %d)", len(thinkContent))
|
||||
}
|
||||
|
||||
// Move past the closing tag
|
||||
remaining = remaining[endIdx+len(thinkingEndTag):]
|
||||
}
|
||||
|
||||
// If no blocks were created (all whitespace), return empty text block
|
||||
if len(blocks) == 0 {
|
||||
blocks = append(blocks, map[string]interface{}{
|
||||
"type": "text",
|
||||
"text": "",
|
||||
})
|
||||
}
|
||||
|
||||
return blocks
|
||||
}
|
||||
176
internal/translator/kiro/claude/kiro_claude_stream.go
Normal file
176
internal/translator/kiro/claude/kiro_claude_stream.go
Normal file
@@ -0,0 +1,176 @@
|
||||
// Package claude provides streaming SSE event building for Claude format.
|
||||
// This package handles the construction of Claude-compatible Server-Sent Events (SSE)
|
||||
// for streaming responses from Kiro API.
|
||||
package claude
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/usage"
|
||||
)
|
||||
|
||||
// BuildClaudeMessageStartEvent creates the message_start SSE event
|
||||
func BuildClaudeMessageStartEvent(model string, inputTokens int64) []byte {
|
||||
event := map[string]interface{}{
|
||||
"type": "message_start",
|
||||
"message": map[string]interface{}{
|
||||
"id": "msg_" + uuid.New().String()[:24],
|
||||
"type": "message",
|
||||
"role": "assistant",
|
||||
"content": []interface{}{},
|
||||
"model": model,
|
||||
"stop_reason": nil,
|
||||
"stop_sequence": nil,
|
||||
"usage": map[string]interface{}{"input_tokens": inputTokens, "output_tokens": 0},
|
||||
},
|
||||
}
|
||||
result, _ := json.Marshal(event)
|
||||
return []byte("event: message_start\ndata: " + string(result))
|
||||
}
|
||||
|
||||
// BuildClaudeContentBlockStartEvent creates a content_block_start SSE event
|
||||
func BuildClaudeContentBlockStartEvent(index int, blockType, toolUseID, toolName string) []byte {
|
||||
var contentBlock map[string]interface{}
|
||||
switch blockType {
|
||||
case "tool_use":
|
||||
contentBlock = map[string]interface{}{
|
||||
"type": "tool_use",
|
||||
"id": toolUseID,
|
||||
"name": toolName,
|
||||
"input": map[string]interface{}{},
|
||||
}
|
||||
case "thinking":
|
||||
contentBlock = map[string]interface{}{
|
||||
"type": "thinking",
|
||||
"thinking": "",
|
||||
}
|
||||
default:
|
||||
contentBlock = map[string]interface{}{
|
||||
"type": "text",
|
||||
"text": "",
|
||||
}
|
||||
}
|
||||
|
||||
event := map[string]interface{}{
|
||||
"type": "content_block_start",
|
||||
"index": index,
|
||||
"content_block": contentBlock,
|
||||
}
|
||||
result, _ := json.Marshal(event)
|
||||
return []byte("event: content_block_start\ndata: " + string(result))
|
||||
}
|
||||
|
||||
// BuildClaudeStreamEvent creates a text_delta content_block_delta SSE event
|
||||
func BuildClaudeStreamEvent(contentDelta string, index int) []byte {
|
||||
event := map[string]interface{}{
|
||||
"type": "content_block_delta",
|
||||
"index": index,
|
||||
"delta": map[string]interface{}{
|
||||
"type": "text_delta",
|
||||
"text": contentDelta,
|
||||
},
|
||||
}
|
||||
result, _ := json.Marshal(event)
|
||||
return []byte("event: content_block_delta\ndata: " + string(result))
|
||||
}
|
||||
|
||||
// BuildClaudeInputJsonDeltaEvent creates an input_json_delta event for tool use streaming
|
||||
func BuildClaudeInputJsonDeltaEvent(partialJSON string, index int) []byte {
|
||||
event := map[string]interface{}{
|
||||
"type": "content_block_delta",
|
||||
"index": index,
|
||||
"delta": map[string]interface{}{
|
||||
"type": "input_json_delta",
|
||||
"partial_json": partialJSON,
|
||||
},
|
||||
}
|
||||
result, _ := json.Marshal(event)
|
||||
return []byte("event: content_block_delta\ndata: " + string(result))
|
||||
}
|
||||
|
||||
// BuildClaudeContentBlockStopEvent creates a content_block_stop SSE event
|
||||
func BuildClaudeContentBlockStopEvent(index int) []byte {
|
||||
event := map[string]interface{}{
|
||||
"type": "content_block_stop",
|
||||
"index": index,
|
||||
}
|
||||
result, _ := json.Marshal(event)
|
||||
return []byte("event: content_block_stop\ndata: " + string(result))
|
||||
}
|
||||
|
||||
// BuildClaudeMessageDeltaEvent creates the message_delta event with stop_reason and usage
|
||||
func BuildClaudeMessageDeltaEvent(stopReason string, usageInfo usage.Detail) []byte {
|
||||
deltaEvent := map[string]interface{}{
|
||||
"type": "message_delta",
|
||||
"delta": map[string]interface{}{
|
||||
"stop_reason": stopReason,
|
||||
"stop_sequence": nil,
|
||||
},
|
||||
"usage": map[string]interface{}{
|
||||
"input_tokens": usageInfo.InputTokens,
|
||||
"output_tokens": usageInfo.OutputTokens,
|
||||
},
|
||||
}
|
||||
deltaResult, _ := json.Marshal(deltaEvent)
|
||||
return []byte("event: message_delta\ndata: " + string(deltaResult))
|
||||
}
|
||||
|
||||
// BuildClaudeMessageStopOnlyEvent creates only the message_stop event
|
||||
func BuildClaudeMessageStopOnlyEvent() []byte {
|
||||
stopEvent := map[string]interface{}{
|
||||
"type": "message_stop",
|
||||
}
|
||||
stopResult, _ := json.Marshal(stopEvent)
|
||||
return []byte("event: message_stop\ndata: " + string(stopResult))
|
||||
}
|
||||
|
||||
// BuildClaudePingEventWithUsage creates a ping event with embedded usage information.
|
||||
// This is used for real-time usage estimation during streaming.
|
||||
func BuildClaudePingEventWithUsage(inputTokens, outputTokens int64) []byte {
|
||||
event := map[string]interface{}{
|
||||
"type": "ping",
|
||||
"usage": map[string]interface{}{
|
||||
"input_tokens": inputTokens,
|
||||
"output_tokens": outputTokens,
|
||||
"total_tokens": inputTokens + outputTokens,
|
||||
"estimated": true,
|
||||
},
|
||||
}
|
||||
result, _ := json.Marshal(event)
|
||||
return []byte("event: ping\ndata: " + string(result))
|
||||
}
|
||||
|
||||
// BuildClaudeThinkingDeltaEvent creates a thinking_delta event for Claude API compatibility.
|
||||
// This is used when streaming thinking content wrapped in <thinking> tags.
|
||||
func BuildClaudeThinkingDeltaEvent(thinkingDelta string, index int) []byte {
|
||||
event := map[string]interface{}{
|
||||
"type": "content_block_delta",
|
||||
"index": index,
|
||||
"delta": map[string]interface{}{
|
||||
"type": "thinking_delta",
|
||||
"thinking": thinkingDelta,
|
||||
},
|
||||
}
|
||||
result, _ := json.Marshal(event)
|
||||
return []byte("event: content_block_delta\ndata: " + string(result))
|
||||
}
|
||||
|
||||
// PendingTagSuffix detects if the buffer ends with a partial prefix of the given tag.
|
||||
// Returns the length of the partial match (0 if no match).
|
||||
// Based on amq2api implementation for handling cross-chunk tag boundaries.
|
||||
func PendingTagSuffix(buffer, tag string) int {
|
||||
if buffer == "" || tag == "" {
|
||||
return 0
|
||||
}
|
||||
maxLen := len(buffer)
|
||||
if maxLen > len(tag)-1 {
|
||||
maxLen = len(tag) - 1
|
||||
}
|
||||
for length := maxLen; length > 0; length-- {
|
||||
if len(buffer) >= length && buffer[len(buffer)-length:] == tag[:length] {
|
||||
return length
|
||||
}
|
||||
}
|
||||
return 0
|
||||
}
|
||||
522
internal/translator/kiro/claude/kiro_claude_tools.go
Normal file
522
internal/translator/kiro/claude/kiro_claude_tools.go
Normal file
@@ -0,0 +1,522 @@
|
||||
// Package claude provides tool calling support for Kiro to Claude translation.
|
||||
// This package handles parsing embedded tool calls, JSON repair, and deduplication.
|
||||
package claude
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"regexp"
|
||||
"strings"
|
||||
|
||||
"github.com/google/uuid"
|
||||
kirocommon "github.com/router-for-me/CLIProxyAPI/v6/internal/translator/kiro/common"
|
||||
log "github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
// ToolUseState tracks the state of an in-progress tool use during streaming.
|
||||
type ToolUseState struct {
|
||||
ToolUseID string
|
||||
Name string
|
||||
InputBuffer strings.Builder
|
||||
IsComplete bool
|
||||
}
|
||||
|
||||
// Pre-compiled regex patterns for performance
|
||||
var (
|
||||
// embeddedToolCallPattern matches [Called tool_name with args: {...}] format
|
||||
embeddedToolCallPattern = regexp.MustCompile(`\[Called\s+([A-Za-z0-9_.-]+)\s+with\s+args:\s*`)
|
||||
// trailingCommaPattern matches trailing commas before closing braces/brackets
|
||||
trailingCommaPattern = regexp.MustCompile(`,\s*([}\]])`)
|
||||
)
|
||||
|
||||
// ParseEmbeddedToolCalls extracts [Called tool_name with args: {...}] format from text.
|
||||
// Kiro sometimes embeds tool calls in text content instead of using toolUseEvent.
|
||||
// Returns the cleaned text (with tool calls removed) and extracted tool uses.
|
||||
func ParseEmbeddedToolCalls(text string, processedIDs map[string]bool) (string, []KiroToolUse) {
|
||||
if !strings.Contains(text, "[Called") {
|
||||
return text, nil
|
||||
}
|
||||
|
||||
var toolUses []KiroToolUse
|
||||
cleanText := text
|
||||
|
||||
// Find all [Called markers
|
||||
matches := embeddedToolCallPattern.FindAllStringSubmatchIndex(text, -1)
|
||||
if len(matches) == 0 {
|
||||
return text, nil
|
||||
}
|
||||
|
||||
// Process matches in reverse order to maintain correct indices
|
||||
for i := len(matches) - 1; i >= 0; i-- {
|
||||
matchStart := matches[i][0]
|
||||
toolNameStart := matches[i][2]
|
||||
toolNameEnd := matches[i][3]
|
||||
|
||||
if toolNameStart < 0 || toolNameEnd < 0 {
|
||||
continue
|
||||
}
|
||||
|
||||
toolName := text[toolNameStart:toolNameEnd]
|
||||
|
||||
// Find the JSON object start (after "with args:")
|
||||
jsonStart := matches[i][1]
|
||||
if jsonStart >= len(text) {
|
||||
continue
|
||||
}
|
||||
|
||||
// Skip whitespace to find the opening brace
|
||||
for jsonStart < len(text) && (text[jsonStart] == ' ' || text[jsonStart] == '\t') {
|
||||
jsonStart++
|
||||
}
|
||||
|
||||
if jsonStart >= len(text) || text[jsonStart] != '{' {
|
||||
continue
|
||||
}
|
||||
|
||||
// Find matching closing bracket
|
||||
jsonEnd := findMatchingBracket(text, jsonStart)
|
||||
if jsonEnd < 0 {
|
||||
continue
|
||||
}
|
||||
|
||||
// Extract JSON and find the closing bracket of [Called ...]
|
||||
jsonStr := text[jsonStart : jsonEnd+1]
|
||||
|
||||
// Find the closing ] after the JSON
|
||||
closingBracket := jsonEnd + 1
|
||||
for closingBracket < len(text) && text[closingBracket] != ']' {
|
||||
closingBracket++
|
||||
}
|
||||
if closingBracket >= len(text) {
|
||||
continue
|
||||
}
|
||||
|
||||
// End index of the full tool call (closing ']' inclusive)
|
||||
matchEnd := closingBracket + 1
|
||||
|
||||
// Repair and parse JSON
|
||||
repairedJSON := RepairJSON(jsonStr)
|
||||
var inputMap map[string]interface{}
|
||||
if err := json.Unmarshal([]byte(repairedJSON), &inputMap); err != nil {
|
||||
log.Debugf("kiro: failed to parse embedded tool call JSON: %v, raw: %s", err, jsonStr)
|
||||
continue
|
||||
}
|
||||
|
||||
// Generate unique tool ID
|
||||
toolUseID := "toolu_" + uuid.New().String()[:12]
|
||||
|
||||
// Check for duplicates using name+input as key
|
||||
dedupeKey := toolName + ":" + repairedJSON
|
||||
if processedIDs != nil {
|
||||
if processedIDs[dedupeKey] {
|
||||
log.Debugf("kiro: skipping duplicate embedded tool call: %s", toolName)
|
||||
// Still remove from text even if duplicate
|
||||
if matchStart >= 0 && matchEnd <= len(cleanText) && matchStart <= matchEnd {
|
||||
cleanText = cleanText[:matchStart] + cleanText[matchEnd:]
|
||||
}
|
||||
continue
|
||||
}
|
||||
processedIDs[dedupeKey] = true
|
||||
}
|
||||
|
||||
toolUses = append(toolUses, KiroToolUse{
|
||||
ToolUseID: toolUseID,
|
||||
Name: toolName,
|
||||
Input: inputMap,
|
||||
})
|
||||
|
||||
log.Infof("kiro: extracted embedded tool call: %s (ID: %s)", toolName, toolUseID)
|
||||
|
||||
// Remove from clean text (index-based removal to avoid deleting the wrong occurrence)
|
||||
if matchStart >= 0 && matchEnd <= len(cleanText) && matchStart <= matchEnd {
|
||||
cleanText = cleanText[:matchStart] + cleanText[matchEnd:]
|
||||
}
|
||||
}
|
||||
|
||||
return cleanText, toolUses
|
||||
}
|
||||
|
||||
// findMatchingBracket finds the index of the closing brace/bracket that matches
|
||||
// the opening one at startPos. Handles nested objects and strings correctly.
|
||||
func findMatchingBracket(text string, startPos int) int {
|
||||
if startPos >= len(text) {
|
||||
return -1
|
||||
}
|
||||
|
||||
openChar := text[startPos]
|
||||
var closeChar byte
|
||||
switch openChar {
|
||||
case '{':
|
||||
closeChar = '}'
|
||||
case '[':
|
||||
closeChar = ']'
|
||||
default:
|
||||
return -1
|
||||
}
|
||||
|
||||
depth := 1
|
||||
inString := false
|
||||
escapeNext := false
|
||||
|
||||
for i := startPos + 1; i < len(text); i++ {
|
||||
char := text[i]
|
||||
|
||||
if escapeNext {
|
||||
escapeNext = false
|
||||
continue
|
||||
}
|
||||
|
||||
if char == '\\' && inString {
|
||||
escapeNext = true
|
||||
continue
|
||||
}
|
||||
|
||||
if char == '"' {
|
||||
inString = !inString
|
||||
continue
|
||||
}
|
||||
|
||||
if !inString {
|
||||
if char == openChar {
|
||||
depth++
|
||||
} else if char == closeChar {
|
||||
depth--
|
||||
if depth == 0 {
|
||||
return i
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return -1
|
||||
}
|
||||
|
||||
// RepairJSON attempts to fix common JSON issues that may occur in tool call arguments.
|
||||
// Conservative repair strategy:
|
||||
// 1. First try to parse JSON directly - if valid, return as-is
|
||||
// 2. Only attempt repair if parsing fails
|
||||
// 3. After repair, validate the result - if still invalid, return original
|
||||
func RepairJSON(jsonString string) string {
|
||||
// Handle empty or invalid input
|
||||
if jsonString == "" {
|
||||
return "{}"
|
||||
}
|
||||
|
||||
str := strings.TrimSpace(jsonString)
|
||||
if str == "" {
|
||||
return "{}"
|
||||
}
|
||||
|
||||
// CONSERVATIVE STRATEGY: First try to parse directly
|
||||
var testParse interface{}
|
||||
if err := json.Unmarshal([]byte(str), &testParse); err == nil {
|
||||
log.Debugf("kiro: repairJSON - JSON is already valid, returning unchanged")
|
||||
return str
|
||||
}
|
||||
|
||||
log.Debugf("kiro: repairJSON - JSON parse failed, attempting repair")
|
||||
originalStr := str
|
||||
|
||||
// First, escape unescaped newlines/tabs within JSON string values
|
||||
str = escapeNewlinesInStrings(str)
|
||||
// Remove trailing commas before closing braces/brackets
|
||||
str = trailingCommaPattern.ReplaceAllString(str, "$1")
|
||||
|
||||
// Calculate bracket balance
|
||||
braceCount := 0
|
||||
bracketCount := 0
|
||||
inString := false
|
||||
escape := false
|
||||
lastValidIndex := -1
|
||||
|
||||
for i := 0; i < len(str); i++ {
|
||||
char := str[i]
|
||||
|
||||
if escape {
|
||||
escape = false
|
||||
continue
|
||||
}
|
||||
|
||||
if char == '\\' {
|
||||
escape = true
|
||||
continue
|
||||
}
|
||||
|
||||
if char == '"' {
|
||||
inString = !inString
|
||||
continue
|
||||
}
|
||||
|
||||
if inString {
|
||||
continue
|
||||
}
|
||||
|
||||
switch char {
|
||||
case '{':
|
||||
braceCount++
|
||||
case '}':
|
||||
braceCount--
|
||||
case '[':
|
||||
bracketCount++
|
||||
case ']':
|
||||
bracketCount--
|
||||
}
|
||||
|
||||
if braceCount >= 0 && bracketCount >= 0 {
|
||||
lastValidIndex = i
|
||||
}
|
||||
}
|
||||
|
||||
// If brackets are unbalanced, try to repair
|
||||
if braceCount > 0 || bracketCount > 0 {
|
||||
if lastValidIndex > 0 && lastValidIndex < len(str)-1 {
|
||||
truncated := str[:lastValidIndex+1]
|
||||
// Recount brackets after truncation
|
||||
braceCount = 0
|
||||
bracketCount = 0
|
||||
inString = false
|
||||
escape = false
|
||||
for i := 0; i < len(truncated); i++ {
|
||||
char := truncated[i]
|
||||
if escape {
|
||||
escape = false
|
||||
continue
|
||||
}
|
||||
if char == '\\' {
|
||||
escape = true
|
||||
continue
|
||||
}
|
||||
if char == '"' {
|
||||
inString = !inString
|
||||
continue
|
||||
}
|
||||
if inString {
|
||||
continue
|
||||
}
|
||||
switch char {
|
||||
case '{':
|
||||
braceCount++
|
||||
case '}':
|
||||
braceCount--
|
||||
case '[':
|
||||
bracketCount++
|
||||
case ']':
|
||||
bracketCount--
|
||||
}
|
||||
}
|
||||
str = truncated
|
||||
}
|
||||
|
||||
// Add missing closing brackets
|
||||
for braceCount > 0 {
|
||||
str += "}"
|
||||
braceCount--
|
||||
}
|
||||
for bracketCount > 0 {
|
||||
str += "]"
|
||||
bracketCount--
|
||||
}
|
||||
}
|
||||
|
||||
// Validate repaired JSON
|
||||
if err := json.Unmarshal([]byte(str), &testParse); err != nil {
|
||||
log.Warnf("kiro: repairJSON - repair failed to produce valid JSON, returning original")
|
||||
return originalStr
|
||||
}
|
||||
|
||||
log.Debugf("kiro: repairJSON - successfully repaired JSON")
|
||||
return str
|
||||
}
|
||||
|
||||
// escapeNewlinesInStrings escapes literal newlines, tabs, and other control characters
|
||||
// that appear inside JSON string values.
|
||||
func escapeNewlinesInStrings(raw string) string {
|
||||
var result strings.Builder
|
||||
result.Grow(len(raw) + 100)
|
||||
|
||||
inString := false
|
||||
escaped := false
|
||||
|
||||
for i := 0; i < len(raw); i++ {
|
||||
c := raw[i]
|
||||
|
||||
if escaped {
|
||||
result.WriteByte(c)
|
||||
escaped = false
|
||||
continue
|
||||
}
|
||||
|
||||
if c == '\\' && inString {
|
||||
result.WriteByte(c)
|
||||
escaped = true
|
||||
continue
|
||||
}
|
||||
|
||||
if c == '"' {
|
||||
inString = !inString
|
||||
result.WriteByte(c)
|
||||
continue
|
||||
}
|
||||
|
||||
if inString {
|
||||
switch c {
|
||||
case '\n':
|
||||
result.WriteString("\\n")
|
||||
case '\r':
|
||||
result.WriteString("\\r")
|
||||
case '\t':
|
||||
result.WriteString("\\t")
|
||||
default:
|
||||
result.WriteByte(c)
|
||||
}
|
||||
} else {
|
||||
result.WriteByte(c)
|
||||
}
|
||||
}
|
||||
|
||||
return result.String()
|
||||
}
|
||||
|
||||
// ProcessToolUseEvent handles a toolUseEvent from the Kiro stream.
|
||||
// It accumulates input fragments and emits tool_use blocks when complete.
|
||||
// Returns events to emit and updated state.
|
||||
func ProcessToolUseEvent(event map[string]interface{}, currentToolUse *ToolUseState, processedIDs map[string]bool) ([]KiroToolUse, *ToolUseState) {
|
||||
var toolUses []KiroToolUse
|
||||
|
||||
// Extract from nested toolUseEvent or direct format
|
||||
tu := event
|
||||
if nested, ok := event["toolUseEvent"].(map[string]interface{}); ok {
|
||||
tu = nested
|
||||
}
|
||||
|
||||
toolUseID := kirocommon.GetString(tu, "toolUseId")
|
||||
toolName := kirocommon.GetString(tu, "name")
|
||||
isStop := false
|
||||
if stop, ok := tu["stop"].(bool); ok {
|
||||
isStop = stop
|
||||
}
|
||||
|
||||
// Get input - can be string (fragment) or object (complete)
|
||||
var inputFragment string
|
||||
var inputMap map[string]interface{}
|
||||
|
||||
if inputRaw, ok := tu["input"]; ok {
|
||||
switch v := inputRaw.(type) {
|
||||
case string:
|
||||
inputFragment = v
|
||||
case map[string]interface{}:
|
||||
inputMap = v
|
||||
}
|
||||
}
|
||||
|
||||
// New tool use starting
|
||||
if toolUseID != "" && toolName != "" {
|
||||
if currentToolUse != nil && currentToolUse.ToolUseID != toolUseID {
|
||||
log.Warnf("kiro: interleaved tool use detected - new ID %s arrived while %s in progress, completing previous",
|
||||
toolUseID, currentToolUse.ToolUseID)
|
||||
if !processedIDs[currentToolUse.ToolUseID] {
|
||||
incomplete := KiroToolUse{
|
||||
ToolUseID: currentToolUse.ToolUseID,
|
||||
Name: currentToolUse.Name,
|
||||
}
|
||||
if currentToolUse.InputBuffer.Len() > 0 {
|
||||
raw := currentToolUse.InputBuffer.String()
|
||||
repaired := RepairJSON(raw)
|
||||
|
||||
var input map[string]interface{}
|
||||
if err := json.Unmarshal([]byte(repaired), &input); err != nil {
|
||||
log.Warnf("kiro: failed to parse interleaved tool input: %v, raw: %s", err, raw)
|
||||
input = make(map[string]interface{})
|
||||
}
|
||||
incomplete.Input = input
|
||||
}
|
||||
toolUses = append(toolUses, incomplete)
|
||||
processedIDs[currentToolUse.ToolUseID] = true
|
||||
}
|
||||
currentToolUse = nil
|
||||
}
|
||||
|
||||
if currentToolUse == nil {
|
||||
if processedIDs != nil && processedIDs[toolUseID] {
|
||||
log.Debugf("kiro: skipping duplicate toolUseEvent: %s", toolUseID)
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
currentToolUse = &ToolUseState{
|
||||
ToolUseID: toolUseID,
|
||||
Name: toolName,
|
||||
}
|
||||
log.Infof("kiro: starting new tool use: %s (ID: %s)", toolName, toolUseID)
|
||||
}
|
||||
}
|
||||
|
||||
// Accumulate input fragments
|
||||
if currentToolUse != nil && inputFragment != "" {
|
||||
currentToolUse.InputBuffer.WriteString(inputFragment)
|
||||
log.Debugf("kiro: accumulated input fragment, total length: %d", currentToolUse.InputBuffer.Len())
|
||||
}
|
||||
|
||||
// If complete input object provided directly
|
||||
if currentToolUse != nil && inputMap != nil {
|
||||
inputBytes, _ := json.Marshal(inputMap)
|
||||
currentToolUse.InputBuffer.Reset()
|
||||
currentToolUse.InputBuffer.Write(inputBytes)
|
||||
}
|
||||
|
||||
// Tool use complete
|
||||
if isStop && currentToolUse != nil {
|
||||
fullInput := currentToolUse.InputBuffer.String()
|
||||
|
||||
// Repair and parse the accumulated JSON
|
||||
repairedJSON := RepairJSON(fullInput)
|
||||
var finalInput map[string]interface{}
|
||||
if err := json.Unmarshal([]byte(repairedJSON), &finalInput); err != nil {
|
||||
log.Warnf("kiro: failed to parse accumulated tool input: %v, raw: %s", err, fullInput)
|
||||
finalInput = make(map[string]interface{})
|
||||
}
|
||||
|
||||
toolUse := KiroToolUse{
|
||||
ToolUseID: currentToolUse.ToolUseID,
|
||||
Name: currentToolUse.Name,
|
||||
Input: finalInput,
|
||||
}
|
||||
toolUses = append(toolUses, toolUse)
|
||||
|
||||
if processedIDs != nil {
|
||||
processedIDs[currentToolUse.ToolUseID] = true
|
||||
}
|
||||
|
||||
log.Infof("kiro: completed tool use: %s (ID: %s)", currentToolUse.Name, currentToolUse.ToolUseID)
|
||||
return toolUses, nil
|
||||
}
|
||||
|
||||
return toolUses, currentToolUse
|
||||
}
|
||||
|
||||
// DeduplicateToolUses removes duplicate tool uses based on toolUseId and content.
|
||||
func DeduplicateToolUses(toolUses []KiroToolUse) []KiroToolUse {
|
||||
seenIDs := make(map[string]bool)
|
||||
seenContent := make(map[string]bool)
|
||||
var unique []KiroToolUse
|
||||
|
||||
for _, tu := range toolUses {
|
||||
if seenIDs[tu.ToolUseID] {
|
||||
log.Debugf("kiro: removing ID-duplicate tool use: %s (name: %s)", tu.ToolUseID, tu.Name)
|
||||
continue
|
||||
}
|
||||
|
||||
inputJSON, _ := json.Marshal(tu.Input)
|
||||
contentKey := tu.Name + ":" + string(inputJSON)
|
||||
|
||||
if seenContent[contentKey] {
|
||||
log.Debugf("kiro: removing content-duplicate tool use: %s (id: %s)", tu.Name, tu.ToolUseID)
|
||||
continue
|
||||
}
|
||||
|
||||
seenIDs[tu.ToolUseID] = true
|
||||
seenContent[contentKey] = true
|
||||
unique = append(unique, tu)
|
||||
}
|
||||
|
||||
return unique
|
||||
}
|
||||
|
||||
75
internal/translator/kiro/common/constants.go
Normal file
75
internal/translator/kiro/common/constants.go
Normal file
@@ -0,0 +1,75 @@
|
||||
// Package common provides shared constants and utilities for Kiro translator.
|
||||
package common
|
||||
|
||||
const (
|
||||
// KiroMaxToolDescLen is the maximum description length for Kiro API tools.
|
||||
// Kiro API limit is 10240 bytes, leave room for "..."
|
||||
KiroMaxToolDescLen = 10237
|
||||
|
||||
// ThinkingStartTag is the start tag for thinking blocks in responses.
|
||||
ThinkingStartTag = "<thinking>"
|
||||
|
||||
// ThinkingEndTag is the end tag for thinking blocks in responses.
|
||||
ThinkingEndTag = "</thinking>"
|
||||
|
||||
// CodeFenceMarker is the markdown code fence marker.
|
||||
CodeFenceMarker = "```"
|
||||
|
||||
// AltCodeFenceMarker is the alternative markdown code fence marker.
|
||||
AltCodeFenceMarker = "~~~"
|
||||
|
||||
// InlineCodeMarker is the markdown inline code marker (backtick).
|
||||
InlineCodeMarker = "`"
|
||||
|
||||
// KiroAgenticSystemPrompt is injected only for -agentic models to prevent timeouts on large writes.
|
||||
// AWS Kiro API has a 2-3 minute timeout for large file write operations.
|
||||
KiroAgenticSystemPrompt = `
|
||||
# CRITICAL: CHUNKED WRITE PROTOCOL (MANDATORY)
|
||||
|
||||
You MUST follow these rules for ALL file operations. Violation causes server timeouts and task failure.
|
||||
|
||||
## ABSOLUTE LIMITS
|
||||
- **MAXIMUM 350 LINES** per single write/edit operation - NO EXCEPTIONS
|
||||
- **RECOMMENDED 300 LINES** or less for optimal performance
|
||||
- **NEVER** write entire files in one operation if >300 lines
|
||||
|
||||
## MANDATORY CHUNKED WRITE STRATEGY
|
||||
|
||||
### For NEW FILES (>300 lines total):
|
||||
1. FIRST: Write initial chunk (first 250-300 lines) using write_to_file/fsWrite
|
||||
2. THEN: Append remaining content in 250-300 line chunks using file append operations
|
||||
3. REPEAT: Continue appending until complete
|
||||
|
||||
### For EDITING EXISTING FILES:
|
||||
1. Use surgical edits (apply_diff/targeted edits) - change ONLY what's needed
|
||||
2. NEVER rewrite entire files - use incremental modifications
|
||||
3. Split large refactors into multiple small, focused edits
|
||||
|
||||
### For LARGE CODE GENERATION:
|
||||
1. Generate in logical sections (imports, types, functions separately)
|
||||
2. Write each section as a separate operation
|
||||
3. Use append operations for subsequent sections
|
||||
|
||||
## EXAMPLES OF CORRECT BEHAVIOR
|
||||
|
||||
✅ CORRECT: Writing a 600-line file
|
||||
- Operation 1: Write lines 1-300 (initial file creation)
|
||||
- Operation 2: Append lines 301-600
|
||||
|
||||
✅ CORRECT: Editing multiple functions
|
||||
- Operation 1: Edit function A
|
||||
- Operation 2: Edit function B
|
||||
- Operation 3: Edit function C
|
||||
|
||||
❌ WRONG: Writing 500 lines in single operation → TIMEOUT
|
||||
❌ WRONG: Rewriting entire file to change 5 lines → TIMEOUT
|
||||
❌ WRONG: Generating massive code blocks without chunking → TIMEOUT
|
||||
|
||||
## WHY THIS MATTERS
|
||||
- Server has 2-3 minute timeout for operations
|
||||
- Large writes exceed timeout and FAIL completely
|
||||
- Chunked writes are FASTER and more RELIABLE
|
||||
- Failed writes waste time and require retry
|
||||
|
||||
REMEMBER: When in doubt, write LESS per operation. Multiple small operations > one large operation.`
|
||||
)
|
||||
125
internal/translator/kiro/common/message_merge.go
Normal file
125
internal/translator/kiro/common/message_merge.go
Normal file
@@ -0,0 +1,125 @@
|
||||
// Package common provides shared utilities for Kiro translators.
|
||||
package common
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
|
||||
"github.com/tidwall/gjson"
|
||||
)
|
||||
|
||||
// MergeAdjacentMessages merges adjacent messages with the same role.
|
||||
// This reduces API call complexity and improves compatibility.
|
||||
// Based on AIClient-2-API implementation.
|
||||
func MergeAdjacentMessages(messages []gjson.Result) []gjson.Result {
|
||||
if len(messages) <= 1 {
|
||||
return messages
|
||||
}
|
||||
|
||||
var merged []gjson.Result
|
||||
for _, msg := range messages {
|
||||
if len(merged) == 0 {
|
||||
merged = append(merged, msg)
|
||||
continue
|
||||
}
|
||||
|
||||
lastMsg := merged[len(merged)-1]
|
||||
currentRole := msg.Get("role").String()
|
||||
lastRole := lastMsg.Get("role").String()
|
||||
|
||||
if currentRole == lastRole {
|
||||
// Merge content from current message into last message
|
||||
mergedContent := mergeMessageContent(lastMsg, msg)
|
||||
// Create a new merged message JSON
|
||||
mergedMsg := createMergedMessage(lastRole, mergedContent)
|
||||
merged[len(merged)-1] = gjson.Parse(mergedMsg)
|
||||
} else {
|
||||
merged = append(merged, msg)
|
||||
}
|
||||
}
|
||||
|
||||
return merged
|
||||
}
|
||||
|
||||
// mergeMessageContent merges the content of two messages with the same role.
|
||||
// Handles both string content and array content (with text, tool_use, tool_result blocks).
|
||||
func mergeMessageContent(msg1, msg2 gjson.Result) string {
|
||||
content1 := msg1.Get("content")
|
||||
content2 := msg2.Get("content")
|
||||
|
||||
// Extract content blocks from both messages
|
||||
var blocks1, blocks2 []map[string]interface{}
|
||||
|
||||
if content1.IsArray() {
|
||||
for _, block := range content1.Array() {
|
||||
blocks1 = append(blocks1, blockToMap(block))
|
||||
}
|
||||
} else if content1.Type == gjson.String {
|
||||
blocks1 = append(blocks1, map[string]interface{}{
|
||||
"type": "text",
|
||||
"text": content1.String(),
|
||||
})
|
||||
}
|
||||
|
||||
if content2.IsArray() {
|
||||
for _, block := range content2.Array() {
|
||||
blocks2 = append(blocks2, blockToMap(block))
|
||||
}
|
||||
} else if content2.Type == gjson.String {
|
||||
blocks2 = append(blocks2, map[string]interface{}{
|
||||
"type": "text",
|
||||
"text": content2.String(),
|
||||
})
|
||||
}
|
||||
|
||||
// Merge text blocks if both end/start with text
|
||||
if len(blocks1) > 0 && len(blocks2) > 0 {
|
||||
if blocks1[len(blocks1)-1]["type"] == "text" && blocks2[0]["type"] == "text" {
|
||||
// Merge the last text block of msg1 with the first text block of msg2
|
||||
text1 := blocks1[len(blocks1)-1]["text"].(string)
|
||||
text2 := blocks2[0]["text"].(string)
|
||||
blocks1[len(blocks1)-1]["text"] = text1 + "\n" + text2
|
||||
blocks2 = blocks2[1:] // Remove the merged block from blocks2
|
||||
}
|
||||
}
|
||||
|
||||
// Combine all blocks
|
||||
allBlocks := append(blocks1, blocks2...)
|
||||
|
||||
// Convert to JSON
|
||||
result, _ := json.Marshal(allBlocks)
|
||||
return string(result)
|
||||
}
|
||||
|
||||
// blockToMap converts a gjson.Result block to a map[string]interface{}
|
||||
func blockToMap(block gjson.Result) map[string]interface{} {
|
||||
result := make(map[string]interface{})
|
||||
block.ForEach(func(key, value gjson.Result) bool {
|
||||
if value.IsObject() {
|
||||
result[key.String()] = blockToMap(value)
|
||||
} else if value.IsArray() {
|
||||
var arr []interface{}
|
||||
for _, item := range value.Array() {
|
||||
if item.IsObject() {
|
||||
arr = append(arr, blockToMap(item))
|
||||
} else {
|
||||
arr = append(arr, item.Value())
|
||||
}
|
||||
}
|
||||
result[key.String()] = arr
|
||||
} else {
|
||||
result[key.String()] = value.Value()
|
||||
}
|
||||
return true
|
||||
})
|
||||
return result
|
||||
}
|
||||
|
||||
// createMergedMessage creates a JSON string for a merged message
|
||||
func createMergedMessage(role string, content string) string {
|
||||
msg := map[string]interface{}{
|
||||
"role": role,
|
||||
"content": json.RawMessage(content),
|
||||
}
|
||||
result, _ := json.Marshal(msg)
|
||||
return string(result)
|
||||
}
|
||||
16
internal/translator/kiro/common/utils.go
Normal file
16
internal/translator/kiro/common/utils.go
Normal file
@@ -0,0 +1,16 @@
|
||||
// Package common provides shared constants and utilities for Kiro translator.
|
||||
package common
|
||||
|
||||
// GetString safely extracts a string from a map.
|
||||
// Returns empty string if the key doesn't exist or the value is not a string.
|
||||
func GetString(m map[string]interface{}, key string) string {
|
||||
if v, ok := m[key].(string); ok {
|
||||
return v
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
// GetStringValue is an alias for GetString for backward compatibility.
|
||||
func GetStringValue(m map[string]interface{}, key string) string {
|
||||
return GetString(m, key)
|
||||
}
|
||||
@@ -1,319 +0,0 @@
|
||||
// Package chat_completions provides request translation from OpenAI to Kiro format.
|
||||
package chat_completions
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"strings"
|
||||
|
||||
"github.com/tidwall/gjson"
|
||||
"github.com/tidwall/sjson"
|
||||
)
|
||||
|
||||
// ConvertOpenAIRequestToKiro transforms an OpenAI Chat Completions API request into Kiro (Claude) format.
|
||||
// Kiro uses Claude-compatible format internally, so we primarily pass through to Claude format.
|
||||
// Supports tool calling: OpenAI tools -> Claude tools, tool_calls -> tool_use, tool messages -> tool_result.
|
||||
func ConvertOpenAIRequestToKiro(modelName string, inputRawJSON []byte, stream bool) []byte {
|
||||
rawJSON := bytes.Clone(inputRawJSON)
|
||||
root := gjson.ParseBytes(rawJSON)
|
||||
|
||||
// Build Claude-compatible request
|
||||
out := `{"model":"","max_tokens":32000,"messages":[]}`
|
||||
|
||||
// Set model
|
||||
out, _ = sjson.Set(out, "model", modelName)
|
||||
|
||||
// Copy max_tokens if present
|
||||
if v := root.Get("max_tokens"); v.Exists() {
|
||||
out, _ = sjson.Set(out, "max_tokens", v.Int())
|
||||
}
|
||||
|
||||
// Copy temperature if present
|
||||
if v := root.Get("temperature"); v.Exists() {
|
||||
out, _ = sjson.Set(out, "temperature", v.Float())
|
||||
}
|
||||
|
||||
// Copy top_p if present
|
||||
if v := root.Get("top_p"); v.Exists() {
|
||||
out, _ = sjson.Set(out, "top_p", v.Float())
|
||||
}
|
||||
|
||||
// Convert OpenAI tools to Claude tools format
|
||||
if tools := root.Get("tools"); tools.Exists() && tools.IsArray() {
|
||||
claudeTools := make([]interface{}, 0)
|
||||
for _, tool := range tools.Array() {
|
||||
if tool.Get("type").String() == "function" {
|
||||
fn := tool.Get("function")
|
||||
claudeTool := map[string]interface{}{
|
||||
"name": fn.Get("name").String(),
|
||||
"description": fn.Get("description").String(),
|
||||
}
|
||||
// Convert parameters to input_schema
|
||||
if params := fn.Get("parameters"); params.Exists() {
|
||||
claudeTool["input_schema"] = params.Value()
|
||||
} else {
|
||||
claudeTool["input_schema"] = map[string]interface{}{
|
||||
"type": "object",
|
||||
"properties": map[string]interface{}{},
|
||||
}
|
||||
}
|
||||
claudeTools = append(claudeTools, claudeTool)
|
||||
}
|
||||
}
|
||||
if len(claudeTools) > 0 {
|
||||
out, _ = sjson.Set(out, "tools", claudeTools)
|
||||
}
|
||||
}
|
||||
|
||||
// Process messages
|
||||
messages := root.Get("messages")
|
||||
if messages.Exists() && messages.IsArray() {
|
||||
claudeMessages := make([]interface{}, 0)
|
||||
var systemPrompt string
|
||||
|
||||
// Track pending tool results to merge with next user message
|
||||
var pendingToolResults []map[string]interface{}
|
||||
|
||||
for _, msg := range messages.Array() {
|
||||
role := msg.Get("role").String()
|
||||
content := msg.Get("content")
|
||||
|
||||
if role == "system" {
|
||||
// Extract system message
|
||||
if content.IsArray() {
|
||||
for _, part := range content.Array() {
|
||||
if part.Get("type").String() == "text" {
|
||||
systemPrompt += part.Get("text").String() + "\n"
|
||||
}
|
||||
}
|
||||
} else {
|
||||
systemPrompt = content.String()
|
||||
}
|
||||
continue
|
||||
}
|
||||
|
||||
if role == "tool" {
|
||||
// OpenAI tool message -> Claude tool_result content block
|
||||
toolCallID := msg.Get("tool_call_id").String()
|
||||
toolContent := content.String()
|
||||
|
||||
toolResult := map[string]interface{}{
|
||||
"type": "tool_result",
|
||||
"tool_use_id": toolCallID,
|
||||
}
|
||||
|
||||
// Handle content - can be string or structured
|
||||
if content.IsArray() {
|
||||
contentParts := make([]interface{}, 0)
|
||||
for _, part := range content.Array() {
|
||||
if part.Get("type").String() == "text" {
|
||||
contentParts = append(contentParts, map[string]interface{}{
|
||||
"type": "text",
|
||||
"text": part.Get("text").String(),
|
||||
})
|
||||
}
|
||||
}
|
||||
toolResult["content"] = contentParts
|
||||
} else {
|
||||
toolResult["content"] = toolContent
|
||||
}
|
||||
|
||||
pendingToolResults = append(pendingToolResults, toolResult)
|
||||
continue
|
||||
}
|
||||
|
||||
claudeMsg := map[string]interface{}{
|
||||
"role": role,
|
||||
}
|
||||
|
||||
// Handle assistant messages with tool_calls
|
||||
if role == "assistant" && msg.Get("tool_calls").Exists() {
|
||||
contentParts := make([]interface{}, 0)
|
||||
|
||||
// Add text content if present
|
||||
if content.Exists() && content.String() != "" {
|
||||
contentParts = append(contentParts, map[string]interface{}{
|
||||
"type": "text",
|
||||
"text": content.String(),
|
||||
})
|
||||
}
|
||||
|
||||
// Convert tool_calls to tool_use blocks
|
||||
for _, toolCall := range msg.Get("tool_calls").Array() {
|
||||
toolUseID := toolCall.Get("id").String()
|
||||
fnName := toolCall.Get("function.name").String()
|
||||
fnArgs := toolCall.Get("function.arguments").String()
|
||||
|
||||
// Parse arguments JSON
|
||||
var argsMap map[string]interface{}
|
||||
if err := json.Unmarshal([]byte(fnArgs), &argsMap); err != nil {
|
||||
argsMap = map[string]interface{}{"raw": fnArgs}
|
||||
}
|
||||
|
||||
contentParts = append(contentParts, map[string]interface{}{
|
||||
"type": "tool_use",
|
||||
"id": toolUseID,
|
||||
"name": fnName,
|
||||
"input": argsMap,
|
||||
})
|
||||
}
|
||||
|
||||
claudeMsg["content"] = contentParts
|
||||
claudeMessages = append(claudeMessages, claudeMsg)
|
||||
continue
|
||||
}
|
||||
|
||||
// Handle user messages - may need to include pending tool results
|
||||
if role == "user" && len(pendingToolResults) > 0 {
|
||||
contentParts := make([]interface{}, 0)
|
||||
|
||||
// Add pending tool results first
|
||||
for _, tr := range pendingToolResults {
|
||||
contentParts = append(contentParts, tr)
|
||||
}
|
||||
pendingToolResults = nil
|
||||
|
||||
// Add user content
|
||||
if content.IsArray() {
|
||||
for _, part := range content.Array() {
|
||||
partType := part.Get("type").String()
|
||||
if partType == "text" {
|
||||
contentParts = append(contentParts, map[string]interface{}{
|
||||
"type": "text",
|
||||
"text": part.Get("text").String(),
|
||||
})
|
||||
} else if partType == "image_url" {
|
||||
imageURL := part.Get("image_url.url").String()
|
||||
|
||||
// Check if it's base64 format (data:image/png;base64,xxxxx)
|
||||
if strings.HasPrefix(imageURL, "data:") {
|
||||
// Parse data URL format
|
||||
// Format: data:image/png;base64,xxxxx
|
||||
commaIdx := strings.Index(imageURL, ",")
|
||||
if commaIdx != -1 {
|
||||
// Extract media_type (e.g., "image/png")
|
||||
header := imageURL[5:commaIdx] // Remove "data:" prefix
|
||||
mediaType := header
|
||||
if semiIdx := strings.Index(header, ";"); semiIdx != -1 {
|
||||
mediaType = header[:semiIdx]
|
||||
}
|
||||
|
||||
// Extract base64 data
|
||||
base64Data := imageURL[commaIdx+1:]
|
||||
|
||||
contentParts = append(contentParts, map[string]interface{}{
|
||||
"type": "image",
|
||||
"source": map[string]interface{}{
|
||||
"type": "base64",
|
||||
"media_type": mediaType,
|
||||
"data": base64Data,
|
||||
},
|
||||
})
|
||||
}
|
||||
} else {
|
||||
// Regular URL format - keep original logic
|
||||
contentParts = append(contentParts, map[string]interface{}{
|
||||
"type": "image",
|
||||
"source": map[string]interface{}{
|
||||
"type": "url",
|
||||
"url": imageURL,
|
||||
},
|
||||
})
|
||||
}
|
||||
}
|
||||
}
|
||||
} else if content.String() != "" {
|
||||
contentParts = append(contentParts, map[string]interface{}{
|
||||
"type": "text",
|
||||
"text": content.String(),
|
||||
})
|
||||
}
|
||||
|
||||
claudeMsg["content"] = contentParts
|
||||
claudeMessages = append(claudeMessages, claudeMsg)
|
||||
continue
|
||||
}
|
||||
|
||||
// Handle regular content
|
||||
if content.IsArray() {
|
||||
contentParts := make([]interface{}, 0)
|
||||
for _, part := range content.Array() {
|
||||
partType := part.Get("type").String()
|
||||
if partType == "text" {
|
||||
contentParts = append(contentParts, map[string]interface{}{
|
||||
"type": "text",
|
||||
"text": part.Get("text").String(),
|
||||
})
|
||||
} else if partType == "image_url" {
|
||||
imageURL := part.Get("image_url.url").String()
|
||||
|
||||
// Check if it's base64 format (data:image/png;base64,xxxxx)
|
||||
if strings.HasPrefix(imageURL, "data:") {
|
||||
// Parse data URL format
|
||||
// Format: data:image/png;base64,xxxxx
|
||||
commaIdx := strings.Index(imageURL, ",")
|
||||
if commaIdx != -1 {
|
||||
// Extract media_type (e.g., "image/png")
|
||||
header := imageURL[5:commaIdx] // Remove "data:" prefix
|
||||
mediaType := header
|
||||
if semiIdx := strings.Index(header, ";"); semiIdx != -1 {
|
||||
mediaType = header[:semiIdx]
|
||||
}
|
||||
|
||||
// Extract base64 data
|
||||
base64Data := imageURL[commaIdx+1:]
|
||||
|
||||
contentParts = append(contentParts, map[string]interface{}{
|
||||
"type": "image",
|
||||
"source": map[string]interface{}{
|
||||
"type": "base64",
|
||||
"media_type": mediaType,
|
||||
"data": base64Data,
|
||||
},
|
||||
})
|
||||
}
|
||||
} else {
|
||||
// Regular URL format - keep original logic
|
||||
contentParts = append(contentParts, map[string]interface{}{
|
||||
"type": "image",
|
||||
"source": map[string]interface{}{
|
||||
"type": "url",
|
||||
"url": imageURL,
|
||||
},
|
||||
})
|
||||
}
|
||||
}
|
||||
}
|
||||
claudeMsg["content"] = contentParts
|
||||
} else {
|
||||
claudeMsg["content"] = content.String()
|
||||
}
|
||||
|
||||
claudeMessages = append(claudeMessages, claudeMsg)
|
||||
}
|
||||
|
||||
// If there are pending tool results without a following user message,
|
||||
// create a user message with just the tool results
|
||||
if len(pendingToolResults) > 0 {
|
||||
contentParts := make([]interface{}, 0)
|
||||
for _, tr := range pendingToolResults {
|
||||
contentParts = append(contentParts, tr)
|
||||
}
|
||||
claudeMessages = append(claudeMessages, map[string]interface{}{
|
||||
"role": "user",
|
||||
"content": contentParts,
|
||||
})
|
||||
}
|
||||
|
||||
out, _ = sjson.Set(out, "messages", claudeMessages)
|
||||
|
||||
if systemPrompt != "" {
|
||||
out, _ = sjson.Set(out, "system", systemPrompt)
|
||||
}
|
||||
}
|
||||
|
||||
// Set stream
|
||||
out, _ = sjson.Set(out, "stream", stream)
|
||||
|
||||
return []byte(out)
|
||||
}
|
||||
@@ -1,373 +0,0 @@
|
||||
// Package chat_completions provides response translation from Kiro to OpenAI format.
|
||||
package chat_completions
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"github.com/tidwall/gjson"
|
||||
)
|
||||
|
||||
// ConvertKiroResponseToOpenAI converts Kiro streaming response to OpenAI SSE format.
|
||||
// Handles Claude SSE events: content_block_start, content_block_delta, input_json_delta,
|
||||
// content_block_stop, message_delta, and message_stop.
|
||||
// Input may be in SSE format: "event: xxx\ndata: {...}" or raw JSON.
|
||||
func ConvertKiroResponseToOpenAI(ctx context.Context, model string, originalRequest, request, rawResponse []byte, param *any) []string {
|
||||
raw := string(rawResponse)
|
||||
var results []string
|
||||
|
||||
// Handle SSE format: extract JSON from "data: " lines
|
||||
// Input format: "event: message_start\ndata: {...}"
|
||||
lines := strings.Split(raw, "\n")
|
||||
for _, line := range lines {
|
||||
line = strings.TrimSpace(line)
|
||||
if strings.HasPrefix(line, "data: ") {
|
||||
jsonPart := strings.TrimPrefix(line, "data: ")
|
||||
chunks := convertClaudeEventToOpenAI(jsonPart, model)
|
||||
results = append(results, chunks...)
|
||||
} else if strings.HasPrefix(line, "{") {
|
||||
// Raw JSON (backward compatibility)
|
||||
chunks := convertClaudeEventToOpenAI(line, model)
|
||||
results = append(results, chunks...)
|
||||
}
|
||||
}
|
||||
|
||||
return results
|
||||
}
|
||||
|
||||
// convertClaudeEventToOpenAI converts a single Claude JSON event to OpenAI format
|
||||
func convertClaudeEventToOpenAI(jsonStr string, model string) []string {
|
||||
root := gjson.Parse(jsonStr)
|
||||
var results []string
|
||||
|
||||
eventType := root.Get("type").String()
|
||||
|
||||
switch eventType {
|
||||
case "message_start":
|
||||
// Initial message event - emit initial chunk with role
|
||||
response := map[string]interface{}{
|
||||
"id": "chatcmpl-" + uuid.New().String()[:24],
|
||||
"object": "chat.completion.chunk",
|
||||
"created": time.Now().Unix(),
|
||||
"model": model,
|
||||
"choices": []map[string]interface{}{
|
||||
{
|
||||
"index": 0,
|
||||
"delta": map[string]interface{}{
|
||||
"role": "assistant",
|
||||
"content": "",
|
||||
},
|
||||
"finish_reason": nil,
|
||||
},
|
||||
},
|
||||
}
|
||||
result, _ := json.Marshal(response)
|
||||
results = append(results, string(result))
|
||||
return results
|
||||
|
||||
case "content_block_start":
|
||||
// Start of a content block (text or tool_use)
|
||||
blockType := root.Get("content_block.type").String()
|
||||
index := int(root.Get("index").Int())
|
||||
|
||||
if blockType == "tool_use" {
|
||||
// Start of tool_use block
|
||||
toolUseID := root.Get("content_block.id").String()
|
||||
toolName := root.Get("content_block.name").String()
|
||||
|
||||
toolCall := map[string]interface{}{
|
||||
"index": index,
|
||||
"id": toolUseID,
|
||||
"type": "function",
|
||||
"function": map[string]interface{}{
|
||||
"name": toolName,
|
||||
"arguments": "",
|
||||
},
|
||||
}
|
||||
|
||||
response := map[string]interface{}{
|
||||
"id": "chatcmpl-" + uuid.New().String()[:24],
|
||||
"object": "chat.completion.chunk",
|
||||
"created": time.Now().Unix(),
|
||||
"model": model,
|
||||
"choices": []map[string]interface{}{
|
||||
{
|
||||
"index": 0,
|
||||
"delta": map[string]interface{}{
|
||||
"tool_calls": []map[string]interface{}{toolCall},
|
||||
},
|
||||
"finish_reason": nil,
|
||||
},
|
||||
},
|
||||
}
|
||||
result, _ := json.Marshal(response)
|
||||
results = append(results, string(result))
|
||||
}
|
||||
return results
|
||||
|
||||
case "content_block_delta":
|
||||
index := int(root.Get("index").Int())
|
||||
deltaType := root.Get("delta.type").String()
|
||||
|
||||
if deltaType == "text_delta" {
|
||||
// Text content delta
|
||||
contentDelta := root.Get("delta.text").String()
|
||||
if contentDelta != "" {
|
||||
response := map[string]interface{}{
|
||||
"id": "chatcmpl-" + uuid.New().String()[:24],
|
||||
"object": "chat.completion.chunk",
|
||||
"created": time.Now().Unix(),
|
||||
"model": model,
|
||||
"choices": []map[string]interface{}{
|
||||
{
|
||||
"index": 0,
|
||||
"delta": map[string]interface{}{
|
||||
"content": contentDelta,
|
||||
},
|
||||
"finish_reason": nil,
|
||||
},
|
||||
},
|
||||
}
|
||||
result, _ := json.Marshal(response)
|
||||
results = append(results, string(result))
|
||||
}
|
||||
} else if deltaType == "input_json_delta" {
|
||||
// Tool input delta (streaming arguments)
|
||||
partialJSON := root.Get("delta.partial_json").String()
|
||||
if partialJSON != "" {
|
||||
toolCall := map[string]interface{}{
|
||||
"index": index,
|
||||
"function": map[string]interface{}{
|
||||
"arguments": partialJSON,
|
||||
},
|
||||
}
|
||||
|
||||
response := map[string]interface{}{
|
||||
"id": "chatcmpl-" + uuid.New().String()[:24],
|
||||
"object": "chat.completion.chunk",
|
||||
"created": time.Now().Unix(),
|
||||
"model": model,
|
||||
"choices": []map[string]interface{}{
|
||||
{
|
||||
"index": 0,
|
||||
"delta": map[string]interface{}{
|
||||
"tool_calls": []map[string]interface{}{toolCall},
|
||||
},
|
||||
"finish_reason": nil,
|
||||
},
|
||||
},
|
||||
}
|
||||
result, _ := json.Marshal(response)
|
||||
results = append(results, string(result))
|
||||
}
|
||||
}
|
||||
return results
|
||||
|
||||
case "content_block_stop":
|
||||
// End of content block - no output needed for OpenAI format
|
||||
return results
|
||||
|
||||
case "message_delta":
|
||||
// Final message delta with stop_reason and usage
|
||||
stopReason := root.Get("delta.stop_reason").String()
|
||||
if stopReason != "" {
|
||||
finishReason := "stop"
|
||||
if stopReason == "tool_use" {
|
||||
finishReason = "tool_calls"
|
||||
} else if stopReason == "end_turn" {
|
||||
finishReason = "stop"
|
||||
} else if stopReason == "max_tokens" {
|
||||
finishReason = "length"
|
||||
}
|
||||
|
||||
response := map[string]interface{}{
|
||||
"id": "chatcmpl-" + uuid.New().String()[:24],
|
||||
"object": "chat.completion.chunk",
|
||||
"created": time.Now().Unix(),
|
||||
"model": model,
|
||||
"choices": []map[string]interface{}{
|
||||
{
|
||||
"index": 0,
|
||||
"delta": map[string]interface{}{},
|
||||
"finish_reason": finishReason,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
// Extract and include usage information from message_delta event
|
||||
usage := root.Get("usage")
|
||||
if usage.Exists() {
|
||||
inputTokens := usage.Get("input_tokens").Int()
|
||||
outputTokens := usage.Get("output_tokens").Int()
|
||||
response["usage"] = map[string]interface{}{
|
||||
"prompt_tokens": inputTokens,
|
||||
"completion_tokens": outputTokens,
|
||||
"total_tokens": inputTokens + outputTokens,
|
||||
}
|
||||
}
|
||||
|
||||
result, _ := json.Marshal(response)
|
||||
results = append(results, string(result))
|
||||
}
|
||||
return results
|
||||
|
||||
case "message_stop":
|
||||
// End of message - could emit [DONE] marker
|
||||
return results
|
||||
}
|
||||
|
||||
// Fallback: handle raw content for backward compatibility
|
||||
var contentDelta string
|
||||
if delta := root.Get("delta.text"); delta.Exists() {
|
||||
contentDelta = delta.String()
|
||||
} else if content := root.Get("content"); content.Exists() && root.Get("type").String() == "" {
|
||||
contentDelta = content.String()
|
||||
}
|
||||
|
||||
if contentDelta != "" {
|
||||
response := map[string]interface{}{
|
||||
"id": "chatcmpl-" + uuid.New().String()[:24],
|
||||
"object": "chat.completion.chunk",
|
||||
"created": time.Now().Unix(),
|
||||
"model": model,
|
||||
"choices": []map[string]interface{}{
|
||||
{
|
||||
"index": 0,
|
||||
"delta": map[string]interface{}{
|
||||
"content": contentDelta,
|
||||
},
|
||||
"finish_reason": nil,
|
||||
},
|
||||
},
|
||||
}
|
||||
result, _ := json.Marshal(response)
|
||||
results = append(results, string(result))
|
||||
}
|
||||
|
||||
// Handle tool_use content blocks (Claude format) - fallback
|
||||
toolUses := root.Get("delta.tool_use")
|
||||
if !toolUses.Exists() {
|
||||
toolUses = root.Get("tool_use")
|
||||
}
|
||||
if toolUses.Exists() && toolUses.IsObject() {
|
||||
inputJSON := toolUses.Get("input").String()
|
||||
if inputJSON == "" {
|
||||
if inputObj := toolUses.Get("input"); inputObj.Exists() {
|
||||
inputBytes, _ := json.Marshal(inputObj.Value())
|
||||
inputJSON = string(inputBytes)
|
||||
}
|
||||
}
|
||||
|
||||
toolCall := map[string]interface{}{
|
||||
"index": 0,
|
||||
"id": toolUses.Get("id").String(),
|
||||
"type": "function",
|
||||
"function": map[string]interface{}{
|
||||
"name": toolUses.Get("name").String(),
|
||||
"arguments": inputJSON,
|
||||
},
|
||||
}
|
||||
|
||||
response := map[string]interface{}{
|
||||
"id": "chatcmpl-" + uuid.New().String()[:24],
|
||||
"object": "chat.completion.chunk",
|
||||
"created": time.Now().Unix(),
|
||||
"model": model,
|
||||
"choices": []map[string]interface{}{
|
||||
{
|
||||
"index": 0,
|
||||
"delta": map[string]interface{}{
|
||||
"tool_calls": []map[string]interface{}{toolCall},
|
||||
},
|
||||
"finish_reason": nil,
|
||||
},
|
||||
},
|
||||
}
|
||||
result, _ := json.Marshal(response)
|
||||
results = append(results, string(result))
|
||||
}
|
||||
|
||||
return results
|
||||
}
|
||||
|
||||
// ConvertKiroResponseToOpenAINonStream converts Kiro non-streaming response to OpenAI format.
|
||||
func ConvertKiroResponseToOpenAINonStream(ctx context.Context, model string, originalRequest, request, rawResponse []byte, param *any) string {
|
||||
root := gjson.ParseBytes(rawResponse)
|
||||
|
||||
var content string
|
||||
var toolCalls []map[string]interface{}
|
||||
|
||||
contentArray := root.Get("content")
|
||||
if contentArray.IsArray() {
|
||||
for _, item := range contentArray.Array() {
|
||||
itemType := item.Get("type").String()
|
||||
if itemType == "text" {
|
||||
content += item.Get("text").String()
|
||||
} else if itemType == "tool_use" {
|
||||
// Convert Claude tool_use to OpenAI tool_calls format
|
||||
inputJSON := item.Get("input").String()
|
||||
if inputJSON == "" {
|
||||
// If input is an object, marshal it
|
||||
if inputObj := item.Get("input"); inputObj.Exists() {
|
||||
inputBytes, _ := json.Marshal(inputObj.Value())
|
||||
inputJSON = string(inputBytes)
|
||||
}
|
||||
}
|
||||
toolCall := map[string]interface{}{
|
||||
"id": item.Get("id").String(),
|
||||
"type": "function",
|
||||
"function": map[string]interface{}{
|
||||
"name": item.Get("name").String(),
|
||||
"arguments": inputJSON,
|
||||
},
|
||||
}
|
||||
toolCalls = append(toolCalls, toolCall)
|
||||
}
|
||||
}
|
||||
} else {
|
||||
content = root.Get("content").String()
|
||||
}
|
||||
|
||||
inputTokens := root.Get("usage.input_tokens").Int()
|
||||
outputTokens := root.Get("usage.output_tokens").Int()
|
||||
|
||||
message := map[string]interface{}{
|
||||
"role": "assistant",
|
||||
"content": content,
|
||||
}
|
||||
|
||||
// Add tool_calls if present
|
||||
if len(toolCalls) > 0 {
|
||||
message["tool_calls"] = toolCalls
|
||||
}
|
||||
|
||||
finishReason := "stop"
|
||||
if len(toolCalls) > 0 {
|
||||
finishReason = "tool_calls"
|
||||
}
|
||||
|
||||
response := map[string]interface{}{
|
||||
"id": "chatcmpl-" + uuid.New().String()[:24],
|
||||
"object": "chat.completion",
|
||||
"created": time.Now().Unix(),
|
||||
"model": model,
|
||||
"choices": []map[string]interface{}{
|
||||
{
|
||||
"index": 0,
|
||||
"message": message,
|
||||
"finish_reason": finishReason,
|
||||
},
|
||||
},
|
||||
"usage": map[string]interface{}{
|
||||
"prompt_tokens": inputTokens,
|
||||
"completion_tokens": outputTokens,
|
||||
"total_tokens": inputTokens + outputTokens,
|
||||
},
|
||||
}
|
||||
|
||||
result, _ := json.Marshal(response)
|
||||
return string(result)
|
||||
}
|
||||
@@ -1,4 +1,5 @@
|
||||
package chat_completions
|
||||
// Package openai provides translation between OpenAI Chat Completions and Kiro formats.
|
||||
package openai
|
||||
|
||||
import (
|
||||
. "github.com/router-for-me/CLIProxyAPI/v6/internal/constant"
|
||||
@@ -8,12 +9,12 @@ import (
|
||||
|
||||
func init() {
|
||||
translator.Register(
|
||||
OpenAI,
|
||||
Kiro,
|
||||
OpenAI, // source format
|
||||
Kiro, // target format
|
||||
ConvertOpenAIRequestToKiro,
|
||||
interfaces.TranslateResponse{
|
||||
Stream: ConvertKiroResponseToOpenAI,
|
||||
NonStream: ConvertKiroResponseToOpenAINonStream,
|
||||
Stream: ConvertKiroStreamToOpenAI,
|
||||
NonStream: ConvertKiroNonStreamToOpenAI,
|
||||
},
|
||||
)
|
||||
}
|
||||
}
|
||||
369
internal/translator/kiro/openai/kiro_openai.go
Normal file
369
internal/translator/kiro/openai/kiro_openai.go
Normal file
@@ -0,0 +1,369 @@
|
||||
// Package openai provides translation between OpenAI Chat Completions and Kiro formats.
|
||||
// This package enables direct OpenAI → Kiro translation, bypassing the Claude intermediate layer.
|
||||
//
|
||||
// The Kiro executor generates Claude-compatible SSE format internally, so the streaming response
|
||||
// translation converts from Claude SSE format to OpenAI SSE format.
|
||||
package openai
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"strings"
|
||||
|
||||
kirocommon "github.com/router-for-me/CLIProxyAPI/v6/internal/translator/kiro/common"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/usage"
|
||||
log "github.com/sirupsen/logrus"
|
||||
"github.com/tidwall/gjson"
|
||||
)
|
||||
|
||||
// ConvertKiroStreamToOpenAI converts Kiro streaming response to OpenAI format.
|
||||
// The Kiro executor emits Claude-compatible SSE events, so this function translates
|
||||
// from Claude SSE format to OpenAI SSE format.
|
||||
//
|
||||
// Claude SSE format:
|
||||
// - event: message_start\ndata: {...}
|
||||
// - event: content_block_start\ndata: {...}
|
||||
// - event: content_block_delta\ndata: {...}
|
||||
// - event: content_block_stop\ndata: {...}
|
||||
// - event: message_delta\ndata: {...}
|
||||
// - event: message_stop\ndata: {...}
|
||||
//
|
||||
// OpenAI SSE format:
|
||||
// - data: {"id":"...","object":"chat.completion.chunk",...}
|
||||
// - data: [DONE]
|
||||
func ConvertKiroStreamToOpenAI(ctx context.Context, model string, originalRequest, request, rawResponse []byte, param *any) []string {
|
||||
// Initialize state if needed
|
||||
if *param == nil {
|
||||
*param = NewOpenAIStreamState(model)
|
||||
}
|
||||
state := (*param).(*OpenAIStreamState)
|
||||
|
||||
// Parse the Claude SSE event
|
||||
responseStr := string(rawResponse)
|
||||
|
||||
// Handle raw event format (event: xxx\ndata: {...})
|
||||
var eventType string
|
||||
var eventData string
|
||||
|
||||
if strings.HasPrefix(responseStr, "event:") {
|
||||
// Parse event type and data
|
||||
lines := strings.SplitN(responseStr, "\n", 2)
|
||||
if len(lines) >= 1 {
|
||||
eventType = strings.TrimSpace(strings.TrimPrefix(lines[0], "event:"))
|
||||
}
|
||||
if len(lines) >= 2 && strings.HasPrefix(lines[1], "data:") {
|
||||
eventData = strings.TrimSpace(strings.TrimPrefix(lines[1], "data:"))
|
||||
}
|
||||
} else if strings.HasPrefix(responseStr, "data:") {
|
||||
// Just data line
|
||||
eventData = strings.TrimSpace(strings.TrimPrefix(responseStr, "data:"))
|
||||
} else {
|
||||
// Try to parse as raw JSON
|
||||
eventData = strings.TrimSpace(responseStr)
|
||||
}
|
||||
|
||||
if eventData == "" {
|
||||
return []string{}
|
||||
}
|
||||
|
||||
// Parse the event data as JSON
|
||||
eventJSON := gjson.Parse(eventData)
|
||||
if !eventJSON.Exists() {
|
||||
return []string{}
|
||||
}
|
||||
|
||||
// Determine event type from JSON if not already set
|
||||
if eventType == "" {
|
||||
eventType = eventJSON.Get("type").String()
|
||||
}
|
||||
|
||||
var results []string
|
||||
|
||||
switch eventType {
|
||||
case "message_start":
|
||||
// Send first chunk with role
|
||||
firstChunk := BuildOpenAISSEFirstChunk(state)
|
||||
results = append(results, firstChunk)
|
||||
|
||||
case "content_block_start":
|
||||
// Check block type
|
||||
blockType := eventJSON.Get("content_block.type").String()
|
||||
switch blockType {
|
||||
case "text":
|
||||
// Text block starting - nothing to emit yet
|
||||
case "thinking":
|
||||
// Thinking block starting - nothing to emit yet for OpenAI
|
||||
case "tool_use":
|
||||
// Tool use block starting
|
||||
toolUseID := eventJSON.Get("content_block.id").String()
|
||||
toolName := eventJSON.Get("content_block.name").String()
|
||||
chunk := BuildOpenAISSEToolCallStart(state, toolUseID, toolName)
|
||||
results = append(results, chunk)
|
||||
state.ToolCallIndex++
|
||||
}
|
||||
|
||||
case "content_block_delta":
|
||||
deltaType := eventJSON.Get("delta.type").String()
|
||||
switch deltaType {
|
||||
case "text_delta":
|
||||
textDelta := eventJSON.Get("delta.text").String()
|
||||
if textDelta != "" {
|
||||
chunk := BuildOpenAISSETextDelta(state, textDelta)
|
||||
results = append(results, chunk)
|
||||
}
|
||||
case "thinking_delta":
|
||||
// Convert thinking to reasoning_content for o1-style compatibility
|
||||
thinkingDelta := eventJSON.Get("delta.thinking").String()
|
||||
if thinkingDelta != "" {
|
||||
chunk := BuildOpenAISSEReasoningDelta(state, thinkingDelta)
|
||||
results = append(results, chunk)
|
||||
}
|
||||
case "input_json_delta":
|
||||
// Tool call arguments delta
|
||||
partialJSON := eventJSON.Get("delta.partial_json").String()
|
||||
if partialJSON != "" {
|
||||
// Get the tool index from content block index
|
||||
blockIndex := int(eventJSON.Get("index").Int())
|
||||
chunk := BuildOpenAISSEToolCallArgumentsDelta(state, partialJSON, blockIndex-1) // Adjust for 0-based tool index
|
||||
results = append(results, chunk)
|
||||
}
|
||||
}
|
||||
|
||||
case "content_block_stop":
|
||||
// Content block ended - nothing to emit for OpenAI
|
||||
|
||||
case "message_delta":
|
||||
// Message delta with stop_reason
|
||||
stopReason := eventJSON.Get("delta.stop_reason").String()
|
||||
finishReason := mapKiroStopReasonToOpenAI(stopReason)
|
||||
if finishReason != "" {
|
||||
chunk := BuildOpenAISSEFinish(state, finishReason)
|
||||
results = append(results, chunk)
|
||||
}
|
||||
|
||||
// Extract usage if present
|
||||
if eventJSON.Get("usage").Exists() {
|
||||
inputTokens := eventJSON.Get("usage.input_tokens").Int()
|
||||
outputTokens := eventJSON.Get("usage.output_tokens").Int()
|
||||
usageInfo := usage.Detail{
|
||||
InputTokens: inputTokens,
|
||||
OutputTokens: outputTokens,
|
||||
TotalTokens: inputTokens + outputTokens,
|
||||
}
|
||||
chunk := BuildOpenAISSEUsage(state, usageInfo)
|
||||
results = append(results, chunk)
|
||||
}
|
||||
|
||||
case "message_stop":
|
||||
// Final event - do NOT emit [DONE] here
|
||||
// The handler layer (openai_handlers.go) will send [DONE] when the stream closes
|
||||
// Emitting [DONE] here would cause duplicate [DONE] markers
|
||||
|
||||
case "ping":
|
||||
// Ping event with usage - optionally emit usage chunk
|
||||
if eventJSON.Get("usage").Exists() {
|
||||
inputTokens := eventJSON.Get("usage.input_tokens").Int()
|
||||
outputTokens := eventJSON.Get("usage.output_tokens").Int()
|
||||
usageInfo := usage.Detail{
|
||||
InputTokens: inputTokens,
|
||||
OutputTokens: outputTokens,
|
||||
TotalTokens: inputTokens + outputTokens,
|
||||
}
|
||||
chunk := BuildOpenAISSEUsage(state, usageInfo)
|
||||
results = append(results, chunk)
|
||||
}
|
||||
}
|
||||
|
||||
return results
|
||||
}
|
||||
|
||||
// ConvertKiroNonStreamToOpenAI converts Kiro non-streaming response to OpenAI format.
|
||||
// The Kiro executor returns Claude-compatible JSON responses, so this function translates
|
||||
// from Claude format to OpenAI format.
|
||||
func ConvertKiroNonStreamToOpenAI(ctx context.Context, model string, originalRequest, request, rawResponse []byte, param *any) string {
|
||||
// Parse the Claude-format response
|
||||
response := gjson.ParseBytes(rawResponse)
|
||||
|
||||
// Extract content
|
||||
var content string
|
||||
var toolUses []KiroToolUse
|
||||
var stopReason string
|
||||
|
||||
// Get stop_reason
|
||||
stopReason = response.Get("stop_reason").String()
|
||||
|
||||
// Process content blocks
|
||||
contentBlocks := response.Get("content")
|
||||
if contentBlocks.IsArray() {
|
||||
for _, block := range contentBlocks.Array() {
|
||||
blockType := block.Get("type").String()
|
||||
switch blockType {
|
||||
case "text":
|
||||
content += block.Get("text").String()
|
||||
case "thinking":
|
||||
// Skip thinking blocks for OpenAI format (or convert to reasoning_content if needed)
|
||||
case "tool_use":
|
||||
toolUseID := block.Get("id").String()
|
||||
toolName := block.Get("name").String()
|
||||
toolInput := block.Get("input")
|
||||
|
||||
var inputMap map[string]interface{}
|
||||
if toolInput.IsObject() {
|
||||
inputMap = make(map[string]interface{})
|
||||
toolInput.ForEach(func(key, value gjson.Result) bool {
|
||||
inputMap[key.String()] = value.Value()
|
||||
return true
|
||||
})
|
||||
}
|
||||
|
||||
toolUses = append(toolUses, KiroToolUse{
|
||||
ToolUseID: toolUseID,
|
||||
Name: toolName,
|
||||
Input: inputMap,
|
||||
})
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Extract usage
|
||||
usageInfo := usage.Detail{
|
||||
InputTokens: response.Get("usage.input_tokens").Int(),
|
||||
OutputTokens: response.Get("usage.output_tokens").Int(),
|
||||
}
|
||||
usageInfo.TotalTokens = usageInfo.InputTokens + usageInfo.OutputTokens
|
||||
|
||||
// Build OpenAI response
|
||||
openaiResponse := BuildOpenAIResponse(content, toolUses, model, usageInfo, stopReason)
|
||||
return string(openaiResponse)
|
||||
}
|
||||
|
||||
// ParseClaudeEvent parses a Claude SSE event and returns the event type and data
|
||||
func ParseClaudeEvent(rawEvent []byte) (eventType string, eventData []byte) {
|
||||
lines := bytes.Split(rawEvent, []byte("\n"))
|
||||
for _, line := range lines {
|
||||
line = bytes.TrimSpace(line)
|
||||
if bytes.HasPrefix(line, []byte("event:")) {
|
||||
eventType = string(bytes.TrimSpace(bytes.TrimPrefix(line, []byte("event:"))))
|
||||
} else if bytes.HasPrefix(line, []byte("data:")) {
|
||||
eventData = bytes.TrimSpace(bytes.TrimPrefix(line, []byte("data:")))
|
||||
}
|
||||
}
|
||||
return eventType, eventData
|
||||
}
|
||||
|
||||
// ExtractThinkingFromContent parses content to extract thinking blocks.
|
||||
// Returns cleaned content (without thinking tags) and whether thinking was found.
|
||||
func ExtractThinkingFromContent(content string) (string, string, bool) {
|
||||
if !strings.Contains(content, kirocommon.ThinkingStartTag) {
|
||||
return content, "", false
|
||||
}
|
||||
|
||||
var cleanedContent strings.Builder
|
||||
var thinkingContent strings.Builder
|
||||
hasThinking := false
|
||||
remaining := content
|
||||
|
||||
for len(remaining) > 0 {
|
||||
startIdx := strings.Index(remaining, kirocommon.ThinkingStartTag)
|
||||
if startIdx == -1 {
|
||||
cleanedContent.WriteString(remaining)
|
||||
break
|
||||
}
|
||||
|
||||
// Add content before thinking tag
|
||||
cleanedContent.WriteString(remaining[:startIdx])
|
||||
|
||||
// Move past opening tag
|
||||
remaining = remaining[startIdx+len(kirocommon.ThinkingStartTag):]
|
||||
|
||||
// Find closing tag
|
||||
endIdx := strings.Index(remaining, kirocommon.ThinkingEndTag)
|
||||
if endIdx == -1 {
|
||||
// No closing tag - treat rest as thinking
|
||||
thinkingContent.WriteString(remaining)
|
||||
hasThinking = true
|
||||
break
|
||||
}
|
||||
|
||||
// Extract thinking content
|
||||
thinkingContent.WriteString(remaining[:endIdx])
|
||||
hasThinking = true
|
||||
remaining = remaining[endIdx+len(kirocommon.ThinkingEndTag):]
|
||||
}
|
||||
|
||||
return strings.TrimSpace(cleanedContent.String()), strings.TrimSpace(thinkingContent.String()), hasThinking
|
||||
}
|
||||
|
||||
// ConvertOpenAIToolsToKiroFormat is a helper that converts OpenAI tools format to Kiro format
|
||||
func ConvertOpenAIToolsToKiroFormat(tools []map[string]interface{}) []KiroToolWrapper {
|
||||
var kiroTools []KiroToolWrapper
|
||||
|
||||
for _, tool := range tools {
|
||||
toolType, _ := tool["type"].(string)
|
||||
if toolType != "function" {
|
||||
continue
|
||||
}
|
||||
|
||||
fn, ok := tool["function"].(map[string]interface{})
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
|
||||
name := kirocommon.GetString(fn, "name")
|
||||
description := kirocommon.GetString(fn, "description")
|
||||
parameters := fn["parameters"]
|
||||
|
||||
if name == "" {
|
||||
continue
|
||||
}
|
||||
|
||||
if description == "" {
|
||||
description = "Tool: " + name
|
||||
}
|
||||
|
||||
kiroTools = append(kiroTools, KiroToolWrapper{
|
||||
ToolSpecification: KiroToolSpecification{
|
||||
Name: name,
|
||||
Description: description,
|
||||
InputSchema: KiroInputSchema{JSON: parameters},
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
return kiroTools
|
||||
}
|
||||
|
||||
// OpenAIStreamParams holds parameters for OpenAI streaming conversion
|
||||
type OpenAIStreamParams struct {
|
||||
State *OpenAIStreamState
|
||||
ThinkingState *ThinkingTagState
|
||||
ToolCallsEmitted map[string]bool
|
||||
}
|
||||
|
||||
// NewOpenAIStreamParams creates new streaming parameters
|
||||
func NewOpenAIStreamParams(model string) *OpenAIStreamParams {
|
||||
return &OpenAIStreamParams{
|
||||
State: NewOpenAIStreamState(model),
|
||||
ThinkingState: NewThinkingTagState(),
|
||||
ToolCallsEmitted: make(map[string]bool),
|
||||
}
|
||||
}
|
||||
|
||||
// ConvertClaudeToolUseToOpenAI converts a Claude tool_use block to OpenAI tool_calls format
|
||||
func ConvertClaudeToolUseToOpenAI(toolUseID, toolName string, input map[string]interface{}) map[string]interface{} {
|
||||
inputJSON, _ := json.Marshal(input)
|
||||
return map[string]interface{}{
|
||||
"id": toolUseID,
|
||||
"type": "function",
|
||||
"function": map[string]interface{}{
|
||||
"name": toolName,
|
||||
"arguments": string(inputJSON),
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
// LogStreamEvent logs a streaming event for debugging
|
||||
func LogStreamEvent(eventType, data string) {
|
||||
log.Debugf("kiro-openai: stream event type=%s, data_len=%d", eventType, len(data))
|
||||
}
|
||||
848
internal/translator/kiro/openai/kiro_openai_request.go
Normal file
848
internal/translator/kiro/openai/kiro_openai_request.go
Normal file
@@ -0,0 +1,848 @@
|
||||
// Package openai provides request translation from OpenAI Chat Completions to Kiro format.
|
||||
// It handles parsing and transforming OpenAI API requests into the Kiro/Amazon Q API format,
|
||||
// extracting model information, system instructions, message contents, and tool declarations.
|
||||
package openai
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"strings"
|
||||
"time"
|
||||
"unicode/utf8"
|
||||
|
||||
"github.com/google/uuid"
|
||||
kirocommon "github.com/router-for-me/CLIProxyAPI/v6/internal/translator/kiro/common"
|
||||
log "github.com/sirupsen/logrus"
|
||||
"github.com/tidwall/gjson"
|
||||
)
|
||||
|
||||
// Kiro API request structs - reuse from kiroclaude package structure
|
||||
|
||||
// KiroPayload is the top-level request structure for Kiro API
|
||||
type KiroPayload struct {
|
||||
ConversationState KiroConversationState `json:"conversationState"`
|
||||
ProfileArn string `json:"profileArn,omitempty"`
|
||||
InferenceConfig *KiroInferenceConfig `json:"inferenceConfig,omitempty"`
|
||||
}
|
||||
|
||||
// KiroInferenceConfig contains inference parameters for the Kiro API.
|
||||
type KiroInferenceConfig struct {
|
||||
MaxTokens int `json:"maxTokens,omitempty"`
|
||||
Temperature float64 `json:"temperature,omitempty"`
|
||||
TopP float64 `json:"topP,omitempty"`
|
||||
}
|
||||
|
||||
// KiroConversationState holds the conversation context
|
||||
type KiroConversationState struct {
|
||||
ChatTriggerType string `json:"chatTriggerType"` // Required: "MANUAL"
|
||||
ConversationID string `json:"conversationId"`
|
||||
CurrentMessage KiroCurrentMessage `json:"currentMessage"`
|
||||
History []KiroHistoryMessage `json:"history,omitempty"`
|
||||
}
|
||||
|
||||
// KiroCurrentMessage wraps the current user message
|
||||
type KiroCurrentMessage struct {
|
||||
UserInputMessage KiroUserInputMessage `json:"userInputMessage"`
|
||||
}
|
||||
|
||||
// KiroHistoryMessage represents a message in the conversation history
|
||||
type KiroHistoryMessage struct {
|
||||
UserInputMessage *KiroUserInputMessage `json:"userInputMessage,omitempty"`
|
||||
AssistantResponseMessage *KiroAssistantResponseMessage `json:"assistantResponseMessage,omitempty"`
|
||||
}
|
||||
|
||||
// KiroImage represents an image in Kiro API format
|
||||
type KiroImage struct {
|
||||
Format string `json:"format"`
|
||||
Source KiroImageSource `json:"source"`
|
||||
}
|
||||
|
||||
// KiroImageSource contains the image data
|
||||
type KiroImageSource struct {
|
||||
Bytes string `json:"bytes"` // base64 encoded image data
|
||||
}
|
||||
|
||||
// KiroUserInputMessage represents a user message
|
||||
type KiroUserInputMessage struct {
|
||||
Content string `json:"content"`
|
||||
ModelID string `json:"modelId"`
|
||||
Origin string `json:"origin"`
|
||||
Images []KiroImage `json:"images,omitempty"`
|
||||
UserInputMessageContext *KiroUserInputMessageContext `json:"userInputMessageContext,omitempty"`
|
||||
}
|
||||
|
||||
// KiroUserInputMessageContext contains tool-related context
|
||||
type KiroUserInputMessageContext struct {
|
||||
ToolResults []KiroToolResult `json:"toolResults,omitempty"`
|
||||
Tools []KiroToolWrapper `json:"tools,omitempty"`
|
||||
}
|
||||
|
||||
// KiroToolResult represents a tool execution result
|
||||
type KiroToolResult struct {
|
||||
Content []KiroTextContent `json:"content"`
|
||||
Status string `json:"status"`
|
||||
ToolUseID string `json:"toolUseId"`
|
||||
}
|
||||
|
||||
// KiroTextContent represents text content
|
||||
type KiroTextContent struct {
|
||||
Text string `json:"text"`
|
||||
}
|
||||
|
||||
// KiroToolWrapper wraps a tool specification
|
||||
type KiroToolWrapper struct {
|
||||
ToolSpecification KiroToolSpecification `json:"toolSpecification"`
|
||||
}
|
||||
|
||||
// KiroToolSpecification defines a tool's schema
|
||||
type KiroToolSpecification struct {
|
||||
Name string `json:"name"`
|
||||
Description string `json:"description"`
|
||||
InputSchema KiroInputSchema `json:"inputSchema"`
|
||||
}
|
||||
|
||||
// KiroInputSchema wraps the JSON schema for tool input
|
||||
type KiroInputSchema struct {
|
||||
JSON interface{} `json:"json"`
|
||||
}
|
||||
|
||||
// KiroAssistantResponseMessage represents an assistant message
|
||||
type KiroAssistantResponseMessage struct {
|
||||
Content string `json:"content"`
|
||||
ToolUses []KiroToolUse `json:"toolUses,omitempty"`
|
||||
}
|
||||
|
||||
// KiroToolUse represents a tool invocation by the assistant
|
||||
type KiroToolUse struct {
|
||||
ToolUseID string `json:"toolUseId"`
|
||||
Name string `json:"name"`
|
||||
Input map[string]interface{} `json:"input"`
|
||||
}
|
||||
|
||||
// ConvertOpenAIRequestToKiro converts an OpenAI Chat Completions request to Kiro format.
|
||||
// This is the main entry point for request translation.
|
||||
// Note: The actual payload building happens in the executor, this just passes through
|
||||
// the OpenAI format which will be converted by BuildKiroPayloadFromOpenAI.
|
||||
func ConvertOpenAIRequestToKiro(modelName string, inputRawJSON []byte, stream bool) []byte {
|
||||
// Pass through the OpenAI format - actual conversion happens in BuildKiroPayloadFromOpenAI
|
||||
return inputRawJSON
|
||||
}
|
||||
|
||||
// BuildKiroPayloadFromOpenAI constructs the Kiro API request payload from OpenAI format.
|
||||
// Supports tool calling - tools are passed via userInputMessageContext.
|
||||
// origin parameter determines which quota to use: "CLI" for Amazon Q, "AI_EDITOR" for Kiro IDE.
|
||||
// isAgentic parameter enables chunked write optimization prompt for -agentic model variants.
|
||||
// isChatOnly parameter disables tool calling for -chat model variants (pure conversation mode).
|
||||
// Returns the payload and a boolean indicating whether thinking mode was injected.
|
||||
func BuildKiroPayloadFromOpenAI(openaiBody []byte, modelID, profileArn, origin string, isAgentic, isChatOnly bool) ([]byte, bool) {
|
||||
// Extract max_tokens for potential use in inferenceConfig
|
||||
// Handle -1 as "use maximum" (Kiro max output is ~32000 tokens)
|
||||
const kiroMaxOutputTokens = 32000
|
||||
var maxTokens int64
|
||||
if mt := gjson.GetBytes(openaiBody, "max_tokens"); mt.Exists() {
|
||||
maxTokens = mt.Int()
|
||||
if maxTokens == -1 {
|
||||
maxTokens = kiroMaxOutputTokens
|
||||
log.Debugf("kiro-openai: max_tokens=-1 converted to %d", kiroMaxOutputTokens)
|
||||
}
|
||||
}
|
||||
|
||||
// Extract temperature if specified
|
||||
var temperature float64
|
||||
var hasTemperature bool
|
||||
if temp := gjson.GetBytes(openaiBody, "temperature"); temp.Exists() {
|
||||
temperature = temp.Float()
|
||||
hasTemperature = true
|
||||
}
|
||||
|
||||
// Extract top_p if specified
|
||||
var topP float64
|
||||
var hasTopP bool
|
||||
if tp := gjson.GetBytes(openaiBody, "top_p"); tp.Exists() {
|
||||
topP = tp.Float()
|
||||
hasTopP = true
|
||||
log.Debugf("kiro-openai: extracted top_p: %.2f", topP)
|
||||
}
|
||||
|
||||
// Normalize origin value for Kiro API compatibility
|
||||
origin = normalizeOrigin(origin)
|
||||
log.Debugf("kiro-openai: normalized origin value: %s", origin)
|
||||
|
||||
messages := gjson.GetBytes(openaiBody, "messages")
|
||||
|
||||
// For chat-only mode, don't include tools
|
||||
var tools gjson.Result
|
||||
if !isChatOnly {
|
||||
tools = gjson.GetBytes(openaiBody, "tools")
|
||||
}
|
||||
|
||||
// Extract system prompt from messages
|
||||
systemPrompt := extractSystemPromptFromOpenAI(messages)
|
||||
|
||||
// Inject timestamp context
|
||||
timestamp := time.Now().Format("2006-01-02 15:04:05 MST")
|
||||
timestampContext := fmt.Sprintf("[Context: Current time is %s]", timestamp)
|
||||
if systemPrompt != "" {
|
||||
systemPrompt = timestampContext + "\n\n" + systemPrompt
|
||||
} else {
|
||||
systemPrompt = timestampContext
|
||||
}
|
||||
log.Debugf("kiro-openai: injected timestamp context: %s", timestamp)
|
||||
|
||||
// Inject agentic optimization prompt for -agentic model variants
|
||||
if isAgentic {
|
||||
if systemPrompt != "" {
|
||||
systemPrompt += "\n"
|
||||
}
|
||||
systemPrompt += kirocommon.KiroAgenticSystemPrompt
|
||||
}
|
||||
|
||||
// Handle tool_choice parameter - Kiro doesn't support it natively, so we inject system prompt hints
|
||||
// OpenAI tool_choice values: "none", "auto", "required", or {"type":"function","function":{"name":"..."}}
|
||||
toolChoiceHint := extractToolChoiceHint(openaiBody)
|
||||
if toolChoiceHint != "" {
|
||||
if systemPrompt != "" {
|
||||
systemPrompt += "\n"
|
||||
}
|
||||
systemPrompt += toolChoiceHint
|
||||
log.Debugf("kiro-openai: injected tool_choice hint into system prompt")
|
||||
}
|
||||
|
||||
// Handle response_format parameter - Kiro doesn't support it natively, so we inject system prompt hints
|
||||
// OpenAI response_format: {"type": "json_object"} or {"type": "json_schema", "json_schema": {...}}
|
||||
responseFormatHint := extractResponseFormatHint(openaiBody)
|
||||
if responseFormatHint != "" {
|
||||
if systemPrompt != "" {
|
||||
systemPrompt += "\n"
|
||||
}
|
||||
systemPrompt += responseFormatHint
|
||||
log.Debugf("kiro-openai: injected response_format hint into system prompt")
|
||||
}
|
||||
|
||||
// Check for thinking mode and inject thinking hint
|
||||
// Supports OpenAI reasoning_effort parameter and model name hints
|
||||
thinkingEnabled, budgetTokens := checkThinkingModeFromOpenAI(openaiBody)
|
||||
if thinkingEnabled {
|
||||
// Adjust budgetTokens based on max_tokens if not explicitly set by reasoning_effort
|
||||
// Use 50% of max_tokens for thinking, with min 8000 and max 24000
|
||||
if maxTokens > 0 && budgetTokens == 16000 { // 16000 is the default, meaning not explicitly set
|
||||
calculatedBudget := maxTokens / 2
|
||||
if calculatedBudget < 8000 {
|
||||
calculatedBudget = 8000
|
||||
}
|
||||
if calculatedBudget > 24000 {
|
||||
calculatedBudget = 24000
|
||||
}
|
||||
budgetTokens = calculatedBudget
|
||||
log.Debugf("kiro-openai: budgetTokens calculated from max_tokens: %d (max_tokens=%d)", budgetTokens, maxTokens)
|
||||
}
|
||||
|
||||
if systemPrompt != "" {
|
||||
systemPrompt += "\n"
|
||||
}
|
||||
dynamicThinkingHint := fmt.Sprintf("<thinking_mode>interleaved</thinking_mode><max_thinking_length>%d</max_thinking_length>", budgetTokens)
|
||||
systemPrompt += dynamicThinkingHint
|
||||
log.Debugf("kiro-openai: injected dynamic thinking hint into system prompt, max_thinking_length: %d", budgetTokens)
|
||||
}
|
||||
|
||||
// Convert OpenAI tools to Kiro format
|
||||
kiroTools := convertOpenAIToolsToKiro(tools)
|
||||
|
||||
// Process messages and build history
|
||||
history, currentUserMsg, currentToolResults := processOpenAIMessages(messages, modelID, origin)
|
||||
|
||||
// Build content with system prompt
|
||||
if currentUserMsg != nil {
|
||||
currentUserMsg.Content = buildFinalContent(currentUserMsg.Content, systemPrompt, currentToolResults)
|
||||
|
||||
// Deduplicate currentToolResults
|
||||
currentToolResults = deduplicateToolResults(currentToolResults)
|
||||
|
||||
// Build userInputMessageContext with tools and tool results
|
||||
if len(kiroTools) > 0 || len(currentToolResults) > 0 {
|
||||
currentUserMsg.UserInputMessageContext = &KiroUserInputMessageContext{
|
||||
Tools: kiroTools,
|
||||
ToolResults: currentToolResults,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Build payload
|
||||
var currentMessage KiroCurrentMessage
|
||||
if currentUserMsg != nil {
|
||||
currentMessage = KiroCurrentMessage{UserInputMessage: *currentUserMsg}
|
||||
} else {
|
||||
fallbackContent := ""
|
||||
if systemPrompt != "" {
|
||||
fallbackContent = "--- SYSTEM PROMPT ---\n" + systemPrompt + "\n--- END SYSTEM PROMPT ---\n"
|
||||
}
|
||||
currentMessage = KiroCurrentMessage{UserInputMessage: KiroUserInputMessage{
|
||||
Content: fallbackContent,
|
||||
ModelID: modelID,
|
||||
Origin: origin,
|
||||
}}
|
||||
}
|
||||
|
||||
// Build inferenceConfig if we have any inference parameters
|
||||
var inferenceConfig *KiroInferenceConfig
|
||||
if maxTokens > 0 || hasTemperature || hasTopP {
|
||||
inferenceConfig = &KiroInferenceConfig{}
|
||||
if maxTokens > 0 {
|
||||
inferenceConfig.MaxTokens = int(maxTokens)
|
||||
}
|
||||
if hasTemperature {
|
||||
inferenceConfig.Temperature = temperature
|
||||
}
|
||||
if hasTopP {
|
||||
inferenceConfig.TopP = topP
|
||||
}
|
||||
}
|
||||
|
||||
payload := KiroPayload{
|
||||
ConversationState: KiroConversationState{
|
||||
ChatTriggerType: "MANUAL",
|
||||
ConversationID: uuid.New().String(),
|
||||
CurrentMessage: currentMessage,
|
||||
History: history,
|
||||
},
|
||||
ProfileArn: profileArn,
|
||||
InferenceConfig: inferenceConfig,
|
||||
}
|
||||
|
||||
result, err := json.Marshal(payload)
|
||||
if err != nil {
|
||||
log.Debugf("kiro-openai: failed to marshal payload: %v", err)
|
||||
return nil, false
|
||||
}
|
||||
|
||||
return result, thinkingEnabled
|
||||
}
|
||||
|
||||
// normalizeOrigin normalizes origin value for Kiro API compatibility
|
||||
func normalizeOrigin(origin string) string {
|
||||
switch origin {
|
||||
case "KIRO_CLI":
|
||||
return "CLI"
|
||||
case "KIRO_AI_EDITOR":
|
||||
return "AI_EDITOR"
|
||||
case "AMAZON_Q":
|
||||
return "CLI"
|
||||
case "KIRO_IDE":
|
||||
return "AI_EDITOR"
|
||||
default:
|
||||
return origin
|
||||
}
|
||||
}
|
||||
|
||||
// extractSystemPromptFromOpenAI extracts system prompt from OpenAI messages
|
||||
func extractSystemPromptFromOpenAI(messages gjson.Result) string {
|
||||
if !messages.IsArray() {
|
||||
return ""
|
||||
}
|
||||
|
||||
var systemParts []string
|
||||
for _, msg := range messages.Array() {
|
||||
if msg.Get("role").String() == "system" {
|
||||
content := msg.Get("content")
|
||||
if content.Type == gjson.String {
|
||||
systemParts = append(systemParts, content.String())
|
||||
} else if content.IsArray() {
|
||||
// Handle array content format
|
||||
for _, part := range content.Array() {
|
||||
if part.Get("type").String() == "text" {
|
||||
systemParts = append(systemParts, part.Get("text").String())
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return strings.Join(systemParts, "\n")
|
||||
}
|
||||
|
||||
// shortenToolNameIfNeeded shortens tool names that exceed 64 characters.
|
||||
// MCP tools often have long names like "mcp__server-name__tool-name".
|
||||
// This preserves the "mcp__" prefix and last segment when possible.
|
||||
func shortenToolNameIfNeeded(name string) string {
|
||||
const limit = 64
|
||||
if len(name) <= limit {
|
||||
return name
|
||||
}
|
||||
// For MCP tools, try to preserve prefix and last segment
|
||||
if strings.HasPrefix(name, "mcp__") {
|
||||
idx := strings.LastIndex(name, "__")
|
||||
if idx > 0 {
|
||||
cand := "mcp__" + name[idx+2:]
|
||||
if len(cand) > limit {
|
||||
return cand[:limit]
|
||||
}
|
||||
return cand
|
||||
}
|
||||
}
|
||||
return name[:limit]
|
||||
}
|
||||
|
||||
// convertOpenAIToolsToKiro converts OpenAI tools to Kiro format
|
||||
func convertOpenAIToolsToKiro(tools gjson.Result) []KiroToolWrapper {
|
||||
var kiroTools []KiroToolWrapper
|
||||
if !tools.IsArray() {
|
||||
return kiroTools
|
||||
}
|
||||
|
||||
for _, tool := range tools.Array() {
|
||||
// OpenAI tools have type "function" with function definition inside
|
||||
if tool.Get("type").String() != "function" {
|
||||
continue
|
||||
}
|
||||
|
||||
fn := tool.Get("function")
|
||||
if !fn.Exists() {
|
||||
continue
|
||||
}
|
||||
|
||||
name := fn.Get("name").String()
|
||||
description := fn.Get("description").String()
|
||||
parameters := fn.Get("parameters").Value()
|
||||
|
||||
// Shorten tool name if it exceeds 64 characters (common with MCP tools)
|
||||
originalName := name
|
||||
name = shortenToolNameIfNeeded(name)
|
||||
if name != originalName {
|
||||
log.Debugf("kiro-openai: shortened tool name from '%s' to '%s'", originalName, name)
|
||||
}
|
||||
|
||||
// CRITICAL FIX: Kiro API requires non-empty description
|
||||
if strings.TrimSpace(description) == "" {
|
||||
description = fmt.Sprintf("Tool: %s", name)
|
||||
log.Debugf("kiro-openai: tool '%s' has empty description, using default: %s", name, description)
|
||||
}
|
||||
|
||||
// Truncate long descriptions
|
||||
if len(description) > kirocommon.KiroMaxToolDescLen {
|
||||
truncLen := kirocommon.KiroMaxToolDescLen - 30
|
||||
for truncLen > 0 && !utf8.RuneStart(description[truncLen]) {
|
||||
truncLen--
|
||||
}
|
||||
description = description[:truncLen] + "... (description truncated)"
|
||||
}
|
||||
|
||||
kiroTools = append(kiroTools, KiroToolWrapper{
|
||||
ToolSpecification: KiroToolSpecification{
|
||||
Name: name,
|
||||
Description: description,
|
||||
InputSchema: KiroInputSchema{JSON: parameters},
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
return kiroTools
|
||||
}
|
||||
|
||||
// processOpenAIMessages processes OpenAI messages and builds Kiro history
|
||||
func processOpenAIMessages(messages gjson.Result, modelID, origin string) ([]KiroHistoryMessage, *KiroUserInputMessage, []KiroToolResult) {
|
||||
var history []KiroHistoryMessage
|
||||
var currentUserMsg *KiroUserInputMessage
|
||||
var currentToolResults []KiroToolResult
|
||||
|
||||
if !messages.IsArray() {
|
||||
return history, currentUserMsg, currentToolResults
|
||||
}
|
||||
|
||||
// Merge adjacent messages with the same role
|
||||
messagesArray := kirocommon.MergeAdjacentMessages(messages.Array())
|
||||
|
||||
// Build tool_call_id to name mapping from assistant messages
|
||||
toolCallIDToName := make(map[string]string)
|
||||
for _, msg := range messagesArray {
|
||||
if msg.Get("role").String() == "assistant" {
|
||||
toolCalls := msg.Get("tool_calls")
|
||||
if toolCalls.IsArray() {
|
||||
for _, tc := range toolCalls.Array() {
|
||||
if tc.Get("type").String() == "function" {
|
||||
id := tc.Get("id").String()
|
||||
name := tc.Get("function.name").String()
|
||||
if id != "" && name != "" {
|
||||
toolCallIDToName[id] = name
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
for i, msg := range messagesArray {
|
||||
role := msg.Get("role").String()
|
||||
isLastMessage := i == len(messagesArray)-1
|
||||
|
||||
switch role {
|
||||
case "system":
|
||||
// System messages are handled separately via extractSystemPromptFromOpenAI
|
||||
continue
|
||||
|
||||
case "user":
|
||||
userMsg, toolResults := buildUserMessageFromOpenAI(msg, modelID, origin)
|
||||
if isLastMessage {
|
||||
currentUserMsg = &userMsg
|
||||
currentToolResults = toolResults
|
||||
} else {
|
||||
// CRITICAL: Kiro API requires content to be non-empty for history messages
|
||||
if strings.TrimSpace(userMsg.Content) == "" {
|
||||
if len(toolResults) > 0 {
|
||||
userMsg.Content = "Tool results provided."
|
||||
} else {
|
||||
userMsg.Content = "Continue"
|
||||
}
|
||||
}
|
||||
// For history messages, embed tool results in context
|
||||
if len(toolResults) > 0 {
|
||||
userMsg.UserInputMessageContext = &KiroUserInputMessageContext{
|
||||
ToolResults: toolResults,
|
||||
}
|
||||
}
|
||||
history = append(history, KiroHistoryMessage{
|
||||
UserInputMessage: &userMsg,
|
||||
})
|
||||
}
|
||||
|
||||
case "assistant":
|
||||
assistantMsg := buildAssistantMessageFromOpenAI(msg)
|
||||
if isLastMessage {
|
||||
history = append(history, KiroHistoryMessage{
|
||||
AssistantResponseMessage: &assistantMsg,
|
||||
})
|
||||
// Create a "Continue" user message as currentMessage
|
||||
currentUserMsg = &KiroUserInputMessage{
|
||||
Content: "Continue",
|
||||
ModelID: modelID,
|
||||
Origin: origin,
|
||||
}
|
||||
} else {
|
||||
history = append(history, KiroHistoryMessage{
|
||||
AssistantResponseMessage: &assistantMsg,
|
||||
})
|
||||
}
|
||||
|
||||
case "tool":
|
||||
// Tool messages in OpenAI format provide results for tool_calls
|
||||
// These are typically followed by user or assistant messages
|
||||
// Process them and merge into the next user message's tool results
|
||||
toolCallID := msg.Get("tool_call_id").String()
|
||||
content := msg.Get("content").String()
|
||||
|
||||
if toolCallID != "" {
|
||||
toolResult := KiroToolResult{
|
||||
ToolUseID: toolCallID,
|
||||
Content: []KiroTextContent{{Text: content}},
|
||||
Status: "success",
|
||||
}
|
||||
// Tool results should be included in the next user message
|
||||
// For now, collect them and they'll be handled when we build the current message
|
||||
currentToolResults = append(currentToolResults, toolResult)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return history, currentUserMsg, currentToolResults
|
||||
}
|
||||
|
||||
// buildUserMessageFromOpenAI builds a user message from OpenAI format and extracts tool results
|
||||
func buildUserMessageFromOpenAI(msg gjson.Result, modelID, origin string) (KiroUserInputMessage, []KiroToolResult) {
|
||||
content := msg.Get("content")
|
||||
var contentBuilder strings.Builder
|
||||
var toolResults []KiroToolResult
|
||||
var images []KiroImage
|
||||
|
||||
// Track seen toolCallIds to deduplicate
|
||||
seenToolCallIDs := make(map[string]bool)
|
||||
|
||||
if content.IsArray() {
|
||||
for _, part := range content.Array() {
|
||||
partType := part.Get("type").String()
|
||||
switch partType {
|
||||
case "text":
|
||||
contentBuilder.WriteString(part.Get("text").String())
|
||||
case "image_url":
|
||||
imageURL := part.Get("image_url.url").String()
|
||||
if strings.HasPrefix(imageURL, "data:") {
|
||||
// Parse data URL: data:image/png;base64,xxxxx
|
||||
if idx := strings.Index(imageURL, ";base64,"); idx != -1 {
|
||||
mediaType := imageURL[5:idx] // Skip "data:"
|
||||
data := imageURL[idx+8:] // Skip ";base64,"
|
||||
|
||||
format := ""
|
||||
if lastSlash := strings.LastIndex(mediaType, "/"); lastSlash != -1 {
|
||||
format = mediaType[lastSlash+1:]
|
||||
}
|
||||
|
||||
if format != "" && data != "" {
|
||||
images = append(images, KiroImage{
|
||||
Format: format,
|
||||
Source: KiroImageSource{
|
||||
Bytes: data,
|
||||
},
|
||||
})
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
} else if content.Type == gjson.String {
|
||||
contentBuilder.WriteString(content.String())
|
||||
}
|
||||
|
||||
// Check for tool_calls in the message (shouldn't be in user messages, but handle edge cases)
|
||||
_ = seenToolCallIDs // Used for deduplication if needed
|
||||
|
||||
userMsg := KiroUserInputMessage{
|
||||
Content: contentBuilder.String(),
|
||||
ModelID: modelID,
|
||||
Origin: origin,
|
||||
}
|
||||
|
||||
if len(images) > 0 {
|
||||
userMsg.Images = images
|
||||
}
|
||||
|
||||
return userMsg, toolResults
|
||||
}
|
||||
|
||||
// buildAssistantMessageFromOpenAI builds an assistant message from OpenAI format
|
||||
func buildAssistantMessageFromOpenAI(msg gjson.Result) KiroAssistantResponseMessage {
|
||||
content := msg.Get("content")
|
||||
var contentBuilder strings.Builder
|
||||
var toolUses []KiroToolUse
|
||||
|
||||
// Handle content
|
||||
if content.Type == gjson.String {
|
||||
contentBuilder.WriteString(content.String())
|
||||
} else if content.IsArray() {
|
||||
for _, part := range content.Array() {
|
||||
if part.Get("type").String() == "text" {
|
||||
contentBuilder.WriteString(part.Get("text").String())
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Handle tool_calls
|
||||
toolCalls := msg.Get("tool_calls")
|
||||
if toolCalls.IsArray() {
|
||||
for _, tc := range toolCalls.Array() {
|
||||
if tc.Get("type").String() != "function" {
|
||||
continue
|
||||
}
|
||||
|
||||
toolUseID := tc.Get("id").String()
|
||||
toolName := tc.Get("function.name").String()
|
||||
toolArgs := tc.Get("function.arguments").String()
|
||||
|
||||
var inputMap map[string]interface{}
|
||||
if err := json.Unmarshal([]byte(toolArgs), &inputMap); err != nil {
|
||||
log.Debugf("kiro-openai: failed to parse tool arguments: %v", err)
|
||||
inputMap = make(map[string]interface{})
|
||||
}
|
||||
|
||||
toolUses = append(toolUses, KiroToolUse{
|
||||
ToolUseID: toolUseID,
|
||||
Name: toolName,
|
||||
Input: inputMap,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
return KiroAssistantResponseMessage{
|
||||
Content: contentBuilder.String(),
|
||||
ToolUses: toolUses,
|
||||
}
|
||||
}
|
||||
|
||||
// buildFinalContent builds the final content with system prompt
|
||||
func buildFinalContent(content, systemPrompt string, toolResults []KiroToolResult) string {
|
||||
var contentBuilder strings.Builder
|
||||
|
||||
if systemPrompt != "" {
|
||||
contentBuilder.WriteString("--- SYSTEM PROMPT ---\n")
|
||||
contentBuilder.WriteString(systemPrompt)
|
||||
contentBuilder.WriteString("\n--- END SYSTEM PROMPT ---\n\n")
|
||||
}
|
||||
|
||||
contentBuilder.WriteString(content)
|
||||
finalContent := contentBuilder.String()
|
||||
|
||||
// CRITICAL: Kiro API requires content to be non-empty
|
||||
if strings.TrimSpace(finalContent) == "" {
|
||||
if len(toolResults) > 0 {
|
||||
finalContent = "Tool results provided."
|
||||
} else {
|
||||
finalContent = "Continue"
|
||||
}
|
||||
log.Debugf("kiro-openai: content was empty, using default: %s", finalContent)
|
||||
}
|
||||
|
||||
return finalContent
|
||||
}
|
||||
|
||||
// checkThinkingModeFromOpenAI checks if thinking mode is enabled in the OpenAI request.
|
||||
// Returns (thinkingEnabled, budgetTokens).
|
||||
// Supports:
|
||||
// - reasoning_effort parameter (low/medium/high/auto)
|
||||
// - Model name containing "thinking" or "reason"
|
||||
// - <thinking_mode> tag in system prompt (AMP/Cursor format)
|
||||
func checkThinkingModeFromOpenAI(openaiBody []byte) (bool, int64) {
|
||||
var budgetTokens int64 = 16000 // Default budget
|
||||
|
||||
// Check OpenAI format: reasoning_effort parameter
|
||||
// Valid values: "low", "medium", "high", "auto" (not "none")
|
||||
reasoningEffort := gjson.GetBytes(openaiBody, "reasoning_effort")
|
||||
if reasoningEffort.Exists() {
|
||||
effort := reasoningEffort.String()
|
||||
if effort != "" && effort != "none" {
|
||||
log.Debugf("kiro-openai: thinking mode enabled via reasoning_effort: %s", effort)
|
||||
// Adjust budget based on effort level
|
||||
switch effort {
|
||||
case "low":
|
||||
budgetTokens = 8000
|
||||
case "medium":
|
||||
budgetTokens = 16000
|
||||
case "high":
|
||||
budgetTokens = 32000
|
||||
case "auto":
|
||||
budgetTokens = 16000
|
||||
}
|
||||
return true, budgetTokens
|
||||
}
|
||||
}
|
||||
|
||||
// Check AMP/Cursor format: <thinking_mode>interleaved</thinking_mode> in system prompt
|
||||
bodyStr := string(openaiBody)
|
||||
if strings.Contains(bodyStr, "<thinking_mode>") && strings.Contains(bodyStr, "</thinking_mode>") {
|
||||
startTag := "<thinking_mode>"
|
||||
endTag := "</thinking_mode>"
|
||||
startIdx := strings.Index(bodyStr, startTag)
|
||||
if startIdx >= 0 {
|
||||
startIdx += len(startTag)
|
||||
endIdx := strings.Index(bodyStr[startIdx:], endTag)
|
||||
if endIdx >= 0 {
|
||||
thinkingMode := bodyStr[startIdx : startIdx+endIdx]
|
||||
if thinkingMode == "interleaved" || thinkingMode == "enabled" {
|
||||
log.Debugf("kiro-openai: thinking mode enabled via AMP/Cursor format: %s", thinkingMode)
|
||||
// Try to extract max_thinking_length if present
|
||||
if maxLenStart := strings.Index(bodyStr, "<max_thinking_length>"); maxLenStart >= 0 {
|
||||
maxLenStart += len("<max_thinking_length>")
|
||||
if maxLenEnd := strings.Index(bodyStr[maxLenStart:], "</max_thinking_length>"); maxLenEnd >= 0 {
|
||||
maxLenStr := bodyStr[maxLenStart : maxLenStart+maxLenEnd]
|
||||
if parsed, err := fmt.Sscanf(maxLenStr, "%d", &budgetTokens); err == nil && parsed == 1 {
|
||||
log.Debugf("kiro-openai: extracted max_thinking_length: %d", budgetTokens)
|
||||
}
|
||||
}
|
||||
}
|
||||
return true, budgetTokens
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Check model name for thinking hints
|
||||
model := gjson.GetBytes(openaiBody, "model").String()
|
||||
modelLower := strings.ToLower(model)
|
||||
if strings.Contains(modelLower, "thinking") || strings.Contains(modelLower, "-reason") {
|
||||
log.Debugf("kiro-openai: thinking mode enabled via model name hint: %s", model)
|
||||
return true, budgetTokens
|
||||
}
|
||||
|
||||
log.Debugf("kiro-openai: no thinking mode detected in OpenAI request")
|
||||
return false, budgetTokens
|
||||
}
|
||||
|
||||
// extractToolChoiceHint extracts tool_choice from OpenAI request and returns a system prompt hint.
|
||||
// OpenAI tool_choice values:
|
||||
// - "none": Don't use any tools
|
||||
// - "auto": Model decides (default, no hint needed)
|
||||
// - "required": Must use at least one tool
|
||||
// - {"type":"function","function":{"name":"..."}} : Must use specific tool
|
||||
func extractToolChoiceHint(openaiBody []byte) string {
|
||||
toolChoice := gjson.GetBytes(openaiBody, "tool_choice")
|
||||
if !toolChoice.Exists() {
|
||||
return ""
|
||||
}
|
||||
|
||||
// Handle string values
|
||||
if toolChoice.Type == gjson.String {
|
||||
switch toolChoice.String() {
|
||||
case "none":
|
||||
// Note: When tool_choice is "none", we should ideally not pass tools at all
|
||||
// But since we can't modify tool passing here, we add a strong hint
|
||||
return "[INSTRUCTION: Do NOT use any tools. Respond with text only.]"
|
||||
case "required":
|
||||
return "[INSTRUCTION: You MUST use at least one of the available tools to respond. Do not respond with text only - always make a tool call.]"
|
||||
case "auto":
|
||||
// Default behavior, no hint needed
|
||||
return ""
|
||||
}
|
||||
}
|
||||
|
||||
// Handle object value: {"type":"function","function":{"name":"..."}}
|
||||
if toolChoice.IsObject() {
|
||||
if toolChoice.Get("type").String() == "function" {
|
||||
toolName := toolChoice.Get("function.name").String()
|
||||
if toolName != "" {
|
||||
return fmt.Sprintf("[INSTRUCTION: You MUST use the tool named '%s' to respond. Do not use any other tool or respond with text only.]", toolName)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return ""
|
||||
}
|
||||
|
||||
// extractResponseFormatHint extracts response_format from OpenAI request and returns a system prompt hint.
|
||||
// OpenAI response_format values:
|
||||
// - {"type": "text"}: Default, no hint needed
|
||||
// - {"type": "json_object"}: Must respond with valid JSON
|
||||
// - {"type": "json_schema", "json_schema": {...}}: Must respond with JSON matching schema
|
||||
func extractResponseFormatHint(openaiBody []byte) string {
|
||||
responseFormat := gjson.GetBytes(openaiBody, "response_format")
|
||||
if !responseFormat.Exists() {
|
||||
return ""
|
||||
}
|
||||
|
||||
formatType := responseFormat.Get("type").String()
|
||||
switch formatType {
|
||||
case "json_object":
|
||||
return "[INSTRUCTION: You MUST respond with valid JSON only. Do not include any text before or after the JSON. Do not wrap the JSON in markdown code blocks. Output raw JSON directly.]"
|
||||
case "json_schema":
|
||||
// Extract schema if provided
|
||||
schema := responseFormat.Get("json_schema.schema")
|
||||
if schema.Exists() {
|
||||
schemaStr := schema.Raw
|
||||
// Truncate if too long
|
||||
if len(schemaStr) > 500 {
|
||||
schemaStr = schemaStr[:500] + "..."
|
||||
}
|
||||
return fmt.Sprintf("[INSTRUCTION: You MUST respond with valid JSON that matches this schema: %s. Do not include any text before or after the JSON. Do not wrap the JSON in markdown code blocks. Output raw JSON directly.]", schemaStr)
|
||||
}
|
||||
return "[INSTRUCTION: You MUST respond with valid JSON only. Do not include any text before or after the JSON. Do not wrap the JSON in markdown code blocks. Output raw JSON directly.]"
|
||||
case "text":
|
||||
// Default behavior, no hint needed
|
||||
return ""
|
||||
}
|
||||
|
||||
return ""
|
||||
}
|
||||
|
||||
// deduplicateToolResults removes duplicate tool results
|
||||
func deduplicateToolResults(toolResults []KiroToolResult) []KiroToolResult {
|
||||
if len(toolResults) == 0 {
|
||||
return toolResults
|
||||
}
|
||||
|
||||
seenIDs := make(map[string]bool)
|
||||
unique := make([]KiroToolResult, 0, len(toolResults))
|
||||
for _, tr := range toolResults {
|
||||
if !seenIDs[tr.ToolUseID] {
|
||||
seenIDs[tr.ToolUseID] = true
|
||||
unique = append(unique, tr)
|
||||
} else {
|
||||
log.Debugf("kiro-openai: skipping duplicate toolResult: %s", tr.ToolUseID)
|
||||
}
|
||||
}
|
||||
return unique
|
||||
}
|
||||
264
internal/translator/kiro/openai/kiro_openai_response.go
Normal file
264
internal/translator/kiro/openai/kiro_openai_response.go
Normal file
@@ -0,0 +1,264 @@
|
||||
// Package openai provides response translation from Kiro to OpenAI format.
|
||||
// This package handles the conversion of Kiro API responses into OpenAI Chat Completions-compatible
|
||||
// JSON format, transforming streaming events and non-streaming responses.
|
||||
package openai
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/usage"
|
||||
log "github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
// functionCallIDCounter provides a process-wide unique counter for function call identifiers.
|
||||
var functionCallIDCounter uint64
|
||||
|
||||
// BuildOpenAIResponse constructs an OpenAI Chat Completions-compatible response.
|
||||
// Supports tool_calls when tools are present in the response.
|
||||
// stopReason is passed from upstream; fallback logic applied if empty.
|
||||
func BuildOpenAIResponse(content string, toolUses []KiroToolUse, model string, usageInfo usage.Detail, stopReason string) []byte {
|
||||
// Build the message object
|
||||
message := map[string]interface{}{
|
||||
"role": "assistant",
|
||||
"content": content,
|
||||
}
|
||||
|
||||
// Add tool_calls if present
|
||||
if len(toolUses) > 0 {
|
||||
var toolCalls []map[string]interface{}
|
||||
for i, tu := range toolUses {
|
||||
inputJSON, _ := json.Marshal(tu.Input)
|
||||
toolCalls = append(toolCalls, map[string]interface{}{
|
||||
"id": tu.ToolUseID,
|
||||
"type": "function",
|
||||
"index": i,
|
||||
"function": map[string]interface{}{
|
||||
"name": tu.Name,
|
||||
"arguments": string(inputJSON),
|
||||
},
|
||||
})
|
||||
}
|
||||
message["tool_calls"] = toolCalls
|
||||
// When tool_calls are present, content should be null according to OpenAI spec
|
||||
if content == "" {
|
||||
message["content"] = nil
|
||||
}
|
||||
}
|
||||
|
||||
// Use upstream stopReason; apply fallback logic if not provided
|
||||
finishReason := mapKiroStopReasonToOpenAI(stopReason)
|
||||
if finishReason == "" {
|
||||
finishReason = "stop"
|
||||
if len(toolUses) > 0 {
|
||||
finishReason = "tool_calls"
|
||||
}
|
||||
log.Debugf("kiro-openai: buildOpenAIResponse using fallback finish_reason: %s", finishReason)
|
||||
}
|
||||
|
||||
response := map[string]interface{}{
|
||||
"id": "chatcmpl-" + uuid.New().String()[:24],
|
||||
"object": "chat.completion",
|
||||
"created": time.Now().Unix(),
|
||||
"model": model,
|
||||
"choices": []map[string]interface{}{
|
||||
{
|
||||
"index": 0,
|
||||
"message": message,
|
||||
"finish_reason": finishReason,
|
||||
},
|
||||
},
|
||||
"usage": map[string]interface{}{
|
||||
"prompt_tokens": usageInfo.InputTokens,
|
||||
"completion_tokens": usageInfo.OutputTokens,
|
||||
"total_tokens": usageInfo.InputTokens + usageInfo.OutputTokens,
|
||||
},
|
||||
}
|
||||
|
||||
result, _ := json.Marshal(response)
|
||||
return result
|
||||
}
|
||||
|
||||
// mapKiroStopReasonToOpenAI converts Kiro/Claude stop_reason to OpenAI finish_reason
|
||||
func mapKiroStopReasonToOpenAI(stopReason string) string {
|
||||
switch stopReason {
|
||||
case "end_turn":
|
||||
return "stop"
|
||||
case "stop_sequence":
|
||||
return "stop"
|
||||
case "tool_use":
|
||||
return "tool_calls"
|
||||
case "max_tokens":
|
||||
return "length"
|
||||
case "content_filtered":
|
||||
return "content_filter"
|
||||
default:
|
||||
return stopReason
|
||||
}
|
||||
}
|
||||
|
||||
// BuildOpenAIStreamChunk constructs an OpenAI Chat Completions streaming chunk.
|
||||
// This is the delta format used in streaming responses.
|
||||
func BuildOpenAIStreamChunk(model string, deltaContent string, deltaToolCalls []map[string]interface{}, finishReason string, index int) []byte {
|
||||
delta := map[string]interface{}{}
|
||||
|
||||
// First chunk should include role
|
||||
if index == 0 && deltaContent == "" && len(deltaToolCalls) == 0 {
|
||||
delta["role"] = "assistant"
|
||||
delta["content"] = ""
|
||||
} else if deltaContent != "" {
|
||||
delta["content"] = deltaContent
|
||||
}
|
||||
|
||||
// Add tool_calls delta if present
|
||||
if len(deltaToolCalls) > 0 {
|
||||
delta["tool_calls"] = deltaToolCalls
|
||||
}
|
||||
|
||||
choice := map[string]interface{}{
|
||||
"index": 0,
|
||||
"delta": delta,
|
||||
}
|
||||
|
||||
if finishReason != "" {
|
||||
choice["finish_reason"] = finishReason
|
||||
} else {
|
||||
choice["finish_reason"] = nil
|
||||
}
|
||||
|
||||
chunk := map[string]interface{}{
|
||||
"id": "chatcmpl-" + uuid.New().String()[:12],
|
||||
"object": "chat.completion.chunk",
|
||||
"created": time.Now().Unix(),
|
||||
"model": model,
|
||||
"choices": []map[string]interface{}{choice},
|
||||
}
|
||||
|
||||
result, _ := json.Marshal(chunk)
|
||||
return result
|
||||
}
|
||||
|
||||
// BuildOpenAIStreamChunkWithToolCallStart creates a stream chunk for tool call start
|
||||
func BuildOpenAIStreamChunkWithToolCallStart(model string, toolUseID, toolName string, toolIndex int) []byte {
|
||||
toolCall := map[string]interface{}{
|
||||
"index": toolIndex,
|
||||
"id": toolUseID,
|
||||
"type": "function",
|
||||
"function": map[string]interface{}{
|
||||
"name": toolName,
|
||||
"arguments": "",
|
||||
},
|
||||
}
|
||||
|
||||
delta := map[string]interface{}{
|
||||
"tool_calls": []map[string]interface{}{toolCall},
|
||||
}
|
||||
|
||||
choice := map[string]interface{}{
|
||||
"index": 0,
|
||||
"delta": delta,
|
||||
"finish_reason": nil,
|
||||
}
|
||||
|
||||
chunk := map[string]interface{}{
|
||||
"id": "chatcmpl-" + uuid.New().String()[:12],
|
||||
"object": "chat.completion.chunk",
|
||||
"created": time.Now().Unix(),
|
||||
"model": model,
|
||||
"choices": []map[string]interface{}{choice},
|
||||
}
|
||||
|
||||
result, _ := json.Marshal(chunk)
|
||||
return result
|
||||
}
|
||||
|
||||
// BuildOpenAIStreamChunkWithToolCallDelta creates a stream chunk for tool call arguments delta
|
||||
func BuildOpenAIStreamChunkWithToolCallDelta(model string, argumentsDelta string, toolIndex int) []byte {
|
||||
toolCall := map[string]interface{}{
|
||||
"index": toolIndex,
|
||||
"function": map[string]interface{}{
|
||||
"arguments": argumentsDelta,
|
||||
},
|
||||
}
|
||||
|
||||
delta := map[string]interface{}{
|
||||
"tool_calls": []map[string]interface{}{toolCall},
|
||||
}
|
||||
|
||||
choice := map[string]interface{}{
|
||||
"index": 0,
|
||||
"delta": delta,
|
||||
"finish_reason": nil,
|
||||
}
|
||||
|
||||
chunk := map[string]interface{}{
|
||||
"id": "chatcmpl-" + uuid.New().String()[:12],
|
||||
"object": "chat.completion.chunk",
|
||||
"created": time.Now().Unix(),
|
||||
"model": model,
|
||||
"choices": []map[string]interface{}{choice},
|
||||
}
|
||||
|
||||
result, _ := json.Marshal(chunk)
|
||||
return result
|
||||
}
|
||||
|
||||
// BuildOpenAIStreamDoneChunk creates the final [DONE] stream event
|
||||
func BuildOpenAIStreamDoneChunk() []byte {
|
||||
return []byte("data: [DONE]")
|
||||
}
|
||||
|
||||
// BuildOpenAIStreamFinishChunk creates the final chunk with finish_reason
|
||||
func BuildOpenAIStreamFinishChunk(model string, finishReason string) []byte {
|
||||
choice := map[string]interface{}{
|
||||
"index": 0,
|
||||
"delta": map[string]interface{}{},
|
||||
"finish_reason": finishReason,
|
||||
}
|
||||
|
||||
chunk := map[string]interface{}{
|
||||
"id": "chatcmpl-" + uuid.New().String()[:12],
|
||||
"object": "chat.completion.chunk",
|
||||
"created": time.Now().Unix(),
|
||||
"model": model,
|
||||
"choices": []map[string]interface{}{choice},
|
||||
}
|
||||
|
||||
result, _ := json.Marshal(chunk)
|
||||
return result
|
||||
}
|
||||
|
||||
// BuildOpenAIStreamUsageChunk creates a chunk with usage information (optional, for stream_options.include_usage)
|
||||
func BuildOpenAIStreamUsageChunk(model string, usageInfo usage.Detail) []byte {
|
||||
chunk := map[string]interface{}{
|
||||
"id": "chatcmpl-" + uuid.New().String()[:12],
|
||||
"object": "chat.completion.chunk",
|
||||
"created": time.Now().Unix(),
|
||||
"model": model,
|
||||
"choices": []map[string]interface{}{},
|
||||
"usage": map[string]interface{}{
|
||||
"prompt_tokens": usageInfo.InputTokens,
|
||||
"completion_tokens": usageInfo.OutputTokens,
|
||||
"total_tokens": usageInfo.InputTokens + usageInfo.OutputTokens,
|
||||
},
|
||||
}
|
||||
|
||||
result, _ := json.Marshal(chunk)
|
||||
return result
|
||||
}
|
||||
|
||||
// GenerateToolCallID generates a unique tool call ID in OpenAI format
|
||||
func GenerateToolCallID(toolName string) string {
|
||||
return fmt.Sprintf("call_%s_%d_%d", toolName[:min(8, len(toolName))], time.Now().UnixNano(), atomic.AddUint64(&functionCallIDCounter, 1))
|
||||
}
|
||||
|
||||
// min returns the minimum of two integers
|
||||
func min(a, b int) int {
|
||||
if a < b {
|
||||
return a
|
||||
}
|
||||
return b
|
||||
}
|
||||
212
internal/translator/kiro/openai/kiro_openai_stream.go
Normal file
212
internal/translator/kiro/openai/kiro_openai_stream.go
Normal file
@@ -0,0 +1,212 @@
|
||||
// Package openai provides streaming SSE event building for OpenAI format.
|
||||
// This package handles the construction of OpenAI-compatible Server-Sent Events (SSE)
|
||||
// for streaming responses from Kiro API.
|
||||
package openai
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/usage"
|
||||
)
|
||||
|
||||
// OpenAIStreamState tracks the state of streaming response conversion
|
||||
type OpenAIStreamState struct {
|
||||
ChunkIndex int
|
||||
ToolCallIndex int
|
||||
HasSentFirstChunk bool
|
||||
Model string
|
||||
ResponseID string
|
||||
Created int64
|
||||
}
|
||||
|
||||
// NewOpenAIStreamState creates a new stream state for tracking
|
||||
func NewOpenAIStreamState(model string) *OpenAIStreamState {
|
||||
return &OpenAIStreamState{
|
||||
ChunkIndex: 0,
|
||||
ToolCallIndex: 0,
|
||||
HasSentFirstChunk: false,
|
||||
Model: model,
|
||||
ResponseID: "chatcmpl-" + uuid.New().String()[:24],
|
||||
Created: time.Now().Unix(),
|
||||
}
|
||||
}
|
||||
|
||||
// FormatSSEEvent formats a JSON payload for SSE streaming.
|
||||
// Note: This returns raw JSON data without "data:" prefix.
|
||||
// The SSE "data:" prefix is added by the Handler layer (e.g., openai_handlers.go)
|
||||
// to maintain architectural consistency and avoid double-prefix issues.
|
||||
func FormatSSEEvent(data []byte) string {
|
||||
return string(data)
|
||||
}
|
||||
|
||||
// BuildOpenAISSETextDelta creates an SSE event for text content delta
|
||||
func BuildOpenAISSETextDelta(state *OpenAIStreamState, textDelta string) string {
|
||||
delta := map[string]interface{}{
|
||||
"content": textDelta,
|
||||
}
|
||||
|
||||
// Include role in first chunk
|
||||
if !state.HasSentFirstChunk {
|
||||
delta["role"] = "assistant"
|
||||
state.HasSentFirstChunk = true
|
||||
}
|
||||
|
||||
chunk := buildBaseChunk(state, delta, nil)
|
||||
result, _ := json.Marshal(chunk)
|
||||
state.ChunkIndex++
|
||||
return FormatSSEEvent(result)
|
||||
}
|
||||
|
||||
// BuildOpenAISSEToolCallStart creates an SSE event for tool call start
|
||||
func BuildOpenAISSEToolCallStart(state *OpenAIStreamState, toolUseID, toolName string) string {
|
||||
toolCall := map[string]interface{}{
|
||||
"index": state.ToolCallIndex,
|
||||
"id": toolUseID,
|
||||
"type": "function",
|
||||
"function": map[string]interface{}{
|
||||
"name": toolName,
|
||||
"arguments": "",
|
||||
},
|
||||
}
|
||||
|
||||
delta := map[string]interface{}{
|
||||
"tool_calls": []map[string]interface{}{toolCall},
|
||||
}
|
||||
|
||||
// Include role in first chunk if not sent yet
|
||||
if !state.HasSentFirstChunk {
|
||||
delta["role"] = "assistant"
|
||||
state.HasSentFirstChunk = true
|
||||
}
|
||||
|
||||
chunk := buildBaseChunk(state, delta, nil)
|
||||
result, _ := json.Marshal(chunk)
|
||||
state.ChunkIndex++
|
||||
return FormatSSEEvent(result)
|
||||
}
|
||||
|
||||
// BuildOpenAISSEToolCallArgumentsDelta creates an SSE event for tool call arguments delta
|
||||
func BuildOpenAISSEToolCallArgumentsDelta(state *OpenAIStreamState, argumentsDelta string, toolIndex int) string {
|
||||
toolCall := map[string]interface{}{
|
||||
"index": toolIndex,
|
||||
"function": map[string]interface{}{
|
||||
"arguments": argumentsDelta,
|
||||
},
|
||||
}
|
||||
|
||||
delta := map[string]interface{}{
|
||||
"tool_calls": []map[string]interface{}{toolCall},
|
||||
}
|
||||
|
||||
chunk := buildBaseChunk(state, delta, nil)
|
||||
result, _ := json.Marshal(chunk)
|
||||
state.ChunkIndex++
|
||||
return FormatSSEEvent(result)
|
||||
}
|
||||
|
||||
// BuildOpenAISSEFinish creates an SSE event with finish_reason
|
||||
func BuildOpenAISSEFinish(state *OpenAIStreamState, finishReason string) string {
|
||||
chunk := buildBaseChunk(state, map[string]interface{}{}, &finishReason)
|
||||
result, _ := json.Marshal(chunk)
|
||||
state.ChunkIndex++
|
||||
return FormatSSEEvent(result)
|
||||
}
|
||||
|
||||
// BuildOpenAISSEUsage creates an SSE event with usage information
|
||||
func BuildOpenAISSEUsage(state *OpenAIStreamState, usageInfo usage.Detail) string {
|
||||
chunk := map[string]interface{}{
|
||||
"id": state.ResponseID,
|
||||
"object": "chat.completion.chunk",
|
||||
"created": state.Created,
|
||||
"model": state.Model,
|
||||
"choices": []map[string]interface{}{},
|
||||
"usage": map[string]interface{}{
|
||||
"prompt_tokens": usageInfo.InputTokens,
|
||||
"completion_tokens": usageInfo.OutputTokens,
|
||||
"total_tokens": usageInfo.InputTokens + usageInfo.OutputTokens,
|
||||
},
|
||||
}
|
||||
result, _ := json.Marshal(chunk)
|
||||
return FormatSSEEvent(result)
|
||||
}
|
||||
|
||||
// BuildOpenAISSEDone creates the final [DONE] SSE event.
|
||||
// Note: This returns raw "[DONE]" without "data:" prefix.
|
||||
// The SSE "data:" prefix is added by the Handler layer (e.g., openai_handlers.go)
|
||||
// to maintain architectural consistency and avoid double-prefix issues.
|
||||
func BuildOpenAISSEDone() string {
|
||||
return "[DONE]"
|
||||
}
|
||||
|
||||
// buildBaseChunk creates a base chunk structure for streaming
|
||||
func buildBaseChunk(state *OpenAIStreamState, delta map[string]interface{}, finishReason *string) map[string]interface{} {
|
||||
choice := map[string]interface{}{
|
||||
"index": 0,
|
||||
"delta": delta,
|
||||
}
|
||||
|
||||
if finishReason != nil {
|
||||
choice["finish_reason"] = *finishReason
|
||||
} else {
|
||||
choice["finish_reason"] = nil
|
||||
}
|
||||
|
||||
return map[string]interface{}{
|
||||
"id": state.ResponseID,
|
||||
"object": "chat.completion.chunk",
|
||||
"created": state.Created,
|
||||
"model": state.Model,
|
||||
"choices": []map[string]interface{}{choice},
|
||||
}
|
||||
}
|
||||
|
||||
// BuildOpenAISSEReasoningDelta creates an SSE event for reasoning content delta
|
||||
// This is used for o1/o3 style models that expose reasoning tokens
|
||||
func BuildOpenAISSEReasoningDelta(state *OpenAIStreamState, reasoningDelta string) string {
|
||||
delta := map[string]interface{}{
|
||||
"reasoning_content": reasoningDelta,
|
||||
}
|
||||
|
||||
// Include role in first chunk
|
||||
if !state.HasSentFirstChunk {
|
||||
delta["role"] = "assistant"
|
||||
state.HasSentFirstChunk = true
|
||||
}
|
||||
|
||||
chunk := buildBaseChunk(state, delta, nil)
|
||||
result, _ := json.Marshal(chunk)
|
||||
state.ChunkIndex++
|
||||
return FormatSSEEvent(result)
|
||||
}
|
||||
|
||||
// BuildOpenAISSEFirstChunk creates the first chunk with role only
|
||||
func BuildOpenAISSEFirstChunk(state *OpenAIStreamState) string {
|
||||
delta := map[string]interface{}{
|
||||
"role": "assistant",
|
||||
"content": "",
|
||||
}
|
||||
|
||||
state.HasSentFirstChunk = true
|
||||
chunk := buildBaseChunk(state, delta, nil)
|
||||
result, _ := json.Marshal(chunk)
|
||||
state.ChunkIndex++
|
||||
return FormatSSEEvent(result)
|
||||
}
|
||||
|
||||
// ThinkingTagState tracks state for thinking tag detection in streaming
|
||||
type ThinkingTagState struct {
|
||||
InThinkingBlock bool
|
||||
PendingStartChars int
|
||||
PendingEndChars int
|
||||
}
|
||||
|
||||
// NewThinkingTagState creates a new thinking tag state
|
||||
func NewThinkingTagState() *ThinkingTagState {
|
||||
return &ThinkingTagState{
|
||||
InThinkingBlock: false,
|
||||
PendingStartChars: 0,
|
||||
PendingEndChars: 0,
|
||||
}
|
||||
}
|
||||
@@ -10,6 +10,7 @@ import (
|
||||
"encoding/json"
|
||||
"strings"
|
||||
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/util"
|
||||
"github.com/tidwall/gjson"
|
||||
"github.com/tidwall/sjson"
|
||||
)
|
||||
@@ -60,6 +61,18 @@ func ConvertClaudeRequestToOpenAI(modelName string, inputRawJSON []byte, stream
|
||||
// Stream
|
||||
out, _ = sjson.Set(out, "stream", stream)
|
||||
|
||||
// Thinking: Convert Claude thinking.budget_tokens to OpenAI reasoning_effort
|
||||
if thinking := root.Get("thinking"); thinking.Exists() && thinking.IsObject() {
|
||||
if thinkingType := thinking.Get("type"); thinkingType.Exists() && thinkingType.String() == "enabled" {
|
||||
if budgetTokens := thinking.Get("budget_tokens"); budgetTokens.Exists() {
|
||||
budget := int(budgetTokens.Int())
|
||||
if effort, ok := util.OpenAIThinkingBudgetToEffort(modelName, budget); ok && effort != "" {
|
||||
out, _ = sjson.Set(out, "reasoning_effort", effort)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Process messages and system
|
||||
var messagesJSON = "[]"
|
||||
|
||||
|
||||
@@ -13,6 +13,7 @@ import (
|
||||
"math/big"
|
||||
"strings"
|
||||
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/util"
|
||||
"github.com/tidwall/gjson"
|
||||
"github.com/tidwall/sjson"
|
||||
)
|
||||
@@ -76,6 +77,17 @@ func ConvertGeminiRequestToOpenAI(modelName string, inputRawJSON []byte, stream
|
||||
out, _ = sjson.Set(out, "stop", stops)
|
||||
}
|
||||
}
|
||||
|
||||
// Convert thinkingBudget to reasoning_effort
|
||||
// Always perform conversion to support allowCompat models that may not be in registry
|
||||
if thinkingConfig := genConfig.Get("thinkingConfig"); thinkingConfig.Exists() && thinkingConfig.IsObject() {
|
||||
if thinkingBudget := thinkingConfig.Get("thinkingBudget"); thinkingBudget.Exists() {
|
||||
budget := int(thinkingBudget.Int())
|
||||
if effort, ok := util.OpenAIThinkingBudgetToEffort(modelName, budget); ok && effort != "" {
|
||||
out, _ = sjson.Set(out, "reasoning_effort", effort)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Stream parameter
|
||||
|
||||
@@ -2,6 +2,7 @@ package responses
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"strings"
|
||||
|
||||
"github.com/tidwall/gjson"
|
||||
"github.com/tidwall/sjson"
|
||||
@@ -189,23 +190,9 @@ func ConvertOpenAIResponsesRequestToOpenAIChatCompletions(modelName string, inpu
|
||||
}
|
||||
|
||||
if reasoningEffort := root.Get("reasoning.effort"); reasoningEffort.Exists() {
|
||||
switch reasoningEffort.String() {
|
||||
case "none":
|
||||
out, _ = sjson.Set(out, "reasoning_effort", "none")
|
||||
case "auto":
|
||||
out, _ = sjson.Set(out, "reasoning_effort", "auto")
|
||||
case "minimal":
|
||||
out, _ = sjson.Set(out, "reasoning_effort", "low")
|
||||
case "low":
|
||||
out, _ = sjson.Set(out, "reasoning_effort", "low")
|
||||
case "medium":
|
||||
out, _ = sjson.Set(out, "reasoning_effort", "medium")
|
||||
case "high":
|
||||
out, _ = sjson.Set(out, "reasoning_effort", "high")
|
||||
case "xhigh":
|
||||
out, _ = sjson.Set(out, "reasoning_effort", "xhigh")
|
||||
default:
|
||||
out, _ = sjson.Set(out, "reasoning_effort", "auto")
|
||||
effort := strings.ToLower(strings.TrimSpace(reasoningEffort.String()))
|
||||
if effort != "" {
|
||||
out, _ = sjson.Set(out, "reasoning_effort", effort)
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -28,6 +28,9 @@ func ApplyClaudeThinkingConfig(body []byte, budget *int) []byte {
|
||||
// It uses the unified ResolveThinkingConfigFromMetadata and normalizes the budget.
|
||||
// Returns the normalized budget (nil if thinking should not be enabled) and whether it matched.
|
||||
func ResolveClaudeThinkingConfig(modelName string, metadata map[string]any) (*int, bool) {
|
||||
if !ModelSupportsThinking(modelName) {
|
||||
return nil, false
|
||||
}
|
||||
budget, include, matched := ResolveThinkingConfigFromMetadata(modelName, metadata)
|
||||
if !matched {
|
||||
return nil, false
|
||||
|
||||
@@ -25,9 +25,15 @@ func ApplyGeminiThinkingConfig(body []byte, budget *int, includeThoughts *bool)
|
||||
updated = rewritten
|
||||
}
|
||||
}
|
||||
if includeThoughts != nil {
|
||||
// Default to including thoughts when a budget override is present but no explicit include flag is provided.
|
||||
incl := includeThoughts
|
||||
if incl == nil && budget != nil && *budget != 0 {
|
||||
defaultInclude := true
|
||||
incl = &defaultInclude
|
||||
}
|
||||
if incl != nil {
|
||||
valuePath := "generationConfig.thinkingConfig.include_thoughts"
|
||||
rewritten, err := sjson.SetBytes(updated, valuePath, *includeThoughts)
|
||||
rewritten, err := sjson.SetBytes(updated, valuePath, *incl)
|
||||
if err == nil {
|
||||
updated = rewritten
|
||||
}
|
||||
@@ -47,9 +53,15 @@ func ApplyGeminiCLIThinkingConfig(body []byte, budget *int, includeThoughts *boo
|
||||
updated = rewritten
|
||||
}
|
||||
}
|
||||
if includeThoughts != nil {
|
||||
// Default to including thoughts when a budget override is present but no explicit include flag is provided.
|
||||
incl := includeThoughts
|
||||
if incl == nil && budget != nil && *budget != 0 {
|
||||
defaultInclude := true
|
||||
incl = &defaultInclude
|
||||
}
|
||||
if incl != nil {
|
||||
valuePath := "request.generationConfig.thinkingConfig.include_thoughts"
|
||||
rewritten, err := sjson.SetBytes(updated, valuePath, *includeThoughts)
|
||||
rewritten, err := sjson.SetBytes(updated, valuePath, *incl)
|
||||
if err == nil {
|
||||
updated = rewritten
|
||||
}
|
||||
@@ -140,6 +152,71 @@ func NormalizeGeminiCLIThinkingBudget(model string, body []byte) []byte {
|
||||
return updated
|
||||
}
|
||||
|
||||
// ReasoningEffortBudgetMapping defines the thinkingBudget values for each reasoning effort level.
|
||||
var ReasoningEffortBudgetMapping = map[string]int{
|
||||
"none": 0,
|
||||
"auto": -1,
|
||||
"minimal": 512,
|
||||
"low": 1024,
|
||||
"medium": 8192,
|
||||
"high": 24576,
|
||||
"xhigh": 32768,
|
||||
}
|
||||
|
||||
// ApplyReasoningEffortToGemini applies OpenAI reasoning_effort to Gemini thinkingConfig
|
||||
// for standard Gemini API format (generationConfig.thinkingConfig path).
|
||||
// Returns the modified body with thinkingBudget and include_thoughts set.
|
||||
func ApplyReasoningEffortToGemini(body []byte, effort string) []byte {
|
||||
normalized := strings.ToLower(strings.TrimSpace(effort))
|
||||
if normalized == "" {
|
||||
return body
|
||||
}
|
||||
|
||||
budgetPath := "generationConfig.thinkingConfig.thinkingBudget"
|
||||
includePath := "generationConfig.thinkingConfig.include_thoughts"
|
||||
|
||||
if normalized == "none" {
|
||||
body, _ = sjson.DeleteBytes(body, "generationConfig.thinkingConfig")
|
||||
return body
|
||||
}
|
||||
|
||||
budget, ok := ReasoningEffortBudgetMapping[normalized]
|
||||
if !ok {
|
||||
return body
|
||||
}
|
||||
|
||||
body, _ = sjson.SetBytes(body, budgetPath, budget)
|
||||
body, _ = sjson.SetBytes(body, includePath, true)
|
||||
return body
|
||||
}
|
||||
|
||||
// ApplyReasoningEffortToGeminiCLI applies OpenAI reasoning_effort to Gemini CLI thinkingConfig
|
||||
// for Gemini CLI API format (request.generationConfig.thinkingConfig path).
|
||||
// Returns the modified body with thinkingBudget and include_thoughts set.
|
||||
func ApplyReasoningEffortToGeminiCLI(body []byte, effort string) []byte {
|
||||
normalized := strings.ToLower(strings.TrimSpace(effort))
|
||||
if normalized == "" {
|
||||
return body
|
||||
}
|
||||
|
||||
budgetPath := "request.generationConfig.thinkingConfig.thinkingBudget"
|
||||
includePath := "request.generationConfig.thinkingConfig.include_thoughts"
|
||||
|
||||
if normalized == "none" {
|
||||
body, _ = sjson.DeleteBytes(body, "request.generationConfig.thinkingConfig")
|
||||
return body
|
||||
}
|
||||
|
||||
budget, ok := ReasoningEffortBudgetMapping[normalized]
|
||||
if !ok {
|
||||
return body
|
||||
}
|
||||
|
||||
body, _ = sjson.SetBytes(body, budgetPath, budget)
|
||||
body, _ = sjson.SetBytes(body, includePath, true)
|
||||
return body
|
||||
}
|
||||
|
||||
// ConvertThinkingLevelToBudget checks for "generationConfig.thinkingConfig.thinkingLevel"
|
||||
// and converts it to "thinkingBudget".
|
||||
// "high" -> 32768
|
||||
|
||||
37
internal/util/openai_thinking.go
Normal file
37
internal/util/openai_thinking.go
Normal file
@@ -0,0 +1,37 @@
|
||||
package util
|
||||
|
||||
// OpenAIThinkingBudgetToEffort maps a numeric thinking budget (tokens)
|
||||
// into an OpenAI-style reasoning effort level for level-based models.
|
||||
//
|
||||
// Ranges:
|
||||
// - 0 -> "none"
|
||||
// - -1 -> "auto"
|
||||
// - 1..1024 -> "low"
|
||||
// - 1025..8192 -> "medium"
|
||||
// - 8193..24576 -> "high"
|
||||
// - 24577.. -> highest supported level for the model (defaults to "xhigh")
|
||||
//
|
||||
// Negative values other than -1 are treated as unsupported.
|
||||
func OpenAIThinkingBudgetToEffort(model string, budget int) (string, bool) {
|
||||
switch {
|
||||
case budget == -1:
|
||||
return "auto", true
|
||||
case budget < -1:
|
||||
return "", false
|
||||
case budget == 0:
|
||||
return "none", true
|
||||
case budget > 0 && budget <= 1024:
|
||||
return "low", true
|
||||
case budget <= 8192:
|
||||
return "medium", true
|
||||
case budget <= 24576:
|
||||
return "high", true
|
||||
case budget > 24576:
|
||||
if levels := GetModelThinkingLevels(model); len(levels) > 0 {
|
||||
return levels[len(levels)-1], true
|
||||
}
|
||||
return "xhigh", true
|
||||
default:
|
||||
return "", false
|
||||
}
|
||||
}
|
||||
@@ -25,33 +25,33 @@ func ModelSupportsThinking(model string) bool {
|
||||
// or min (0 if zero is allowed and mid <= 0).
|
||||
func NormalizeThinkingBudget(model string, budget int) int {
|
||||
if budget == -1 { // dynamic
|
||||
if found, min, max, zeroAllowed, dynamicAllowed := thinkingRangeFromRegistry(model); found {
|
||||
if found, minBudget, maxBudget, zeroAllowed, dynamicAllowed := thinkingRangeFromRegistry(model); found {
|
||||
if dynamicAllowed {
|
||||
return -1
|
||||
}
|
||||
mid := (min + max) / 2
|
||||
mid := (minBudget + maxBudget) / 2
|
||||
if mid <= 0 && zeroAllowed {
|
||||
return 0
|
||||
}
|
||||
if mid <= 0 {
|
||||
return min
|
||||
return minBudget
|
||||
}
|
||||
return mid
|
||||
}
|
||||
return -1
|
||||
}
|
||||
if found, min, max, zeroAllowed, _ := thinkingRangeFromRegistry(model); found {
|
||||
if found, minBudget, maxBudget, zeroAllowed, _ := thinkingRangeFromRegistry(model); found {
|
||||
if budget == 0 {
|
||||
if zeroAllowed {
|
||||
return 0
|
||||
}
|
||||
return min
|
||||
return minBudget
|
||||
}
|
||||
if budget < min {
|
||||
return min
|
||||
if budget < minBudget {
|
||||
return minBudget
|
||||
}
|
||||
if budget > max {
|
||||
return max
|
||||
if budget > maxBudget {
|
||||
return maxBudget
|
||||
}
|
||||
return budget
|
||||
}
|
||||
@@ -105,3 +105,16 @@ func NormalizeReasoningEffortLevel(model, effort string) (string, bool) {
|
||||
}
|
||||
return "", false
|
||||
}
|
||||
|
||||
// IsOpenAICompatibilityModel reports whether the model is registered as an OpenAI-compatibility model.
|
||||
// These models may not advertise Thinking metadata in the registry.
|
||||
func IsOpenAICompatibilityModel(model string) bool {
|
||||
if model == "" {
|
||||
return false
|
||||
}
|
||||
info := registry.GetGlobalRegistry().GetModelInfo(model)
|
||||
if info == nil {
|
||||
return false
|
||||
}
|
||||
return strings.EqualFold(strings.TrimSpace(info.Type), "openai-compatibility")
|
||||
}
|
||||
|
||||
@@ -163,6 +163,11 @@ func ResolveThinkingConfigFromMetadata(model string, metadata map[string]any) (*
|
||||
if !matched {
|
||||
return nil, nil, false
|
||||
}
|
||||
// Level-based models (OpenAI-style) do not accept numeric thinking budgets in
|
||||
// Claude/Gemini-style protocols, so we don't derive budgets for them here.
|
||||
if ModelUsesThinkingLevels(model) {
|
||||
return nil, nil, false
|
||||
}
|
||||
|
||||
if budget == nil && effort != nil {
|
||||
if derived, ok := ThinkingEffortToBudget(model, *effort); ok {
|
||||
|
||||
@@ -14,6 +14,7 @@ import (
|
||||
"os"
|
||||
"path/filepath"
|
||||
"reflect"
|
||||
"runtime"
|
||||
"sort"
|
||||
"strings"
|
||||
"sync"
|
||||
@@ -62,6 +63,7 @@ type Watcher struct {
|
||||
reloadCallback func(*config.Config)
|
||||
watcher *fsnotify.Watcher
|
||||
lastAuthHashes map[string]string
|
||||
lastRemoveTimes map[string]time.Time
|
||||
lastConfigHash string
|
||||
authQueue chan<- AuthUpdate
|
||||
currentAuths map[string]*coreauth.Auth
|
||||
@@ -128,8 +130,9 @@ type AuthUpdate struct {
|
||||
const (
|
||||
// replaceCheckDelay is a short delay to allow atomic replace (rename) to settle
|
||||
// before deciding whether a Remove event indicates a real deletion.
|
||||
replaceCheckDelay = 50 * time.Millisecond
|
||||
configReloadDebounce = 150 * time.Millisecond
|
||||
replaceCheckDelay = 50 * time.Millisecond
|
||||
configReloadDebounce = 150 * time.Millisecond
|
||||
authRemoveDebounceWindow = 1 * time.Second
|
||||
)
|
||||
|
||||
// NewWatcher creates a new file watcher instance
|
||||
@@ -750,8 +753,9 @@ func (w *Watcher) authFileUnchanged(path string) (bool, error) {
|
||||
sum := sha256.Sum256(data)
|
||||
curHash := hex.EncodeToString(sum[:])
|
||||
|
||||
normalized := w.normalizeAuthPath(path)
|
||||
w.clientsMutex.RLock()
|
||||
prevHash, ok := w.lastAuthHashes[path]
|
||||
prevHash, ok := w.lastAuthHashes[normalized]
|
||||
w.clientsMutex.RUnlock()
|
||||
if ok && prevHash == curHash {
|
||||
return true, nil
|
||||
@@ -760,19 +764,63 @@ func (w *Watcher) authFileUnchanged(path string) (bool, error) {
|
||||
}
|
||||
|
||||
func (w *Watcher) isKnownAuthFile(path string) bool {
|
||||
normalized := w.normalizeAuthPath(path)
|
||||
w.clientsMutex.RLock()
|
||||
defer w.clientsMutex.RUnlock()
|
||||
_, ok := w.lastAuthHashes[path]
|
||||
_, ok := w.lastAuthHashes[normalized]
|
||||
return ok
|
||||
}
|
||||
|
||||
func (w *Watcher) normalizeAuthPath(path string) string {
|
||||
trimmed := strings.TrimSpace(path)
|
||||
if trimmed == "" {
|
||||
return ""
|
||||
}
|
||||
cleaned := filepath.Clean(trimmed)
|
||||
if runtime.GOOS == "windows" {
|
||||
cleaned = strings.TrimPrefix(cleaned, `\\?\`)
|
||||
cleaned = strings.ToLower(cleaned)
|
||||
}
|
||||
return cleaned
|
||||
}
|
||||
|
||||
func (w *Watcher) shouldDebounceRemove(normalizedPath string, now time.Time) bool {
|
||||
if normalizedPath == "" {
|
||||
return false
|
||||
}
|
||||
w.clientsMutex.Lock()
|
||||
if w.lastRemoveTimes == nil {
|
||||
w.lastRemoveTimes = make(map[string]time.Time)
|
||||
}
|
||||
if last, ok := w.lastRemoveTimes[normalizedPath]; ok {
|
||||
if now.Sub(last) < authRemoveDebounceWindow {
|
||||
w.clientsMutex.Unlock()
|
||||
return true
|
||||
}
|
||||
}
|
||||
w.lastRemoveTimes[normalizedPath] = now
|
||||
if len(w.lastRemoveTimes) > 128 {
|
||||
cutoff := now.Add(-2 * authRemoveDebounceWindow)
|
||||
for p, t := range w.lastRemoveTimes {
|
||||
if t.Before(cutoff) {
|
||||
delete(w.lastRemoveTimes, p)
|
||||
}
|
||||
}
|
||||
}
|
||||
w.clientsMutex.Unlock()
|
||||
return false
|
||||
}
|
||||
|
||||
// handleEvent processes individual file system events
|
||||
func (w *Watcher) handleEvent(event fsnotify.Event) {
|
||||
// Filter only relevant events: config file or auth-dir JSON files.
|
||||
configOps := fsnotify.Write | fsnotify.Create | fsnotify.Rename
|
||||
isConfigEvent := event.Name == w.configPath && event.Op&configOps != 0
|
||||
normalizedName := w.normalizeAuthPath(event.Name)
|
||||
normalizedConfigPath := w.normalizeAuthPath(w.configPath)
|
||||
normalizedAuthDir := w.normalizeAuthPath(w.authDir)
|
||||
isConfigEvent := normalizedName == normalizedConfigPath && event.Op&configOps != 0
|
||||
authOps := fsnotify.Create | fsnotify.Write | fsnotify.Remove | fsnotify.Rename
|
||||
isAuthJSON := strings.HasPrefix(event.Name, w.authDir) && strings.HasSuffix(event.Name, ".json") && event.Op&authOps != 0
|
||||
isAuthJSON := strings.HasPrefix(normalizedName, normalizedAuthDir) && strings.HasSuffix(normalizedName, ".json") && event.Op&authOps != 0
|
||||
|
||||
// Check for Kiro IDE token file changes
|
||||
isKiroIDEToken := w.isKiroIDETokenFile(event.Name) && event.Op&authOps != 0
|
||||
@@ -800,6 +848,10 @@ func (w *Watcher) handleEvent(event fsnotify.Event) {
|
||||
|
||||
// Handle auth directory changes incrementally (.json only)
|
||||
if event.Op&(fsnotify.Remove|fsnotify.Rename) != 0 {
|
||||
if w.shouldDebounceRemove(normalizedName, now) {
|
||||
log.Debugf("debouncing remove event for %s", filepath.Base(event.Name))
|
||||
return
|
||||
}
|
||||
// Atomic replace on some platforms may surface as Rename (or Remove) before the new file is ready.
|
||||
// Wait briefly; if the path exists again, treat as an update instead of removal.
|
||||
time.Sleep(replaceCheckDelay)
|
||||
@@ -1062,7 +1114,8 @@ func (w *Watcher) reloadClients(rescanAuth bool, affectedOAuthProviders []string
|
||||
if !info.IsDir() && strings.HasSuffix(strings.ToLower(info.Name()), ".json") {
|
||||
if data, errReadFile := os.ReadFile(path); errReadFile == nil && len(data) > 0 {
|
||||
sum := sha256.Sum256(data)
|
||||
w.lastAuthHashes[path] = hex.EncodeToString(sum[:])
|
||||
normalizedPath := w.normalizeAuthPath(path)
|
||||
w.lastAuthHashes[normalizedPath] = hex.EncodeToString(sum[:])
|
||||
}
|
||||
}
|
||||
return nil
|
||||
@@ -1109,6 +1162,7 @@ func (w *Watcher) addOrUpdateClient(path string) {
|
||||
|
||||
sum := sha256.Sum256(data)
|
||||
curHash := hex.EncodeToString(sum[:])
|
||||
normalized := w.normalizeAuthPath(path)
|
||||
|
||||
w.clientsMutex.Lock()
|
||||
|
||||
@@ -1118,14 +1172,14 @@ func (w *Watcher) addOrUpdateClient(path string) {
|
||||
w.clientsMutex.Unlock()
|
||||
return
|
||||
}
|
||||
if prev, ok := w.lastAuthHashes[path]; ok && prev == curHash {
|
||||
if prev, ok := w.lastAuthHashes[normalized]; ok && prev == curHash {
|
||||
log.Debugf("auth file unchanged (hash match), skipping reload: %s", filepath.Base(path))
|
||||
w.clientsMutex.Unlock()
|
||||
return
|
||||
}
|
||||
|
||||
// Update hash cache
|
||||
w.lastAuthHashes[path] = curHash
|
||||
w.lastAuthHashes[normalized] = curHash
|
||||
|
||||
w.clientsMutex.Unlock() // Unlock before the callback
|
||||
|
||||
@@ -1140,10 +1194,11 @@ func (w *Watcher) addOrUpdateClient(path string) {
|
||||
|
||||
// removeClient handles the removal of a single client.
|
||||
func (w *Watcher) removeClient(path string) {
|
||||
normalized := w.normalizeAuthPath(path)
|
||||
w.clientsMutex.Lock()
|
||||
|
||||
cfg := w.config
|
||||
delete(w.lastAuthHashes, path)
|
||||
delete(w.lastAuthHashes, normalized)
|
||||
|
||||
w.clientsMutex.Unlock() // Release the lock before the callback
|
||||
|
||||
@@ -1317,6 +1372,12 @@ func (w *Watcher) SnapshotCoreAuths() []*coreauth.Auth {
|
||||
if kk.AgentTaskType != "" {
|
||||
attrs["agent_task_type"] = kk.AgentTaskType
|
||||
}
|
||||
if kk.PreferredEndpoint != "" {
|
||||
attrs["preferred_endpoint"] = kk.PreferredEndpoint
|
||||
} else if cfg.KiroPreferredEndpoint != "" {
|
||||
// Apply global default if not overridden by specific key
|
||||
attrs["preferred_endpoint"] = cfg.KiroPreferredEndpoint
|
||||
}
|
||||
if refreshToken != "" {
|
||||
attrs["refresh_token"] = refreshToken
|
||||
}
|
||||
@@ -1532,6 +1593,17 @@ func (w *Watcher) SnapshotCoreAuths() []*coreauth.Auth {
|
||||
a.NextRefreshAfter = expiresAt.Add(-30 * time.Minute)
|
||||
}
|
||||
}
|
||||
|
||||
// Apply global preferred endpoint setting if not present in metadata
|
||||
if cfg.KiroPreferredEndpoint != "" {
|
||||
// Check if already set in metadata (which takes precedence in executor)
|
||||
if _, hasMeta := metadata["preferred_endpoint"]; !hasMeta {
|
||||
if a.Attributes == nil {
|
||||
a.Attributes = make(map[string]string)
|
||||
}
|
||||
a.Attributes["preferred_endpoint"] = cfg.KiroPreferredEndpoint
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
applyAuthExcludedModelsMeta(a, cfg, nil, "oauth")
|
||||
|
||||
@@ -7,7 +7,6 @@
|
||||
package claude
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"bytes"
|
||||
"compress/gzip"
|
||||
"context"
|
||||
@@ -219,52 +218,24 @@ func (h *ClaudeCodeAPIHandler) handleStreamingResponse(c *gin.Context, rawJSON [
|
||||
}
|
||||
|
||||
func (h *ClaudeCodeAPIHandler) forwardClaudeStream(c *gin.Context, flusher http.Flusher, cancel func(error), data <-chan []byte, errs <-chan *interfaces.ErrorMessage) {
|
||||
// v6.1: Intelligent Buffered Streamer strategy
|
||||
// Enhanced buffering with larger buffer size (16KB) and longer flush interval (120ms).
|
||||
// Smart flush only when buffer is sufficiently filled (≥50%), dramatically reducing
|
||||
// flush frequency from ~12.5Hz to ~5-8Hz while maintaining low latency.
|
||||
writer := bufio.NewWriterSize(c.Writer, 16*1024) // 4KB → 16KB
|
||||
ticker := time.NewTicker(120 * time.Millisecond) // 80ms → 120ms
|
||||
defer ticker.Stop()
|
||||
|
||||
var chunkIdx int
|
||||
|
||||
// OpenAI-style stream forwarding: write each SSE chunk and flush immediately.
|
||||
// This guarantees clients see incremental output even for small responses.
|
||||
for {
|
||||
select {
|
||||
case <-c.Request.Context().Done():
|
||||
// Context cancelled, flush any remaining data before exit
|
||||
_ = writer.Flush()
|
||||
cancel(c.Request.Context().Err())
|
||||
return
|
||||
|
||||
case <-ticker.C:
|
||||
// Smart flush: only flush when buffer has sufficient data (≥50% full)
|
||||
// This reduces flush frequency while ensuring data flows naturally
|
||||
buffered := writer.Buffered()
|
||||
if buffered >= 8*1024 { // At least 8KB (50% of 16KB buffer)
|
||||
if err := writer.Flush(); err != nil {
|
||||
// Error flushing, cancel and return
|
||||
cancel(err)
|
||||
return
|
||||
}
|
||||
flusher.Flush() // Also flush the underlying http.ResponseWriter
|
||||
}
|
||||
|
||||
case chunk, ok := <-data:
|
||||
if !ok {
|
||||
// Stream ended, flush remaining data
|
||||
_ = writer.Flush()
|
||||
flusher.Flush()
|
||||
cancel(nil)
|
||||
return
|
||||
}
|
||||
|
||||
// Forward the complete SSE event block directly (already formatted by the translator).
|
||||
// The translator returns a complete SSE-compliant event block, including event:, data:, and separators.
|
||||
// The handler just needs to forward it without reassembly.
|
||||
if len(chunk) > 0 {
|
||||
_, _ = writer.Write(chunk)
|
||||
_, _ = c.Writer.Write(chunk)
|
||||
flusher.Flush()
|
||||
}
|
||||
chunkIdx++
|
||||
|
||||
case errMsg, ok := <-errs:
|
||||
if !ok {
|
||||
@@ -276,21 +247,20 @@ func (h *ClaudeCodeAPIHandler) forwardClaudeStream(c *gin.Context, flusher http.
|
||||
status = errMsg.StatusCode
|
||||
}
|
||||
c.Status(status)
|
||||
|
||||
// An error occurred: emit as a proper SSE error event
|
||||
errorBytes, _ := json.Marshal(h.toClaudeError(errMsg))
|
||||
_, _ = writer.WriteString("event: error\n")
|
||||
_, _ = writer.WriteString("data: ")
|
||||
_, _ = writer.Write(errorBytes)
|
||||
_, _ = writer.WriteString("\n\n")
|
||||
_ = writer.Flush()
|
||||
_, _ = fmt.Fprintf(c.Writer, "event: error\ndata: %s\n\n", errorBytes)
|
||||
flusher.Flush()
|
||||
}
|
||||
|
||||
var execErr error
|
||||
if errMsg != nil {
|
||||
execErr = errMsg.Error
|
||||
}
|
||||
cancel(execErr)
|
||||
return
|
||||
case <-time.After(500 * time.Millisecond):
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -136,19 +136,29 @@ func (h *BaseAPIHandler) GetContextWithCancel(handler interfaces.APIHandler, c *
|
||||
newCtx = context.WithValue(newCtx, "gin", c)
|
||||
newCtx = context.WithValue(newCtx, "handler", handler)
|
||||
return newCtx, func(params ...interface{}) {
|
||||
if h.Cfg.RequestLog {
|
||||
if len(params) == 1 {
|
||||
data := params[0]
|
||||
switch data.(type) {
|
||||
case []byte:
|
||||
appendAPIResponse(c, data.([]byte))
|
||||
case error:
|
||||
appendAPIResponse(c, []byte(data.(error).Error()))
|
||||
case string:
|
||||
appendAPIResponse(c, []byte(data.(string)))
|
||||
case bool:
|
||||
case nil:
|
||||
if h.Cfg.RequestLog && len(params) == 1 {
|
||||
var payload []byte
|
||||
switch data := params[0].(type) {
|
||||
case []byte:
|
||||
payload = data
|
||||
case error:
|
||||
if data != nil {
|
||||
payload = []byte(data.Error())
|
||||
}
|
||||
case string:
|
||||
payload = []byte(data)
|
||||
}
|
||||
if len(payload) > 0 {
|
||||
if existing, exists := c.Get("API_RESPONSE"); exists {
|
||||
if existingBytes, ok := existing.([]byte); ok && len(existingBytes) > 0 {
|
||||
trimmedPayload := bytes.TrimSpace(payload)
|
||||
if len(trimmedPayload) > 0 && bytes.Contains(existingBytes, trimmedPayload) {
|
||||
cancel()
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
appendAPIResponse(c, payload)
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -107,7 +107,7 @@ func (a *IFlowAuthenticator) Login(ctx context.Context, cfg *config.Config, opts
|
||||
return nil, fmt.Errorf("iflow authentication failed: missing account identifier")
|
||||
}
|
||||
|
||||
fileName := fmt.Sprintf("iflow-%s.json", email)
|
||||
fileName := fmt.Sprintf("iflow-%s-%d.json", email, time.Now().Unix())
|
||||
metadata := map[string]any{
|
||||
"email": email,
|
||||
"api_key": tokenStorage.APIKey,
|
||||
|
||||
@@ -47,8 +47,9 @@ func (a *KiroAuthenticator) Provider() string {
|
||||
}
|
||||
|
||||
// RefreshLead indicates how soon before expiry a refresh should be attempted.
|
||||
// Set to 5 minutes to match Antigravity and avoid frequent refresh checks while still ensuring timely token refresh.
|
||||
func (a *KiroAuthenticator) RefreshLead() *time.Duration {
|
||||
d := 30 * time.Minute
|
||||
d := 5 * time.Minute
|
||||
return &d
|
||||
}
|
||||
|
||||
@@ -103,7 +104,8 @@ func (a *KiroAuthenticator) Login(ctx context.Context, cfg *config.Config, opts
|
||||
"source": "aws-builder-id",
|
||||
"email": tokenData.Email,
|
||||
},
|
||||
NextRefreshAfter: expiresAt.Add(-30 * time.Minute),
|
||||
// NextRefreshAfter is aligned with RefreshLead (5min)
|
||||
NextRefreshAfter: expiresAt.Add(-5 * time.Minute),
|
||||
}
|
||||
|
||||
if tokenData.Email != "" {
|
||||
@@ -165,7 +167,8 @@ func (a *KiroAuthenticator) LoginWithGoogle(ctx context.Context, cfg *config.Con
|
||||
"source": "google-oauth",
|
||||
"email": tokenData.Email,
|
||||
},
|
||||
NextRefreshAfter: expiresAt.Add(-30 * time.Minute),
|
||||
// NextRefreshAfter is aligned with RefreshLead (5min)
|
||||
NextRefreshAfter: expiresAt.Add(-5 * time.Minute),
|
||||
}
|
||||
|
||||
if tokenData.Email != "" {
|
||||
@@ -227,7 +230,8 @@ func (a *KiroAuthenticator) LoginWithGitHub(ctx context.Context, cfg *config.Con
|
||||
"source": "github-oauth",
|
||||
"email": tokenData.Email,
|
||||
},
|
||||
NextRefreshAfter: expiresAt.Add(-30 * time.Minute),
|
||||
// NextRefreshAfter is aligned with RefreshLead (5min)
|
||||
NextRefreshAfter: expiresAt.Add(-5 * time.Minute),
|
||||
}
|
||||
|
||||
if tokenData.Email != "" {
|
||||
@@ -291,7 +295,8 @@ func (a *KiroAuthenticator) ImportFromKiroIDE(ctx context.Context, cfg *config.C
|
||||
"source": "kiro-ide-import",
|
||||
"email": tokenData.Email,
|
||||
},
|
||||
NextRefreshAfter: expiresAt.Add(-30 * time.Minute),
|
||||
// NextRefreshAfter is aligned with RefreshLead (5min)
|
||||
NextRefreshAfter: expiresAt.Add(-5 * time.Minute),
|
||||
}
|
||||
|
||||
// Display the email if extracted
|
||||
@@ -351,7 +356,8 @@ func (a *KiroAuthenticator) Refresh(ctx context.Context, cfg *config.Config, aut
|
||||
updated.Metadata["refresh_token"] = tokenData.RefreshToken
|
||||
updated.Metadata["expires_at"] = tokenData.ExpiresAt
|
||||
updated.Metadata["last_refresh"] = now.Format(time.RFC3339) // For double-check optimization
|
||||
updated.NextRefreshAfter = expiresAt.Add(-30 * time.Minute)
|
||||
// NextRefreshAfter is aligned with RefreshLead (5min)
|
||||
updated.NextRefreshAfter = expiresAt.Add(-5 * time.Minute)
|
||||
|
||||
return updated, nil
|
||||
}
|
||||
|
||||
@@ -40,7 +40,7 @@ type RefreshEvaluator interface {
|
||||
const (
|
||||
refreshCheckInterval = 5 * time.Second
|
||||
refreshPendingBackoff = time.Minute
|
||||
refreshFailureBackoff = 5 * time.Minute
|
||||
refreshFailureBackoff = 1 * time.Minute
|
||||
quotaBackoffBase = time.Second
|
||||
quotaBackoffMax = 30 * time.Minute
|
||||
)
|
||||
@@ -375,10 +375,19 @@ func (m *Manager) executeWithProvider(ctx context.Context, provider string, req
|
||||
}
|
||||
|
||||
accountType, accountInfo := auth.AccountInfo()
|
||||
proxyInfo := auth.ProxyInfo()
|
||||
if accountType == "api_key" {
|
||||
log.Debugf("Use API key %s for model %s", util.HideAPIKey(accountInfo), req.Model)
|
||||
if proxyInfo != "" {
|
||||
log.Debugf("Use API key %s for model %s %s", util.HideAPIKey(accountInfo), req.Model, proxyInfo)
|
||||
} else {
|
||||
log.Debugf("Use API key %s for model %s", util.HideAPIKey(accountInfo), req.Model)
|
||||
}
|
||||
} else if accountType == "oauth" {
|
||||
log.Debugf("Use OAuth %s for model %s", accountInfo, req.Model)
|
||||
if proxyInfo != "" {
|
||||
log.Debugf("Use OAuth %s for model %s %s", accountInfo, req.Model, proxyInfo)
|
||||
} else {
|
||||
log.Debugf("Use OAuth %s for model %s", accountInfo, req.Model)
|
||||
}
|
||||
}
|
||||
|
||||
tried[auth.ID] = struct{}{}
|
||||
@@ -423,10 +432,19 @@ func (m *Manager) executeCountWithProvider(ctx context.Context, provider string,
|
||||
}
|
||||
|
||||
accountType, accountInfo := auth.AccountInfo()
|
||||
proxyInfo := auth.ProxyInfo()
|
||||
if accountType == "api_key" {
|
||||
log.Debugf("Use API key %s for model %s", util.HideAPIKey(accountInfo), req.Model)
|
||||
if proxyInfo != "" {
|
||||
log.Debugf("Use API key %s for model %s %s", util.HideAPIKey(accountInfo), req.Model, proxyInfo)
|
||||
} else {
|
||||
log.Debugf("Use API key %s for model %s", util.HideAPIKey(accountInfo), req.Model)
|
||||
}
|
||||
} else if accountType == "oauth" {
|
||||
log.Debugf("Use OAuth %s for model %s", accountInfo, req.Model)
|
||||
if proxyInfo != "" {
|
||||
log.Debugf("Use OAuth %s for model %s %s", accountInfo, req.Model, proxyInfo)
|
||||
} else {
|
||||
log.Debugf("Use OAuth %s for model %s", accountInfo, req.Model)
|
||||
}
|
||||
}
|
||||
|
||||
tried[auth.ID] = struct{}{}
|
||||
@@ -471,10 +489,19 @@ func (m *Manager) executeStreamWithProvider(ctx context.Context, provider string
|
||||
}
|
||||
|
||||
accountType, accountInfo := auth.AccountInfo()
|
||||
proxyInfo := auth.ProxyInfo()
|
||||
if accountType == "api_key" {
|
||||
log.Debugf("Use API key %s for model %s", util.HideAPIKey(accountInfo), req.Model)
|
||||
if proxyInfo != "" {
|
||||
log.Debugf("Use API key %s for model %s %s", util.HideAPIKey(accountInfo), req.Model, proxyInfo)
|
||||
} else {
|
||||
log.Debugf("Use API key %s for model %s", util.HideAPIKey(accountInfo), req.Model)
|
||||
}
|
||||
} else if accountType == "oauth" {
|
||||
log.Debugf("Use OAuth %s for model %s", accountInfo, req.Model)
|
||||
if proxyInfo != "" {
|
||||
log.Debugf("Use OAuth %s for model %s %s", accountInfo, req.Model, proxyInfo)
|
||||
} else {
|
||||
log.Debugf("Use OAuth %s for model %s", accountInfo, req.Model)
|
||||
}
|
||||
}
|
||||
|
||||
tried[auth.ID] = struct{}{}
|
||||
@@ -1471,7 +1498,9 @@ func (m *Manager) refreshAuth(ctx context.Context, id string) {
|
||||
updated.Runtime = auth.Runtime
|
||||
}
|
||||
updated.LastRefreshedAt = now
|
||||
updated.NextRefreshAfter = time.Time{}
|
||||
// Preserve NextRefreshAfter set by the Authenticator
|
||||
// If the Authenticator set a reasonable refresh time, it should not be overwritten
|
||||
// If the Authenticator did not set it (zero value), shouldRefresh will use default logic
|
||||
updated.LastError = nil
|
||||
updated.UpdatedAt = now
|
||||
_, _ = m.Update(ctx, updated)
|
||||
|
||||
@@ -157,6 +157,20 @@ func (m *ModelState) Clone() *ModelState {
|
||||
return ©State
|
||||
}
|
||||
|
||||
func (a *Auth) ProxyInfo() string {
|
||||
if a == nil {
|
||||
return ""
|
||||
}
|
||||
proxyStr := strings.TrimSpace(a.ProxyURL)
|
||||
if proxyStr == "" {
|
||||
return ""
|
||||
}
|
||||
if idx := strings.Index(proxyStr, "://"); idx > 0 {
|
||||
return "via " + proxyStr[:idx] + " proxy"
|
||||
}
|
||||
return "via proxy"
|
||||
}
|
||||
|
||||
func (a *Auth) AccountInfo() (string, string) {
|
||||
if a == nil {
|
||||
return "", ""
|
||||
|
||||
760
test/thinking_conversion_test.go
Normal file
760
test/thinking_conversion_test.go
Normal file
@@ -0,0 +1,760 @@
|
||||
package test
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
_ "github.com/router-for-me/CLIProxyAPI/v6/internal/translator"
|
||||
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/registry"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/runtime/executor"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/util"
|
||||
sdktranslator "github.com/router-for-me/CLIProxyAPI/v6/sdk/translator"
|
||||
"github.com/tidwall/gjson"
|
||||
"github.com/tidwall/sjson"
|
||||
)
|
||||
|
||||
// isOpenAICompatModel returns true if the model is configured as an OpenAI-compatible
|
||||
// model that should have reasoning effort passed through even if not in registry.
|
||||
// This simulates the allowCompat behavior from OpenAICompatExecutor.
|
||||
func isOpenAICompatModel(model string) bool {
|
||||
return model == "openai-compat"
|
||||
}
|
||||
|
||||
// registerCoreModels loads representative models across providers into the registry
|
||||
// so NormalizeThinkingBudget and level validation use real ranges.
|
||||
func registerCoreModels(t *testing.T) func() {
|
||||
t.Helper()
|
||||
reg := registry.GetGlobalRegistry()
|
||||
uid := fmt.Sprintf("thinking-core-%d", time.Now().UnixNano())
|
||||
reg.RegisterClient(uid+"-gemini", "gemini", registry.GetGeminiModels())
|
||||
reg.RegisterClient(uid+"-claude", "claude", registry.GetClaudeModels())
|
||||
reg.RegisterClient(uid+"-openai", "codex", registry.GetOpenAIModels())
|
||||
reg.RegisterClient(uid+"-qwen", "qwen", registry.GetQwenModels())
|
||||
// Custom openai-compatible model with forced thinking suffix passthrough.
|
||||
// No Thinking field - simulates an external model added via openai-compat
|
||||
// where the registry has no knowledge of its thinking capabilities.
|
||||
// The allowCompat flag should preserve reasoning effort for such models.
|
||||
customOpenAIModels := []*registry.ModelInfo{
|
||||
{
|
||||
ID: "openai-compat",
|
||||
Object: "model",
|
||||
Created: 1700000000,
|
||||
OwnedBy: "custom-provider",
|
||||
Type: "openai",
|
||||
DisplayName: "OpenAI Compatible Model",
|
||||
Description: "OpenAI-compatible model with forced thinking suffix support",
|
||||
},
|
||||
}
|
||||
reg.RegisterClient(uid+"-custom-openai", "codex", customOpenAIModels)
|
||||
return func() {
|
||||
reg.UnregisterClient(uid + "-gemini")
|
||||
reg.UnregisterClient(uid + "-claude")
|
||||
reg.UnregisterClient(uid + "-openai")
|
||||
reg.UnregisterClient(uid + "-qwen")
|
||||
reg.UnregisterClient(uid + "-custom-openai")
|
||||
}
|
||||
}
|
||||
|
||||
var (
|
||||
thinkingTestModels = []string{
|
||||
"gpt-5", // level-based thinking model
|
||||
"gemini-2.5-pro", // numeric-budget thinking model
|
||||
"qwen3-code-plus", // no thinking support
|
||||
"openai-compat", // allowCompat=true (OpenAI-compatible channel)
|
||||
}
|
||||
thinkingTestFromProtocols = []string{"openai", "claude", "gemini", "openai-response"}
|
||||
thinkingTestToProtocols = []string{"gemini", "claude", "openai", "codex"}
|
||||
|
||||
// Numeric budgets and their level equivalents:
|
||||
// -1 -> auto
|
||||
// 0 -> none
|
||||
// 1..1024 -> low
|
||||
// 1025..8192 -> medium
|
||||
// 8193..24576 -> high
|
||||
// >24576 -> model highest level (right-most in Levels)
|
||||
thinkingNumericSamples = []int{-1, 0, 1023, 1025, 8193, 64000}
|
||||
|
||||
// Levels and their numeric equivalents:
|
||||
// auto -> -1
|
||||
// none -> 0
|
||||
// minimal -> 512
|
||||
// low -> 1024
|
||||
// medium -> 8192
|
||||
// high -> 24576
|
||||
// xhigh -> 32768
|
||||
// invalid -> invalid (no mapping)
|
||||
thinkingLevelSamples = []string{"auto", "none", "minimal", "low", "medium", "high", "xhigh", "invalid"}
|
||||
)
|
||||
|
||||
func buildRawPayload(fromProtocol, modelWithSuffix string) []byte {
|
||||
switch fromProtocol {
|
||||
case "gemini":
|
||||
return []byte(fmt.Sprintf(`{"model":"%s","contents":[{"role":"user","parts":[{"text":"hi"}]}]}`, modelWithSuffix))
|
||||
case "openai-response":
|
||||
return []byte(fmt.Sprintf(`{"model":"%s","input":[{"role":"user","content":[{"type":"text","text":"hi"}]}]}`, modelWithSuffix))
|
||||
default: // openai / claude and other chat-style payloads
|
||||
return []byte(fmt.Sprintf(`{"model":"%s","messages":[{"role":"user","content":"hi"}]}`, modelWithSuffix))
|
||||
}
|
||||
}
|
||||
|
||||
// normalizeCodexPayload mirrors codex_executor's reasoning + streaming tweaks.
|
||||
func normalizeCodexPayload(body []byte, upstreamModel string, allowCompat bool) ([]byte, error) {
|
||||
body = executor.NormalizeThinkingConfig(body, upstreamModel, allowCompat)
|
||||
if err := executor.ValidateThinkingConfig(body, upstreamModel); err != nil {
|
||||
return body, err
|
||||
}
|
||||
body, _ = sjson.SetBytes(body, "model", upstreamModel)
|
||||
body, _ = sjson.SetBytes(body, "stream", true)
|
||||
body, _ = sjson.DeleteBytes(body, "previous_response_id")
|
||||
return body, nil
|
||||
}
|
||||
|
||||
// buildBodyForProtocol runs a minimal request through the same translation and
|
||||
// thinking pipeline used in executors for the given target protocol.
|
||||
func buildBodyForProtocol(t *testing.T, fromProtocol, toProtocol, modelWithSuffix string) ([]byte, error) {
|
||||
t.Helper()
|
||||
normalizedModel, metadata := util.NormalizeThinkingModel(modelWithSuffix)
|
||||
upstreamModel := util.ResolveOriginalModel(normalizedModel, metadata)
|
||||
raw := buildRawPayload(fromProtocol, modelWithSuffix)
|
||||
stream := fromProtocol != toProtocol
|
||||
|
||||
body := sdktranslator.TranslateRequest(
|
||||
sdktranslator.FromString(fromProtocol),
|
||||
sdktranslator.FromString(toProtocol),
|
||||
normalizedModel,
|
||||
raw,
|
||||
stream,
|
||||
)
|
||||
|
||||
var err error
|
||||
allowCompat := isOpenAICompatModel(normalizedModel)
|
||||
switch toProtocol {
|
||||
case "gemini":
|
||||
body = executor.ApplyThinkingMetadata(body, metadata, normalizedModel)
|
||||
body = util.ApplyDefaultThinkingIfNeeded(normalizedModel, body)
|
||||
body = util.NormalizeGeminiThinkingBudget(normalizedModel, body)
|
||||
body = util.StripThinkingConfigIfUnsupported(normalizedModel, body)
|
||||
case "claude":
|
||||
if budget, ok := util.ResolveClaudeThinkingConfig(normalizedModel, metadata); ok {
|
||||
body = util.ApplyClaudeThinkingConfig(body, budget)
|
||||
}
|
||||
case "openai":
|
||||
body = executor.ApplyReasoningEffortMetadata(body, metadata, normalizedModel, "reasoning_effort", allowCompat)
|
||||
body = executor.NormalizeThinkingConfig(body, upstreamModel, allowCompat)
|
||||
err = executor.ValidateThinkingConfig(body, upstreamModel)
|
||||
case "codex": // OpenAI responses / codex
|
||||
// Codex does not support allowCompat; always use false.
|
||||
body = executor.ApplyReasoningEffortMetadata(body, metadata, normalizedModel, "reasoning.effort", false)
|
||||
// Mirror CodexExecutor final normalization and model override so tests log the final body.
|
||||
body, err = normalizeCodexPayload(body, upstreamModel, false)
|
||||
default:
|
||||
}
|
||||
|
||||
// Mirror executor behavior: final payload uses the upstream (base) model name.
|
||||
if upstreamModel != "" {
|
||||
body, _ = sjson.SetBytes(body, "model", upstreamModel)
|
||||
}
|
||||
|
||||
// For tests we only keep model + thinking-related fields to avoid noise.
|
||||
body = filterThinkingBody(toProtocol, body, upstreamModel, normalizedModel)
|
||||
return body, err
|
||||
}
|
||||
|
||||
// filterThinkingBody projects the translated payload down to only model and
|
||||
// thinking-related fields for the given target protocol.
|
||||
func filterThinkingBody(toProtocol string, body []byte, upstreamModel, normalizedModel string) []byte {
|
||||
if len(body) == 0 {
|
||||
return body
|
||||
}
|
||||
out := []byte(`{}`)
|
||||
|
||||
// Preserve model if present, otherwise fall back to upstream/normalized model.
|
||||
if m := gjson.GetBytes(body, "model"); m.Exists() {
|
||||
out, _ = sjson.SetBytes(out, "model", m.Value())
|
||||
} else if upstreamModel != "" {
|
||||
out, _ = sjson.SetBytes(out, "model", upstreamModel)
|
||||
} else if normalizedModel != "" {
|
||||
out, _ = sjson.SetBytes(out, "model", normalizedModel)
|
||||
}
|
||||
|
||||
switch toProtocol {
|
||||
case "gemini":
|
||||
if tc := gjson.GetBytes(body, "generationConfig.thinkingConfig"); tc.Exists() {
|
||||
out, _ = sjson.SetRawBytes(out, "generationConfig.thinkingConfig", []byte(tc.Raw))
|
||||
}
|
||||
case "claude":
|
||||
if tcfg := gjson.GetBytes(body, "thinking"); tcfg.Exists() {
|
||||
out, _ = sjson.SetRawBytes(out, "thinking", []byte(tcfg.Raw))
|
||||
}
|
||||
case "openai":
|
||||
if re := gjson.GetBytes(body, "reasoning_effort"); re.Exists() {
|
||||
out, _ = sjson.SetBytes(out, "reasoning_effort", re.Value())
|
||||
}
|
||||
case "codex":
|
||||
if re := gjson.GetBytes(body, "reasoning.effort"); re.Exists() {
|
||||
out, _ = sjson.SetBytes(out, "reasoning.effort", re.Value())
|
||||
}
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
func TestThinkingConversionsAcrossProtocolsAndModels(t *testing.T) {
|
||||
cleanup := registerCoreModels(t)
|
||||
defer cleanup()
|
||||
|
||||
type scenario struct {
|
||||
name string
|
||||
modelSuffix string
|
||||
}
|
||||
|
||||
numericName := func(budget int) string {
|
||||
if budget < 0 {
|
||||
return "numeric-neg1"
|
||||
}
|
||||
return fmt.Sprintf("numeric-%d", budget)
|
||||
}
|
||||
|
||||
for _, model := range thinkingTestModels {
|
||||
_ = registry.GetGlobalRegistry().GetModelInfo(model)
|
||||
|
||||
for _, from := range thinkingTestFromProtocols {
|
||||
// Scenario selection follows protocol semantics:
|
||||
// - OpenAI-style protocols (openai/openai-response) express thinking as levels.
|
||||
// - Claude/Gemini-style protocols express thinking as numeric budgets.
|
||||
cases := []scenario{
|
||||
{name: "no-suffix", modelSuffix: model},
|
||||
}
|
||||
if from == "openai" || from == "openai-response" {
|
||||
for _, lvl := range thinkingLevelSamples {
|
||||
cases = append(cases, scenario{
|
||||
name: "level-" + lvl,
|
||||
modelSuffix: fmt.Sprintf("%s(%s)", model, lvl),
|
||||
})
|
||||
}
|
||||
} else { // claude or gemini
|
||||
for _, budget := range thinkingNumericSamples {
|
||||
budget := budget
|
||||
cases = append(cases, scenario{
|
||||
name: numericName(budget),
|
||||
modelSuffix: fmt.Sprintf("%s(%d)", model, budget),
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
for _, to := range thinkingTestToProtocols {
|
||||
if from == to {
|
||||
continue
|
||||
}
|
||||
t.Logf("─────────────────────────────────────────────────────────────────────────────────")
|
||||
t.Logf(" %s -> %s | model: %s", from, to, model)
|
||||
t.Logf("─────────────────────────────────────────────────────────────────────────────────")
|
||||
for _, cs := range cases {
|
||||
from := from
|
||||
to := to
|
||||
cs := cs
|
||||
testName := fmt.Sprintf("%s->%s/%s/%s", from, to, model, cs.name)
|
||||
t.Run(testName, func(t *testing.T) {
|
||||
normalizedModel, metadata := util.NormalizeThinkingModel(cs.modelSuffix)
|
||||
expectPresent, expectValue, expectErr := func() (bool, string, bool) {
|
||||
switch to {
|
||||
case "gemini":
|
||||
budget, include, ok := util.ResolveThinkingConfigFromMetadata(normalizedModel, metadata)
|
||||
if !ok || !util.ModelSupportsThinking(normalizedModel) {
|
||||
return false, "", false
|
||||
}
|
||||
if include != nil && !*include {
|
||||
return false, "", false
|
||||
}
|
||||
if budget == nil {
|
||||
return false, "", false
|
||||
}
|
||||
norm := util.NormalizeThinkingBudget(normalizedModel, *budget)
|
||||
return true, fmt.Sprintf("%d", norm), false
|
||||
case "claude":
|
||||
if !util.ModelSupportsThinking(normalizedModel) {
|
||||
return false, "", false
|
||||
}
|
||||
budget, ok := util.ResolveClaudeThinkingConfig(normalizedModel, metadata)
|
||||
if !ok || budget == nil {
|
||||
return false, "", false
|
||||
}
|
||||
return true, fmt.Sprintf("%d", *budget), false
|
||||
case "openai":
|
||||
allowCompat := isOpenAICompatModel(normalizedModel)
|
||||
if !util.ModelSupportsThinking(normalizedModel) && !allowCompat {
|
||||
return false, "", false
|
||||
}
|
||||
// For allowCompat models, pass through effort directly without validation
|
||||
if allowCompat {
|
||||
effort, ok := util.ReasoningEffortFromMetadata(metadata)
|
||||
if ok && strings.TrimSpace(effort) != "" {
|
||||
return true, strings.ToLower(strings.TrimSpace(effort)), false
|
||||
}
|
||||
// Check numeric budget fallback for allowCompat
|
||||
if budget, _, _, matched := util.ThinkingFromMetadata(metadata); matched && budget != nil {
|
||||
if mapped, okMap := util.OpenAIThinkingBudgetToEffort(normalizedModel, *budget); okMap && mapped != "" {
|
||||
return true, mapped, false
|
||||
}
|
||||
}
|
||||
return false, "", false
|
||||
}
|
||||
if !util.ModelUsesThinkingLevels(normalizedModel) {
|
||||
// Non-levels models don't support effort strings in openai
|
||||
return false, "", false
|
||||
}
|
||||
effort, ok := util.ReasoningEffortFromMetadata(metadata)
|
||||
if !ok || strings.TrimSpace(effort) == "" {
|
||||
if budget, _, _, matched := util.ThinkingFromMetadata(metadata); matched && budget != nil {
|
||||
if mapped, okMap := util.OpenAIThinkingBudgetToEffort(normalizedModel, *budget); okMap {
|
||||
effort = mapped
|
||||
ok = true
|
||||
}
|
||||
}
|
||||
}
|
||||
if !ok || strings.TrimSpace(effort) == "" {
|
||||
return false, "", false
|
||||
}
|
||||
effort = strings.ToLower(strings.TrimSpace(effort))
|
||||
if normalized, okLevel := util.NormalizeReasoningEffortLevel(normalizedModel, effort); okLevel {
|
||||
return true, normalized, false
|
||||
}
|
||||
return false, "", true // validation would fail
|
||||
case "codex":
|
||||
// Codex does not support allowCompat; require thinking-capable level models.
|
||||
if !util.ModelSupportsThinking(normalizedModel) || !util.ModelUsesThinkingLevels(normalizedModel) {
|
||||
return false, "", false
|
||||
}
|
||||
effort, ok := util.ReasoningEffortFromMetadata(metadata)
|
||||
if ok && strings.TrimSpace(effort) != "" {
|
||||
effort = strings.ToLower(strings.TrimSpace(effort))
|
||||
if normalized, okLevel := util.NormalizeReasoningEffortLevel(normalizedModel, effort); okLevel {
|
||||
return true, normalized, false
|
||||
}
|
||||
return false, "", true
|
||||
}
|
||||
if budget, _, _, matched := util.ThinkingFromMetadata(metadata); matched && budget != nil {
|
||||
if mapped, okMap := util.OpenAIThinkingBudgetToEffort(normalizedModel, *budget); okMap && mapped != "" {
|
||||
mapped = strings.ToLower(strings.TrimSpace(mapped))
|
||||
if normalized, okLevel := util.NormalizeReasoningEffortLevel(normalizedModel, mapped); okLevel {
|
||||
return true, normalized, false
|
||||
}
|
||||
return false, "", true
|
||||
}
|
||||
}
|
||||
if from != "openai-response" {
|
||||
// Codex translators default reasoning.effort to "medium" when
|
||||
// no explicit thinking suffix/metadata is provided.
|
||||
return true, "medium", false
|
||||
}
|
||||
return false, "", false
|
||||
default:
|
||||
return false, "", false
|
||||
}
|
||||
}()
|
||||
|
||||
body, err := buildBodyForProtocol(t, from, to, cs.modelSuffix)
|
||||
actualPresent, actualValue := func() (bool, string) {
|
||||
path := ""
|
||||
switch to {
|
||||
case "gemini":
|
||||
path = "generationConfig.thinkingConfig.thinkingBudget"
|
||||
case "claude":
|
||||
path = "thinking.budget_tokens"
|
||||
case "openai":
|
||||
path = "reasoning_effort"
|
||||
case "codex":
|
||||
path = "reasoning.effort"
|
||||
}
|
||||
if path == "" {
|
||||
return false, ""
|
||||
}
|
||||
val := gjson.GetBytes(body, path)
|
||||
if to == "codex" && !val.Exists() {
|
||||
reasoning := gjson.GetBytes(body, "reasoning")
|
||||
if reasoning.Exists() {
|
||||
val = reasoning.Get("effort")
|
||||
}
|
||||
}
|
||||
if !val.Exists() {
|
||||
return false, ""
|
||||
}
|
||||
if val.Type == gjson.Number {
|
||||
return true, fmt.Sprintf("%d", val.Int())
|
||||
}
|
||||
return true, val.String()
|
||||
}()
|
||||
|
||||
t.Logf("from=%s to=%s model=%s suffix=%s present(expect=%v got=%v) value(expect=%s got=%s) err(expect=%v got=%v) body=%s",
|
||||
from, to, model, cs.modelSuffix, expectPresent, actualPresent, expectValue, actualValue, expectErr, err != nil, string(body))
|
||||
|
||||
if expectErr {
|
||||
if err == nil {
|
||||
t.Fatalf("expected validation error but got none, body=%s", string(body))
|
||||
}
|
||||
return
|
||||
}
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v body=%s", err, string(body))
|
||||
}
|
||||
|
||||
if expectPresent != actualPresent {
|
||||
t.Fatalf("presence mismatch: expect %v got %v body=%s", expectPresent, actualPresent, string(body))
|
||||
}
|
||||
if expectPresent && expectValue != actualValue {
|
||||
t.Fatalf("value mismatch: expect %s got %s body=%s", expectValue, actualValue, string(body))
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// buildRawPayloadWithThinking creates a payload with thinking parameters already in the body.
|
||||
// This tests the path where thinking comes from the raw payload, not model suffix.
|
||||
func buildRawPayloadWithThinking(fromProtocol, model string, thinkingParam any) []byte {
|
||||
switch fromProtocol {
|
||||
case "gemini":
|
||||
base := fmt.Sprintf(`{"model":"%s","contents":[{"role":"user","parts":[{"text":"hi"}]}]}`, model)
|
||||
if budget, ok := thinkingParam.(int); ok {
|
||||
base, _ = sjson.Set(base, "generationConfig.thinkingConfig.thinkingBudget", budget)
|
||||
}
|
||||
return []byte(base)
|
||||
case "openai-response":
|
||||
base := fmt.Sprintf(`{"model":"%s","input":[{"role":"user","content":[{"type":"text","text":"hi"}]}]}`, model)
|
||||
if effort, ok := thinkingParam.(string); ok && effort != "" {
|
||||
base, _ = sjson.Set(base, "reasoning.effort", effort)
|
||||
}
|
||||
return []byte(base)
|
||||
case "openai":
|
||||
base := fmt.Sprintf(`{"model":"%s","messages":[{"role":"user","content":"hi"}]}`, model)
|
||||
if effort, ok := thinkingParam.(string); ok && effort != "" {
|
||||
base, _ = sjson.Set(base, "reasoning_effort", effort)
|
||||
}
|
||||
return []byte(base)
|
||||
case "claude":
|
||||
base := fmt.Sprintf(`{"model":"%s","messages":[{"role":"user","content":"hi"}]}`, model)
|
||||
if budget, ok := thinkingParam.(int); ok {
|
||||
base, _ = sjson.Set(base, "thinking.type", "enabled")
|
||||
base, _ = sjson.Set(base, "thinking.budget_tokens", budget)
|
||||
}
|
||||
return []byte(base)
|
||||
default:
|
||||
return []byte(fmt.Sprintf(`{"model":"%s","messages":[{"role":"user","content":"hi"}]}`, model))
|
||||
}
|
||||
}
|
||||
|
||||
// buildBodyForProtocolWithRawThinking translates payload with raw thinking params.
|
||||
func buildBodyForProtocolWithRawThinking(t *testing.T, fromProtocol, toProtocol, model string, thinkingParam any) ([]byte, error) {
|
||||
t.Helper()
|
||||
raw := buildRawPayloadWithThinking(fromProtocol, model, thinkingParam)
|
||||
stream := fromProtocol != toProtocol
|
||||
|
||||
body := sdktranslator.TranslateRequest(
|
||||
sdktranslator.FromString(fromProtocol),
|
||||
sdktranslator.FromString(toProtocol),
|
||||
model,
|
||||
raw,
|
||||
stream,
|
||||
)
|
||||
|
||||
var err error
|
||||
allowCompat := isOpenAICompatModel(model)
|
||||
switch toProtocol {
|
||||
case "gemini":
|
||||
body = util.ApplyDefaultThinkingIfNeeded(model, body)
|
||||
body = util.NormalizeGeminiThinkingBudget(model, body)
|
||||
body = util.StripThinkingConfigIfUnsupported(model, body)
|
||||
case "claude":
|
||||
// For raw payload, Claude thinking is passed through by translator
|
||||
// No additional processing needed as thinking is already in body
|
||||
case "openai":
|
||||
body = executor.NormalizeThinkingConfig(body, model, allowCompat)
|
||||
err = executor.ValidateThinkingConfig(body, model)
|
||||
case "codex":
|
||||
// Codex does not support allowCompat; always use false.
|
||||
body, err = normalizeCodexPayload(body, model, false)
|
||||
}
|
||||
|
||||
body, _ = sjson.SetBytes(body, "model", model)
|
||||
body = filterThinkingBody(toProtocol, body, model, model)
|
||||
return body, err
|
||||
}
|
||||
|
||||
func TestRawPayloadThinkingConversions(t *testing.T) {
|
||||
cleanup := registerCoreModels(t)
|
||||
defer cleanup()
|
||||
|
||||
type scenario struct {
|
||||
name string
|
||||
thinkingParam any // int for budget, string for effort level
|
||||
}
|
||||
|
||||
numericName := func(budget int) string {
|
||||
if budget < 0 {
|
||||
return "budget-neg1"
|
||||
}
|
||||
return fmt.Sprintf("budget-%d", budget)
|
||||
}
|
||||
|
||||
for _, model := range thinkingTestModels {
|
||||
supportsThinking := util.ModelSupportsThinking(model)
|
||||
usesLevels := util.ModelUsesThinkingLevels(model)
|
||||
allowCompat := isOpenAICompatModel(model)
|
||||
|
||||
for _, from := range thinkingTestFromProtocols {
|
||||
var cases []scenario
|
||||
switch from {
|
||||
case "openai", "openai-response":
|
||||
cases = []scenario{
|
||||
{name: "no-thinking", thinkingParam: nil},
|
||||
}
|
||||
for _, lvl := range thinkingLevelSamples {
|
||||
cases = append(cases, scenario{
|
||||
name: "effort-" + lvl,
|
||||
thinkingParam: lvl,
|
||||
})
|
||||
}
|
||||
case "gemini", "claude":
|
||||
cases = []scenario{
|
||||
{name: "no-thinking", thinkingParam: nil},
|
||||
}
|
||||
for _, budget := range thinkingNumericSamples {
|
||||
budget := budget
|
||||
cases = append(cases, scenario{
|
||||
name: numericName(budget),
|
||||
thinkingParam: budget,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
for _, to := range thinkingTestToProtocols {
|
||||
if from == to {
|
||||
continue
|
||||
}
|
||||
t.Logf("═══════════════════════════════════════════════════════════════════════════════")
|
||||
t.Logf(" RAW PAYLOAD: %s -> %s | model: %s", from, to, model)
|
||||
t.Logf("═══════════════════════════════════════════════════════════════════════════════")
|
||||
|
||||
for _, cs := range cases {
|
||||
from := from
|
||||
to := to
|
||||
cs := cs
|
||||
testName := fmt.Sprintf("raw/%s->%s/%s/%s", from, to, model, cs.name)
|
||||
t.Run(testName, func(t *testing.T) {
|
||||
expectPresent, expectValue, expectErr := func() (bool, string, bool) {
|
||||
if cs.thinkingParam == nil {
|
||||
if to == "codex" && from != "openai-response" && supportsThinking && usesLevels {
|
||||
// Codex translators default reasoning.effort to "medium" for thinking-capable level models
|
||||
return true, "medium", false
|
||||
}
|
||||
return false, "", false
|
||||
}
|
||||
|
||||
switch to {
|
||||
case "gemini":
|
||||
if !supportsThinking || usesLevels {
|
||||
return false, "", false
|
||||
}
|
||||
// Gemini expects numeric budget (only for non-level models)
|
||||
if budget, ok := cs.thinkingParam.(int); ok {
|
||||
norm := util.NormalizeThinkingBudget(model, budget)
|
||||
return true, fmt.Sprintf("%d", norm), false
|
||||
}
|
||||
// Convert effort level to budget for non-level models only
|
||||
if effort, ok := cs.thinkingParam.(string); ok && effort != "" {
|
||||
// "none" disables thinking - no thinkingBudget in output
|
||||
if strings.ToLower(effort) == "none" {
|
||||
return false, "", false
|
||||
}
|
||||
if budget, okB := util.ThinkingEffortToBudget(model, effort); okB {
|
||||
// ThinkingEffortToBudget already returns normalized budget
|
||||
return true, fmt.Sprintf("%d", budget), false
|
||||
}
|
||||
// Invalid effort does not map to a budget
|
||||
return false, "", false
|
||||
}
|
||||
return false, "", false
|
||||
case "claude":
|
||||
if !supportsThinking || usesLevels {
|
||||
return false, "", false
|
||||
}
|
||||
// Claude expects numeric budget (only for non-level models)
|
||||
if budget, ok := cs.thinkingParam.(int); ok && budget > 0 {
|
||||
norm := util.NormalizeThinkingBudget(model, budget)
|
||||
return true, fmt.Sprintf("%d", norm), false
|
||||
}
|
||||
// Convert effort level to budget for non-level models only
|
||||
if effort, ok := cs.thinkingParam.(string); ok && effort != "" {
|
||||
// "none" and "auto" don't produce budget_tokens
|
||||
lower := strings.ToLower(effort)
|
||||
if lower == "none" || lower == "auto" {
|
||||
return false, "", false
|
||||
}
|
||||
if budget, okB := util.ThinkingEffortToBudget(model, effort); okB {
|
||||
// ThinkingEffortToBudget already returns normalized budget
|
||||
return true, fmt.Sprintf("%d", budget), false
|
||||
}
|
||||
// Invalid effort - claude sets thinking.type:enabled but no budget_tokens
|
||||
return false, "", false
|
||||
}
|
||||
return false, "", false
|
||||
case "openai":
|
||||
if allowCompat {
|
||||
if effort, ok := cs.thinkingParam.(string); ok && strings.TrimSpace(effort) != "" {
|
||||
normalized := strings.ToLower(strings.TrimSpace(effort))
|
||||
return true, normalized, false
|
||||
}
|
||||
if budget, ok := cs.thinkingParam.(int); ok {
|
||||
if mapped, okM := util.OpenAIThinkingBudgetToEffort(model, budget); okM && mapped != "" {
|
||||
return true, mapped, false
|
||||
}
|
||||
}
|
||||
return false, "", false
|
||||
}
|
||||
if !supportsThinking || !usesLevels {
|
||||
return false, "", false
|
||||
}
|
||||
if effort, ok := cs.thinkingParam.(string); ok && effort != "" {
|
||||
if normalized, okN := util.NormalizeReasoningEffortLevel(model, effort); okN {
|
||||
return true, normalized, false
|
||||
}
|
||||
return false, "", true // invalid level
|
||||
}
|
||||
if budget, ok := cs.thinkingParam.(int); ok {
|
||||
if mapped, okM := util.OpenAIThinkingBudgetToEffort(model, budget); okM && mapped != "" {
|
||||
// Check if the mapped effort is valid for this model
|
||||
if _, validLevel := util.NormalizeReasoningEffortLevel(model, mapped); !validLevel {
|
||||
return true, mapped, true // expect validation error
|
||||
}
|
||||
return true, mapped, false
|
||||
}
|
||||
}
|
||||
return false, "", false
|
||||
case "codex":
|
||||
// Codex does not support allowCompat; require thinking-capable level models.
|
||||
if !supportsThinking || !usesLevels {
|
||||
return false, "", false
|
||||
}
|
||||
if effort, ok := cs.thinkingParam.(string); ok && effort != "" {
|
||||
if normalized, okN := util.NormalizeReasoningEffortLevel(model, effort); okN {
|
||||
return true, normalized, false
|
||||
}
|
||||
return false, "", true
|
||||
}
|
||||
if budget, ok := cs.thinkingParam.(int); ok {
|
||||
if mapped, okM := util.OpenAIThinkingBudgetToEffort(model, budget); okM && mapped != "" {
|
||||
// Check if the mapped effort is valid for this model
|
||||
if _, validLevel := util.NormalizeReasoningEffortLevel(model, mapped); !validLevel {
|
||||
return true, mapped, true // expect validation error
|
||||
}
|
||||
return true, mapped, false
|
||||
}
|
||||
}
|
||||
if from != "openai-response" {
|
||||
// Codex translators default reasoning.effort to "medium" for thinking-capable models
|
||||
return true, "medium", false
|
||||
}
|
||||
return false, "", false
|
||||
}
|
||||
return false, "", false
|
||||
}()
|
||||
|
||||
body, err := buildBodyForProtocolWithRawThinking(t, from, to, model, cs.thinkingParam)
|
||||
actualPresent, actualValue := func() (bool, string) {
|
||||
path := ""
|
||||
switch to {
|
||||
case "gemini":
|
||||
path = "generationConfig.thinkingConfig.thinkingBudget"
|
||||
case "claude":
|
||||
path = "thinking.budget_tokens"
|
||||
case "openai":
|
||||
path = "reasoning_effort"
|
||||
case "codex":
|
||||
path = "reasoning.effort"
|
||||
}
|
||||
if path == "" {
|
||||
return false, ""
|
||||
}
|
||||
val := gjson.GetBytes(body, path)
|
||||
if to == "codex" && !val.Exists() {
|
||||
reasoning := gjson.GetBytes(body, "reasoning")
|
||||
if reasoning.Exists() {
|
||||
val = reasoning.Get("effort")
|
||||
}
|
||||
}
|
||||
if !val.Exists() {
|
||||
return false, ""
|
||||
}
|
||||
if val.Type == gjson.Number {
|
||||
return true, fmt.Sprintf("%d", val.Int())
|
||||
}
|
||||
return true, val.String()
|
||||
}()
|
||||
|
||||
t.Logf("from=%s to=%s model=%s param=%v present(expect=%v got=%v) value(expect=%s got=%s) err(expect=%v got=%v) body=%s",
|
||||
from, to, model, cs.thinkingParam, expectPresent, actualPresent, expectValue, actualValue, expectErr, err != nil, string(body))
|
||||
|
||||
if expectErr {
|
||||
if err == nil {
|
||||
t.Fatalf("expected validation error but got none, body=%s", string(body))
|
||||
}
|
||||
return
|
||||
}
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v body=%s", err, string(body))
|
||||
}
|
||||
|
||||
if expectPresent != actualPresent {
|
||||
t.Fatalf("presence mismatch: expect %v got %v body=%s", expectPresent, actualPresent, string(body))
|
||||
}
|
||||
if expectPresent && expectValue != actualValue {
|
||||
t.Fatalf("value mismatch: expect %s got %s body=%s", expectValue, actualValue, string(body))
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestOpenAIThinkingBudgetToEffortRanges(t *testing.T) {
|
||||
cleanup := registerCoreModels(t)
|
||||
defer cleanup()
|
||||
|
||||
cases := []struct {
|
||||
name string
|
||||
model string
|
||||
budget int
|
||||
want string
|
||||
ok bool
|
||||
}{
|
||||
{name: "dynamic-auto", model: "gpt-5", budget: -1, want: "auto", ok: true},
|
||||
{name: "zero-none", model: "gpt-5", budget: 0, want: "none", ok: true},
|
||||
{name: "low-min", model: "gpt-5", budget: 1, want: "low", ok: true},
|
||||
{name: "low-max", model: "gpt-5", budget: 1024, want: "low", ok: true},
|
||||
{name: "medium-min", model: "gpt-5", budget: 1025, want: "medium", ok: true},
|
||||
{name: "medium-max", model: "gpt-5", budget: 8192, want: "medium", ok: true},
|
||||
{name: "high-min", model: "gpt-5", budget: 8193, want: "high", ok: true},
|
||||
{name: "high-max", model: "gpt-5", budget: 24576, want: "high", ok: true},
|
||||
{name: "over-max-clamps-to-highest", model: "gpt-5", budget: 64000, want: "high", ok: true},
|
||||
{name: "over-max-xhigh-model", model: "gpt-5.2", budget: 50000, want: "xhigh", ok: true},
|
||||
{name: "negative-unsupported", model: "gpt-5", budget: -5, want: "", ok: false},
|
||||
}
|
||||
|
||||
for _, cs := range cases {
|
||||
cs := cs
|
||||
t.Run(cs.name, func(t *testing.T) {
|
||||
got, ok := util.OpenAIThinkingBudgetToEffort(cs.model, cs.budget)
|
||||
if ok != cs.ok {
|
||||
t.Fatalf("ok mismatch for model=%s budget=%d: expect %v got %v", cs.model, cs.budget, cs.ok, ok)
|
||||
}
|
||||
if got != cs.want {
|
||||
t.Fatalf("value mismatch for model=%s budget=%d: expect %q got %q", cs.model, cs.budget, cs.want, got)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user