diff --git a/internal/api/middleware/response_writer.go b/internal/api/middleware/response_writer.go index 363278ab..7f489267 100644 --- a/internal/api/middleware/response_writer.go +++ b/internal/api/middleware/response_writer.go @@ -15,6 +15,8 @@ import ( ) const requestBodyOverrideContextKey = "REQUEST_BODY_OVERRIDE" +const responseBodyOverrideContextKey = "RESPONSE_BODY_OVERRIDE" +const websocketTimelineOverrideContextKey = "WEBSOCKET_TIMELINE_OVERRIDE" // RequestInfo holds essential details of an incoming HTTP request for logging purposes. type RequestInfo struct { @@ -304,6 +306,10 @@ func (w *ResponseWriterWrapper) Finalize(c *gin.Context) error { if len(apiResponse) > 0 { _ = w.streamWriter.WriteAPIResponse(apiResponse) } + apiWebsocketTimeline := w.extractAPIWebsocketTimeline(c) + if len(apiWebsocketTimeline) > 0 { + _ = w.streamWriter.WriteAPIWebsocketTimeline(apiWebsocketTimeline) + } if err := w.streamWriter.Close(); err != nil { w.streamWriter = nil return err @@ -312,7 +318,7 @@ func (w *ResponseWriterWrapper) Finalize(c *gin.Context) error { return nil } - return w.logRequest(w.extractRequestBody(c), finalStatusCode, w.cloneHeaders(), w.body.Bytes(), w.extractAPIRequest(c), w.extractAPIResponse(c), w.extractAPIResponseTimestamp(c), slicesAPIResponseError, forceLog) + return w.logRequest(w.extractRequestBody(c), finalStatusCode, w.cloneHeaders(), w.extractResponseBody(c), w.extractWebsocketTimeline(c), w.extractAPIRequest(c), w.extractAPIResponse(c), w.extractAPIWebsocketTimeline(c), w.extractAPIResponseTimestamp(c), slicesAPIResponseError, forceLog) } func (w *ResponseWriterWrapper) cloneHeaders() map[string][]string { @@ -352,6 +358,18 @@ func (w *ResponseWriterWrapper) extractAPIResponse(c *gin.Context) []byte { return data } +func (w *ResponseWriterWrapper) extractAPIWebsocketTimeline(c *gin.Context) []byte { + apiTimeline, isExist := c.Get("API_WEBSOCKET_TIMELINE") + if !isExist { + return nil + } + data, ok := apiTimeline.([]byte) + if !ok || len(data) == 0 { + return nil + } + return bytes.Clone(data) +} + func (w *ResponseWriterWrapper) extractAPIResponseTimestamp(c *gin.Context) time.Time { ts, isExist := c.Get("API_RESPONSE_TIMESTAMP") if !isExist { @@ -364,19 +382,8 @@ func (w *ResponseWriterWrapper) extractAPIResponseTimestamp(c *gin.Context) time } func (w *ResponseWriterWrapper) extractRequestBody(c *gin.Context) []byte { - if c != nil { - if bodyOverride, isExist := c.Get(requestBodyOverrideContextKey); isExist { - switch value := bodyOverride.(type) { - case []byte: - if len(value) > 0 { - return bytes.Clone(value) - } - case string: - if strings.TrimSpace(value) != "" { - return []byte(value) - } - } - } + if body := extractBodyOverride(c, requestBodyOverrideContextKey); len(body) > 0 { + return body } if w.requestInfo != nil && len(w.requestInfo.Body) > 0 { return w.requestInfo.Body @@ -384,13 +391,48 @@ func (w *ResponseWriterWrapper) extractRequestBody(c *gin.Context) []byte { return nil } -func (w *ResponseWriterWrapper) logRequest(requestBody []byte, statusCode int, headers map[string][]string, body []byte, apiRequestBody, apiResponseBody []byte, apiResponseTimestamp time.Time, apiResponseErrors []*interfaces.ErrorMessage, forceLog bool) error { +func (w *ResponseWriterWrapper) extractResponseBody(c *gin.Context) []byte { + if body := extractBodyOverride(c, responseBodyOverrideContextKey); len(body) > 0 { + return body + } + if w.body == nil || w.body.Len() == 0 { + return nil + } + return bytes.Clone(w.body.Bytes()) +} + +func (w *ResponseWriterWrapper) extractWebsocketTimeline(c *gin.Context) []byte { + return extractBodyOverride(c, websocketTimelineOverrideContextKey) +} + +func extractBodyOverride(c *gin.Context, key string) []byte { + if c == nil { + return nil + } + bodyOverride, isExist := c.Get(key) + if !isExist { + return nil + } + switch value := bodyOverride.(type) { + case []byte: + if len(value) > 0 { + return bytes.Clone(value) + } + case string: + if strings.TrimSpace(value) != "" { + return []byte(value) + } + } + return nil +} + +func (w *ResponseWriterWrapper) logRequest(requestBody []byte, statusCode int, headers map[string][]string, body, websocketTimeline, apiRequestBody, apiResponseBody, apiWebsocketTimeline []byte, apiResponseTimestamp time.Time, apiResponseErrors []*interfaces.ErrorMessage, forceLog bool) error { if w.requestInfo == nil { return nil } if loggerWithOptions, ok := w.logger.(interface { - LogRequestWithOptions(string, string, map[string][]string, []byte, int, map[string][]string, []byte, []byte, []byte, []*interfaces.ErrorMessage, bool, string, time.Time, time.Time) error + LogRequestWithOptions(string, string, map[string][]string, []byte, int, map[string][]string, []byte, []byte, []byte, []byte, []byte, []*interfaces.ErrorMessage, bool, string, time.Time, time.Time) error }); ok { return loggerWithOptions.LogRequestWithOptions( w.requestInfo.URL, @@ -400,8 +442,10 @@ func (w *ResponseWriterWrapper) logRequest(requestBody []byte, statusCode int, h statusCode, headers, body, + websocketTimeline, apiRequestBody, apiResponseBody, + apiWebsocketTimeline, apiResponseErrors, forceLog, w.requestInfo.RequestID, @@ -418,8 +462,10 @@ func (w *ResponseWriterWrapper) logRequest(requestBody []byte, statusCode int, h statusCode, headers, body, + websocketTimeline, apiRequestBody, apiResponseBody, + apiWebsocketTimeline, apiResponseErrors, w.requestInfo.RequestID, w.requestInfo.Timestamp, diff --git a/internal/api/middleware/response_writer_test.go b/internal/api/middleware/response_writer_test.go index fa4708e4..f5c21deb 100644 --- a/internal/api/middleware/response_writer_test.go +++ b/internal/api/middleware/response_writer_test.go @@ -1,10 +1,14 @@ package middleware import ( + "bytes" "net/http/httptest" "testing" + "time" "github.com/gin-gonic/gin" + "github.com/router-for-me/CLIProxyAPI/v6/internal/interfaces" + "github.com/router-for-me/CLIProxyAPI/v6/internal/logging" ) func TestExtractRequestBodyPrefersOverride(t *testing.T) { @@ -33,7 +37,7 @@ func TestExtractRequestBodySupportsStringOverride(t *testing.T) { recorder := httptest.NewRecorder() c, _ := gin.CreateTestContext(recorder) - wrapper := &ResponseWriterWrapper{} + wrapper := &ResponseWriterWrapper{body: &bytes.Buffer{}} c.Set(requestBodyOverrideContextKey, "override-as-string") body := wrapper.extractRequestBody(c) @@ -41,3 +45,158 @@ func TestExtractRequestBodySupportsStringOverride(t *testing.T) { t.Fatalf("request body = %q, want %q", string(body), "override-as-string") } } + +func TestExtractResponseBodyPrefersOverride(t *testing.T) { + gin.SetMode(gin.TestMode) + recorder := httptest.NewRecorder() + c, _ := gin.CreateTestContext(recorder) + + wrapper := &ResponseWriterWrapper{body: &bytes.Buffer{}} + wrapper.body.WriteString("original-response") + + body := wrapper.extractResponseBody(c) + if string(body) != "original-response" { + t.Fatalf("response body = %q, want %q", string(body), "original-response") + } + + c.Set(responseBodyOverrideContextKey, []byte("override-response")) + body = wrapper.extractResponseBody(c) + if string(body) != "override-response" { + t.Fatalf("response body = %q, want %q", string(body), "override-response") + } + + body[0] = 'X' + if got := wrapper.extractResponseBody(c); string(got) != "override-response" { + t.Fatalf("response override should be cloned, got %q", string(got)) + } +} + +func TestExtractResponseBodySupportsStringOverride(t *testing.T) { + gin.SetMode(gin.TestMode) + recorder := httptest.NewRecorder() + c, _ := gin.CreateTestContext(recorder) + + wrapper := &ResponseWriterWrapper{} + c.Set(responseBodyOverrideContextKey, "override-response-as-string") + + body := wrapper.extractResponseBody(c) + if string(body) != "override-response-as-string" { + t.Fatalf("response body = %q, want %q", string(body), "override-response-as-string") + } +} + +func TestExtractBodyOverrideClonesBytes(t *testing.T) { + gin.SetMode(gin.TestMode) + recorder := httptest.NewRecorder() + c, _ := gin.CreateTestContext(recorder) + + override := []byte("body-override") + c.Set(requestBodyOverrideContextKey, override) + + body := extractBodyOverride(c, requestBodyOverrideContextKey) + if !bytes.Equal(body, override) { + t.Fatalf("body override = %q, want %q", string(body), string(override)) + } + + body[0] = 'X' + if !bytes.Equal(override, []byte("body-override")) { + t.Fatalf("override mutated: %q", string(override)) + } +} + +func TestExtractWebsocketTimelineUsesOverride(t *testing.T) { + gin.SetMode(gin.TestMode) + recorder := httptest.NewRecorder() + c, _ := gin.CreateTestContext(recorder) + + wrapper := &ResponseWriterWrapper{} + if got := wrapper.extractWebsocketTimeline(c); got != nil { + t.Fatalf("expected nil websocket timeline, got %q", string(got)) + } + + c.Set(websocketTimelineOverrideContextKey, []byte("timeline")) + body := wrapper.extractWebsocketTimeline(c) + if string(body) != "timeline" { + t.Fatalf("websocket timeline = %q, want %q", string(body), "timeline") + } +} + +func TestFinalizeStreamingWritesAPIWebsocketTimeline(t *testing.T) { + gin.SetMode(gin.TestMode) + recorder := httptest.NewRecorder() + c, _ := gin.CreateTestContext(recorder) + + streamWriter := &testStreamingLogWriter{} + wrapper := &ResponseWriterWrapper{ + ResponseWriter: c.Writer, + logger: &testRequestLogger{enabled: true}, + requestInfo: &RequestInfo{ + URL: "/v1/responses", + Method: "POST", + Headers: map[string][]string{"Content-Type": {"application/json"}}, + RequestID: "req-1", + Timestamp: time.Date(2026, time.April, 1, 12, 0, 0, 0, time.UTC), + }, + isStreaming: true, + streamWriter: streamWriter, + } + + c.Set("API_WEBSOCKET_TIMELINE", []byte("Timestamp: 2026-04-01T12:00:00Z\nEvent: api.websocket.request\n{}")) + + if err := wrapper.Finalize(c); err != nil { + t.Fatalf("Finalize error: %v", err) + } + if string(streamWriter.apiWebsocketTimeline) != "Timestamp: 2026-04-01T12:00:00Z\nEvent: api.websocket.request\n{}" { + t.Fatalf("stream writer websocket timeline = %q", string(streamWriter.apiWebsocketTimeline)) + } + if !streamWriter.closed { + t.Fatal("expected stream writer to be closed") + } +} + +type testRequestLogger struct { + enabled bool +} + +func (l *testRequestLogger) LogRequest(string, string, map[string][]string, []byte, int, map[string][]string, []byte, []byte, []byte, []byte, []byte, []*interfaces.ErrorMessage, string, time.Time, time.Time) error { + return nil +} + +func (l *testRequestLogger) LogStreamingRequest(string, string, map[string][]string, []byte, string) (logging.StreamingLogWriter, error) { + return &testStreamingLogWriter{}, nil +} + +func (l *testRequestLogger) IsEnabled() bool { + return l.enabled +} + +type testStreamingLogWriter struct { + apiWebsocketTimeline []byte + closed bool +} + +func (w *testStreamingLogWriter) WriteChunkAsync([]byte) {} + +func (w *testStreamingLogWriter) WriteStatus(int, map[string][]string) error { + return nil +} + +func (w *testStreamingLogWriter) WriteAPIRequest([]byte) error { + return nil +} + +func (w *testStreamingLogWriter) WriteAPIResponse([]byte) error { + return nil +} + +func (w *testStreamingLogWriter) WriteAPIWebsocketTimeline(apiWebsocketTimeline []byte) error { + w.apiWebsocketTimeline = bytes.Clone(apiWebsocketTimeline) + return nil +} + +func (w *testStreamingLogWriter) SetFirstChunkTimestamp(time.Time) {} + +func (w *testStreamingLogWriter) Close() error { + w.closed = true + return nil +} diff --git a/internal/api/server_test.go b/internal/api/server_test.go index f5c18aa1..7ce38b8f 100644 --- a/internal/api/server_test.go +++ b/internal/api/server_test.go @@ -172,6 +172,8 @@ func TestDefaultRequestLoggerFactory_UsesResolvedLogDirectory(t *testing.T) { nil, nil, nil, + nil, + nil, true, "issue-1711", time.Now(), diff --git a/internal/logging/request_logger.go b/internal/logging/request_logger.go index ad7b03c1..2db2a504 100644 --- a/internal/logging/request_logger.go +++ b/internal/logging/request_logger.go @@ -4,6 +4,7 @@ package logging import ( + "bufio" "bytes" "compress/flate" "compress/gzip" @@ -41,15 +42,17 @@ type RequestLogger interface { // - statusCode: The response status code // - responseHeaders: The response headers // - response: The raw response data + // - websocketTimeline: Optional downstream websocket event timeline // - apiRequest: The API request data // - apiResponse: The API response data + // - apiWebsocketTimeline: Optional upstream websocket event timeline // - requestID: Optional request ID for log file naming // - requestTimestamp: When the request was received // - apiResponseTimestamp: When the API response was received // // Returns: // - error: An error if logging fails, nil otherwise - LogRequest(url, method string, requestHeaders map[string][]string, body []byte, statusCode int, responseHeaders map[string][]string, response, apiRequest, apiResponse []byte, apiResponseErrors []*interfaces.ErrorMessage, requestID string, requestTimestamp, apiResponseTimestamp time.Time) error + LogRequest(url, method string, requestHeaders map[string][]string, body []byte, statusCode int, responseHeaders map[string][]string, response, websocketTimeline, apiRequest, apiResponse, apiWebsocketTimeline []byte, apiResponseErrors []*interfaces.ErrorMessage, requestID string, requestTimestamp, apiResponseTimestamp time.Time) error // LogStreamingRequest initiates logging for a streaming request and returns a writer for chunks. // @@ -111,6 +114,16 @@ type StreamingLogWriter interface { // - error: An error if writing fails, nil otherwise WriteAPIResponse(apiResponse []byte) error + // WriteAPIWebsocketTimeline writes the upstream websocket timeline to the log. + // This should be called when upstream communication happened over websocket. + // + // Parameters: + // - apiWebsocketTimeline: The upstream websocket event timeline + // + // Returns: + // - error: An error if writing fails, nil otherwise + WriteAPIWebsocketTimeline(apiWebsocketTimeline []byte) error + // SetFirstChunkTimestamp sets the TTFB timestamp captured when first chunk was received. // // Parameters: @@ -203,17 +216,17 @@ func (l *FileRequestLogger) SetErrorLogsMaxFiles(maxFiles int) { // // Returns: // - error: An error if logging fails, nil otherwise -func (l *FileRequestLogger) LogRequest(url, method string, requestHeaders map[string][]string, body []byte, statusCode int, responseHeaders map[string][]string, response, apiRequest, apiResponse []byte, apiResponseErrors []*interfaces.ErrorMessage, requestID string, requestTimestamp, apiResponseTimestamp time.Time) error { - return l.logRequest(url, method, requestHeaders, body, statusCode, responseHeaders, response, apiRequest, apiResponse, apiResponseErrors, false, requestID, requestTimestamp, apiResponseTimestamp) +func (l *FileRequestLogger) LogRequest(url, method string, requestHeaders map[string][]string, body []byte, statusCode int, responseHeaders map[string][]string, response, websocketTimeline, apiRequest, apiResponse, apiWebsocketTimeline []byte, apiResponseErrors []*interfaces.ErrorMessage, requestID string, requestTimestamp, apiResponseTimestamp time.Time) error { + return l.logRequest(url, method, requestHeaders, body, statusCode, responseHeaders, response, websocketTimeline, apiRequest, apiResponse, apiWebsocketTimeline, apiResponseErrors, false, requestID, requestTimestamp, apiResponseTimestamp) } // LogRequestWithOptions logs a request with optional forced logging behavior. // The force flag allows writing error logs even when regular request logging is disabled. -func (l *FileRequestLogger) LogRequestWithOptions(url, method string, requestHeaders map[string][]string, body []byte, statusCode int, responseHeaders map[string][]string, response, apiRequest, apiResponse []byte, apiResponseErrors []*interfaces.ErrorMessage, force bool, requestID string, requestTimestamp, apiResponseTimestamp time.Time) error { - return l.logRequest(url, method, requestHeaders, body, statusCode, responseHeaders, response, apiRequest, apiResponse, apiResponseErrors, force, requestID, requestTimestamp, apiResponseTimestamp) +func (l *FileRequestLogger) LogRequestWithOptions(url, method string, requestHeaders map[string][]string, body []byte, statusCode int, responseHeaders map[string][]string, response, websocketTimeline, apiRequest, apiResponse, apiWebsocketTimeline []byte, apiResponseErrors []*interfaces.ErrorMessage, force bool, requestID string, requestTimestamp, apiResponseTimestamp time.Time) error { + return l.logRequest(url, method, requestHeaders, body, statusCode, responseHeaders, response, websocketTimeline, apiRequest, apiResponse, apiWebsocketTimeline, apiResponseErrors, force, requestID, requestTimestamp, apiResponseTimestamp) } -func (l *FileRequestLogger) logRequest(url, method string, requestHeaders map[string][]string, body []byte, statusCode int, responseHeaders map[string][]string, response, apiRequest, apiResponse []byte, apiResponseErrors []*interfaces.ErrorMessage, force bool, requestID string, requestTimestamp, apiResponseTimestamp time.Time) error { +func (l *FileRequestLogger) logRequest(url, method string, requestHeaders map[string][]string, body []byte, statusCode int, responseHeaders map[string][]string, response, websocketTimeline, apiRequest, apiResponse, apiWebsocketTimeline []byte, apiResponseErrors []*interfaces.ErrorMessage, force bool, requestID string, requestTimestamp, apiResponseTimestamp time.Time) error { if !l.enabled && !force { return nil } @@ -260,8 +273,10 @@ func (l *FileRequestLogger) logRequest(url, method string, requestHeaders map[st requestHeaders, body, requestBodyPath, + websocketTimeline, apiRequest, apiResponse, + apiWebsocketTimeline, apiResponseErrors, statusCode, responseHeaders, @@ -518,8 +533,10 @@ func (l *FileRequestLogger) writeNonStreamingLog( requestHeaders map[string][]string, requestBody []byte, requestBodyPath string, + websocketTimeline []byte, apiRequest []byte, apiResponse []byte, + apiWebsocketTimeline []byte, apiResponseErrors []*interfaces.ErrorMessage, statusCode int, responseHeaders map[string][]string, @@ -531,7 +548,16 @@ func (l *FileRequestLogger) writeNonStreamingLog( if requestTimestamp.IsZero() { requestTimestamp = time.Now() } - if errWrite := writeRequestInfoWithBody(w, url, method, requestHeaders, requestBody, requestBodyPath, requestTimestamp); errWrite != nil { + isWebsocketTranscript := hasSectionPayload(websocketTimeline) + downstreamTransport := inferDownstreamTransport(requestHeaders, websocketTimeline) + upstreamTransport := inferUpstreamTransport(apiRequest, apiResponse, apiWebsocketTimeline, apiResponseErrors) + if errWrite := writeRequestInfoWithBody(w, url, method, requestHeaders, requestBody, requestBodyPath, requestTimestamp, downstreamTransport, upstreamTransport, !isWebsocketTranscript); errWrite != nil { + return errWrite + } + if errWrite := writeAPISection(w, "=== WEBSOCKET TIMELINE ===\n", "=== WEBSOCKET TIMELINE", websocketTimeline, time.Time{}); errWrite != nil { + return errWrite + } + if errWrite := writeAPISection(w, "=== API WEBSOCKET TIMELINE ===\n", "=== API WEBSOCKET TIMELINE", apiWebsocketTimeline, time.Time{}); errWrite != nil { return errWrite } if errWrite := writeAPISection(w, "=== API REQUEST ===\n", "=== API REQUEST", apiRequest, time.Time{}); errWrite != nil { @@ -543,6 +569,12 @@ func (l *FileRequestLogger) writeNonStreamingLog( if errWrite := writeAPISection(w, "=== API RESPONSE ===\n", "=== API RESPONSE", apiResponse, apiResponseTimestamp); errWrite != nil { return errWrite } + if isWebsocketTranscript { + // Intentionally omit the generic downstream HTTP response section for websocket + // transcripts. The durable session exchange is captured in WEBSOCKET TIMELINE, + // and appending a one-off upgrade response snapshot would dilute that transcript. + return nil + } return writeResponseSection(w, statusCode, true, responseHeaders, bytes.NewReader(response), decompressErr, true) } @@ -553,6 +585,9 @@ func writeRequestInfoWithBody( body []byte, bodyPath string, timestamp time.Time, + downstreamTransport string, + upstreamTransport string, + includeBody bool, ) error { if _, errWrite := io.WriteString(w, "=== REQUEST INFO ===\n"); errWrite != nil { return errWrite @@ -566,10 +601,20 @@ func writeRequestInfoWithBody( if _, errWrite := io.WriteString(w, fmt.Sprintf("Method: %s\n", method)); errWrite != nil { return errWrite } + if strings.TrimSpace(downstreamTransport) != "" { + if _, errWrite := io.WriteString(w, fmt.Sprintf("Downstream Transport: %s\n", downstreamTransport)); errWrite != nil { + return errWrite + } + } + if strings.TrimSpace(upstreamTransport) != "" { + if _, errWrite := io.WriteString(w, fmt.Sprintf("Upstream Transport: %s\n", upstreamTransport)); errWrite != nil { + return errWrite + } + } if _, errWrite := io.WriteString(w, fmt.Sprintf("Timestamp: %s\n", timestamp.Format(time.RFC3339Nano))); errWrite != nil { return errWrite } - if _, errWrite := io.WriteString(w, "\n"); errWrite != nil { + if errWrite := writeSectionSpacing(w, 1); errWrite != nil { return errWrite } @@ -584,36 +629,121 @@ func writeRequestInfoWithBody( } } } - if _, errWrite := io.WriteString(w, "\n"); errWrite != nil { + if errWrite := writeSectionSpacing(w, 1); errWrite != nil { return errWrite } + if !includeBody { + return nil + } + if _, errWrite := io.WriteString(w, "=== REQUEST BODY ===\n"); errWrite != nil { return errWrite } + bodyTrailingNewlines := 1 if bodyPath != "" { bodyFile, errOpen := os.Open(bodyPath) if errOpen != nil { return errOpen } - if _, errCopy := io.Copy(w, bodyFile); errCopy != nil { + tracker := &trailingNewlineTrackingWriter{writer: w} + written, errCopy := io.Copy(tracker, bodyFile) + if errCopy != nil { _ = bodyFile.Close() return errCopy } + if written > 0 { + bodyTrailingNewlines = tracker.trailingNewlines + } if errClose := bodyFile.Close(); errClose != nil { log.WithError(errClose).Warn("failed to close request body temp file") } } else if _, errWrite := w.Write(body); errWrite != nil { return errWrite + } else if len(body) > 0 { + bodyTrailingNewlines = countTrailingNewlinesBytes(body) } - - if _, errWrite := io.WriteString(w, "\n\n"); errWrite != nil { + if errWrite := writeSectionSpacing(w, bodyTrailingNewlines); errWrite != nil { return errWrite } return nil } +func countTrailingNewlinesBytes(payload []byte) int { + count := 0 + for i := len(payload) - 1; i >= 0; i-- { + if payload[i] != '\n' { + break + } + count++ + } + return count +} + +func writeSectionSpacing(w io.Writer, trailingNewlines int) error { + missingNewlines := 3 - trailingNewlines + if missingNewlines <= 0 { + return nil + } + _, errWrite := io.WriteString(w, strings.Repeat("\n", missingNewlines)) + return errWrite +} + +type trailingNewlineTrackingWriter struct { + writer io.Writer + trailingNewlines int +} + +func (t *trailingNewlineTrackingWriter) Write(payload []byte) (int, error) { + written, errWrite := t.writer.Write(payload) + if written > 0 { + writtenPayload := payload[:written] + trailingNewlines := countTrailingNewlinesBytes(writtenPayload) + if trailingNewlines == len(writtenPayload) { + t.trailingNewlines += trailingNewlines + } else { + t.trailingNewlines = trailingNewlines + } + } + return written, errWrite +} + +func hasSectionPayload(payload []byte) bool { + return len(bytes.TrimSpace(payload)) > 0 +} + +func inferDownstreamTransport(headers map[string][]string, websocketTimeline []byte) string { + if hasSectionPayload(websocketTimeline) { + return "websocket" + } + for key, values := range headers { + if strings.EqualFold(strings.TrimSpace(key), "Upgrade") { + for _, value := range values { + if strings.EqualFold(strings.TrimSpace(value), "websocket") { + return "websocket" + } + } + } + } + return "http" +} + +func inferUpstreamTransport(apiRequest, apiResponse, apiWebsocketTimeline []byte, _ []*interfaces.ErrorMessage) string { + hasHTTP := hasSectionPayload(apiRequest) || hasSectionPayload(apiResponse) + hasWS := hasSectionPayload(apiWebsocketTimeline) + switch { + case hasHTTP && hasWS: + return "websocket+http" + case hasWS: + return "websocket" + case hasHTTP: + return "http" + default: + return "" + } +} + func writeAPISection(w io.Writer, sectionHeader string, sectionPrefix string, payload []byte, timestamp time.Time) error { if len(payload) == 0 { return nil @@ -623,11 +753,6 @@ func writeAPISection(w io.Writer, sectionHeader string, sectionPrefix string, pa if _, errWrite := w.Write(payload); errWrite != nil { return errWrite } - if !bytes.HasSuffix(payload, []byte("\n")) { - if _, errWrite := io.WriteString(w, "\n"); errWrite != nil { - return errWrite - } - } } else { if _, errWrite := io.WriteString(w, sectionHeader); errWrite != nil { return errWrite @@ -640,12 +765,9 @@ func writeAPISection(w io.Writer, sectionHeader string, sectionPrefix string, pa if _, errWrite := w.Write(payload); errWrite != nil { return errWrite } - if _, errWrite := io.WriteString(w, "\n"); errWrite != nil { - return errWrite - } } - if _, errWrite := io.WriteString(w, "\n"); errWrite != nil { + if errWrite := writeSectionSpacing(w, countTrailingNewlinesBytes(payload)); errWrite != nil { return errWrite } return nil @@ -662,12 +784,17 @@ func writeAPIErrorResponses(w io.Writer, apiResponseErrors []*interfaces.ErrorMe if _, errWrite := io.WriteString(w, fmt.Sprintf("HTTP Status: %d\n", apiResponseErrors[i].StatusCode)); errWrite != nil { return errWrite } + trailingNewlines := 1 if apiResponseErrors[i].Error != nil { - if _, errWrite := io.WriteString(w, apiResponseErrors[i].Error.Error()); errWrite != nil { + errText := apiResponseErrors[i].Error.Error() + if _, errWrite := io.WriteString(w, errText); errWrite != nil { return errWrite } + if errText != "" { + trailingNewlines = countTrailingNewlinesBytes([]byte(errText)) + } } - if _, errWrite := io.WriteString(w, "\n\n"); errWrite != nil { + if errWrite := writeSectionSpacing(w, trailingNewlines); errWrite != nil { return errWrite } } @@ -694,12 +821,18 @@ func writeResponseSection(w io.Writer, statusCode int, statusWritten bool, respo } } - if _, errWrite := io.WriteString(w, "\n"); errWrite != nil { - return errWrite + var bufferedReader *bufio.Reader + if responseReader != nil { + bufferedReader = bufio.NewReader(responseReader) + } + if !responseBodyStartsWithLeadingNewline(bufferedReader) { + if _, errWrite := io.WriteString(w, "\n"); errWrite != nil { + return errWrite + } } - if responseReader != nil { - if _, errCopy := io.Copy(w, responseReader); errCopy != nil { + if bufferedReader != nil { + if _, errCopy := io.Copy(w, bufferedReader); errCopy != nil { return errCopy } } @@ -717,6 +850,19 @@ func writeResponseSection(w io.Writer, statusCode int, statusWritten bool, respo return nil } +func responseBodyStartsWithLeadingNewline(reader *bufio.Reader) bool { + if reader == nil { + return false + } + if peeked, _ := reader.Peek(2); len(peeked) >= 2 && peeked[0] == '\r' && peeked[1] == '\n' { + return true + } + if peeked, _ := reader.Peek(1); len(peeked) >= 1 && peeked[0] == '\n' { + return true + } + return false +} + // formatLogContent creates the complete log content for non-streaming requests. // // Parameters: @@ -724,6 +870,7 @@ func writeResponseSection(w io.Writer, statusCode int, statusWritten bool, respo // - method: The HTTP method // - headers: The request headers // - body: The request body +// - websocketTimeline: The downstream websocket event timeline // - apiRequest: The API request data // - apiResponse: The API response data // - response: The raw response data @@ -732,11 +879,42 @@ func writeResponseSection(w io.Writer, statusCode int, statusWritten bool, respo // // Returns: // - string: The formatted log content -func (l *FileRequestLogger) formatLogContent(url, method string, headers map[string][]string, body, apiRequest, apiResponse, response []byte, status int, responseHeaders map[string][]string, apiResponseErrors []*interfaces.ErrorMessage) string { +func (l *FileRequestLogger) formatLogContent(url, method string, headers map[string][]string, body, websocketTimeline, apiRequest, apiResponse, apiWebsocketTimeline, response []byte, status int, responseHeaders map[string][]string, apiResponseErrors []*interfaces.ErrorMessage) string { var content strings.Builder + isWebsocketTranscript := hasSectionPayload(websocketTimeline) + downstreamTransport := inferDownstreamTransport(headers, websocketTimeline) + upstreamTransport := inferUpstreamTransport(apiRequest, apiResponse, apiWebsocketTimeline, apiResponseErrors) // Request info - content.WriteString(l.formatRequestInfo(url, method, headers, body)) + content.WriteString(l.formatRequestInfo(url, method, headers, body, downstreamTransport, upstreamTransport, !isWebsocketTranscript)) + + if len(websocketTimeline) > 0 { + if bytes.HasPrefix(websocketTimeline, []byte("=== WEBSOCKET TIMELINE")) { + content.Write(websocketTimeline) + if !bytes.HasSuffix(websocketTimeline, []byte("\n")) { + content.WriteString("\n") + } + } else { + content.WriteString("=== WEBSOCKET TIMELINE ===\n") + content.Write(websocketTimeline) + content.WriteString("\n") + } + content.WriteString("\n") + } + + if len(apiWebsocketTimeline) > 0 { + if bytes.HasPrefix(apiWebsocketTimeline, []byte("=== API WEBSOCKET TIMELINE")) { + content.Write(apiWebsocketTimeline) + if !bytes.HasSuffix(apiWebsocketTimeline, []byte("\n")) { + content.WriteString("\n") + } + } else { + content.WriteString("=== API WEBSOCKET TIMELINE ===\n") + content.Write(apiWebsocketTimeline) + content.WriteString("\n") + } + content.WriteString("\n") + } if len(apiRequest) > 0 { if bytes.HasPrefix(apiRequest, []byte("=== API REQUEST")) { @@ -773,6 +951,12 @@ func (l *FileRequestLogger) formatLogContent(url, method string, headers map[str content.WriteString("\n") } + if isWebsocketTranscript { + // Mirror writeNonStreamingLog: websocket transcripts end with the dedicated + // timeline sections instead of a generic downstream HTTP response block. + return content.String() + } + // Response section content.WriteString("=== RESPONSE ===\n") content.WriteString(fmt.Sprintf("Status: %d\n", status)) @@ -933,13 +1117,19 @@ func (l *FileRequestLogger) decompressZstd(data []byte) ([]byte, error) { // // Returns: // - string: The formatted request information -func (l *FileRequestLogger) formatRequestInfo(url, method string, headers map[string][]string, body []byte) string { +func (l *FileRequestLogger) formatRequestInfo(url, method string, headers map[string][]string, body []byte, downstreamTransport string, upstreamTransport string, includeBody bool) string { var content strings.Builder content.WriteString("=== REQUEST INFO ===\n") content.WriteString(fmt.Sprintf("Version: %s\n", buildinfo.Version)) content.WriteString(fmt.Sprintf("URL: %s\n", url)) content.WriteString(fmt.Sprintf("Method: %s\n", method)) + if strings.TrimSpace(downstreamTransport) != "" { + content.WriteString(fmt.Sprintf("Downstream Transport: %s\n", downstreamTransport)) + } + if strings.TrimSpace(upstreamTransport) != "" { + content.WriteString(fmt.Sprintf("Upstream Transport: %s\n", upstreamTransport)) + } content.WriteString(fmt.Sprintf("Timestamp: %s\n", time.Now().Format(time.RFC3339Nano))) content.WriteString("\n") @@ -952,6 +1142,10 @@ func (l *FileRequestLogger) formatRequestInfo(url, method string, headers map[st } content.WriteString("\n") + if !includeBody { + return content.String() + } + content.WriteString("=== REQUEST BODY ===\n") content.Write(body) content.WriteString("\n\n") @@ -1011,6 +1205,9 @@ type FileStreamingLogWriter struct { // apiResponse stores the upstream API response data. apiResponse []byte + // apiWebsocketTimeline stores the upstream websocket event timeline. + apiWebsocketTimeline []byte + // apiResponseTimestamp captures when the API response was received. apiResponseTimestamp time.Time } @@ -1092,6 +1289,21 @@ func (w *FileStreamingLogWriter) WriteAPIResponse(apiResponse []byte) error { return nil } +// WriteAPIWebsocketTimeline buffers the upstream websocket timeline for later writing. +// +// Parameters: +// - apiWebsocketTimeline: The upstream websocket event timeline +// +// Returns: +// - error: Always returns nil (buffering cannot fail) +func (w *FileStreamingLogWriter) WriteAPIWebsocketTimeline(apiWebsocketTimeline []byte) error { + if len(apiWebsocketTimeline) == 0 { + return nil + } + w.apiWebsocketTimeline = bytes.Clone(apiWebsocketTimeline) + return nil +} + func (w *FileStreamingLogWriter) SetFirstChunkTimestamp(timestamp time.Time) { if !timestamp.IsZero() { w.apiResponseTimestamp = timestamp @@ -1100,7 +1312,7 @@ func (w *FileStreamingLogWriter) SetFirstChunkTimestamp(timestamp time.Time) { // Close finalizes the log file and cleans up resources. // It writes all buffered data to the file in the correct order: -// API REQUEST -> API RESPONSE -> RESPONSE (status, headers, body chunks) +// API WEBSOCKET TIMELINE -> API REQUEST -> API RESPONSE -> RESPONSE (status, headers, body chunks) // // Returns: // - error: An error if closing fails, nil otherwise @@ -1182,7 +1394,10 @@ func (w *FileStreamingLogWriter) asyncWriter() { } func (w *FileStreamingLogWriter) writeFinalLog(logFile *os.File) error { - if errWrite := writeRequestInfoWithBody(logFile, w.url, w.method, w.requestHeaders, nil, w.requestBodyPath, w.timestamp); errWrite != nil { + if errWrite := writeRequestInfoWithBody(logFile, w.url, w.method, w.requestHeaders, nil, w.requestBodyPath, w.timestamp, "http", inferUpstreamTransport(w.apiRequest, w.apiResponse, w.apiWebsocketTimeline, nil), true); errWrite != nil { + return errWrite + } + if errWrite := writeAPISection(logFile, "=== API WEBSOCKET TIMELINE ===\n", "=== API WEBSOCKET TIMELINE", w.apiWebsocketTimeline, time.Time{}); errWrite != nil { return errWrite } if errWrite := writeAPISection(logFile, "=== API REQUEST ===\n", "=== API REQUEST", w.apiRequest, time.Time{}); errWrite != nil { @@ -1265,6 +1480,17 @@ func (w *NoOpStreamingLogWriter) WriteAPIResponse(_ []byte) error { return nil } +// WriteAPIWebsocketTimeline is a no-op implementation that does nothing and always returns nil. +// +// Parameters: +// - apiWebsocketTimeline: The upstream websocket event timeline (ignored) +// +// Returns: +// - error: Always returns nil +func (w *NoOpStreamingLogWriter) WriteAPIWebsocketTimeline(_ []byte) error { + return nil +} + func (w *NoOpStreamingLogWriter) SetFirstChunkTimestamp(_ time.Time) {} // Close is a no-op implementation that does nothing and always returns nil. diff --git a/internal/runtime/executor/codex_websockets_executor.go b/internal/runtime/executor/codex_websockets_executor.go index dc9a8a79..2041cebc 100644 --- a/internal/runtime/executor/codex_websockets_executor.go +++ b/internal/runtime/executor/codex_websockets_executor.go @@ -219,7 +219,7 @@ func (e *CodexWebsocketsExecutor) Execute(ctx context.Context, auth *cliproxyaut } wsReqBody := buildCodexWebsocketRequestBody(body) - helps.RecordAPIRequest(ctx, e.cfg, helps.UpstreamRequestLog{ + wsReqLog := helps.UpstreamRequestLog{ URL: wsURL, Method: "WEBSOCKET", Headers: wsHeaders.Clone(), @@ -229,16 +229,14 @@ func (e *CodexWebsocketsExecutor) Execute(ctx context.Context, auth *cliproxyaut AuthLabel: authLabel, AuthType: authType, AuthValue: authValue, - }) + } + helps.RecordAPIWebsocketRequest(ctx, e.cfg, wsReqLog) conn, respHS, errDial := e.ensureUpstreamConn(ctx, auth, sess, authID, wsURL, wsHeaders) - if respHS != nil { - helps.RecordAPIResponseMetadata(ctx, e.cfg, respHS.StatusCode, respHS.Header.Clone()) - } if errDial != nil { bodyErr := websocketHandshakeBody(respHS) - if len(bodyErr) > 0 { - helps.AppendAPIResponseChunk(ctx, e.cfg, bodyErr) + if respHS != nil { + helps.RecordAPIWebsocketUpgradeRejection(ctx, e.cfg, websocketUpgradeRequestLog(wsReqLog), respHS.StatusCode, respHS.Header.Clone(), bodyErr) } if respHS != nil && respHS.StatusCode == http.StatusUpgradeRequired { return e.CodexExecutor.Execute(ctx, auth, req, opts) @@ -246,10 +244,10 @@ func (e *CodexWebsocketsExecutor) Execute(ctx context.Context, auth *cliproxyaut if respHS != nil && respHS.StatusCode > 0 { return resp, statusErr{code: respHS.StatusCode, msg: string(bodyErr)} } - helps.RecordAPIResponseError(ctx, e.cfg, errDial) + helps.RecordAPIWebsocketError(ctx, e.cfg, "dial", errDial) return resp, errDial } - closeHTTPResponseBody(respHS, "codex websockets executor: close handshake response body error") + recordAPIWebsocketHandshake(ctx, e.cfg, respHS) if sess == nil { logCodexWebsocketConnected(executionSessionID, authID, wsURL) defer func() { @@ -278,10 +276,10 @@ func (e *CodexWebsocketsExecutor) Execute(ctx context.Context, auth *cliproxyaut // Retry once with a fresh websocket connection. This is mainly to handle // upstream closing the socket between sequential requests within the same // execution session. - connRetry, _, errDialRetry := e.ensureUpstreamConn(ctx, auth, sess, authID, wsURL, wsHeaders) + connRetry, respHSRetry, errDialRetry := e.ensureUpstreamConn(ctx, auth, sess, authID, wsURL, wsHeaders) if errDialRetry == nil && connRetry != nil { wsReqBodyRetry := buildCodexWebsocketRequestBody(body) - helps.RecordAPIRequest(ctx, e.cfg, helps.UpstreamRequestLog{ + helps.RecordAPIWebsocketRequest(ctx, e.cfg, helps.UpstreamRequestLog{ URL: wsURL, Method: "WEBSOCKET", Headers: wsHeaders.Clone(), @@ -292,20 +290,22 @@ func (e *CodexWebsocketsExecutor) Execute(ctx context.Context, auth *cliproxyaut AuthType: authType, AuthValue: authValue, }) + recordAPIWebsocketHandshake(ctx, e.cfg, respHSRetry) if errSendRetry := writeCodexWebsocketMessage(sess, connRetry, wsReqBodyRetry); errSendRetry == nil { conn = connRetry wsReqBody = wsReqBodyRetry } else { e.invalidateUpstreamConn(sess, connRetry, "send_error", errSendRetry) - helps.RecordAPIResponseError(ctx, e.cfg, errSendRetry) + helps.RecordAPIWebsocketError(ctx, e.cfg, "send_retry", errSendRetry) return resp, errSendRetry } } else { - helps.RecordAPIResponseError(ctx, e.cfg, errDialRetry) + closeHTTPResponseBody(respHSRetry, "codex websockets executor: close handshake response body error") + helps.RecordAPIWebsocketError(ctx, e.cfg, "dial_retry", errDialRetry) return resp, errDialRetry } } else { - helps.RecordAPIResponseError(ctx, e.cfg, errSend) + helps.RecordAPIWebsocketError(ctx, e.cfg, "send", errSend) return resp, errSend } } @@ -316,7 +316,7 @@ func (e *CodexWebsocketsExecutor) Execute(ctx context.Context, auth *cliproxyaut } msgType, payload, errRead := readCodexWebsocketMessage(ctx, sess, conn, readCh) if errRead != nil { - helps.RecordAPIResponseError(ctx, e.cfg, errRead) + helps.RecordAPIWebsocketError(ctx, e.cfg, "read", errRead) return resp, errRead } if msgType != websocket.TextMessage { @@ -325,7 +325,7 @@ func (e *CodexWebsocketsExecutor) Execute(ctx context.Context, auth *cliproxyaut if sess != nil { e.invalidateUpstreamConn(sess, conn, "unexpected_binary", err) } - helps.RecordAPIResponseError(ctx, e.cfg, err) + helps.RecordAPIWebsocketError(ctx, e.cfg, "unexpected_binary", err) return resp, err } continue @@ -335,13 +335,13 @@ func (e *CodexWebsocketsExecutor) Execute(ctx context.Context, auth *cliproxyaut if len(payload) == 0 { continue } - helps.AppendAPIResponseChunk(ctx, e.cfg, payload) + helps.AppendAPIWebsocketResponse(ctx, e.cfg, payload) if wsErr, ok := parseCodexWebsocketError(payload); ok { if sess != nil { e.invalidateUpstreamConn(sess, conn, "upstream_error", wsErr) } - helps.RecordAPIResponseError(ctx, e.cfg, wsErr) + helps.RecordAPIWebsocketError(ctx, e.cfg, "upstream_error", wsErr) return resp, wsErr } @@ -413,7 +413,7 @@ func (e *CodexWebsocketsExecutor) ExecuteStream(ctx context.Context, auth *clipr } wsReqBody := buildCodexWebsocketRequestBody(body) - helps.RecordAPIRequest(ctx, e.cfg, helps.UpstreamRequestLog{ + wsReqLog := helps.UpstreamRequestLog{ URL: wsURL, Method: "WEBSOCKET", Headers: wsHeaders.Clone(), @@ -423,18 +423,18 @@ func (e *CodexWebsocketsExecutor) ExecuteStream(ctx context.Context, auth *clipr AuthLabel: authLabel, AuthType: authType, AuthValue: authValue, - }) + } + helps.RecordAPIWebsocketRequest(ctx, e.cfg, wsReqLog) conn, respHS, errDial := e.ensureUpstreamConn(ctx, auth, sess, authID, wsURL, wsHeaders) var upstreamHeaders http.Header if respHS != nil { upstreamHeaders = respHS.Header.Clone() - helps.RecordAPIResponseMetadata(ctx, e.cfg, respHS.StatusCode, respHS.Header.Clone()) } if errDial != nil { bodyErr := websocketHandshakeBody(respHS) - if len(bodyErr) > 0 { - helps.AppendAPIResponseChunk(ctx, e.cfg, bodyErr) + if respHS != nil { + helps.RecordAPIWebsocketUpgradeRejection(ctx, e.cfg, websocketUpgradeRequestLog(wsReqLog), respHS.StatusCode, respHS.Header.Clone(), bodyErr) } if respHS != nil && respHS.StatusCode == http.StatusUpgradeRequired { return e.CodexExecutor.ExecuteStream(ctx, auth, req, opts) @@ -442,13 +442,13 @@ func (e *CodexWebsocketsExecutor) ExecuteStream(ctx context.Context, auth *clipr if respHS != nil && respHS.StatusCode > 0 { return nil, statusErr{code: respHS.StatusCode, msg: string(bodyErr)} } - helps.RecordAPIResponseError(ctx, e.cfg, errDial) + helps.RecordAPIWebsocketError(ctx, e.cfg, "dial", errDial) if sess != nil { sess.reqMu.Unlock() } return nil, errDial } - closeHTTPResponseBody(respHS, "codex websockets executor: close handshake response body error") + recordAPIWebsocketHandshake(ctx, e.cfg, respHS) if sess == nil { logCodexWebsocketConnected(executionSessionID, authID, wsURL) @@ -461,20 +461,21 @@ func (e *CodexWebsocketsExecutor) ExecuteStream(ctx context.Context, auth *clipr } if errSend := writeCodexWebsocketMessage(sess, conn, wsReqBody); errSend != nil { - helps.RecordAPIResponseError(ctx, e.cfg, errSend) + helps.RecordAPIWebsocketError(ctx, e.cfg, "send", errSend) if sess != nil { e.invalidateUpstreamConn(sess, conn, "send_error", errSend) // Retry once with a new websocket connection for the same execution session. - connRetry, _, errDialRetry := e.ensureUpstreamConn(ctx, auth, sess, authID, wsURL, wsHeaders) + connRetry, respHSRetry, errDialRetry := e.ensureUpstreamConn(ctx, auth, sess, authID, wsURL, wsHeaders) if errDialRetry != nil || connRetry == nil { - helps.RecordAPIResponseError(ctx, e.cfg, errDialRetry) + closeHTTPResponseBody(respHSRetry, "codex websockets executor: close handshake response body error") + helps.RecordAPIWebsocketError(ctx, e.cfg, "dial_retry", errDialRetry) sess.clearActive(readCh) sess.reqMu.Unlock() return nil, errDialRetry } wsReqBodyRetry := buildCodexWebsocketRequestBody(body) - helps.RecordAPIRequest(ctx, e.cfg, helps.UpstreamRequestLog{ + helps.RecordAPIWebsocketRequest(ctx, e.cfg, helps.UpstreamRequestLog{ URL: wsURL, Method: "WEBSOCKET", Headers: wsHeaders.Clone(), @@ -485,8 +486,9 @@ func (e *CodexWebsocketsExecutor) ExecuteStream(ctx context.Context, auth *clipr AuthType: authType, AuthValue: authValue, }) + recordAPIWebsocketHandshake(ctx, e.cfg, respHSRetry) if errSendRetry := writeCodexWebsocketMessage(sess, connRetry, wsReqBodyRetry); errSendRetry != nil { - helps.RecordAPIResponseError(ctx, e.cfg, errSendRetry) + helps.RecordAPIWebsocketError(ctx, e.cfg, "send_retry", errSendRetry) e.invalidateUpstreamConn(sess, connRetry, "send_error", errSendRetry) sess.clearActive(readCh) sess.reqMu.Unlock() @@ -552,7 +554,7 @@ func (e *CodexWebsocketsExecutor) ExecuteStream(ctx context.Context, auth *clipr } terminateReason = "read_error" terminateErr = errRead - helps.RecordAPIResponseError(ctx, e.cfg, errRead) + helps.RecordAPIWebsocketError(ctx, e.cfg, "read", errRead) reporter.PublishFailure(ctx) _ = send(cliproxyexecutor.StreamChunk{Err: errRead}) return @@ -562,7 +564,7 @@ func (e *CodexWebsocketsExecutor) ExecuteStream(ctx context.Context, auth *clipr err = fmt.Errorf("codex websockets executor: unexpected binary message") terminateReason = "unexpected_binary" terminateErr = err - helps.RecordAPIResponseError(ctx, e.cfg, err) + helps.RecordAPIWebsocketError(ctx, e.cfg, "unexpected_binary", err) reporter.PublishFailure(ctx) if sess != nil { e.invalidateUpstreamConn(sess, conn, "unexpected_binary", err) @@ -577,12 +579,12 @@ func (e *CodexWebsocketsExecutor) ExecuteStream(ctx context.Context, auth *clipr if len(payload) == 0 { continue } - helps.AppendAPIResponseChunk(ctx, e.cfg, payload) + helps.AppendAPIWebsocketResponse(ctx, e.cfg, payload) if wsErr, ok := parseCodexWebsocketError(payload); ok { terminateReason = "upstream_error" terminateErr = wsErr - helps.RecordAPIResponseError(ctx, e.cfg, wsErr) + helps.RecordAPIWebsocketError(ctx, e.cfg, "upstream_error", wsErr) reporter.PublishFailure(ctx) if sess != nil { e.invalidateUpstreamConn(sess, conn, "upstream_error", wsErr) @@ -1022,6 +1024,32 @@ func encodeCodexWebsocketAsSSE(payload []byte) []byte { return line } +func websocketUpgradeRequestLog(info helps.UpstreamRequestLog) helps.UpstreamRequestLog { + upgradeInfo := info + upgradeInfo.URL = helps.WebsocketUpgradeRequestURL(info.URL) + upgradeInfo.Method = http.MethodGet + upgradeInfo.Body = nil + upgradeInfo.Headers = info.Headers.Clone() + if upgradeInfo.Headers == nil { + upgradeInfo.Headers = make(http.Header) + } + if strings.TrimSpace(upgradeInfo.Headers.Get("Connection")) == "" { + upgradeInfo.Headers.Set("Connection", "Upgrade") + } + if strings.TrimSpace(upgradeInfo.Headers.Get("Upgrade")) == "" { + upgradeInfo.Headers.Set("Upgrade", "websocket") + } + return upgradeInfo +} + +func recordAPIWebsocketHandshake(ctx context.Context, cfg *config.Config, resp *http.Response) { + if resp == nil { + return + } + helps.RecordAPIWebsocketHandshake(ctx, cfg, resp.StatusCode, resp.Header.Clone()) + closeHTTPResponseBody(resp, "codex websockets executor: close handshake response body error") +} + func websocketHandshakeBody(resp *http.Response) []byte { if resp == nil || resp.Body == nil { return nil diff --git a/internal/runtime/executor/helps/logging_helpers.go b/internal/runtime/executor/helps/logging_helpers.go index f9389edd..767c8820 100644 --- a/internal/runtime/executor/helps/logging_helpers.go +++ b/internal/runtime/executor/helps/logging_helpers.go @@ -6,6 +6,7 @@ import ( "fmt" "html" "net/http" + "net/url" "sort" "strings" "time" @@ -19,9 +20,10 @@ import ( ) const ( - apiAttemptsKey = "API_UPSTREAM_ATTEMPTS" - apiRequestKey = "API_REQUEST" - apiResponseKey = "API_RESPONSE" + apiAttemptsKey = "API_UPSTREAM_ATTEMPTS" + apiRequestKey = "API_REQUEST" + apiResponseKey = "API_RESPONSE" + apiWebsocketTimelineKey = "API_WEBSOCKET_TIMELINE" ) // UpstreamRequestLog captures the outbound upstream request details for logging. @@ -46,6 +48,7 @@ type upstreamAttempt struct { headersWritten bool bodyStarted bool bodyHasContent bool + prevWasSSEEvent bool errorWritten bool } @@ -173,15 +176,157 @@ func AppendAPIResponseChunk(ctx context.Context, cfg *config.Config, chunk []byt attempt.response.WriteString("Body:\n") attempt.bodyStarted = true } + currentChunkIsSSEEvent := bytes.HasPrefix(data, []byte("event:")) + currentChunkIsSSEData := bytes.HasPrefix(data, []byte("data:")) if attempt.bodyHasContent { - attempt.response.WriteString("\n\n") + separator := "\n\n" + if attempt.prevWasSSEEvent && currentChunkIsSSEData { + separator = "\n" + } + attempt.response.WriteString(separator) } attempt.response.WriteString(string(data)) attempt.bodyHasContent = true + attempt.prevWasSSEEvent = currentChunkIsSSEEvent updateAggregatedResponse(ginCtx, attempts) } +// RecordAPIWebsocketRequest stores an upstream websocket request event in Gin context. +func RecordAPIWebsocketRequest(ctx context.Context, cfg *config.Config, info UpstreamRequestLog) { + if cfg == nil || !cfg.RequestLog { + return + } + ginCtx := ginContextFrom(ctx) + if ginCtx == nil { + return + } + + builder := &strings.Builder{} + builder.WriteString(fmt.Sprintf("Timestamp: %s\n", time.Now().Format(time.RFC3339Nano))) + builder.WriteString("Event: api.websocket.request\n") + if info.URL != "" { + builder.WriteString(fmt.Sprintf("Upstream URL: %s\n", info.URL)) + } + if auth := formatAuthInfo(info); auth != "" { + builder.WriteString(fmt.Sprintf("Auth: %s\n", auth)) + } + builder.WriteString("Headers:\n") + writeHeaders(builder, info.Headers) + builder.WriteString("\nBody:\n") + if len(info.Body) > 0 { + builder.Write(info.Body) + } else { + builder.WriteString("") + } + builder.WriteString("\n") + + appendAPIWebsocketTimeline(ginCtx, []byte(builder.String())) +} + +// RecordAPIWebsocketHandshake stores the upstream websocket handshake response metadata. +func RecordAPIWebsocketHandshake(ctx context.Context, cfg *config.Config, status int, headers http.Header) { + if cfg == nil || !cfg.RequestLog { + return + } + ginCtx := ginContextFrom(ctx) + if ginCtx == nil { + return + } + + builder := &strings.Builder{} + builder.WriteString(fmt.Sprintf("Timestamp: %s\n", time.Now().Format(time.RFC3339Nano))) + builder.WriteString("Event: api.websocket.handshake\n") + if status > 0 { + builder.WriteString(fmt.Sprintf("Status: %d\n", status)) + } + builder.WriteString("Headers:\n") + writeHeaders(builder, headers) + builder.WriteString("\n") + + appendAPIWebsocketTimeline(ginCtx, []byte(builder.String())) +} + +// RecordAPIWebsocketUpgradeRejection stores a rejected websocket upgrade as an HTTP attempt. +func RecordAPIWebsocketUpgradeRejection(ctx context.Context, cfg *config.Config, info UpstreamRequestLog, status int, headers http.Header, body []byte) { + if cfg == nil || !cfg.RequestLog { + return + } + ginCtx := ginContextFrom(ctx) + if ginCtx == nil { + return + } + + RecordAPIRequest(ctx, cfg, info) + RecordAPIResponseMetadata(ctx, cfg, status, headers) + AppendAPIResponseChunk(ctx, cfg, body) +} + +// WebsocketUpgradeRequestURL converts a websocket URL back to its HTTP handshake URL for logging. +func WebsocketUpgradeRequestURL(rawURL string) string { + trimmedURL := strings.TrimSpace(rawURL) + if trimmedURL == "" { + return "" + } + parsed, err := url.Parse(trimmedURL) + if err != nil { + return trimmedURL + } + switch strings.ToLower(parsed.Scheme) { + case "ws": + parsed.Scheme = "http" + case "wss": + parsed.Scheme = "https" + } + return parsed.String() +} + +// AppendAPIWebsocketResponse stores an upstream websocket response frame in Gin context. +func AppendAPIWebsocketResponse(ctx context.Context, cfg *config.Config, payload []byte) { + if cfg == nil || !cfg.RequestLog { + return + } + data := bytes.TrimSpace(payload) + if len(data) == 0 { + return + } + ginCtx := ginContextFrom(ctx) + if ginCtx == nil { + return + } + markAPIResponseTimestamp(ginCtx) + + builder := &strings.Builder{} + builder.WriteString(fmt.Sprintf("Timestamp: %s\n", time.Now().Format(time.RFC3339Nano))) + builder.WriteString("Event: api.websocket.response\n") + builder.Write(data) + builder.WriteString("\n") + + appendAPIWebsocketTimeline(ginCtx, []byte(builder.String())) +} + +// RecordAPIWebsocketError stores an upstream websocket error event in Gin context. +func RecordAPIWebsocketError(ctx context.Context, cfg *config.Config, stage string, err error) { + if cfg == nil || !cfg.RequestLog || err == nil { + return + } + ginCtx := ginContextFrom(ctx) + if ginCtx == nil { + return + } + markAPIResponseTimestamp(ginCtx) + + builder := &strings.Builder{} + builder.WriteString(fmt.Sprintf("Timestamp: %s\n", time.Now().Format(time.RFC3339Nano))) + builder.WriteString("Event: api.websocket.error\n") + if trimmed := strings.TrimSpace(stage); trimmed != "" { + builder.WriteString(fmt.Sprintf("Stage: %s\n", trimmed)) + } + builder.WriteString(fmt.Sprintf("Error: %s\n", err.Error())) + + appendAPIWebsocketTimeline(ginCtx, []byte(builder.String())) +} + func ginContextFrom(ctx context.Context) *gin.Context { ginCtx, _ := ctx.Value("gin").(*gin.Context) return ginCtx @@ -259,6 +404,40 @@ func updateAggregatedResponse(ginCtx *gin.Context, attempts []*upstreamAttempt) ginCtx.Set(apiResponseKey, []byte(builder.String())) } +func appendAPIWebsocketTimeline(ginCtx *gin.Context, chunk []byte) { + if ginCtx == nil { + return + } + data := bytes.TrimSpace(chunk) + if len(data) == 0 { + return + } + if existing, exists := ginCtx.Get(apiWebsocketTimelineKey); exists { + if existingBytes, ok := existing.([]byte); ok && len(existingBytes) > 0 { + combined := make([]byte, 0, len(existingBytes)+len(data)+2) + combined = append(combined, existingBytes...) + if !bytes.HasSuffix(existingBytes, []byte("\n")) { + combined = append(combined, '\n') + } + combined = append(combined, '\n') + combined = append(combined, data...) + ginCtx.Set(apiWebsocketTimelineKey, combined) + return + } + } + ginCtx.Set(apiWebsocketTimelineKey, bytes.Clone(data)) +} + +func markAPIResponseTimestamp(ginCtx *gin.Context) { + if ginCtx == nil { + return + } + if _, exists := ginCtx.Get("API_RESPONSE_TIMESTAMP"); exists { + return + } + ginCtx.Set("API_RESPONSE_TIMESTAMP", time.Now()) +} + func writeHeaders(builder *strings.Builder, headers http.Header) { if builder == nil { return diff --git a/sdk/api/handlers/openai/openai_responses_websocket.go b/sdk/api/handlers/openai/openai_responses_websocket.go index df46d971..1080f5cd 100644 --- a/sdk/api/handlers/openai/openai_responses_websocket.go +++ b/sdk/api/handlers/openai/openai_responses_websocket.go @@ -32,7 +32,7 @@ const ( wsEventTypeCompleted = "response.completed" wsDoneMarker = "[DONE]" wsTurnStateHeader = "x-codex-turn-state" - wsRequestBodyKey = "REQUEST_BODY_OVERRIDE" + wsTimelineBodyKey = "WEBSOCKET_TIMELINE_OVERRIDE" ) var responsesWebsocketUpgrader = websocket.Upgrader{ @@ -57,10 +57,11 @@ func (h *OpenAIResponsesAPIHandler) ResponsesWebsocket(c *gin.Context) { clientIP := websocketClientAddress(c) log.Infof("responses websocket: client connected id=%s remote=%s", passthroughSessionID, clientIP) var wsTerminateErr error - var wsBodyLog strings.Builder + var wsTimelineLog strings.Builder defer func() { releaseResponsesWebsocketToolCaches(downstreamSessionKey) if wsTerminateErr != nil { + appendWebsocketTimelineDisconnect(&wsTimelineLog, wsTerminateErr, time.Now()) // log.Infof("responses websocket: session closing id=%s reason=%v", passthroughSessionID, wsTerminateErr) } else { log.Infof("responses websocket: session closing id=%s", passthroughSessionID) @@ -69,7 +70,7 @@ func (h *OpenAIResponsesAPIHandler) ResponsesWebsocket(c *gin.Context) { h.AuthManager.CloseExecutionSession(passthroughSessionID) log.Infof("responses websocket: upstream execution session closed id=%s", passthroughSessionID) } - setWebsocketRequestBody(c, wsBodyLog.String()) + setWebsocketTimelineBody(c, wsTimelineLog.String()) if errClose := conn.Close(); errClose != nil { log.Warnf("responses websocket: close connection error: %v", errClose) } @@ -83,7 +84,6 @@ func (h *OpenAIResponsesAPIHandler) ResponsesWebsocket(c *gin.Context) { msgType, payload, errReadMessage := conn.ReadMessage() if errReadMessage != nil { wsTerminateErr = errReadMessage - appendWebsocketEvent(&wsBodyLog, "disconnect", []byte(errReadMessage.Error())) if websocket.IsCloseError(errReadMessage, websocket.CloseNormalClosure, websocket.CloseGoingAway, websocket.CloseNoStatusReceived) { log.Infof("responses websocket: client disconnected id=%s error=%v", passthroughSessionID, errReadMessage) } else { @@ -101,7 +101,7 @@ func (h *OpenAIResponsesAPIHandler) ResponsesWebsocket(c *gin.Context) { // websocketPayloadEventType(payload), // websocketPayloadPreview(payload), // ) - appendWebsocketEvent(&wsBodyLog, "request", payload) + appendWebsocketTimelineEvent(&wsTimelineLog, "request", payload, time.Now()) allowIncrementalInputWithPreviousResponseID := false if pinnedAuthID != "" && h != nil && h.AuthManager != nil { @@ -128,8 +128,7 @@ func (h *OpenAIResponsesAPIHandler) ResponsesWebsocket(c *gin.Context) { if errMsg != nil { h.LoggingAPIResponseError(context.WithValue(context.Background(), "gin", c), errMsg) markAPIResponseTimestamp(c) - errorPayload, errWrite := writeResponsesWebsocketError(conn, errMsg) - appendWebsocketEvent(&wsBodyLog, "response", errorPayload) + errorPayload, errWrite := writeResponsesWebsocketError(conn, &wsTimelineLog, errMsg) log.Infof( "responses websocket: downstream_out id=%s type=%d event=%s payload=%s", passthroughSessionID, @@ -157,9 +156,8 @@ func (h *OpenAIResponsesAPIHandler) ResponsesWebsocket(c *gin.Context) { } lastRequest = updatedLastRequest lastResponseOutput = []byte("[]") - if errWrite := writeResponsesWebsocketSyntheticPrewarm(c, conn, requestJSON, &wsBodyLog, passthroughSessionID); errWrite != nil { + if errWrite := writeResponsesWebsocketSyntheticPrewarm(c, conn, requestJSON, &wsTimelineLog, passthroughSessionID); errWrite != nil { wsTerminateErr = errWrite - appendWebsocketEvent(&wsBodyLog, "disconnect", []byte(errWrite.Error())) return } continue @@ -192,10 +190,9 @@ func (h *OpenAIResponsesAPIHandler) ResponsesWebsocket(c *gin.Context) { } dataChan, _, errChan := h.ExecuteStreamWithAuthManager(cliCtx, h.HandlerType(), modelName, requestJSON, "") - completedOutput, errForward := h.forwardResponsesWebsocket(c, conn, cliCancel, dataChan, errChan, &wsBodyLog, passthroughSessionID) + completedOutput, errForward := h.forwardResponsesWebsocket(c, conn, cliCancel, dataChan, errChan, &wsTimelineLog, passthroughSessionID) if errForward != nil { wsTerminateErr = errForward - appendWebsocketEvent(&wsBodyLog, "disconnect", []byte(errForward.Error())) log.Warnf("responses websocket: forward failed id=%s error=%v", passthroughSessionID, errForward) return } @@ -597,7 +594,7 @@ func writeResponsesWebsocketSyntheticPrewarm( c *gin.Context, conn *websocket.Conn, requestJSON []byte, - wsBodyLog *strings.Builder, + wsTimelineLog *strings.Builder, sessionID string, ) error { payloads, errPayloads := syntheticResponsesWebsocketPrewarmPayloads(requestJSON) @@ -606,7 +603,6 @@ func writeResponsesWebsocketSyntheticPrewarm( } for i := 0; i < len(payloads); i++ { markAPIResponseTimestamp(c) - appendWebsocketEvent(wsBodyLog, "response", payloads[i]) // log.Infof( // "responses websocket: downstream_out id=%s type=%d event=%s payload=%s", // sessionID, @@ -614,7 +610,7 @@ func writeResponsesWebsocketSyntheticPrewarm( // websocketPayloadEventType(payloads[i]), // websocketPayloadPreview(payloads[i]), // ) - if errWrite := conn.WriteMessage(websocket.TextMessage, payloads[i]); errWrite != nil { + if errWrite := writeResponsesWebsocketPayload(conn, wsTimelineLog, payloads[i], time.Now()); errWrite != nil { log.Warnf( "responses websocket: downstream_out write failed id=%s event=%s error=%v", sessionID, @@ -713,7 +709,7 @@ func (h *OpenAIResponsesAPIHandler) forwardResponsesWebsocket( cancel handlers.APIHandlerCancelFunc, data <-chan []byte, errs <-chan *interfaces.ErrorMessage, - wsBodyLog *strings.Builder, + wsTimelineLog *strings.Builder, sessionID string, ) ([]byte, error) { completed := false @@ -736,8 +732,7 @@ func (h *OpenAIResponsesAPIHandler) forwardResponsesWebsocket( if errMsg != nil { h.LoggingAPIResponseError(context.WithValue(context.Background(), "gin", c), errMsg) markAPIResponseTimestamp(c) - errorPayload, errWrite := writeResponsesWebsocketError(conn, errMsg) - appendWebsocketEvent(wsBodyLog, "response", errorPayload) + errorPayload, errWrite := writeResponsesWebsocketError(conn, wsTimelineLog, errMsg) log.Infof( "responses websocket: downstream_out id=%s type=%d event=%s payload=%s", sessionID, @@ -771,8 +766,7 @@ func (h *OpenAIResponsesAPIHandler) forwardResponsesWebsocket( } h.LoggingAPIResponseError(context.WithValue(context.Background(), "gin", c), errMsg) markAPIResponseTimestamp(c) - errorPayload, errWrite := writeResponsesWebsocketError(conn, errMsg) - appendWebsocketEvent(wsBodyLog, "response", errorPayload) + errorPayload, errWrite := writeResponsesWebsocketError(conn, wsTimelineLog, errMsg) log.Infof( "responses websocket: downstream_out id=%s type=%d event=%s payload=%s", sessionID, @@ -806,7 +800,6 @@ func (h *OpenAIResponsesAPIHandler) forwardResponsesWebsocket( completedOutput = responseCompletedOutputFromPayload(payloads[i]) } markAPIResponseTimestamp(c) - appendWebsocketEvent(wsBodyLog, "response", payloads[i]) // log.Infof( // "responses websocket: downstream_out id=%s type=%d event=%s payload=%s", // sessionID, @@ -814,7 +807,7 @@ func (h *OpenAIResponsesAPIHandler) forwardResponsesWebsocket( // websocketPayloadEventType(payloads[i]), // websocketPayloadPreview(payloads[i]), // ) - if errWrite := conn.WriteMessage(websocket.TextMessage, payloads[i]); errWrite != nil { + if errWrite := writeResponsesWebsocketPayload(conn, wsTimelineLog, payloads[i], time.Now()); errWrite != nil { log.Warnf( "responses websocket: downstream_out write failed id=%s event=%s error=%v", sessionID, @@ -870,7 +863,7 @@ func websocketJSONPayloadsFromChunk(chunk []byte) [][]byte { return payloads } -func writeResponsesWebsocketError(conn *websocket.Conn, errMsg *interfaces.ErrorMessage) ([]byte, error) { +func writeResponsesWebsocketError(conn *websocket.Conn, wsTimelineLog *strings.Builder, errMsg *interfaces.ErrorMessage) ([]byte, error) { status := http.StatusInternalServerError errText := http.StatusText(status) if errMsg != nil { @@ -940,7 +933,7 @@ func writeResponsesWebsocketError(conn *websocket.Conn, errMsg *interfaces.Error } } - return payload, conn.WriteMessage(websocket.TextMessage, payload) + return payload, writeResponsesWebsocketPayload(conn, wsTimelineLog, payload, time.Now()) } func appendWebsocketEvent(builder *strings.Builder, eventType string, payload []byte) { @@ -979,7 +972,11 @@ func websocketPayloadPreview(payload []byte) string { return previewText } -func setWebsocketRequestBody(c *gin.Context, body string) { +func setWebsocketTimelineBody(c *gin.Context, body string) { + setWebsocketBody(c, wsTimelineBodyKey, body) +} + +func setWebsocketBody(c *gin.Context, key string, body string) { if c == nil { return } @@ -987,7 +984,40 @@ func setWebsocketRequestBody(c *gin.Context, body string) { if trimmedBody == "" { return } - c.Set(wsRequestBodyKey, []byte(trimmedBody)) + c.Set(key, []byte(trimmedBody)) +} + +func writeResponsesWebsocketPayload(conn *websocket.Conn, wsTimelineLog *strings.Builder, payload []byte, timestamp time.Time) error { + appendWebsocketTimelineEvent(wsTimelineLog, "response", payload, timestamp) + return conn.WriteMessage(websocket.TextMessage, payload) +} + +func appendWebsocketTimelineDisconnect(builder *strings.Builder, err error, timestamp time.Time) { + if err == nil { + return + } + appendWebsocketTimelineEvent(builder, "disconnect", []byte(err.Error()), timestamp) +} + +func appendWebsocketTimelineEvent(builder *strings.Builder, eventType string, payload []byte, timestamp time.Time) { + if builder == nil { + return + } + trimmedPayload := bytes.TrimSpace(payload) + if len(trimmedPayload) == 0 { + return + } + if builder.Len() > 0 { + builder.WriteString("\n") + } + builder.WriteString("Timestamp: ") + builder.WriteString(timestamp.Format(time.RFC3339Nano)) + builder.WriteString("\n") + builder.WriteString("Event: websocket.") + builder.WriteString(eventType) + builder.WriteString("\n") + builder.Write(trimmedPayload) + builder.WriteString("\n") } func markAPIResponseTimestamp(c *gin.Context) { diff --git a/sdk/api/handlers/openai/openai_responses_websocket_test.go b/sdk/api/handlers/openai/openai_responses_websocket_test.go index 773df18e..6fce1bf1 100644 --- a/sdk/api/handlers/openai/openai_responses_websocket_test.go +++ b/sdk/api/handlers/openai/openai_responses_websocket_test.go @@ -392,27 +392,45 @@ func TestAppendWebsocketEvent(t *testing.T) { } } -func TestSetWebsocketRequestBody(t *testing.T) { +func TestAppendWebsocketTimelineEvent(t *testing.T) { + var builder strings.Builder + ts := time.Date(2026, time.April, 1, 12, 34, 56, 789000000, time.UTC) + + appendWebsocketTimelineEvent(&builder, "request", []byte(" {\"type\":\"response.create\"}\n"), ts) + + got := builder.String() + if !strings.Contains(got, "Timestamp: 2026-04-01T12:34:56.789Z") { + t.Fatalf("timeline timestamp not found: %s", got) + } + if !strings.Contains(got, "Event: websocket.request") { + t.Fatalf("timeline event not found: %s", got) + } + if !strings.Contains(got, "{\"type\":\"response.create\"}") { + t.Fatalf("timeline payload not found: %s", got) + } +} + +func TestSetWebsocketTimelineBody(t *testing.T) { gin.SetMode(gin.TestMode) recorder := httptest.NewRecorder() c, _ := gin.CreateTestContext(recorder) - setWebsocketRequestBody(c, " \n ") - if _, exists := c.Get(wsRequestBodyKey); exists { - t.Fatalf("request body key should not be set for empty body") + setWebsocketTimelineBody(c, " \n ") + if _, exists := c.Get(wsTimelineBodyKey); exists { + t.Fatalf("timeline body key should not be set for empty body") } - setWebsocketRequestBody(c, "event body") - value, exists := c.Get(wsRequestBodyKey) + setWebsocketTimelineBody(c, "timeline body") + value, exists := c.Get(wsTimelineBodyKey) if !exists { - t.Fatalf("request body key not set") + t.Fatalf("timeline body key not set") } bodyBytes, ok := value.([]byte) if !ok { - t.Fatalf("request body key type mismatch") + t.Fatalf("timeline body key type mismatch") } - if string(bodyBytes) != "event body" { - t.Fatalf("request body = %q, want %q", string(bodyBytes), "event body") + if string(bodyBytes) != "timeline body" { + t.Fatalf("timeline body = %q, want %q", string(bodyBytes), "timeline body") } } @@ -544,14 +562,14 @@ func TestForwardResponsesWebsocketPreservesCompletedEvent(t *testing.T) { close(data) close(errCh) - var bodyLog strings.Builder + var timelineLog strings.Builder completedOutput, err := (*OpenAIResponsesAPIHandler)(nil).forwardResponsesWebsocket( ctx, conn, func(...interface{}) {}, data, errCh, - &bodyLog, + &timelineLog, "session-1", ) if err != nil { @@ -562,6 +580,10 @@ func TestForwardResponsesWebsocketPreservesCompletedEvent(t *testing.T) { serverErrCh <- errors.New("completed output not captured") return } + if !strings.Contains(timelineLog.String(), "Event: websocket.response") { + serverErrCh <- errors.New("websocket timeline did not capture downstream response") + return + } serverErrCh <- nil })) defer server.Close() @@ -594,6 +616,116 @@ func TestForwardResponsesWebsocketPreservesCompletedEvent(t *testing.T) { } } +func TestForwardResponsesWebsocketLogsAttemptedResponseOnWriteFailure(t *testing.T) { + gin.SetMode(gin.TestMode) + + serverErrCh := make(chan error, 1) + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + conn, err := responsesWebsocketUpgrader.Upgrade(w, r, nil) + if err != nil { + serverErrCh <- err + return + } + + ctx, _ := gin.CreateTestContext(httptest.NewRecorder()) + ctx.Request = r + + data := make(chan []byte, 1) + errCh := make(chan *interfaces.ErrorMessage) + data <- []byte("data: {\"type\":\"response.completed\",\"response\":{\"id\":\"resp-1\",\"output\":[{\"type\":\"message\",\"id\":\"out-1\"}]}}\n\n") + close(data) + close(errCh) + + var timelineLog strings.Builder + if errClose := conn.Close(); errClose != nil { + serverErrCh <- errClose + return + } + + _, err = (*OpenAIResponsesAPIHandler)(nil).forwardResponsesWebsocket( + ctx, + conn, + func(...interface{}) {}, + data, + errCh, + &timelineLog, + "session-1", + ) + if err == nil { + serverErrCh <- errors.New("expected websocket write failure") + return + } + if !strings.Contains(timelineLog.String(), "Event: websocket.response") { + serverErrCh <- errors.New("websocket timeline did not capture attempted downstream response") + return + } + if !strings.Contains(timelineLog.String(), "\"type\":\"response.completed\"") { + serverErrCh <- errors.New("websocket timeline did not retain attempted payload") + return + } + serverErrCh <- nil + })) + defer server.Close() + + wsURL := "ws" + strings.TrimPrefix(server.URL, "http") + conn, _, err := websocket.DefaultDialer.Dial(wsURL, nil) + if err != nil { + t.Fatalf("dial websocket: %v", err) + } + defer func() { + _ = conn.Close() + }() + + if errServer := <-serverErrCh; errServer != nil { + t.Fatalf("server error: %v", errServer) + } +} + +func TestResponsesWebsocketTimelineRecordsDisconnectEvent(t *testing.T) { + gin.SetMode(gin.TestMode) + + manager := coreauth.NewManager(nil, nil, nil) + base := handlers.NewBaseAPIHandlers(&sdkconfig.SDKConfig{}, manager) + h := NewOpenAIResponsesAPIHandler(base) + + timelineCh := make(chan string, 1) + router := gin.New() + router.GET("/v1/responses/ws", func(c *gin.Context) { + h.ResponsesWebsocket(c) + timeline := "" + if value, exists := c.Get(wsTimelineBodyKey); exists { + if body, ok := value.([]byte); ok { + timeline = string(body) + } + } + timelineCh <- timeline + }) + + server := httptest.NewServer(router) + defer server.Close() + + wsURL := "ws" + strings.TrimPrefix(server.URL, "http") + "/v1/responses/ws" + conn, _, err := websocket.DefaultDialer.Dial(wsURL, nil) + if err != nil { + t.Fatalf("dial websocket: %v", err) + } + + closePayload := websocket.FormatCloseMessage(websocket.CloseGoingAway, "client closing") + if err = conn.WriteControl(websocket.CloseMessage, closePayload, time.Now().Add(time.Second)); err != nil { + t.Fatalf("write close control: %v", err) + } + _ = conn.Close() + + select { + case timeline := <-timelineCh: + if !strings.Contains(timeline, "Event: websocket.disconnect") { + t.Fatalf("websocket timeline missing disconnect event: %s", timeline) + } + case <-time.After(5 * time.Second): + t.Fatal("timed out waiting for websocket timeline") + } +} + func TestWebsocketUpstreamSupportsIncrementalInputForModel(t *testing.T) { manager := coreauth.NewManager(nil, nil, nil) auth := &coreauth.Auth{