From dc7187ca5b611035f2ce67e2ed71cf3c5e713d3a Mon Sep 17 00:00:00 2001 From: Luis Pater Date: Mon, 16 Mar 2026 09:57:38 +0800 Subject: [PATCH] fix(websocket): pin only websocket-capable auth IDs and add corresponding test --- .../openai/openai_responses_websocket.go | 12 +- .../openai/openai_responses_websocket_test.go | 143 ++++++++++++++++++ 2 files changed, 154 insertions(+), 1 deletion(-) diff --git a/sdk/api/handlers/openai/openai_responses_websocket.go b/sdk/api/handlers/openai/openai_responses_websocket.go index d417d6b2..5c68f40e 100644 --- a/sdk/api/handlers/openai/openai_responses_websocket.go +++ b/sdk/api/handlers/openai/openai_responses_websocket.go @@ -177,7 +177,17 @@ func (h *OpenAIResponsesAPIHandler) ResponsesWebsocket(c *gin.Context) { cliCtx = handlers.WithPinnedAuthID(cliCtx, pinnedAuthID) } else { cliCtx = handlers.WithSelectedAuthIDCallback(cliCtx, func(authID string) { - pinnedAuthID = strings.TrimSpace(authID) + authID = strings.TrimSpace(authID) + if authID == "" || h == nil || h.AuthManager == nil { + return + } + selectedAuth, ok := h.AuthManager.GetByID(authID) + if !ok || selectedAuth == nil { + return + } + if websocketUpstreamSupportsIncrementalInput(selectedAuth.Attributes, selectedAuth.Metadata) { + pinnedAuthID = authID + } }) } dataChan, _, errChan := h.ExecuteStreamWithAuthManager(cliCtx, h.HandlerType(), modelName, requestJSON, "") diff --git a/sdk/api/handlers/openai/openai_responses_websocket_test.go b/sdk/api/handlers/openai/openai_responses_websocket_test.go index 981c6630..b3a32c5c 100644 --- a/sdk/api/handlers/openai/openai_responses_websocket_test.go +++ b/sdk/api/handlers/openai/openai_responses_websocket_test.go @@ -8,6 +8,7 @@ import ( "net/http" "net/http/httptest" "strings" + "sync" "testing" "github.com/gin-gonic/gin" @@ -26,6 +27,78 @@ type websocketCaptureExecutor struct { payloads [][]byte } +type orderedWebsocketSelector struct { + mu sync.Mutex + order []string + cursor int +} + +func (s *orderedWebsocketSelector) Pick(_ context.Context, _ string, _ string, _ coreexecutor.Options, auths []*coreauth.Auth) (*coreauth.Auth, error) { + s.mu.Lock() + defer s.mu.Unlock() + + if len(auths) == 0 { + return nil, errors.New("no auth available") + } + for len(s.order) > 0 && s.cursor < len(s.order) { + authID := strings.TrimSpace(s.order[s.cursor]) + s.cursor++ + for _, auth := range auths { + if auth != nil && auth.ID == authID { + return auth, nil + } + } + } + for _, auth := range auths { + if auth != nil { + return auth, nil + } + } + return nil, errors.New("no auth available") +} + +type websocketAuthCaptureExecutor struct { + mu sync.Mutex + authIDs []string +} + +func (e *websocketAuthCaptureExecutor) Identifier() string { return "test-provider" } + +func (e *websocketAuthCaptureExecutor) Execute(context.Context, *coreauth.Auth, coreexecutor.Request, coreexecutor.Options) (coreexecutor.Response, error) { + return coreexecutor.Response{}, errors.New("not implemented") +} + +func (e *websocketAuthCaptureExecutor) ExecuteStream(_ context.Context, auth *coreauth.Auth, _ coreexecutor.Request, _ coreexecutor.Options) (*coreexecutor.StreamResult, error) { + e.mu.Lock() + if auth != nil { + e.authIDs = append(e.authIDs, auth.ID) + } + e.mu.Unlock() + + chunks := make(chan coreexecutor.StreamChunk, 1) + chunks <- coreexecutor.StreamChunk{Payload: []byte(`{"type":"response.completed","response":{"id":"resp-upstream","output":[{"type":"message","id":"out-1"}]}}`)} + close(chunks) + return &coreexecutor.StreamResult{Chunks: chunks}, nil +} + +func (e *websocketAuthCaptureExecutor) Refresh(_ context.Context, auth *coreauth.Auth) (*coreauth.Auth, error) { + return auth, nil +} + +func (e *websocketAuthCaptureExecutor) CountTokens(context.Context, *coreauth.Auth, coreexecutor.Request, coreexecutor.Options) (coreexecutor.Response, error) { + return coreexecutor.Response{}, errors.New("not implemented") +} + +func (e *websocketAuthCaptureExecutor) HttpRequest(context.Context, *coreauth.Auth, *http.Request) (*http.Response, error) { + return nil, errors.New("not implemented") +} + +func (e *websocketAuthCaptureExecutor) AuthIDs() []string { + e.mu.Lock() + defer e.mu.Unlock() + return append([]string(nil), e.authIDs...) +} + func (e *websocketCaptureExecutor) Identifier() string { return "test-provider" } func (e *websocketCaptureExecutor) Execute(context.Context, *coreauth.Auth, coreexecutor.Request, coreexecutor.Options) (coreexecutor.Response, error) { @@ -519,3 +592,73 @@ func TestResponsesWebsocketPrewarmHandledLocallyForSSEUpstream(t *testing.T) { t.Fatalf("unexpected forwarded input: %s", forwarded) } } + +func TestResponsesWebsocketPinsOnlyWebsocketCapableAuth(t *testing.T) { + gin.SetMode(gin.TestMode) + + selector := &orderedWebsocketSelector{order: []string{"auth-sse", "auth-ws"}} + executor := &websocketAuthCaptureExecutor{} + manager := coreauth.NewManager(nil, selector, nil) + manager.RegisterExecutor(executor) + + authSSE := &coreauth.Auth{ID: "auth-sse", Provider: executor.Identifier(), Status: coreauth.StatusActive} + if _, err := manager.Register(context.Background(), authSSE); err != nil { + t.Fatalf("Register SSE auth: %v", err) + } + authWS := &coreauth.Auth{ + ID: "auth-ws", + Provider: executor.Identifier(), + Status: coreauth.StatusActive, + Attributes: map[string]string{"websockets": "true"}, + } + if _, err := manager.Register(context.Background(), authWS); err != nil { + t.Fatalf("Register websocket auth: %v", err) + } + + registry.GetGlobalRegistry().RegisterClient(authSSE.ID, authSSE.Provider, []*registry.ModelInfo{{ID: "test-model"}}) + registry.GetGlobalRegistry().RegisterClient(authWS.ID, authWS.Provider, []*registry.ModelInfo{{ID: "test-model"}}) + t.Cleanup(func() { + registry.GetGlobalRegistry().UnregisterClient(authSSE.ID) + registry.GetGlobalRegistry().UnregisterClient(authWS.ID) + }) + + base := handlers.NewBaseAPIHandlers(&sdkconfig.SDKConfig{}, manager) + h := NewOpenAIResponsesAPIHandler(base) + router := gin.New() + router.GET("/v1/responses/ws", h.ResponsesWebsocket) + + server := httptest.NewServer(router) + defer server.Close() + + wsURL := "ws" + strings.TrimPrefix(server.URL, "http") + "/v1/responses/ws" + conn, _, err := websocket.DefaultDialer.Dial(wsURL, nil) + if err != nil { + t.Fatalf("dial websocket: %v", err) + } + defer func() { + if errClose := conn.Close(); errClose != nil { + t.Fatalf("close websocket: %v", errClose) + } + }() + + requests := []string{ + `{"type":"response.create","model":"test-model","input":[{"type":"message","id":"msg-1"}]}`, + `{"type":"response.create","input":[{"type":"message","id":"msg-2"}]}`, + } + for i := range requests { + if errWrite := conn.WriteMessage(websocket.TextMessage, []byte(requests[i])); errWrite != nil { + t.Fatalf("write websocket message %d: %v", i+1, errWrite) + } + _, payload, errReadMessage := conn.ReadMessage() + if errReadMessage != nil { + t.Fatalf("read websocket message %d: %v", i+1, errReadMessage) + } + if got := gjson.GetBytes(payload, "type").String(); got != wsEventTypeCompleted { + t.Fatalf("message %d payload type = %s, want %s", i+1, got, wsEventTypeCompleted) + } + } + + if got := executor.AuthIDs(); len(got) != 2 || got[0] != "auth-sse" || got[1] != "auth-ws" { + t.Fatalf("selected auth IDs = %v, want [auth-sse auth-ws]", got) + } +}