mirror of
https://github.com/router-for-me/CLIProxyAPIPlus.git
synced 2026-03-18 15:02:28 +00:00
Compare commits
10 Commits
v6.6.103-0
...
v6.6.107-0
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
496f6770a5 | ||
|
|
5a7e5bd870 | ||
|
|
6f8a8f8136 | ||
|
|
5a2bf191fc | ||
|
|
a235fb1507 | ||
|
|
0d66522ed8 | ||
|
|
b163f8ed9e | ||
|
|
83e5f60b8b | ||
|
|
5b433f962f | ||
|
|
a1da6ff5ac |
@@ -74,6 +74,7 @@ func main() {
|
||||
var iflowLogin bool
|
||||
var iflowCookie bool
|
||||
var noBrowser bool
|
||||
var oauthCallbackPort int
|
||||
var antigravityLogin bool
|
||||
var kiroLogin bool
|
||||
var kiroGoogleLogin bool
|
||||
@@ -96,6 +97,7 @@ func main() {
|
||||
flag.BoolVar(&iflowLogin, "iflow-login", false, "Login to iFlow using OAuth")
|
||||
flag.BoolVar(&iflowCookie, "iflow-cookie", false, "Login to iFlow using Cookie")
|
||||
flag.BoolVar(&noBrowser, "no-browser", false, "Don't open browser automatically for OAuth")
|
||||
flag.IntVar(&oauthCallbackPort, "oauth-callback-port", 0, "Override OAuth callback port (defaults to provider-specific port)")
|
||||
flag.BoolVar(&useIncognito, "incognito", false, "Open browser in incognito/private mode for OAuth (useful for multiple accounts)")
|
||||
flag.BoolVar(&noIncognito, "no-incognito", false, "Force disable incognito mode (uses existing browser session)")
|
||||
flag.BoolVar(&antigravityLogin, "antigravity-login", false, "Login to Antigravity using OAuth")
|
||||
@@ -454,7 +456,8 @@ func main() {
|
||||
|
||||
// Create login options to be used in authentication flows.
|
||||
options := &cmd.LoginOptions{
|
||||
NoBrowser: noBrowser,
|
||||
NoBrowser: noBrowser,
|
||||
CallbackPort: oauthCallbackPort,
|
||||
}
|
||||
|
||||
// Register the shared token store once so all components use the same persistence backend.
|
||||
|
||||
@@ -29,8 +29,9 @@ import (
|
||||
)
|
||||
|
||||
const (
|
||||
geminiOauthClientID = "681255809395-oo8ft2oprdrnp9e3aqf6av3hmdib135j.apps.googleusercontent.com"
|
||||
geminiOauthClientSecret = "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"
|
||||
geminiOauthClientID = "681255809395-oo8ft2oprdrnp9e3aqf6av3hmdib135j.apps.googleusercontent.com"
|
||||
geminiOauthClientSecret = "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"
|
||||
geminiDefaultCallbackPort = 8085
|
||||
)
|
||||
|
||||
var (
|
||||
@@ -49,8 +50,9 @@ type GeminiAuth struct {
|
||||
|
||||
// WebLoginOptions customizes the interactive OAuth flow.
|
||||
type WebLoginOptions struct {
|
||||
NoBrowser bool
|
||||
Prompt func(string) (string, error)
|
||||
NoBrowser bool
|
||||
CallbackPort int
|
||||
Prompt func(string) (string, error)
|
||||
}
|
||||
|
||||
// NewGeminiAuth creates a new instance of GeminiAuth.
|
||||
@@ -72,6 +74,12 @@ func NewGeminiAuth() *GeminiAuth {
|
||||
// - *http.Client: An HTTP client configured with authentication
|
||||
// - error: An error if the client configuration fails, nil otherwise
|
||||
func (g *GeminiAuth) GetAuthenticatedClient(ctx context.Context, ts *GeminiTokenStorage, cfg *config.Config, opts *WebLoginOptions) (*http.Client, error) {
|
||||
callbackPort := geminiDefaultCallbackPort
|
||||
if opts != nil && opts.CallbackPort > 0 {
|
||||
callbackPort = opts.CallbackPort
|
||||
}
|
||||
callbackURL := fmt.Sprintf("http://localhost:%d/oauth2callback", callbackPort)
|
||||
|
||||
// Configure proxy settings for the HTTP client if a proxy URL is provided.
|
||||
proxyURL, err := url.Parse(cfg.ProxyURL)
|
||||
if err == nil {
|
||||
@@ -106,7 +114,7 @@ func (g *GeminiAuth) GetAuthenticatedClient(ctx context.Context, ts *GeminiToken
|
||||
conf := &oauth2.Config{
|
||||
ClientID: geminiOauthClientID,
|
||||
ClientSecret: geminiOauthClientSecret,
|
||||
RedirectURL: "http://localhost:8085/oauth2callback", // This will be used by the local server.
|
||||
RedirectURL: callbackURL, // This will be used by the local server.
|
||||
Scopes: geminiOauthScopes,
|
||||
Endpoint: google.Endpoint,
|
||||
}
|
||||
@@ -218,14 +226,20 @@ func (g *GeminiAuth) createTokenStorage(ctx context.Context, config *oauth2.Conf
|
||||
// - *oauth2.Token: The OAuth2 token obtained from the authorization flow
|
||||
// - error: An error if the token acquisition fails, nil otherwise
|
||||
func (g *GeminiAuth) getTokenFromWeb(ctx context.Context, config *oauth2.Config, opts *WebLoginOptions) (*oauth2.Token, error) {
|
||||
callbackPort := geminiDefaultCallbackPort
|
||||
if opts != nil && opts.CallbackPort > 0 {
|
||||
callbackPort = opts.CallbackPort
|
||||
}
|
||||
callbackURL := fmt.Sprintf("http://localhost:%d/oauth2callback", callbackPort)
|
||||
|
||||
// Use a channel to pass the authorization code from the HTTP handler to the main function.
|
||||
codeChan := make(chan string, 1)
|
||||
errChan := make(chan error, 1)
|
||||
|
||||
// Create a new HTTP server with its own multiplexer.
|
||||
mux := http.NewServeMux()
|
||||
server := &http.Server{Addr: ":8085", Handler: mux}
|
||||
config.RedirectURL = "http://localhost:8085/oauth2callback"
|
||||
server := &http.Server{Addr: fmt.Sprintf(":%d", callbackPort), Handler: mux}
|
||||
config.RedirectURL = callbackURL
|
||||
|
||||
mux.HandleFunc("/oauth2callback", func(w http.ResponseWriter, r *http.Request) {
|
||||
if err := r.URL.Query().Get("error"); err != "" {
|
||||
@@ -277,13 +291,13 @@ func (g *GeminiAuth) getTokenFromWeb(ctx context.Context, config *oauth2.Config,
|
||||
// Check if browser is available
|
||||
if !browser.IsAvailable() {
|
||||
log.Warn("No browser available on this system")
|
||||
util.PrintSSHTunnelInstructions(8085)
|
||||
util.PrintSSHTunnelInstructions(callbackPort)
|
||||
fmt.Printf("Please manually open this URL in your browser:\n\n%s\n", authURL)
|
||||
} else {
|
||||
if err := browser.OpenURL(authURL); err != nil {
|
||||
authErr := codex.NewAuthenticationError(codex.ErrBrowserOpenFailed, err)
|
||||
log.Warn(codex.GetUserFriendlyMessage(authErr))
|
||||
util.PrintSSHTunnelInstructions(8085)
|
||||
util.PrintSSHTunnelInstructions(callbackPort)
|
||||
fmt.Printf("Please manually open this URL in your browser:\n\n%s\n", authURL)
|
||||
|
||||
// Log platform info for debugging
|
||||
@@ -294,7 +308,7 @@ func (g *GeminiAuth) getTokenFromWeb(ctx context.Context, config *oauth2.Config,
|
||||
}
|
||||
}
|
||||
} else {
|
||||
util.PrintSSHTunnelInstructions(8085)
|
||||
util.PrintSSHTunnelInstructions(callbackPort)
|
||||
fmt.Printf("Please open this URL in your browser:\n\n%s\n", authURL)
|
||||
}
|
||||
|
||||
|
||||
@@ -32,9 +32,10 @@ func DoClaudeLogin(cfg *config.Config, options *LoginOptions) {
|
||||
manager := newAuthManager()
|
||||
|
||||
authOpts := &sdkAuth.LoginOptions{
|
||||
NoBrowser: options.NoBrowser,
|
||||
Metadata: map[string]string{},
|
||||
Prompt: promptFn,
|
||||
NoBrowser: options.NoBrowser,
|
||||
CallbackPort: options.CallbackPort,
|
||||
Metadata: map[string]string{},
|
||||
Prompt: promptFn,
|
||||
}
|
||||
|
||||
_, savedPath, err := manager.Login(context.Background(), "claude", cfg, authOpts)
|
||||
|
||||
@@ -22,9 +22,10 @@ func DoAntigravityLogin(cfg *config.Config, options *LoginOptions) {
|
||||
|
||||
manager := newAuthManager()
|
||||
authOpts := &sdkAuth.LoginOptions{
|
||||
NoBrowser: options.NoBrowser,
|
||||
Metadata: map[string]string{},
|
||||
Prompt: promptFn,
|
||||
NoBrowser: options.NoBrowser,
|
||||
CallbackPort: options.CallbackPort,
|
||||
Metadata: map[string]string{},
|
||||
Prompt: promptFn,
|
||||
}
|
||||
|
||||
record, savedPath, err := manager.Login(context.Background(), "antigravity", cfg, authOpts)
|
||||
|
||||
@@ -24,9 +24,10 @@ func DoIFlowLogin(cfg *config.Config, options *LoginOptions) {
|
||||
}
|
||||
|
||||
authOpts := &sdkAuth.LoginOptions{
|
||||
NoBrowser: options.NoBrowser,
|
||||
Metadata: map[string]string{},
|
||||
Prompt: promptFn,
|
||||
NoBrowser: options.NoBrowser,
|
||||
CallbackPort: options.CallbackPort,
|
||||
Metadata: map[string]string{},
|
||||
Prompt: promptFn,
|
||||
}
|
||||
|
||||
_, savedPath, err := manager.Login(context.Background(), "iflow", cfg, authOpts)
|
||||
|
||||
@@ -67,10 +67,11 @@ func DoLogin(cfg *config.Config, projectID string, options *LoginOptions) {
|
||||
}
|
||||
|
||||
loginOpts := &sdkAuth.LoginOptions{
|
||||
NoBrowser: options.NoBrowser,
|
||||
ProjectID: trimmedProjectID,
|
||||
Metadata: map[string]string{},
|
||||
Prompt: callbackPrompt,
|
||||
NoBrowser: options.NoBrowser,
|
||||
ProjectID: trimmedProjectID,
|
||||
CallbackPort: options.CallbackPort,
|
||||
Metadata: map[string]string{},
|
||||
Prompt: callbackPrompt,
|
||||
}
|
||||
|
||||
authenticator := sdkAuth.NewGeminiAuthenticator()
|
||||
@@ -88,8 +89,9 @@ func DoLogin(cfg *config.Config, projectID string, options *LoginOptions) {
|
||||
|
||||
geminiAuth := gemini.NewGeminiAuth()
|
||||
httpClient, errClient := geminiAuth.GetAuthenticatedClient(ctx, storage, cfg, &gemini.WebLoginOptions{
|
||||
NoBrowser: options.NoBrowser,
|
||||
Prompt: callbackPrompt,
|
||||
NoBrowser: options.NoBrowser,
|
||||
CallbackPort: options.CallbackPort,
|
||||
Prompt: callbackPrompt,
|
||||
})
|
||||
if errClient != nil {
|
||||
log.Errorf("Gemini authentication failed: %v", errClient)
|
||||
|
||||
@@ -19,6 +19,9 @@ type LoginOptions struct {
|
||||
// NoBrowser indicates whether to skip opening the browser automatically.
|
||||
NoBrowser bool
|
||||
|
||||
// CallbackPort overrides the local OAuth callback port when set (>0).
|
||||
CallbackPort int
|
||||
|
||||
// Prompt allows the caller to provide interactive input when needed.
|
||||
Prompt func(prompt string) (string, error)
|
||||
}
|
||||
@@ -43,9 +46,10 @@ func DoCodexLogin(cfg *config.Config, options *LoginOptions) {
|
||||
manager := newAuthManager()
|
||||
|
||||
authOpts := &sdkAuth.LoginOptions{
|
||||
NoBrowser: options.NoBrowser,
|
||||
Metadata: map[string]string{},
|
||||
Prompt: promptFn,
|
||||
NoBrowser: options.NoBrowser,
|
||||
CallbackPort: options.CallbackPort,
|
||||
Metadata: map[string]string{},
|
||||
Prompt: promptFn,
|
||||
}
|
||||
|
||||
_, savedPath, err := manager.Login(context.Background(), "codex", cfg, authOpts)
|
||||
|
||||
@@ -36,9 +36,10 @@ func DoQwenLogin(cfg *config.Config, options *LoginOptions) {
|
||||
}
|
||||
|
||||
authOpts := &sdkAuth.LoginOptions{
|
||||
NoBrowser: options.NoBrowser,
|
||||
Metadata: map[string]string{},
|
||||
Prompt: promptFn,
|
||||
NoBrowser: options.NoBrowser,
|
||||
CallbackPort: options.CallbackPort,
|
||||
Metadata: map[string]string{},
|
||||
Prompt: promptFn,
|
||||
}
|
||||
|
||||
_, savedPath, err := manager.Login(context.Background(), "qwen", cfg, authOpts)
|
||||
|
||||
@@ -255,6 +255,10 @@ type ClaudeKey struct {
|
||||
// APIKey is the authentication key for accessing Claude API services.
|
||||
APIKey string `yaml:"api-key" json:"api-key"`
|
||||
|
||||
// Priority controls selection preference when multiple credentials match.
|
||||
// Higher values are preferred; defaults to 0.
|
||||
Priority int `yaml:"priority,omitempty" json:"priority,omitempty"`
|
||||
|
||||
// Prefix optionally namespaces models for this credential (e.g., "teamA/claude-sonnet-4").
|
||||
Prefix string `yaml:"prefix,omitempty" json:"prefix,omitempty"`
|
||||
|
||||
@@ -293,6 +297,10 @@ type CodexKey struct {
|
||||
// APIKey is the authentication key for accessing Codex API services.
|
||||
APIKey string `yaml:"api-key" json:"api-key"`
|
||||
|
||||
// Priority controls selection preference when multiple credentials match.
|
||||
// Higher values are preferred; defaults to 0.
|
||||
Priority int `yaml:"priority,omitempty" json:"priority,omitempty"`
|
||||
|
||||
// Prefix optionally namespaces models for this credential (e.g., "teamA/gpt-5-codex").
|
||||
Prefix string `yaml:"prefix,omitempty" json:"prefix,omitempty"`
|
||||
|
||||
@@ -331,6 +339,10 @@ type GeminiKey struct {
|
||||
// APIKey is the authentication key for accessing Gemini API services.
|
||||
APIKey string `yaml:"api-key" json:"api-key"`
|
||||
|
||||
// Priority controls selection preference when multiple credentials match.
|
||||
// Higher values are preferred; defaults to 0.
|
||||
Priority int `yaml:"priority,omitempty" json:"priority,omitempty"`
|
||||
|
||||
// Prefix optionally namespaces models for this credential (e.g., "teamA/gemini-3-pro-preview").
|
||||
Prefix string `yaml:"prefix,omitempty" json:"prefix,omitempty"`
|
||||
|
||||
@@ -397,6 +409,10 @@ type OpenAICompatibility struct {
|
||||
// Name is the identifier for this OpenAI compatibility configuration.
|
||||
Name string `yaml:"name" json:"name"`
|
||||
|
||||
// Priority controls selection preference when multiple providers or credentials match.
|
||||
// Higher values are preferred; defaults to 0.
|
||||
Priority int `yaml:"priority,omitempty" json:"priority,omitempty"`
|
||||
|
||||
// Prefix optionally namespaces model aliases for this provider (e.g., "teamA/kimi-k2").
|
||||
Prefix string `yaml:"prefix,omitempty" json:"prefix,omitempty"`
|
||||
|
||||
|
||||
@@ -13,6 +13,10 @@ type VertexCompatKey struct {
|
||||
// Maps to the x-goog-api-key header.
|
||||
APIKey string `yaml:"api-key" json:"api-key"`
|
||||
|
||||
// Priority controls selection preference when multiple credentials match.
|
||||
// Higher values are preferred; defaults to 0.
|
||||
Priority int `yaml:"priority,omitempty" json:"priority,omitempty"`
|
||||
|
||||
// Prefix optionally namespaces model aliases for this credential (e.g., "teamA/vertex-pro").
|
||||
Prefix string `yaml:"prefix,omitempty" json:"prefix,omitempty"`
|
||||
|
||||
|
||||
@@ -251,6 +251,7 @@ func ConvertClaudeResponseToOpenAIResponses(ctx context.Context, modelName strin
|
||||
itemDone, _ = sjson.Set(itemDone, "item.id", fmt.Sprintf("fc_%s", st.CurrentFCID))
|
||||
itemDone, _ = sjson.Set(itemDone, "item.arguments", args)
|
||||
itemDone, _ = sjson.Set(itemDone, "item.call_id", st.CurrentFCID)
|
||||
itemDone, _ = sjson.Set(itemDone, "item.name", st.FuncNames[idx])
|
||||
out = append(out, emitEvent("response.output_item.done", itemDone))
|
||||
st.InFuncBlock = false
|
||||
} else if st.ReasoningActive {
|
||||
|
||||
@@ -520,7 +520,7 @@ func convertClaudeToolsToKiro(tools gjson.Result) []KiroToolWrapper {
|
||||
log.Debugf("kiro: tool '%s' has empty description, using default: %s", name, description)
|
||||
}
|
||||
|
||||
// Truncate long descriptions
|
||||
// Truncate long descriptions (individual tool limit)
|
||||
if len(description) > kirocommon.KiroMaxToolDescLen {
|
||||
truncLen := kirocommon.KiroMaxToolDescLen - 30
|
||||
for truncLen > 0 && !utf8.RuneStart(description[truncLen]) {
|
||||
@@ -538,6 +538,10 @@ func convertClaudeToolsToKiro(tools gjson.Result) []KiroToolWrapper {
|
||||
})
|
||||
}
|
||||
|
||||
// Apply dynamic compression if total tools size exceeds threshold
|
||||
// This prevents 500 errors when Claude Code sends too many tools
|
||||
kiroTools = compressToolsIfNeeded(kiroTools)
|
||||
|
||||
return kiroTools
|
||||
}
|
||||
|
||||
|
||||
191
internal/translator/kiro/claude/tool_compression.go
Normal file
191
internal/translator/kiro/claude/tool_compression.go
Normal file
@@ -0,0 +1,191 @@
|
||||
// Package claude provides tool compression functionality for Kiro translator.
|
||||
// This file implements dynamic tool compression to reduce tool payload size
|
||||
// when it exceeds the target threshold, preventing 500 errors from Kiro API.
|
||||
package claude
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"unicode/utf8"
|
||||
|
||||
kirocommon "github.com/router-for-me/CLIProxyAPI/v6/internal/translator/kiro/common"
|
||||
log "github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
// calculateToolsSize calculates the JSON serialized size of the tools list.
|
||||
// Returns the size in bytes.
|
||||
func calculateToolsSize(tools []KiroToolWrapper) int {
|
||||
if len(tools) == 0 {
|
||||
return 0
|
||||
}
|
||||
data, err := json.Marshal(tools)
|
||||
if err != nil {
|
||||
log.Warnf("kiro: failed to marshal tools for size calculation: %v", err)
|
||||
return 0
|
||||
}
|
||||
return len(data)
|
||||
}
|
||||
|
||||
// simplifyInputSchema simplifies the input_schema by keeping only essential fields:
|
||||
// type, enum, required. Recursively processes nested properties.
|
||||
func simplifyInputSchema(schema interface{}) interface{} {
|
||||
if schema == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
schemaMap, ok := schema.(map[string]interface{})
|
||||
if !ok {
|
||||
return schema
|
||||
}
|
||||
|
||||
simplified := make(map[string]interface{})
|
||||
|
||||
// Keep essential fields
|
||||
if t, ok := schemaMap["type"]; ok {
|
||||
simplified["type"] = t
|
||||
}
|
||||
if enum, ok := schemaMap["enum"]; ok {
|
||||
simplified["enum"] = enum
|
||||
}
|
||||
if required, ok := schemaMap["required"]; ok {
|
||||
simplified["required"] = required
|
||||
}
|
||||
|
||||
// Recursively process properties
|
||||
if properties, ok := schemaMap["properties"].(map[string]interface{}); ok {
|
||||
simplifiedProps := make(map[string]interface{})
|
||||
for key, value := range properties {
|
||||
simplifiedProps[key] = simplifyInputSchema(value)
|
||||
}
|
||||
simplified["properties"] = simplifiedProps
|
||||
}
|
||||
|
||||
// Process items for array types
|
||||
if items, ok := schemaMap["items"]; ok {
|
||||
simplified["items"] = simplifyInputSchema(items)
|
||||
}
|
||||
|
||||
// Process additionalProperties if present
|
||||
if additionalProps, ok := schemaMap["additionalProperties"]; ok {
|
||||
simplified["additionalProperties"] = simplifyInputSchema(additionalProps)
|
||||
}
|
||||
|
||||
// Process anyOf, oneOf, allOf
|
||||
for _, key := range []string{"anyOf", "oneOf", "allOf"} {
|
||||
if arr, ok := schemaMap[key].([]interface{}); ok {
|
||||
simplifiedArr := make([]interface{}, len(arr))
|
||||
for i, item := range arr {
|
||||
simplifiedArr[i] = simplifyInputSchema(item)
|
||||
}
|
||||
simplified[key] = simplifiedArr
|
||||
}
|
||||
}
|
||||
|
||||
return simplified
|
||||
}
|
||||
|
||||
// compressToolDescription compresses a description to the target length.
|
||||
// Ensures the result is at least MinToolDescriptionLength characters.
|
||||
// Uses UTF-8 safe truncation.
|
||||
func compressToolDescription(description string, targetLength int) string {
|
||||
if targetLength < kirocommon.MinToolDescriptionLength {
|
||||
targetLength = kirocommon.MinToolDescriptionLength
|
||||
}
|
||||
|
||||
if len(description) <= targetLength {
|
||||
return description
|
||||
}
|
||||
|
||||
// Find a safe truncation point (UTF-8 boundary)
|
||||
truncLen := targetLength - 3 // Leave room for "..."
|
||||
|
||||
// Ensure we don't cut in the middle of a UTF-8 character
|
||||
for truncLen > 0 && !utf8.RuneStart(description[truncLen]) {
|
||||
truncLen--
|
||||
}
|
||||
|
||||
if truncLen <= 0 {
|
||||
return description[:kirocommon.MinToolDescriptionLength]
|
||||
}
|
||||
|
||||
return description[:truncLen] + "..."
|
||||
}
|
||||
|
||||
// compressToolsIfNeeded compresses tools if their total size exceeds the target threshold.
|
||||
// Compression strategy:
|
||||
// 1. First, check if compression is needed (size > ToolCompressionTargetSize)
|
||||
// 2. Step 1: Simplify input_schema (keep only type/enum/required)
|
||||
// 3. Step 2: Proportionally compress descriptions (minimum MinToolDescriptionLength chars)
|
||||
// Returns the compressed tools list.
|
||||
func compressToolsIfNeeded(tools []KiroToolWrapper) []KiroToolWrapper {
|
||||
if len(tools) == 0 {
|
||||
return tools
|
||||
}
|
||||
|
||||
originalSize := calculateToolsSize(tools)
|
||||
if originalSize <= kirocommon.ToolCompressionTargetSize {
|
||||
log.Debugf("kiro: tools size %d bytes is within target %d bytes, no compression needed",
|
||||
originalSize, kirocommon.ToolCompressionTargetSize)
|
||||
return tools
|
||||
}
|
||||
|
||||
log.Infof("kiro: tools size %d bytes exceeds target %d bytes, starting compression",
|
||||
originalSize, kirocommon.ToolCompressionTargetSize)
|
||||
|
||||
// Create a copy of tools to avoid modifying the original
|
||||
compressedTools := make([]KiroToolWrapper, len(tools))
|
||||
for i, tool := range tools {
|
||||
compressedTools[i] = KiroToolWrapper{
|
||||
ToolSpecification: KiroToolSpecification{
|
||||
Name: tool.ToolSpecification.Name,
|
||||
Description: tool.ToolSpecification.Description,
|
||||
InputSchema: KiroInputSchema{JSON: tool.ToolSpecification.InputSchema.JSON},
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
// Step 1: Simplify input_schema
|
||||
for i := range compressedTools {
|
||||
compressedTools[i].ToolSpecification.InputSchema.JSON =
|
||||
simplifyInputSchema(compressedTools[i].ToolSpecification.InputSchema.JSON)
|
||||
}
|
||||
|
||||
sizeAfterSchemaSimplification := calculateToolsSize(compressedTools)
|
||||
log.Debugf("kiro: size after schema simplification: %d bytes (reduced by %d bytes)",
|
||||
sizeAfterSchemaSimplification, originalSize-sizeAfterSchemaSimplification)
|
||||
|
||||
// Check if we're within target after schema simplification
|
||||
if sizeAfterSchemaSimplification <= kirocommon.ToolCompressionTargetSize {
|
||||
log.Infof("kiro: compression complete after schema simplification, final size: %d bytes",
|
||||
sizeAfterSchemaSimplification)
|
||||
return compressedTools
|
||||
}
|
||||
|
||||
// Step 2: Compress descriptions proportionally
|
||||
sizeToReduce := float64(sizeAfterSchemaSimplification - kirocommon.ToolCompressionTargetSize)
|
||||
var totalDescLen float64
|
||||
for _, tool := range compressedTools {
|
||||
totalDescLen += float64(len(tool.ToolSpecification.Description))
|
||||
}
|
||||
|
||||
if totalDescLen > 0 {
|
||||
// Assume size reduction comes primarily from descriptions.
|
||||
keepRatio := 1.0 - (sizeToReduce / totalDescLen)
|
||||
if keepRatio > 1.0 {
|
||||
keepRatio = 1.0
|
||||
} else if keepRatio < 0 {
|
||||
keepRatio = 0
|
||||
}
|
||||
|
||||
for i := range compressedTools {
|
||||
desc := compressedTools[i].ToolSpecification.Description
|
||||
targetLen := int(float64(len(desc)) * keepRatio)
|
||||
compressedTools[i].ToolSpecification.Description = compressToolDescription(desc, targetLen)
|
||||
}
|
||||
}
|
||||
|
||||
finalSize := calculateToolsSize(compressedTools)
|
||||
log.Infof("kiro: compression complete, original: %d bytes, final: %d bytes (%.1f%% reduction)",
|
||||
originalSize, finalSize, float64(originalSize-finalSize)/float64(originalSize)*100)
|
||||
|
||||
return compressedTools
|
||||
}
|
||||
@@ -6,6 +6,14 @@ const (
|
||||
// Kiro API limit is 10240 bytes, leave room for "..."
|
||||
KiroMaxToolDescLen = 10237
|
||||
|
||||
// ToolCompressionTargetSize is the target total size for compressed tools (20KB).
|
||||
// If tools exceed this size, compression will be applied.
|
||||
ToolCompressionTargetSize = 20 * 1024 // 20KB
|
||||
|
||||
// MinToolDescriptionLength is the minimum description length after compression.
|
||||
// Descriptions will not be shortened below this length.
|
||||
MinToolDescriptionLength = 50
|
||||
|
||||
// ThinkingStartTag is the start tag for thinking blocks in responses.
|
||||
ThinkingStartTag = "<thinking>"
|
||||
|
||||
@@ -72,4 +80,4 @@ You MUST follow these rules for ALL file operations. Violation causes server tim
|
||||
- Failed writes waste time and require retry
|
||||
|
||||
REMEMBER: When in doubt, write LESS per operation. Multiple small operations > one large operation.`
|
||||
)
|
||||
)
|
||||
|
||||
@@ -2,6 +2,7 @@ package synthesizer
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
kiroauth "github.com/router-for-me/CLIProxyAPI/v6/internal/auth/kiro"
|
||||
@@ -63,6 +64,9 @@ func (s *ConfigSynthesizer) synthesizeGeminiKeys(ctx *SynthesisContext) []*corea
|
||||
"source": fmt.Sprintf("config:gemini[%s]", token),
|
||||
"api_key": key,
|
||||
}
|
||||
if entry.Priority != 0 {
|
||||
attrs["priority"] = strconv.Itoa(entry.Priority)
|
||||
}
|
||||
if base != "" {
|
||||
attrs["base_url"] = base
|
||||
}
|
||||
@@ -107,6 +111,9 @@ func (s *ConfigSynthesizer) synthesizeClaudeKeys(ctx *SynthesisContext) []*corea
|
||||
"source": fmt.Sprintf("config:claude[%s]", token),
|
||||
"api_key": key,
|
||||
}
|
||||
if ck.Priority != 0 {
|
||||
attrs["priority"] = strconv.Itoa(ck.Priority)
|
||||
}
|
||||
if base != "" {
|
||||
attrs["base_url"] = base
|
||||
}
|
||||
@@ -151,6 +158,9 @@ func (s *ConfigSynthesizer) synthesizeCodexKeys(ctx *SynthesisContext) []*coreau
|
||||
"source": fmt.Sprintf("config:codex[%s]", token),
|
||||
"api_key": key,
|
||||
}
|
||||
if ck.Priority != 0 {
|
||||
attrs["priority"] = strconv.Itoa(ck.Priority)
|
||||
}
|
||||
if ck.BaseURL != "" {
|
||||
attrs["base_url"] = ck.BaseURL
|
||||
}
|
||||
@@ -206,6 +216,9 @@ func (s *ConfigSynthesizer) synthesizeOpenAICompat(ctx *SynthesisContext) []*cor
|
||||
"compat_name": compat.Name,
|
||||
"provider_key": providerName,
|
||||
}
|
||||
if compat.Priority != 0 {
|
||||
attrs["priority"] = strconv.Itoa(compat.Priority)
|
||||
}
|
||||
if key != "" {
|
||||
attrs["api_key"] = key
|
||||
}
|
||||
@@ -237,6 +250,9 @@ func (s *ConfigSynthesizer) synthesizeOpenAICompat(ctx *SynthesisContext) []*cor
|
||||
"compat_name": compat.Name,
|
||||
"provider_key": providerName,
|
||||
}
|
||||
if compat.Priority != 0 {
|
||||
attrs["priority"] = strconv.Itoa(compat.Priority)
|
||||
}
|
||||
if hash := diff.ComputeOpenAICompatModelsHash(compat.Models); hash != "" {
|
||||
attrs["models_hash"] = hash
|
||||
}
|
||||
@@ -279,6 +295,9 @@ func (s *ConfigSynthesizer) synthesizeVertexCompat(ctx *SynthesisContext) []*cor
|
||||
"base_url": base,
|
||||
"provider_key": providerName,
|
||||
}
|
||||
if compat.Priority != 0 {
|
||||
attrs["priority"] = strconv.Itoa(compat.Priority)
|
||||
}
|
||||
if key != "" {
|
||||
attrs["api_key"] = key
|
||||
}
|
||||
|
||||
@@ -60,6 +60,11 @@ func (AntigravityAuthenticator) Login(ctx context.Context, cfg *config.Config, o
|
||||
opts = &LoginOptions{}
|
||||
}
|
||||
|
||||
callbackPort := antigravityCallbackPort
|
||||
if opts.CallbackPort > 0 {
|
||||
callbackPort = opts.CallbackPort
|
||||
}
|
||||
|
||||
httpClient := util.SetProxy(&cfg.SDKConfig, &http.Client{})
|
||||
|
||||
state, err := misc.GenerateRandomState()
|
||||
@@ -67,7 +72,7 @@ func (AntigravityAuthenticator) Login(ctx context.Context, cfg *config.Config, o
|
||||
return nil, fmt.Errorf("antigravity: failed to generate state: %w", err)
|
||||
}
|
||||
|
||||
srv, port, cbChan, errServer := startAntigravityCallbackServer()
|
||||
srv, port, cbChan, errServer := startAntigravityCallbackServer(callbackPort)
|
||||
if errServer != nil {
|
||||
return nil, fmt.Errorf("antigravity: failed to start callback server: %w", errServer)
|
||||
}
|
||||
@@ -224,13 +229,16 @@ type callbackResult struct {
|
||||
State string
|
||||
}
|
||||
|
||||
func startAntigravityCallbackServer() (*http.Server, int, <-chan callbackResult, error) {
|
||||
addr := fmt.Sprintf(":%d", antigravityCallbackPort)
|
||||
func startAntigravityCallbackServer(port int) (*http.Server, int, <-chan callbackResult, error) {
|
||||
if port <= 0 {
|
||||
port = antigravityCallbackPort
|
||||
}
|
||||
addr := fmt.Sprintf(":%d", port)
|
||||
listener, err := net.Listen("tcp", addr)
|
||||
if err != nil {
|
||||
return nil, 0, nil, err
|
||||
}
|
||||
port := listener.Addr().(*net.TCPAddr).Port
|
||||
port = listener.Addr().(*net.TCPAddr).Port
|
||||
resultCh := make(chan callbackResult, 1)
|
||||
|
||||
mux := http.NewServeMux()
|
||||
@@ -374,7 +382,7 @@ func fetchAntigravityProjectID(ctx context.Context, accessToken string, httpClie
|
||||
// Call loadCodeAssist to get the project
|
||||
loadReqBody := map[string]any{
|
||||
"metadata": map[string]string{
|
||||
"ideType": "IDE_UNSPECIFIED",
|
||||
"ideType": "ANTIGRAVITY",
|
||||
"platform": "PLATFORM_UNSPECIFIED",
|
||||
"pluginType": "GEMINI",
|
||||
},
|
||||
@@ -434,8 +442,134 @@ func fetchAntigravityProjectID(ctx context.Context, accessToken string, httpClie
|
||||
}
|
||||
|
||||
if projectID == "" {
|
||||
return "", fmt.Errorf("no cloudaicompanionProject in response")
|
||||
tierID := "legacy-tier"
|
||||
if tiers, okTiers := loadResp["allowedTiers"].([]any); okTiers {
|
||||
for _, rawTier := range tiers {
|
||||
tier, okTier := rawTier.(map[string]any)
|
||||
if !okTier {
|
||||
continue
|
||||
}
|
||||
if isDefault, okDefault := tier["isDefault"].(bool); okDefault && isDefault {
|
||||
if id, okID := tier["id"].(string); okID && strings.TrimSpace(id) != "" {
|
||||
tierID = strings.TrimSpace(id)
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
projectID, err = antigravityOnboardUser(ctx, accessToken, tierID, httpClient)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
return projectID, nil
|
||||
}
|
||||
|
||||
return projectID, nil
|
||||
}
|
||||
|
||||
// antigravityOnboardUser attempts to fetch the project ID via onboardUser by polling for completion.
|
||||
// It returns an empty string when the operation times out or completes without a project ID.
|
||||
func antigravityOnboardUser(ctx context.Context, accessToken, tierID string, httpClient *http.Client) (string, error) {
|
||||
if httpClient == nil {
|
||||
httpClient = http.DefaultClient
|
||||
}
|
||||
fmt.Println("Antigravity: onboarding user...", tierID)
|
||||
requestBody := map[string]any{
|
||||
"tierId": tierID,
|
||||
"metadata": map[string]string{
|
||||
"ideType": "ANTIGRAVITY",
|
||||
"platform": "PLATFORM_UNSPECIFIED",
|
||||
"pluginType": "GEMINI",
|
||||
},
|
||||
}
|
||||
|
||||
rawBody, errMarshal := json.Marshal(requestBody)
|
||||
if errMarshal != nil {
|
||||
return "", fmt.Errorf("marshal request body: %w", errMarshal)
|
||||
}
|
||||
|
||||
maxAttempts := 5
|
||||
for attempt := 1; attempt <= maxAttempts; attempt++ {
|
||||
log.Debugf("Polling attempt %d/%d", attempt, maxAttempts)
|
||||
|
||||
reqCtx := ctx
|
||||
var cancel context.CancelFunc
|
||||
if reqCtx == nil {
|
||||
reqCtx = context.Background()
|
||||
}
|
||||
reqCtx, cancel = context.WithTimeout(reqCtx, 30*time.Second)
|
||||
|
||||
endpointURL := fmt.Sprintf("%s/%s:onboardUser", antigravityAPIEndpoint, antigravityAPIVersion)
|
||||
req, errRequest := http.NewRequestWithContext(reqCtx, http.MethodPost, endpointURL, strings.NewReader(string(rawBody)))
|
||||
if errRequest != nil {
|
||||
cancel()
|
||||
return "", fmt.Errorf("create request: %w", errRequest)
|
||||
}
|
||||
req.Header.Set("Authorization", "Bearer "+accessToken)
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
req.Header.Set("User-Agent", antigravityAPIUserAgent)
|
||||
req.Header.Set("X-Goog-Api-Client", antigravityAPIClient)
|
||||
req.Header.Set("Client-Metadata", antigravityClientMetadata)
|
||||
|
||||
resp, errDo := httpClient.Do(req)
|
||||
if errDo != nil {
|
||||
cancel()
|
||||
return "", fmt.Errorf("execute request: %w", errDo)
|
||||
}
|
||||
|
||||
bodyBytes, errRead := io.ReadAll(resp.Body)
|
||||
if errClose := resp.Body.Close(); errClose != nil {
|
||||
log.Errorf("close body error: %v", errClose)
|
||||
}
|
||||
cancel()
|
||||
|
||||
if errRead != nil {
|
||||
return "", fmt.Errorf("read response: %w", errRead)
|
||||
}
|
||||
|
||||
if resp.StatusCode == http.StatusOK {
|
||||
var data map[string]any
|
||||
if errDecode := json.Unmarshal(bodyBytes, &data); errDecode != nil {
|
||||
return "", fmt.Errorf("decode response: %w", errDecode)
|
||||
}
|
||||
|
||||
if done, okDone := data["done"].(bool); okDone && done {
|
||||
projectID := ""
|
||||
if responseData, okResp := data["response"].(map[string]any); okResp {
|
||||
switch projectValue := responseData["cloudaicompanionProject"].(type) {
|
||||
case map[string]any:
|
||||
if id, okID := projectValue["id"].(string); okID {
|
||||
projectID = strings.TrimSpace(id)
|
||||
}
|
||||
case string:
|
||||
projectID = strings.TrimSpace(projectValue)
|
||||
}
|
||||
}
|
||||
|
||||
if projectID != "" {
|
||||
log.Infof("Successfully fetched project_id: %s", projectID)
|
||||
return projectID, nil
|
||||
}
|
||||
|
||||
return "", fmt.Errorf("no project_id in response")
|
||||
}
|
||||
|
||||
time.Sleep(2 * time.Second)
|
||||
continue
|
||||
}
|
||||
|
||||
responsePreview := strings.TrimSpace(string(bodyBytes))
|
||||
if len(responsePreview) > 500 {
|
||||
responsePreview = responsePreview[:500]
|
||||
}
|
||||
|
||||
responseErr := responsePreview
|
||||
if len(responseErr) > 200 {
|
||||
responseErr = responseErr[:200]
|
||||
}
|
||||
return "", fmt.Errorf("http %d: %s", resp.StatusCode, responseErr)
|
||||
}
|
||||
|
||||
return "", nil
|
||||
}
|
||||
|
||||
@@ -47,6 +47,11 @@ func (a *ClaudeAuthenticator) Login(ctx context.Context, cfg *config.Config, opt
|
||||
opts = &LoginOptions{}
|
||||
}
|
||||
|
||||
callbackPort := a.CallbackPort
|
||||
if opts.CallbackPort > 0 {
|
||||
callbackPort = opts.CallbackPort
|
||||
}
|
||||
|
||||
pkceCodes, err := claude.GeneratePKCECodes()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("claude pkce generation failed: %w", err)
|
||||
@@ -57,7 +62,7 @@ func (a *ClaudeAuthenticator) Login(ctx context.Context, cfg *config.Config, opt
|
||||
return nil, fmt.Errorf("claude state generation failed: %w", err)
|
||||
}
|
||||
|
||||
oauthServer := claude.NewOAuthServer(a.CallbackPort)
|
||||
oauthServer := claude.NewOAuthServer(callbackPort)
|
||||
if err = oauthServer.Start(); err != nil {
|
||||
if strings.Contains(err.Error(), "already in use") {
|
||||
return nil, claude.NewAuthenticationError(claude.ErrPortInUse, err)
|
||||
@@ -84,15 +89,15 @@ func (a *ClaudeAuthenticator) Login(ctx context.Context, cfg *config.Config, opt
|
||||
fmt.Println("Opening browser for Claude authentication")
|
||||
if !browser.IsAvailable() {
|
||||
log.Warn("No browser available; please open the URL manually")
|
||||
util.PrintSSHTunnelInstructions(a.CallbackPort)
|
||||
util.PrintSSHTunnelInstructions(callbackPort)
|
||||
fmt.Printf("Visit the following URL to continue authentication:\n%s\n", authURL)
|
||||
} else if err = browser.OpenURL(authURL); err != nil {
|
||||
log.Warnf("Failed to open browser automatically: %v", err)
|
||||
util.PrintSSHTunnelInstructions(a.CallbackPort)
|
||||
util.PrintSSHTunnelInstructions(callbackPort)
|
||||
fmt.Printf("Visit the following URL to continue authentication:\n%s\n", authURL)
|
||||
}
|
||||
} else {
|
||||
util.PrintSSHTunnelInstructions(a.CallbackPort)
|
||||
util.PrintSSHTunnelInstructions(callbackPort)
|
||||
fmt.Printf("Visit the following URL to continue authentication:\n%s\n", authURL)
|
||||
}
|
||||
|
||||
|
||||
@@ -47,6 +47,11 @@ func (a *CodexAuthenticator) Login(ctx context.Context, cfg *config.Config, opts
|
||||
opts = &LoginOptions{}
|
||||
}
|
||||
|
||||
callbackPort := a.CallbackPort
|
||||
if opts.CallbackPort > 0 {
|
||||
callbackPort = opts.CallbackPort
|
||||
}
|
||||
|
||||
pkceCodes, err := codex.GeneratePKCECodes()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("codex pkce generation failed: %w", err)
|
||||
@@ -57,7 +62,7 @@ func (a *CodexAuthenticator) Login(ctx context.Context, cfg *config.Config, opts
|
||||
return nil, fmt.Errorf("codex state generation failed: %w", err)
|
||||
}
|
||||
|
||||
oauthServer := codex.NewOAuthServer(a.CallbackPort)
|
||||
oauthServer := codex.NewOAuthServer(callbackPort)
|
||||
if err = oauthServer.Start(); err != nil {
|
||||
if strings.Contains(err.Error(), "already in use") {
|
||||
return nil, codex.NewAuthenticationError(codex.ErrPortInUse, err)
|
||||
@@ -83,15 +88,15 @@ func (a *CodexAuthenticator) Login(ctx context.Context, cfg *config.Config, opts
|
||||
fmt.Println("Opening browser for Codex authentication")
|
||||
if !browser.IsAvailable() {
|
||||
log.Warn("No browser available; please open the URL manually")
|
||||
util.PrintSSHTunnelInstructions(a.CallbackPort)
|
||||
util.PrintSSHTunnelInstructions(callbackPort)
|
||||
fmt.Printf("Visit the following URL to continue authentication:\n%s\n", authURL)
|
||||
} else if err = browser.OpenURL(authURL); err != nil {
|
||||
log.Warnf("Failed to open browser automatically: %v", err)
|
||||
util.PrintSSHTunnelInstructions(a.CallbackPort)
|
||||
util.PrintSSHTunnelInstructions(callbackPort)
|
||||
fmt.Printf("Visit the following URL to continue authentication:\n%s\n", authURL)
|
||||
}
|
||||
} else {
|
||||
util.PrintSSHTunnelInstructions(a.CallbackPort)
|
||||
util.PrintSSHTunnelInstructions(callbackPort)
|
||||
fmt.Printf("Visit the following URL to continue authentication:\n%s\n", authURL)
|
||||
}
|
||||
|
||||
|
||||
@@ -45,8 +45,9 @@ func (a *GeminiAuthenticator) Login(ctx context.Context, cfg *config.Config, opt
|
||||
|
||||
geminiAuth := gemini.NewGeminiAuth()
|
||||
_, err := geminiAuth.GetAuthenticatedClient(ctx, &ts, cfg, &gemini.WebLoginOptions{
|
||||
NoBrowser: opts.NoBrowser,
|
||||
Prompt: opts.Prompt,
|
||||
NoBrowser: opts.NoBrowser,
|
||||
CallbackPort: opts.CallbackPort,
|
||||
Prompt: opts.Prompt,
|
||||
})
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("gemini authentication failed: %w", err)
|
||||
|
||||
@@ -42,9 +42,14 @@ func (a *IFlowAuthenticator) Login(ctx context.Context, cfg *config.Config, opts
|
||||
opts = &LoginOptions{}
|
||||
}
|
||||
|
||||
callbackPort := iflow.CallbackPort
|
||||
if opts.CallbackPort > 0 {
|
||||
callbackPort = opts.CallbackPort
|
||||
}
|
||||
|
||||
authSvc := iflow.NewIFlowAuth(cfg)
|
||||
|
||||
oauthServer := iflow.NewOAuthServer(iflow.CallbackPort)
|
||||
oauthServer := iflow.NewOAuthServer(callbackPort)
|
||||
if err := oauthServer.Start(); err != nil {
|
||||
if strings.Contains(err.Error(), "already in use") {
|
||||
return nil, fmt.Errorf("iflow authentication server port in use: %w", err)
|
||||
@@ -64,21 +69,21 @@ func (a *IFlowAuthenticator) Login(ctx context.Context, cfg *config.Config, opts
|
||||
return nil, fmt.Errorf("iflow auth: failed to generate state: %w", err)
|
||||
}
|
||||
|
||||
authURL, redirectURI := authSvc.AuthorizationURL(state, iflow.CallbackPort)
|
||||
authURL, redirectURI := authSvc.AuthorizationURL(state, callbackPort)
|
||||
|
||||
if !opts.NoBrowser {
|
||||
fmt.Println("Opening browser for iFlow authentication")
|
||||
if !browser.IsAvailable() {
|
||||
log.Warn("No browser available; please open the URL manually")
|
||||
util.PrintSSHTunnelInstructions(iflow.CallbackPort)
|
||||
util.PrintSSHTunnelInstructions(callbackPort)
|
||||
fmt.Printf("Visit the following URL to continue authentication:\n%s\n", authURL)
|
||||
} else if err = browser.OpenURL(authURL); err != nil {
|
||||
log.Warnf("Failed to open browser automatically: %v", err)
|
||||
util.PrintSSHTunnelInstructions(iflow.CallbackPort)
|
||||
util.PrintSSHTunnelInstructions(callbackPort)
|
||||
fmt.Printf("Visit the following URL to continue authentication:\n%s\n", authURL)
|
||||
}
|
||||
} else {
|
||||
util.PrintSSHTunnelInstructions(iflow.CallbackPort)
|
||||
util.PrintSSHTunnelInstructions(callbackPort)
|
||||
fmt.Printf("Visit the following URL to continue authentication:\n%s\n", authURL)
|
||||
}
|
||||
|
||||
|
||||
@@ -14,10 +14,11 @@ var ErrRefreshNotSupported = errors.New("cliproxy auth: refresh not supported")
|
||||
// LoginOptions captures generic knobs shared across authenticators.
|
||||
// Provider-specific logic can inspect Metadata for extra parameters.
|
||||
type LoginOptions struct {
|
||||
NoBrowser bool
|
||||
ProjectID string
|
||||
Metadata map[string]string
|
||||
Prompt func(prompt string) (string, error)
|
||||
NoBrowser bool
|
||||
ProjectID string
|
||||
CallbackPort int
|
||||
Metadata map[string]string
|
||||
Prompt func(prompt string) (string, error)
|
||||
}
|
||||
|
||||
// Authenticator manages login and optional refresh flows for a provider.
|
||||
|
||||
@@ -271,7 +271,6 @@ func (m *Manager) Execute(ctx context.Context, providers []string, req cliproxye
|
||||
if len(normalized) == 0 {
|
||||
return cliproxyexecutor.Response{}, &Error{Code: "provider_not_found", Message: "no provider supplied"}
|
||||
}
|
||||
rotated := m.rotateProviders(req.Model, normalized)
|
||||
|
||||
retryTimes, maxWait := m.retrySettings()
|
||||
attempts := retryTimes + 1
|
||||
@@ -281,14 +280,12 @@ func (m *Manager) Execute(ctx context.Context, providers []string, req cliproxye
|
||||
|
||||
var lastErr error
|
||||
for attempt := 0; attempt < attempts; attempt++ {
|
||||
resp, errExec := m.executeProvidersOnce(ctx, rotated, func(execCtx context.Context, provider string) (cliproxyexecutor.Response, error) {
|
||||
return m.executeWithProvider(execCtx, provider, req, opts)
|
||||
})
|
||||
resp, errExec := m.executeMixedOnce(ctx, normalized, req, opts)
|
||||
if errExec == nil {
|
||||
return resp, nil
|
||||
}
|
||||
lastErr = errExec
|
||||
wait, shouldRetry := m.shouldRetryAfterError(errExec, attempt, attempts, rotated, req.Model, maxWait)
|
||||
wait, shouldRetry := m.shouldRetryAfterError(errExec, attempt, attempts, normalized, req.Model, maxWait)
|
||||
if !shouldRetry {
|
||||
break
|
||||
}
|
||||
@@ -309,7 +306,6 @@ func (m *Manager) ExecuteCount(ctx context.Context, providers []string, req clip
|
||||
if len(normalized) == 0 {
|
||||
return cliproxyexecutor.Response{}, &Error{Code: "provider_not_found", Message: "no provider supplied"}
|
||||
}
|
||||
rotated := m.rotateProviders(req.Model, normalized)
|
||||
|
||||
retryTimes, maxWait := m.retrySettings()
|
||||
attempts := retryTimes + 1
|
||||
@@ -319,14 +315,12 @@ func (m *Manager) ExecuteCount(ctx context.Context, providers []string, req clip
|
||||
|
||||
var lastErr error
|
||||
for attempt := 0; attempt < attempts; attempt++ {
|
||||
resp, errExec := m.executeProvidersOnce(ctx, rotated, func(execCtx context.Context, provider string) (cliproxyexecutor.Response, error) {
|
||||
return m.executeCountWithProvider(execCtx, provider, req, opts)
|
||||
})
|
||||
resp, errExec := m.executeCountMixedOnce(ctx, normalized, req, opts)
|
||||
if errExec == nil {
|
||||
return resp, nil
|
||||
}
|
||||
lastErr = errExec
|
||||
wait, shouldRetry := m.shouldRetryAfterError(errExec, attempt, attempts, rotated, req.Model, maxWait)
|
||||
wait, shouldRetry := m.shouldRetryAfterError(errExec, attempt, attempts, normalized, req.Model, maxWait)
|
||||
if !shouldRetry {
|
||||
break
|
||||
}
|
||||
@@ -347,7 +341,6 @@ func (m *Manager) ExecuteStream(ctx context.Context, providers []string, req cli
|
||||
if len(normalized) == 0 {
|
||||
return nil, &Error{Code: "provider_not_found", Message: "no provider supplied"}
|
||||
}
|
||||
rotated := m.rotateProviders(req.Model, normalized)
|
||||
|
||||
retryTimes, maxWait := m.retrySettings()
|
||||
attempts := retryTimes + 1
|
||||
@@ -357,14 +350,12 @@ func (m *Manager) ExecuteStream(ctx context.Context, providers []string, req cli
|
||||
|
||||
var lastErr error
|
||||
for attempt := 0; attempt < attempts; attempt++ {
|
||||
chunks, errStream := m.executeStreamProvidersOnce(ctx, rotated, func(execCtx context.Context, provider string) (<-chan cliproxyexecutor.StreamChunk, error) {
|
||||
return m.executeStreamWithProvider(execCtx, provider, req, opts)
|
||||
})
|
||||
chunks, errStream := m.executeStreamMixedOnce(ctx, normalized, req, opts)
|
||||
if errStream == nil {
|
||||
return chunks, nil
|
||||
}
|
||||
lastErr = errStream
|
||||
wait, shouldRetry := m.shouldRetryAfterError(errStream, attempt, attempts, rotated, req.Model, maxWait)
|
||||
wait, shouldRetry := m.shouldRetryAfterError(errStream, attempt, attempts, normalized, req.Model, maxWait)
|
||||
if !shouldRetry {
|
||||
break
|
||||
}
|
||||
@@ -378,6 +369,167 @@ func (m *Manager) ExecuteStream(ctx context.Context, providers []string, req cli
|
||||
return nil, &Error{Code: "auth_not_found", Message: "no auth available"}
|
||||
}
|
||||
|
||||
func (m *Manager) executeMixedOnce(ctx context.Context, providers []string, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (cliproxyexecutor.Response, error) {
|
||||
if len(providers) == 0 {
|
||||
return cliproxyexecutor.Response{}, &Error{Code: "provider_not_found", Message: "no provider supplied"}
|
||||
}
|
||||
routeModel := req.Model
|
||||
tried := make(map[string]struct{})
|
||||
var lastErr error
|
||||
for {
|
||||
auth, executor, provider, errPick := m.pickNextMixed(ctx, providers, routeModel, opts, tried)
|
||||
if errPick != nil {
|
||||
if lastErr != nil {
|
||||
return cliproxyexecutor.Response{}, lastErr
|
||||
}
|
||||
return cliproxyexecutor.Response{}, errPick
|
||||
}
|
||||
|
||||
entry := logEntryWithRequestID(ctx)
|
||||
debugLogAuthSelection(entry, auth, provider, req.Model)
|
||||
|
||||
tried[auth.ID] = struct{}{}
|
||||
execCtx := ctx
|
||||
if rt := m.roundTripperFor(auth); rt != nil {
|
||||
execCtx = context.WithValue(execCtx, roundTripperContextKey{}, rt)
|
||||
execCtx = context.WithValue(execCtx, "cliproxy.roundtripper", rt)
|
||||
}
|
||||
execReq := req
|
||||
execReq.Model, execReq.Metadata = rewriteModelForAuth(routeModel, req.Metadata, auth)
|
||||
execReq.Model, execReq.Metadata = m.applyOAuthModelMapping(auth, execReq.Model, execReq.Metadata)
|
||||
resp, errExec := executor.Execute(execCtx, auth, execReq, opts)
|
||||
result := Result{AuthID: auth.ID, Provider: provider, Model: routeModel, Success: errExec == nil}
|
||||
if errExec != nil {
|
||||
result.Error = &Error{Message: errExec.Error()}
|
||||
var se cliproxyexecutor.StatusError
|
||||
if errors.As(errExec, &se) && se != nil {
|
||||
result.Error.HTTPStatus = se.StatusCode()
|
||||
}
|
||||
if ra := retryAfterFromError(errExec); ra != nil {
|
||||
result.RetryAfter = ra
|
||||
}
|
||||
m.MarkResult(execCtx, result)
|
||||
lastErr = errExec
|
||||
continue
|
||||
}
|
||||
m.MarkResult(execCtx, result)
|
||||
return resp, nil
|
||||
}
|
||||
}
|
||||
|
||||
func (m *Manager) executeCountMixedOnce(ctx context.Context, providers []string, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (cliproxyexecutor.Response, error) {
|
||||
if len(providers) == 0 {
|
||||
return cliproxyexecutor.Response{}, &Error{Code: "provider_not_found", Message: "no provider supplied"}
|
||||
}
|
||||
routeModel := req.Model
|
||||
tried := make(map[string]struct{})
|
||||
var lastErr error
|
||||
for {
|
||||
auth, executor, provider, errPick := m.pickNextMixed(ctx, providers, routeModel, opts, tried)
|
||||
if errPick != nil {
|
||||
if lastErr != nil {
|
||||
return cliproxyexecutor.Response{}, lastErr
|
||||
}
|
||||
return cliproxyexecutor.Response{}, errPick
|
||||
}
|
||||
|
||||
entry := logEntryWithRequestID(ctx)
|
||||
debugLogAuthSelection(entry, auth, provider, req.Model)
|
||||
|
||||
tried[auth.ID] = struct{}{}
|
||||
execCtx := ctx
|
||||
if rt := m.roundTripperFor(auth); rt != nil {
|
||||
execCtx = context.WithValue(execCtx, roundTripperContextKey{}, rt)
|
||||
execCtx = context.WithValue(execCtx, "cliproxy.roundtripper", rt)
|
||||
}
|
||||
execReq := req
|
||||
execReq.Model, execReq.Metadata = rewriteModelForAuth(routeModel, req.Metadata, auth)
|
||||
execReq.Model, execReq.Metadata = m.applyOAuthModelMapping(auth, execReq.Model, execReq.Metadata)
|
||||
resp, errExec := executor.CountTokens(execCtx, auth, execReq, opts)
|
||||
result := Result{AuthID: auth.ID, Provider: provider, Model: routeModel, Success: errExec == nil}
|
||||
if errExec != nil {
|
||||
result.Error = &Error{Message: errExec.Error()}
|
||||
var se cliproxyexecutor.StatusError
|
||||
if errors.As(errExec, &se) && se != nil {
|
||||
result.Error.HTTPStatus = se.StatusCode()
|
||||
}
|
||||
if ra := retryAfterFromError(errExec); ra != nil {
|
||||
result.RetryAfter = ra
|
||||
}
|
||||
m.MarkResult(execCtx, result)
|
||||
lastErr = errExec
|
||||
continue
|
||||
}
|
||||
m.MarkResult(execCtx, result)
|
||||
return resp, nil
|
||||
}
|
||||
}
|
||||
|
||||
func (m *Manager) executeStreamMixedOnce(ctx context.Context, providers []string, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (<-chan cliproxyexecutor.StreamChunk, error) {
|
||||
if len(providers) == 0 {
|
||||
return nil, &Error{Code: "provider_not_found", Message: "no provider supplied"}
|
||||
}
|
||||
routeModel := req.Model
|
||||
tried := make(map[string]struct{})
|
||||
var lastErr error
|
||||
for {
|
||||
auth, executor, provider, errPick := m.pickNextMixed(ctx, providers, routeModel, opts, tried)
|
||||
if errPick != nil {
|
||||
if lastErr != nil {
|
||||
return nil, lastErr
|
||||
}
|
||||
return nil, errPick
|
||||
}
|
||||
|
||||
entry := logEntryWithRequestID(ctx)
|
||||
debugLogAuthSelection(entry, auth, provider, req.Model)
|
||||
|
||||
tried[auth.ID] = struct{}{}
|
||||
execCtx := ctx
|
||||
if rt := m.roundTripperFor(auth); rt != nil {
|
||||
execCtx = context.WithValue(execCtx, roundTripperContextKey{}, rt)
|
||||
execCtx = context.WithValue(execCtx, "cliproxy.roundtripper", rt)
|
||||
}
|
||||
execReq := req
|
||||
execReq.Model, execReq.Metadata = rewriteModelForAuth(routeModel, req.Metadata, auth)
|
||||
execReq.Model, execReq.Metadata = m.applyOAuthModelMapping(auth, execReq.Model, execReq.Metadata)
|
||||
chunks, errStream := executor.ExecuteStream(execCtx, auth, execReq, opts)
|
||||
if errStream != nil {
|
||||
rerr := &Error{Message: errStream.Error()}
|
||||
var se cliproxyexecutor.StatusError
|
||||
if errors.As(errStream, &se) && se != nil {
|
||||
rerr.HTTPStatus = se.StatusCode()
|
||||
}
|
||||
result := Result{AuthID: auth.ID, Provider: provider, Model: routeModel, Success: false, Error: rerr}
|
||||
result.RetryAfter = retryAfterFromError(errStream)
|
||||
m.MarkResult(execCtx, result)
|
||||
lastErr = errStream
|
||||
continue
|
||||
}
|
||||
out := make(chan cliproxyexecutor.StreamChunk)
|
||||
go func(streamCtx context.Context, streamAuth *Auth, streamProvider string, streamChunks <-chan cliproxyexecutor.StreamChunk) {
|
||||
defer close(out)
|
||||
var failed bool
|
||||
for chunk := range streamChunks {
|
||||
if chunk.Err != nil && !failed {
|
||||
failed = true
|
||||
rerr := &Error{Message: chunk.Err.Error()}
|
||||
var se cliproxyexecutor.StatusError
|
||||
if errors.As(chunk.Err, &se) && se != nil {
|
||||
rerr.HTTPStatus = se.StatusCode()
|
||||
}
|
||||
m.MarkResult(streamCtx, Result{AuthID: streamAuth.ID, Provider: streamProvider, Model: routeModel, Success: false, Error: rerr})
|
||||
}
|
||||
out <- chunk
|
||||
}
|
||||
if !failed {
|
||||
m.MarkResult(streamCtx, Result{AuthID: streamAuth.ID, Provider: streamProvider, Model: routeModel, Success: true})
|
||||
}
|
||||
}(execCtx, auth.Clone(), provider, chunks)
|
||||
return out, nil
|
||||
}
|
||||
}
|
||||
|
||||
func (m *Manager) executeWithProvider(ctx context.Context, provider string, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (cliproxyexecutor.Response, error) {
|
||||
if provider == "" {
|
||||
return cliproxyexecutor.Response{}, &Error{Code: "provider_not_found", Message: "provider identifier is empty"}
|
||||
@@ -1191,6 +1343,77 @@ func (m *Manager) pickNext(ctx context.Context, provider, model string, opts cli
|
||||
return authCopy, executor, nil
|
||||
}
|
||||
|
||||
func (m *Manager) pickNextMixed(ctx context.Context, providers []string, model string, opts cliproxyexecutor.Options, tried map[string]struct{}) (*Auth, ProviderExecutor, string, error) {
|
||||
providerSet := make(map[string]struct{}, len(providers))
|
||||
for _, provider := range providers {
|
||||
p := strings.TrimSpace(strings.ToLower(provider))
|
||||
if p == "" {
|
||||
continue
|
||||
}
|
||||
providerSet[p] = struct{}{}
|
||||
}
|
||||
if len(providerSet) == 0 {
|
||||
return nil, nil, "", &Error{Code: "provider_not_found", Message: "no provider supplied"}
|
||||
}
|
||||
|
||||
m.mu.RLock()
|
||||
candidates := make([]*Auth, 0, len(m.auths))
|
||||
modelKey := strings.TrimSpace(model)
|
||||
registryRef := registry.GetGlobalRegistry()
|
||||
for _, candidate := range m.auths {
|
||||
if candidate == nil || candidate.Disabled {
|
||||
continue
|
||||
}
|
||||
providerKey := strings.TrimSpace(strings.ToLower(candidate.Provider))
|
||||
if providerKey == "" {
|
||||
continue
|
||||
}
|
||||
if _, ok := providerSet[providerKey]; !ok {
|
||||
continue
|
||||
}
|
||||
if _, used := tried[candidate.ID]; used {
|
||||
continue
|
||||
}
|
||||
if _, ok := m.executors[providerKey]; !ok {
|
||||
continue
|
||||
}
|
||||
if modelKey != "" && registryRef != nil && !registryRef.ClientSupportsModel(candidate.ID, modelKey) {
|
||||
continue
|
||||
}
|
||||
candidates = append(candidates, candidate)
|
||||
}
|
||||
if len(candidates) == 0 {
|
||||
m.mu.RUnlock()
|
||||
return nil, nil, "", &Error{Code: "auth_not_found", Message: "no auth available"}
|
||||
}
|
||||
selected, errPick := m.selector.Pick(ctx, "mixed", model, opts, candidates)
|
||||
if errPick != nil {
|
||||
m.mu.RUnlock()
|
||||
return nil, nil, "", errPick
|
||||
}
|
||||
if selected == nil {
|
||||
m.mu.RUnlock()
|
||||
return nil, nil, "", &Error{Code: "auth_not_found", Message: "selector returned no auth"}
|
||||
}
|
||||
providerKey := strings.TrimSpace(strings.ToLower(selected.Provider))
|
||||
executor, okExecutor := m.executors[providerKey]
|
||||
if !okExecutor {
|
||||
m.mu.RUnlock()
|
||||
return nil, nil, "", &Error{Code: "executor_not_found", Message: "executor not registered"}
|
||||
}
|
||||
authCopy := selected.Clone()
|
||||
m.mu.RUnlock()
|
||||
if !selected.indexAssigned {
|
||||
m.mu.Lock()
|
||||
if current := m.auths[authCopy.ID]; current != nil && !current.indexAssigned {
|
||||
current.EnsureIndex()
|
||||
authCopy = current.Clone()
|
||||
}
|
||||
m.mu.Unlock()
|
||||
}
|
||||
return authCopy, executor, providerKey, nil
|
||||
}
|
||||
|
||||
func (m *Manager) persist(ctx context.Context, auth *Auth) error {
|
||||
if m.store == nil || auth == nil {
|
||||
return nil
|
||||
|
||||
@@ -8,6 +8,7 @@ import (
|
||||
"net/http"
|
||||
"sort"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
@@ -103,13 +104,29 @@ func (e *modelCooldownError) Headers() http.Header {
|
||||
return headers
|
||||
}
|
||||
|
||||
func collectAvailable(auths []*Auth, model string, now time.Time) (available []*Auth, cooldownCount int, earliest time.Time) {
|
||||
available = make([]*Auth, 0, len(auths))
|
||||
func authPriority(auth *Auth) int {
|
||||
if auth == nil || auth.Attributes == nil {
|
||||
return 0
|
||||
}
|
||||
raw := strings.TrimSpace(auth.Attributes["priority"])
|
||||
if raw == "" {
|
||||
return 0
|
||||
}
|
||||
parsed, err := strconv.Atoi(raw)
|
||||
if err != nil {
|
||||
return 0
|
||||
}
|
||||
return parsed
|
||||
}
|
||||
|
||||
func collectAvailableByPriority(auths []*Auth, model string, now time.Time) (available map[int][]*Auth, cooldownCount int, earliest time.Time) {
|
||||
available = make(map[int][]*Auth)
|
||||
for i := 0; i < len(auths); i++ {
|
||||
candidate := auths[i]
|
||||
blocked, reason, next := isAuthBlockedForModel(candidate, model, now)
|
||||
if !blocked {
|
||||
available = append(available, candidate)
|
||||
priority := authPriority(candidate)
|
||||
available[priority] = append(available[priority], candidate)
|
||||
continue
|
||||
}
|
||||
if reason == blockReasonCooldown {
|
||||
@@ -119,9 +136,6 @@ func collectAvailable(auths []*Auth, model string, now time.Time) (available []*
|
||||
}
|
||||
}
|
||||
}
|
||||
if len(available) > 1 {
|
||||
sort.Slice(available, func(i, j int) bool { return available[i].ID < available[j].ID })
|
||||
}
|
||||
return available, cooldownCount, earliest
|
||||
}
|
||||
|
||||
@@ -130,18 +144,35 @@ func getAvailableAuths(auths []*Auth, provider, model string, now time.Time) ([]
|
||||
return nil, &Error{Code: "auth_not_found", Message: "no auth candidates"}
|
||||
}
|
||||
|
||||
available, cooldownCount, earliest := collectAvailable(auths, model, now)
|
||||
if len(available) == 0 {
|
||||
availableByPriority, cooldownCount, earliest := collectAvailableByPriority(auths, model, now)
|
||||
if len(availableByPriority) == 0 {
|
||||
if cooldownCount == len(auths) && !earliest.IsZero() {
|
||||
providerForError := provider
|
||||
if providerForError == "mixed" {
|
||||
providerForError = ""
|
||||
}
|
||||
resetIn := earliest.Sub(now)
|
||||
if resetIn < 0 {
|
||||
resetIn = 0
|
||||
}
|
||||
return nil, newModelCooldownError(model, provider, resetIn)
|
||||
return nil, newModelCooldownError(model, providerForError, resetIn)
|
||||
}
|
||||
return nil, &Error{Code: "auth_unavailable", Message: "no auth available"}
|
||||
}
|
||||
|
||||
bestPriority := 0
|
||||
found := false
|
||||
for priority := range availableByPriority {
|
||||
if !found || priority > bestPriority {
|
||||
bestPriority = priority
|
||||
found = true
|
||||
}
|
||||
}
|
||||
|
||||
available := availableByPriority[bestPriority]
|
||||
if len(available) > 1 {
|
||||
sort.Slice(available, func(i, j int) bool { return available[i].ID < available[j].ID })
|
||||
}
|
||||
return available, nil
|
||||
}
|
||||
|
||||
|
||||
@@ -5,6 +5,7 @@ import (
|
||||
"errors"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
cliproxyexecutor "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/executor"
|
||||
)
|
||||
@@ -56,6 +57,69 @@ func TestRoundRobinSelectorPick_CyclesDeterministic(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestRoundRobinSelectorPick_PriorityBuckets(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
selector := &RoundRobinSelector{}
|
||||
auths := []*Auth{
|
||||
{ID: "c", Attributes: map[string]string{"priority": "0"}},
|
||||
{ID: "a", Attributes: map[string]string{"priority": "10"}},
|
||||
{ID: "b", Attributes: map[string]string{"priority": "10"}},
|
||||
}
|
||||
|
||||
want := []string{"a", "b", "a", "b"}
|
||||
for i, id := range want {
|
||||
got, err := selector.Pick(context.Background(), "mixed", "", cliproxyexecutor.Options{}, auths)
|
||||
if err != nil {
|
||||
t.Fatalf("Pick() #%d error = %v", i, err)
|
||||
}
|
||||
if got == nil {
|
||||
t.Fatalf("Pick() #%d auth = nil", i)
|
||||
}
|
||||
if got.ID != id {
|
||||
t.Fatalf("Pick() #%d auth.ID = %q, want %q", i, got.ID, id)
|
||||
}
|
||||
if got.ID == "c" {
|
||||
t.Fatalf("Pick() #%d unexpectedly selected lower priority auth", i)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestFillFirstSelectorPick_PriorityFallbackCooldown(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
selector := &FillFirstSelector{}
|
||||
now := time.Now()
|
||||
model := "test-model"
|
||||
|
||||
high := &Auth{
|
||||
ID: "high",
|
||||
Attributes: map[string]string{"priority": "10"},
|
||||
ModelStates: map[string]*ModelState{
|
||||
model: {
|
||||
Status: StatusActive,
|
||||
Unavailable: true,
|
||||
NextRetryAfter: now.Add(30 * time.Minute),
|
||||
Quota: QuotaState{
|
||||
Exceeded: true,
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
low := &Auth{ID: "low", Attributes: map[string]string{"priority": "0"}}
|
||||
|
||||
got, err := selector.Pick(context.Background(), "mixed", model, cliproxyexecutor.Options{}, []*Auth{high, low})
|
||||
if err != nil {
|
||||
t.Fatalf("Pick() error = %v", err)
|
||||
}
|
||||
if got == nil {
|
||||
t.Fatalf("Pick() auth = nil")
|
||||
}
|
||||
if got.ID != "low" {
|
||||
t.Fatalf("Pick() auth.ID = %q, want %q", got.ID, "low")
|
||||
}
|
||||
}
|
||||
|
||||
func TestRoundRobinSelectorPick_Concurrent(t *testing.T) {
|
||||
selector := &RoundRobinSelector{}
|
||||
auths := []*Auth{
|
||||
|
||||
Reference in New Issue
Block a user