From 3839d93ba040ace1420845698b5efb5b30519af2 Mon Sep 17 00:00:00 2001 From: hkfires <10558748+hkfires@users.noreply.github.com> Date: Sat, 25 Oct 2025 11:30:39 +0800 Subject: [PATCH 1/6] feat: add websocket routing and executor unregister API - Introduce Server.AttachWebsocketRoute(path, handler) to mount websocket upgrade handlers on the Gin engine. - Track registered WS paths via wsRoutes with wsRouteMu to prevent duplicate registrations; initialize in NewServer and import sync. - Add Manager.UnregisterExecutor(provider) for clean executor lifecycle management. - Add github.com/gorilla/websocket v1.5.3 dependency and update go.sum. Motivation: enable services to expose WS endpoints through the core server and allow removing auth executors dynamically while avoiding duplicate route setup. No breaking changes. --- go.mod | 1 + go.sum | 4 +- internal/api/server.go | 33 +++ .../runtime/executor/aistudio_executor.go | 264 ++++++++++++++++++ internal/wsrelay/http.go | 187 +++++++++++++ internal/wsrelay/manager.go | 200 +++++++++++++ internal/wsrelay/message.go | 27 ++ internal/wsrelay/session.go | 188 +++++++++++++ sdk/cliproxy/auth/manager.go | 11 + sdk/cliproxy/auth/types.go | 12 +- sdk/cliproxy/service.go | 111 ++++++++ 11 files changed, 1035 insertions(+), 3 deletions(-) create mode 100644 internal/runtime/executor/aistudio_executor.go create mode 100644 internal/wsrelay/http.go create mode 100644 internal/wsrelay/manager.go create mode 100644 internal/wsrelay/message.go create mode 100644 internal/wsrelay/session.go diff --git a/go.mod b/go.mod index df03ac4e..010c8a6e 100644 --- a/go.mod +++ b/go.mod @@ -7,6 +7,7 @@ require ( github.com/gin-gonic/gin v1.10.1 github.com/go-git/go-git/v6 v6.0.0-20251009132922-75a182125145 github.com/google/uuid v1.6.0 + github.com/gorilla/websocket v1.5.3 github.com/jackc/pgx/v5 v5.7.6 github.com/joho/godotenv v1.5.1 github.com/klauspost/compress v1.17.4 diff --git a/go.sum b/go.sum index cba1c68c..b5cfca4a 100644 --- a/go.sum +++ b/go.sum @@ -66,6 +66,8 @@ github.com/google/go-cmp v0.5.5/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/ github.com/google/gofuzz v1.0.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg= github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0= github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= +github.com/gorilla/websocket v1.5.3 h1:saDtZ6Pbx/0u+bgYQ3q96pZgCzfhKXGPqt7kZ72aNNg= +github.com/gorilla/websocket v1.5.3/go.mod h1:YR8l580nyteQvAITg2hZ9XVh4b55+EU/adAjf1fMHhE= github.com/jackc/pgpassfile v1.0.0 h1:/6Hmqy13Ss2zCq62VdNG8tM1wchn8zjSGOBJ6icpsIM= github.com/jackc/pgpassfile v1.0.0/go.mod h1:CEx0iS5ambNFdcRtxPj5JhEz+xB6uRky5eyVu/W2HEg= github.com/jackc/pgservicefile v0.0.0-20240606120523-5a60cdf6a761 h1:iCEnooe7UlwOQYpKFhBabPMi4aNAfoODPEFNiAnClxo= @@ -80,8 +82,6 @@ github.com/json-iterator/go v1.1.12 h1:PV8peI4a0ysnczrg+LtxykD8LfKY9ML6u2jnxaEnr github.com/json-iterator/go v1.1.12/go.mod h1:e30LSqwooZae/UwlEbR2852Gd8hjQvJoHmT4TnhNGBo= github.com/kevinburke/ssh_config v1.4.0 h1:6xxtP5bZ2E4NF5tuQulISpTO2z8XbtH8cg1PWkxoFkQ= github.com/kevinburke/ssh_config v1.4.0/go.mod h1:q2RIzfka+BXARoNexmF9gkxEX7DmvbW9P4hIVx2Kg4M= -github.com/klauspost/compress v1.17.3 h1:qkRjuerhUU1EmXLYGkSH6EZL+vPSxIrYjLNAK4slzwA= -github.com/klauspost/compress v1.17.3/go.mod h1:/dCuZOvVtNoHsyb+cuJD3itjs3NbnF6KH9zAO4BDxPM= github.com/klauspost/compress v1.17.4 h1:Ej5ixsIri7BrIjBkRZLTo6ghwrEtHFk7ijlczPW4fZ4= github.com/klauspost/compress v1.17.4/go.mod h1:/dCuZOvVtNoHsyb+cuJD3itjs3NbnF6KH9zAO4BDxPM= github.com/klauspost/cpuid/v2 v2.0.1/go.mod h1:FInQzS24/EEf25PyTYn52gqo7WaD8xa0213Md/qVLRg= diff --git a/internal/api/server.go b/internal/api/server.go index aae2b0e0..a41861c2 100644 --- a/internal/api/server.go +++ b/internal/api/server.go @@ -13,6 +13,7 @@ import ( "os" "path/filepath" "strings" + "sync" "sync/atomic" "time" @@ -138,6 +139,10 @@ type Server struct { // currentPath is the absolute path to the current working directory. currentPath string + // wsRoutes tracks registered websocket upgrade paths. + wsRouteMu sync.Mutex + wsRoutes map[string]struct{} + // management handler mgmt *managementHandlers.Handler @@ -228,6 +233,7 @@ func NewServer(cfg *config.Config, authManager *auth.Manager, accessManager *sdk configFilePath: configFilePath, currentPath: wd, envManagementSecret: envManagementSecret, + wsRoutes: make(map[string]struct{}), } // Save initial YAML snapshot s.oldConfigYaml, _ = yaml.Marshal(cfg) @@ -371,6 +377,33 @@ func (s *Server) setupRoutes() { // Management routes are registered lazily by registerManagementRoutes when a secret is configured. } +// AttachWebsocketRoute registers a websocket upgrade handler on the primary Gin engine. +// The handler is served as-is without additional middleware beyond the standard stack already configured. +func (s *Server) AttachWebsocketRoute(path string, handler http.Handler) { + if s == nil || s.engine == nil || handler == nil { + return + } + trimmed := strings.TrimSpace(path) + if trimmed == "" { + trimmed = "/v1/ws" + } + if !strings.HasPrefix(trimmed, "/") { + trimmed = "/" + trimmed + } + s.wsRouteMu.Lock() + if _, exists := s.wsRoutes[trimmed]; exists { + s.wsRouteMu.Unlock() + return + } + s.wsRoutes[trimmed] = struct{}{} + s.wsRouteMu.Unlock() + + s.engine.GET(trimmed, func(c *gin.Context) { + handler.ServeHTTP(c.Writer, c.Request) + c.Abort() + }) +} + func (s *Server) registerManagementRoutes() { if s == nil || s.engine == nil || s.mgmt == nil { return diff --git a/internal/runtime/executor/aistudio_executor.go b/internal/runtime/executor/aistudio_executor.go new file mode 100644 index 00000000..3eb9af24 --- /dev/null +++ b/internal/runtime/executor/aistudio_executor.go @@ -0,0 +1,264 @@ +package executor + +import ( + "bytes" + "context" + "fmt" + "net/http" + "net/url" + "strings" + + "github.com/router-for-me/CLIProxyAPI/v6/internal/config" + "github.com/router-for-me/CLIProxyAPI/v6/internal/util" + "github.com/router-for-me/CLIProxyAPI/v6/internal/wsrelay" + cliproxyauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth" + cliproxyexecutor "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/executor" + sdktranslator "github.com/router-for-me/CLIProxyAPI/v6/sdk/translator" + "github.com/tidwall/sjson" +) + +// AistudioExecutor routes AI Studio requests through a websocket-backed transport. +type AistudioExecutor struct { + provider string + relay *wsrelay.Manager + cfg *config.Config +} + +// NewAistudioExecutor constructs a websocket executor for the provider name. +func NewAistudioExecutor(cfg *config.Config, provider string, relay *wsrelay.Manager) *AistudioExecutor { + return &AistudioExecutor{provider: strings.ToLower(provider), relay: relay, cfg: cfg} +} + +// Identifier returns the provider key served by this executor. +func (e *AistudioExecutor) Identifier() string { return e.provider } + +// PrepareRequest is a no-op because websocket transport already injects headers. +func (e *AistudioExecutor) PrepareRequest(_ *http.Request, _ *cliproxyauth.Auth) error { + return nil +} + +func (e *AistudioExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (cliproxyexecutor.Response, error) { + translatedReq, body, err := e.translateRequest(req, opts, false) + if err != nil { + return cliproxyexecutor.Response{}, err + } + endpoint := e.buildEndpoint(req.Model, body.action, opts.Alt) + wsReq := &wsrelay.HTTPRequest{ + Method: http.MethodPost, + URL: endpoint, + Headers: http.Header{"Content-Type": []string{"application/json"}}, + Body: body.payload, + } + + var authID, authLabel, authType, authValue string + if auth != nil { + authID = auth.ID + authLabel = auth.Label + authType, authValue = auth.AccountInfo() + } + recordAPIRequest(ctx, e.cfg, upstreamRequestLog{ + URL: endpoint, + Method: http.MethodPost, + Headers: wsReq.Headers.Clone(), + Body: bytes.Clone(body.payload), + Provider: e.provider, + AuthID: authID, + AuthLabel: authLabel, + AuthType: authType, + AuthValue: authValue, + }) + + resp, err := e.relay.RoundTrip(ctx, e.provider, wsReq) + if err != nil { + recordAPIResponseError(ctx, e.cfg, err) + return cliproxyexecutor.Response{}, err + } + recordAPIResponseMetadata(ctx, e.cfg, resp.Status, resp.Headers.Clone()) + if len(resp.Body) > 0 { + appendAPIResponseChunk(ctx, e.cfg, bytes.Clone(resp.Body)) + } + if resp.Status < 200 || resp.Status >= 300 { + return cliproxyexecutor.Response{}, statusErr{code: resp.Status, msg: string(resp.Body)} + } + var param any + out := sdktranslator.TranslateNonStream(ctx, body.toFormat, opts.SourceFormat, req.Model, bytes.Clone(opts.OriginalRequest), translatedReq, bytes.Clone(resp.Body), ¶m) + return cliproxyexecutor.Response{Payload: []byte(out)}, nil +} + +func (e *AistudioExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (<-chan cliproxyexecutor.StreamChunk, error) { + translatedReq, body, err := e.translateRequest(req, opts, true) + if err != nil { + return nil, err + } + endpoint := e.buildEndpoint(req.Model, body.action, opts.Alt) + wsReq := &wsrelay.HTTPRequest{ + Method: http.MethodPost, + URL: endpoint, + Headers: http.Header{"Content-Type": []string{"application/json"}}, + Body: body.payload, + } + var authID, authLabel, authType, authValue string + if auth != nil { + authID = auth.ID + authLabel = auth.Label + authType, authValue = auth.AccountInfo() + } + recordAPIRequest(ctx, e.cfg, upstreamRequestLog{ + URL: endpoint, + Method: http.MethodPost, + Headers: wsReq.Headers.Clone(), + Body: bytes.Clone(body.payload), + Provider: e.provider, + AuthID: authID, + AuthLabel: authLabel, + AuthType: authType, + AuthValue: authValue, + }) + stream, err := e.relay.Stream(ctx, e.provider, wsReq) + if err != nil { + recordAPIResponseError(ctx, e.cfg, err) + return nil, err + } + out := make(chan cliproxyexecutor.StreamChunk) + go func() { + defer close(out) + var param any + metadataLogged := false + for event := range stream { + if event.Err != nil { + recordAPIResponseError(ctx, e.cfg, event.Err) + out <- cliproxyexecutor.StreamChunk{Err: event.Err} + return + } + switch event.Type { + case wsrelay.MessageTypeStreamStart: + if !metadataLogged && event.Status > 0 { + recordAPIResponseMetadata(ctx, e.cfg, event.Status, event.Headers.Clone()) + metadataLogged = true + } + case wsrelay.MessageTypeStreamChunk: + if len(event.Payload) > 0 { + appendAPIResponseChunk(ctx, e.cfg, bytes.Clone(event.Payload)) + } + lines := sdktranslator.TranslateStream(ctx, body.toFormat, opts.SourceFormat, req.Model, bytes.Clone(opts.OriginalRequest), translatedReq, bytes.Clone(event.Payload), ¶m) + for i := range lines { + out <- cliproxyexecutor.StreamChunk{Payload: []byte(lines[i])} + } + case wsrelay.MessageTypeStreamEnd: + return + case wsrelay.MessageTypeHTTPResp: + if !metadataLogged && event.Status > 0 { + recordAPIResponseMetadata(ctx, e.cfg, event.Status, event.Headers.Clone()) + metadataLogged = true + } + if len(event.Payload) > 0 { + appendAPIResponseChunk(ctx, e.cfg, bytes.Clone(event.Payload)) + } + lines := sdktranslator.TranslateStream(ctx, body.toFormat, opts.SourceFormat, req.Model, bytes.Clone(opts.OriginalRequest), translatedReq, bytes.Clone(event.Payload), ¶m) + for i := range lines { + out <- cliproxyexecutor.StreamChunk{Payload: []byte(lines[i])} + } + return + case wsrelay.MessageTypeError: + recordAPIResponseError(ctx, e.cfg, event.Err) + out <- cliproxyexecutor.StreamChunk{Err: fmt.Errorf("wsrelay: %v", event.Err)} + return + } + } + }() + return out, nil +} + +func (e *AistudioExecutor) CountTokens(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (cliproxyexecutor.Response, error) { + translatedReq, body, err := e.translateRequest(req, opts, false) + if err != nil { + return cliproxyexecutor.Response{}, err + } + endpoint := e.buildEndpoint(req.Model, "countTokens", "") + wsReq := &wsrelay.HTTPRequest{ + Method: http.MethodPost, + URL: endpoint, + Headers: http.Header{"Content-Type": []string{"application/json"}}, + Body: body.payload, + } + var authID, authLabel, authType, authValue string + if auth != nil { + authID = auth.ID + authLabel = auth.Label + authType, authValue = auth.AccountInfo() + } + recordAPIRequest(ctx, e.cfg, upstreamRequestLog{ + URL: endpoint, + Method: http.MethodPost, + Headers: wsReq.Headers.Clone(), + Body: bytes.Clone(body.payload), + Provider: e.provider, + AuthID: authID, + AuthLabel: authLabel, + AuthType: authType, + AuthValue: authValue, + }) + resp, err := e.relay.RoundTrip(ctx, e.provider, wsReq) + if err != nil { + recordAPIResponseError(ctx, e.cfg, err) + return cliproxyexecutor.Response{}, err + } + recordAPIResponseMetadata(ctx, e.cfg, resp.Status, resp.Headers.Clone()) + if len(resp.Body) > 0 { + appendAPIResponseChunk(ctx, e.cfg, bytes.Clone(resp.Body)) + } + if resp.Status < 200 || resp.Status >= 300 { + return cliproxyexecutor.Response{}, statusErr{code: resp.Status, msg: string(resp.Body)} + } + var param any + out := sdktranslator.TranslateNonStream(ctx, body.toFormat, opts.SourceFormat, req.Model, bytes.Clone(opts.OriginalRequest), translatedReq, bytes.Clone(resp.Body), ¶m) + return cliproxyexecutor.Response{Payload: []byte(out)}, nil +} + +func (e *AistudioExecutor) Refresh(ctx context.Context, auth *cliproxyauth.Auth) (*cliproxyauth.Auth, error) { + _ = ctx + return auth, nil +} + +type translatedPayload struct { + payload []byte + action string + toFormat sdktranslator.Format +} + +func (e *AistudioExecutor) translateRequest(req cliproxyexecutor.Request, opts cliproxyexecutor.Options, stream bool) ([]byte, translatedPayload, error) { + from := opts.SourceFormat + to := sdktranslator.FromString("gemini") + payload := sdktranslator.TranslateRequest(from, to, req.Model, bytes.Clone(req.Payload), stream) + if budgetOverride, includeOverride, ok := util.GeminiThinkingFromMetadata(req.Metadata); ok { + payload = util.ApplyGeminiThinkingConfig(payload, budgetOverride, includeOverride) + } + payload = disableGeminiThinkingConfig(payload, req.Model) + payload = fixGeminiImageAspectRatio(req.Model, payload) + metadataAction := "generateContent" + if req.Metadata != nil { + if action, _ := req.Metadata["action"].(string); action == "countTokens" { + metadataAction = action + } + } + action := metadataAction + if stream && action != "countTokens" { + action = "streamGenerateContent" + } + payload, _ = sjson.DeleteBytes(payload, "session_id") + return payload, translatedPayload{payload: payload, action: action, toFormat: to}, nil +} + +func (e *AistudioExecutor) buildEndpoint(model, action, alt string) string { + base := fmt.Sprintf("%s/%s/models/%s:%s", glEndpoint, glAPIVersion, model, action) + if action == "streamGenerateContent" { + if alt == "" { + return base + "?alt=sse" + } + return base + "?$alt=" + url.QueryEscape(alt) + } + if alt != "" && action != "countTokens" { + return base + "?$alt=" + url.QueryEscape(alt) + } + return base +} diff --git a/internal/wsrelay/http.go b/internal/wsrelay/http.go new file mode 100644 index 00000000..96f80ec3 --- /dev/null +++ b/internal/wsrelay/http.go @@ -0,0 +1,187 @@ +package wsrelay + +import ( + "context" + "errors" + "fmt" + "net/http" + "time" + + "github.com/google/uuid" +) + +// HTTPRequest represents a proxied HTTP request delivered to websocket clients. +type HTTPRequest struct { + Method string + URL string + Headers http.Header + Body []byte +} + +// HTTPResponse captures the response relayed back from websocket clients. +type HTTPResponse struct { + Status int + Headers http.Header + Body []byte +} + +// StreamEvent represents a streaming response event from clients. +type StreamEvent struct { + Type string + Payload []byte + Status int + Headers http.Header + Err error +} + +// RoundTrip executes a non-streaming HTTP request using the websocket provider. +func (m *Manager) RoundTrip(ctx context.Context, provider string, req *HTTPRequest) (*HTTPResponse, error) { + if req == nil { + return nil, fmt.Errorf("wsrelay: request is nil") + } + msg := Message{ID: uuid.NewString(), Type: MessageTypeHTTPReq, Payload: encodeRequest(req)} + respCh, err := m.Send(ctx, provider, msg) + if err != nil { + return nil, err + } + for { + select { + case <-ctx.Done(): + return nil, ctx.Err() + case msg, ok := <-respCh: + if !ok { + return nil, errors.New("wsrelay: connection closed during response") + } + switch msg.Type { + case MessageTypeHTTPResp: + return decodeResponse(msg.Payload), nil + case MessageTypeError: + return nil, decodeError(msg.Payload) + case MessageTypeStreamStart, MessageTypeStreamChunk: + // Ignore streaming noise in non-stream requests. + default: + } + } + } +} + +// Stream executes a streaming HTTP request and returns channel with stream events. +func (m *Manager) Stream(ctx context.Context, provider string, req *HTTPRequest) (<-chan StreamEvent, error) { + if req == nil { + return nil, fmt.Errorf("wsrelay: request is nil") + } + msg := Message{ID: uuid.NewString(), Type: MessageTypeHTTPReq, Payload: encodeRequest(req)} + respCh, err := m.Send(ctx, provider, msg) + if err != nil { + return nil, err + } + out := make(chan StreamEvent) + go func() { + defer close(out) + for { + select { + case <-ctx.Done(): + out <- StreamEvent{Err: ctx.Err()} + return + case msg, ok := <-respCh: + if !ok { + out <- StreamEvent{Err: errors.New("wsrelay: stream closed")} + return + } + switch msg.Type { + case MessageTypeStreamStart: + resp := decodeResponse(msg.Payload) + out <- StreamEvent{Type: MessageTypeStreamStart, Status: resp.Status, Headers: resp.Headers} + case MessageTypeStreamChunk: + chunk := decodeChunk(msg.Payload) + out <- StreamEvent{Type: MessageTypeStreamChunk, Payload: chunk} + case MessageTypeStreamEnd: + out <- StreamEvent{Type: MessageTypeStreamEnd} + return + case MessageTypeError: + out <- StreamEvent{Type: MessageTypeError, Err: decodeError(msg.Payload)} + return + case MessageTypeHTTPResp: + resp := decodeResponse(msg.Payload) + out <- StreamEvent{Type: MessageTypeHTTPResp, Status: resp.Status, Headers: resp.Headers, Payload: resp.Body} + return + default: + } + } + } + }() + return out, nil +} + +func encodeRequest(req *HTTPRequest) map[string]any { + headers := make(map[string]any, len(req.Headers)) + for key, values := range req.Headers { + copyValues := make([]string, len(values)) + copy(copyValues, values) + headers[key] = copyValues + } + return map[string]any{ + "method": req.Method, + "url": req.URL, + "headers": headers, + "body": string(req.Body), + "sent_at": time.Now().UTC().Format(time.RFC3339Nano), + } +} + +func decodeResponse(payload map[string]any) *HTTPResponse { + if payload == nil { + return &HTTPResponse{Status: http.StatusBadGateway, Headers: make(http.Header)} + } + resp := &HTTPResponse{Status: http.StatusOK, Headers: make(http.Header)} + if status, ok := payload["status"].(float64); ok { + resp.Status = int(status) + } + if headers, ok := payload["headers"].(map[string]any); ok { + for key, raw := range headers { + switch v := raw.(type) { + case []any: + for _, item := range v { + if str, ok := item.(string); ok { + resp.Headers.Add(key, str) + } + } + case []string: + for _, str := range v { + resp.Headers.Add(key, str) + } + case string: + resp.Headers.Set(key, v) + } + } + } + if body, ok := payload["body"].(string); ok { + resp.Body = []byte(body) + } + return resp +} + +func decodeChunk(payload map[string]any) []byte { + if payload == nil { + return nil + } + if data, ok := payload["data"].(string); ok { + return []byte(data) + } + return nil +} + +func decodeError(payload map[string]any) error { + if payload == nil { + return errors.New("wsrelay: unknown error") + } + message, _ := payload["error"].(string) + status := 0 + if v, ok := payload["status"].(float64); ok { + status = int(v) + } + if message == "" { + message = "wsrelay: upstream error" + } + return fmt.Errorf("%s (status=%d)", message, status) +} diff --git a/internal/wsrelay/manager.go b/internal/wsrelay/manager.go new file mode 100644 index 00000000..ab32f9f3 --- /dev/null +++ b/internal/wsrelay/manager.go @@ -0,0 +1,200 @@ +package wsrelay + +import ( + "context" + "crypto/rand" + "errors" + "fmt" + "net/http" + "strings" + "sync" + "time" + + "github.com/gorilla/websocket" +) + +// Manager exposes a websocket endpoint that proxies Gemini requests to +// connected clients. +type Manager struct { + path string + upgrader websocket.Upgrader + sessions map[string]*session + sessMutex sync.RWMutex + + providerFactory func(*http.Request) (string, error) + onConnected func(string) + onDisconnected func(string, error) + + logDebugf func(string, ...any) + logInfof func(string, ...any) + logWarnf func(string, ...any) +} + +// Options configures a Manager instance. +type Options struct { + Path string + ProviderFactory func(*http.Request) (string, error) + OnConnected func(string) + OnDisconnected func(string, error) + LogDebugf func(string, ...any) + LogInfof func(string, ...any) + LogWarnf func(string, ...any) +} + +// NewManager builds a websocket relay manager with the supplied options. +func NewManager(opts Options) *Manager { + path := strings.TrimSpace(opts.Path) + if path == "" { + path = "/v1/ws" + } + if !strings.HasPrefix(path, "/") { + path = "/" + path + } + mgr := &Manager{ + path: path, + sessions: make(map[string]*session), + upgrader: websocket.Upgrader{ + ReadBufferSize: 1024, + WriteBufferSize: 1024, + CheckOrigin: func(r *http.Request) bool { + return true + }, + }, + providerFactory: opts.ProviderFactory, + onConnected: opts.OnConnected, + onDisconnected: opts.OnDisconnected, + logDebugf: opts.LogDebugf, + logInfof: opts.LogInfof, + logWarnf: opts.LogWarnf, + } + if mgr.logDebugf == nil { + mgr.logDebugf = func(string, ...any) {} + } + if mgr.logInfof == nil { + mgr.logInfof = func(string, ...any) {} + } + if mgr.logWarnf == nil { + mgr.logWarnf = func(s string, args ...any) { fmt.Printf(s+"\n", args...) } + } + return mgr +} + +// Path returns the HTTP path the manager expects for websocket upgrades. +func (m *Manager) Path() string { + if m == nil { + return "/v1/ws" + } + return m.path +} + +// Handler exposes an http.Handler that upgrades connections to websocket sessions. +func (m *Manager) Handler() http.Handler { + return http.HandlerFunc(m.handleWebsocket) +} + +// Stop gracefully closes all active websocket sessions. +func (m *Manager) Stop(_ context.Context) error { + m.sessMutex.Lock() + sessions := make([]*session, 0, len(m.sessions)) + for _, sess := range m.sessions { + sessions = append(sessions, sess) + } + m.sessions = make(map[string]*session) + m.sessMutex.Unlock() + + for _, sess := range sessions { + if sess != nil { + sess.cleanup(errors.New("wsrelay: manager stopped")) + } + } + return nil +} + +// handleWebsocket upgrades the connection and wires the session into the pool. +func (m *Manager) handleWebsocket(w http.ResponseWriter, r *http.Request) { + expectedPath := m.Path() + if expectedPath != "" && r.URL != nil && r.URL.Path != expectedPath { + http.NotFound(w, r) + return + } + if !strings.EqualFold(r.Method, http.MethodGet) { + w.Header().Set("Allow", http.MethodGet) + http.Error(w, "method not allowed", http.StatusMethodNotAllowed) + return + } + conn, err := m.upgrader.Upgrade(w, r, nil) + if err != nil { + m.logWarnf("wsrelay: upgrade failed: %v", err) + return + } + s := newSession(conn, m, randomProviderName()) + if m.providerFactory != nil { + name, err := m.providerFactory(r) + if err != nil { + s.cleanup(err) + return + } + if strings.TrimSpace(name) != "" { + s.provider = strings.ToLower(name) + } + } + if s.provider == "" { + s.provider = strings.ToLower(s.id) + } + m.sessMutex.Lock() + if existing, ok := m.sessions[s.provider]; ok { + existing.cleanup(errors.New("replaced by new connection")) + } + m.sessions[s.provider] = s + m.sessMutex.Unlock() + if m.onConnected != nil { + m.onConnected(s.provider) + } + + go s.run(context.Background()) +} + +// Send forwards the message to the specific provider connection and returns a channel +// yielding response messages. +func (m *Manager) Send(ctx context.Context, provider string, msg Message) (<-chan Message, error) { + s := m.session(provider) + if s == nil { + return nil, fmt.Errorf("wsrelay: provider %s not connected", provider) + } + return s.request(ctx, msg) +} + +func (m *Manager) session(provider string) *session { + key := strings.ToLower(strings.TrimSpace(provider)) + m.sessMutex.RLock() + s := m.sessions[key] + m.sessMutex.RUnlock() + return s +} + +func (m *Manager) handleSessionClosed(s *session, cause error) { + if s == nil { + return + } + key := strings.ToLower(strings.TrimSpace(s.provider)) + m.sessMutex.Lock() + if cur, ok := m.sessions[key]; ok && cur == s { + delete(m.sessions, key) + } + m.sessMutex.Unlock() + if m.onDisconnected != nil { + m.onDisconnected(s.provider, cause) + } +} + +func randomProviderName() string { + const alphabet = "abcdefghijklmnopqrstuvwxyz0123456789" + buf := make([]byte, 16) + if _, err := rand.Read(buf); err != nil { + return fmt.Sprintf("aistudio-%x", time.Now().UnixNano()) + } + for i := range buf { + buf[i] = alphabet[int(buf[i])%len(alphabet)] + } + return "aistudio-" + string(buf) +} diff --git a/internal/wsrelay/message.go b/internal/wsrelay/message.go new file mode 100644 index 00000000..bf716e5e --- /dev/null +++ b/internal/wsrelay/message.go @@ -0,0 +1,27 @@ +package wsrelay + +// Message represents the JSON payload exchanged with websocket clients. +type Message struct { + ID string `json:"id"` + Type string `json:"type"` + Payload map[string]any `json:"payload,omitempty"` +} + +const ( + // MessageTypeHTTPReq identifies an HTTP-style request envelope. + MessageTypeHTTPReq = "http_request" + // MessageTypeHTTPResp identifies a non-streaming HTTP response envelope. + MessageTypeHTTPResp = "http_response" + // MessageTypeStreamStart marks the beginning of a streaming response. + MessageTypeStreamStart = "stream_start" + // MessageTypeStreamChunk carries a streaming response chunk. + MessageTypeStreamChunk = "stream_chunk" + // MessageTypeStreamEnd marks the completion of a streaming response. + MessageTypeStreamEnd = "stream_end" + // MessageTypeError carries an error response. + MessageTypeError = "error" + // MessageTypePing represents ping messages from clients. + MessageTypePing = "ping" + // MessageTypePong represents pong responses back to clients. + MessageTypePong = "pong" +) diff --git a/internal/wsrelay/session.go b/internal/wsrelay/session.go new file mode 100644 index 00000000..a728cbc3 --- /dev/null +++ b/internal/wsrelay/session.go @@ -0,0 +1,188 @@ +package wsrelay + +import ( + "context" + "errors" + "fmt" + "sync" + "time" + + "github.com/gorilla/websocket" +) + +const ( + readTimeout = 60 * time.Second + writeTimeout = 10 * time.Second + maxInboundMessageLen = 64 << 20 // 64 MiB + heartbeatInterval = 30 * time.Second +) + +var errClosed = errors.New("websocket session closed") + +type pendingRequest struct { + ch chan Message + closeOnce sync.Once +} + +func (pr *pendingRequest) close() { + if pr == nil { + return + } + pr.closeOnce.Do(func() { + close(pr.ch) + }) +} + +type session struct { + conn *websocket.Conn + manager *Manager + provider string + id string + closed chan struct{} + closeOnce sync.Once + writeMutex sync.Mutex + pending sync.Map // map[string]*pendingRequest +} + +func newSession(conn *websocket.Conn, mgr *Manager, id string) *session { + s := &session{ + conn: conn, + manager: mgr, + provider: "", + id: id, + closed: make(chan struct{}), + } + conn.SetReadLimit(maxInboundMessageLen) + conn.SetReadDeadline(time.Now().Add(readTimeout)) + conn.SetPongHandler(func(string) error { + conn.SetReadDeadline(time.Now().Add(readTimeout)) + return nil + }) + s.startHeartbeat() + return s +} + +func (s *session) startHeartbeat() { + if s == nil || s.conn == nil { + return + } + ticker := time.NewTicker(heartbeatInterval) + go func() { + defer ticker.Stop() + for { + select { + case <-s.closed: + return + case <-ticker.C: + s.writeMutex.Lock() + err := s.conn.WriteControl(websocket.PingMessage, []byte("ping"), time.Now().Add(writeTimeout)) + s.writeMutex.Unlock() + if err != nil { + s.cleanup(err) + return + } + } + } + }() +} + +func (s *session) run(ctx context.Context) { + defer s.cleanup(errClosed) + for { + var msg Message + if err := s.conn.ReadJSON(&msg); err != nil { + s.cleanup(err) + return + } + s.dispatch(msg) + } +} + +func (s *session) dispatch(msg Message) { + if msg.Type == MessageTypePing { + _ = s.send(context.Background(), Message{ID: msg.ID, Type: MessageTypePong}) + return + } + if value, ok := s.pending.Load(msg.ID); ok { + req := value.(*pendingRequest) + select { + case req.ch <- msg: + default: + } + if msg.Type == MessageTypeHTTPResp || msg.Type == MessageTypeError || msg.Type == MessageTypeStreamEnd { + if actual, loaded := s.pending.LoadAndDelete(msg.ID); loaded { + actual.(*pendingRequest).close() + } + } + return + } + if msg.Type == MessageTypeHTTPResp || msg.Type == MessageTypeError || msg.Type == MessageTypeStreamEnd { + s.manager.logDebugf("wsrelay: received terminal message for unknown id %s (provider=%s)", msg.ID, s.provider) + } +} + +func (s *session) send(ctx context.Context, msg Message) error { + select { + case <-s.closed: + return errClosed + default: + } + s.writeMutex.Lock() + defer s.writeMutex.Unlock() + if err := s.conn.SetWriteDeadline(time.Now().Add(writeTimeout)); err != nil { + return fmt.Errorf("set write deadline: %w", err) + } + if err := s.conn.WriteJSON(msg); err != nil { + return fmt.Errorf("write json: %w", err) + } + return nil +} + +func (s *session) request(ctx context.Context, msg Message) (<-chan Message, error) { + if msg.ID == "" { + return nil, fmt.Errorf("wsrelay: message id is required") + } + if _, loaded := s.pending.LoadOrStore(msg.ID, &pendingRequest{ch: make(chan Message, 8)}); loaded { + return nil, fmt.Errorf("wsrelay: duplicate message id %s", msg.ID) + } + value, _ := s.pending.Load(msg.ID) + req := value.(*pendingRequest) + if err := s.send(ctx, msg); err != nil { + if actual, loaded := s.pending.LoadAndDelete(msg.ID); loaded { + req := actual.(*pendingRequest) + req.close() + } + return nil, err + } + go func() { + select { + case <-ctx.Done(): + if actual, loaded := s.pending.LoadAndDelete(msg.ID); loaded { + actual.(*pendingRequest).close() + } + case <-s.closed: + } + }() + return req.ch, nil +} + +func (s *session) cleanup(cause error) { + s.closeOnce.Do(func() { + close(s.closed) + s.pending.Range(func(key, value any) bool { + req := value.(*pendingRequest) + msg := Message{ID: key.(string), Type: MessageTypeError, Payload: map[string]any{"error": cause.Error()}} + select { + case req.ch <- msg: + default: + } + req.close() + return true + }) + s.pending = sync.Map{} + _ = s.conn.Close() + if s.manager != nil { + s.manager.handleSessionClosed(s, cause) + } + }) +} diff --git a/sdk/cliproxy/auth/manager.go b/sdk/cliproxy/auth/manager.go index c2e87d9d..2cf7c77e 100644 --- a/sdk/cliproxy/auth/manager.go +++ b/sdk/cliproxy/auth/manager.go @@ -153,6 +153,17 @@ func (m *Manager) RegisterExecutor(executor ProviderExecutor) { m.executors[executor.Identifier()] = executor } +// UnregisterExecutor removes the executor associated with the provider key. +func (m *Manager) UnregisterExecutor(provider string) { + provider = strings.ToLower(strings.TrimSpace(provider)) + if provider == "" { + return + } + m.mu.Lock() + delete(m.executors, provider) + m.mu.Unlock() +} + // Register inserts a new auth entry into the manager. func (m *Manager) Register(ctx context.Context, auth *Auth) (*Auth, error) { if auth == nil { diff --git a/sdk/cliproxy/auth/types.go b/sdk/cliproxy/auth/types.go index 35594bd8..2755383f 100644 --- a/sdk/cliproxy/auth/types.go +++ b/sdk/cliproxy/auth/types.go @@ -156,7 +156,17 @@ func (a *Auth) AccountInfo() (string, string) { if v, ok := a.Metadata["email"].(string); ok { return "oauth", v } - } else if a.Attributes != nil { + } + if strings.HasPrefix(strings.ToLower(strings.TrimSpace(a.Provider)), "aistudio-") { + if label := strings.TrimSpace(a.Label); label != "" { + return "oauth", label + } + if id := strings.TrimSpace(a.ID); id != "" { + return "oauth", id + } + return "oauth", "aistudio" + } + if a.Attributes != nil { if v := a.Attributes["api_key"]; v != "" { return "api_key", v } diff --git a/sdk/cliproxy/service.go b/sdk/cliproxy/service.go index f0b44884..ccbdf903 100644 --- a/sdk/cliproxy/service.go +++ b/sdk/cliproxy/service.go @@ -18,6 +18,7 @@ import ( "github.com/router-for-me/CLIProxyAPI/v6/internal/runtime/executor" _ "github.com/router-for-me/CLIProxyAPI/v6/internal/usage" "github.com/router-for-me/CLIProxyAPI/v6/internal/watcher" + "github.com/router-for-me/CLIProxyAPI/v6/internal/wsrelay" sdkaccess "github.com/router-for-me/CLIProxyAPI/v6/sdk/access" sdkAuth "github.com/router-for-me/CLIProxyAPI/v6/sdk/auth" coreauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth" @@ -82,6 +83,9 @@ type Service struct { // shutdownOnce ensures shutdown is called only once. shutdownOnce sync.Once + + // wsGateway manages websocket Gemini providers. + wsGateway *wsrelay.Manager } // RegisterUsagePlugin registers a usage plugin on the global usage manager. @@ -172,6 +176,66 @@ func (s *Service) handleAuthUpdate(ctx context.Context, update watcher.AuthUpdat } } +func (s *Service) ensureWebsocketGateway() { + if s == nil { + return + } + if s.wsGateway != nil { + return + } + opts := wsrelay.Options{ + Path: "/v1/ws", + OnConnected: s.wsOnConnected, + OnDisconnected: s.wsOnDisconnected, + LogDebugf: log.Debugf, + LogInfof: log.Infof, + LogWarnf: log.Warnf, + } + s.wsGateway = wsrelay.NewManager(opts) +} + +func (s *Service) wsOnConnected(provider string) { + if s == nil || provider == "" { + return + } + if !strings.HasPrefix(strings.ToLower(provider), "aistudio-") { + return + } + if s.coreManager != nil { + if existing, ok := s.coreManager.GetByID(provider); ok && existing != nil { + return + } + } + now := time.Now().UTC() + auth := &coreauth.Auth{ + ID: provider, + Provider: provider, + Label: provider, + Status: coreauth.StatusActive, + CreatedAt: now, + UpdatedAt: now, + Attributes: map[string]string{"ws_provider": "gemini"}, + } + log.Infof("websocket provider connected: %s", provider) + s.applyCoreAuthAddOrUpdate(context.Background(), auth) +} + +func (s *Service) wsOnDisconnected(provider string, reason error) { + if s == nil || provider == "" { + return + } + if reason != nil { + log.Warnf("websocket provider disconnected: %s (%v)", provider, reason) + } else { + log.Infof("websocket provider disconnected: %s", provider) + } + ctx := context.Background() + s.applyCoreAuthRemoval(ctx, provider) + if s.coreManager != nil { + s.coreManager.UnregisterExecutor(provider) + } +} + func (s *Service) applyCoreAuthAddOrUpdate(ctx context.Context, auth *coreauth.Auth) { if s == nil || auth == nil || auth.ID == "" { return @@ -247,6 +311,12 @@ func (s *Service) ensureExecutorsForAuth(a *coreauth.Auth) { s.coreManager.RegisterExecutor(executor.NewOpenAICompatExecutor(compatProviderKey, s.cfg)) return } + if strings.HasPrefix(strings.ToLower(strings.TrimSpace(a.Provider)), "aistudio-") { + if s.wsGateway != nil { + s.coreManager.RegisterExecutor(executor.NewAistudioExecutor(s.cfg, a.Provider, s.wsGateway)) + } + return + } switch strings.ToLower(a.Provider) { case "gemini": s.coreManager.RegisterExecutor(executor.NewGeminiExecutor(s.cfg)) @@ -342,6 +412,11 @@ func (s *Service) Run(ctx context.Context) error { s.authManager = newDefaultAuthManager() } + s.ensureWebsocketGateway() + if s.server != nil && s.wsGateway != nil { + s.server.AttachWebsocketRoute(s.wsGateway.Path(), s.wsGateway.Handler()) + } + if s.hooks.OnBeforeStart != nil { s.hooks.OnBeforeStart(s.cfg) } @@ -449,6 +524,14 @@ func (s *Service) Shutdown(ctx context.Context) error { shutdownErr = err } } + if s.wsGateway != nil { + if err := s.wsGateway.Stop(ctx); err != nil { + log.Errorf("failed to stop websocket gateway: %v", err) + if shutdownErr == nil { + shutdownErr = err + } + } + } if s.authQueueStop != nil { s.authQueueStop() s.authQueueStop = nil @@ -505,6 +588,13 @@ func (s *Service) registerModelsForAuth(a *coreauth.Auth) { } provider := strings.ToLower(strings.TrimSpace(a.Provider)) compatProviderKey, compatDisplayName, compatDetected := openAICompatInfoFromAuth(a) + if a.Attributes != nil { + if strings.EqualFold(a.Attributes["ws_provider"], "gemini") { + models := mergeGeminiModels() + GlobalModelRegistry().RegisterClient(a.ID, provider, models) + return + } + } if compatDetected { provider = "openai-compatibility" } @@ -611,3 +701,24 @@ func (s *Service) registerModelsForAuth(a *coreauth.Auth) { GlobalModelRegistry().RegisterClient(a.ID, key, models) } } + +func mergeGeminiModels() []*ModelInfo { + models := make([]*ModelInfo, 0, 16) + seen := make(map[string]struct{}) + appendModels := func(items []*ModelInfo) { + for i := range items { + m := items[i] + if m == nil || m.ID == "" { + continue + } + if _, ok := seen[m.ID]; ok { + continue + } + seen[m.ID] = struct{}{} + models = append(models, m) + } + } + appendModels(registry.GetGeminiModels()) + appendModels(registry.GetGeminiCLIModels()) + return models +} From c32e0136050279c1d7da444e8d776147a0864dd5 Mon Sep 17 00:00:00 2001 From: hkfires <10558748+hkfires@users.noreply.github.com> Date: Sat, 25 Oct 2025 14:45:42 +0800 Subject: [PATCH 2/6] feat(aistudio): track Gemini usage and improve stream errors --- .../runtime/executor/aistudio_executor.go | 59 ++++++++++++------- internal/wsrelay/manager.go | 7 ++- sdk/cliproxy/service.go | 8 ++- 3 files changed, 52 insertions(+), 22 deletions(-) diff --git a/internal/runtime/executor/aistudio_executor.go b/internal/runtime/executor/aistudio_executor.go index 3eb9af24..4bcdab3a 100644 --- a/internal/runtime/executor/aistudio_executor.go +++ b/internal/runtime/executor/aistudio_executor.go @@ -14,6 +14,7 @@ import ( cliproxyauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth" cliproxyexecutor "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/executor" sdktranslator "github.com/router-for-me/CLIProxyAPI/v6/sdk/translator" + "github.com/tidwall/gjson" "github.com/tidwall/sjson" ) @@ -37,10 +38,13 @@ func (e *AistudioExecutor) PrepareRequest(_ *http.Request, _ *cliproxyauth.Auth) return nil } -func (e *AistudioExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (cliproxyexecutor.Response, error) { +func (e *AistudioExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (resp cliproxyexecutor.Response, err error) { + reporter := newUsageReporter(ctx, e.Identifier(), req.Model, auth) + defer reporter.trackFailure(ctx, &err) + translatedReq, body, err := e.translateRequest(req, opts, false) if err != nil { - return cliproxyexecutor.Response{}, err + return resp, err } endpoint := e.buildEndpoint(req.Model, body.action, opts.Alt) wsReq := &wsrelay.HTTPRequest{ @@ -68,24 +72,29 @@ func (e *AistudioExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, AuthValue: authValue, }) - resp, err := e.relay.RoundTrip(ctx, e.provider, wsReq) + wsResp, err := e.relay.RoundTrip(ctx, e.provider, wsReq) if err != nil { recordAPIResponseError(ctx, e.cfg, err) - return cliproxyexecutor.Response{}, err + return resp, err } - recordAPIResponseMetadata(ctx, e.cfg, resp.Status, resp.Headers.Clone()) - if len(resp.Body) > 0 { - appendAPIResponseChunk(ctx, e.cfg, bytes.Clone(resp.Body)) + recordAPIResponseMetadata(ctx, e.cfg, wsResp.Status, wsResp.Headers.Clone()) + if len(wsResp.Body) > 0 { + appendAPIResponseChunk(ctx, e.cfg, bytes.Clone(wsResp.Body)) } - if resp.Status < 200 || resp.Status >= 300 { - return cliproxyexecutor.Response{}, statusErr{code: resp.Status, msg: string(resp.Body)} + if wsResp.Status < 200 || wsResp.Status >= 300 { + return resp, statusErr{code: wsResp.Status, msg: string(wsResp.Body)} } + reporter.publish(ctx, parseGeminiUsage(wsResp.Body)) var param any - out := sdktranslator.TranslateNonStream(ctx, body.toFormat, opts.SourceFormat, req.Model, bytes.Clone(opts.OriginalRequest), translatedReq, bytes.Clone(resp.Body), ¶m) - return cliproxyexecutor.Response{Payload: []byte(out)}, nil + out := sdktranslator.TranslateNonStream(ctx, body.toFormat, opts.SourceFormat, req.Model, bytes.Clone(opts.OriginalRequest), bytes.Clone(translatedReq), bytes.Clone(wsResp.Body), ¶m) + resp = cliproxyexecutor.Response{Payload: []byte(out)} + return resp, nil } -func (e *AistudioExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (<-chan cliproxyexecutor.StreamChunk, error) { +func (e *AistudioExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (stream <-chan cliproxyexecutor.StreamChunk, err error) { + reporter := newUsageReporter(ctx, e.Identifier(), req.Model, auth) + defer reporter.trackFailure(ctx, &err) + translatedReq, body, err := e.translateRequest(req, opts, true) if err != nil { return nil, err @@ -114,20 +123,22 @@ func (e *AistudioExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth AuthType: authType, AuthValue: authValue, }) - stream, err := e.relay.Stream(ctx, e.provider, wsReq) + wsStream, err := e.relay.Stream(ctx, e.provider, wsReq) if err != nil { recordAPIResponseError(ctx, e.cfg, err) return nil, err } out := make(chan cliproxyexecutor.StreamChunk) + stream = out go func() { defer close(out) var param any metadataLogged := false - for event := range stream { + for event := range wsStream { if event.Err != nil { recordAPIResponseError(ctx, e.cfg, event.Err) - out <- cliproxyexecutor.StreamChunk{Err: event.Err} + reporter.publishFailure(ctx) + out <- cliproxyexecutor.StreamChunk{Err: fmt.Errorf("wsrelay: %v", event.Err)} return } switch event.Type { @@ -139,6 +150,9 @@ func (e *AistudioExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth case wsrelay.MessageTypeStreamChunk: if len(event.Payload) > 0 { appendAPIResponseChunk(ctx, e.cfg, bytes.Clone(event.Payload)) + if detail, ok := parseGeminiStreamUsage(event.Payload); ok { + reporter.publish(ctx, detail) + } } lines := sdktranslator.TranslateStream(ctx, body.toFormat, opts.SourceFormat, req.Model, bytes.Clone(opts.OriginalRequest), translatedReq, bytes.Clone(event.Payload), ¶m) for i := range lines { @@ -158,19 +172,21 @@ func (e *AistudioExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth for i := range lines { out <- cliproxyexecutor.StreamChunk{Payload: []byte(lines[i])} } + reporter.publish(ctx, parseGeminiUsage(event.Payload)) return case wsrelay.MessageTypeError: recordAPIResponseError(ctx, e.cfg, event.Err) + reporter.publishFailure(ctx) out <- cliproxyexecutor.StreamChunk{Err: fmt.Errorf("wsrelay: %v", event.Err)} return } } }() - return out, nil + return stream, nil } func (e *AistudioExecutor) CountTokens(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (cliproxyexecutor.Response, error) { - translatedReq, body, err := e.translateRequest(req, opts, false) + _, body, err := e.translateRequest(req, opts, false) if err != nil { return cliproxyexecutor.Response{}, err } @@ -210,9 +226,12 @@ func (e *AistudioExecutor) CountTokens(ctx context.Context, auth *cliproxyauth.A if resp.Status < 200 || resp.Status >= 300 { return cliproxyexecutor.Response{}, statusErr{code: resp.Status, msg: string(resp.Body)} } - var param any - out := sdktranslator.TranslateNonStream(ctx, body.toFormat, opts.SourceFormat, req.Model, bytes.Clone(opts.OriginalRequest), translatedReq, bytes.Clone(resp.Body), ¶m) - return cliproxyexecutor.Response{Payload: []byte(out)}, nil + totalTokens := gjson.GetBytes(resp.Body, "totalTokens").Int() + if totalTokens <= 0 { + return cliproxyexecutor.Response{}, fmt.Errorf("wsrelay: totalTokens missing in response") + } + translated := sdktranslator.TranslateTokenCount(ctx, body.toFormat, opts.SourceFormat, totalTokens, bytes.Clone(resp.Body)) + return cliproxyexecutor.Response{Payload: []byte(translated)}, nil } func (e *AistudioExecutor) Refresh(ctx context.Context, auth *cliproxyauth.Auth) (*cliproxyauth.Auth, error) { diff --git a/internal/wsrelay/manager.go b/internal/wsrelay/manager.go index ab32f9f3..ae28234c 100644 --- a/internal/wsrelay/manager.go +++ b/internal/wsrelay/manager.go @@ -142,11 +142,16 @@ func (m *Manager) handleWebsocket(w http.ResponseWriter, r *http.Request) { s.provider = strings.ToLower(s.id) } m.sessMutex.Lock() + var replaced *session if existing, ok := m.sessions[s.provider]; ok { - existing.cleanup(errors.New("replaced by new connection")) + replaced = existing } m.sessions[s.provider] = s m.sessMutex.Unlock() + + if replaced != nil { + replaced.cleanup(errors.New("replaced by new connection")) + } if m.onConnected != nil { m.onConnected(s.provider) } diff --git a/sdk/cliproxy/service.go b/sdk/cliproxy/service.go index ccbdf903..b0f4605b 100644 --- a/sdk/cliproxy/service.go +++ b/sdk/cliproxy/service.go @@ -203,7 +203,9 @@ func (s *Service) wsOnConnected(provider string) { } if s.coreManager != nil { if existing, ok := s.coreManager.GetByID(provider); ok && existing != nil { - return + if !existing.Disabled && existing.Status == coreauth.StatusActive { + return + } } } now := time.Now().UTC() @@ -225,6 +227,10 @@ func (s *Service) wsOnDisconnected(provider string, reason error) { return } if reason != nil { + if strings.Contains(reason.Error(), "replaced by new connection") { + log.Infof("websocket provider replaced: %s", provider) + return + } log.Warnf("websocket provider disconnected: %s (%v)", provider, reason) } else { log.Infof("websocket provider disconnected: %s", provider) From 8aaed4cf09c6290bdce14b8d39717ff3a27bf786 Mon Sep 17 00:00:00 2001 From: hkfires <10558748+hkfires@users.noreply.github.com> Date: Sat, 25 Oct 2025 15:57:27 +0800 Subject: [PATCH 3/6] feat(aistudio): support non-streaming responses --- internal/wsrelay/http.go | 50 ++++++++++++++++++++++++++++++++++++++-- 1 file changed, 48 insertions(+), 2 deletions(-) diff --git a/internal/wsrelay/http.go b/internal/wsrelay/http.go index 96f80ec3..f34a61ca 100644 --- a/internal/wsrelay/http.go +++ b/internal/wsrelay/http.go @@ -1,6 +1,7 @@ package wsrelay import ( + "bytes" "context" "errors" "fmt" @@ -44,21 +45,66 @@ func (m *Manager) RoundTrip(ctx context.Context, provider string, req *HTTPReque if err != nil { return nil, err } + var ( + streamMode bool + streamResp *HTTPResponse + streamBody bytes.Buffer + ) for { select { case <-ctx.Done(): return nil, ctx.Err() case msg, ok := <-respCh: if !ok { + if streamMode { + if streamResp == nil { + streamResp = &HTTPResponse{Status: http.StatusOK, Headers: make(http.Header)} + } else if streamResp.Headers == nil { + streamResp.Headers = make(http.Header) + } + streamResp.Body = append(streamResp.Body[:0], streamBody.Bytes()...) + return streamResp, nil + } return nil, errors.New("wsrelay: connection closed during response") } switch msg.Type { case MessageTypeHTTPResp: - return decodeResponse(msg.Payload), nil + resp := decodeResponse(msg.Payload) + if streamMode && streamBody.Len() > 0 && len(resp.Body) == 0 { + resp.Body = append(resp.Body[:0], streamBody.Bytes()...) + } + return resp, nil case MessageTypeError: return nil, decodeError(msg.Payload) case MessageTypeStreamStart, MessageTypeStreamChunk: - // Ignore streaming noise in non-stream requests. + if msg.Type == MessageTypeStreamStart { + streamMode = true + streamResp = decodeResponse(msg.Payload) + if streamResp.Headers == nil { + streamResp.Headers = make(http.Header) + } + streamBody.Reset() + continue + } + if !streamMode { + streamMode = true + streamResp = &HTTPResponse{Status: http.StatusOK, Headers: make(http.Header)} + } + chunk := decodeChunk(msg.Payload) + if len(chunk) > 0 { + streamBody.Write(chunk) + } + case MessageTypeStreamEnd: + if !streamMode { + return &HTTPResponse{Status: http.StatusOK, Headers: make(http.Header)}, nil + } + if streamResp == nil { + streamResp = &HTTPResponse{Status: http.StatusOK, Headers: make(http.Header)} + } else if streamResp.Headers == nil { + streamResp.Headers = make(http.Header) + } + streamResp.Body = append(streamResp.Body[:0], streamBody.Bytes()...) + return streamResp, nil default: } } From ea6065f1b15e4567daa66cae94b91f0bb1ee944b Mon Sep 17 00:00:00 2001 From: hkfires <10558748+hkfires@users.noreply.github.com> Date: Sat, 25 Oct 2025 16:53:49 +0800 Subject: [PATCH 4/6] fix(aistudio): strip usage metadata from non-final stream chunks --- .../runtime/executor/aistudio_executor.go | 71 +++++++++++++++++-- 1 file changed, 66 insertions(+), 5 deletions(-) diff --git a/internal/runtime/executor/aistudio_executor.go b/internal/runtime/executor/aistudio_executor.go index 4bcdab3a..53de71c8 100644 --- a/internal/runtime/executor/aistudio_executor.go +++ b/internal/runtime/executor/aistudio_executor.go @@ -150,13 +150,15 @@ func (e *AistudioExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth case wsrelay.MessageTypeStreamChunk: if len(event.Payload) > 0 { appendAPIResponseChunk(ctx, e.cfg, bytes.Clone(event.Payload)) - if detail, ok := parseGeminiStreamUsage(event.Payload); ok { + filtered := filterAistudioUsageMetadata(event.Payload) + if detail, ok := parseGeminiStreamUsage(filtered); ok { reporter.publish(ctx, detail) } - } - lines := sdktranslator.TranslateStream(ctx, body.toFormat, opts.SourceFormat, req.Model, bytes.Clone(opts.OriginalRequest), translatedReq, bytes.Clone(event.Payload), ¶m) - for i := range lines { - out <- cliproxyexecutor.StreamChunk{Payload: []byte(lines[i])} + lines := sdktranslator.TranslateStream(ctx, body.toFormat, opts.SourceFormat, req.Model, bytes.Clone(opts.OriginalRequest), translatedReq, bytes.Clone(filtered), ¶m) + for i := range lines { + out <- cliproxyexecutor.StreamChunk{Payload: []byte(lines[i])} + } + break } case wsrelay.MessageTypeStreamEnd: return @@ -281,3 +283,62 @@ func (e *AistudioExecutor) buildEndpoint(model, action, alt string) string { } return base } + +// filterAistudioUsageMetadata removes usageMetadata from intermediate SSE events so that +// only the terminal chunk retains token statistics. +func filterAistudioUsageMetadata(payload []byte) []byte { + if len(payload) == 0 { + return payload + } + + lines := bytes.Split(payload, []byte("\n")) + modified := false + for idx, line := range lines { + trimmed := bytes.TrimSpace(line) + if len(trimmed) == 0 || !bytes.HasPrefix(trimmed, []byte("data:")) { + continue + } + dataIdx := bytes.Index(line, []byte("data:")) + if dataIdx < 0 { + continue + } + rawJSON := bytes.TrimSpace(line[dataIdx+5:]) + cleaned, changed := stripUsageMetadataFromJSON(rawJSON) + if !changed { + continue + } + var rebuilt []byte + rebuilt = append(rebuilt, line[:dataIdx]...) + rebuilt = append(rebuilt, []byte("data:")...) + if len(cleaned) > 0 { + rebuilt = append(rebuilt, ' ') + rebuilt = append(rebuilt, cleaned...) + } + lines[idx] = rebuilt + modified = true + } + if !modified { + return payload + } + return bytes.Join(lines, []byte("\n")) +} + +// stripUsageMetadataFromJSON drops usageMetadata when no finishReason is present. +func stripUsageMetadataFromJSON(rawJSON []byte) ([]byte, bool) { + jsonBytes := bytes.TrimSpace(rawJSON) + if len(jsonBytes) == 0 || !gjson.ValidBytes(jsonBytes) { + return rawJSON, false + } + finishReason := gjson.GetBytes(jsonBytes, "candidates.0.finishReason") + if finishReason.Exists() && finishReason.String() != "" { + return rawJSON, false + } + if !gjson.GetBytes(jsonBytes, "usageMetadata").Exists() { + return rawJSON, false + } + cleaned, err := sjson.DeleteBytes(jsonBytes, "usageMetadata") + if err != nil { + return rawJSON, false + } + return cleaned, true +} From 359b8de44e2c265bde5da1a1080e7d1e71e3a4d2 Mon Sep 17 00:00:00 2001 From: hkfires <10558748+hkfires@users.noreply.github.com> Date: Sat, 25 Oct 2025 21:40:20 +0800 Subject: [PATCH 5/6] feat(ws): add WebSocket auth --- config.example.yaml | 3 ++ internal/access/config_access/provider.go | 5 +- internal/api/middleware/request_logging.go | 13 +++--- internal/api/server.go | 32 +++++++++++-- internal/config/config.go | 3 ++ internal/logging/gin_logger.go | 3 +- internal/util/provider.go | 54 ++++++++++++++++++++++ internal/watcher/watcher.go | 3 ++ sdk/cliproxy/service.go | 17 ++++++- 9 files changed, 119 insertions(+), 14 deletions(-) diff --git a/config.example.yaml b/config.example.yaml index 428df70b..d5795719 100644 --- a/config.example.yaml +++ b/config.example.yaml @@ -43,6 +43,9 @@ quota-exceeded: switch-project: true # Whether to automatically switch to another project when a quota is exceeded switch-preview-model: true # Whether to automatically switch to a preview model when a quota is exceeded +# When true, enable authentication for the WebSocket API (/v1/ws). +ws-auth: false + # API keys for official Generative Language API #generative-language-api-key: # - "AIzaSy...01" diff --git a/internal/access/config_access/provider.go b/internal/access/config_access/provider.go index 97a64fe2..70824524 100644 --- a/internal/access/config_access/provider.go +++ b/internal/access/config_access/provider.go @@ -57,10 +57,12 @@ func (p *provider) Authenticate(_ context.Context, r *http.Request) (*sdkaccess. authHeaderGoogle := r.Header.Get("X-Goog-Api-Key") authHeaderAnthropic := r.Header.Get("X-Api-Key") queryKey := "" + queryAuthToken := "" if r.URL != nil { queryKey = r.URL.Query().Get("key") + queryAuthToken = r.URL.Query().Get("auth_token") } - if authHeader == "" && authHeaderGoogle == "" && authHeaderAnthropic == "" && queryKey == "" { + if authHeader == "" && authHeaderGoogle == "" && authHeaderAnthropic == "" && queryKey == "" && queryAuthToken == "" { return nil, sdkaccess.ErrNoCredentials } @@ -74,6 +76,7 @@ func (p *provider) Authenticate(_ context.Context, r *http.Request) (*sdkaccess. {authHeaderGoogle, "x-goog-api-key"}, {authHeaderAnthropic, "x-api-key"}, {queryKey, "query-key"}, + {queryAuthToken, "query-auth-token"}, } for _, candidate := range candidates { diff --git a/internal/api/middleware/request_logging.go b/internal/api/middleware/request_logging.go index b866e00c..d4ea6510 100644 --- a/internal/api/middleware/request_logging.go +++ b/internal/api/middleware/request_logging.go @@ -10,6 +10,7 @@ import ( "github.com/gin-gonic/gin" "github.com/router-for-me/CLIProxyAPI/v6/internal/logging" + "github.com/router-for-me/CLIProxyAPI/v6/internal/util" ) // RequestLoggingMiddleware creates a Gin middleware that logs HTTP requests and responses. @@ -63,13 +64,11 @@ func RequestLoggingMiddleware(logger logging.RequestLogger) gin.HandlerFunc { // It captures the URL, method, headers, and body. The request body is read and then // restored so that it can be processed by subsequent handlers. func captureRequestInfo(c *gin.Context) (*RequestInfo, error) { - // Capture URL - url := c.Request.URL.String() - if c.Request.URL.Path != "" { - url = c.Request.URL.Path - if c.Request.URL.RawQuery != "" { - url += "?" + c.Request.URL.RawQuery - } + // Capture URL with sensitive query parameters masked + maskedQuery := util.MaskSensitiveQuery(c.Request.URL.RawQuery) + url := c.Request.URL.Path + if maskedQuery != "" { + url += "?" + maskedQuery } // Capture method diff --git a/internal/api/server.go b/internal/api/server.go index a41861c2..f4eb81e2 100644 --- a/internal/api/server.go +++ b/internal/api/server.go @@ -140,8 +140,10 @@ type Server struct { currentPath string // wsRoutes tracks registered websocket upgrade paths. - wsRouteMu sync.Mutex - wsRoutes map[string]struct{} + wsRouteMu sync.Mutex + wsRoutes map[string]struct{} + wsAuthChanged func(bool, bool) + wsAuthEnabled atomic.Bool // management handler mgmt *managementHandlers.Handler @@ -235,6 +237,7 @@ func NewServer(cfg *config.Config, authManager *auth.Manager, accessManager *sdk envManagementSecret: envManagementSecret, wsRoutes: make(map[string]struct{}), } + s.wsAuthEnabled.Store(cfg.WebsocketAuth) // Save initial YAML snapshot s.oldConfigYaml, _ = yaml.Marshal(cfg) s.applyAccessConfig(nil, cfg) @@ -398,10 +401,20 @@ func (s *Server) AttachWebsocketRoute(path string, handler http.Handler) { s.wsRoutes[trimmed] = struct{}{} s.wsRouteMu.Unlock() - s.engine.GET(trimmed, func(c *gin.Context) { + authMiddleware := AuthMiddleware(s.accessManager) + conditionalAuth := func(c *gin.Context) { + if !s.wsAuthEnabled.Load() { + c.Next() + return + } + authMiddleware(c) + } + finalHandler := func(c *gin.Context) { handler.ServeHTTP(c.Writer, c.Request) c.Abort() - }) + } + + s.engine.GET(trimmed, conditionalAuth, finalHandler) } func (s *Server) registerManagementRoutes() { @@ -803,6 +816,10 @@ func (s *Server) UpdateClients(cfg *config.Config) { s.applyAccessConfig(oldCfg, cfg) s.cfg = cfg + s.wsAuthEnabled.Store(cfg.WebsocketAuth) + if oldCfg != nil && s.wsAuthChanged != nil && oldCfg.WebsocketAuth != cfg.WebsocketAuth { + s.wsAuthChanged(oldCfg.WebsocketAuth, cfg.WebsocketAuth) + } managementasset.SetCurrentConfig(cfg) // Save YAML snapshot for next comparison s.oldConfigYaml, _ = yaml.Marshal(cfg) @@ -843,6 +860,13 @@ func (s *Server) UpdateClients(cfg *config.Config) { ) } +func (s *Server) SetWebsocketAuthChangeHandler(fn func(bool, bool)) { + if s == nil { + return + } + s.wsAuthChanged = fn +} + // (management handlers moved to internal/api/handlers/management) // AuthMiddleware returns a Gin middleware handler that authenticates requests diff --git a/internal/config/config.go b/internal/config/config.go index 169eecc2..bc4d217a 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -40,6 +40,9 @@ type Config struct { // QuotaExceeded defines the behavior when a quota is exceeded. QuotaExceeded QuotaExceeded `yaml:"quota-exceeded" json:"quota-exceeded"` + // WebsocketAuth enables or disables authentication for the WebSocket API. + WebsocketAuth bool `yaml:"ws-auth" json:"ws-auth"` + // GlAPIKey is the API key for the generative language API. GlAPIKey []string `yaml:"generative-language-api-key" json:"generative-language-api-key"` diff --git a/internal/logging/gin_logger.go b/internal/logging/gin_logger.go index 904fa797..2933a0bb 100644 --- a/internal/logging/gin_logger.go +++ b/internal/logging/gin_logger.go @@ -10,6 +10,7 @@ import ( "time" "github.com/gin-gonic/gin" + "github.com/router-for-me/CLIProxyAPI/v6/internal/util" log "github.com/sirupsen/logrus" ) @@ -23,7 +24,7 @@ func GinLogrusLogger() gin.HandlerFunc { return func(c *gin.Context) { start := time.Now() path := c.Request.URL.Path - raw := c.Request.URL.RawQuery + raw := util.MaskSensitiveQuery(c.Request.URL.RawQuery) c.Next() diff --git a/internal/util/provider.go b/internal/util/provider.go index 5f4dcd19..8c6cefdb 100644 --- a/internal/util/provider.go +++ b/internal/util/provider.go @@ -4,6 +4,7 @@ package util import ( + "net/url" "strings" "github.com/router-for-me/CLIProxyAPI/v6/internal/config" @@ -188,3 +189,56 @@ func MaskSensitiveHeaderValue(key, value string) string { return value } } + +// MaskSensitiveQuery masks sensitive query parameters, e.g. auth_token, within the raw query string. +func MaskSensitiveQuery(raw string) string { + if raw == "" { + return "" + } + parts := strings.Split(raw, "&") + changed := false + for i, part := range parts { + if part == "" { + continue + } + keyPart := part + valuePart := "" + if idx := strings.Index(part, "="); idx >= 0 { + keyPart = part[:idx] + valuePart = part[idx+1:] + } + decodedKey, err := url.QueryUnescape(keyPart) + if err != nil { + decodedKey = keyPart + } + if !shouldMaskQueryParam(decodedKey) { + continue + } + decodedValue, err := url.QueryUnescape(valuePart) + if err != nil { + decodedValue = valuePart + } + masked := HideAPIKey(strings.TrimSpace(decodedValue)) + parts[i] = keyPart + "=" + url.QueryEscape(masked) + changed = true + } + if !changed { + return raw + } + return strings.Join(parts, "&") +} + +func shouldMaskQueryParam(key string) bool { + key = strings.ToLower(strings.TrimSpace(key)) + if key == "" { + return false + } + key = strings.TrimSuffix(key, "[]") + if key == "key" || strings.Contains(key, "api-key") || strings.Contains(key, "apikey") || strings.Contains(key, "api_key") { + return true + } + if strings.Contains(key, "token") || strings.Contains(key, "secret") { + return true + } + return false +} diff --git a/internal/watcher/watcher.go b/internal/watcher/watcher.go index 85b48aae..93694710 100644 --- a/internal/watcher/watcher.go +++ b/internal/watcher/watcher.go @@ -1204,6 +1204,9 @@ func buildConfigChangeDetails(oldCfg, newCfg *config.Config) []string { if oldCfg.ProxyURL != newCfg.ProxyURL { changes = append(changes, fmt.Sprintf("proxy-url: %s -> %s", oldCfg.ProxyURL, newCfg.ProxyURL)) } + if oldCfg.WebsocketAuth != newCfg.WebsocketAuth { + changes = append(changes, fmt.Sprintf("ws-auth: %t -> %t", oldCfg.WebsocketAuth, newCfg.WebsocketAuth)) + } // Quota-exceeded behavior if oldCfg.QuotaExceeded.SwitchProject != newCfg.QuotaExceeded.SwitchProject { diff --git a/sdk/cliproxy/service.go b/sdk/cliproxy/service.go index b0f4605b..ada70eb5 100644 --- a/sdk/cliproxy/service.go +++ b/sdk/cliproxy/service.go @@ -421,6 +421,22 @@ func (s *Service) Run(ctx context.Context) error { s.ensureWebsocketGateway() if s.server != nil && s.wsGateway != nil { s.server.AttachWebsocketRoute(s.wsGateway.Path(), s.wsGateway.Handler()) + s.server.SetWebsocketAuthChangeHandler(func(oldEnabled, newEnabled bool) { + if oldEnabled == newEnabled { + return + } + if !oldEnabled && newEnabled { + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + if errStop := s.wsGateway.Stop(ctx); errStop != nil { + log.Warnf("failed to reset websocket connections after ws-auth change %t -> %t: %v", oldEnabled, newEnabled, errStop) + return + } + log.Debugf("ws-auth enabled; existing websocket sessions terminated to enforce authentication") + return + } + log.Debugf("ws-auth disabled; existing websocket sessions remain connected") + }) } if s.hooks.OnBeforeStart != nil { @@ -460,7 +476,6 @@ func (s *Service) Run(ctx context.Context) error { s.cfg = newCfg s.cfgMu.Unlock() s.rebindExecutors() - } watcherWrapper, err = s.watcherFactory(s.configPath, s.cfg.AuthDir, reloadCallback) From 7459c2c81af86d4cb91dcec48dc0c8b6e4278079 Mon Sep 17 00:00:00 2001 From: hkfires <10558748+hkfires@users.noreply.github.com> Date: Sun, 26 Oct 2025 16:28:20 +0800 Subject: [PATCH 6/6] fix(aistudio): remove generationConfig and tools when action is countTokens --- internal/runtime/executor/aistudio_executor.go | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/internal/runtime/executor/aistudio_executor.go b/internal/runtime/executor/aistudio_executor.go index 53de71c8..de90c63a 100644 --- a/internal/runtime/executor/aistudio_executor.go +++ b/internal/runtime/executor/aistudio_executor.go @@ -192,6 +192,10 @@ func (e *AistudioExecutor) CountTokens(ctx context.Context, auth *cliproxyauth.A if err != nil { return cliproxyexecutor.Response{}, err } + + body.payload, _ = sjson.DeleteBytes(body.payload, "generationConfig") + body.payload, _ = sjson.DeleteBytes(body.payload, "tools") + endpoint := e.buildEndpoint(req.Model, "countTokens", "") wsReq := &wsrelay.HTTPRequest{ Method: http.MethodPost,