From 97ef633c57947364914dccbf2470ca9f81bf58ba Mon Sep 17 00:00:00 2001 From: chujian <765781379@qq.com> Date: Sat, 7 Mar 2026 17:36:57 +0800 Subject: [PATCH] fix(registry): address review feedback --- internal/registry/model_registry.go | 33 +++++++++++----- .../registry/model_registry_safety_test.go | 38 +++++++++++++++++++ 2 files changed, 61 insertions(+), 10 deletions(-) diff --git a/internal/registry/model_registry.go b/internal/registry/model_registry.go index 8b03c59e..becd4c3a 100644 --- a/internal/registry/model_registry.go +++ b/internal/registry/model_registry.go @@ -173,7 +173,7 @@ func LookupModelInfo(modelID string, provider ...string) *ModelInfo { if info := GetGlobalRegistry().GetModelInfo(modelID, p); info != nil { return cloneModelInfo(info) } - return LookupStaticModelInfo(modelID) + return cloneModelInfo(LookupStaticModelInfo(modelID)) } // SetHook sets an optional hook for observing model registration changes. @@ -490,7 +490,6 @@ func (r *ModelRegistry) removeModelRegistration(clientID, modelID, provider stri registration.LastUpdated = now if registration.QuotaExceededClients != nil { delete(registration.QuotaExceededClients, clientID) - r.invalidateAvailableModelsCacheLocked() } if registration.SuspendedClients != nil { delete(registration.SuspendedClients, clientID) @@ -842,13 +841,34 @@ func cloneModelMaps(models []map[string]any) []map[string]any { } copyModel := make(map[string]any, len(model)) for key, value := range model { - copyModel[key] = value + copyModel[key] = cloneModelMapValue(value) } cloned = append(cloned, copyModel) } return cloned } +func cloneModelMapValue(value any) any { + switch typed := value.(type) { + case map[string]any: + copyMap := make(map[string]any, len(typed)) + for key, entry := range typed { + copyMap[key] = cloneModelMapValue(entry) + } + return copyMap + case []any: + copySlice := make([]any, len(typed)) + for i, entry := range typed { + copySlice[i] = cloneModelMapValue(entry) + } + return copySlice + case []string: + return append([]string(nil), typed...) + default: + return value + } +} + // GetAvailableModelsByProvider returns models available for the given provider identifier. // Parameters: // - provider: Provider identifier (e.g., "codex", "gemini", "antigravity") @@ -1298,10 +1318,3 @@ func (r *ModelRegistry) GetModelsForClient(clientID string) []*ModelInfo { } return result } - - - - - - - diff --git a/internal/registry/model_registry_safety_test.go b/internal/registry/model_registry_safety_test.go index 0f3ffe51..5f4f65d2 100644 --- a/internal/registry/model_registry_safety_test.go +++ b/internal/registry/model_registry_safety_test.go @@ -109,3 +109,41 @@ func TestCleanupExpiredQuotasInvalidatesAvailableModelsCache(t *testing.T) { t.Fatalf("expected model id m1, got %v", got) } } + +func TestGetAvailableModelsReturnsClonedSupportedParameters(t *testing.T) { + r := newTestModelRegistry() + r.RegisterClient("client-1", "openai", []*ModelInfo{{ + ID: "m1", + DisplayName: "Model One", + SupportedParameters: []string{"temperature", "top_p"}, + }}) + + first := r.GetAvailableModels("openai") + if len(first) != 1 { + t.Fatalf("expected one model, got %d", len(first)) + } + params, ok := first[0]["supported_parameters"].([]string) + if !ok || len(params) != 2 { + t.Fatalf("expected supported_parameters slice, got %#v", first[0]["supported_parameters"]) + } + params[0] = "mutated" + + second := r.GetAvailableModels("openai") + params, ok = second[0]["supported_parameters"].([]string) + if !ok || len(params) != 2 || params[0] != "temperature" { + t.Fatalf("expected cloned supported_parameters, got %#v", second[0]["supported_parameters"]) + } +} + +func TestLookupModelInfoReturnsCloneForStaticDefinitions(t *testing.T) { + first := LookupModelInfo("glm-4.6") + if first == nil || first.Thinking == nil || len(first.Thinking.Levels) == 0 { + t.Fatalf("expected static model with thinking levels, got %+v", first) + } + first.Thinking.Levels[0] = "mutated" + + second := LookupModelInfo("glm-4.6") + if second == nil || second.Thinking == nil || len(second.Thinking.Levels) == 0 || second.Thinking.Levels[0] == "mutated" { + t.Fatalf("expected static lookup clone, got %+v", second) + } +}