mirror of
https://github.com/router-for-me/CLIProxyAPIPlus.git
synced 2026-04-05 04:01:22 +00:00
fix(websocket): pin only websocket-capable auth IDs and add corresponding test
This commit is contained in:
@@ -177,7 +177,17 @@ func (h *OpenAIResponsesAPIHandler) ResponsesWebsocket(c *gin.Context) {
|
||||
cliCtx = handlers.WithPinnedAuthID(cliCtx, pinnedAuthID)
|
||||
} else {
|
||||
cliCtx = handlers.WithSelectedAuthIDCallback(cliCtx, func(authID string) {
|
||||
pinnedAuthID = strings.TrimSpace(authID)
|
||||
authID = strings.TrimSpace(authID)
|
||||
if authID == "" || h == nil || h.AuthManager == nil {
|
||||
return
|
||||
}
|
||||
selectedAuth, ok := h.AuthManager.GetByID(authID)
|
||||
if !ok || selectedAuth == nil {
|
||||
return
|
||||
}
|
||||
if websocketUpstreamSupportsIncrementalInput(selectedAuth.Attributes, selectedAuth.Metadata) {
|
||||
pinnedAuthID = authID
|
||||
}
|
||||
})
|
||||
}
|
||||
dataChan, _, errChan := h.ExecuteStreamWithAuthManager(cliCtx, h.HandlerType(), modelName, requestJSON, "")
|
||||
|
||||
@@ -8,6 +8,7 @@ import (
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"strings"
|
||||
"sync"
|
||||
"testing"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
@@ -26,6 +27,78 @@ type websocketCaptureExecutor struct {
|
||||
payloads [][]byte
|
||||
}
|
||||
|
||||
type orderedWebsocketSelector struct {
|
||||
mu sync.Mutex
|
||||
order []string
|
||||
cursor int
|
||||
}
|
||||
|
||||
func (s *orderedWebsocketSelector) Pick(_ context.Context, _ string, _ string, _ coreexecutor.Options, auths []*coreauth.Auth) (*coreauth.Auth, error) {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
|
||||
if len(auths) == 0 {
|
||||
return nil, errors.New("no auth available")
|
||||
}
|
||||
for len(s.order) > 0 && s.cursor < len(s.order) {
|
||||
authID := strings.TrimSpace(s.order[s.cursor])
|
||||
s.cursor++
|
||||
for _, auth := range auths {
|
||||
if auth != nil && auth.ID == authID {
|
||||
return auth, nil
|
||||
}
|
||||
}
|
||||
}
|
||||
for _, auth := range auths {
|
||||
if auth != nil {
|
||||
return auth, nil
|
||||
}
|
||||
}
|
||||
return nil, errors.New("no auth available")
|
||||
}
|
||||
|
||||
type websocketAuthCaptureExecutor struct {
|
||||
mu sync.Mutex
|
||||
authIDs []string
|
||||
}
|
||||
|
||||
func (e *websocketAuthCaptureExecutor) Identifier() string { return "test-provider" }
|
||||
|
||||
func (e *websocketAuthCaptureExecutor) Execute(context.Context, *coreauth.Auth, coreexecutor.Request, coreexecutor.Options) (coreexecutor.Response, error) {
|
||||
return coreexecutor.Response{}, errors.New("not implemented")
|
||||
}
|
||||
|
||||
func (e *websocketAuthCaptureExecutor) ExecuteStream(_ context.Context, auth *coreauth.Auth, _ coreexecutor.Request, _ coreexecutor.Options) (*coreexecutor.StreamResult, error) {
|
||||
e.mu.Lock()
|
||||
if auth != nil {
|
||||
e.authIDs = append(e.authIDs, auth.ID)
|
||||
}
|
||||
e.mu.Unlock()
|
||||
|
||||
chunks := make(chan coreexecutor.StreamChunk, 1)
|
||||
chunks <- coreexecutor.StreamChunk{Payload: []byte(`{"type":"response.completed","response":{"id":"resp-upstream","output":[{"type":"message","id":"out-1"}]}}`)}
|
||||
close(chunks)
|
||||
return &coreexecutor.StreamResult{Chunks: chunks}, nil
|
||||
}
|
||||
|
||||
func (e *websocketAuthCaptureExecutor) Refresh(_ context.Context, auth *coreauth.Auth) (*coreauth.Auth, error) {
|
||||
return auth, nil
|
||||
}
|
||||
|
||||
func (e *websocketAuthCaptureExecutor) CountTokens(context.Context, *coreauth.Auth, coreexecutor.Request, coreexecutor.Options) (coreexecutor.Response, error) {
|
||||
return coreexecutor.Response{}, errors.New("not implemented")
|
||||
}
|
||||
|
||||
func (e *websocketAuthCaptureExecutor) HttpRequest(context.Context, *coreauth.Auth, *http.Request) (*http.Response, error) {
|
||||
return nil, errors.New("not implemented")
|
||||
}
|
||||
|
||||
func (e *websocketAuthCaptureExecutor) AuthIDs() []string {
|
||||
e.mu.Lock()
|
||||
defer e.mu.Unlock()
|
||||
return append([]string(nil), e.authIDs...)
|
||||
}
|
||||
|
||||
func (e *websocketCaptureExecutor) Identifier() string { return "test-provider" }
|
||||
|
||||
func (e *websocketCaptureExecutor) Execute(context.Context, *coreauth.Auth, coreexecutor.Request, coreexecutor.Options) (coreexecutor.Response, error) {
|
||||
@@ -519,3 +592,73 @@ func TestResponsesWebsocketPrewarmHandledLocallyForSSEUpstream(t *testing.T) {
|
||||
t.Fatalf("unexpected forwarded input: %s", forwarded)
|
||||
}
|
||||
}
|
||||
|
||||
func TestResponsesWebsocketPinsOnlyWebsocketCapableAuth(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
selector := &orderedWebsocketSelector{order: []string{"auth-sse", "auth-ws"}}
|
||||
executor := &websocketAuthCaptureExecutor{}
|
||||
manager := coreauth.NewManager(nil, selector, nil)
|
||||
manager.RegisterExecutor(executor)
|
||||
|
||||
authSSE := &coreauth.Auth{ID: "auth-sse", Provider: executor.Identifier(), Status: coreauth.StatusActive}
|
||||
if _, err := manager.Register(context.Background(), authSSE); err != nil {
|
||||
t.Fatalf("Register SSE auth: %v", err)
|
||||
}
|
||||
authWS := &coreauth.Auth{
|
||||
ID: "auth-ws",
|
||||
Provider: executor.Identifier(),
|
||||
Status: coreauth.StatusActive,
|
||||
Attributes: map[string]string{"websockets": "true"},
|
||||
}
|
||||
if _, err := manager.Register(context.Background(), authWS); err != nil {
|
||||
t.Fatalf("Register websocket auth: %v", err)
|
||||
}
|
||||
|
||||
registry.GetGlobalRegistry().RegisterClient(authSSE.ID, authSSE.Provider, []*registry.ModelInfo{{ID: "test-model"}})
|
||||
registry.GetGlobalRegistry().RegisterClient(authWS.ID, authWS.Provider, []*registry.ModelInfo{{ID: "test-model"}})
|
||||
t.Cleanup(func() {
|
||||
registry.GetGlobalRegistry().UnregisterClient(authSSE.ID)
|
||||
registry.GetGlobalRegistry().UnregisterClient(authWS.ID)
|
||||
})
|
||||
|
||||
base := handlers.NewBaseAPIHandlers(&sdkconfig.SDKConfig{}, manager)
|
||||
h := NewOpenAIResponsesAPIHandler(base)
|
||||
router := gin.New()
|
||||
router.GET("/v1/responses/ws", h.ResponsesWebsocket)
|
||||
|
||||
server := httptest.NewServer(router)
|
||||
defer server.Close()
|
||||
|
||||
wsURL := "ws" + strings.TrimPrefix(server.URL, "http") + "/v1/responses/ws"
|
||||
conn, _, err := websocket.DefaultDialer.Dial(wsURL, nil)
|
||||
if err != nil {
|
||||
t.Fatalf("dial websocket: %v", err)
|
||||
}
|
||||
defer func() {
|
||||
if errClose := conn.Close(); errClose != nil {
|
||||
t.Fatalf("close websocket: %v", errClose)
|
||||
}
|
||||
}()
|
||||
|
||||
requests := []string{
|
||||
`{"type":"response.create","model":"test-model","input":[{"type":"message","id":"msg-1"}]}`,
|
||||
`{"type":"response.create","input":[{"type":"message","id":"msg-2"}]}`,
|
||||
}
|
||||
for i := range requests {
|
||||
if errWrite := conn.WriteMessage(websocket.TextMessage, []byte(requests[i])); errWrite != nil {
|
||||
t.Fatalf("write websocket message %d: %v", i+1, errWrite)
|
||||
}
|
||||
_, payload, errReadMessage := conn.ReadMessage()
|
||||
if errReadMessage != nil {
|
||||
t.Fatalf("read websocket message %d: %v", i+1, errReadMessage)
|
||||
}
|
||||
if got := gjson.GetBytes(payload, "type").String(); got != wsEventTypeCompleted {
|
||||
t.Fatalf("message %d payload type = %s, want %s", i+1, got, wsEventTypeCompleted)
|
||||
}
|
||||
}
|
||||
|
||||
if got := executor.AuthIDs(); len(got) != 2 || got[0] != "auth-sse" || got[1] != "auth-ws" {
|
||||
t.Fatalf("selected auth IDs = %v, want [auth-sse auth-ws]", got)
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user