Merge pull request #2490 from router-for-me/logs

Refactor websocket logging and error handling
This commit is contained in:
Luis Pater
2026-04-02 20:47:31 +08:00
committed by GitHub
8 changed files with 926 additions and 124 deletions

View File

@@ -15,6 +15,8 @@ import (
) )
const requestBodyOverrideContextKey = "REQUEST_BODY_OVERRIDE" const requestBodyOverrideContextKey = "REQUEST_BODY_OVERRIDE"
const responseBodyOverrideContextKey = "RESPONSE_BODY_OVERRIDE"
const websocketTimelineOverrideContextKey = "WEBSOCKET_TIMELINE_OVERRIDE"
// RequestInfo holds essential details of an incoming HTTP request for logging purposes. // RequestInfo holds essential details of an incoming HTTP request for logging purposes.
type RequestInfo struct { type RequestInfo struct {
@@ -304,6 +306,10 @@ func (w *ResponseWriterWrapper) Finalize(c *gin.Context) error {
if len(apiResponse) > 0 { if len(apiResponse) > 0 {
_ = w.streamWriter.WriteAPIResponse(apiResponse) _ = w.streamWriter.WriteAPIResponse(apiResponse)
} }
apiWebsocketTimeline := w.extractAPIWebsocketTimeline(c)
if len(apiWebsocketTimeline) > 0 {
_ = w.streamWriter.WriteAPIWebsocketTimeline(apiWebsocketTimeline)
}
if err := w.streamWriter.Close(); err != nil { if err := w.streamWriter.Close(); err != nil {
w.streamWriter = nil w.streamWriter = nil
return err return err
@@ -312,7 +318,7 @@ func (w *ResponseWriterWrapper) Finalize(c *gin.Context) error {
return nil return nil
} }
return w.logRequest(w.extractRequestBody(c), finalStatusCode, w.cloneHeaders(), w.body.Bytes(), w.extractAPIRequest(c), w.extractAPIResponse(c), w.extractAPIResponseTimestamp(c), slicesAPIResponseError, forceLog) return w.logRequest(w.extractRequestBody(c), finalStatusCode, w.cloneHeaders(), w.extractResponseBody(c), w.extractWebsocketTimeline(c), w.extractAPIRequest(c), w.extractAPIResponse(c), w.extractAPIWebsocketTimeline(c), w.extractAPIResponseTimestamp(c), slicesAPIResponseError, forceLog)
} }
func (w *ResponseWriterWrapper) cloneHeaders() map[string][]string { func (w *ResponseWriterWrapper) cloneHeaders() map[string][]string {
@@ -352,6 +358,18 @@ func (w *ResponseWriterWrapper) extractAPIResponse(c *gin.Context) []byte {
return data return data
} }
func (w *ResponseWriterWrapper) extractAPIWebsocketTimeline(c *gin.Context) []byte {
apiTimeline, isExist := c.Get("API_WEBSOCKET_TIMELINE")
if !isExist {
return nil
}
data, ok := apiTimeline.([]byte)
if !ok || len(data) == 0 {
return nil
}
return bytes.Clone(data)
}
func (w *ResponseWriterWrapper) extractAPIResponseTimestamp(c *gin.Context) time.Time { func (w *ResponseWriterWrapper) extractAPIResponseTimestamp(c *gin.Context) time.Time {
ts, isExist := c.Get("API_RESPONSE_TIMESTAMP") ts, isExist := c.Get("API_RESPONSE_TIMESTAMP")
if !isExist { if !isExist {
@@ -364,19 +382,8 @@ func (w *ResponseWriterWrapper) extractAPIResponseTimestamp(c *gin.Context) time
} }
func (w *ResponseWriterWrapper) extractRequestBody(c *gin.Context) []byte { func (w *ResponseWriterWrapper) extractRequestBody(c *gin.Context) []byte {
if c != nil { if body := extractBodyOverride(c, requestBodyOverrideContextKey); len(body) > 0 {
if bodyOverride, isExist := c.Get(requestBodyOverrideContextKey); isExist { return body
switch value := bodyOverride.(type) {
case []byte:
if len(value) > 0 {
return bytes.Clone(value)
}
case string:
if strings.TrimSpace(value) != "" {
return []byte(value)
}
}
}
} }
if w.requestInfo != nil && len(w.requestInfo.Body) > 0 { if w.requestInfo != nil && len(w.requestInfo.Body) > 0 {
return w.requestInfo.Body return w.requestInfo.Body
@@ -384,13 +391,48 @@ func (w *ResponseWriterWrapper) extractRequestBody(c *gin.Context) []byte {
return nil return nil
} }
func (w *ResponseWriterWrapper) logRequest(requestBody []byte, statusCode int, headers map[string][]string, body []byte, apiRequestBody, apiResponseBody []byte, apiResponseTimestamp time.Time, apiResponseErrors []*interfaces.ErrorMessage, forceLog bool) error { func (w *ResponseWriterWrapper) extractResponseBody(c *gin.Context) []byte {
if body := extractBodyOverride(c, responseBodyOverrideContextKey); len(body) > 0 {
return body
}
if w.body == nil || w.body.Len() == 0 {
return nil
}
return bytes.Clone(w.body.Bytes())
}
func (w *ResponseWriterWrapper) extractWebsocketTimeline(c *gin.Context) []byte {
return extractBodyOverride(c, websocketTimelineOverrideContextKey)
}
func extractBodyOverride(c *gin.Context, key string) []byte {
if c == nil {
return nil
}
bodyOverride, isExist := c.Get(key)
if !isExist {
return nil
}
switch value := bodyOverride.(type) {
case []byte:
if len(value) > 0 {
return bytes.Clone(value)
}
case string:
if strings.TrimSpace(value) != "" {
return []byte(value)
}
}
return nil
}
func (w *ResponseWriterWrapper) logRequest(requestBody []byte, statusCode int, headers map[string][]string, body, websocketTimeline, apiRequestBody, apiResponseBody, apiWebsocketTimeline []byte, apiResponseTimestamp time.Time, apiResponseErrors []*interfaces.ErrorMessage, forceLog bool) error {
if w.requestInfo == nil { if w.requestInfo == nil {
return nil return nil
} }
if loggerWithOptions, ok := w.logger.(interface { if loggerWithOptions, ok := w.logger.(interface {
LogRequestWithOptions(string, string, map[string][]string, []byte, int, map[string][]string, []byte, []byte, []byte, []*interfaces.ErrorMessage, bool, string, time.Time, time.Time) error LogRequestWithOptions(string, string, map[string][]string, []byte, int, map[string][]string, []byte, []byte, []byte, []byte, []byte, []*interfaces.ErrorMessage, bool, string, time.Time, time.Time) error
}); ok { }); ok {
return loggerWithOptions.LogRequestWithOptions( return loggerWithOptions.LogRequestWithOptions(
w.requestInfo.URL, w.requestInfo.URL,
@@ -400,8 +442,10 @@ func (w *ResponseWriterWrapper) logRequest(requestBody []byte, statusCode int, h
statusCode, statusCode,
headers, headers,
body, body,
websocketTimeline,
apiRequestBody, apiRequestBody,
apiResponseBody, apiResponseBody,
apiWebsocketTimeline,
apiResponseErrors, apiResponseErrors,
forceLog, forceLog,
w.requestInfo.RequestID, w.requestInfo.RequestID,
@@ -418,8 +462,10 @@ func (w *ResponseWriterWrapper) logRequest(requestBody []byte, statusCode int, h
statusCode, statusCode,
headers, headers,
body, body,
websocketTimeline,
apiRequestBody, apiRequestBody,
apiResponseBody, apiResponseBody,
apiWebsocketTimeline,
apiResponseErrors, apiResponseErrors,
w.requestInfo.RequestID, w.requestInfo.RequestID,
w.requestInfo.Timestamp, w.requestInfo.Timestamp,

View File

@@ -1,10 +1,14 @@
package middleware package middleware
import ( import (
"bytes"
"net/http/httptest" "net/http/httptest"
"testing" "testing"
"time"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"github.com/router-for-me/CLIProxyAPI/v6/internal/interfaces"
"github.com/router-for-me/CLIProxyAPI/v6/internal/logging"
) )
func TestExtractRequestBodyPrefersOverride(t *testing.T) { func TestExtractRequestBodyPrefersOverride(t *testing.T) {
@@ -33,7 +37,7 @@ func TestExtractRequestBodySupportsStringOverride(t *testing.T) {
recorder := httptest.NewRecorder() recorder := httptest.NewRecorder()
c, _ := gin.CreateTestContext(recorder) c, _ := gin.CreateTestContext(recorder)
wrapper := &ResponseWriterWrapper{} wrapper := &ResponseWriterWrapper{body: &bytes.Buffer{}}
c.Set(requestBodyOverrideContextKey, "override-as-string") c.Set(requestBodyOverrideContextKey, "override-as-string")
body := wrapper.extractRequestBody(c) body := wrapper.extractRequestBody(c)
@@ -41,3 +45,158 @@ func TestExtractRequestBodySupportsStringOverride(t *testing.T) {
t.Fatalf("request body = %q, want %q", string(body), "override-as-string") t.Fatalf("request body = %q, want %q", string(body), "override-as-string")
} }
} }
func TestExtractResponseBodyPrefersOverride(t *testing.T) {
gin.SetMode(gin.TestMode)
recorder := httptest.NewRecorder()
c, _ := gin.CreateTestContext(recorder)
wrapper := &ResponseWriterWrapper{body: &bytes.Buffer{}}
wrapper.body.WriteString("original-response")
body := wrapper.extractResponseBody(c)
if string(body) != "original-response" {
t.Fatalf("response body = %q, want %q", string(body), "original-response")
}
c.Set(responseBodyOverrideContextKey, []byte("override-response"))
body = wrapper.extractResponseBody(c)
if string(body) != "override-response" {
t.Fatalf("response body = %q, want %q", string(body), "override-response")
}
body[0] = 'X'
if got := wrapper.extractResponseBody(c); string(got) != "override-response" {
t.Fatalf("response override should be cloned, got %q", string(got))
}
}
func TestExtractResponseBodySupportsStringOverride(t *testing.T) {
gin.SetMode(gin.TestMode)
recorder := httptest.NewRecorder()
c, _ := gin.CreateTestContext(recorder)
wrapper := &ResponseWriterWrapper{}
c.Set(responseBodyOverrideContextKey, "override-response-as-string")
body := wrapper.extractResponseBody(c)
if string(body) != "override-response-as-string" {
t.Fatalf("response body = %q, want %q", string(body), "override-response-as-string")
}
}
func TestExtractBodyOverrideClonesBytes(t *testing.T) {
gin.SetMode(gin.TestMode)
recorder := httptest.NewRecorder()
c, _ := gin.CreateTestContext(recorder)
override := []byte("body-override")
c.Set(requestBodyOverrideContextKey, override)
body := extractBodyOverride(c, requestBodyOverrideContextKey)
if !bytes.Equal(body, override) {
t.Fatalf("body override = %q, want %q", string(body), string(override))
}
body[0] = 'X'
if !bytes.Equal(override, []byte("body-override")) {
t.Fatalf("override mutated: %q", string(override))
}
}
func TestExtractWebsocketTimelineUsesOverride(t *testing.T) {
gin.SetMode(gin.TestMode)
recorder := httptest.NewRecorder()
c, _ := gin.CreateTestContext(recorder)
wrapper := &ResponseWriterWrapper{}
if got := wrapper.extractWebsocketTimeline(c); got != nil {
t.Fatalf("expected nil websocket timeline, got %q", string(got))
}
c.Set(websocketTimelineOverrideContextKey, []byte("timeline"))
body := wrapper.extractWebsocketTimeline(c)
if string(body) != "timeline" {
t.Fatalf("websocket timeline = %q, want %q", string(body), "timeline")
}
}
func TestFinalizeStreamingWritesAPIWebsocketTimeline(t *testing.T) {
gin.SetMode(gin.TestMode)
recorder := httptest.NewRecorder()
c, _ := gin.CreateTestContext(recorder)
streamWriter := &testStreamingLogWriter{}
wrapper := &ResponseWriterWrapper{
ResponseWriter: c.Writer,
logger: &testRequestLogger{enabled: true},
requestInfo: &RequestInfo{
URL: "/v1/responses",
Method: "POST",
Headers: map[string][]string{"Content-Type": {"application/json"}},
RequestID: "req-1",
Timestamp: time.Date(2026, time.April, 1, 12, 0, 0, 0, time.UTC),
},
isStreaming: true,
streamWriter: streamWriter,
}
c.Set("API_WEBSOCKET_TIMELINE", []byte("Timestamp: 2026-04-01T12:00:00Z\nEvent: api.websocket.request\n{}"))
if err := wrapper.Finalize(c); err != nil {
t.Fatalf("Finalize error: %v", err)
}
if string(streamWriter.apiWebsocketTimeline) != "Timestamp: 2026-04-01T12:00:00Z\nEvent: api.websocket.request\n{}" {
t.Fatalf("stream writer websocket timeline = %q", string(streamWriter.apiWebsocketTimeline))
}
if !streamWriter.closed {
t.Fatal("expected stream writer to be closed")
}
}
type testRequestLogger struct {
enabled bool
}
func (l *testRequestLogger) LogRequest(string, string, map[string][]string, []byte, int, map[string][]string, []byte, []byte, []byte, []byte, []byte, []*interfaces.ErrorMessage, string, time.Time, time.Time) error {
return nil
}
func (l *testRequestLogger) LogStreamingRequest(string, string, map[string][]string, []byte, string) (logging.StreamingLogWriter, error) {
return &testStreamingLogWriter{}, nil
}
func (l *testRequestLogger) IsEnabled() bool {
return l.enabled
}
type testStreamingLogWriter struct {
apiWebsocketTimeline []byte
closed bool
}
func (w *testStreamingLogWriter) WriteChunkAsync([]byte) {}
func (w *testStreamingLogWriter) WriteStatus(int, map[string][]string) error {
return nil
}
func (w *testStreamingLogWriter) WriteAPIRequest([]byte) error {
return nil
}
func (w *testStreamingLogWriter) WriteAPIResponse([]byte) error {
return nil
}
func (w *testStreamingLogWriter) WriteAPIWebsocketTimeline(apiWebsocketTimeline []byte) error {
w.apiWebsocketTimeline = bytes.Clone(apiWebsocketTimeline)
return nil
}
func (w *testStreamingLogWriter) SetFirstChunkTimestamp(time.Time) {}
func (w *testStreamingLogWriter) Close() error {
w.closed = true
return nil
}

View File

@@ -172,6 +172,8 @@ func TestDefaultRequestLoggerFactory_UsesResolvedLogDirectory(t *testing.T) {
nil, nil,
nil, nil,
nil, nil,
nil,
nil,
true, true,
"issue-1711", "issue-1711",
time.Now(), time.Now(),

View File

@@ -4,6 +4,7 @@
package logging package logging
import ( import (
"bufio"
"bytes" "bytes"
"compress/flate" "compress/flate"
"compress/gzip" "compress/gzip"
@@ -41,15 +42,17 @@ type RequestLogger interface {
// - statusCode: The response status code // - statusCode: The response status code
// - responseHeaders: The response headers // - responseHeaders: The response headers
// - response: The raw response data // - response: The raw response data
// - websocketTimeline: Optional downstream websocket event timeline
// - apiRequest: The API request data // - apiRequest: The API request data
// - apiResponse: The API response data // - apiResponse: The API response data
// - apiWebsocketTimeline: Optional upstream websocket event timeline
// - requestID: Optional request ID for log file naming // - requestID: Optional request ID for log file naming
// - requestTimestamp: When the request was received // - requestTimestamp: When the request was received
// - apiResponseTimestamp: When the API response was received // - apiResponseTimestamp: When the API response was received
// //
// Returns: // Returns:
// - error: An error if logging fails, nil otherwise // - error: An error if logging fails, nil otherwise
LogRequest(url, method string, requestHeaders map[string][]string, body []byte, statusCode int, responseHeaders map[string][]string, response, apiRequest, apiResponse []byte, apiResponseErrors []*interfaces.ErrorMessage, requestID string, requestTimestamp, apiResponseTimestamp time.Time) error LogRequest(url, method string, requestHeaders map[string][]string, body []byte, statusCode int, responseHeaders map[string][]string, response, websocketTimeline, apiRequest, apiResponse, apiWebsocketTimeline []byte, apiResponseErrors []*interfaces.ErrorMessage, requestID string, requestTimestamp, apiResponseTimestamp time.Time) error
// LogStreamingRequest initiates logging for a streaming request and returns a writer for chunks. // LogStreamingRequest initiates logging for a streaming request and returns a writer for chunks.
// //
@@ -111,6 +114,16 @@ type StreamingLogWriter interface {
// - error: An error if writing fails, nil otherwise // - error: An error if writing fails, nil otherwise
WriteAPIResponse(apiResponse []byte) error WriteAPIResponse(apiResponse []byte) error
// WriteAPIWebsocketTimeline writes the upstream websocket timeline to the log.
// This should be called when upstream communication happened over websocket.
//
// Parameters:
// - apiWebsocketTimeline: The upstream websocket event timeline
//
// Returns:
// - error: An error if writing fails, nil otherwise
WriteAPIWebsocketTimeline(apiWebsocketTimeline []byte) error
// SetFirstChunkTimestamp sets the TTFB timestamp captured when first chunk was received. // SetFirstChunkTimestamp sets the TTFB timestamp captured when first chunk was received.
// //
// Parameters: // Parameters:
@@ -203,17 +216,17 @@ func (l *FileRequestLogger) SetErrorLogsMaxFiles(maxFiles int) {
// //
// Returns: // Returns:
// - error: An error if logging fails, nil otherwise // - error: An error if logging fails, nil otherwise
func (l *FileRequestLogger) LogRequest(url, method string, requestHeaders map[string][]string, body []byte, statusCode int, responseHeaders map[string][]string, response, apiRequest, apiResponse []byte, apiResponseErrors []*interfaces.ErrorMessage, requestID string, requestTimestamp, apiResponseTimestamp time.Time) error { func (l *FileRequestLogger) LogRequest(url, method string, requestHeaders map[string][]string, body []byte, statusCode int, responseHeaders map[string][]string, response, websocketTimeline, apiRequest, apiResponse, apiWebsocketTimeline []byte, apiResponseErrors []*interfaces.ErrorMessage, requestID string, requestTimestamp, apiResponseTimestamp time.Time) error {
return l.logRequest(url, method, requestHeaders, body, statusCode, responseHeaders, response, apiRequest, apiResponse, apiResponseErrors, false, requestID, requestTimestamp, apiResponseTimestamp) return l.logRequest(url, method, requestHeaders, body, statusCode, responseHeaders, response, websocketTimeline, apiRequest, apiResponse, apiWebsocketTimeline, apiResponseErrors, false, requestID, requestTimestamp, apiResponseTimestamp)
} }
// LogRequestWithOptions logs a request with optional forced logging behavior. // LogRequestWithOptions logs a request with optional forced logging behavior.
// The force flag allows writing error logs even when regular request logging is disabled. // The force flag allows writing error logs even when regular request logging is disabled.
func (l *FileRequestLogger) LogRequestWithOptions(url, method string, requestHeaders map[string][]string, body []byte, statusCode int, responseHeaders map[string][]string, response, apiRequest, apiResponse []byte, apiResponseErrors []*interfaces.ErrorMessage, force bool, requestID string, requestTimestamp, apiResponseTimestamp time.Time) error { func (l *FileRequestLogger) LogRequestWithOptions(url, method string, requestHeaders map[string][]string, body []byte, statusCode int, responseHeaders map[string][]string, response, websocketTimeline, apiRequest, apiResponse, apiWebsocketTimeline []byte, apiResponseErrors []*interfaces.ErrorMessage, force bool, requestID string, requestTimestamp, apiResponseTimestamp time.Time) error {
return l.logRequest(url, method, requestHeaders, body, statusCode, responseHeaders, response, apiRequest, apiResponse, apiResponseErrors, force, requestID, requestTimestamp, apiResponseTimestamp) return l.logRequest(url, method, requestHeaders, body, statusCode, responseHeaders, response, websocketTimeline, apiRequest, apiResponse, apiWebsocketTimeline, apiResponseErrors, force, requestID, requestTimestamp, apiResponseTimestamp)
} }
func (l *FileRequestLogger) logRequest(url, method string, requestHeaders map[string][]string, body []byte, statusCode int, responseHeaders map[string][]string, response, apiRequest, apiResponse []byte, apiResponseErrors []*interfaces.ErrorMessage, force bool, requestID string, requestTimestamp, apiResponseTimestamp time.Time) error { func (l *FileRequestLogger) logRequest(url, method string, requestHeaders map[string][]string, body []byte, statusCode int, responseHeaders map[string][]string, response, websocketTimeline, apiRequest, apiResponse, apiWebsocketTimeline []byte, apiResponseErrors []*interfaces.ErrorMessage, force bool, requestID string, requestTimestamp, apiResponseTimestamp time.Time) error {
if !l.enabled && !force { if !l.enabled && !force {
return nil return nil
} }
@@ -260,8 +273,10 @@ func (l *FileRequestLogger) logRequest(url, method string, requestHeaders map[st
requestHeaders, requestHeaders,
body, body,
requestBodyPath, requestBodyPath,
websocketTimeline,
apiRequest, apiRequest,
apiResponse, apiResponse,
apiWebsocketTimeline,
apiResponseErrors, apiResponseErrors,
statusCode, statusCode,
responseHeaders, responseHeaders,
@@ -518,8 +533,10 @@ func (l *FileRequestLogger) writeNonStreamingLog(
requestHeaders map[string][]string, requestHeaders map[string][]string,
requestBody []byte, requestBody []byte,
requestBodyPath string, requestBodyPath string,
websocketTimeline []byte,
apiRequest []byte, apiRequest []byte,
apiResponse []byte, apiResponse []byte,
apiWebsocketTimeline []byte,
apiResponseErrors []*interfaces.ErrorMessage, apiResponseErrors []*interfaces.ErrorMessage,
statusCode int, statusCode int,
responseHeaders map[string][]string, responseHeaders map[string][]string,
@@ -531,7 +548,16 @@ func (l *FileRequestLogger) writeNonStreamingLog(
if requestTimestamp.IsZero() { if requestTimestamp.IsZero() {
requestTimestamp = time.Now() requestTimestamp = time.Now()
} }
if errWrite := writeRequestInfoWithBody(w, url, method, requestHeaders, requestBody, requestBodyPath, requestTimestamp); errWrite != nil { isWebsocketTranscript := hasSectionPayload(websocketTimeline)
downstreamTransport := inferDownstreamTransport(requestHeaders, websocketTimeline)
upstreamTransport := inferUpstreamTransport(apiRequest, apiResponse, apiWebsocketTimeline, apiResponseErrors)
if errWrite := writeRequestInfoWithBody(w, url, method, requestHeaders, requestBody, requestBodyPath, requestTimestamp, downstreamTransport, upstreamTransport, !isWebsocketTranscript); errWrite != nil {
return errWrite
}
if errWrite := writeAPISection(w, "=== WEBSOCKET TIMELINE ===\n", "=== WEBSOCKET TIMELINE", websocketTimeline, time.Time{}); errWrite != nil {
return errWrite
}
if errWrite := writeAPISection(w, "=== API WEBSOCKET TIMELINE ===\n", "=== API WEBSOCKET TIMELINE", apiWebsocketTimeline, time.Time{}); errWrite != nil {
return errWrite return errWrite
} }
if errWrite := writeAPISection(w, "=== API REQUEST ===\n", "=== API REQUEST", apiRequest, time.Time{}); errWrite != nil { if errWrite := writeAPISection(w, "=== API REQUEST ===\n", "=== API REQUEST", apiRequest, time.Time{}); errWrite != nil {
@@ -543,6 +569,12 @@ func (l *FileRequestLogger) writeNonStreamingLog(
if errWrite := writeAPISection(w, "=== API RESPONSE ===\n", "=== API RESPONSE", apiResponse, apiResponseTimestamp); errWrite != nil { if errWrite := writeAPISection(w, "=== API RESPONSE ===\n", "=== API RESPONSE", apiResponse, apiResponseTimestamp); errWrite != nil {
return errWrite return errWrite
} }
if isWebsocketTranscript {
// Intentionally omit the generic downstream HTTP response section for websocket
// transcripts. The durable session exchange is captured in WEBSOCKET TIMELINE,
// and appending a one-off upgrade response snapshot would dilute that transcript.
return nil
}
return writeResponseSection(w, statusCode, true, responseHeaders, bytes.NewReader(response), decompressErr, true) return writeResponseSection(w, statusCode, true, responseHeaders, bytes.NewReader(response), decompressErr, true)
} }
@@ -553,6 +585,9 @@ func writeRequestInfoWithBody(
body []byte, body []byte,
bodyPath string, bodyPath string,
timestamp time.Time, timestamp time.Time,
downstreamTransport string,
upstreamTransport string,
includeBody bool,
) error { ) error {
if _, errWrite := io.WriteString(w, "=== REQUEST INFO ===\n"); errWrite != nil { if _, errWrite := io.WriteString(w, "=== REQUEST INFO ===\n"); errWrite != nil {
return errWrite return errWrite
@@ -566,10 +601,20 @@ func writeRequestInfoWithBody(
if _, errWrite := io.WriteString(w, fmt.Sprintf("Method: %s\n", method)); errWrite != nil { if _, errWrite := io.WriteString(w, fmt.Sprintf("Method: %s\n", method)); errWrite != nil {
return errWrite return errWrite
} }
if strings.TrimSpace(downstreamTransport) != "" {
if _, errWrite := io.WriteString(w, fmt.Sprintf("Downstream Transport: %s\n", downstreamTransport)); errWrite != nil {
return errWrite
}
}
if strings.TrimSpace(upstreamTransport) != "" {
if _, errWrite := io.WriteString(w, fmt.Sprintf("Upstream Transport: %s\n", upstreamTransport)); errWrite != nil {
return errWrite
}
}
if _, errWrite := io.WriteString(w, fmt.Sprintf("Timestamp: %s\n", timestamp.Format(time.RFC3339Nano))); errWrite != nil { if _, errWrite := io.WriteString(w, fmt.Sprintf("Timestamp: %s\n", timestamp.Format(time.RFC3339Nano))); errWrite != nil {
return errWrite return errWrite
} }
if _, errWrite := io.WriteString(w, "\n"); errWrite != nil { if errWrite := writeSectionSpacing(w, 1); errWrite != nil {
return errWrite return errWrite
} }
@@ -584,36 +629,121 @@ func writeRequestInfoWithBody(
} }
} }
} }
if _, errWrite := io.WriteString(w, "\n"); errWrite != nil { if errWrite := writeSectionSpacing(w, 1); errWrite != nil {
return errWrite return errWrite
} }
if !includeBody {
return nil
}
if _, errWrite := io.WriteString(w, "=== REQUEST BODY ===\n"); errWrite != nil { if _, errWrite := io.WriteString(w, "=== REQUEST BODY ===\n"); errWrite != nil {
return errWrite return errWrite
} }
bodyTrailingNewlines := 1
if bodyPath != "" { if bodyPath != "" {
bodyFile, errOpen := os.Open(bodyPath) bodyFile, errOpen := os.Open(bodyPath)
if errOpen != nil { if errOpen != nil {
return errOpen return errOpen
} }
if _, errCopy := io.Copy(w, bodyFile); errCopy != nil { tracker := &trailingNewlineTrackingWriter{writer: w}
written, errCopy := io.Copy(tracker, bodyFile)
if errCopy != nil {
_ = bodyFile.Close() _ = bodyFile.Close()
return errCopy return errCopy
} }
if written > 0 {
bodyTrailingNewlines = tracker.trailingNewlines
}
if errClose := bodyFile.Close(); errClose != nil { if errClose := bodyFile.Close(); errClose != nil {
log.WithError(errClose).Warn("failed to close request body temp file") log.WithError(errClose).Warn("failed to close request body temp file")
} }
} else if _, errWrite := w.Write(body); errWrite != nil { } else if _, errWrite := w.Write(body); errWrite != nil {
return errWrite return errWrite
} else if len(body) > 0 {
bodyTrailingNewlines = countTrailingNewlinesBytes(body)
} }
if errWrite := writeSectionSpacing(w, bodyTrailingNewlines); errWrite != nil {
if _, errWrite := io.WriteString(w, "\n\n"); errWrite != nil {
return errWrite return errWrite
} }
return nil return nil
} }
func countTrailingNewlinesBytes(payload []byte) int {
count := 0
for i := len(payload) - 1; i >= 0; i-- {
if payload[i] != '\n' {
break
}
count++
}
return count
}
func writeSectionSpacing(w io.Writer, trailingNewlines int) error {
missingNewlines := 3 - trailingNewlines
if missingNewlines <= 0 {
return nil
}
_, errWrite := io.WriteString(w, strings.Repeat("\n", missingNewlines))
return errWrite
}
type trailingNewlineTrackingWriter struct {
writer io.Writer
trailingNewlines int
}
func (t *trailingNewlineTrackingWriter) Write(payload []byte) (int, error) {
written, errWrite := t.writer.Write(payload)
if written > 0 {
writtenPayload := payload[:written]
trailingNewlines := countTrailingNewlinesBytes(writtenPayload)
if trailingNewlines == len(writtenPayload) {
t.trailingNewlines += trailingNewlines
} else {
t.trailingNewlines = trailingNewlines
}
}
return written, errWrite
}
func hasSectionPayload(payload []byte) bool {
return len(bytes.TrimSpace(payload)) > 0
}
func inferDownstreamTransport(headers map[string][]string, websocketTimeline []byte) string {
if hasSectionPayload(websocketTimeline) {
return "websocket"
}
for key, values := range headers {
if strings.EqualFold(strings.TrimSpace(key), "Upgrade") {
for _, value := range values {
if strings.EqualFold(strings.TrimSpace(value), "websocket") {
return "websocket"
}
}
}
}
return "http"
}
func inferUpstreamTransport(apiRequest, apiResponse, apiWebsocketTimeline []byte, _ []*interfaces.ErrorMessage) string {
hasHTTP := hasSectionPayload(apiRequest) || hasSectionPayload(apiResponse)
hasWS := hasSectionPayload(apiWebsocketTimeline)
switch {
case hasHTTP && hasWS:
return "websocket+http"
case hasWS:
return "websocket"
case hasHTTP:
return "http"
default:
return ""
}
}
func writeAPISection(w io.Writer, sectionHeader string, sectionPrefix string, payload []byte, timestamp time.Time) error { func writeAPISection(w io.Writer, sectionHeader string, sectionPrefix string, payload []byte, timestamp time.Time) error {
if len(payload) == 0 { if len(payload) == 0 {
return nil return nil
@@ -623,11 +753,6 @@ func writeAPISection(w io.Writer, sectionHeader string, sectionPrefix string, pa
if _, errWrite := w.Write(payload); errWrite != nil { if _, errWrite := w.Write(payload); errWrite != nil {
return errWrite return errWrite
} }
if !bytes.HasSuffix(payload, []byte("\n")) {
if _, errWrite := io.WriteString(w, "\n"); errWrite != nil {
return errWrite
}
}
} else { } else {
if _, errWrite := io.WriteString(w, sectionHeader); errWrite != nil { if _, errWrite := io.WriteString(w, sectionHeader); errWrite != nil {
return errWrite return errWrite
@@ -640,12 +765,9 @@ func writeAPISection(w io.Writer, sectionHeader string, sectionPrefix string, pa
if _, errWrite := w.Write(payload); errWrite != nil { if _, errWrite := w.Write(payload); errWrite != nil {
return errWrite return errWrite
} }
if _, errWrite := io.WriteString(w, "\n"); errWrite != nil {
return errWrite
}
} }
if _, errWrite := io.WriteString(w, "\n"); errWrite != nil { if errWrite := writeSectionSpacing(w, countTrailingNewlinesBytes(payload)); errWrite != nil {
return errWrite return errWrite
} }
return nil return nil
@@ -662,12 +784,17 @@ func writeAPIErrorResponses(w io.Writer, apiResponseErrors []*interfaces.ErrorMe
if _, errWrite := io.WriteString(w, fmt.Sprintf("HTTP Status: %d\n", apiResponseErrors[i].StatusCode)); errWrite != nil { if _, errWrite := io.WriteString(w, fmt.Sprintf("HTTP Status: %d\n", apiResponseErrors[i].StatusCode)); errWrite != nil {
return errWrite return errWrite
} }
trailingNewlines := 1
if apiResponseErrors[i].Error != nil { if apiResponseErrors[i].Error != nil {
if _, errWrite := io.WriteString(w, apiResponseErrors[i].Error.Error()); errWrite != nil { errText := apiResponseErrors[i].Error.Error()
if _, errWrite := io.WriteString(w, errText); errWrite != nil {
return errWrite return errWrite
} }
if errText != "" {
trailingNewlines = countTrailingNewlinesBytes([]byte(errText))
}
} }
if _, errWrite := io.WriteString(w, "\n\n"); errWrite != nil { if errWrite := writeSectionSpacing(w, trailingNewlines); errWrite != nil {
return errWrite return errWrite
} }
} }
@@ -694,12 +821,18 @@ func writeResponseSection(w io.Writer, statusCode int, statusWritten bool, respo
} }
} }
if _, errWrite := io.WriteString(w, "\n"); errWrite != nil { var bufferedReader *bufio.Reader
return errWrite if responseReader != nil {
bufferedReader = bufio.NewReader(responseReader)
}
if !responseBodyStartsWithLeadingNewline(bufferedReader) {
if _, errWrite := io.WriteString(w, "\n"); errWrite != nil {
return errWrite
}
} }
if responseReader != nil { if bufferedReader != nil {
if _, errCopy := io.Copy(w, responseReader); errCopy != nil { if _, errCopy := io.Copy(w, bufferedReader); errCopy != nil {
return errCopy return errCopy
} }
} }
@@ -717,6 +850,19 @@ func writeResponseSection(w io.Writer, statusCode int, statusWritten bool, respo
return nil return nil
} }
func responseBodyStartsWithLeadingNewline(reader *bufio.Reader) bool {
if reader == nil {
return false
}
if peeked, _ := reader.Peek(2); len(peeked) >= 2 && peeked[0] == '\r' && peeked[1] == '\n' {
return true
}
if peeked, _ := reader.Peek(1); len(peeked) >= 1 && peeked[0] == '\n' {
return true
}
return false
}
// formatLogContent creates the complete log content for non-streaming requests. // formatLogContent creates the complete log content for non-streaming requests.
// //
// Parameters: // Parameters:
@@ -724,6 +870,7 @@ func writeResponseSection(w io.Writer, statusCode int, statusWritten bool, respo
// - method: The HTTP method // - method: The HTTP method
// - headers: The request headers // - headers: The request headers
// - body: The request body // - body: The request body
// - websocketTimeline: The downstream websocket event timeline
// - apiRequest: The API request data // - apiRequest: The API request data
// - apiResponse: The API response data // - apiResponse: The API response data
// - response: The raw response data // - response: The raw response data
@@ -732,11 +879,42 @@ func writeResponseSection(w io.Writer, statusCode int, statusWritten bool, respo
// //
// Returns: // Returns:
// - string: The formatted log content // - string: The formatted log content
func (l *FileRequestLogger) formatLogContent(url, method string, headers map[string][]string, body, apiRequest, apiResponse, response []byte, status int, responseHeaders map[string][]string, apiResponseErrors []*interfaces.ErrorMessage) string { func (l *FileRequestLogger) formatLogContent(url, method string, headers map[string][]string, body, websocketTimeline, apiRequest, apiResponse, apiWebsocketTimeline, response []byte, status int, responseHeaders map[string][]string, apiResponseErrors []*interfaces.ErrorMessage) string {
var content strings.Builder var content strings.Builder
isWebsocketTranscript := hasSectionPayload(websocketTimeline)
downstreamTransport := inferDownstreamTransport(headers, websocketTimeline)
upstreamTransport := inferUpstreamTransport(apiRequest, apiResponse, apiWebsocketTimeline, apiResponseErrors)
// Request info // Request info
content.WriteString(l.formatRequestInfo(url, method, headers, body)) content.WriteString(l.formatRequestInfo(url, method, headers, body, downstreamTransport, upstreamTransport, !isWebsocketTranscript))
if len(websocketTimeline) > 0 {
if bytes.HasPrefix(websocketTimeline, []byte("=== WEBSOCKET TIMELINE")) {
content.Write(websocketTimeline)
if !bytes.HasSuffix(websocketTimeline, []byte("\n")) {
content.WriteString("\n")
}
} else {
content.WriteString("=== WEBSOCKET TIMELINE ===\n")
content.Write(websocketTimeline)
content.WriteString("\n")
}
content.WriteString("\n")
}
if len(apiWebsocketTimeline) > 0 {
if bytes.HasPrefix(apiWebsocketTimeline, []byte("=== API WEBSOCKET TIMELINE")) {
content.Write(apiWebsocketTimeline)
if !bytes.HasSuffix(apiWebsocketTimeline, []byte("\n")) {
content.WriteString("\n")
}
} else {
content.WriteString("=== API WEBSOCKET TIMELINE ===\n")
content.Write(apiWebsocketTimeline)
content.WriteString("\n")
}
content.WriteString("\n")
}
if len(apiRequest) > 0 { if len(apiRequest) > 0 {
if bytes.HasPrefix(apiRequest, []byte("=== API REQUEST")) { if bytes.HasPrefix(apiRequest, []byte("=== API REQUEST")) {
@@ -773,6 +951,12 @@ func (l *FileRequestLogger) formatLogContent(url, method string, headers map[str
content.WriteString("\n") content.WriteString("\n")
} }
if isWebsocketTranscript {
// Mirror writeNonStreamingLog: websocket transcripts end with the dedicated
// timeline sections instead of a generic downstream HTTP response block.
return content.String()
}
// Response section // Response section
content.WriteString("=== RESPONSE ===\n") content.WriteString("=== RESPONSE ===\n")
content.WriteString(fmt.Sprintf("Status: %d\n", status)) content.WriteString(fmt.Sprintf("Status: %d\n", status))
@@ -933,13 +1117,19 @@ func (l *FileRequestLogger) decompressZstd(data []byte) ([]byte, error) {
// //
// Returns: // Returns:
// - string: The formatted request information // - string: The formatted request information
func (l *FileRequestLogger) formatRequestInfo(url, method string, headers map[string][]string, body []byte) string { func (l *FileRequestLogger) formatRequestInfo(url, method string, headers map[string][]string, body []byte, downstreamTransport string, upstreamTransport string, includeBody bool) string {
var content strings.Builder var content strings.Builder
content.WriteString("=== REQUEST INFO ===\n") content.WriteString("=== REQUEST INFO ===\n")
content.WriteString(fmt.Sprintf("Version: %s\n", buildinfo.Version)) content.WriteString(fmt.Sprintf("Version: %s\n", buildinfo.Version))
content.WriteString(fmt.Sprintf("URL: %s\n", url)) content.WriteString(fmt.Sprintf("URL: %s\n", url))
content.WriteString(fmt.Sprintf("Method: %s\n", method)) content.WriteString(fmt.Sprintf("Method: %s\n", method))
if strings.TrimSpace(downstreamTransport) != "" {
content.WriteString(fmt.Sprintf("Downstream Transport: %s\n", downstreamTransport))
}
if strings.TrimSpace(upstreamTransport) != "" {
content.WriteString(fmt.Sprintf("Upstream Transport: %s\n", upstreamTransport))
}
content.WriteString(fmt.Sprintf("Timestamp: %s\n", time.Now().Format(time.RFC3339Nano))) content.WriteString(fmt.Sprintf("Timestamp: %s\n", time.Now().Format(time.RFC3339Nano)))
content.WriteString("\n") content.WriteString("\n")
@@ -952,6 +1142,10 @@ func (l *FileRequestLogger) formatRequestInfo(url, method string, headers map[st
} }
content.WriteString("\n") content.WriteString("\n")
if !includeBody {
return content.String()
}
content.WriteString("=== REQUEST BODY ===\n") content.WriteString("=== REQUEST BODY ===\n")
content.Write(body) content.Write(body)
content.WriteString("\n\n") content.WriteString("\n\n")
@@ -1011,6 +1205,9 @@ type FileStreamingLogWriter struct {
// apiResponse stores the upstream API response data. // apiResponse stores the upstream API response data.
apiResponse []byte apiResponse []byte
// apiWebsocketTimeline stores the upstream websocket event timeline.
apiWebsocketTimeline []byte
// apiResponseTimestamp captures when the API response was received. // apiResponseTimestamp captures when the API response was received.
apiResponseTimestamp time.Time apiResponseTimestamp time.Time
} }
@@ -1092,6 +1289,21 @@ func (w *FileStreamingLogWriter) WriteAPIResponse(apiResponse []byte) error {
return nil return nil
} }
// WriteAPIWebsocketTimeline buffers the upstream websocket timeline for later writing.
//
// Parameters:
// - apiWebsocketTimeline: The upstream websocket event timeline
//
// Returns:
// - error: Always returns nil (buffering cannot fail)
func (w *FileStreamingLogWriter) WriteAPIWebsocketTimeline(apiWebsocketTimeline []byte) error {
if len(apiWebsocketTimeline) == 0 {
return nil
}
w.apiWebsocketTimeline = bytes.Clone(apiWebsocketTimeline)
return nil
}
func (w *FileStreamingLogWriter) SetFirstChunkTimestamp(timestamp time.Time) { func (w *FileStreamingLogWriter) SetFirstChunkTimestamp(timestamp time.Time) {
if !timestamp.IsZero() { if !timestamp.IsZero() {
w.apiResponseTimestamp = timestamp w.apiResponseTimestamp = timestamp
@@ -1100,7 +1312,7 @@ func (w *FileStreamingLogWriter) SetFirstChunkTimestamp(timestamp time.Time) {
// Close finalizes the log file and cleans up resources. // Close finalizes the log file and cleans up resources.
// It writes all buffered data to the file in the correct order: // It writes all buffered data to the file in the correct order:
// API REQUEST -> API RESPONSE -> RESPONSE (status, headers, body chunks) // API WEBSOCKET TIMELINE -> API REQUEST -> API RESPONSE -> RESPONSE (status, headers, body chunks)
// //
// Returns: // Returns:
// - error: An error if closing fails, nil otherwise // - error: An error if closing fails, nil otherwise
@@ -1182,7 +1394,10 @@ func (w *FileStreamingLogWriter) asyncWriter() {
} }
func (w *FileStreamingLogWriter) writeFinalLog(logFile *os.File) error { func (w *FileStreamingLogWriter) writeFinalLog(logFile *os.File) error {
if errWrite := writeRequestInfoWithBody(logFile, w.url, w.method, w.requestHeaders, nil, w.requestBodyPath, w.timestamp); errWrite != nil { if errWrite := writeRequestInfoWithBody(logFile, w.url, w.method, w.requestHeaders, nil, w.requestBodyPath, w.timestamp, "http", inferUpstreamTransport(w.apiRequest, w.apiResponse, w.apiWebsocketTimeline, nil), true); errWrite != nil {
return errWrite
}
if errWrite := writeAPISection(logFile, "=== API WEBSOCKET TIMELINE ===\n", "=== API WEBSOCKET TIMELINE", w.apiWebsocketTimeline, time.Time{}); errWrite != nil {
return errWrite return errWrite
} }
if errWrite := writeAPISection(logFile, "=== API REQUEST ===\n", "=== API REQUEST", w.apiRequest, time.Time{}); errWrite != nil { if errWrite := writeAPISection(logFile, "=== API REQUEST ===\n", "=== API REQUEST", w.apiRequest, time.Time{}); errWrite != nil {
@@ -1265,6 +1480,17 @@ func (w *NoOpStreamingLogWriter) WriteAPIResponse(_ []byte) error {
return nil return nil
} }
// WriteAPIWebsocketTimeline is a no-op implementation that does nothing and always returns nil.
//
// Parameters:
// - apiWebsocketTimeline: The upstream websocket event timeline (ignored)
//
// Returns:
// - error: Always returns nil
func (w *NoOpStreamingLogWriter) WriteAPIWebsocketTimeline(_ []byte) error {
return nil
}
func (w *NoOpStreamingLogWriter) SetFirstChunkTimestamp(_ time.Time) {} func (w *NoOpStreamingLogWriter) SetFirstChunkTimestamp(_ time.Time) {}
// Close is a no-op implementation that does nothing and always returns nil. // Close is a no-op implementation that does nothing and always returns nil.

View File

@@ -219,7 +219,7 @@ func (e *CodexWebsocketsExecutor) Execute(ctx context.Context, auth *cliproxyaut
} }
wsReqBody := buildCodexWebsocketRequestBody(body) wsReqBody := buildCodexWebsocketRequestBody(body)
helps.RecordAPIRequest(ctx, e.cfg, helps.UpstreamRequestLog{ wsReqLog := helps.UpstreamRequestLog{
URL: wsURL, URL: wsURL,
Method: "WEBSOCKET", Method: "WEBSOCKET",
Headers: wsHeaders.Clone(), Headers: wsHeaders.Clone(),
@@ -229,16 +229,14 @@ func (e *CodexWebsocketsExecutor) Execute(ctx context.Context, auth *cliproxyaut
AuthLabel: authLabel, AuthLabel: authLabel,
AuthType: authType, AuthType: authType,
AuthValue: authValue, AuthValue: authValue,
}) }
helps.RecordAPIWebsocketRequest(ctx, e.cfg, wsReqLog)
conn, respHS, errDial := e.ensureUpstreamConn(ctx, auth, sess, authID, wsURL, wsHeaders) conn, respHS, errDial := e.ensureUpstreamConn(ctx, auth, sess, authID, wsURL, wsHeaders)
if respHS != nil {
helps.RecordAPIResponseMetadata(ctx, e.cfg, respHS.StatusCode, respHS.Header.Clone())
}
if errDial != nil { if errDial != nil {
bodyErr := websocketHandshakeBody(respHS) bodyErr := websocketHandshakeBody(respHS)
if len(bodyErr) > 0 { if respHS != nil {
helps.AppendAPIResponseChunk(ctx, e.cfg, bodyErr) helps.RecordAPIWebsocketUpgradeRejection(ctx, e.cfg, websocketUpgradeRequestLog(wsReqLog), respHS.StatusCode, respHS.Header.Clone(), bodyErr)
} }
if respHS != nil && respHS.StatusCode == http.StatusUpgradeRequired { if respHS != nil && respHS.StatusCode == http.StatusUpgradeRequired {
return e.CodexExecutor.Execute(ctx, auth, req, opts) return e.CodexExecutor.Execute(ctx, auth, req, opts)
@@ -246,10 +244,10 @@ func (e *CodexWebsocketsExecutor) Execute(ctx context.Context, auth *cliproxyaut
if respHS != nil && respHS.StatusCode > 0 { if respHS != nil && respHS.StatusCode > 0 {
return resp, statusErr{code: respHS.StatusCode, msg: string(bodyErr)} return resp, statusErr{code: respHS.StatusCode, msg: string(bodyErr)}
} }
helps.RecordAPIResponseError(ctx, e.cfg, errDial) helps.RecordAPIWebsocketError(ctx, e.cfg, "dial", errDial)
return resp, errDial return resp, errDial
} }
closeHTTPResponseBody(respHS, "codex websockets executor: close handshake response body error") recordAPIWebsocketHandshake(ctx, e.cfg, respHS)
if sess == nil { if sess == nil {
logCodexWebsocketConnected(executionSessionID, authID, wsURL) logCodexWebsocketConnected(executionSessionID, authID, wsURL)
defer func() { defer func() {
@@ -278,10 +276,10 @@ func (e *CodexWebsocketsExecutor) Execute(ctx context.Context, auth *cliproxyaut
// Retry once with a fresh websocket connection. This is mainly to handle // Retry once with a fresh websocket connection. This is mainly to handle
// upstream closing the socket between sequential requests within the same // upstream closing the socket between sequential requests within the same
// execution session. // execution session.
connRetry, _, errDialRetry := e.ensureUpstreamConn(ctx, auth, sess, authID, wsURL, wsHeaders) connRetry, respHSRetry, errDialRetry := e.ensureUpstreamConn(ctx, auth, sess, authID, wsURL, wsHeaders)
if errDialRetry == nil && connRetry != nil { if errDialRetry == nil && connRetry != nil {
wsReqBodyRetry := buildCodexWebsocketRequestBody(body) wsReqBodyRetry := buildCodexWebsocketRequestBody(body)
helps.RecordAPIRequest(ctx, e.cfg, helps.UpstreamRequestLog{ helps.RecordAPIWebsocketRequest(ctx, e.cfg, helps.UpstreamRequestLog{
URL: wsURL, URL: wsURL,
Method: "WEBSOCKET", Method: "WEBSOCKET",
Headers: wsHeaders.Clone(), Headers: wsHeaders.Clone(),
@@ -292,20 +290,22 @@ func (e *CodexWebsocketsExecutor) Execute(ctx context.Context, auth *cliproxyaut
AuthType: authType, AuthType: authType,
AuthValue: authValue, AuthValue: authValue,
}) })
recordAPIWebsocketHandshake(ctx, e.cfg, respHSRetry)
if errSendRetry := writeCodexWebsocketMessage(sess, connRetry, wsReqBodyRetry); errSendRetry == nil { if errSendRetry := writeCodexWebsocketMessage(sess, connRetry, wsReqBodyRetry); errSendRetry == nil {
conn = connRetry conn = connRetry
wsReqBody = wsReqBodyRetry wsReqBody = wsReqBodyRetry
} else { } else {
e.invalidateUpstreamConn(sess, connRetry, "send_error", errSendRetry) e.invalidateUpstreamConn(sess, connRetry, "send_error", errSendRetry)
helps.RecordAPIResponseError(ctx, e.cfg, errSendRetry) helps.RecordAPIWebsocketError(ctx, e.cfg, "send_retry", errSendRetry)
return resp, errSendRetry return resp, errSendRetry
} }
} else { } else {
helps.RecordAPIResponseError(ctx, e.cfg, errDialRetry) closeHTTPResponseBody(respHSRetry, "codex websockets executor: close handshake response body error")
helps.RecordAPIWebsocketError(ctx, e.cfg, "dial_retry", errDialRetry)
return resp, errDialRetry return resp, errDialRetry
} }
} else { } else {
helps.RecordAPIResponseError(ctx, e.cfg, errSend) helps.RecordAPIWebsocketError(ctx, e.cfg, "send", errSend)
return resp, errSend return resp, errSend
} }
} }
@@ -316,7 +316,7 @@ func (e *CodexWebsocketsExecutor) Execute(ctx context.Context, auth *cliproxyaut
} }
msgType, payload, errRead := readCodexWebsocketMessage(ctx, sess, conn, readCh) msgType, payload, errRead := readCodexWebsocketMessage(ctx, sess, conn, readCh)
if errRead != nil { if errRead != nil {
helps.RecordAPIResponseError(ctx, e.cfg, errRead) helps.RecordAPIWebsocketError(ctx, e.cfg, "read", errRead)
return resp, errRead return resp, errRead
} }
if msgType != websocket.TextMessage { if msgType != websocket.TextMessage {
@@ -325,7 +325,7 @@ func (e *CodexWebsocketsExecutor) Execute(ctx context.Context, auth *cliproxyaut
if sess != nil { if sess != nil {
e.invalidateUpstreamConn(sess, conn, "unexpected_binary", err) e.invalidateUpstreamConn(sess, conn, "unexpected_binary", err)
} }
helps.RecordAPIResponseError(ctx, e.cfg, err) helps.RecordAPIWebsocketError(ctx, e.cfg, "unexpected_binary", err)
return resp, err return resp, err
} }
continue continue
@@ -335,13 +335,13 @@ func (e *CodexWebsocketsExecutor) Execute(ctx context.Context, auth *cliproxyaut
if len(payload) == 0 { if len(payload) == 0 {
continue continue
} }
helps.AppendAPIResponseChunk(ctx, e.cfg, payload) helps.AppendAPIWebsocketResponse(ctx, e.cfg, payload)
if wsErr, ok := parseCodexWebsocketError(payload); ok { if wsErr, ok := parseCodexWebsocketError(payload); ok {
if sess != nil { if sess != nil {
e.invalidateUpstreamConn(sess, conn, "upstream_error", wsErr) e.invalidateUpstreamConn(sess, conn, "upstream_error", wsErr)
} }
helps.RecordAPIResponseError(ctx, e.cfg, wsErr) helps.RecordAPIWebsocketError(ctx, e.cfg, "upstream_error", wsErr)
return resp, wsErr return resp, wsErr
} }
@@ -413,7 +413,7 @@ func (e *CodexWebsocketsExecutor) ExecuteStream(ctx context.Context, auth *clipr
} }
wsReqBody := buildCodexWebsocketRequestBody(body) wsReqBody := buildCodexWebsocketRequestBody(body)
helps.RecordAPIRequest(ctx, e.cfg, helps.UpstreamRequestLog{ wsReqLog := helps.UpstreamRequestLog{
URL: wsURL, URL: wsURL,
Method: "WEBSOCKET", Method: "WEBSOCKET",
Headers: wsHeaders.Clone(), Headers: wsHeaders.Clone(),
@@ -423,18 +423,18 @@ func (e *CodexWebsocketsExecutor) ExecuteStream(ctx context.Context, auth *clipr
AuthLabel: authLabel, AuthLabel: authLabel,
AuthType: authType, AuthType: authType,
AuthValue: authValue, AuthValue: authValue,
}) }
helps.RecordAPIWebsocketRequest(ctx, e.cfg, wsReqLog)
conn, respHS, errDial := e.ensureUpstreamConn(ctx, auth, sess, authID, wsURL, wsHeaders) conn, respHS, errDial := e.ensureUpstreamConn(ctx, auth, sess, authID, wsURL, wsHeaders)
var upstreamHeaders http.Header var upstreamHeaders http.Header
if respHS != nil { if respHS != nil {
upstreamHeaders = respHS.Header.Clone() upstreamHeaders = respHS.Header.Clone()
helps.RecordAPIResponseMetadata(ctx, e.cfg, respHS.StatusCode, respHS.Header.Clone())
} }
if errDial != nil { if errDial != nil {
bodyErr := websocketHandshakeBody(respHS) bodyErr := websocketHandshakeBody(respHS)
if len(bodyErr) > 0 { if respHS != nil {
helps.AppendAPIResponseChunk(ctx, e.cfg, bodyErr) helps.RecordAPIWebsocketUpgradeRejection(ctx, e.cfg, websocketUpgradeRequestLog(wsReqLog), respHS.StatusCode, respHS.Header.Clone(), bodyErr)
} }
if respHS != nil && respHS.StatusCode == http.StatusUpgradeRequired { if respHS != nil && respHS.StatusCode == http.StatusUpgradeRequired {
return e.CodexExecutor.ExecuteStream(ctx, auth, req, opts) return e.CodexExecutor.ExecuteStream(ctx, auth, req, opts)
@@ -442,13 +442,13 @@ func (e *CodexWebsocketsExecutor) ExecuteStream(ctx context.Context, auth *clipr
if respHS != nil && respHS.StatusCode > 0 { if respHS != nil && respHS.StatusCode > 0 {
return nil, statusErr{code: respHS.StatusCode, msg: string(bodyErr)} return nil, statusErr{code: respHS.StatusCode, msg: string(bodyErr)}
} }
helps.RecordAPIResponseError(ctx, e.cfg, errDial) helps.RecordAPIWebsocketError(ctx, e.cfg, "dial", errDial)
if sess != nil { if sess != nil {
sess.reqMu.Unlock() sess.reqMu.Unlock()
} }
return nil, errDial return nil, errDial
} }
closeHTTPResponseBody(respHS, "codex websockets executor: close handshake response body error") recordAPIWebsocketHandshake(ctx, e.cfg, respHS)
if sess == nil { if sess == nil {
logCodexWebsocketConnected(executionSessionID, authID, wsURL) logCodexWebsocketConnected(executionSessionID, authID, wsURL)
@@ -461,20 +461,21 @@ func (e *CodexWebsocketsExecutor) ExecuteStream(ctx context.Context, auth *clipr
} }
if errSend := writeCodexWebsocketMessage(sess, conn, wsReqBody); errSend != nil { if errSend := writeCodexWebsocketMessage(sess, conn, wsReqBody); errSend != nil {
helps.RecordAPIResponseError(ctx, e.cfg, errSend) helps.RecordAPIWebsocketError(ctx, e.cfg, "send", errSend)
if sess != nil { if sess != nil {
e.invalidateUpstreamConn(sess, conn, "send_error", errSend) e.invalidateUpstreamConn(sess, conn, "send_error", errSend)
// Retry once with a new websocket connection for the same execution session. // Retry once with a new websocket connection for the same execution session.
connRetry, _, errDialRetry := e.ensureUpstreamConn(ctx, auth, sess, authID, wsURL, wsHeaders) connRetry, respHSRetry, errDialRetry := e.ensureUpstreamConn(ctx, auth, sess, authID, wsURL, wsHeaders)
if errDialRetry != nil || connRetry == nil { if errDialRetry != nil || connRetry == nil {
helps.RecordAPIResponseError(ctx, e.cfg, errDialRetry) closeHTTPResponseBody(respHSRetry, "codex websockets executor: close handshake response body error")
helps.RecordAPIWebsocketError(ctx, e.cfg, "dial_retry", errDialRetry)
sess.clearActive(readCh) sess.clearActive(readCh)
sess.reqMu.Unlock() sess.reqMu.Unlock()
return nil, errDialRetry return nil, errDialRetry
} }
wsReqBodyRetry := buildCodexWebsocketRequestBody(body) wsReqBodyRetry := buildCodexWebsocketRequestBody(body)
helps.RecordAPIRequest(ctx, e.cfg, helps.UpstreamRequestLog{ helps.RecordAPIWebsocketRequest(ctx, e.cfg, helps.UpstreamRequestLog{
URL: wsURL, URL: wsURL,
Method: "WEBSOCKET", Method: "WEBSOCKET",
Headers: wsHeaders.Clone(), Headers: wsHeaders.Clone(),
@@ -485,8 +486,9 @@ func (e *CodexWebsocketsExecutor) ExecuteStream(ctx context.Context, auth *clipr
AuthType: authType, AuthType: authType,
AuthValue: authValue, AuthValue: authValue,
}) })
recordAPIWebsocketHandshake(ctx, e.cfg, respHSRetry)
if errSendRetry := writeCodexWebsocketMessage(sess, connRetry, wsReqBodyRetry); errSendRetry != nil { if errSendRetry := writeCodexWebsocketMessage(sess, connRetry, wsReqBodyRetry); errSendRetry != nil {
helps.RecordAPIResponseError(ctx, e.cfg, errSendRetry) helps.RecordAPIWebsocketError(ctx, e.cfg, "send_retry", errSendRetry)
e.invalidateUpstreamConn(sess, connRetry, "send_error", errSendRetry) e.invalidateUpstreamConn(sess, connRetry, "send_error", errSendRetry)
sess.clearActive(readCh) sess.clearActive(readCh)
sess.reqMu.Unlock() sess.reqMu.Unlock()
@@ -552,7 +554,7 @@ func (e *CodexWebsocketsExecutor) ExecuteStream(ctx context.Context, auth *clipr
} }
terminateReason = "read_error" terminateReason = "read_error"
terminateErr = errRead terminateErr = errRead
helps.RecordAPIResponseError(ctx, e.cfg, errRead) helps.RecordAPIWebsocketError(ctx, e.cfg, "read", errRead)
reporter.PublishFailure(ctx) reporter.PublishFailure(ctx)
_ = send(cliproxyexecutor.StreamChunk{Err: errRead}) _ = send(cliproxyexecutor.StreamChunk{Err: errRead})
return return
@@ -562,7 +564,7 @@ func (e *CodexWebsocketsExecutor) ExecuteStream(ctx context.Context, auth *clipr
err = fmt.Errorf("codex websockets executor: unexpected binary message") err = fmt.Errorf("codex websockets executor: unexpected binary message")
terminateReason = "unexpected_binary" terminateReason = "unexpected_binary"
terminateErr = err terminateErr = err
helps.RecordAPIResponseError(ctx, e.cfg, err) helps.RecordAPIWebsocketError(ctx, e.cfg, "unexpected_binary", err)
reporter.PublishFailure(ctx) reporter.PublishFailure(ctx)
if sess != nil { if sess != nil {
e.invalidateUpstreamConn(sess, conn, "unexpected_binary", err) e.invalidateUpstreamConn(sess, conn, "unexpected_binary", err)
@@ -577,12 +579,12 @@ func (e *CodexWebsocketsExecutor) ExecuteStream(ctx context.Context, auth *clipr
if len(payload) == 0 { if len(payload) == 0 {
continue continue
} }
helps.AppendAPIResponseChunk(ctx, e.cfg, payload) helps.AppendAPIWebsocketResponse(ctx, e.cfg, payload)
if wsErr, ok := parseCodexWebsocketError(payload); ok { if wsErr, ok := parseCodexWebsocketError(payload); ok {
terminateReason = "upstream_error" terminateReason = "upstream_error"
terminateErr = wsErr terminateErr = wsErr
helps.RecordAPIResponseError(ctx, e.cfg, wsErr) helps.RecordAPIWebsocketError(ctx, e.cfg, "upstream_error", wsErr)
reporter.PublishFailure(ctx) reporter.PublishFailure(ctx)
if sess != nil { if sess != nil {
e.invalidateUpstreamConn(sess, conn, "upstream_error", wsErr) e.invalidateUpstreamConn(sess, conn, "upstream_error", wsErr)
@@ -1022,6 +1024,32 @@ func encodeCodexWebsocketAsSSE(payload []byte) []byte {
return line return line
} }
func websocketUpgradeRequestLog(info helps.UpstreamRequestLog) helps.UpstreamRequestLog {
upgradeInfo := info
upgradeInfo.URL = helps.WebsocketUpgradeRequestURL(info.URL)
upgradeInfo.Method = http.MethodGet
upgradeInfo.Body = nil
upgradeInfo.Headers = info.Headers.Clone()
if upgradeInfo.Headers == nil {
upgradeInfo.Headers = make(http.Header)
}
if strings.TrimSpace(upgradeInfo.Headers.Get("Connection")) == "" {
upgradeInfo.Headers.Set("Connection", "Upgrade")
}
if strings.TrimSpace(upgradeInfo.Headers.Get("Upgrade")) == "" {
upgradeInfo.Headers.Set("Upgrade", "websocket")
}
return upgradeInfo
}
func recordAPIWebsocketHandshake(ctx context.Context, cfg *config.Config, resp *http.Response) {
if resp == nil {
return
}
helps.RecordAPIWebsocketHandshake(ctx, cfg, resp.StatusCode, resp.Header.Clone())
closeHTTPResponseBody(resp, "codex websockets executor: close handshake response body error")
}
func websocketHandshakeBody(resp *http.Response) []byte { func websocketHandshakeBody(resp *http.Response) []byte {
if resp == nil || resp.Body == nil { if resp == nil || resp.Body == nil {
return nil return nil

View File

@@ -6,6 +6,7 @@ import (
"fmt" "fmt"
"html" "html"
"net/http" "net/http"
"net/url"
"sort" "sort"
"strings" "strings"
"time" "time"
@@ -19,9 +20,10 @@ import (
) )
const ( const (
apiAttemptsKey = "API_UPSTREAM_ATTEMPTS" apiAttemptsKey = "API_UPSTREAM_ATTEMPTS"
apiRequestKey = "API_REQUEST" apiRequestKey = "API_REQUEST"
apiResponseKey = "API_RESPONSE" apiResponseKey = "API_RESPONSE"
apiWebsocketTimelineKey = "API_WEBSOCKET_TIMELINE"
) )
// UpstreamRequestLog captures the outbound upstream request details for logging. // UpstreamRequestLog captures the outbound upstream request details for logging.
@@ -46,6 +48,7 @@ type upstreamAttempt struct {
headersWritten bool headersWritten bool
bodyStarted bool bodyStarted bool
bodyHasContent bool bodyHasContent bool
prevWasSSEEvent bool
errorWritten bool errorWritten bool
} }
@@ -173,15 +176,157 @@ func AppendAPIResponseChunk(ctx context.Context, cfg *config.Config, chunk []byt
attempt.response.WriteString("Body:\n") attempt.response.WriteString("Body:\n")
attempt.bodyStarted = true attempt.bodyStarted = true
} }
currentChunkIsSSEEvent := bytes.HasPrefix(data, []byte("event:"))
currentChunkIsSSEData := bytes.HasPrefix(data, []byte("data:"))
if attempt.bodyHasContent { if attempt.bodyHasContent {
attempt.response.WriteString("\n\n") separator := "\n\n"
if attempt.prevWasSSEEvent && currentChunkIsSSEData {
separator = "\n"
}
attempt.response.WriteString(separator)
} }
attempt.response.WriteString(string(data)) attempt.response.WriteString(string(data))
attempt.bodyHasContent = true attempt.bodyHasContent = true
attempt.prevWasSSEEvent = currentChunkIsSSEEvent
updateAggregatedResponse(ginCtx, attempts) updateAggregatedResponse(ginCtx, attempts)
} }
// RecordAPIWebsocketRequest stores an upstream websocket request event in Gin context.
func RecordAPIWebsocketRequest(ctx context.Context, cfg *config.Config, info UpstreamRequestLog) {
if cfg == nil || !cfg.RequestLog {
return
}
ginCtx := ginContextFrom(ctx)
if ginCtx == nil {
return
}
builder := &strings.Builder{}
builder.WriteString(fmt.Sprintf("Timestamp: %s\n", time.Now().Format(time.RFC3339Nano)))
builder.WriteString("Event: api.websocket.request\n")
if info.URL != "" {
builder.WriteString(fmt.Sprintf("Upstream URL: %s\n", info.URL))
}
if auth := formatAuthInfo(info); auth != "" {
builder.WriteString(fmt.Sprintf("Auth: %s\n", auth))
}
builder.WriteString("Headers:\n")
writeHeaders(builder, info.Headers)
builder.WriteString("\nBody:\n")
if len(info.Body) > 0 {
builder.Write(info.Body)
} else {
builder.WriteString("<empty>")
}
builder.WriteString("\n")
appendAPIWebsocketTimeline(ginCtx, []byte(builder.String()))
}
// RecordAPIWebsocketHandshake stores the upstream websocket handshake response metadata.
func RecordAPIWebsocketHandshake(ctx context.Context, cfg *config.Config, status int, headers http.Header) {
if cfg == nil || !cfg.RequestLog {
return
}
ginCtx := ginContextFrom(ctx)
if ginCtx == nil {
return
}
builder := &strings.Builder{}
builder.WriteString(fmt.Sprintf("Timestamp: %s\n", time.Now().Format(time.RFC3339Nano)))
builder.WriteString("Event: api.websocket.handshake\n")
if status > 0 {
builder.WriteString(fmt.Sprintf("Status: %d\n", status))
}
builder.WriteString("Headers:\n")
writeHeaders(builder, headers)
builder.WriteString("\n")
appendAPIWebsocketTimeline(ginCtx, []byte(builder.String()))
}
// RecordAPIWebsocketUpgradeRejection stores a rejected websocket upgrade as an HTTP attempt.
func RecordAPIWebsocketUpgradeRejection(ctx context.Context, cfg *config.Config, info UpstreamRequestLog, status int, headers http.Header, body []byte) {
if cfg == nil || !cfg.RequestLog {
return
}
ginCtx := ginContextFrom(ctx)
if ginCtx == nil {
return
}
RecordAPIRequest(ctx, cfg, info)
RecordAPIResponseMetadata(ctx, cfg, status, headers)
AppendAPIResponseChunk(ctx, cfg, body)
}
// WebsocketUpgradeRequestURL converts a websocket URL back to its HTTP handshake URL for logging.
func WebsocketUpgradeRequestURL(rawURL string) string {
trimmedURL := strings.TrimSpace(rawURL)
if trimmedURL == "" {
return ""
}
parsed, err := url.Parse(trimmedURL)
if err != nil {
return trimmedURL
}
switch strings.ToLower(parsed.Scheme) {
case "ws":
parsed.Scheme = "http"
case "wss":
parsed.Scheme = "https"
}
return parsed.String()
}
// AppendAPIWebsocketResponse stores an upstream websocket response frame in Gin context.
func AppendAPIWebsocketResponse(ctx context.Context, cfg *config.Config, payload []byte) {
if cfg == nil || !cfg.RequestLog {
return
}
data := bytes.TrimSpace(payload)
if len(data) == 0 {
return
}
ginCtx := ginContextFrom(ctx)
if ginCtx == nil {
return
}
markAPIResponseTimestamp(ginCtx)
builder := &strings.Builder{}
builder.WriteString(fmt.Sprintf("Timestamp: %s\n", time.Now().Format(time.RFC3339Nano)))
builder.WriteString("Event: api.websocket.response\n")
builder.Write(data)
builder.WriteString("\n")
appendAPIWebsocketTimeline(ginCtx, []byte(builder.String()))
}
// RecordAPIWebsocketError stores an upstream websocket error event in Gin context.
func RecordAPIWebsocketError(ctx context.Context, cfg *config.Config, stage string, err error) {
if cfg == nil || !cfg.RequestLog || err == nil {
return
}
ginCtx := ginContextFrom(ctx)
if ginCtx == nil {
return
}
markAPIResponseTimestamp(ginCtx)
builder := &strings.Builder{}
builder.WriteString(fmt.Sprintf("Timestamp: %s\n", time.Now().Format(time.RFC3339Nano)))
builder.WriteString("Event: api.websocket.error\n")
if trimmed := strings.TrimSpace(stage); trimmed != "" {
builder.WriteString(fmt.Sprintf("Stage: %s\n", trimmed))
}
builder.WriteString(fmt.Sprintf("Error: %s\n", err.Error()))
appendAPIWebsocketTimeline(ginCtx, []byte(builder.String()))
}
func ginContextFrom(ctx context.Context) *gin.Context { func ginContextFrom(ctx context.Context) *gin.Context {
ginCtx, _ := ctx.Value("gin").(*gin.Context) ginCtx, _ := ctx.Value("gin").(*gin.Context)
return ginCtx return ginCtx
@@ -259,6 +404,40 @@ func updateAggregatedResponse(ginCtx *gin.Context, attempts []*upstreamAttempt)
ginCtx.Set(apiResponseKey, []byte(builder.String())) ginCtx.Set(apiResponseKey, []byte(builder.String()))
} }
func appendAPIWebsocketTimeline(ginCtx *gin.Context, chunk []byte) {
if ginCtx == nil {
return
}
data := bytes.TrimSpace(chunk)
if len(data) == 0 {
return
}
if existing, exists := ginCtx.Get(apiWebsocketTimelineKey); exists {
if existingBytes, ok := existing.([]byte); ok && len(existingBytes) > 0 {
combined := make([]byte, 0, len(existingBytes)+len(data)+2)
combined = append(combined, existingBytes...)
if !bytes.HasSuffix(existingBytes, []byte("\n")) {
combined = append(combined, '\n')
}
combined = append(combined, '\n')
combined = append(combined, data...)
ginCtx.Set(apiWebsocketTimelineKey, combined)
return
}
}
ginCtx.Set(apiWebsocketTimelineKey, bytes.Clone(data))
}
func markAPIResponseTimestamp(ginCtx *gin.Context) {
if ginCtx == nil {
return
}
if _, exists := ginCtx.Get("API_RESPONSE_TIMESTAMP"); exists {
return
}
ginCtx.Set("API_RESPONSE_TIMESTAMP", time.Now())
}
func writeHeaders(builder *strings.Builder, headers http.Header) { func writeHeaders(builder *strings.Builder, headers http.Header) {
if builder == nil { if builder == nil {
return return

View File

@@ -32,7 +32,7 @@ const (
wsEventTypeCompleted = "response.completed" wsEventTypeCompleted = "response.completed"
wsDoneMarker = "[DONE]" wsDoneMarker = "[DONE]"
wsTurnStateHeader = "x-codex-turn-state" wsTurnStateHeader = "x-codex-turn-state"
wsRequestBodyKey = "REQUEST_BODY_OVERRIDE" wsTimelineBodyKey = "WEBSOCKET_TIMELINE_OVERRIDE"
) )
var responsesWebsocketUpgrader = websocket.Upgrader{ var responsesWebsocketUpgrader = websocket.Upgrader{
@@ -57,10 +57,11 @@ func (h *OpenAIResponsesAPIHandler) ResponsesWebsocket(c *gin.Context) {
clientIP := websocketClientAddress(c) clientIP := websocketClientAddress(c)
log.Infof("responses websocket: client connected id=%s remote=%s", passthroughSessionID, clientIP) log.Infof("responses websocket: client connected id=%s remote=%s", passthroughSessionID, clientIP)
var wsTerminateErr error var wsTerminateErr error
var wsBodyLog strings.Builder var wsTimelineLog strings.Builder
defer func() { defer func() {
releaseResponsesWebsocketToolCaches(downstreamSessionKey) releaseResponsesWebsocketToolCaches(downstreamSessionKey)
if wsTerminateErr != nil { if wsTerminateErr != nil {
appendWebsocketTimelineDisconnect(&wsTimelineLog, wsTerminateErr, time.Now())
// log.Infof("responses websocket: session closing id=%s reason=%v", passthroughSessionID, wsTerminateErr) // log.Infof("responses websocket: session closing id=%s reason=%v", passthroughSessionID, wsTerminateErr)
} else { } else {
log.Infof("responses websocket: session closing id=%s", passthroughSessionID) log.Infof("responses websocket: session closing id=%s", passthroughSessionID)
@@ -69,7 +70,7 @@ func (h *OpenAIResponsesAPIHandler) ResponsesWebsocket(c *gin.Context) {
h.AuthManager.CloseExecutionSession(passthroughSessionID) h.AuthManager.CloseExecutionSession(passthroughSessionID)
log.Infof("responses websocket: upstream execution session closed id=%s", passthroughSessionID) log.Infof("responses websocket: upstream execution session closed id=%s", passthroughSessionID)
} }
setWebsocketRequestBody(c, wsBodyLog.String()) setWebsocketTimelineBody(c, wsTimelineLog.String())
if errClose := conn.Close(); errClose != nil { if errClose := conn.Close(); errClose != nil {
log.Warnf("responses websocket: close connection error: %v", errClose) log.Warnf("responses websocket: close connection error: %v", errClose)
} }
@@ -83,7 +84,6 @@ func (h *OpenAIResponsesAPIHandler) ResponsesWebsocket(c *gin.Context) {
msgType, payload, errReadMessage := conn.ReadMessage() msgType, payload, errReadMessage := conn.ReadMessage()
if errReadMessage != nil { if errReadMessage != nil {
wsTerminateErr = errReadMessage wsTerminateErr = errReadMessage
appendWebsocketEvent(&wsBodyLog, "disconnect", []byte(errReadMessage.Error()))
if websocket.IsCloseError(errReadMessage, websocket.CloseNormalClosure, websocket.CloseGoingAway, websocket.CloseNoStatusReceived) { if websocket.IsCloseError(errReadMessage, websocket.CloseNormalClosure, websocket.CloseGoingAway, websocket.CloseNoStatusReceived) {
log.Infof("responses websocket: client disconnected id=%s error=%v", passthroughSessionID, errReadMessage) log.Infof("responses websocket: client disconnected id=%s error=%v", passthroughSessionID, errReadMessage)
} else { } else {
@@ -101,7 +101,7 @@ func (h *OpenAIResponsesAPIHandler) ResponsesWebsocket(c *gin.Context) {
// websocketPayloadEventType(payload), // websocketPayloadEventType(payload),
// websocketPayloadPreview(payload), // websocketPayloadPreview(payload),
// ) // )
appendWebsocketEvent(&wsBodyLog, "request", payload) appendWebsocketTimelineEvent(&wsTimelineLog, "request", payload, time.Now())
allowIncrementalInputWithPreviousResponseID := false allowIncrementalInputWithPreviousResponseID := false
if pinnedAuthID != "" && h != nil && h.AuthManager != nil { if pinnedAuthID != "" && h != nil && h.AuthManager != nil {
@@ -128,8 +128,7 @@ func (h *OpenAIResponsesAPIHandler) ResponsesWebsocket(c *gin.Context) {
if errMsg != nil { if errMsg != nil {
h.LoggingAPIResponseError(context.WithValue(context.Background(), "gin", c), errMsg) h.LoggingAPIResponseError(context.WithValue(context.Background(), "gin", c), errMsg)
markAPIResponseTimestamp(c) markAPIResponseTimestamp(c)
errorPayload, errWrite := writeResponsesWebsocketError(conn, errMsg) errorPayload, errWrite := writeResponsesWebsocketError(conn, &wsTimelineLog, errMsg)
appendWebsocketEvent(&wsBodyLog, "response", errorPayload)
log.Infof( log.Infof(
"responses websocket: downstream_out id=%s type=%d event=%s payload=%s", "responses websocket: downstream_out id=%s type=%d event=%s payload=%s",
passthroughSessionID, passthroughSessionID,
@@ -157,9 +156,8 @@ func (h *OpenAIResponsesAPIHandler) ResponsesWebsocket(c *gin.Context) {
} }
lastRequest = updatedLastRequest lastRequest = updatedLastRequest
lastResponseOutput = []byte("[]") lastResponseOutput = []byte("[]")
if errWrite := writeResponsesWebsocketSyntheticPrewarm(c, conn, requestJSON, &wsBodyLog, passthroughSessionID); errWrite != nil { if errWrite := writeResponsesWebsocketSyntheticPrewarm(c, conn, requestJSON, &wsTimelineLog, passthroughSessionID); errWrite != nil {
wsTerminateErr = errWrite wsTerminateErr = errWrite
appendWebsocketEvent(&wsBodyLog, "disconnect", []byte(errWrite.Error()))
return return
} }
continue continue
@@ -192,10 +190,9 @@ func (h *OpenAIResponsesAPIHandler) ResponsesWebsocket(c *gin.Context) {
} }
dataChan, _, errChan := h.ExecuteStreamWithAuthManager(cliCtx, h.HandlerType(), modelName, requestJSON, "") dataChan, _, errChan := h.ExecuteStreamWithAuthManager(cliCtx, h.HandlerType(), modelName, requestJSON, "")
completedOutput, errForward := h.forwardResponsesWebsocket(c, conn, cliCancel, dataChan, errChan, &wsBodyLog, passthroughSessionID) completedOutput, errForward := h.forwardResponsesWebsocket(c, conn, cliCancel, dataChan, errChan, &wsTimelineLog, passthroughSessionID)
if errForward != nil { if errForward != nil {
wsTerminateErr = errForward wsTerminateErr = errForward
appendWebsocketEvent(&wsBodyLog, "disconnect", []byte(errForward.Error()))
log.Warnf("responses websocket: forward failed id=%s error=%v", passthroughSessionID, errForward) log.Warnf("responses websocket: forward failed id=%s error=%v", passthroughSessionID, errForward)
return return
} }
@@ -597,7 +594,7 @@ func writeResponsesWebsocketSyntheticPrewarm(
c *gin.Context, c *gin.Context,
conn *websocket.Conn, conn *websocket.Conn,
requestJSON []byte, requestJSON []byte,
wsBodyLog *strings.Builder, wsTimelineLog *strings.Builder,
sessionID string, sessionID string,
) error { ) error {
payloads, errPayloads := syntheticResponsesWebsocketPrewarmPayloads(requestJSON) payloads, errPayloads := syntheticResponsesWebsocketPrewarmPayloads(requestJSON)
@@ -606,7 +603,6 @@ func writeResponsesWebsocketSyntheticPrewarm(
} }
for i := 0; i < len(payloads); i++ { for i := 0; i < len(payloads); i++ {
markAPIResponseTimestamp(c) markAPIResponseTimestamp(c)
appendWebsocketEvent(wsBodyLog, "response", payloads[i])
// log.Infof( // log.Infof(
// "responses websocket: downstream_out id=%s type=%d event=%s payload=%s", // "responses websocket: downstream_out id=%s type=%d event=%s payload=%s",
// sessionID, // sessionID,
@@ -614,7 +610,7 @@ func writeResponsesWebsocketSyntheticPrewarm(
// websocketPayloadEventType(payloads[i]), // websocketPayloadEventType(payloads[i]),
// websocketPayloadPreview(payloads[i]), // websocketPayloadPreview(payloads[i]),
// ) // )
if errWrite := conn.WriteMessage(websocket.TextMessage, payloads[i]); errWrite != nil { if errWrite := writeResponsesWebsocketPayload(conn, wsTimelineLog, payloads[i], time.Now()); errWrite != nil {
log.Warnf( log.Warnf(
"responses websocket: downstream_out write failed id=%s event=%s error=%v", "responses websocket: downstream_out write failed id=%s event=%s error=%v",
sessionID, sessionID,
@@ -713,7 +709,7 @@ func (h *OpenAIResponsesAPIHandler) forwardResponsesWebsocket(
cancel handlers.APIHandlerCancelFunc, cancel handlers.APIHandlerCancelFunc,
data <-chan []byte, data <-chan []byte,
errs <-chan *interfaces.ErrorMessage, errs <-chan *interfaces.ErrorMessage,
wsBodyLog *strings.Builder, wsTimelineLog *strings.Builder,
sessionID string, sessionID string,
) ([]byte, error) { ) ([]byte, error) {
completed := false completed := false
@@ -736,8 +732,7 @@ func (h *OpenAIResponsesAPIHandler) forwardResponsesWebsocket(
if errMsg != nil { if errMsg != nil {
h.LoggingAPIResponseError(context.WithValue(context.Background(), "gin", c), errMsg) h.LoggingAPIResponseError(context.WithValue(context.Background(), "gin", c), errMsg)
markAPIResponseTimestamp(c) markAPIResponseTimestamp(c)
errorPayload, errWrite := writeResponsesWebsocketError(conn, errMsg) errorPayload, errWrite := writeResponsesWebsocketError(conn, wsTimelineLog, errMsg)
appendWebsocketEvent(wsBodyLog, "response", errorPayload)
log.Infof( log.Infof(
"responses websocket: downstream_out id=%s type=%d event=%s payload=%s", "responses websocket: downstream_out id=%s type=%d event=%s payload=%s",
sessionID, sessionID,
@@ -771,8 +766,7 @@ func (h *OpenAIResponsesAPIHandler) forwardResponsesWebsocket(
} }
h.LoggingAPIResponseError(context.WithValue(context.Background(), "gin", c), errMsg) h.LoggingAPIResponseError(context.WithValue(context.Background(), "gin", c), errMsg)
markAPIResponseTimestamp(c) markAPIResponseTimestamp(c)
errorPayload, errWrite := writeResponsesWebsocketError(conn, errMsg) errorPayload, errWrite := writeResponsesWebsocketError(conn, wsTimelineLog, errMsg)
appendWebsocketEvent(wsBodyLog, "response", errorPayload)
log.Infof( log.Infof(
"responses websocket: downstream_out id=%s type=%d event=%s payload=%s", "responses websocket: downstream_out id=%s type=%d event=%s payload=%s",
sessionID, sessionID,
@@ -806,7 +800,6 @@ func (h *OpenAIResponsesAPIHandler) forwardResponsesWebsocket(
completedOutput = responseCompletedOutputFromPayload(payloads[i]) completedOutput = responseCompletedOutputFromPayload(payloads[i])
} }
markAPIResponseTimestamp(c) markAPIResponseTimestamp(c)
appendWebsocketEvent(wsBodyLog, "response", payloads[i])
// log.Infof( // log.Infof(
// "responses websocket: downstream_out id=%s type=%d event=%s payload=%s", // "responses websocket: downstream_out id=%s type=%d event=%s payload=%s",
// sessionID, // sessionID,
@@ -814,7 +807,7 @@ func (h *OpenAIResponsesAPIHandler) forwardResponsesWebsocket(
// websocketPayloadEventType(payloads[i]), // websocketPayloadEventType(payloads[i]),
// websocketPayloadPreview(payloads[i]), // websocketPayloadPreview(payloads[i]),
// ) // )
if errWrite := conn.WriteMessage(websocket.TextMessage, payloads[i]); errWrite != nil { if errWrite := writeResponsesWebsocketPayload(conn, wsTimelineLog, payloads[i], time.Now()); errWrite != nil {
log.Warnf( log.Warnf(
"responses websocket: downstream_out write failed id=%s event=%s error=%v", "responses websocket: downstream_out write failed id=%s event=%s error=%v",
sessionID, sessionID,
@@ -870,7 +863,7 @@ func websocketJSONPayloadsFromChunk(chunk []byte) [][]byte {
return payloads return payloads
} }
func writeResponsesWebsocketError(conn *websocket.Conn, errMsg *interfaces.ErrorMessage) ([]byte, error) { func writeResponsesWebsocketError(conn *websocket.Conn, wsTimelineLog *strings.Builder, errMsg *interfaces.ErrorMessage) ([]byte, error) {
status := http.StatusInternalServerError status := http.StatusInternalServerError
errText := http.StatusText(status) errText := http.StatusText(status)
if errMsg != nil { if errMsg != nil {
@@ -940,7 +933,7 @@ func writeResponsesWebsocketError(conn *websocket.Conn, errMsg *interfaces.Error
} }
} }
return payload, conn.WriteMessage(websocket.TextMessage, payload) return payload, writeResponsesWebsocketPayload(conn, wsTimelineLog, payload, time.Now())
} }
func appendWebsocketEvent(builder *strings.Builder, eventType string, payload []byte) { func appendWebsocketEvent(builder *strings.Builder, eventType string, payload []byte) {
@@ -979,7 +972,11 @@ func websocketPayloadPreview(payload []byte) string {
return previewText return previewText
} }
func setWebsocketRequestBody(c *gin.Context, body string) { func setWebsocketTimelineBody(c *gin.Context, body string) {
setWebsocketBody(c, wsTimelineBodyKey, body)
}
func setWebsocketBody(c *gin.Context, key string, body string) {
if c == nil { if c == nil {
return return
} }
@@ -987,7 +984,40 @@ func setWebsocketRequestBody(c *gin.Context, body string) {
if trimmedBody == "" { if trimmedBody == "" {
return return
} }
c.Set(wsRequestBodyKey, []byte(trimmedBody)) c.Set(key, []byte(trimmedBody))
}
func writeResponsesWebsocketPayload(conn *websocket.Conn, wsTimelineLog *strings.Builder, payload []byte, timestamp time.Time) error {
appendWebsocketTimelineEvent(wsTimelineLog, "response", payload, timestamp)
return conn.WriteMessage(websocket.TextMessage, payload)
}
func appendWebsocketTimelineDisconnect(builder *strings.Builder, err error, timestamp time.Time) {
if err == nil {
return
}
appendWebsocketTimelineEvent(builder, "disconnect", []byte(err.Error()), timestamp)
}
func appendWebsocketTimelineEvent(builder *strings.Builder, eventType string, payload []byte, timestamp time.Time) {
if builder == nil {
return
}
trimmedPayload := bytes.TrimSpace(payload)
if len(trimmedPayload) == 0 {
return
}
if builder.Len() > 0 {
builder.WriteString("\n")
}
builder.WriteString("Timestamp: ")
builder.WriteString(timestamp.Format(time.RFC3339Nano))
builder.WriteString("\n")
builder.WriteString("Event: websocket.")
builder.WriteString(eventType)
builder.WriteString("\n")
builder.Write(trimmedPayload)
builder.WriteString("\n")
} }
func markAPIResponseTimestamp(c *gin.Context) { func markAPIResponseTimestamp(c *gin.Context) {

View File

@@ -392,27 +392,45 @@ func TestAppendWebsocketEvent(t *testing.T) {
} }
} }
func TestSetWebsocketRequestBody(t *testing.T) { func TestAppendWebsocketTimelineEvent(t *testing.T) {
var builder strings.Builder
ts := time.Date(2026, time.April, 1, 12, 34, 56, 789000000, time.UTC)
appendWebsocketTimelineEvent(&builder, "request", []byte(" {\"type\":\"response.create\"}\n"), ts)
got := builder.String()
if !strings.Contains(got, "Timestamp: 2026-04-01T12:34:56.789Z") {
t.Fatalf("timeline timestamp not found: %s", got)
}
if !strings.Contains(got, "Event: websocket.request") {
t.Fatalf("timeline event not found: %s", got)
}
if !strings.Contains(got, "{\"type\":\"response.create\"}") {
t.Fatalf("timeline payload not found: %s", got)
}
}
func TestSetWebsocketTimelineBody(t *testing.T) {
gin.SetMode(gin.TestMode) gin.SetMode(gin.TestMode)
recorder := httptest.NewRecorder() recorder := httptest.NewRecorder()
c, _ := gin.CreateTestContext(recorder) c, _ := gin.CreateTestContext(recorder)
setWebsocketRequestBody(c, " \n ") setWebsocketTimelineBody(c, " \n ")
if _, exists := c.Get(wsRequestBodyKey); exists { if _, exists := c.Get(wsTimelineBodyKey); exists {
t.Fatalf("request body key should not be set for empty body") t.Fatalf("timeline body key should not be set for empty body")
} }
setWebsocketRequestBody(c, "event body") setWebsocketTimelineBody(c, "timeline body")
value, exists := c.Get(wsRequestBodyKey) value, exists := c.Get(wsTimelineBodyKey)
if !exists { if !exists {
t.Fatalf("request body key not set") t.Fatalf("timeline body key not set")
} }
bodyBytes, ok := value.([]byte) bodyBytes, ok := value.([]byte)
if !ok { if !ok {
t.Fatalf("request body key type mismatch") t.Fatalf("timeline body key type mismatch")
} }
if string(bodyBytes) != "event body" { if string(bodyBytes) != "timeline body" {
t.Fatalf("request body = %q, want %q", string(bodyBytes), "event body") t.Fatalf("timeline body = %q, want %q", string(bodyBytes), "timeline body")
} }
} }
@@ -544,14 +562,14 @@ func TestForwardResponsesWebsocketPreservesCompletedEvent(t *testing.T) {
close(data) close(data)
close(errCh) close(errCh)
var bodyLog strings.Builder var timelineLog strings.Builder
completedOutput, err := (*OpenAIResponsesAPIHandler)(nil).forwardResponsesWebsocket( completedOutput, err := (*OpenAIResponsesAPIHandler)(nil).forwardResponsesWebsocket(
ctx, ctx,
conn, conn,
func(...interface{}) {}, func(...interface{}) {},
data, data,
errCh, errCh,
&bodyLog, &timelineLog,
"session-1", "session-1",
) )
if err != nil { if err != nil {
@@ -562,6 +580,10 @@ func TestForwardResponsesWebsocketPreservesCompletedEvent(t *testing.T) {
serverErrCh <- errors.New("completed output not captured") serverErrCh <- errors.New("completed output not captured")
return return
} }
if !strings.Contains(timelineLog.String(), "Event: websocket.response") {
serverErrCh <- errors.New("websocket timeline did not capture downstream response")
return
}
serverErrCh <- nil serverErrCh <- nil
})) }))
defer server.Close() defer server.Close()
@@ -594,6 +616,116 @@ func TestForwardResponsesWebsocketPreservesCompletedEvent(t *testing.T) {
} }
} }
func TestForwardResponsesWebsocketLogsAttemptedResponseOnWriteFailure(t *testing.T) {
gin.SetMode(gin.TestMode)
serverErrCh := make(chan error, 1)
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
conn, err := responsesWebsocketUpgrader.Upgrade(w, r, nil)
if err != nil {
serverErrCh <- err
return
}
ctx, _ := gin.CreateTestContext(httptest.NewRecorder())
ctx.Request = r
data := make(chan []byte, 1)
errCh := make(chan *interfaces.ErrorMessage)
data <- []byte("data: {\"type\":\"response.completed\",\"response\":{\"id\":\"resp-1\",\"output\":[{\"type\":\"message\",\"id\":\"out-1\"}]}}\n\n")
close(data)
close(errCh)
var timelineLog strings.Builder
if errClose := conn.Close(); errClose != nil {
serverErrCh <- errClose
return
}
_, err = (*OpenAIResponsesAPIHandler)(nil).forwardResponsesWebsocket(
ctx,
conn,
func(...interface{}) {},
data,
errCh,
&timelineLog,
"session-1",
)
if err == nil {
serverErrCh <- errors.New("expected websocket write failure")
return
}
if !strings.Contains(timelineLog.String(), "Event: websocket.response") {
serverErrCh <- errors.New("websocket timeline did not capture attempted downstream response")
return
}
if !strings.Contains(timelineLog.String(), "\"type\":\"response.completed\"") {
serverErrCh <- errors.New("websocket timeline did not retain attempted payload")
return
}
serverErrCh <- nil
}))
defer server.Close()
wsURL := "ws" + strings.TrimPrefix(server.URL, "http")
conn, _, err := websocket.DefaultDialer.Dial(wsURL, nil)
if err != nil {
t.Fatalf("dial websocket: %v", err)
}
defer func() {
_ = conn.Close()
}()
if errServer := <-serverErrCh; errServer != nil {
t.Fatalf("server error: %v", errServer)
}
}
func TestResponsesWebsocketTimelineRecordsDisconnectEvent(t *testing.T) {
gin.SetMode(gin.TestMode)
manager := coreauth.NewManager(nil, nil, nil)
base := handlers.NewBaseAPIHandlers(&sdkconfig.SDKConfig{}, manager)
h := NewOpenAIResponsesAPIHandler(base)
timelineCh := make(chan string, 1)
router := gin.New()
router.GET("/v1/responses/ws", func(c *gin.Context) {
h.ResponsesWebsocket(c)
timeline := ""
if value, exists := c.Get(wsTimelineBodyKey); exists {
if body, ok := value.([]byte); ok {
timeline = string(body)
}
}
timelineCh <- timeline
})
server := httptest.NewServer(router)
defer server.Close()
wsURL := "ws" + strings.TrimPrefix(server.URL, "http") + "/v1/responses/ws"
conn, _, err := websocket.DefaultDialer.Dial(wsURL, nil)
if err != nil {
t.Fatalf("dial websocket: %v", err)
}
closePayload := websocket.FormatCloseMessage(websocket.CloseGoingAway, "client closing")
if err = conn.WriteControl(websocket.CloseMessage, closePayload, time.Now().Add(time.Second)); err != nil {
t.Fatalf("write close control: %v", err)
}
_ = conn.Close()
select {
case timeline := <-timelineCh:
if !strings.Contains(timeline, "Event: websocket.disconnect") {
t.Fatalf("websocket timeline missing disconnect event: %s", timeline)
}
case <-time.After(5 * time.Second):
t.Fatal("timed out waiting for websocket timeline")
}
}
func TestWebsocketUpstreamSupportsIncrementalInputForModel(t *testing.T) { func TestWebsocketUpstreamSupportsIncrementalInputForModel(t *testing.T) {
manager := coreauth.NewManager(nil, nil, nil) manager := coreauth.NewManager(nil, nil, nil)
auth := &coreauth.Auth{ auth := &coreauth.Auth{