fix(openai): add session reference counter and cache lifecycle management for websocket tools

This commit is contained in:
Luis Pater
2026-04-01 17:28:50 +08:00
parent d1c07a091e
commit acf98ed10e
2 changed files with 83 additions and 6 deletions

View File

@@ -55,6 +55,7 @@ func (h *OpenAIResponsesAPIHandler) ResponsesWebsocket(c *gin.Context) {
}
passthroughSessionID := uuid.NewString()
downstreamSessionKey := websocketDownstreamSessionKey(c.Request)
retainResponsesWebsocketToolCaches(downstreamSessionKey)
clientRemoteAddr := ""
if c != nil && c.Request != nil {
clientRemoteAddr = strings.TrimSpace(c.Request.RemoteAddr)
@@ -63,6 +64,7 @@ func (h *OpenAIResponsesAPIHandler) ResponsesWebsocket(c *gin.Context) {
var wsTerminateErr error
var wsBodyLog strings.Builder
defer func() {
releaseResponsesWebsocketToolCaches(downstreamSessionKey)
if wsTerminateErr != nil {
// log.Infof("responses websocket: session closing id=%s reason=%v", passthroughSessionID, wsTerminateErr)
} else {

View File

@@ -16,8 +16,9 @@ const (
websocketToolOutputCacheTTL = 30 * time.Minute
)
var defaultWebsocketToolOutputCache = newWebsocketToolOutputCache(websocketToolOutputCacheTTL, websocketToolOutputCacheMaxPerSession)
var defaultWebsocketToolCallCache = newWebsocketToolOutputCache(websocketToolOutputCacheTTL, websocketToolOutputCacheMaxPerSession)
var defaultWebsocketToolOutputCache = newWebsocketToolOutputCache(0, websocketToolOutputCacheMaxPerSession)
var defaultWebsocketToolCallCache = newWebsocketToolOutputCache(0, websocketToolOutputCacheMaxPerSession)
var defaultWebsocketToolSessionRefs = newWebsocketToolSessionRefCounter()
type websocketToolOutputCache struct {
mu sync.Mutex
@@ -33,7 +34,7 @@ type websocketToolOutputSession struct {
}
func newWebsocketToolOutputCache(ttl time.Duration, maxPerSession int) *websocketToolOutputCache {
if ttl <= 0 {
if ttl < 0 {
ttl = websocketToolOutputCacheTTL
}
if maxPerSession <= 0 {
@@ -122,13 +123,22 @@ func (c *websocketToolOutputCache) cleanupLocked(now time.Time) {
}
}
func (c *websocketToolOutputCache) deleteSession(sessionKey string) {
sessionKey = strings.TrimSpace(sessionKey)
if sessionKey == "" || c == nil {
return
}
c.mu.Lock()
defer c.mu.Unlock()
delete(c.sessions, sessionKey)
}
func websocketDownstreamSessionKey(req *http.Request) string {
if req == nil {
return ""
}
if sessionID := strings.TrimSpace(req.Header.Get("Session_id")); sessionID != "" {
return sessionID
}
if requestID := strings.TrimSpace(req.Header.Get("X-Client-Request-Id")); requestID != "" {
return requestID
}
@@ -137,9 +147,74 @@ func websocketDownstreamSessionKey(req *http.Request) string {
return sessionID
}
}
if sessionID := strings.TrimSpace(req.Header.Get("Session_id")); sessionID != "" {
return sessionID
}
return ""
}
type websocketToolSessionRefCounter struct {
mu sync.Mutex
counts map[string]int
}
func newWebsocketToolSessionRefCounter() *websocketToolSessionRefCounter {
return &websocketToolSessionRefCounter{counts: make(map[string]int)}
}
func (c *websocketToolSessionRefCounter) acquire(sessionKey string) {
sessionKey = strings.TrimSpace(sessionKey)
if sessionKey == "" || c == nil {
return
}
c.mu.Lock()
defer c.mu.Unlock()
c.counts[sessionKey]++
}
func (c *websocketToolSessionRefCounter) release(sessionKey string) bool {
sessionKey = strings.TrimSpace(sessionKey)
if sessionKey == "" || c == nil {
return false
}
c.mu.Lock()
defer c.mu.Unlock()
count := c.counts[sessionKey]
if count <= 1 {
delete(c.counts, sessionKey)
return true
}
c.counts[sessionKey] = count - 1
return false
}
func retainResponsesWebsocketToolCaches(sessionKey string) {
if defaultWebsocketToolSessionRefs == nil {
return
}
defaultWebsocketToolSessionRefs.acquire(sessionKey)
}
func releaseResponsesWebsocketToolCaches(sessionKey string) {
if defaultWebsocketToolSessionRefs == nil {
return
}
if !defaultWebsocketToolSessionRefs.release(sessionKey) {
return
}
if defaultWebsocketToolOutputCache != nil {
defaultWebsocketToolOutputCache.deleteSession(sessionKey)
}
if defaultWebsocketToolCallCache != nil {
defaultWebsocketToolCallCache.deleteSession(sessionKey)
}
}
func repairResponsesWebsocketToolCalls(sessionKey string, payload []byte) []byte {
return repairResponsesWebsocketToolCallsWithCaches(defaultWebsocketToolOutputCache, defaultWebsocketToolCallCache, sessionKey, payload)
}