From c7e8830a563d2615753162381fc2d0937c5ce0aa Mon Sep 17 00:00:00 2001 From: hkfires <10558748+hkfires@users.noreply.github.com> Date: Sat, 17 Jan 2026 22:53:10 +0800 Subject: [PATCH] refactor(thinking): pass source and target formats to ApplyThinking for cross-format validation Update ApplyThinking signature to accept fromFormat and toFormat parameters instead of a single provider string. This enables: - Proper level-to-budget conversion when source is level-based (openai/codex) and target is budget-based (gemini/claude) - Strict budget range validation when source and target formats match - Level clamping to nearest supported level for cross-format requests - Format alias resolution in SDK translator registry for codex/openai-response Also adds ErrBudgetOutOfRange error code and improves iflow config extraction to fall back to openai format when iflow-specific config is not present. --- .../runtime/executor/aistudio_executor.go | 2 +- .../runtime/executor/antigravity_executor.go | 8 +- internal/runtime/executor/claude_executor.go | 4 +- internal/runtime/executor/codex_executor.go | 6 +- .../runtime/executor/gemini_cli_executor.go | 6 +- internal/runtime/executor/gemini_executor.go | 6 +- .../executor/gemini_vertex_executor.go | 12 +- internal/runtime/executor/iflow_executor.go | 4 +- .../executor/openai_compat_executor.go | 6 +- internal/runtime/executor/qwen_executor.go | 4 +- internal/thinking/apply.go | 86 ++++-- internal/thinking/errors.go | 4 + internal/thinking/strip.go | 32 ++- internal/thinking/validate.go | 270 ++++++++++++------ sdk/translator/registry.go | 62 +++- 15 files changed, 341 insertions(+), 171 deletions(-) diff --git a/internal/runtime/executor/aistudio_executor.go b/internal/runtime/executor/aistudio_executor.go index fffb50c4..a020c670 100644 --- a/internal/runtime/executor/aistudio_executor.go +++ b/internal/runtime/executor/aistudio_executor.go @@ -393,7 +393,7 @@ func (e *AIStudioExecutor) translateRequest(req cliproxyexecutor.Request, opts c } originalTranslated := sdktranslator.TranslateRequest(from, to, baseModel, originalPayload, stream) payload := sdktranslator.TranslateRequest(from, to, baseModel, bytes.Clone(req.Payload), stream) - payload, err := thinking.ApplyThinking(payload, req.Model, "gemini") + payload, err := thinking.ApplyThinking(payload, req.Model, from.String(), to.String()) if err != nil { return nil, translatedPayload{}, err } diff --git a/internal/runtime/executor/antigravity_executor.go b/internal/runtime/executor/antigravity_executor.go index 47113cfc..99392188 100644 --- a/internal/runtime/executor/antigravity_executor.go +++ b/internal/runtime/executor/antigravity_executor.go @@ -137,7 +137,7 @@ func (e *AntigravityExecutor) Execute(ctx context.Context, auth *cliproxyauth.Au originalTranslated := sdktranslator.TranslateRequest(from, to, baseModel, originalPayload, false) translated := sdktranslator.TranslateRequest(from, to, baseModel, bytes.Clone(req.Payload), false) - translated, err = thinking.ApplyThinking(translated, req.Model, "antigravity") + translated, err = thinking.ApplyThinking(translated, req.Model, from.String(), to.String()) if err != nil { return resp, err } @@ -256,7 +256,7 @@ func (e *AntigravityExecutor) executeClaudeNonStream(ctx context.Context, auth * originalTranslated := sdktranslator.TranslateRequest(from, to, baseModel, originalPayload, true) translated := sdktranslator.TranslateRequest(from, to, baseModel, bytes.Clone(req.Payload), true) - translated, err = thinking.ApplyThinking(translated, req.Model, "antigravity") + translated, err = thinking.ApplyThinking(translated, req.Model, from.String(), to.String()) if err != nil { return resp, err } @@ -622,7 +622,7 @@ func (e *AntigravityExecutor) ExecuteStream(ctx context.Context, auth *cliproxya originalTranslated := sdktranslator.TranslateRequest(from, to, baseModel, originalPayload, true) translated := sdktranslator.TranslateRequest(from, to, baseModel, bytes.Clone(req.Payload), true) - translated, err = thinking.ApplyThinking(translated, req.Model, "antigravity") + translated, err = thinking.ApplyThinking(translated, req.Model, from.String(), to.String()) if err != nil { return nil, err } @@ -802,7 +802,7 @@ func (e *AntigravityExecutor) CountTokens(ctx context.Context, auth *cliproxyaut // Prepare payload once (doesn't depend on baseURL) payload := sdktranslator.TranslateRequest(from, to, baseModel, bytes.Clone(req.Payload), false) - payload, err := thinking.ApplyThinking(payload, req.Model, "antigravity") + payload, err := thinking.ApplyThinking(payload, req.Model, from.String(), to.String()) if err != nil { return cliproxyexecutor.Response{}, err } diff --git a/internal/runtime/executor/claude_executor.go b/internal/runtime/executor/claude_executor.go index b4cbd450..17c5a143 100644 --- a/internal/runtime/executor/claude_executor.go +++ b/internal/runtime/executor/claude_executor.go @@ -106,7 +106,7 @@ func (e *ClaudeExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, r body := sdktranslator.TranslateRequest(from, to, baseModel, bytes.Clone(req.Payload), stream) body, _ = sjson.SetBytes(body, "model", baseModel) - body, err = thinking.ApplyThinking(body, req.Model, "claude") + body, err = thinking.ApplyThinking(body, req.Model, from.String(), to.String()) if err != nil { return resp, err } @@ -239,7 +239,7 @@ func (e *ClaudeExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.A body := sdktranslator.TranslateRequest(from, to, baseModel, bytes.Clone(req.Payload), true) body, _ = sjson.SetBytes(body, "model", baseModel) - body, err = thinking.ApplyThinking(body, req.Model, "claude") + body, err = thinking.ApplyThinking(body, req.Model, from.String(), to.String()) if err != nil { return nil, err } diff --git a/internal/runtime/executor/codex_executor.go b/internal/runtime/executor/codex_executor.go index eeefe6bc..cc0e32a1 100644 --- a/internal/runtime/executor/codex_executor.go +++ b/internal/runtime/executor/codex_executor.go @@ -96,7 +96,7 @@ func (e *CodexExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, re body = sdktranslator.TranslateRequest(from, to, baseModel, body, false) body = misc.StripCodexUserAgent(body) - body, err = thinking.ApplyThinking(body, req.Model, "codex") + body, err = thinking.ApplyThinking(body, req.Model, from.String(), to.String()) if err != nil { return resp, err } @@ -208,7 +208,7 @@ func (e *CodexExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Au body = sdktranslator.TranslateRequest(from, to, baseModel, body, true) body = misc.StripCodexUserAgent(body) - body, err = thinking.ApplyThinking(body, req.Model, "codex") + body, err = thinking.ApplyThinking(body, req.Model, from.String(), to.String()) if err != nil { return nil, err } @@ -316,7 +316,7 @@ func (e *CodexExecutor) CountTokens(ctx context.Context, auth *cliproxyauth.Auth body = sdktranslator.TranslateRequest(from, to, baseModel, body, false) body = misc.StripCodexUserAgent(body) - body, err := thinking.ApplyThinking(body, req.Model, "codex") + body, err := thinking.ApplyThinking(body, req.Model, from.String(), to.String()) if err != nil { return cliproxyexecutor.Response{}, err } diff --git a/internal/runtime/executor/gemini_cli_executor.go b/internal/runtime/executor/gemini_cli_executor.go index add01cb3..b23406af 100644 --- a/internal/runtime/executor/gemini_cli_executor.go +++ b/internal/runtime/executor/gemini_cli_executor.go @@ -123,7 +123,7 @@ func (e *GeminiCLIExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth originalTranslated := sdktranslator.TranslateRequest(from, to, baseModel, originalPayload, false) basePayload := sdktranslator.TranslateRequest(from, to, baseModel, bytes.Clone(req.Payload), false) - basePayload, err = thinking.ApplyThinking(basePayload, req.Model, "gemini-cli") + basePayload, err = thinking.ApplyThinking(basePayload, req.Model, from.String(), to.String()) if err != nil { return resp, err } @@ -272,7 +272,7 @@ func (e *GeminiCLIExecutor) ExecuteStream(ctx context.Context, auth *cliproxyaut originalTranslated := sdktranslator.TranslateRequest(from, to, baseModel, originalPayload, true) basePayload := sdktranslator.TranslateRequest(from, to, baseModel, bytes.Clone(req.Payload), true) - basePayload, err = thinking.ApplyThinking(basePayload, req.Model, "gemini-cli") + basePayload, err = thinking.ApplyThinking(basePayload, req.Model, from.String(), to.String()) if err != nil { return nil, err } @@ -479,7 +479,7 @@ func (e *GeminiCLIExecutor) CountTokens(ctx context.Context, auth *cliproxyauth. for range models { payload := sdktranslator.TranslateRequest(from, to, baseModel, bytes.Clone(req.Payload), false) - payload, err = thinking.ApplyThinking(payload, req.Model, "gemini-cli") + payload, err = thinking.ApplyThinking(payload, req.Model, from.String(), to.String()) if err != nil { return cliproxyexecutor.Response{}, err } diff --git a/internal/runtime/executor/gemini_executor.go b/internal/runtime/executor/gemini_executor.go index 4cc5d945..e9f9dbca 100644 --- a/internal/runtime/executor/gemini_executor.go +++ b/internal/runtime/executor/gemini_executor.go @@ -120,7 +120,7 @@ func (e *GeminiExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, r originalTranslated := sdktranslator.TranslateRequest(from, to, baseModel, originalPayload, false) body := sdktranslator.TranslateRequest(from, to, baseModel, bytes.Clone(req.Payload), false) - body, err = thinking.ApplyThinking(body, req.Model, "gemini") + body, err = thinking.ApplyThinking(body, req.Model, from.String(), to.String()) if err != nil { return resp, err } @@ -222,7 +222,7 @@ func (e *GeminiExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.A originalTranslated := sdktranslator.TranslateRequest(from, to, baseModel, originalPayload, true) body := sdktranslator.TranslateRequest(from, to, baseModel, bytes.Clone(req.Payload), true) - body, err = thinking.ApplyThinking(body, req.Model, "gemini") + body, err = thinking.ApplyThinking(body, req.Model, from.String(), to.String()) if err != nil { return nil, err } @@ -338,7 +338,7 @@ func (e *GeminiExecutor) CountTokens(ctx context.Context, auth *cliproxyauth.Aut to := sdktranslator.FromString("gemini") translatedReq := sdktranslator.TranslateRequest(from, to, baseModel, bytes.Clone(req.Payload), false) - translatedReq, err := thinking.ApplyThinking(translatedReq, req.Model, "gemini") + translatedReq, err := thinking.ApplyThinking(translatedReq, req.Model, from.String(), to.String()) if err != nil { return cliproxyexecutor.Response{}, err } diff --git a/internal/runtime/executor/gemini_vertex_executor.go b/internal/runtime/executor/gemini_vertex_executor.go index 8a412b47..20e59b3f 100644 --- a/internal/runtime/executor/gemini_vertex_executor.go +++ b/internal/runtime/executor/gemini_vertex_executor.go @@ -170,7 +170,7 @@ func (e *GeminiVertexExecutor) executeWithServiceAccount(ctx context.Context, au originalTranslated := sdktranslator.TranslateRequest(from, to, baseModel, originalPayload, false) body := sdktranslator.TranslateRequest(from, to, baseModel, bytes.Clone(req.Payload), false) - body, err = thinking.ApplyThinking(body, req.Model, "gemini") + body, err = thinking.ApplyThinking(body, req.Model, from.String(), to.String()) if err != nil { return resp, err } @@ -272,7 +272,7 @@ func (e *GeminiVertexExecutor) executeWithAPIKey(ctx context.Context, auth *clip originalTranslated := sdktranslator.TranslateRequest(from, to, baseModel, originalPayload, false) body := sdktranslator.TranslateRequest(from, to, baseModel, bytes.Clone(req.Payload), false) - body, err = thinking.ApplyThinking(body, req.Model, "gemini") + body, err = thinking.ApplyThinking(body, req.Model, from.String(), to.String()) if err != nil { return resp, err } @@ -375,7 +375,7 @@ func (e *GeminiVertexExecutor) executeStreamWithServiceAccount(ctx context.Conte originalTranslated := sdktranslator.TranslateRequest(from, to, baseModel, originalPayload, true) body := sdktranslator.TranslateRequest(from, to, baseModel, bytes.Clone(req.Payload), true) - body, err = thinking.ApplyThinking(body, req.Model, "gemini") + body, err = thinking.ApplyThinking(body, req.Model, from.String(), to.String()) if err != nil { return nil, err } @@ -494,7 +494,7 @@ func (e *GeminiVertexExecutor) executeStreamWithAPIKey(ctx context.Context, auth originalTranslated := sdktranslator.TranslateRequest(from, to, baseModel, originalPayload, true) body := sdktranslator.TranslateRequest(from, to, baseModel, bytes.Clone(req.Payload), true) - body, err = thinking.ApplyThinking(body, req.Model, "gemini") + body, err = thinking.ApplyThinking(body, req.Model, from.String(), to.String()) if err != nil { return nil, err } @@ -605,7 +605,7 @@ func (e *GeminiVertexExecutor) countTokensWithServiceAccount(ctx context.Context translatedReq := sdktranslator.TranslateRequest(from, to, baseModel, bytes.Clone(req.Payload), false) - translatedReq, err := thinking.ApplyThinking(translatedReq, req.Model, "gemini") + translatedReq, err := thinking.ApplyThinking(translatedReq, req.Model, from.String(), to.String()) if err != nil { return cliproxyexecutor.Response{}, err } @@ -689,7 +689,7 @@ func (e *GeminiVertexExecutor) countTokensWithAPIKey(ctx context.Context, auth * translatedReq := sdktranslator.TranslateRequest(from, to, baseModel, bytes.Clone(req.Payload), false) - translatedReq, err := thinking.ApplyThinking(translatedReq, req.Model, "gemini") + translatedReq, err := thinking.ApplyThinking(translatedReq, req.Model, from.String(), to.String()) if err != nil { return cliproxyexecutor.Response{}, err } diff --git a/internal/runtime/executor/iflow_executor.go b/internal/runtime/executor/iflow_executor.go index 6ce4221c..3e6ca4e5 100644 --- a/internal/runtime/executor/iflow_executor.go +++ b/internal/runtime/executor/iflow_executor.go @@ -92,7 +92,7 @@ func (e *IFlowExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, re body := sdktranslator.TranslateRequest(from, to, baseModel, bytes.Clone(req.Payload), false) body, _ = sjson.SetBytes(body, "model", baseModel) - body, err = thinking.ApplyThinking(body, req.Model, "iflow") + body, err = thinking.ApplyThinking(body, req.Model, from.String(), "iflow") if err != nil { return resp, err } @@ -190,7 +190,7 @@ func (e *IFlowExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Au body := sdktranslator.TranslateRequest(from, to, baseModel, bytes.Clone(req.Payload), true) body, _ = sjson.SetBytes(body, "model", baseModel) - body, err = thinking.ApplyThinking(body, req.Model, "iflow") + body, err = thinking.ApplyThinking(body, req.Model, from.String(), "iflow") if err != nil { return nil, err } diff --git a/internal/runtime/executor/openai_compat_executor.go b/internal/runtime/executor/openai_compat_executor.go index 6ae9103f..a2bef724 100644 --- a/internal/runtime/executor/openai_compat_executor.go +++ b/internal/runtime/executor/openai_compat_executor.go @@ -92,7 +92,7 @@ func (e *OpenAICompatExecutor) Execute(ctx context.Context, auth *cliproxyauth.A translated := sdktranslator.TranslateRequest(from, to, baseModel, bytes.Clone(req.Payload), opts.Stream) translated = applyPayloadConfigWithRoot(e.cfg, baseModel, to.String(), "", translated, originalTranslated) - translated, err = thinking.ApplyThinking(translated, req.Model, "openai") + translated, err = thinking.ApplyThinking(translated, req.Model, from.String(), to.String()) if err != nil { return resp, err } @@ -187,7 +187,7 @@ func (e *OpenAICompatExecutor) ExecuteStream(ctx context.Context, auth *cliproxy translated := sdktranslator.TranslateRequest(from, to, baseModel, bytes.Clone(req.Payload), true) translated = applyPayloadConfigWithRoot(e.cfg, baseModel, to.String(), "", translated, originalTranslated) - translated, err = thinking.ApplyThinking(translated, req.Model, "openai") + translated, err = thinking.ApplyThinking(translated, req.Model, from.String(), to.String()) if err != nil { return nil, err } @@ -297,7 +297,7 @@ func (e *OpenAICompatExecutor) CountTokens(ctx context.Context, auth *cliproxyau modelForCounting := baseModel - translated, err := thinking.ApplyThinking(translated, req.Model, "openai") + translated, err := thinking.ApplyThinking(translated, req.Model, from.String(), to.String()) if err != nil { return cliproxyexecutor.Response{}, err } diff --git a/internal/runtime/executor/qwen_executor.go b/internal/runtime/executor/qwen_executor.go index ff35c935..260165d9 100644 --- a/internal/runtime/executor/qwen_executor.go +++ b/internal/runtime/executor/qwen_executor.go @@ -86,7 +86,7 @@ func (e *QwenExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, req body := sdktranslator.TranslateRequest(from, to, baseModel, bytes.Clone(req.Payload), false) body, _ = sjson.SetBytes(body, "model", baseModel) - body, err = thinking.ApplyThinking(body, req.Model, "openai") + body, err = thinking.ApplyThinking(body, req.Model, from.String(), to.String()) if err != nil { return resp, err } @@ -172,7 +172,7 @@ func (e *QwenExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Aut body := sdktranslator.TranslateRequest(from, to, baseModel, bytes.Clone(req.Payload), true) body, _ = sjson.SetBytes(body, "model", baseModel) - body, err = thinking.ApplyThinking(body, req.Model, "openai") + body, err = thinking.ApplyThinking(body, req.Model, from.String(), to.String()) if err != nil { return nil, err } diff --git a/internal/thinking/apply.go b/internal/thinking/apply.go index 003405c0..fe7d59b4 100644 --- a/internal/thinking/apply.go +++ b/internal/thinking/apply.go @@ -2,6 +2,8 @@ package thinking import ( + "strings" + "github.com/router-for-me/CLIProxyAPI/v6/internal/registry" log "github.com/sirupsen/logrus" "github.com/tidwall/gjson" @@ -59,7 +61,8 @@ func IsUserDefinedModel(modelInfo *registry.ModelInfo) bool { // Parameters: // - body: Original request body JSON // - model: Model name, optionally with thinking suffix (e.g., "claude-sonnet-4-5(16384)") -// - provider: Provider name (gemini, gemini-cli, antigravity, claude, openai, codex, iflow) +// - fromFormat: Source request format (e.g., openai, codex, gemini) +// - toFormat: Target provider format for the request body (gemini, gemini-cli, antigravity, claude, openai, codex, iflow) // // Returns: // - Modified request body JSON with thinking configuration applied @@ -76,16 +79,21 @@ func IsUserDefinedModel(modelInfo *registry.ModelInfo) bool { // Example: // // // With suffix - suffix config takes priority -// result, err := thinking.ApplyThinking(body, "gemini-2.5-pro(8192)", "gemini") +// result, err := thinking.ApplyThinking(body, "gemini-2.5-pro(8192)", "gemini", "gemini") // // // Without suffix - uses body config -// result, err := thinking.ApplyThinking(body, "gemini-2.5-pro", "gemini") -func ApplyThinking(body []byte, model string, provider string) ([]byte, error) { +// result, err := thinking.ApplyThinking(body, "gemini-2.5-pro", "gemini", "gemini") +func ApplyThinking(body []byte, model string, fromFormat string, toFormat string) ([]byte, error) { + providerFormat := strings.ToLower(strings.TrimSpace(toFormat)) + fromFormat = strings.ToLower(strings.TrimSpace(fromFormat)) + if fromFormat == "" { + fromFormat = providerFormat + } // 1. Route check: Get provider applier - applier := GetProviderApplier(provider) + applier := GetProviderApplier(providerFormat) if applier == nil { log.WithFields(log.Fields{ - "provider": provider, + "provider": providerFormat, "model": model, }).Debug("thinking: unknown provider, passthrough |") return body, nil @@ -100,19 +108,19 @@ func ApplyThinking(body []byte, model string, provider string) ([]byte, error) { // Unknown models are treated as user-defined so thinking config can still be applied. // The upstream service is responsible for validating the configuration. if IsUserDefinedModel(modelInfo) { - return applyUserDefinedModel(body, modelInfo, provider, suffixResult) + return applyUserDefinedModel(body, modelInfo, fromFormat, providerFormat, suffixResult) } if modelInfo.Thinking == nil { - config := extractThinkingConfig(body, provider) + config := extractThinkingConfig(body, providerFormat) if hasThinkingConfig(config) { log.WithFields(log.Fields{ "model": baseModel, - "provider": provider, + "provider": providerFormat, }).Debug("thinking: model does not support thinking, stripping config |") - return StripThinkingConfig(body, provider), nil + return StripThinkingConfig(body, providerFormat), nil } log.WithFields(log.Fields{ - "provider": provider, + "provider": providerFormat, "model": baseModel, }).Debug("thinking: model does not support thinking, passthrough |") return body, nil @@ -121,19 +129,19 @@ func ApplyThinking(body []byte, model string, provider string) ([]byte, error) { // 4. Get config: suffix priority over body var config ThinkingConfig if suffixResult.HasSuffix { - config = parseSuffixToConfig(suffixResult.RawSuffix, provider, model) + config = parseSuffixToConfig(suffixResult.RawSuffix, providerFormat, model) log.WithFields(log.Fields{ - "provider": provider, + "provider": providerFormat, "model": model, "mode": config.Mode, "budget": config.Budget, "level": config.Level, }).Debug("thinking: config from model suffix |") } else { - config = extractThinkingConfig(body, provider) + config = extractThinkingConfig(body, providerFormat) if hasThinkingConfig(config) { log.WithFields(log.Fields{ - "provider": provider, + "provider": providerFormat, "model": modelInfo.ID, "mode": config.Mode, "budget": config.Budget, @@ -144,17 +152,17 @@ func ApplyThinking(body []byte, model string, provider string) ([]byte, error) { if !hasThinkingConfig(config) { log.WithFields(log.Fields{ - "provider": provider, + "provider": providerFormat, "model": modelInfo.ID, }).Debug("thinking: no config found, passthrough |") return body, nil } // 5. Validate and normalize configuration - validated, err := ValidateConfig(config, modelInfo, provider) + validated, err := ValidateConfig(config, modelInfo, fromFormat, providerFormat) if err != nil { log.WithFields(log.Fields{ - "provider": provider, + "provider": providerFormat, "model": modelInfo.ID, "error": err.Error(), }).Warn("thinking: validation failed |") @@ -167,14 +175,14 @@ func ApplyThinking(body []byte, model string, provider string) ([]byte, error) { // Defensive check: ValidateConfig should never return (nil, nil) if validated == nil { log.WithFields(log.Fields{ - "provider": provider, + "provider": providerFormat, "model": modelInfo.ID, }).Warn("thinking: ValidateConfig returned nil config without error, passthrough |") return body, nil } log.WithFields(log.Fields{ - "provider": provider, + "provider": providerFormat, "model": modelInfo.ID, "mode": validated.Mode, "budget": validated.Budget, @@ -228,7 +236,7 @@ func parseSuffixToConfig(rawSuffix, provider, model string) ThinkingConfig { // applyUserDefinedModel applies thinking configuration for user-defined models // without ThinkingSupport validation. -func applyUserDefinedModel(body []byte, modelInfo *registry.ModelInfo, provider string, suffixResult SuffixResult) ([]byte, error) { +func applyUserDefinedModel(body []byte, modelInfo *registry.ModelInfo, fromFormat, toFormat string, suffixResult SuffixResult) ([]byte, error) { // Get model ID for logging modelID := "" if modelInfo != nil { @@ -240,39 +248,57 @@ func applyUserDefinedModel(body []byte, modelInfo *registry.ModelInfo, provider // Get config: suffix priority over body var config ThinkingConfig if suffixResult.HasSuffix { - config = parseSuffixToConfig(suffixResult.RawSuffix, provider, modelID) + config = parseSuffixToConfig(suffixResult.RawSuffix, toFormat, modelID) } else { - config = extractThinkingConfig(body, provider) + config = extractThinkingConfig(body, toFormat) } if !hasThinkingConfig(config) { log.WithFields(log.Fields{ "model": modelID, - "provider": provider, + "provider": toFormat, }).Debug("thinking: user-defined model, passthrough (no config) |") return body, nil } - applier := GetProviderApplier(provider) + applier := GetProviderApplier(toFormat) if applier == nil { log.WithFields(log.Fields{ "model": modelID, - "provider": provider, + "provider": toFormat, }).Debug("thinking: user-defined model, passthrough (unknown provider) |") return body, nil } log.WithFields(log.Fields{ - "provider": provider, + "provider": toFormat, "model": modelID, "mode": config.Mode, "budget": config.Budget, "level": config.Level, }).Debug("thinking: applying config for user-defined model (skip validation)") + config = normalizeUserDefinedConfig(config, fromFormat, toFormat) return applier.Apply(body, config, modelInfo) } +func normalizeUserDefinedConfig(config ThinkingConfig, fromFormat, toFormat string) ThinkingConfig { + if config.Mode != ModeLevel { + return config + } + if !isBudgetBasedProvider(toFormat) || !isLevelBasedProvider(fromFormat) { + return config + } + budget, ok := ConvertLevelToBudget(string(config.Level)) + if !ok { + return config + } + config.Mode = ModeBudget + config.Budget = budget + config.Level = "" + return config +} + // extractThinkingConfig extracts provider-specific thinking config from request body. func extractThinkingConfig(body []byte, provider string) ThinkingConfig { if len(body) == 0 || !gjson.ValidBytes(body) { @@ -289,7 +315,11 @@ func extractThinkingConfig(body []byte, provider string) ThinkingConfig { case "codex": return extractCodexConfig(body) case "iflow": - return extractIFlowConfig(body) + config := extractIFlowConfig(body) + if hasThinkingConfig(config) { + return config + } + return extractOpenAIConfig(body) default: return ThinkingConfig{} } diff --git a/internal/thinking/errors.go b/internal/thinking/errors.go index 1cf9ccd0..5eed9381 100644 --- a/internal/thinking/errors.go +++ b/internal/thinking/errors.go @@ -24,6 +24,10 @@ const ( // Example: using level with a budget-only model ErrLevelNotSupported ErrorCode = "LEVEL_NOT_SUPPORTED" + // ErrBudgetOutOfRange indicates the budget value is outside model range. + // Example: budget 64000 exceeds max 20000 + ErrBudgetOutOfRange ErrorCode = "BUDGET_OUT_OF_RANGE" + // ErrProviderMismatch indicates the provider does not match the model. // Example: applying Claude format to a Gemini model ErrProviderMismatch ErrorCode = "PROVIDER_MISMATCH" diff --git a/internal/thinking/strip.go b/internal/thinking/strip.go index 4904d4d5..eb691715 100644 --- a/internal/thinking/strip.go +++ b/internal/thinking/strip.go @@ -27,28 +27,32 @@ func StripThinkingConfig(body []byte, provider string) []byte { return body } + var paths []string switch provider { case "claude": - result, _ := sjson.DeleteBytes(body, "thinking") - return result + paths = []string{"thinking"} case "gemini": - result, _ := sjson.DeleteBytes(body, "generationConfig.thinkingConfig") - return result + paths = []string{"generationConfig.thinkingConfig"} case "gemini-cli", "antigravity": - result, _ := sjson.DeleteBytes(body, "request.generationConfig.thinkingConfig") - return result + paths = []string{"request.generationConfig.thinkingConfig"} case "openai": - result, _ := sjson.DeleteBytes(body, "reasoning_effort") - return result + paths = []string{"reasoning_effort"} case "codex": - result, _ := sjson.DeleteBytes(body, "reasoning.effort") - return result + paths = []string{"reasoning.effort"} case "iflow": - result, _ := sjson.DeleteBytes(body, "chat_template_kwargs.enable_thinking") - result, _ = sjson.DeleteBytes(result, "chat_template_kwargs.clear_thinking") - result, _ = sjson.DeleteBytes(result, "reasoning_split") - return result + paths = []string{ + "chat_template_kwargs.enable_thinking", + "chat_template_kwargs.clear_thinking", + "reasoning_split", + "reasoning_effort", + } default: return body } + + result := body + for _, path := range paths { + result, _ = sjson.DeleteBytes(result, path) + } + return result } diff --git a/internal/thinking/validate.go b/internal/thinking/validate.go index aabe04eb..853e187d 100644 --- a/internal/thinking/validate.go +++ b/internal/thinking/validate.go @@ -9,64 +9,6 @@ import ( log "github.com/sirupsen/logrus" ) -// ClampBudget clamps a budget value to the model's supported range. -// -// Logging: -// - Warn when value=0 but ZeroAllowed=false -// - Debug when value is clamped to min/max -// -// Fields: provider, model, original_value, clamped_to, min, max -func ClampBudget(value int, modelInfo *registry.ModelInfo, provider string) int { - model := "unknown" - support := (*registry.ThinkingSupport)(nil) - if modelInfo != nil { - if modelInfo.ID != "" { - model = modelInfo.ID - } - support = modelInfo.Thinking - } - if support == nil { - return value - } - - // Auto value (-1) passes through without clamping. - if value == -1 { - return value - } - - min := support.Min - max := support.Max - if value == 0 && !support.ZeroAllowed { - log.WithFields(log.Fields{ - "provider": provider, - "model": model, - "original_value": value, - "clamped_to": min, - "min": min, - "max": max, - }).Warn("thinking: budget zero not allowed |") - return min - } - - // Some models are level-only and do not define numeric budget ranges. - if min == 0 && max == 0 { - return value - } - - if value < min { - if value == 0 && support.ZeroAllowed { - return 0 - } - logClamp(provider, model, value, min, min, max) - return min - } - if value > max { - logClamp(provider, model, value, max, min, max) - return max - } - return value -} - // ValidateConfig validates a thinking configuration against model capabilities. // // This function performs comprehensive validation: @@ -74,10 +16,14 @@ func ClampBudget(value int, modelInfo *registry.ModelInfo, provider string) int // - Auto-converts between Budget and Level formats based on model capability // - Validates that requested level is in the model's supported levels list // - Clamps budget values to model's allowed range +// - When converting Budget -> Level for level-only models, clamps the derived standard level to the nearest supported level +// (special values none/auto are preserved) // // Parameters: // - config: The thinking configuration to validate // - support: Model's ThinkingSupport properties (nil means no thinking support) +// - fromFormat: Source provider format (used to determine strict validation rules) +// - toFormat: Target provider format // // Returns: // - Normalized ThinkingConfig with clamped values @@ -87,9 +33,9 @@ func ClampBudget(value int, modelInfo *registry.ModelInfo, provider string) int // - Budget-only model + Level config → Level converted to Budget // - Level-only model + Budget config → Budget converted to Level // - Hybrid model → preserve original format -func ValidateConfig(config ThinkingConfig, modelInfo *registry.ModelInfo, provider string) (*ThinkingConfig, error) { +func ValidateConfig(config ThinkingConfig, modelInfo *registry.ModelInfo, fromFormat, toFormat string) (*ThinkingConfig, error) { + fromFormat, toFormat = strings.ToLower(strings.TrimSpace(fromFormat)), strings.ToLower(strings.TrimSpace(toFormat)) normalized := config - model := "unknown" support := (*registry.ThinkingSupport)(nil) if modelInfo != nil { @@ -106,6 +52,9 @@ func ValidateConfig(config ThinkingConfig, modelInfo *registry.ModelInfo, provid return &normalized, nil } + allowClampUnsupported := isBudgetBasedProvider(fromFormat) && isLevelBasedProvider(toFormat) + strictBudget := fromFormat != "" && fromFormat == toFormat + capability := detectModelCapability(modelInfo) switch capability { case CapabilityBudgetOnly: @@ -127,8 +76,10 @@ func ValidateConfig(config ThinkingConfig, modelInfo *registry.ModelInfo, provid if !ok { return nil, NewThinkingError(ErrUnknownLevel, fmt.Sprintf("budget %d cannot be converted to a valid level", normalized.Budget)) } + // When converting Budget -> Level for level-only models, clamp the derived standard level + // to the nearest supported level. Special values (none/auto) are preserved. normalized.Mode = ModeLevel - normalized.Level = ThinkingLevel(level) + normalized.Level = clampLevel(ThinkingLevel(level), modelInfo, toFormat) normalized.Budget = 0 } case CapabilityHybrid: @@ -151,18 +102,35 @@ func ValidateConfig(config ThinkingConfig, modelInfo *registry.ModelInfo, provid if len(support.Levels) > 0 && normalized.Mode == ModeLevel { if !isLevelSupported(string(normalized.Level), support.Levels) { - validLevels := normalizeLevels(support.Levels) - message := fmt.Sprintf("level %q not supported, valid levels: %s", strings.ToLower(string(normalized.Level)), strings.Join(validLevels, ", ")) - return nil, NewThinkingError(ErrLevelNotSupported, message) + if allowClampUnsupported { + normalized.Level = clampLevel(normalized.Level, modelInfo, toFormat) + } + if !isLevelSupported(string(normalized.Level), support.Levels) { + // User explicitly specified an unsupported level - return error + // (budget-derived levels may be clamped based on source format) + validLevels := normalizeLevels(support.Levels) + message := fmt.Sprintf("level %q not supported, valid levels: %s", strings.ToLower(string(normalized.Level)), strings.Join(validLevels, ", ")) + return nil, NewThinkingError(ErrLevelNotSupported, message) + } + } + } + + if strictBudget && normalized.Mode == ModeBudget { + min, max := support.Min, support.Max + if min != 0 || max != 0 { + if normalized.Budget < min || normalized.Budget > max || (normalized.Budget == 0 && !support.ZeroAllowed) { + message := fmt.Sprintf("budget %d out of range [%d,%d]", normalized.Budget, min, max) + return nil, NewThinkingError(ErrBudgetOutOfRange, message) + } } } // Convert ModeAuto to mid-range if dynamic not allowed if normalized.Mode == ModeAuto && !support.DynamicAllowed { - normalized = convertAutoToMidRange(normalized, support, provider, model) + normalized = convertAutoToMidRange(normalized, support, toFormat, model) } - if normalized.Mode == ModeNone && provider == "claude" { + if normalized.Mode == ModeNone && toFormat == "claude" { // Claude supports explicit disable via thinking.type="disabled". // Keep Budget=0 so applier can omit budget_tokens. normalized.Budget = 0 @@ -170,7 +138,7 @@ func ValidateConfig(config ThinkingConfig, modelInfo *registry.ModelInfo, provid } else { switch normalized.Mode { case ModeBudget, ModeAuto, ModeNone: - normalized.Budget = ClampBudget(normalized.Budget, modelInfo, provider) + normalized.Budget = clampBudget(normalized.Budget, modelInfo, toFormat) } // ModeNone with clamped Budget > 0: set Level to lowest for Level-only/Hybrid models @@ -183,23 +151,6 @@ func ValidateConfig(config ThinkingConfig, modelInfo *registry.ModelInfo, provid return &normalized, nil } -func isLevelSupported(level string, supported []string) bool { - for _, candidate := range supported { - if strings.EqualFold(level, strings.TrimSpace(candidate)) { - return true - } - } - return false -} - -func normalizeLevels(levels []string) []string { - normalized := make([]string, 0, len(levels)) - for _, level := range levels { - normalized = append(normalized, strings.ToLower(strings.TrimSpace(level))) - } - return normalized -} - // convertAutoToMidRange converts ModeAuto to a mid-range value when dynamic is not allowed. // // This function handles the case where a model does not support dynamic/auto thinking. @@ -246,7 +197,156 @@ func convertAutoToMidRange(config ThinkingConfig, support *registry.ThinkingSupp return config } -// logClamp logs a debug message when budget clamping occurs. +// standardLevelOrder defines the canonical ordering of thinking levels from lowest to highest. +var standardLevelOrder = []ThinkingLevel{LevelMinimal, LevelLow, LevelMedium, LevelHigh, LevelXHigh} + +// clampLevel clamps the given level to the nearest supported level. +// On tie, prefers the lower level. +func clampLevel(level ThinkingLevel, modelInfo *registry.ModelInfo, provider string) ThinkingLevel { + model := "unknown" + var supported []string + if modelInfo != nil { + if modelInfo.ID != "" { + model = modelInfo.ID + } + if modelInfo.Thinking != nil { + supported = modelInfo.Thinking.Levels + } + } + + if len(supported) == 0 || isLevelSupported(string(level), supported) { + return level + } + + pos := levelIndex(string(level)) + if pos == -1 { + return level + } + bestIdx, bestDist := -1, len(standardLevelOrder)+1 + + for _, s := range supported { + if idx := levelIndex(strings.TrimSpace(s)); idx != -1 { + if dist := abs(pos - idx); dist < bestDist || (dist == bestDist && idx < bestIdx) { + bestIdx, bestDist = idx, dist + } + } + } + + if bestIdx >= 0 { + clamped := standardLevelOrder[bestIdx] + log.WithFields(log.Fields{ + "provider": provider, + "model": model, + "original_level": string(level), + "clamped_to": string(clamped), + }).Debug("thinking: level clamped |") + return clamped + } + return level +} + +// clampBudget clamps a budget value to the model's supported range. +func clampBudget(value int, modelInfo *registry.ModelInfo, provider string) int { + model := "unknown" + support := (*registry.ThinkingSupport)(nil) + if modelInfo != nil { + if modelInfo.ID != "" { + model = modelInfo.ID + } + support = modelInfo.Thinking + } + if support == nil { + return value + } + + // Auto value (-1) passes through without clamping. + if value == -1 { + return value + } + + min, max := support.Min, support.Max + if value == 0 && !support.ZeroAllowed { + log.WithFields(log.Fields{ + "provider": provider, + "model": model, + "original_value": value, + "clamped_to": min, + "min": min, + "max": max, + }).Warn("thinking: budget zero not allowed |") + return min + } + + // Some models are level-only and do not define numeric budget ranges. + if min == 0 && max == 0 { + return value + } + + if value < min { + if value == 0 && support.ZeroAllowed { + return 0 + } + logClamp(provider, model, value, min, min, max) + return min + } + if value > max { + logClamp(provider, model, value, max, min, max) + return max + } + return value +} + +func isLevelSupported(level string, supported []string) bool { + for _, s := range supported { + if strings.EqualFold(level, strings.TrimSpace(s)) { + return true + } + } + return false +} + +func levelIndex(level string) int { + for i, l := range standardLevelOrder { + if strings.EqualFold(level, string(l)) { + return i + } + } + return -1 +} + +func normalizeLevels(levels []string) []string { + out := make([]string, len(levels)) + for i, l := range levels { + out[i] = strings.ToLower(strings.TrimSpace(l)) + } + return out +} + +func isBudgetBasedProvider(provider string) bool { + switch provider { + case "gemini", "gemini-cli", "antigravity", "claude": + return true + default: + return false + } +} + +func isLevelBasedProvider(provider string) bool { + switch provider { + case "openai", "openai-response", "codex": + return true + default: + return false + } +} + +func abs(x int) int { + if x < 0 { + return -x + } + return x +} + func logClamp(provider, model string, original, clampedTo, min, max int) { log.WithFields(log.Fields{ "provider": provider, diff --git a/sdk/translator/registry.go b/sdk/translator/registry.go index ace97137..882e80f6 100644 --- a/sdk/translator/registry.go +++ b/sdk/translator/registry.go @@ -38,15 +38,31 @@ func (r *Registry) Register(from, to Format, request RequestTransform, response r.responses[from][to] = response } +// formatAliases returns compatible aliases for a format, ordered by preference. +func formatAliases(format Format) []Format { + switch format { + case "codex": + return []Format{"codex", "openai-response"} + case "openai-response": + return []Format{"openai-response", "codex"} + default: + return []Format{format} + } +} + // TranslateRequest converts a payload between schemas, returning the original payload // if no translator is registered. func (r *Registry) TranslateRequest(from, to Format, model string, rawJSON []byte, stream bool) []byte { r.mu.RLock() defer r.mu.RUnlock() - if byTarget, ok := r.requests[from]; ok { - if fn, isOk := byTarget[to]; isOk && fn != nil { - return fn(model, rawJSON, stream) + for _, fromFormat := range formatAliases(from) { + if byTarget, ok := r.requests[fromFormat]; ok { + for _, toFormat := range formatAliases(to) { + if fn, isOk := byTarget[toFormat]; isOk && fn != nil { + return fn(model, rawJSON, stream) + } + } } } return rawJSON @@ -57,9 +73,13 @@ func (r *Registry) HasResponseTransformer(from, to Format) bool { r.mu.RLock() defer r.mu.RUnlock() - if byTarget, ok := r.responses[from]; ok { - if _, isOk := byTarget[to]; isOk { - return true + for _, toFormat := range formatAliases(to) { + if byTarget, ok := r.responses[toFormat]; ok { + for _, fromFormat := range formatAliases(from) { + if _, isOk := byTarget[fromFormat]; isOk { + return true + } + } } } return false @@ -70,9 +90,13 @@ func (r *Registry) TranslateStream(ctx context.Context, from, to Format, model s r.mu.RLock() defer r.mu.RUnlock() - if byTarget, ok := r.responses[to]; ok { - if fn, isOk := byTarget[from]; isOk && fn.Stream != nil { - return fn.Stream(ctx, model, originalRequestRawJSON, requestRawJSON, rawJSON, param) + for _, toFormat := range formatAliases(to) { + if byTarget, ok := r.responses[toFormat]; ok { + for _, fromFormat := range formatAliases(from) { + if fn, isOk := byTarget[fromFormat]; isOk && fn.Stream != nil { + return fn.Stream(ctx, model, originalRequestRawJSON, requestRawJSON, rawJSON, param) + } + } } } return []string{string(rawJSON)} @@ -83,9 +107,13 @@ func (r *Registry) TranslateNonStream(ctx context.Context, from, to Format, mode r.mu.RLock() defer r.mu.RUnlock() - if byTarget, ok := r.responses[to]; ok { - if fn, isOk := byTarget[from]; isOk && fn.NonStream != nil { - return fn.NonStream(ctx, model, originalRequestRawJSON, requestRawJSON, rawJSON, param) + for _, toFormat := range formatAliases(to) { + if byTarget, ok := r.responses[toFormat]; ok { + for _, fromFormat := range formatAliases(from) { + if fn, isOk := byTarget[fromFormat]; isOk && fn.NonStream != nil { + return fn.NonStream(ctx, model, originalRequestRawJSON, requestRawJSON, rawJSON, param) + } + } } } return string(rawJSON) @@ -96,9 +124,13 @@ func (r *Registry) TranslateTokenCount(ctx context.Context, from, to Format, cou r.mu.RLock() defer r.mu.RUnlock() - if byTarget, ok := r.responses[to]; ok { - if fn, isOk := byTarget[from]; isOk && fn.TokenCount != nil { - return fn.TokenCount(ctx, count) + for _, toFormat := range formatAliases(to) { + if byTarget, ok := r.responses[toFormat]; ok { + for _, fromFormat := range formatAliases(from) { + if fn, isOk := byTarget[fromFormat]; isOk && fn.TokenCount != nil { + return fn.TokenCount(ctx, count) + } + } } } return string(rawJSON)