From e641fde25cf9f8146560042a2c3ba43ef7f329a1 Mon Sep 17 00:00:00 2001 From: hkfires <10558748+hkfires@users.noreply.github.com> Date: Tue, 20 Jan 2026 09:57:06 +0800 Subject: [PATCH] feat(registry): support provider-specific model info lookup --- internal/registry/model_registry.go | 54 +++++++++++++++---- .../runtime/executor/aistudio_executor.go | 2 +- .../runtime/executor/antigravity_executor.go | 8 +-- internal/runtime/executor/claude_executor.go | 4 +- internal/runtime/executor/codex_executor.go | 6 +-- .../runtime/executor/gemini_cli_executor.go | 6 +-- internal/runtime/executor/gemini_executor.go | 6 +-- .../executor/gemini_vertex_executor.go | 12 ++--- internal/runtime/executor/iflow_executor.go | 4 +- .../executor/openai_compat_executor.go | 6 +-- internal/runtime/executor/qwen_executor.go | 4 +- internal/thinking/apply.go | 14 +++-- test/thinking_conversion_test.go | 2 +- 13 files changed, 85 insertions(+), 43 deletions(-) diff --git a/internal/registry/model_registry.go b/internal/registry/model_registry.go index 970c2dc9..5de0ba4a 100644 --- a/internal/registry/model_registry.go +++ b/internal/registry/model_registry.go @@ -78,6 +78,8 @@ type ThinkingSupport struct { type ModelRegistration struct { // Info contains the model metadata Info *ModelInfo + // InfoByProvider maps provider identifiers to specific ModelInfo to support differing capabilities. + InfoByProvider map[string]*ModelInfo // Count is the number of active clients that can provide this model Count int // LastUpdated tracks when this registration was last modified @@ -132,16 +134,19 @@ func GetGlobalRegistry() *ModelRegistry { return globalRegistry } -// LookupModelInfo searches the dynamic registry first, then falls back to static model definitions. -// -// This helper exists because some code paths only have a model ID and still need Thinking and -// max completion token metadata even when the dynamic registry hasn't been populated. -func LookupModelInfo(modelID string) *ModelInfo { +// LookupModelInfo searches dynamic registry (provider-specific > global) then static definitions. +func LookupModelInfo(modelID string, provider ...string) *ModelInfo { modelID = strings.TrimSpace(modelID) if modelID == "" { return nil } - if info := GetGlobalRegistry().GetModelInfo(modelID); info != nil { + + p := "" + if len(provider) > 0 { + p = strings.ToLower(strings.TrimSpace(provider[0])) + } + + if info := GetGlobalRegistry().GetModelInfo(modelID, p); info != nil { return info } return LookupStaticModelInfo(modelID) @@ -297,6 +302,9 @@ func (r *ModelRegistry) RegisterClient(clientID, clientProvider string, models [ if count, okProv := reg.Providers[oldProvider]; okProv { if count <= toRemove { delete(reg.Providers, oldProvider) + if reg.InfoByProvider != nil { + delete(reg.InfoByProvider, oldProvider) + } } else { reg.Providers[oldProvider] = count - toRemove } @@ -346,6 +354,12 @@ func (r *ModelRegistry) RegisterClient(clientID, clientProvider string, models [ model := newModels[id] if reg, ok := r.models[id]; ok { reg.Info = cloneModelInfo(model) + if provider != "" { + if reg.InfoByProvider == nil { + reg.InfoByProvider = make(map[string]*ModelInfo) + } + reg.InfoByProvider[provider] = cloneModelInfo(model) + } reg.LastUpdated = now if reg.QuotaExceededClients != nil { delete(reg.QuotaExceededClients, clientID) @@ -409,11 +423,15 @@ func (r *ModelRegistry) addModelRegistration(modelID, provider string, model *Mo if existing.SuspendedClients == nil { existing.SuspendedClients = make(map[string]string) } + if existing.InfoByProvider == nil { + existing.InfoByProvider = make(map[string]*ModelInfo) + } if provider != "" { if existing.Providers == nil { existing.Providers = make(map[string]int) } existing.Providers[provider]++ + existing.InfoByProvider[provider] = cloneModelInfo(model) } log.Debugf("Incremented count for model %s, now %d clients", modelID, existing.Count) return @@ -421,6 +439,7 @@ func (r *ModelRegistry) addModelRegistration(modelID, provider string, model *Mo registration := &ModelRegistration{ Info: cloneModelInfo(model), + InfoByProvider: make(map[string]*ModelInfo), Count: 1, LastUpdated: now, QuotaExceededClients: make(map[string]*time.Time), @@ -428,6 +447,7 @@ func (r *ModelRegistry) addModelRegistration(modelID, provider string, model *Mo } if provider != "" { registration.Providers = map[string]int{provider: 1} + registration.InfoByProvider[provider] = cloneModelInfo(model) } r.models[modelID] = registration log.Debugf("Registered new model %s from provider %s", modelID, provider) @@ -453,6 +473,9 @@ func (r *ModelRegistry) removeModelRegistration(clientID, modelID, provider stri if count, ok := registration.Providers[provider]; ok { if count <= 1 { delete(registration.Providers, provider) + if registration.InfoByProvider != nil { + delete(registration.InfoByProvider, provider) + } } else { registration.Providers[provider] = count - 1 } @@ -534,6 +557,9 @@ func (r *ModelRegistry) unregisterClientInternal(clientID string) { if count, ok := registration.Providers[provider]; ok { if count <= 1 { delete(registration.Providers, provider) + if registration.InfoByProvider != nil { + delete(registration.InfoByProvider, provider) + } } else { registration.Providers[provider] = count - 1 } @@ -940,12 +966,22 @@ func (r *ModelRegistry) GetModelProviders(modelID string) []string { return result } -// GetModelInfo returns the registered ModelInfo for the given model ID, if present. -// Returns nil if the model is unknown to the registry. -func (r *ModelRegistry) GetModelInfo(modelID string) *ModelInfo { +// GetModelInfo returns ModelInfo, prioritizing provider-specific definition if available. +func (r *ModelRegistry) GetModelInfo(modelID, provider string) *ModelInfo { r.mutex.RLock() defer r.mutex.RUnlock() if reg, ok := r.models[modelID]; ok && reg != nil { + // Try provider specific definition first + if provider != "" && reg.InfoByProvider != nil { + if reg.Providers != nil { + if count, ok := reg.Providers[provider]; ok && count > 0 { + if info, ok := reg.InfoByProvider[provider]; ok && info != nil { + return info + } + } + } + } + // Fallback to global info (last registered) return reg.Info } return nil diff --git a/internal/runtime/executor/aistudio_executor.go b/internal/runtime/executor/aistudio_executor.go index a020c670..eba38b00 100644 --- a/internal/runtime/executor/aistudio_executor.go +++ b/internal/runtime/executor/aistudio_executor.go @@ -393,7 +393,7 @@ func (e *AIStudioExecutor) translateRequest(req cliproxyexecutor.Request, opts c } originalTranslated := sdktranslator.TranslateRequest(from, to, baseModel, originalPayload, stream) payload := sdktranslator.TranslateRequest(from, to, baseModel, bytes.Clone(req.Payload), stream) - payload, err := thinking.ApplyThinking(payload, req.Model, from.String(), to.String()) + payload, err := thinking.ApplyThinking(payload, req.Model, from.String(), to.String(), e.Identifier()) if err != nil { return nil, translatedPayload{}, err } diff --git a/internal/runtime/executor/antigravity_executor.go b/internal/runtime/executor/antigravity_executor.go index df26e376..55cc1626 100644 --- a/internal/runtime/executor/antigravity_executor.go +++ b/internal/runtime/executor/antigravity_executor.go @@ -137,7 +137,7 @@ func (e *AntigravityExecutor) Execute(ctx context.Context, auth *cliproxyauth.Au originalTranslated := sdktranslator.TranslateRequest(from, to, baseModel, originalPayload, false) translated := sdktranslator.TranslateRequest(from, to, baseModel, bytes.Clone(req.Payload), false) - translated, err = thinking.ApplyThinking(translated, req.Model, from.String(), to.String()) + translated, err = thinking.ApplyThinking(translated, req.Model, from.String(), to.String(), e.Identifier()) if err != nil { return resp, err } @@ -256,7 +256,7 @@ func (e *AntigravityExecutor) executeClaudeNonStream(ctx context.Context, auth * originalTranslated := sdktranslator.TranslateRequest(from, to, baseModel, originalPayload, true) translated := sdktranslator.TranslateRequest(from, to, baseModel, bytes.Clone(req.Payload), true) - translated, err = thinking.ApplyThinking(translated, req.Model, from.String(), to.String()) + translated, err = thinking.ApplyThinking(translated, req.Model, from.String(), to.String(), e.Identifier()) if err != nil { return resp, err } @@ -622,7 +622,7 @@ func (e *AntigravityExecutor) ExecuteStream(ctx context.Context, auth *cliproxya originalTranslated := sdktranslator.TranslateRequest(from, to, baseModel, originalPayload, true) translated := sdktranslator.TranslateRequest(from, to, baseModel, bytes.Clone(req.Payload), true) - translated, err = thinking.ApplyThinking(translated, req.Model, from.String(), to.String()) + translated, err = thinking.ApplyThinking(translated, req.Model, from.String(), to.String(), e.Identifier()) if err != nil { return nil, err } @@ -802,7 +802,7 @@ func (e *AntigravityExecutor) CountTokens(ctx context.Context, auth *cliproxyaut // Prepare payload once (doesn't depend on baseURL) payload := sdktranslator.TranslateRequest(from, to, baseModel, bytes.Clone(req.Payload), false) - payload, err := thinking.ApplyThinking(payload, req.Model, from.String(), to.String()) + payload, err := thinking.ApplyThinking(payload, req.Model, from.String(), to.String(), e.Identifier()) if err != nil { return cliproxyexecutor.Response{}, err } diff --git a/internal/runtime/executor/claude_executor.go b/internal/runtime/executor/claude_executor.go index b6d5418a..d5b3132a 100644 --- a/internal/runtime/executor/claude_executor.go +++ b/internal/runtime/executor/claude_executor.go @@ -105,7 +105,7 @@ func (e *ClaudeExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, r body := sdktranslator.TranslateRequest(from, to, baseModel, bytes.Clone(req.Payload), stream) body, _ = sjson.SetBytes(body, "model", baseModel) - body, err = thinking.ApplyThinking(body, req.Model, from.String(), to.String()) + body, err = thinking.ApplyThinking(body, req.Model, from.String(), to.String(), e.Identifier()) if err != nil { return resp, err } @@ -235,7 +235,7 @@ func (e *ClaudeExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.A body := sdktranslator.TranslateRequest(from, to, baseModel, bytes.Clone(req.Payload), true) body, _ = sjson.SetBytes(body, "model", baseModel) - body, err = thinking.ApplyThinking(body, req.Model, from.String(), to.String()) + body, err = thinking.ApplyThinking(body, req.Model, from.String(), to.String(), e.Identifier()) if err != nil { return nil, err } diff --git a/internal/runtime/executor/codex_executor.go b/internal/runtime/executor/codex_executor.go index cc0e32a1..a283df86 100644 --- a/internal/runtime/executor/codex_executor.go +++ b/internal/runtime/executor/codex_executor.go @@ -96,7 +96,7 @@ func (e *CodexExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, re body = sdktranslator.TranslateRequest(from, to, baseModel, body, false) body = misc.StripCodexUserAgent(body) - body, err = thinking.ApplyThinking(body, req.Model, from.String(), to.String()) + body, err = thinking.ApplyThinking(body, req.Model, from.String(), to.String(), e.Identifier()) if err != nil { return resp, err } @@ -208,7 +208,7 @@ func (e *CodexExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Au body = sdktranslator.TranslateRequest(from, to, baseModel, body, true) body = misc.StripCodexUserAgent(body) - body, err = thinking.ApplyThinking(body, req.Model, from.String(), to.String()) + body, err = thinking.ApplyThinking(body, req.Model, from.String(), to.String(), e.Identifier()) if err != nil { return nil, err } @@ -316,7 +316,7 @@ func (e *CodexExecutor) CountTokens(ctx context.Context, auth *cliproxyauth.Auth body = sdktranslator.TranslateRequest(from, to, baseModel, body, false) body = misc.StripCodexUserAgent(body) - body, err := thinking.ApplyThinking(body, req.Model, from.String(), to.String()) + body, err := thinking.ApplyThinking(body, req.Model, from.String(), to.String(), e.Identifier()) if err != nil { return cliproxyexecutor.Response{}, err } diff --git a/internal/runtime/executor/gemini_cli_executor.go b/internal/runtime/executor/gemini_cli_executor.go index b23406af..ba321ca5 100644 --- a/internal/runtime/executor/gemini_cli_executor.go +++ b/internal/runtime/executor/gemini_cli_executor.go @@ -123,7 +123,7 @@ func (e *GeminiCLIExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth originalTranslated := sdktranslator.TranslateRequest(from, to, baseModel, originalPayload, false) basePayload := sdktranslator.TranslateRequest(from, to, baseModel, bytes.Clone(req.Payload), false) - basePayload, err = thinking.ApplyThinking(basePayload, req.Model, from.String(), to.String()) + basePayload, err = thinking.ApplyThinking(basePayload, req.Model, from.String(), to.String(), e.Identifier()) if err != nil { return resp, err } @@ -272,7 +272,7 @@ func (e *GeminiCLIExecutor) ExecuteStream(ctx context.Context, auth *cliproxyaut originalTranslated := sdktranslator.TranslateRequest(from, to, baseModel, originalPayload, true) basePayload := sdktranslator.TranslateRequest(from, to, baseModel, bytes.Clone(req.Payload), true) - basePayload, err = thinking.ApplyThinking(basePayload, req.Model, from.String(), to.String()) + basePayload, err = thinking.ApplyThinking(basePayload, req.Model, from.String(), to.String(), e.Identifier()) if err != nil { return nil, err } @@ -479,7 +479,7 @@ func (e *GeminiCLIExecutor) CountTokens(ctx context.Context, auth *cliproxyauth. for range models { payload := sdktranslator.TranslateRequest(from, to, baseModel, bytes.Clone(req.Payload), false) - payload, err = thinking.ApplyThinking(payload, req.Model, from.String(), to.String()) + payload, err = thinking.ApplyThinking(payload, req.Model, from.String(), to.String(), e.Identifier()) if err != nil { return cliproxyexecutor.Response{}, err } diff --git a/internal/runtime/executor/gemini_executor.go b/internal/runtime/executor/gemini_executor.go index e9f9dbca..2c7a860c 100644 --- a/internal/runtime/executor/gemini_executor.go +++ b/internal/runtime/executor/gemini_executor.go @@ -120,7 +120,7 @@ func (e *GeminiExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, r originalTranslated := sdktranslator.TranslateRequest(from, to, baseModel, originalPayload, false) body := sdktranslator.TranslateRequest(from, to, baseModel, bytes.Clone(req.Payload), false) - body, err = thinking.ApplyThinking(body, req.Model, from.String(), to.String()) + body, err = thinking.ApplyThinking(body, req.Model, from.String(), to.String(), e.Identifier()) if err != nil { return resp, err } @@ -222,7 +222,7 @@ func (e *GeminiExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.A originalTranslated := sdktranslator.TranslateRequest(from, to, baseModel, originalPayload, true) body := sdktranslator.TranslateRequest(from, to, baseModel, bytes.Clone(req.Payload), true) - body, err = thinking.ApplyThinking(body, req.Model, from.String(), to.String()) + body, err = thinking.ApplyThinking(body, req.Model, from.String(), to.String(), e.Identifier()) if err != nil { return nil, err } @@ -338,7 +338,7 @@ func (e *GeminiExecutor) CountTokens(ctx context.Context, auth *cliproxyauth.Aut to := sdktranslator.FromString("gemini") translatedReq := sdktranslator.TranslateRequest(from, to, baseModel, bytes.Clone(req.Payload), false) - translatedReq, err := thinking.ApplyThinking(translatedReq, req.Model, from.String(), to.String()) + translatedReq, err := thinking.ApplyThinking(translatedReq, req.Model, from.String(), to.String(), e.Identifier()) if err != nil { return cliproxyexecutor.Response{}, err } diff --git a/internal/runtime/executor/gemini_vertex_executor.go b/internal/runtime/executor/gemini_vertex_executor.go index 1184c07e..302989c8 100644 --- a/internal/runtime/executor/gemini_vertex_executor.go +++ b/internal/runtime/executor/gemini_vertex_executor.go @@ -319,7 +319,7 @@ func (e *GeminiVertexExecutor) executeWithServiceAccount(ctx context.Context, au originalTranslated := sdktranslator.TranslateRequest(from, to, baseModel, originalPayload, false) body = sdktranslator.TranslateRequest(from, to, baseModel, bytes.Clone(req.Payload), false) - body, err = thinking.ApplyThinking(body, req.Model, from.String(), to.String()) + body, err = thinking.ApplyThinking(body, req.Model, from.String(), to.String(), e.Identifier()) if err != nil { return resp, err } @@ -432,7 +432,7 @@ func (e *GeminiVertexExecutor) executeWithAPIKey(ctx context.Context, auth *clip originalTranslated := sdktranslator.TranslateRequest(from, to, baseModel, originalPayload, false) body := sdktranslator.TranslateRequest(from, to, baseModel, bytes.Clone(req.Payload), false) - body, err = thinking.ApplyThinking(body, req.Model, from.String(), to.String()) + body, err = thinking.ApplyThinking(body, req.Model, from.String(), to.String(), e.Identifier()) if err != nil { return resp, err } @@ -535,7 +535,7 @@ func (e *GeminiVertexExecutor) executeStreamWithServiceAccount(ctx context.Conte originalTranslated := sdktranslator.TranslateRequest(from, to, baseModel, originalPayload, true) body := sdktranslator.TranslateRequest(from, to, baseModel, bytes.Clone(req.Payload), true) - body, err = thinking.ApplyThinking(body, req.Model, from.String(), to.String()) + body, err = thinking.ApplyThinking(body, req.Model, from.String(), to.String(), e.Identifier()) if err != nil { return nil, err } @@ -658,7 +658,7 @@ func (e *GeminiVertexExecutor) executeStreamWithAPIKey(ctx context.Context, auth originalTranslated := sdktranslator.TranslateRequest(from, to, baseModel, originalPayload, true) body := sdktranslator.TranslateRequest(from, to, baseModel, bytes.Clone(req.Payload), true) - body, err = thinking.ApplyThinking(body, req.Model, from.String(), to.String()) + body, err = thinking.ApplyThinking(body, req.Model, from.String(), to.String(), e.Identifier()) if err != nil { return nil, err } @@ -773,7 +773,7 @@ func (e *GeminiVertexExecutor) countTokensWithServiceAccount(ctx context.Context translatedReq := sdktranslator.TranslateRequest(from, to, baseModel, bytes.Clone(req.Payload), false) - translatedReq, err := thinking.ApplyThinking(translatedReq, req.Model, from.String(), to.String()) + translatedReq, err := thinking.ApplyThinking(translatedReq, req.Model, from.String(), to.String(), e.Identifier()) if err != nil { return cliproxyexecutor.Response{}, err } @@ -857,7 +857,7 @@ func (e *GeminiVertexExecutor) countTokensWithAPIKey(ctx context.Context, auth * translatedReq := sdktranslator.TranslateRequest(from, to, baseModel, bytes.Clone(req.Payload), false) - translatedReq, err := thinking.ApplyThinking(translatedReq, req.Model, from.String(), to.String()) + translatedReq, err := thinking.ApplyThinking(translatedReq, req.Model, from.String(), to.String(), e.Identifier()) if err != nil { return cliproxyexecutor.Response{}, err } diff --git a/internal/runtime/executor/iflow_executor.go b/internal/runtime/executor/iflow_executor.go index 3e6ca4e5..c62c0659 100644 --- a/internal/runtime/executor/iflow_executor.go +++ b/internal/runtime/executor/iflow_executor.go @@ -92,7 +92,7 @@ func (e *IFlowExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, re body := sdktranslator.TranslateRequest(from, to, baseModel, bytes.Clone(req.Payload), false) body, _ = sjson.SetBytes(body, "model", baseModel) - body, err = thinking.ApplyThinking(body, req.Model, from.String(), "iflow") + body, err = thinking.ApplyThinking(body, req.Model, from.String(), "iflow", e.Identifier()) if err != nil { return resp, err } @@ -190,7 +190,7 @@ func (e *IFlowExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Au body := sdktranslator.TranslateRequest(from, to, baseModel, bytes.Clone(req.Payload), true) body, _ = sjson.SetBytes(body, "model", baseModel) - body, err = thinking.ApplyThinking(body, req.Model, from.String(), "iflow") + body, err = thinking.ApplyThinking(body, req.Model, from.String(), "iflow", e.Identifier()) if err != nil { return nil, err } diff --git a/internal/runtime/executor/openai_compat_executor.go b/internal/runtime/executor/openai_compat_executor.go index a2bef724..d910294a 100644 --- a/internal/runtime/executor/openai_compat_executor.go +++ b/internal/runtime/executor/openai_compat_executor.go @@ -92,7 +92,7 @@ func (e *OpenAICompatExecutor) Execute(ctx context.Context, auth *cliproxyauth.A translated := sdktranslator.TranslateRequest(from, to, baseModel, bytes.Clone(req.Payload), opts.Stream) translated = applyPayloadConfigWithRoot(e.cfg, baseModel, to.String(), "", translated, originalTranslated) - translated, err = thinking.ApplyThinking(translated, req.Model, from.String(), to.String()) + translated, err = thinking.ApplyThinking(translated, req.Model, from.String(), to.String(), e.Identifier()) if err != nil { return resp, err } @@ -187,7 +187,7 @@ func (e *OpenAICompatExecutor) ExecuteStream(ctx context.Context, auth *cliproxy translated := sdktranslator.TranslateRequest(from, to, baseModel, bytes.Clone(req.Payload), true) translated = applyPayloadConfigWithRoot(e.cfg, baseModel, to.String(), "", translated, originalTranslated) - translated, err = thinking.ApplyThinking(translated, req.Model, from.String(), to.String()) + translated, err = thinking.ApplyThinking(translated, req.Model, from.String(), to.String(), e.Identifier()) if err != nil { return nil, err } @@ -297,7 +297,7 @@ func (e *OpenAICompatExecutor) CountTokens(ctx context.Context, auth *cliproxyau modelForCounting := baseModel - translated, err := thinking.ApplyThinking(translated, req.Model, from.String(), to.String()) + translated, err := thinking.ApplyThinking(translated, req.Model, from.String(), to.String(), e.Identifier()) if err != nil { return cliproxyexecutor.Response{}, err } diff --git a/internal/runtime/executor/qwen_executor.go b/internal/runtime/executor/qwen_executor.go index 260165d9..e013f594 100644 --- a/internal/runtime/executor/qwen_executor.go +++ b/internal/runtime/executor/qwen_executor.go @@ -86,7 +86,7 @@ func (e *QwenExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, req body := sdktranslator.TranslateRequest(from, to, baseModel, bytes.Clone(req.Payload), false) body, _ = sjson.SetBytes(body, "model", baseModel) - body, err = thinking.ApplyThinking(body, req.Model, from.String(), to.String()) + body, err = thinking.ApplyThinking(body, req.Model, from.String(), to.String(), e.Identifier()) if err != nil { return resp, err } @@ -172,7 +172,7 @@ func (e *QwenExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Aut body := sdktranslator.TranslateRequest(from, to, baseModel, bytes.Clone(req.Payload), true) body, _ = sjson.SetBytes(body, "model", baseModel) - body, err = thinking.ApplyThinking(body, req.Model, from.String(), to.String()) + body, err = thinking.ApplyThinking(body, req.Model, from.String(), to.String(), e.Identifier()) if err != nil { return nil, err } diff --git a/internal/thinking/apply.go b/internal/thinking/apply.go index cf0e373b..58c26286 100644 --- a/internal/thinking/apply.go +++ b/internal/thinking/apply.go @@ -63,6 +63,7 @@ func IsUserDefinedModel(modelInfo *registry.ModelInfo) bool { // - model: Model name, optionally with thinking suffix (e.g., "claude-sonnet-4-5(16384)") // - fromFormat: Source request format (e.g., openai, codex, gemini) // - toFormat: Target provider format for the request body (gemini, gemini-cli, antigravity, claude, openai, codex, iflow) +// - providerKey: Provider identifier used for registry model lookups (may differ from toFormat, e.g., openrouter -> openai) // // Returns: // - Modified request body JSON with thinking configuration applied @@ -79,12 +80,16 @@ func IsUserDefinedModel(modelInfo *registry.ModelInfo) bool { // Example: // // // With suffix - suffix config takes priority -// result, err := thinking.ApplyThinking(body, "gemini-2.5-pro(8192)", "gemini", "gemini") +// result, err := thinking.ApplyThinking(body, "gemini-2.5-pro(8192)", "gemini", "gemini", "gemini") // // // Without suffix - uses body config -// result, err := thinking.ApplyThinking(body, "gemini-2.5-pro", "gemini", "gemini") -func ApplyThinking(body []byte, model string, fromFormat string, toFormat string) ([]byte, error) { +// result, err := thinking.ApplyThinking(body, "gemini-2.5-pro", "gemini", "gemini", "gemini") +func ApplyThinking(body []byte, model string, fromFormat string, toFormat string, providerKey string) ([]byte, error) { providerFormat := strings.ToLower(strings.TrimSpace(toFormat)) + providerKey = strings.ToLower(strings.TrimSpace(providerKey)) + if providerKey == "" { + providerKey = providerFormat + } fromFormat = strings.ToLower(strings.TrimSpace(fromFormat)) if fromFormat == "" { fromFormat = providerFormat @@ -102,7 +107,8 @@ func ApplyThinking(body []byte, model string, fromFormat string, toFormat string // 2. Parse suffix and get modelInfo suffixResult := ParseSuffix(model) baseModel := suffixResult.ModelName - modelInfo := registry.LookupModelInfo(baseModel) + // Use provider-specific lookup to handle capability differences across providers. + modelInfo := registry.LookupModelInfo(baseModel, providerKey) // 3. Model capability check // Unknown models are treated as user-defined so thinking config can still be applied. diff --git a/test/thinking_conversion_test.go b/test/thinking_conversion_test.go index 4a7df29a..3ad26ea6 100644 --- a/test/thinking_conversion_test.go +++ b/test/thinking_conversion_test.go @@ -2712,7 +2712,7 @@ func runThinkingTests(t *testing.T, cases []thinkingTestCase) { body, _ = sjson.SetBytes(body, "max_tokens", 200000) } - body, err := thinking.ApplyThinking(body, tc.model, tc.from, applyTo) + body, err := thinking.ApplyThinking(body, tc.model, tc.from, applyTo, applyTo) if tc.expectErr { if err == nil {