mirror of
https://github.com/router-for-me/CLIProxyAPIPlus.git
synced 2026-04-24 06:40:31 +00:00
- Introduce Server.AttachWebsocketRoute(path, handler) to mount websocket upgrade handlers on the Gin engine. - Track registered WS paths via wsRoutes with wsRouteMu to prevent duplicate registrations; initialize in NewServer and import sync. - Add Manager.UnregisterExecutor(provider) for clean executor lifecycle management. - Add github.com/gorilla/websocket v1.5.3 dependency and update go.sum. Motivation: enable services to expose WS endpoints through the core server and allow removing auth executors dynamically while avoiding duplicate route setup. No breaking changes.
201 lines
4.9 KiB
Go
201 lines
4.9 KiB
Go
package wsrelay
|
|
|
|
import (
|
|
"context"
|
|
"crypto/rand"
|
|
"errors"
|
|
"fmt"
|
|
"net/http"
|
|
"strings"
|
|
"sync"
|
|
"time"
|
|
|
|
"github.com/gorilla/websocket"
|
|
)
|
|
|
|
// Manager exposes a websocket endpoint that proxies Gemini requests to
|
|
// connected clients.
|
|
type Manager struct {
|
|
path string
|
|
upgrader websocket.Upgrader
|
|
sessions map[string]*session
|
|
sessMutex sync.RWMutex
|
|
|
|
providerFactory func(*http.Request) (string, error)
|
|
onConnected func(string)
|
|
onDisconnected func(string, error)
|
|
|
|
logDebugf func(string, ...any)
|
|
logInfof func(string, ...any)
|
|
logWarnf func(string, ...any)
|
|
}
|
|
|
|
// Options configures a Manager instance.
|
|
type Options struct {
|
|
Path string
|
|
ProviderFactory func(*http.Request) (string, error)
|
|
OnConnected func(string)
|
|
OnDisconnected func(string, error)
|
|
LogDebugf func(string, ...any)
|
|
LogInfof func(string, ...any)
|
|
LogWarnf func(string, ...any)
|
|
}
|
|
|
|
// NewManager builds a websocket relay manager with the supplied options.
|
|
func NewManager(opts Options) *Manager {
|
|
path := strings.TrimSpace(opts.Path)
|
|
if path == "" {
|
|
path = "/v1/ws"
|
|
}
|
|
if !strings.HasPrefix(path, "/") {
|
|
path = "/" + path
|
|
}
|
|
mgr := &Manager{
|
|
path: path,
|
|
sessions: make(map[string]*session),
|
|
upgrader: websocket.Upgrader{
|
|
ReadBufferSize: 1024,
|
|
WriteBufferSize: 1024,
|
|
CheckOrigin: func(r *http.Request) bool {
|
|
return true
|
|
},
|
|
},
|
|
providerFactory: opts.ProviderFactory,
|
|
onConnected: opts.OnConnected,
|
|
onDisconnected: opts.OnDisconnected,
|
|
logDebugf: opts.LogDebugf,
|
|
logInfof: opts.LogInfof,
|
|
logWarnf: opts.LogWarnf,
|
|
}
|
|
if mgr.logDebugf == nil {
|
|
mgr.logDebugf = func(string, ...any) {}
|
|
}
|
|
if mgr.logInfof == nil {
|
|
mgr.logInfof = func(string, ...any) {}
|
|
}
|
|
if mgr.logWarnf == nil {
|
|
mgr.logWarnf = func(s string, args ...any) { fmt.Printf(s+"\n", args...) }
|
|
}
|
|
return mgr
|
|
}
|
|
|
|
// Path returns the HTTP path the manager expects for websocket upgrades.
|
|
func (m *Manager) Path() string {
|
|
if m == nil {
|
|
return "/v1/ws"
|
|
}
|
|
return m.path
|
|
}
|
|
|
|
// Handler exposes an http.Handler that upgrades connections to websocket sessions.
|
|
func (m *Manager) Handler() http.Handler {
|
|
return http.HandlerFunc(m.handleWebsocket)
|
|
}
|
|
|
|
// Stop gracefully closes all active websocket sessions.
|
|
func (m *Manager) Stop(_ context.Context) error {
|
|
m.sessMutex.Lock()
|
|
sessions := make([]*session, 0, len(m.sessions))
|
|
for _, sess := range m.sessions {
|
|
sessions = append(sessions, sess)
|
|
}
|
|
m.sessions = make(map[string]*session)
|
|
m.sessMutex.Unlock()
|
|
|
|
for _, sess := range sessions {
|
|
if sess != nil {
|
|
sess.cleanup(errors.New("wsrelay: manager stopped"))
|
|
}
|
|
}
|
|
return nil
|
|
}
|
|
|
|
// handleWebsocket upgrades the connection and wires the session into the pool.
|
|
func (m *Manager) handleWebsocket(w http.ResponseWriter, r *http.Request) {
|
|
expectedPath := m.Path()
|
|
if expectedPath != "" && r.URL != nil && r.URL.Path != expectedPath {
|
|
http.NotFound(w, r)
|
|
return
|
|
}
|
|
if !strings.EqualFold(r.Method, http.MethodGet) {
|
|
w.Header().Set("Allow", http.MethodGet)
|
|
http.Error(w, "method not allowed", http.StatusMethodNotAllowed)
|
|
return
|
|
}
|
|
conn, err := m.upgrader.Upgrade(w, r, nil)
|
|
if err != nil {
|
|
m.logWarnf("wsrelay: upgrade failed: %v", err)
|
|
return
|
|
}
|
|
s := newSession(conn, m, randomProviderName())
|
|
if m.providerFactory != nil {
|
|
name, err := m.providerFactory(r)
|
|
if err != nil {
|
|
s.cleanup(err)
|
|
return
|
|
}
|
|
if strings.TrimSpace(name) != "" {
|
|
s.provider = strings.ToLower(name)
|
|
}
|
|
}
|
|
if s.provider == "" {
|
|
s.provider = strings.ToLower(s.id)
|
|
}
|
|
m.sessMutex.Lock()
|
|
if existing, ok := m.sessions[s.provider]; ok {
|
|
existing.cleanup(errors.New("replaced by new connection"))
|
|
}
|
|
m.sessions[s.provider] = s
|
|
m.sessMutex.Unlock()
|
|
if m.onConnected != nil {
|
|
m.onConnected(s.provider)
|
|
}
|
|
|
|
go s.run(context.Background())
|
|
}
|
|
|
|
// Send forwards the message to the specific provider connection and returns a channel
|
|
// yielding response messages.
|
|
func (m *Manager) Send(ctx context.Context, provider string, msg Message) (<-chan Message, error) {
|
|
s := m.session(provider)
|
|
if s == nil {
|
|
return nil, fmt.Errorf("wsrelay: provider %s not connected", provider)
|
|
}
|
|
return s.request(ctx, msg)
|
|
}
|
|
|
|
func (m *Manager) session(provider string) *session {
|
|
key := strings.ToLower(strings.TrimSpace(provider))
|
|
m.sessMutex.RLock()
|
|
s := m.sessions[key]
|
|
m.sessMutex.RUnlock()
|
|
return s
|
|
}
|
|
|
|
func (m *Manager) handleSessionClosed(s *session, cause error) {
|
|
if s == nil {
|
|
return
|
|
}
|
|
key := strings.ToLower(strings.TrimSpace(s.provider))
|
|
m.sessMutex.Lock()
|
|
if cur, ok := m.sessions[key]; ok && cur == s {
|
|
delete(m.sessions, key)
|
|
}
|
|
m.sessMutex.Unlock()
|
|
if m.onDisconnected != nil {
|
|
m.onDisconnected(s.provider, cause)
|
|
}
|
|
}
|
|
|
|
func randomProviderName() string {
|
|
const alphabet = "abcdefghijklmnopqrstuvwxyz0123456789"
|
|
buf := make([]byte, 16)
|
|
if _, err := rand.Read(buf); err != nil {
|
|
return fmt.Sprintf("aistudio-%x", time.Now().UnixNano())
|
|
}
|
|
for i := range buf {
|
|
buf[i] = alphabet[int(buf[i])%len(alphabet)]
|
|
}
|
|
return "aistudio-" + string(buf)
|
|
}
|