From ddcf1f279d6150ef9fe19675b8cd6e2fbec4ee42 Mon Sep 17 00:00:00 2001 From: Luis Pater Date: Sat, 7 Mar 2026 13:11:28 +0800 Subject: [PATCH] Fixed: #1901 test(websocket): add tests for incremental input and prewarm handling logic - Added test cases for incremental input support based on upstream capabilities. - Introduced validation for prewarm handling of `response.create` messages locally. - Enhanced test coverage for websocket executor behavior, including payload forwarding checks. - Updated websocket implementation with prewarm and incremental input logic for better testability. --- .../openai/openai_responses_websocket.go | 280 ++++++++++++++++-- .../openai/openai_responses_websocket_test.go | 166 +++++++++++ 2 files changed, 420 insertions(+), 26 deletions(-) diff --git a/sdk/api/handlers/openai/openai_responses_websocket.go b/sdk/api/handlers/openai/openai_responses_websocket.go index 5e2beb94..6a444b45 100644 --- a/sdk/api/handlers/openai/openai_responses_websocket.go +++ b/sdk/api/handlers/openai/openai_responses_websocket.go @@ -14,7 +14,11 @@ import ( "github.com/google/uuid" "github.com/gorilla/websocket" "github.com/router-for-me/CLIProxyAPI/v6/internal/interfaces" + "github.com/router-for-me/CLIProxyAPI/v6/internal/registry" + "github.com/router-for-me/CLIProxyAPI/v6/internal/thinking" + "github.com/router-for-me/CLIProxyAPI/v6/internal/util" "github.com/router-for-me/CLIProxyAPI/v6/sdk/api/handlers" + coreauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth" cliproxyexecutor "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/executor" log "github.com/sirupsen/logrus" "github.com/tidwall/gjson" @@ -100,11 +104,17 @@ func (h *OpenAIResponsesAPIHandler) ResponsesWebsocket(c *gin.Context) { // ) appendWebsocketEvent(&wsBodyLog, "request", payload) - allowIncrementalInputWithPreviousResponseID := websocketUpstreamSupportsIncrementalInput(nil, nil) + allowIncrementalInputWithPreviousResponseID := false if pinnedAuthID != "" && h != nil && h.AuthManager != nil { if pinnedAuth, ok := h.AuthManager.GetByID(pinnedAuthID); ok && pinnedAuth != nil { allowIncrementalInputWithPreviousResponseID = websocketUpstreamSupportsIncrementalInput(pinnedAuth.Attributes, pinnedAuth.Metadata) } + } else { + requestModelName := strings.TrimSpace(gjson.GetBytes(payload, "model").String()) + if requestModelName == "" { + requestModelName = strings.TrimSpace(gjson.GetBytes(lastRequest, "model").String()) + } + allowIncrementalInputWithPreviousResponseID = h.websocketUpstreamSupportsIncrementalInputForModel(requestModelName) } var requestJSON []byte @@ -139,6 +149,22 @@ func (h *OpenAIResponsesAPIHandler) ResponsesWebsocket(c *gin.Context) { } continue } + if shouldHandleResponsesWebsocketPrewarmLocally(payload, lastRequest, allowIncrementalInputWithPreviousResponseID) { + if updated, errDelete := sjson.DeleteBytes(requestJSON, "generate"); errDelete == nil { + requestJSON = updated + } + if updated, errDelete := sjson.DeleteBytes(updatedLastRequest, "generate"); errDelete == nil { + updatedLastRequest = updated + } + lastRequest = updatedLastRequest + lastResponseOutput = []byte("[]") + if errWrite := writeResponsesWebsocketSyntheticPrewarm(c, conn, requestJSON, &wsBodyLog, passthroughSessionID); errWrite != nil { + wsTerminateErr = errWrite + appendWebsocketEvent(&wsBodyLog, "disconnect", []byte(errWrite.Error())) + return + } + continue + } lastRequest = updatedLastRequest modelName := gjson.GetBytes(requestJSON, "model").String() @@ -339,6 +365,192 @@ func websocketUpstreamSupportsIncrementalInput(attributes map[string]string, met return false } +func (h *OpenAIResponsesAPIHandler) websocketUpstreamSupportsIncrementalInputForModel(modelName string) bool { + if h == nil || h.AuthManager == nil { + return false + } + + resolvedModelName := modelName + initialSuffix := thinking.ParseSuffix(modelName) + if initialSuffix.ModelName == "auto" { + resolvedBase := util.ResolveAutoModel(initialSuffix.ModelName) + if initialSuffix.HasSuffix { + resolvedModelName = fmt.Sprintf("%s(%s)", resolvedBase, initialSuffix.RawSuffix) + } else { + resolvedModelName = resolvedBase + } + } else { + resolvedModelName = util.ResolveAutoModel(modelName) + } + + parsed := thinking.ParseSuffix(resolvedModelName) + baseModel := strings.TrimSpace(parsed.ModelName) + providers := util.GetProviderName(baseModel) + if len(providers) == 0 && baseModel != resolvedModelName { + providers = util.GetProviderName(resolvedModelName) + } + if len(providers) == 0 { + return false + } + + providerSet := make(map[string]struct{}, len(providers)) + for i := 0; i < len(providers); i++ { + providerKey := strings.TrimSpace(strings.ToLower(providers[i])) + if providerKey == "" { + continue + } + providerSet[providerKey] = struct{}{} + } + if len(providerSet) == 0 { + return false + } + + modelKey := baseModel + if modelKey == "" { + modelKey = strings.TrimSpace(resolvedModelName) + } + registryRef := registry.GetGlobalRegistry() + now := time.Now() + auths := h.AuthManager.List() + for i := 0; i < len(auths); i++ { + auth := auths[i] + if auth == nil { + continue + } + providerKey := strings.TrimSpace(strings.ToLower(auth.Provider)) + if _, ok := providerSet[providerKey]; !ok { + continue + } + if modelKey != "" && registryRef != nil && !registryRef.ClientSupportsModel(auth.ID, modelKey) { + continue + } + if !responsesWebsocketAuthAvailableForModel(auth, modelKey, now) { + continue + } + if websocketUpstreamSupportsIncrementalInput(auth.Attributes, auth.Metadata) { + return true + } + } + return false +} + +func responsesWebsocketAuthAvailableForModel(auth *coreauth.Auth, modelName string, now time.Time) bool { + if auth == nil { + return false + } + if auth.Disabled || auth.Status == coreauth.StatusDisabled { + return false + } + if modelName != "" && len(auth.ModelStates) > 0 { + state, ok := auth.ModelStates[modelName] + if (!ok || state == nil) && modelName != "" { + baseModel := strings.TrimSpace(thinking.ParseSuffix(modelName).ModelName) + if baseModel != "" && baseModel != modelName { + state, ok = auth.ModelStates[baseModel] + } + } + if ok && state != nil { + if state.Status == coreauth.StatusDisabled { + return false + } + if state.Unavailable && !state.NextRetryAfter.IsZero() && state.NextRetryAfter.After(now) { + return false + } + return true + } + } + if auth.Unavailable && !auth.NextRetryAfter.IsZero() && auth.NextRetryAfter.After(now) { + return false + } + return true +} + +func shouldHandleResponsesWebsocketPrewarmLocally(rawJSON []byte, lastRequest []byte, allowIncrementalInputWithPreviousResponseID bool) bool { + if allowIncrementalInputWithPreviousResponseID || len(lastRequest) != 0 { + return false + } + if strings.TrimSpace(gjson.GetBytes(rawJSON, "type").String()) != wsRequestTypeCreate { + return false + } + generateResult := gjson.GetBytes(rawJSON, "generate") + return generateResult.Exists() && !generateResult.Bool() +} + +func writeResponsesWebsocketSyntheticPrewarm( + c *gin.Context, + conn *websocket.Conn, + requestJSON []byte, + wsBodyLog *strings.Builder, + sessionID string, +) error { + payloads, errPayloads := syntheticResponsesWebsocketPrewarmPayloads(requestJSON) + if errPayloads != nil { + return errPayloads + } + for i := 0; i < len(payloads); i++ { + markAPIResponseTimestamp(c) + appendWebsocketEvent(wsBodyLog, "response", payloads[i]) + // log.Infof( + // "responses websocket: downstream_out id=%s type=%d event=%s payload=%s", + // sessionID, + // websocket.TextMessage, + // websocketPayloadEventType(payloads[i]), + // websocketPayloadPreview(payloads[i]), + // ) + if errWrite := conn.WriteMessage(websocket.TextMessage, payloads[i]); errWrite != nil { + log.Warnf( + "responses websocket: downstream_out write failed id=%s event=%s error=%v", + sessionID, + websocketPayloadEventType(payloads[i]), + errWrite, + ) + return errWrite + } + } + return nil +} + +func syntheticResponsesWebsocketPrewarmPayloads(requestJSON []byte) ([][]byte, error) { + responseID := "resp_prewarm_" + uuid.NewString() + createdAt := time.Now().Unix() + modelName := strings.TrimSpace(gjson.GetBytes(requestJSON, "model").String()) + + createdPayload := []byte(`{"type":"response.created","sequence_number":0,"response":{"id":"","object":"response","created_at":0,"status":"in_progress","background":false,"error":null,"output":[]}}`) + var errSet error + createdPayload, errSet = sjson.SetBytes(createdPayload, "response.id", responseID) + if errSet != nil { + return nil, errSet + } + createdPayload, errSet = sjson.SetBytes(createdPayload, "response.created_at", createdAt) + if errSet != nil { + return nil, errSet + } + if modelName != "" { + createdPayload, errSet = sjson.SetBytes(createdPayload, "response.model", modelName) + if errSet != nil { + return nil, errSet + } + } + + completedPayload := []byte(`{"type":"response.completed","sequence_number":1,"response":{"id":"","object":"response","created_at":0,"status":"completed","background":false,"error":null,"output":[],"usage":{"input_tokens":0,"output_tokens":0,"total_tokens":0}}}`) + completedPayload, errSet = sjson.SetBytes(completedPayload, "response.id", responseID) + if errSet != nil { + return nil, errSet + } + completedPayload, errSet = sjson.SetBytes(completedPayload, "response.created_at", createdAt) + if errSet != nil { + return nil, errSet + } + if modelName != "" { + completedPayload, errSet = sjson.SetBytes(completedPayload, "response.model", modelName) + if errSet != nil { + return nil, errSet + } + } + + return [][]byte{createdPayload, completedPayload}, nil +} + func mergeJSONArrayRaw(existingRaw, appendRaw string) (string, error) { existingRaw = strings.TrimSpace(existingRaw) appendRaw = strings.TrimSpace(appendRaw) @@ -550,47 +762,63 @@ func writeResponsesWebsocketError(conn *websocket.Conn, errMsg *interfaces.Error } body := handlers.BuildErrorResponseBody(status, errText) - payload := map[string]any{ - "type": wsEventTypeError, - "status": status, + payload := []byte(`{}`) + var errSet error + payload, errSet = sjson.SetBytes(payload, "type", wsEventTypeError) + if errSet != nil { + return nil, errSet + } + payload, errSet = sjson.SetBytes(payload, "status", status) + if errSet != nil { + return nil, errSet } if errMsg != nil && errMsg.Addon != nil { - headers := map[string]any{} + headers := []byte(`{}`) + hasHeaders := false for key, values := range errMsg.Addon { if len(values) == 0 { continue } - headers[key] = values[0] + headerPath := strings.ReplaceAll(strings.ReplaceAll(key, `\\`, `\\\\`), ".", `\\.`) + headers, errSet = sjson.SetBytes(headers, headerPath, values[0]) + if errSet != nil { + return nil, errSet + } + hasHeaders = true } - if len(headers) > 0 { - payload["headers"] = headers - } - } - - if len(body) > 0 && json.Valid(body) { - var decoded map[string]any - if errDecode := json.Unmarshal(body, &decoded); errDecode == nil { - if inner, ok := decoded["error"]; ok { - payload["error"] = inner - } else { - payload["error"] = decoded + if hasHeaders { + payload, errSet = sjson.SetRawBytes(payload, "headers", headers) + if errSet != nil { + return nil, errSet } } } - if _, ok := payload["error"]; !ok { - payload["error"] = map[string]any{ - "type": "server_error", - "message": errText, + if len(body) > 0 && json.Valid(body) { + errorNode := gjson.GetBytes(body, "error") + if errorNode.Exists() { + payload, errSet = sjson.SetRawBytes(payload, "error", []byte(errorNode.Raw)) + } else { + payload, errSet = sjson.SetRawBytes(payload, "error", body) + } + if errSet != nil { + return nil, errSet } } - data, err := json.Marshal(payload) - if err != nil { - return nil, err + if !gjson.GetBytes(payload, "error").Exists() { + payload, errSet = sjson.SetBytes(payload, "error.type", "server_error") + if errSet != nil { + return nil, errSet + } + payload, errSet = sjson.SetBytes(payload, "error.message", errText) + if errSet != nil { + return nil, errSet + } } - return data, conn.WriteMessage(websocket.TextMessage, data) + + return payload, conn.WriteMessage(websocket.TextMessage, payload) } func appendWebsocketEvent(builder *strings.Builder, eventType string, payload []byte) { diff --git a/sdk/api/handlers/openai/openai_responses_websocket_test.go b/sdk/api/handlers/openai/openai_responses_websocket_test.go index a04bb18c..d30c648d 100644 --- a/sdk/api/handlers/openai/openai_responses_websocket_test.go +++ b/sdk/api/handlers/openai/openai_responses_websocket_test.go @@ -2,7 +2,9 @@ package openai import ( "bytes" + "context" "errors" + "fmt" "net/http" "net/http/httptest" "strings" @@ -11,9 +13,46 @@ import ( "github.com/gin-gonic/gin" "github.com/gorilla/websocket" "github.com/router-for-me/CLIProxyAPI/v6/internal/interfaces" + "github.com/router-for-me/CLIProxyAPI/v6/internal/registry" + "github.com/router-for-me/CLIProxyAPI/v6/sdk/api/handlers" + coreauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth" + coreexecutor "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/executor" + sdkconfig "github.com/router-for-me/CLIProxyAPI/v6/sdk/config" "github.com/tidwall/gjson" ) +type websocketCaptureExecutor struct { + streamCalls int + payloads [][]byte +} + +func (e *websocketCaptureExecutor) Identifier() string { return "test-provider" } + +func (e *websocketCaptureExecutor) Execute(context.Context, *coreauth.Auth, coreexecutor.Request, coreexecutor.Options) (coreexecutor.Response, error) { + return coreexecutor.Response{}, errors.New("not implemented") +} + +func (e *websocketCaptureExecutor) ExecuteStream(_ context.Context, _ *coreauth.Auth, req coreexecutor.Request, _ coreexecutor.Options) (*coreexecutor.StreamResult, error) { + e.streamCalls++ + e.payloads = append(e.payloads, bytes.Clone(req.Payload)) + chunks := make(chan coreexecutor.StreamChunk, 1) + chunks <- coreexecutor.StreamChunk{Payload: []byte(`{"type":"response.completed","response":{"id":"resp-upstream","output":[{"type":"message","id":"out-1"}]}}`)} + close(chunks) + return &coreexecutor.StreamResult{Chunks: chunks}, nil +} + +func (e *websocketCaptureExecutor) Refresh(_ context.Context, auth *coreauth.Auth) (*coreauth.Auth, error) { + return auth, nil +} + +func (e *websocketCaptureExecutor) CountTokens(context.Context, *coreauth.Auth, coreexecutor.Request, coreexecutor.Options) (coreexecutor.Response, error) { + return coreexecutor.Response{}, errors.New("not implemented") +} + +func (e *websocketCaptureExecutor) HttpRequest(context.Context, *coreauth.Auth, *http.Request) (*http.Response, error) { + return nil, errors.New("not implemented") +} + func TestNormalizeResponsesWebsocketRequestCreate(t *testing.T) { raw := []byte(`{"type":"response.create","model":"test-model","stream":false,"input":[{"type":"message","id":"msg-1"}]}`) @@ -326,3 +365,130 @@ func TestForwardResponsesWebsocketPreservesCompletedEvent(t *testing.T) { t.Fatalf("server error: %v", errServer) } } + +func TestWebsocketUpstreamSupportsIncrementalInputForModel(t *testing.T) { + manager := coreauth.NewManager(nil, nil, nil) + auth := &coreauth.Auth{ + ID: "auth-ws", + Provider: "test-provider", + Status: coreauth.StatusActive, + Attributes: map[string]string{"websockets": "true"}, + } + if _, err := manager.Register(context.Background(), auth); err != nil { + t.Fatalf("Register auth: %v", err) + } + registry.GetGlobalRegistry().RegisterClient(auth.ID, auth.Provider, []*registry.ModelInfo{{ID: "test-model"}}) + t.Cleanup(func() { + registry.GetGlobalRegistry().UnregisterClient(auth.ID) + }) + + base := handlers.NewBaseAPIHandlers(&sdkconfig.SDKConfig{}, manager) + h := NewOpenAIResponsesAPIHandler(base) + if !h.websocketUpstreamSupportsIncrementalInputForModel("test-model") { + t.Fatalf("expected websocket-capable upstream for test-model") + } +} + +func TestResponsesWebsocketPrewarmHandledLocallyForSSEUpstream(t *testing.T) { + gin.SetMode(gin.TestMode) + + executor := &websocketCaptureExecutor{} + manager := coreauth.NewManager(nil, nil, nil) + manager.RegisterExecutor(executor) + auth := &coreauth.Auth{ID: "auth-sse", Provider: executor.Identifier(), Status: coreauth.StatusActive} + if _, err := manager.Register(context.Background(), auth); err != nil { + t.Fatalf("Register auth: %v", err) + } + registry.GetGlobalRegistry().RegisterClient(auth.ID, auth.Provider, []*registry.ModelInfo{{ID: "test-model"}}) + t.Cleanup(func() { + registry.GetGlobalRegistry().UnregisterClient(auth.ID) + }) + + base := handlers.NewBaseAPIHandlers(&sdkconfig.SDKConfig{}, manager) + h := NewOpenAIResponsesAPIHandler(base) + router := gin.New() + router.GET("/v1/responses/ws", h.ResponsesWebsocket) + + server := httptest.NewServer(router) + defer server.Close() + + wsURL := "ws" + strings.TrimPrefix(server.URL, "http") + "/v1/responses/ws" + conn, _, err := websocket.DefaultDialer.Dial(wsURL, nil) + if err != nil { + t.Fatalf("dial websocket: %v", err) + } + defer func() { + errClose := conn.Close() + if errClose != nil { + t.Fatalf("close websocket: %v", errClose) + } + }() + + errWrite := conn.WriteMessage(websocket.TextMessage, []byte(`{"type":"response.create","model":"test-model","generate":false}`)) + if errWrite != nil { + t.Fatalf("write prewarm websocket message: %v", errWrite) + } + + _, createdPayload, errReadMessage := conn.ReadMessage() + if errReadMessage != nil { + t.Fatalf("read prewarm created message: %v", errReadMessage) + } + if gjson.GetBytes(createdPayload, "type").String() != "response.created" { + t.Fatalf("created payload type = %s, want response.created", gjson.GetBytes(createdPayload, "type").String()) + } + prewarmResponseID := gjson.GetBytes(createdPayload, "response.id").String() + if prewarmResponseID == "" { + t.Fatalf("prewarm response id is empty") + } + if executor.streamCalls != 0 { + t.Fatalf("stream calls after prewarm = %d, want 0", executor.streamCalls) + } + + _, completedPayload, errReadMessage := conn.ReadMessage() + if errReadMessage != nil { + t.Fatalf("read prewarm completed message: %v", errReadMessage) + } + if gjson.GetBytes(completedPayload, "type").String() != wsEventTypeCompleted { + t.Fatalf("completed payload type = %s, want %s", gjson.GetBytes(completedPayload, "type").String(), wsEventTypeCompleted) + } + if gjson.GetBytes(completedPayload, "response.id").String() != prewarmResponseID { + t.Fatalf("completed response id = %s, want %s", gjson.GetBytes(completedPayload, "response.id").String(), prewarmResponseID) + } + if gjson.GetBytes(completedPayload, "response.usage.total_tokens").Int() != 0 { + t.Fatalf("prewarm total tokens = %d, want 0", gjson.GetBytes(completedPayload, "response.usage.total_tokens").Int()) + } + + secondRequest := fmt.Sprintf(`{"type":"response.create","previous_response_id":%q,"input":[{"type":"message","id":"msg-1"}]}`, prewarmResponseID) + errWrite = conn.WriteMessage(websocket.TextMessage, []byte(secondRequest)) + if errWrite != nil { + t.Fatalf("write follow-up websocket message: %v", errWrite) + } + + _, upstreamPayload, errReadMessage := conn.ReadMessage() + if errReadMessage != nil { + t.Fatalf("read upstream completed message: %v", errReadMessage) + } + if gjson.GetBytes(upstreamPayload, "type").String() != wsEventTypeCompleted { + t.Fatalf("upstream payload type = %s, want %s", gjson.GetBytes(upstreamPayload, "type").String(), wsEventTypeCompleted) + } + if executor.streamCalls != 1 { + t.Fatalf("stream calls after follow-up = %d, want 1", executor.streamCalls) + } + if len(executor.payloads) != 1 { + t.Fatalf("captured upstream payloads = %d, want 1", len(executor.payloads)) + } + forwarded := executor.payloads[0] + if gjson.GetBytes(forwarded, "previous_response_id").Exists() { + t.Fatalf("previous_response_id leaked upstream: %s", forwarded) + } + if gjson.GetBytes(forwarded, "generate").Exists() { + t.Fatalf("generate leaked upstream: %s", forwarded) + } + if gjson.GetBytes(forwarded, "model").String() != "test-model" { + t.Fatalf("forwarded model = %s, want test-model", gjson.GetBytes(forwarded, "model").String()) + } + input := gjson.GetBytes(forwarded, "input").Array() + if len(input) != 1 || input[0].Get("id").String() != "msg-1" { + t.Fatalf("unexpected forwarded input: %s", forwarded) + } +}