From caa529c282303b171c180b92a772a1a766d8fdbb Mon Sep 17 00:00:00 2001 From: hkfires <10558748+hkfires@users.noreply.github.com> Date: Wed, 1 Apr 2026 20:16:01 +0800 Subject: [PATCH] fix(openai): improve client IP retrieval in websocket handler --- .../openai/openai_responses_websocket.go | 14 +++++++---- .../openai/openai_responses_websocket_test.go | 25 +++++++++++++++++++ 2 files changed, 34 insertions(+), 5 deletions(-) diff --git a/sdk/api/handlers/openai/openai_responses_websocket.go b/sdk/api/handlers/openai/openai_responses_websocket.go index 9f065efd..df46d971 100644 --- a/sdk/api/handlers/openai/openai_responses_websocket.go +++ b/sdk/api/handlers/openai/openai_responses_websocket.go @@ -54,11 +54,8 @@ func (h *OpenAIResponsesAPIHandler) ResponsesWebsocket(c *gin.Context) { passthroughSessionID := uuid.NewString() downstreamSessionKey := websocketDownstreamSessionKey(c.Request) retainResponsesWebsocketToolCaches(downstreamSessionKey) - clientRemoteAddr := "" - if c != nil && c.Request != nil { - clientRemoteAddr = strings.TrimSpace(c.Request.RemoteAddr) - } - log.Infof("responses websocket: client connected id=%s remote=%s", passthroughSessionID, clientRemoteAddr) + clientIP := websocketClientAddress(c) + log.Infof("responses websocket: client connected id=%s remote=%s", passthroughSessionID, clientIP) var wsTerminateErr error var wsBodyLog strings.Builder defer func() { @@ -206,6 +203,13 @@ func (h *OpenAIResponsesAPIHandler) ResponsesWebsocket(c *gin.Context) { } } +func websocketClientAddress(c *gin.Context) string { + if c == nil || c.Request == nil { + return "" + } + return strings.TrimSpace(c.ClientIP()) +} + func websocketUpgradeHeaders(req *http.Request) http.Header { headers := http.Header{} if req == nil { diff --git a/sdk/api/handlers/openai/openai_responses_websocket_test.go b/sdk/api/handlers/openai/openai_responses_websocket_test.go index 157d6e2f..773df18e 100644 --- a/sdk/api/handlers/openai/openai_responses_websocket_test.go +++ b/sdk/api/handlers/openai/openai_responses_websocket_test.go @@ -721,6 +721,31 @@ func TestResponsesWebsocketPrewarmHandledLocallyForSSEUpstream(t *testing.T) { } } +func TestWebsocketClientAddressUsesGinClientIP(t *testing.T) { + gin.SetMode(gin.TestMode) + + recorder := httptest.NewRecorder() + c, engine := gin.CreateTestContext(recorder) + if err := engine.SetTrustedProxies([]string{"0.0.0.0/0", "::/0"}); err != nil { + t.Fatalf("SetTrustedProxies: %v", err) + } + + req := httptest.NewRequest(http.MethodGet, "/v1/responses/ws", nil) + req.RemoteAddr = "172.18.0.1:34282" + req.Header.Set("X-Forwarded-For", "203.0.113.7") + c.Request = req + + if got := websocketClientAddress(c); got != strings.TrimSpace(c.ClientIP()) { + t.Fatalf("websocketClientAddress = %q, ClientIP = %q", got, c.ClientIP()) + } +} + +func TestWebsocketClientAddressReturnsEmptyForNilContext(t *testing.T) { + if got := websocketClientAddress(nil); got != "" { + t.Fatalf("websocketClientAddress(nil) = %q, want empty", got) + } +} + func TestResponsesWebsocketPinsOnlyWebsocketCapableAuth(t *testing.T) { gin.SetMode(gin.TestMode)