From 3839d93ba040ace1420845698b5efb5b30519af2 Mon Sep 17 00:00:00 2001 From: hkfires <10558748+hkfires@users.noreply.github.com> Date: Sat, 25 Oct 2025 11:30:39 +0800 Subject: [PATCH] feat: add websocket routing and executor unregister API - Introduce Server.AttachWebsocketRoute(path, handler) to mount websocket upgrade handlers on the Gin engine. - Track registered WS paths via wsRoutes with wsRouteMu to prevent duplicate registrations; initialize in NewServer and import sync. - Add Manager.UnregisterExecutor(provider) for clean executor lifecycle management. - Add github.com/gorilla/websocket v1.5.3 dependency and update go.sum. Motivation: enable services to expose WS endpoints through the core server and allow removing auth executors dynamically while avoiding duplicate route setup. No breaking changes. --- go.mod | 1 + go.sum | 4 +- internal/api/server.go | 33 +++ .../runtime/executor/aistudio_executor.go | 264 ++++++++++++++++++ internal/wsrelay/http.go | 187 +++++++++++++ internal/wsrelay/manager.go | 200 +++++++++++++ internal/wsrelay/message.go | 27 ++ internal/wsrelay/session.go | 188 +++++++++++++ sdk/cliproxy/auth/manager.go | 11 + sdk/cliproxy/auth/types.go | 12 +- sdk/cliproxy/service.go | 111 ++++++++ 11 files changed, 1035 insertions(+), 3 deletions(-) create mode 100644 internal/runtime/executor/aistudio_executor.go create mode 100644 internal/wsrelay/http.go create mode 100644 internal/wsrelay/manager.go create mode 100644 internal/wsrelay/message.go create mode 100644 internal/wsrelay/session.go diff --git a/go.mod b/go.mod index df03ac4e..010c8a6e 100644 --- a/go.mod +++ b/go.mod @@ -7,6 +7,7 @@ require ( github.com/gin-gonic/gin v1.10.1 github.com/go-git/go-git/v6 v6.0.0-20251009132922-75a182125145 github.com/google/uuid v1.6.0 + github.com/gorilla/websocket v1.5.3 github.com/jackc/pgx/v5 v5.7.6 github.com/joho/godotenv v1.5.1 github.com/klauspost/compress v1.17.4 diff --git a/go.sum b/go.sum index cba1c68c..b5cfca4a 100644 --- a/go.sum +++ b/go.sum @@ -66,6 +66,8 @@ github.com/google/go-cmp v0.5.5/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/ github.com/google/gofuzz v1.0.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg= github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0= github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= +github.com/gorilla/websocket v1.5.3 h1:saDtZ6Pbx/0u+bgYQ3q96pZgCzfhKXGPqt7kZ72aNNg= +github.com/gorilla/websocket v1.5.3/go.mod h1:YR8l580nyteQvAITg2hZ9XVh4b55+EU/adAjf1fMHhE= github.com/jackc/pgpassfile v1.0.0 h1:/6Hmqy13Ss2zCq62VdNG8tM1wchn8zjSGOBJ6icpsIM= github.com/jackc/pgpassfile v1.0.0/go.mod h1:CEx0iS5ambNFdcRtxPj5JhEz+xB6uRky5eyVu/W2HEg= github.com/jackc/pgservicefile v0.0.0-20240606120523-5a60cdf6a761 h1:iCEnooe7UlwOQYpKFhBabPMi4aNAfoODPEFNiAnClxo= @@ -80,8 +82,6 @@ github.com/json-iterator/go v1.1.12 h1:PV8peI4a0ysnczrg+LtxykD8LfKY9ML6u2jnxaEnr github.com/json-iterator/go v1.1.12/go.mod h1:e30LSqwooZae/UwlEbR2852Gd8hjQvJoHmT4TnhNGBo= github.com/kevinburke/ssh_config v1.4.0 h1:6xxtP5bZ2E4NF5tuQulISpTO2z8XbtH8cg1PWkxoFkQ= github.com/kevinburke/ssh_config v1.4.0/go.mod h1:q2RIzfka+BXARoNexmF9gkxEX7DmvbW9P4hIVx2Kg4M= -github.com/klauspost/compress v1.17.3 h1:qkRjuerhUU1EmXLYGkSH6EZL+vPSxIrYjLNAK4slzwA= -github.com/klauspost/compress v1.17.3/go.mod h1:/dCuZOvVtNoHsyb+cuJD3itjs3NbnF6KH9zAO4BDxPM= github.com/klauspost/compress v1.17.4 h1:Ej5ixsIri7BrIjBkRZLTo6ghwrEtHFk7ijlczPW4fZ4= github.com/klauspost/compress v1.17.4/go.mod h1:/dCuZOvVtNoHsyb+cuJD3itjs3NbnF6KH9zAO4BDxPM= github.com/klauspost/cpuid/v2 v2.0.1/go.mod h1:FInQzS24/EEf25PyTYn52gqo7WaD8xa0213Md/qVLRg= diff --git a/internal/api/server.go b/internal/api/server.go index aae2b0e0..a41861c2 100644 --- a/internal/api/server.go +++ b/internal/api/server.go @@ -13,6 +13,7 @@ import ( "os" "path/filepath" "strings" + "sync" "sync/atomic" "time" @@ -138,6 +139,10 @@ type Server struct { // currentPath is the absolute path to the current working directory. currentPath string + // wsRoutes tracks registered websocket upgrade paths. + wsRouteMu sync.Mutex + wsRoutes map[string]struct{} + // management handler mgmt *managementHandlers.Handler @@ -228,6 +233,7 @@ func NewServer(cfg *config.Config, authManager *auth.Manager, accessManager *sdk configFilePath: configFilePath, currentPath: wd, envManagementSecret: envManagementSecret, + wsRoutes: make(map[string]struct{}), } // Save initial YAML snapshot s.oldConfigYaml, _ = yaml.Marshal(cfg) @@ -371,6 +377,33 @@ func (s *Server) setupRoutes() { // Management routes are registered lazily by registerManagementRoutes when a secret is configured. } +// AttachWebsocketRoute registers a websocket upgrade handler on the primary Gin engine. +// The handler is served as-is without additional middleware beyond the standard stack already configured. +func (s *Server) AttachWebsocketRoute(path string, handler http.Handler) { + if s == nil || s.engine == nil || handler == nil { + return + } + trimmed := strings.TrimSpace(path) + if trimmed == "" { + trimmed = "/v1/ws" + } + if !strings.HasPrefix(trimmed, "/") { + trimmed = "/" + trimmed + } + s.wsRouteMu.Lock() + if _, exists := s.wsRoutes[trimmed]; exists { + s.wsRouteMu.Unlock() + return + } + s.wsRoutes[trimmed] = struct{}{} + s.wsRouteMu.Unlock() + + s.engine.GET(trimmed, func(c *gin.Context) { + handler.ServeHTTP(c.Writer, c.Request) + c.Abort() + }) +} + func (s *Server) registerManagementRoutes() { if s == nil || s.engine == nil || s.mgmt == nil { return diff --git a/internal/runtime/executor/aistudio_executor.go b/internal/runtime/executor/aistudio_executor.go new file mode 100644 index 00000000..3eb9af24 --- /dev/null +++ b/internal/runtime/executor/aistudio_executor.go @@ -0,0 +1,264 @@ +package executor + +import ( + "bytes" + "context" + "fmt" + "net/http" + "net/url" + "strings" + + "github.com/router-for-me/CLIProxyAPI/v6/internal/config" + "github.com/router-for-me/CLIProxyAPI/v6/internal/util" + "github.com/router-for-me/CLIProxyAPI/v6/internal/wsrelay" + cliproxyauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth" + cliproxyexecutor "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/executor" + sdktranslator "github.com/router-for-me/CLIProxyAPI/v6/sdk/translator" + "github.com/tidwall/sjson" +) + +// AistudioExecutor routes AI Studio requests through a websocket-backed transport. +type AistudioExecutor struct { + provider string + relay *wsrelay.Manager + cfg *config.Config +} + +// NewAistudioExecutor constructs a websocket executor for the provider name. +func NewAistudioExecutor(cfg *config.Config, provider string, relay *wsrelay.Manager) *AistudioExecutor { + return &AistudioExecutor{provider: strings.ToLower(provider), relay: relay, cfg: cfg} +} + +// Identifier returns the provider key served by this executor. +func (e *AistudioExecutor) Identifier() string { return e.provider } + +// PrepareRequest is a no-op because websocket transport already injects headers. +func (e *AistudioExecutor) PrepareRequest(_ *http.Request, _ *cliproxyauth.Auth) error { + return nil +} + +func (e *AistudioExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (cliproxyexecutor.Response, error) { + translatedReq, body, err := e.translateRequest(req, opts, false) + if err != nil { + return cliproxyexecutor.Response{}, err + } + endpoint := e.buildEndpoint(req.Model, body.action, opts.Alt) + wsReq := &wsrelay.HTTPRequest{ + Method: http.MethodPost, + URL: endpoint, + Headers: http.Header{"Content-Type": []string{"application/json"}}, + Body: body.payload, + } + + var authID, authLabel, authType, authValue string + if auth != nil { + authID = auth.ID + authLabel = auth.Label + authType, authValue = auth.AccountInfo() + } + recordAPIRequest(ctx, e.cfg, upstreamRequestLog{ + URL: endpoint, + Method: http.MethodPost, + Headers: wsReq.Headers.Clone(), + Body: bytes.Clone(body.payload), + Provider: e.provider, + AuthID: authID, + AuthLabel: authLabel, + AuthType: authType, + AuthValue: authValue, + }) + + resp, err := e.relay.RoundTrip(ctx, e.provider, wsReq) + if err != nil { + recordAPIResponseError(ctx, e.cfg, err) + return cliproxyexecutor.Response{}, err + } + recordAPIResponseMetadata(ctx, e.cfg, resp.Status, resp.Headers.Clone()) + if len(resp.Body) > 0 { + appendAPIResponseChunk(ctx, e.cfg, bytes.Clone(resp.Body)) + } + if resp.Status < 200 || resp.Status >= 300 { + return cliproxyexecutor.Response{}, statusErr{code: resp.Status, msg: string(resp.Body)} + } + var param any + out := sdktranslator.TranslateNonStream(ctx, body.toFormat, opts.SourceFormat, req.Model, bytes.Clone(opts.OriginalRequest), translatedReq, bytes.Clone(resp.Body), ¶m) + return cliproxyexecutor.Response{Payload: []byte(out)}, nil +} + +func (e *AistudioExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (<-chan cliproxyexecutor.StreamChunk, error) { + translatedReq, body, err := e.translateRequest(req, opts, true) + if err != nil { + return nil, err + } + endpoint := e.buildEndpoint(req.Model, body.action, opts.Alt) + wsReq := &wsrelay.HTTPRequest{ + Method: http.MethodPost, + URL: endpoint, + Headers: http.Header{"Content-Type": []string{"application/json"}}, + Body: body.payload, + } + var authID, authLabel, authType, authValue string + if auth != nil { + authID = auth.ID + authLabel = auth.Label + authType, authValue = auth.AccountInfo() + } + recordAPIRequest(ctx, e.cfg, upstreamRequestLog{ + URL: endpoint, + Method: http.MethodPost, + Headers: wsReq.Headers.Clone(), + Body: bytes.Clone(body.payload), + Provider: e.provider, + AuthID: authID, + AuthLabel: authLabel, + AuthType: authType, + AuthValue: authValue, + }) + stream, err := e.relay.Stream(ctx, e.provider, wsReq) + if err != nil { + recordAPIResponseError(ctx, e.cfg, err) + return nil, err + } + out := make(chan cliproxyexecutor.StreamChunk) + go func() { + defer close(out) + var param any + metadataLogged := false + for event := range stream { + if event.Err != nil { + recordAPIResponseError(ctx, e.cfg, event.Err) + out <- cliproxyexecutor.StreamChunk{Err: event.Err} + return + } + switch event.Type { + case wsrelay.MessageTypeStreamStart: + if !metadataLogged && event.Status > 0 { + recordAPIResponseMetadata(ctx, e.cfg, event.Status, event.Headers.Clone()) + metadataLogged = true + } + case wsrelay.MessageTypeStreamChunk: + if len(event.Payload) > 0 { + appendAPIResponseChunk(ctx, e.cfg, bytes.Clone(event.Payload)) + } + lines := sdktranslator.TranslateStream(ctx, body.toFormat, opts.SourceFormat, req.Model, bytes.Clone(opts.OriginalRequest), translatedReq, bytes.Clone(event.Payload), ¶m) + for i := range lines { + out <- cliproxyexecutor.StreamChunk{Payload: []byte(lines[i])} + } + case wsrelay.MessageTypeStreamEnd: + return + case wsrelay.MessageTypeHTTPResp: + if !metadataLogged && event.Status > 0 { + recordAPIResponseMetadata(ctx, e.cfg, event.Status, event.Headers.Clone()) + metadataLogged = true + } + if len(event.Payload) > 0 { + appendAPIResponseChunk(ctx, e.cfg, bytes.Clone(event.Payload)) + } + lines := sdktranslator.TranslateStream(ctx, body.toFormat, opts.SourceFormat, req.Model, bytes.Clone(opts.OriginalRequest), translatedReq, bytes.Clone(event.Payload), ¶m) + for i := range lines { + out <- cliproxyexecutor.StreamChunk{Payload: []byte(lines[i])} + } + return + case wsrelay.MessageTypeError: + recordAPIResponseError(ctx, e.cfg, event.Err) + out <- cliproxyexecutor.StreamChunk{Err: fmt.Errorf("wsrelay: %v", event.Err)} + return + } + } + }() + return out, nil +} + +func (e *AistudioExecutor) CountTokens(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (cliproxyexecutor.Response, error) { + translatedReq, body, err := e.translateRequest(req, opts, false) + if err != nil { + return cliproxyexecutor.Response{}, err + } + endpoint := e.buildEndpoint(req.Model, "countTokens", "") + wsReq := &wsrelay.HTTPRequest{ + Method: http.MethodPost, + URL: endpoint, + Headers: http.Header{"Content-Type": []string{"application/json"}}, + Body: body.payload, + } + var authID, authLabel, authType, authValue string + if auth != nil { + authID = auth.ID + authLabel = auth.Label + authType, authValue = auth.AccountInfo() + } + recordAPIRequest(ctx, e.cfg, upstreamRequestLog{ + URL: endpoint, + Method: http.MethodPost, + Headers: wsReq.Headers.Clone(), + Body: bytes.Clone(body.payload), + Provider: e.provider, + AuthID: authID, + AuthLabel: authLabel, + AuthType: authType, + AuthValue: authValue, + }) + resp, err := e.relay.RoundTrip(ctx, e.provider, wsReq) + if err != nil { + recordAPIResponseError(ctx, e.cfg, err) + return cliproxyexecutor.Response{}, err + } + recordAPIResponseMetadata(ctx, e.cfg, resp.Status, resp.Headers.Clone()) + if len(resp.Body) > 0 { + appendAPIResponseChunk(ctx, e.cfg, bytes.Clone(resp.Body)) + } + if resp.Status < 200 || resp.Status >= 300 { + return cliproxyexecutor.Response{}, statusErr{code: resp.Status, msg: string(resp.Body)} + } + var param any + out := sdktranslator.TranslateNonStream(ctx, body.toFormat, opts.SourceFormat, req.Model, bytes.Clone(opts.OriginalRequest), translatedReq, bytes.Clone(resp.Body), ¶m) + return cliproxyexecutor.Response{Payload: []byte(out)}, nil +} + +func (e *AistudioExecutor) Refresh(ctx context.Context, auth *cliproxyauth.Auth) (*cliproxyauth.Auth, error) { + _ = ctx + return auth, nil +} + +type translatedPayload struct { + payload []byte + action string + toFormat sdktranslator.Format +} + +func (e *AistudioExecutor) translateRequest(req cliproxyexecutor.Request, opts cliproxyexecutor.Options, stream bool) ([]byte, translatedPayload, error) { + from := opts.SourceFormat + to := sdktranslator.FromString("gemini") + payload := sdktranslator.TranslateRequest(from, to, req.Model, bytes.Clone(req.Payload), stream) + if budgetOverride, includeOverride, ok := util.GeminiThinkingFromMetadata(req.Metadata); ok { + payload = util.ApplyGeminiThinkingConfig(payload, budgetOverride, includeOverride) + } + payload = disableGeminiThinkingConfig(payload, req.Model) + payload = fixGeminiImageAspectRatio(req.Model, payload) + metadataAction := "generateContent" + if req.Metadata != nil { + if action, _ := req.Metadata["action"].(string); action == "countTokens" { + metadataAction = action + } + } + action := metadataAction + if stream && action != "countTokens" { + action = "streamGenerateContent" + } + payload, _ = sjson.DeleteBytes(payload, "session_id") + return payload, translatedPayload{payload: payload, action: action, toFormat: to}, nil +} + +func (e *AistudioExecutor) buildEndpoint(model, action, alt string) string { + base := fmt.Sprintf("%s/%s/models/%s:%s", glEndpoint, glAPIVersion, model, action) + if action == "streamGenerateContent" { + if alt == "" { + return base + "?alt=sse" + } + return base + "?$alt=" + url.QueryEscape(alt) + } + if alt != "" && action != "countTokens" { + return base + "?$alt=" + url.QueryEscape(alt) + } + return base +} diff --git a/internal/wsrelay/http.go b/internal/wsrelay/http.go new file mode 100644 index 00000000..96f80ec3 --- /dev/null +++ b/internal/wsrelay/http.go @@ -0,0 +1,187 @@ +package wsrelay + +import ( + "context" + "errors" + "fmt" + "net/http" + "time" + + "github.com/google/uuid" +) + +// HTTPRequest represents a proxied HTTP request delivered to websocket clients. +type HTTPRequest struct { + Method string + URL string + Headers http.Header + Body []byte +} + +// HTTPResponse captures the response relayed back from websocket clients. +type HTTPResponse struct { + Status int + Headers http.Header + Body []byte +} + +// StreamEvent represents a streaming response event from clients. +type StreamEvent struct { + Type string + Payload []byte + Status int + Headers http.Header + Err error +} + +// RoundTrip executes a non-streaming HTTP request using the websocket provider. +func (m *Manager) RoundTrip(ctx context.Context, provider string, req *HTTPRequest) (*HTTPResponse, error) { + if req == nil { + return nil, fmt.Errorf("wsrelay: request is nil") + } + msg := Message{ID: uuid.NewString(), Type: MessageTypeHTTPReq, Payload: encodeRequest(req)} + respCh, err := m.Send(ctx, provider, msg) + if err != nil { + return nil, err + } + for { + select { + case <-ctx.Done(): + return nil, ctx.Err() + case msg, ok := <-respCh: + if !ok { + return nil, errors.New("wsrelay: connection closed during response") + } + switch msg.Type { + case MessageTypeHTTPResp: + return decodeResponse(msg.Payload), nil + case MessageTypeError: + return nil, decodeError(msg.Payload) + case MessageTypeStreamStart, MessageTypeStreamChunk: + // Ignore streaming noise in non-stream requests. + default: + } + } + } +} + +// Stream executes a streaming HTTP request and returns channel with stream events. +func (m *Manager) Stream(ctx context.Context, provider string, req *HTTPRequest) (<-chan StreamEvent, error) { + if req == nil { + return nil, fmt.Errorf("wsrelay: request is nil") + } + msg := Message{ID: uuid.NewString(), Type: MessageTypeHTTPReq, Payload: encodeRequest(req)} + respCh, err := m.Send(ctx, provider, msg) + if err != nil { + return nil, err + } + out := make(chan StreamEvent) + go func() { + defer close(out) + for { + select { + case <-ctx.Done(): + out <- StreamEvent{Err: ctx.Err()} + return + case msg, ok := <-respCh: + if !ok { + out <- StreamEvent{Err: errors.New("wsrelay: stream closed")} + return + } + switch msg.Type { + case MessageTypeStreamStart: + resp := decodeResponse(msg.Payload) + out <- StreamEvent{Type: MessageTypeStreamStart, Status: resp.Status, Headers: resp.Headers} + case MessageTypeStreamChunk: + chunk := decodeChunk(msg.Payload) + out <- StreamEvent{Type: MessageTypeStreamChunk, Payload: chunk} + case MessageTypeStreamEnd: + out <- StreamEvent{Type: MessageTypeStreamEnd} + return + case MessageTypeError: + out <- StreamEvent{Type: MessageTypeError, Err: decodeError(msg.Payload)} + return + case MessageTypeHTTPResp: + resp := decodeResponse(msg.Payload) + out <- StreamEvent{Type: MessageTypeHTTPResp, Status: resp.Status, Headers: resp.Headers, Payload: resp.Body} + return + default: + } + } + } + }() + return out, nil +} + +func encodeRequest(req *HTTPRequest) map[string]any { + headers := make(map[string]any, len(req.Headers)) + for key, values := range req.Headers { + copyValues := make([]string, len(values)) + copy(copyValues, values) + headers[key] = copyValues + } + return map[string]any{ + "method": req.Method, + "url": req.URL, + "headers": headers, + "body": string(req.Body), + "sent_at": time.Now().UTC().Format(time.RFC3339Nano), + } +} + +func decodeResponse(payload map[string]any) *HTTPResponse { + if payload == nil { + return &HTTPResponse{Status: http.StatusBadGateway, Headers: make(http.Header)} + } + resp := &HTTPResponse{Status: http.StatusOK, Headers: make(http.Header)} + if status, ok := payload["status"].(float64); ok { + resp.Status = int(status) + } + if headers, ok := payload["headers"].(map[string]any); ok { + for key, raw := range headers { + switch v := raw.(type) { + case []any: + for _, item := range v { + if str, ok := item.(string); ok { + resp.Headers.Add(key, str) + } + } + case []string: + for _, str := range v { + resp.Headers.Add(key, str) + } + case string: + resp.Headers.Set(key, v) + } + } + } + if body, ok := payload["body"].(string); ok { + resp.Body = []byte(body) + } + return resp +} + +func decodeChunk(payload map[string]any) []byte { + if payload == nil { + return nil + } + if data, ok := payload["data"].(string); ok { + return []byte(data) + } + return nil +} + +func decodeError(payload map[string]any) error { + if payload == nil { + return errors.New("wsrelay: unknown error") + } + message, _ := payload["error"].(string) + status := 0 + if v, ok := payload["status"].(float64); ok { + status = int(v) + } + if message == "" { + message = "wsrelay: upstream error" + } + return fmt.Errorf("%s (status=%d)", message, status) +} diff --git a/internal/wsrelay/manager.go b/internal/wsrelay/manager.go new file mode 100644 index 00000000..ab32f9f3 --- /dev/null +++ b/internal/wsrelay/manager.go @@ -0,0 +1,200 @@ +package wsrelay + +import ( + "context" + "crypto/rand" + "errors" + "fmt" + "net/http" + "strings" + "sync" + "time" + + "github.com/gorilla/websocket" +) + +// Manager exposes a websocket endpoint that proxies Gemini requests to +// connected clients. +type Manager struct { + path string + upgrader websocket.Upgrader + sessions map[string]*session + sessMutex sync.RWMutex + + providerFactory func(*http.Request) (string, error) + onConnected func(string) + onDisconnected func(string, error) + + logDebugf func(string, ...any) + logInfof func(string, ...any) + logWarnf func(string, ...any) +} + +// Options configures a Manager instance. +type Options struct { + Path string + ProviderFactory func(*http.Request) (string, error) + OnConnected func(string) + OnDisconnected func(string, error) + LogDebugf func(string, ...any) + LogInfof func(string, ...any) + LogWarnf func(string, ...any) +} + +// NewManager builds a websocket relay manager with the supplied options. +func NewManager(opts Options) *Manager { + path := strings.TrimSpace(opts.Path) + if path == "" { + path = "/v1/ws" + } + if !strings.HasPrefix(path, "/") { + path = "/" + path + } + mgr := &Manager{ + path: path, + sessions: make(map[string]*session), + upgrader: websocket.Upgrader{ + ReadBufferSize: 1024, + WriteBufferSize: 1024, + CheckOrigin: func(r *http.Request) bool { + return true + }, + }, + providerFactory: opts.ProviderFactory, + onConnected: opts.OnConnected, + onDisconnected: opts.OnDisconnected, + logDebugf: opts.LogDebugf, + logInfof: opts.LogInfof, + logWarnf: opts.LogWarnf, + } + if mgr.logDebugf == nil { + mgr.logDebugf = func(string, ...any) {} + } + if mgr.logInfof == nil { + mgr.logInfof = func(string, ...any) {} + } + if mgr.logWarnf == nil { + mgr.logWarnf = func(s string, args ...any) { fmt.Printf(s+"\n", args...) } + } + return mgr +} + +// Path returns the HTTP path the manager expects for websocket upgrades. +func (m *Manager) Path() string { + if m == nil { + return "/v1/ws" + } + return m.path +} + +// Handler exposes an http.Handler that upgrades connections to websocket sessions. +func (m *Manager) Handler() http.Handler { + return http.HandlerFunc(m.handleWebsocket) +} + +// Stop gracefully closes all active websocket sessions. +func (m *Manager) Stop(_ context.Context) error { + m.sessMutex.Lock() + sessions := make([]*session, 0, len(m.sessions)) + for _, sess := range m.sessions { + sessions = append(sessions, sess) + } + m.sessions = make(map[string]*session) + m.sessMutex.Unlock() + + for _, sess := range sessions { + if sess != nil { + sess.cleanup(errors.New("wsrelay: manager stopped")) + } + } + return nil +} + +// handleWebsocket upgrades the connection and wires the session into the pool. +func (m *Manager) handleWebsocket(w http.ResponseWriter, r *http.Request) { + expectedPath := m.Path() + if expectedPath != "" && r.URL != nil && r.URL.Path != expectedPath { + http.NotFound(w, r) + return + } + if !strings.EqualFold(r.Method, http.MethodGet) { + w.Header().Set("Allow", http.MethodGet) + http.Error(w, "method not allowed", http.StatusMethodNotAllowed) + return + } + conn, err := m.upgrader.Upgrade(w, r, nil) + if err != nil { + m.logWarnf("wsrelay: upgrade failed: %v", err) + return + } + s := newSession(conn, m, randomProviderName()) + if m.providerFactory != nil { + name, err := m.providerFactory(r) + if err != nil { + s.cleanup(err) + return + } + if strings.TrimSpace(name) != "" { + s.provider = strings.ToLower(name) + } + } + if s.provider == "" { + s.provider = strings.ToLower(s.id) + } + m.sessMutex.Lock() + if existing, ok := m.sessions[s.provider]; ok { + existing.cleanup(errors.New("replaced by new connection")) + } + m.sessions[s.provider] = s + m.sessMutex.Unlock() + if m.onConnected != nil { + m.onConnected(s.provider) + } + + go s.run(context.Background()) +} + +// Send forwards the message to the specific provider connection and returns a channel +// yielding response messages. +func (m *Manager) Send(ctx context.Context, provider string, msg Message) (<-chan Message, error) { + s := m.session(provider) + if s == nil { + return nil, fmt.Errorf("wsrelay: provider %s not connected", provider) + } + return s.request(ctx, msg) +} + +func (m *Manager) session(provider string) *session { + key := strings.ToLower(strings.TrimSpace(provider)) + m.sessMutex.RLock() + s := m.sessions[key] + m.sessMutex.RUnlock() + return s +} + +func (m *Manager) handleSessionClosed(s *session, cause error) { + if s == nil { + return + } + key := strings.ToLower(strings.TrimSpace(s.provider)) + m.sessMutex.Lock() + if cur, ok := m.sessions[key]; ok && cur == s { + delete(m.sessions, key) + } + m.sessMutex.Unlock() + if m.onDisconnected != nil { + m.onDisconnected(s.provider, cause) + } +} + +func randomProviderName() string { + const alphabet = "abcdefghijklmnopqrstuvwxyz0123456789" + buf := make([]byte, 16) + if _, err := rand.Read(buf); err != nil { + return fmt.Sprintf("aistudio-%x", time.Now().UnixNano()) + } + for i := range buf { + buf[i] = alphabet[int(buf[i])%len(alphabet)] + } + return "aistudio-" + string(buf) +} diff --git a/internal/wsrelay/message.go b/internal/wsrelay/message.go new file mode 100644 index 00000000..bf716e5e --- /dev/null +++ b/internal/wsrelay/message.go @@ -0,0 +1,27 @@ +package wsrelay + +// Message represents the JSON payload exchanged with websocket clients. +type Message struct { + ID string `json:"id"` + Type string `json:"type"` + Payload map[string]any `json:"payload,omitempty"` +} + +const ( + // MessageTypeHTTPReq identifies an HTTP-style request envelope. + MessageTypeHTTPReq = "http_request" + // MessageTypeHTTPResp identifies a non-streaming HTTP response envelope. + MessageTypeHTTPResp = "http_response" + // MessageTypeStreamStart marks the beginning of a streaming response. + MessageTypeStreamStart = "stream_start" + // MessageTypeStreamChunk carries a streaming response chunk. + MessageTypeStreamChunk = "stream_chunk" + // MessageTypeStreamEnd marks the completion of a streaming response. + MessageTypeStreamEnd = "stream_end" + // MessageTypeError carries an error response. + MessageTypeError = "error" + // MessageTypePing represents ping messages from clients. + MessageTypePing = "ping" + // MessageTypePong represents pong responses back to clients. + MessageTypePong = "pong" +) diff --git a/internal/wsrelay/session.go b/internal/wsrelay/session.go new file mode 100644 index 00000000..a728cbc3 --- /dev/null +++ b/internal/wsrelay/session.go @@ -0,0 +1,188 @@ +package wsrelay + +import ( + "context" + "errors" + "fmt" + "sync" + "time" + + "github.com/gorilla/websocket" +) + +const ( + readTimeout = 60 * time.Second + writeTimeout = 10 * time.Second + maxInboundMessageLen = 64 << 20 // 64 MiB + heartbeatInterval = 30 * time.Second +) + +var errClosed = errors.New("websocket session closed") + +type pendingRequest struct { + ch chan Message + closeOnce sync.Once +} + +func (pr *pendingRequest) close() { + if pr == nil { + return + } + pr.closeOnce.Do(func() { + close(pr.ch) + }) +} + +type session struct { + conn *websocket.Conn + manager *Manager + provider string + id string + closed chan struct{} + closeOnce sync.Once + writeMutex sync.Mutex + pending sync.Map // map[string]*pendingRequest +} + +func newSession(conn *websocket.Conn, mgr *Manager, id string) *session { + s := &session{ + conn: conn, + manager: mgr, + provider: "", + id: id, + closed: make(chan struct{}), + } + conn.SetReadLimit(maxInboundMessageLen) + conn.SetReadDeadline(time.Now().Add(readTimeout)) + conn.SetPongHandler(func(string) error { + conn.SetReadDeadline(time.Now().Add(readTimeout)) + return nil + }) + s.startHeartbeat() + return s +} + +func (s *session) startHeartbeat() { + if s == nil || s.conn == nil { + return + } + ticker := time.NewTicker(heartbeatInterval) + go func() { + defer ticker.Stop() + for { + select { + case <-s.closed: + return + case <-ticker.C: + s.writeMutex.Lock() + err := s.conn.WriteControl(websocket.PingMessage, []byte("ping"), time.Now().Add(writeTimeout)) + s.writeMutex.Unlock() + if err != nil { + s.cleanup(err) + return + } + } + } + }() +} + +func (s *session) run(ctx context.Context) { + defer s.cleanup(errClosed) + for { + var msg Message + if err := s.conn.ReadJSON(&msg); err != nil { + s.cleanup(err) + return + } + s.dispatch(msg) + } +} + +func (s *session) dispatch(msg Message) { + if msg.Type == MessageTypePing { + _ = s.send(context.Background(), Message{ID: msg.ID, Type: MessageTypePong}) + return + } + if value, ok := s.pending.Load(msg.ID); ok { + req := value.(*pendingRequest) + select { + case req.ch <- msg: + default: + } + if msg.Type == MessageTypeHTTPResp || msg.Type == MessageTypeError || msg.Type == MessageTypeStreamEnd { + if actual, loaded := s.pending.LoadAndDelete(msg.ID); loaded { + actual.(*pendingRequest).close() + } + } + return + } + if msg.Type == MessageTypeHTTPResp || msg.Type == MessageTypeError || msg.Type == MessageTypeStreamEnd { + s.manager.logDebugf("wsrelay: received terminal message for unknown id %s (provider=%s)", msg.ID, s.provider) + } +} + +func (s *session) send(ctx context.Context, msg Message) error { + select { + case <-s.closed: + return errClosed + default: + } + s.writeMutex.Lock() + defer s.writeMutex.Unlock() + if err := s.conn.SetWriteDeadline(time.Now().Add(writeTimeout)); err != nil { + return fmt.Errorf("set write deadline: %w", err) + } + if err := s.conn.WriteJSON(msg); err != nil { + return fmt.Errorf("write json: %w", err) + } + return nil +} + +func (s *session) request(ctx context.Context, msg Message) (<-chan Message, error) { + if msg.ID == "" { + return nil, fmt.Errorf("wsrelay: message id is required") + } + if _, loaded := s.pending.LoadOrStore(msg.ID, &pendingRequest{ch: make(chan Message, 8)}); loaded { + return nil, fmt.Errorf("wsrelay: duplicate message id %s", msg.ID) + } + value, _ := s.pending.Load(msg.ID) + req := value.(*pendingRequest) + if err := s.send(ctx, msg); err != nil { + if actual, loaded := s.pending.LoadAndDelete(msg.ID); loaded { + req := actual.(*pendingRequest) + req.close() + } + return nil, err + } + go func() { + select { + case <-ctx.Done(): + if actual, loaded := s.pending.LoadAndDelete(msg.ID); loaded { + actual.(*pendingRequest).close() + } + case <-s.closed: + } + }() + return req.ch, nil +} + +func (s *session) cleanup(cause error) { + s.closeOnce.Do(func() { + close(s.closed) + s.pending.Range(func(key, value any) bool { + req := value.(*pendingRequest) + msg := Message{ID: key.(string), Type: MessageTypeError, Payload: map[string]any{"error": cause.Error()}} + select { + case req.ch <- msg: + default: + } + req.close() + return true + }) + s.pending = sync.Map{} + _ = s.conn.Close() + if s.manager != nil { + s.manager.handleSessionClosed(s, cause) + } + }) +} diff --git a/sdk/cliproxy/auth/manager.go b/sdk/cliproxy/auth/manager.go index c2e87d9d..2cf7c77e 100644 --- a/sdk/cliproxy/auth/manager.go +++ b/sdk/cliproxy/auth/manager.go @@ -153,6 +153,17 @@ func (m *Manager) RegisterExecutor(executor ProviderExecutor) { m.executors[executor.Identifier()] = executor } +// UnregisterExecutor removes the executor associated with the provider key. +func (m *Manager) UnregisterExecutor(provider string) { + provider = strings.ToLower(strings.TrimSpace(provider)) + if provider == "" { + return + } + m.mu.Lock() + delete(m.executors, provider) + m.mu.Unlock() +} + // Register inserts a new auth entry into the manager. func (m *Manager) Register(ctx context.Context, auth *Auth) (*Auth, error) { if auth == nil { diff --git a/sdk/cliproxy/auth/types.go b/sdk/cliproxy/auth/types.go index 35594bd8..2755383f 100644 --- a/sdk/cliproxy/auth/types.go +++ b/sdk/cliproxy/auth/types.go @@ -156,7 +156,17 @@ func (a *Auth) AccountInfo() (string, string) { if v, ok := a.Metadata["email"].(string); ok { return "oauth", v } - } else if a.Attributes != nil { + } + if strings.HasPrefix(strings.ToLower(strings.TrimSpace(a.Provider)), "aistudio-") { + if label := strings.TrimSpace(a.Label); label != "" { + return "oauth", label + } + if id := strings.TrimSpace(a.ID); id != "" { + return "oauth", id + } + return "oauth", "aistudio" + } + if a.Attributes != nil { if v := a.Attributes["api_key"]; v != "" { return "api_key", v } diff --git a/sdk/cliproxy/service.go b/sdk/cliproxy/service.go index f0b44884..ccbdf903 100644 --- a/sdk/cliproxy/service.go +++ b/sdk/cliproxy/service.go @@ -18,6 +18,7 @@ import ( "github.com/router-for-me/CLIProxyAPI/v6/internal/runtime/executor" _ "github.com/router-for-me/CLIProxyAPI/v6/internal/usage" "github.com/router-for-me/CLIProxyAPI/v6/internal/watcher" + "github.com/router-for-me/CLIProxyAPI/v6/internal/wsrelay" sdkaccess "github.com/router-for-me/CLIProxyAPI/v6/sdk/access" sdkAuth "github.com/router-for-me/CLIProxyAPI/v6/sdk/auth" coreauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth" @@ -82,6 +83,9 @@ type Service struct { // shutdownOnce ensures shutdown is called only once. shutdownOnce sync.Once + + // wsGateway manages websocket Gemini providers. + wsGateway *wsrelay.Manager } // RegisterUsagePlugin registers a usage plugin on the global usage manager. @@ -172,6 +176,66 @@ func (s *Service) handleAuthUpdate(ctx context.Context, update watcher.AuthUpdat } } +func (s *Service) ensureWebsocketGateway() { + if s == nil { + return + } + if s.wsGateway != nil { + return + } + opts := wsrelay.Options{ + Path: "/v1/ws", + OnConnected: s.wsOnConnected, + OnDisconnected: s.wsOnDisconnected, + LogDebugf: log.Debugf, + LogInfof: log.Infof, + LogWarnf: log.Warnf, + } + s.wsGateway = wsrelay.NewManager(opts) +} + +func (s *Service) wsOnConnected(provider string) { + if s == nil || provider == "" { + return + } + if !strings.HasPrefix(strings.ToLower(provider), "aistudio-") { + return + } + if s.coreManager != nil { + if existing, ok := s.coreManager.GetByID(provider); ok && existing != nil { + return + } + } + now := time.Now().UTC() + auth := &coreauth.Auth{ + ID: provider, + Provider: provider, + Label: provider, + Status: coreauth.StatusActive, + CreatedAt: now, + UpdatedAt: now, + Attributes: map[string]string{"ws_provider": "gemini"}, + } + log.Infof("websocket provider connected: %s", provider) + s.applyCoreAuthAddOrUpdate(context.Background(), auth) +} + +func (s *Service) wsOnDisconnected(provider string, reason error) { + if s == nil || provider == "" { + return + } + if reason != nil { + log.Warnf("websocket provider disconnected: %s (%v)", provider, reason) + } else { + log.Infof("websocket provider disconnected: %s", provider) + } + ctx := context.Background() + s.applyCoreAuthRemoval(ctx, provider) + if s.coreManager != nil { + s.coreManager.UnregisterExecutor(provider) + } +} + func (s *Service) applyCoreAuthAddOrUpdate(ctx context.Context, auth *coreauth.Auth) { if s == nil || auth == nil || auth.ID == "" { return @@ -247,6 +311,12 @@ func (s *Service) ensureExecutorsForAuth(a *coreauth.Auth) { s.coreManager.RegisterExecutor(executor.NewOpenAICompatExecutor(compatProviderKey, s.cfg)) return } + if strings.HasPrefix(strings.ToLower(strings.TrimSpace(a.Provider)), "aistudio-") { + if s.wsGateway != nil { + s.coreManager.RegisterExecutor(executor.NewAistudioExecutor(s.cfg, a.Provider, s.wsGateway)) + } + return + } switch strings.ToLower(a.Provider) { case "gemini": s.coreManager.RegisterExecutor(executor.NewGeminiExecutor(s.cfg)) @@ -342,6 +412,11 @@ func (s *Service) Run(ctx context.Context) error { s.authManager = newDefaultAuthManager() } + s.ensureWebsocketGateway() + if s.server != nil && s.wsGateway != nil { + s.server.AttachWebsocketRoute(s.wsGateway.Path(), s.wsGateway.Handler()) + } + if s.hooks.OnBeforeStart != nil { s.hooks.OnBeforeStart(s.cfg) } @@ -449,6 +524,14 @@ func (s *Service) Shutdown(ctx context.Context) error { shutdownErr = err } } + if s.wsGateway != nil { + if err := s.wsGateway.Stop(ctx); err != nil { + log.Errorf("failed to stop websocket gateway: %v", err) + if shutdownErr == nil { + shutdownErr = err + } + } + } if s.authQueueStop != nil { s.authQueueStop() s.authQueueStop = nil @@ -505,6 +588,13 @@ func (s *Service) registerModelsForAuth(a *coreauth.Auth) { } provider := strings.ToLower(strings.TrimSpace(a.Provider)) compatProviderKey, compatDisplayName, compatDetected := openAICompatInfoFromAuth(a) + if a.Attributes != nil { + if strings.EqualFold(a.Attributes["ws_provider"], "gemini") { + models := mergeGeminiModels() + GlobalModelRegistry().RegisterClient(a.ID, provider, models) + return + } + } if compatDetected { provider = "openai-compatibility" } @@ -611,3 +701,24 @@ func (s *Service) registerModelsForAuth(a *coreauth.Auth) { GlobalModelRegistry().RegisterClient(a.ID, key, models) } } + +func mergeGeminiModels() []*ModelInfo { + models := make([]*ModelInfo, 0, 16) + seen := make(map[string]struct{}) + appendModels := func(items []*ModelInfo) { + for i := range items { + m := items[i] + if m == nil || m.ID == "" { + continue + } + if _, ok := seen[m.ID]; ok { + continue + } + seen[m.ID] = struct{}{} + models = append(models, m) + } + } + appendModels(registry.GetGeminiModels()) + appendModels(registry.GetGeminiCLIModels()) + return models +}