Merge pull request #2490 from router-for-me/logs

Refactor websocket logging and error handling
This commit is contained in:
Luis Pater
2026-04-02 20:47:31 +08:00
committed by GitHub
8 changed files with 926 additions and 124 deletions

View File

@@ -15,6 +15,8 @@ import (
)
const requestBodyOverrideContextKey = "REQUEST_BODY_OVERRIDE"
const responseBodyOverrideContextKey = "RESPONSE_BODY_OVERRIDE"
const websocketTimelineOverrideContextKey = "WEBSOCKET_TIMELINE_OVERRIDE"
// RequestInfo holds essential details of an incoming HTTP request for logging purposes.
type RequestInfo struct {
@@ -304,6 +306,10 @@ func (w *ResponseWriterWrapper) Finalize(c *gin.Context) error {
if len(apiResponse) > 0 {
_ = w.streamWriter.WriteAPIResponse(apiResponse)
}
apiWebsocketTimeline := w.extractAPIWebsocketTimeline(c)
if len(apiWebsocketTimeline) > 0 {
_ = w.streamWriter.WriteAPIWebsocketTimeline(apiWebsocketTimeline)
}
if err := w.streamWriter.Close(); err != nil {
w.streamWriter = nil
return err
@@ -312,7 +318,7 @@ func (w *ResponseWriterWrapper) Finalize(c *gin.Context) error {
return nil
}
return w.logRequest(w.extractRequestBody(c), finalStatusCode, w.cloneHeaders(), w.body.Bytes(), w.extractAPIRequest(c), w.extractAPIResponse(c), w.extractAPIResponseTimestamp(c), slicesAPIResponseError, forceLog)
return w.logRequest(w.extractRequestBody(c), finalStatusCode, w.cloneHeaders(), w.extractResponseBody(c), w.extractWebsocketTimeline(c), w.extractAPIRequest(c), w.extractAPIResponse(c), w.extractAPIWebsocketTimeline(c), w.extractAPIResponseTimestamp(c), slicesAPIResponseError, forceLog)
}
func (w *ResponseWriterWrapper) cloneHeaders() map[string][]string {
@@ -352,6 +358,18 @@ func (w *ResponseWriterWrapper) extractAPIResponse(c *gin.Context) []byte {
return data
}
func (w *ResponseWriterWrapper) extractAPIWebsocketTimeline(c *gin.Context) []byte {
apiTimeline, isExist := c.Get("API_WEBSOCKET_TIMELINE")
if !isExist {
return nil
}
data, ok := apiTimeline.([]byte)
if !ok || len(data) == 0 {
return nil
}
return bytes.Clone(data)
}
func (w *ResponseWriterWrapper) extractAPIResponseTimestamp(c *gin.Context) time.Time {
ts, isExist := c.Get("API_RESPONSE_TIMESTAMP")
if !isExist {
@@ -364,19 +382,8 @@ func (w *ResponseWriterWrapper) extractAPIResponseTimestamp(c *gin.Context) time
}
func (w *ResponseWriterWrapper) extractRequestBody(c *gin.Context) []byte {
if c != nil {
if bodyOverride, isExist := c.Get(requestBodyOverrideContextKey); isExist {
switch value := bodyOverride.(type) {
case []byte:
if len(value) > 0 {
return bytes.Clone(value)
}
case string:
if strings.TrimSpace(value) != "" {
return []byte(value)
}
}
}
if body := extractBodyOverride(c, requestBodyOverrideContextKey); len(body) > 0 {
return body
}
if w.requestInfo != nil && len(w.requestInfo.Body) > 0 {
return w.requestInfo.Body
@@ -384,13 +391,48 @@ func (w *ResponseWriterWrapper) extractRequestBody(c *gin.Context) []byte {
return nil
}
func (w *ResponseWriterWrapper) logRequest(requestBody []byte, statusCode int, headers map[string][]string, body []byte, apiRequestBody, apiResponseBody []byte, apiResponseTimestamp time.Time, apiResponseErrors []*interfaces.ErrorMessage, forceLog bool) error {
func (w *ResponseWriterWrapper) extractResponseBody(c *gin.Context) []byte {
if body := extractBodyOverride(c, responseBodyOverrideContextKey); len(body) > 0 {
return body
}
if w.body == nil || w.body.Len() == 0 {
return nil
}
return bytes.Clone(w.body.Bytes())
}
func (w *ResponseWriterWrapper) extractWebsocketTimeline(c *gin.Context) []byte {
return extractBodyOverride(c, websocketTimelineOverrideContextKey)
}
func extractBodyOverride(c *gin.Context, key string) []byte {
if c == nil {
return nil
}
bodyOverride, isExist := c.Get(key)
if !isExist {
return nil
}
switch value := bodyOverride.(type) {
case []byte:
if len(value) > 0 {
return bytes.Clone(value)
}
case string:
if strings.TrimSpace(value) != "" {
return []byte(value)
}
}
return nil
}
func (w *ResponseWriterWrapper) logRequest(requestBody []byte, statusCode int, headers map[string][]string, body, websocketTimeline, apiRequestBody, apiResponseBody, apiWebsocketTimeline []byte, apiResponseTimestamp time.Time, apiResponseErrors []*interfaces.ErrorMessage, forceLog bool) error {
if w.requestInfo == nil {
return nil
}
if loggerWithOptions, ok := w.logger.(interface {
LogRequestWithOptions(string, string, map[string][]string, []byte, int, map[string][]string, []byte, []byte, []byte, []*interfaces.ErrorMessage, bool, string, time.Time, time.Time) error
LogRequestWithOptions(string, string, map[string][]string, []byte, int, map[string][]string, []byte, []byte, []byte, []byte, []byte, []*interfaces.ErrorMessage, bool, string, time.Time, time.Time) error
}); ok {
return loggerWithOptions.LogRequestWithOptions(
w.requestInfo.URL,
@@ -400,8 +442,10 @@ func (w *ResponseWriterWrapper) logRequest(requestBody []byte, statusCode int, h
statusCode,
headers,
body,
websocketTimeline,
apiRequestBody,
apiResponseBody,
apiWebsocketTimeline,
apiResponseErrors,
forceLog,
w.requestInfo.RequestID,
@@ -418,8 +462,10 @@ func (w *ResponseWriterWrapper) logRequest(requestBody []byte, statusCode int, h
statusCode,
headers,
body,
websocketTimeline,
apiRequestBody,
apiResponseBody,
apiWebsocketTimeline,
apiResponseErrors,
w.requestInfo.RequestID,
w.requestInfo.Timestamp,

View File

@@ -1,10 +1,14 @@
package middleware
import (
"bytes"
"net/http/httptest"
"testing"
"time"
"github.com/gin-gonic/gin"
"github.com/router-for-me/CLIProxyAPI/v6/internal/interfaces"
"github.com/router-for-me/CLIProxyAPI/v6/internal/logging"
)
func TestExtractRequestBodyPrefersOverride(t *testing.T) {
@@ -33,7 +37,7 @@ func TestExtractRequestBodySupportsStringOverride(t *testing.T) {
recorder := httptest.NewRecorder()
c, _ := gin.CreateTestContext(recorder)
wrapper := &ResponseWriterWrapper{}
wrapper := &ResponseWriterWrapper{body: &bytes.Buffer{}}
c.Set(requestBodyOverrideContextKey, "override-as-string")
body := wrapper.extractRequestBody(c)
@@ -41,3 +45,158 @@ func TestExtractRequestBodySupportsStringOverride(t *testing.T) {
t.Fatalf("request body = %q, want %q", string(body), "override-as-string")
}
}
func TestExtractResponseBodyPrefersOverride(t *testing.T) {
gin.SetMode(gin.TestMode)
recorder := httptest.NewRecorder()
c, _ := gin.CreateTestContext(recorder)
wrapper := &ResponseWriterWrapper{body: &bytes.Buffer{}}
wrapper.body.WriteString("original-response")
body := wrapper.extractResponseBody(c)
if string(body) != "original-response" {
t.Fatalf("response body = %q, want %q", string(body), "original-response")
}
c.Set(responseBodyOverrideContextKey, []byte("override-response"))
body = wrapper.extractResponseBody(c)
if string(body) != "override-response" {
t.Fatalf("response body = %q, want %q", string(body), "override-response")
}
body[0] = 'X'
if got := wrapper.extractResponseBody(c); string(got) != "override-response" {
t.Fatalf("response override should be cloned, got %q", string(got))
}
}
func TestExtractResponseBodySupportsStringOverride(t *testing.T) {
gin.SetMode(gin.TestMode)
recorder := httptest.NewRecorder()
c, _ := gin.CreateTestContext(recorder)
wrapper := &ResponseWriterWrapper{}
c.Set(responseBodyOverrideContextKey, "override-response-as-string")
body := wrapper.extractResponseBody(c)
if string(body) != "override-response-as-string" {
t.Fatalf("response body = %q, want %q", string(body), "override-response-as-string")
}
}
func TestExtractBodyOverrideClonesBytes(t *testing.T) {
gin.SetMode(gin.TestMode)
recorder := httptest.NewRecorder()
c, _ := gin.CreateTestContext(recorder)
override := []byte("body-override")
c.Set(requestBodyOverrideContextKey, override)
body := extractBodyOverride(c, requestBodyOverrideContextKey)
if !bytes.Equal(body, override) {
t.Fatalf("body override = %q, want %q", string(body), string(override))
}
body[0] = 'X'
if !bytes.Equal(override, []byte("body-override")) {
t.Fatalf("override mutated: %q", string(override))
}
}
func TestExtractWebsocketTimelineUsesOverride(t *testing.T) {
gin.SetMode(gin.TestMode)
recorder := httptest.NewRecorder()
c, _ := gin.CreateTestContext(recorder)
wrapper := &ResponseWriterWrapper{}
if got := wrapper.extractWebsocketTimeline(c); got != nil {
t.Fatalf("expected nil websocket timeline, got %q", string(got))
}
c.Set(websocketTimelineOverrideContextKey, []byte("timeline"))
body := wrapper.extractWebsocketTimeline(c)
if string(body) != "timeline" {
t.Fatalf("websocket timeline = %q, want %q", string(body), "timeline")
}
}
func TestFinalizeStreamingWritesAPIWebsocketTimeline(t *testing.T) {
gin.SetMode(gin.TestMode)
recorder := httptest.NewRecorder()
c, _ := gin.CreateTestContext(recorder)
streamWriter := &testStreamingLogWriter{}
wrapper := &ResponseWriterWrapper{
ResponseWriter: c.Writer,
logger: &testRequestLogger{enabled: true},
requestInfo: &RequestInfo{
URL: "/v1/responses",
Method: "POST",
Headers: map[string][]string{"Content-Type": {"application/json"}},
RequestID: "req-1",
Timestamp: time.Date(2026, time.April, 1, 12, 0, 0, 0, time.UTC),
},
isStreaming: true,
streamWriter: streamWriter,
}
c.Set("API_WEBSOCKET_TIMELINE", []byte("Timestamp: 2026-04-01T12:00:00Z\nEvent: api.websocket.request\n{}"))
if err := wrapper.Finalize(c); err != nil {
t.Fatalf("Finalize error: %v", err)
}
if string(streamWriter.apiWebsocketTimeline) != "Timestamp: 2026-04-01T12:00:00Z\nEvent: api.websocket.request\n{}" {
t.Fatalf("stream writer websocket timeline = %q", string(streamWriter.apiWebsocketTimeline))
}
if !streamWriter.closed {
t.Fatal("expected stream writer to be closed")
}
}
type testRequestLogger struct {
enabled bool
}
func (l *testRequestLogger) LogRequest(string, string, map[string][]string, []byte, int, map[string][]string, []byte, []byte, []byte, []byte, []byte, []*interfaces.ErrorMessage, string, time.Time, time.Time) error {
return nil
}
func (l *testRequestLogger) LogStreamingRequest(string, string, map[string][]string, []byte, string) (logging.StreamingLogWriter, error) {
return &testStreamingLogWriter{}, nil
}
func (l *testRequestLogger) IsEnabled() bool {
return l.enabled
}
type testStreamingLogWriter struct {
apiWebsocketTimeline []byte
closed bool
}
func (w *testStreamingLogWriter) WriteChunkAsync([]byte) {}
func (w *testStreamingLogWriter) WriteStatus(int, map[string][]string) error {
return nil
}
func (w *testStreamingLogWriter) WriteAPIRequest([]byte) error {
return nil
}
func (w *testStreamingLogWriter) WriteAPIResponse([]byte) error {
return nil
}
func (w *testStreamingLogWriter) WriteAPIWebsocketTimeline(apiWebsocketTimeline []byte) error {
w.apiWebsocketTimeline = bytes.Clone(apiWebsocketTimeline)
return nil
}
func (w *testStreamingLogWriter) SetFirstChunkTimestamp(time.Time) {}
func (w *testStreamingLogWriter) Close() error {
w.closed = true
return nil
}

View File

@@ -172,6 +172,8 @@ func TestDefaultRequestLoggerFactory_UsesResolvedLogDirectory(t *testing.T) {
nil,
nil,
nil,
nil,
nil,
true,
"issue-1711",
time.Now(),

View File

@@ -4,6 +4,7 @@
package logging
import (
"bufio"
"bytes"
"compress/flate"
"compress/gzip"
@@ -41,15 +42,17 @@ type RequestLogger interface {
// - statusCode: The response status code
// - responseHeaders: The response headers
// - response: The raw response data
// - websocketTimeline: Optional downstream websocket event timeline
// - apiRequest: The API request data
// - apiResponse: The API response data
// - apiWebsocketTimeline: Optional upstream websocket event timeline
// - requestID: Optional request ID for log file naming
// - requestTimestamp: When the request was received
// - apiResponseTimestamp: When the API response was received
//
// Returns:
// - error: An error if logging fails, nil otherwise
LogRequest(url, method string, requestHeaders map[string][]string, body []byte, statusCode int, responseHeaders map[string][]string, response, apiRequest, apiResponse []byte, apiResponseErrors []*interfaces.ErrorMessage, requestID string, requestTimestamp, apiResponseTimestamp time.Time) error
LogRequest(url, method string, requestHeaders map[string][]string, body []byte, statusCode int, responseHeaders map[string][]string, response, websocketTimeline, apiRequest, apiResponse, apiWebsocketTimeline []byte, apiResponseErrors []*interfaces.ErrorMessage, requestID string, requestTimestamp, apiResponseTimestamp time.Time) error
// LogStreamingRequest initiates logging for a streaming request and returns a writer for chunks.
//
@@ -111,6 +114,16 @@ type StreamingLogWriter interface {
// - error: An error if writing fails, nil otherwise
WriteAPIResponse(apiResponse []byte) error
// WriteAPIWebsocketTimeline writes the upstream websocket timeline to the log.
// This should be called when upstream communication happened over websocket.
//
// Parameters:
// - apiWebsocketTimeline: The upstream websocket event timeline
//
// Returns:
// - error: An error if writing fails, nil otherwise
WriteAPIWebsocketTimeline(apiWebsocketTimeline []byte) error
// SetFirstChunkTimestamp sets the TTFB timestamp captured when first chunk was received.
//
// Parameters:
@@ -203,17 +216,17 @@ func (l *FileRequestLogger) SetErrorLogsMaxFiles(maxFiles int) {
//
// Returns:
// - error: An error if logging fails, nil otherwise
func (l *FileRequestLogger) LogRequest(url, method string, requestHeaders map[string][]string, body []byte, statusCode int, responseHeaders map[string][]string, response, apiRequest, apiResponse []byte, apiResponseErrors []*interfaces.ErrorMessage, requestID string, requestTimestamp, apiResponseTimestamp time.Time) error {
return l.logRequest(url, method, requestHeaders, body, statusCode, responseHeaders, response, apiRequest, apiResponse, apiResponseErrors, false, requestID, requestTimestamp, apiResponseTimestamp)
func (l *FileRequestLogger) LogRequest(url, method string, requestHeaders map[string][]string, body []byte, statusCode int, responseHeaders map[string][]string, response, websocketTimeline, apiRequest, apiResponse, apiWebsocketTimeline []byte, apiResponseErrors []*interfaces.ErrorMessage, requestID string, requestTimestamp, apiResponseTimestamp time.Time) error {
return l.logRequest(url, method, requestHeaders, body, statusCode, responseHeaders, response, websocketTimeline, apiRequest, apiResponse, apiWebsocketTimeline, apiResponseErrors, false, requestID, requestTimestamp, apiResponseTimestamp)
}
// LogRequestWithOptions logs a request with optional forced logging behavior.
// The force flag allows writing error logs even when regular request logging is disabled.
func (l *FileRequestLogger) LogRequestWithOptions(url, method string, requestHeaders map[string][]string, body []byte, statusCode int, responseHeaders map[string][]string, response, apiRequest, apiResponse []byte, apiResponseErrors []*interfaces.ErrorMessage, force bool, requestID string, requestTimestamp, apiResponseTimestamp time.Time) error {
return l.logRequest(url, method, requestHeaders, body, statusCode, responseHeaders, response, apiRequest, apiResponse, apiResponseErrors, force, requestID, requestTimestamp, apiResponseTimestamp)
func (l *FileRequestLogger) LogRequestWithOptions(url, method string, requestHeaders map[string][]string, body []byte, statusCode int, responseHeaders map[string][]string, response, websocketTimeline, apiRequest, apiResponse, apiWebsocketTimeline []byte, apiResponseErrors []*interfaces.ErrorMessage, force bool, requestID string, requestTimestamp, apiResponseTimestamp time.Time) error {
return l.logRequest(url, method, requestHeaders, body, statusCode, responseHeaders, response, websocketTimeline, apiRequest, apiResponse, apiWebsocketTimeline, apiResponseErrors, force, requestID, requestTimestamp, apiResponseTimestamp)
}
func (l *FileRequestLogger) logRequest(url, method string, requestHeaders map[string][]string, body []byte, statusCode int, responseHeaders map[string][]string, response, apiRequest, apiResponse []byte, apiResponseErrors []*interfaces.ErrorMessage, force bool, requestID string, requestTimestamp, apiResponseTimestamp time.Time) error {
func (l *FileRequestLogger) logRequest(url, method string, requestHeaders map[string][]string, body []byte, statusCode int, responseHeaders map[string][]string, response, websocketTimeline, apiRequest, apiResponse, apiWebsocketTimeline []byte, apiResponseErrors []*interfaces.ErrorMessage, force bool, requestID string, requestTimestamp, apiResponseTimestamp time.Time) error {
if !l.enabled && !force {
return nil
}
@@ -260,8 +273,10 @@ func (l *FileRequestLogger) logRequest(url, method string, requestHeaders map[st
requestHeaders,
body,
requestBodyPath,
websocketTimeline,
apiRequest,
apiResponse,
apiWebsocketTimeline,
apiResponseErrors,
statusCode,
responseHeaders,
@@ -518,8 +533,10 @@ func (l *FileRequestLogger) writeNonStreamingLog(
requestHeaders map[string][]string,
requestBody []byte,
requestBodyPath string,
websocketTimeline []byte,
apiRequest []byte,
apiResponse []byte,
apiWebsocketTimeline []byte,
apiResponseErrors []*interfaces.ErrorMessage,
statusCode int,
responseHeaders map[string][]string,
@@ -531,7 +548,16 @@ func (l *FileRequestLogger) writeNonStreamingLog(
if requestTimestamp.IsZero() {
requestTimestamp = time.Now()
}
if errWrite := writeRequestInfoWithBody(w, url, method, requestHeaders, requestBody, requestBodyPath, requestTimestamp); errWrite != nil {
isWebsocketTranscript := hasSectionPayload(websocketTimeline)
downstreamTransport := inferDownstreamTransport(requestHeaders, websocketTimeline)
upstreamTransport := inferUpstreamTransport(apiRequest, apiResponse, apiWebsocketTimeline, apiResponseErrors)
if errWrite := writeRequestInfoWithBody(w, url, method, requestHeaders, requestBody, requestBodyPath, requestTimestamp, downstreamTransport, upstreamTransport, !isWebsocketTranscript); errWrite != nil {
return errWrite
}
if errWrite := writeAPISection(w, "=== WEBSOCKET TIMELINE ===\n", "=== WEBSOCKET TIMELINE", websocketTimeline, time.Time{}); errWrite != nil {
return errWrite
}
if errWrite := writeAPISection(w, "=== API WEBSOCKET TIMELINE ===\n", "=== API WEBSOCKET TIMELINE", apiWebsocketTimeline, time.Time{}); errWrite != nil {
return errWrite
}
if errWrite := writeAPISection(w, "=== API REQUEST ===\n", "=== API REQUEST", apiRequest, time.Time{}); errWrite != nil {
@@ -543,6 +569,12 @@ func (l *FileRequestLogger) writeNonStreamingLog(
if errWrite := writeAPISection(w, "=== API RESPONSE ===\n", "=== API RESPONSE", apiResponse, apiResponseTimestamp); errWrite != nil {
return errWrite
}
if isWebsocketTranscript {
// Intentionally omit the generic downstream HTTP response section for websocket
// transcripts. The durable session exchange is captured in WEBSOCKET TIMELINE,
// and appending a one-off upgrade response snapshot would dilute that transcript.
return nil
}
return writeResponseSection(w, statusCode, true, responseHeaders, bytes.NewReader(response), decompressErr, true)
}
@@ -553,6 +585,9 @@ func writeRequestInfoWithBody(
body []byte,
bodyPath string,
timestamp time.Time,
downstreamTransport string,
upstreamTransport string,
includeBody bool,
) error {
if _, errWrite := io.WriteString(w, "=== REQUEST INFO ===\n"); errWrite != nil {
return errWrite
@@ -566,10 +601,20 @@ func writeRequestInfoWithBody(
if _, errWrite := io.WriteString(w, fmt.Sprintf("Method: %s\n", method)); errWrite != nil {
return errWrite
}
if strings.TrimSpace(downstreamTransport) != "" {
if _, errWrite := io.WriteString(w, fmt.Sprintf("Downstream Transport: %s\n", downstreamTransport)); errWrite != nil {
return errWrite
}
}
if strings.TrimSpace(upstreamTransport) != "" {
if _, errWrite := io.WriteString(w, fmt.Sprintf("Upstream Transport: %s\n", upstreamTransport)); errWrite != nil {
return errWrite
}
}
if _, errWrite := io.WriteString(w, fmt.Sprintf("Timestamp: %s\n", timestamp.Format(time.RFC3339Nano))); errWrite != nil {
return errWrite
}
if _, errWrite := io.WriteString(w, "\n"); errWrite != nil {
if errWrite := writeSectionSpacing(w, 1); errWrite != nil {
return errWrite
}
@@ -584,36 +629,121 @@ func writeRequestInfoWithBody(
}
}
}
if _, errWrite := io.WriteString(w, "\n"); errWrite != nil {
if errWrite := writeSectionSpacing(w, 1); errWrite != nil {
return errWrite
}
if !includeBody {
return nil
}
if _, errWrite := io.WriteString(w, "=== REQUEST BODY ===\n"); errWrite != nil {
return errWrite
}
bodyTrailingNewlines := 1
if bodyPath != "" {
bodyFile, errOpen := os.Open(bodyPath)
if errOpen != nil {
return errOpen
}
if _, errCopy := io.Copy(w, bodyFile); errCopy != nil {
tracker := &trailingNewlineTrackingWriter{writer: w}
written, errCopy := io.Copy(tracker, bodyFile)
if errCopy != nil {
_ = bodyFile.Close()
return errCopy
}
if written > 0 {
bodyTrailingNewlines = tracker.trailingNewlines
}
if errClose := bodyFile.Close(); errClose != nil {
log.WithError(errClose).Warn("failed to close request body temp file")
}
} else if _, errWrite := w.Write(body); errWrite != nil {
return errWrite
} else if len(body) > 0 {
bodyTrailingNewlines = countTrailingNewlinesBytes(body)
}
if _, errWrite := io.WriteString(w, "\n\n"); errWrite != nil {
if errWrite := writeSectionSpacing(w, bodyTrailingNewlines); errWrite != nil {
return errWrite
}
return nil
}
func countTrailingNewlinesBytes(payload []byte) int {
count := 0
for i := len(payload) - 1; i >= 0; i-- {
if payload[i] != '\n' {
break
}
count++
}
return count
}
func writeSectionSpacing(w io.Writer, trailingNewlines int) error {
missingNewlines := 3 - trailingNewlines
if missingNewlines <= 0 {
return nil
}
_, errWrite := io.WriteString(w, strings.Repeat("\n", missingNewlines))
return errWrite
}
type trailingNewlineTrackingWriter struct {
writer io.Writer
trailingNewlines int
}
func (t *trailingNewlineTrackingWriter) Write(payload []byte) (int, error) {
written, errWrite := t.writer.Write(payload)
if written > 0 {
writtenPayload := payload[:written]
trailingNewlines := countTrailingNewlinesBytes(writtenPayload)
if trailingNewlines == len(writtenPayload) {
t.trailingNewlines += trailingNewlines
} else {
t.trailingNewlines = trailingNewlines
}
}
return written, errWrite
}
func hasSectionPayload(payload []byte) bool {
return len(bytes.TrimSpace(payload)) > 0
}
func inferDownstreamTransport(headers map[string][]string, websocketTimeline []byte) string {
if hasSectionPayload(websocketTimeline) {
return "websocket"
}
for key, values := range headers {
if strings.EqualFold(strings.TrimSpace(key), "Upgrade") {
for _, value := range values {
if strings.EqualFold(strings.TrimSpace(value), "websocket") {
return "websocket"
}
}
}
}
return "http"
}
func inferUpstreamTransport(apiRequest, apiResponse, apiWebsocketTimeline []byte, _ []*interfaces.ErrorMessage) string {
hasHTTP := hasSectionPayload(apiRequest) || hasSectionPayload(apiResponse)
hasWS := hasSectionPayload(apiWebsocketTimeline)
switch {
case hasHTTP && hasWS:
return "websocket+http"
case hasWS:
return "websocket"
case hasHTTP:
return "http"
default:
return ""
}
}
func writeAPISection(w io.Writer, sectionHeader string, sectionPrefix string, payload []byte, timestamp time.Time) error {
if len(payload) == 0 {
return nil
@@ -623,11 +753,6 @@ func writeAPISection(w io.Writer, sectionHeader string, sectionPrefix string, pa
if _, errWrite := w.Write(payload); errWrite != nil {
return errWrite
}
if !bytes.HasSuffix(payload, []byte("\n")) {
if _, errWrite := io.WriteString(w, "\n"); errWrite != nil {
return errWrite
}
}
} else {
if _, errWrite := io.WriteString(w, sectionHeader); errWrite != nil {
return errWrite
@@ -640,12 +765,9 @@ func writeAPISection(w io.Writer, sectionHeader string, sectionPrefix string, pa
if _, errWrite := w.Write(payload); errWrite != nil {
return errWrite
}
if _, errWrite := io.WriteString(w, "\n"); errWrite != nil {
return errWrite
}
}
if _, errWrite := io.WriteString(w, "\n"); errWrite != nil {
if errWrite := writeSectionSpacing(w, countTrailingNewlinesBytes(payload)); errWrite != nil {
return errWrite
}
return nil
@@ -662,12 +784,17 @@ func writeAPIErrorResponses(w io.Writer, apiResponseErrors []*interfaces.ErrorMe
if _, errWrite := io.WriteString(w, fmt.Sprintf("HTTP Status: %d\n", apiResponseErrors[i].StatusCode)); errWrite != nil {
return errWrite
}
trailingNewlines := 1
if apiResponseErrors[i].Error != nil {
if _, errWrite := io.WriteString(w, apiResponseErrors[i].Error.Error()); errWrite != nil {
errText := apiResponseErrors[i].Error.Error()
if _, errWrite := io.WriteString(w, errText); errWrite != nil {
return errWrite
}
if errText != "" {
trailingNewlines = countTrailingNewlinesBytes([]byte(errText))
}
}
if _, errWrite := io.WriteString(w, "\n\n"); errWrite != nil {
if errWrite := writeSectionSpacing(w, trailingNewlines); errWrite != nil {
return errWrite
}
}
@@ -694,12 +821,18 @@ func writeResponseSection(w io.Writer, statusCode int, statusWritten bool, respo
}
}
if _, errWrite := io.WriteString(w, "\n"); errWrite != nil {
return errWrite
var bufferedReader *bufio.Reader
if responseReader != nil {
bufferedReader = bufio.NewReader(responseReader)
}
if !responseBodyStartsWithLeadingNewline(bufferedReader) {
if _, errWrite := io.WriteString(w, "\n"); errWrite != nil {
return errWrite
}
}
if responseReader != nil {
if _, errCopy := io.Copy(w, responseReader); errCopy != nil {
if bufferedReader != nil {
if _, errCopy := io.Copy(w, bufferedReader); errCopy != nil {
return errCopy
}
}
@@ -717,6 +850,19 @@ func writeResponseSection(w io.Writer, statusCode int, statusWritten bool, respo
return nil
}
func responseBodyStartsWithLeadingNewline(reader *bufio.Reader) bool {
if reader == nil {
return false
}
if peeked, _ := reader.Peek(2); len(peeked) >= 2 && peeked[0] == '\r' && peeked[1] == '\n' {
return true
}
if peeked, _ := reader.Peek(1); len(peeked) >= 1 && peeked[0] == '\n' {
return true
}
return false
}
// formatLogContent creates the complete log content for non-streaming requests.
//
// Parameters:
@@ -724,6 +870,7 @@ func writeResponseSection(w io.Writer, statusCode int, statusWritten bool, respo
// - method: The HTTP method
// - headers: The request headers
// - body: The request body
// - websocketTimeline: The downstream websocket event timeline
// - apiRequest: The API request data
// - apiResponse: The API response data
// - response: The raw response data
@@ -732,11 +879,42 @@ func writeResponseSection(w io.Writer, statusCode int, statusWritten bool, respo
//
// Returns:
// - string: The formatted log content
func (l *FileRequestLogger) formatLogContent(url, method string, headers map[string][]string, body, apiRequest, apiResponse, response []byte, status int, responseHeaders map[string][]string, apiResponseErrors []*interfaces.ErrorMessage) string {
func (l *FileRequestLogger) formatLogContent(url, method string, headers map[string][]string, body, websocketTimeline, apiRequest, apiResponse, apiWebsocketTimeline, response []byte, status int, responseHeaders map[string][]string, apiResponseErrors []*interfaces.ErrorMessage) string {
var content strings.Builder
isWebsocketTranscript := hasSectionPayload(websocketTimeline)
downstreamTransport := inferDownstreamTransport(headers, websocketTimeline)
upstreamTransport := inferUpstreamTransport(apiRequest, apiResponse, apiWebsocketTimeline, apiResponseErrors)
// Request info
content.WriteString(l.formatRequestInfo(url, method, headers, body))
content.WriteString(l.formatRequestInfo(url, method, headers, body, downstreamTransport, upstreamTransport, !isWebsocketTranscript))
if len(websocketTimeline) > 0 {
if bytes.HasPrefix(websocketTimeline, []byte("=== WEBSOCKET TIMELINE")) {
content.Write(websocketTimeline)
if !bytes.HasSuffix(websocketTimeline, []byte("\n")) {
content.WriteString("\n")
}
} else {
content.WriteString("=== WEBSOCKET TIMELINE ===\n")
content.Write(websocketTimeline)
content.WriteString("\n")
}
content.WriteString("\n")
}
if len(apiWebsocketTimeline) > 0 {
if bytes.HasPrefix(apiWebsocketTimeline, []byte("=== API WEBSOCKET TIMELINE")) {
content.Write(apiWebsocketTimeline)
if !bytes.HasSuffix(apiWebsocketTimeline, []byte("\n")) {
content.WriteString("\n")
}
} else {
content.WriteString("=== API WEBSOCKET TIMELINE ===\n")
content.Write(apiWebsocketTimeline)
content.WriteString("\n")
}
content.WriteString("\n")
}
if len(apiRequest) > 0 {
if bytes.HasPrefix(apiRequest, []byte("=== API REQUEST")) {
@@ -773,6 +951,12 @@ func (l *FileRequestLogger) formatLogContent(url, method string, headers map[str
content.WriteString("\n")
}
if isWebsocketTranscript {
// Mirror writeNonStreamingLog: websocket transcripts end with the dedicated
// timeline sections instead of a generic downstream HTTP response block.
return content.String()
}
// Response section
content.WriteString("=== RESPONSE ===\n")
content.WriteString(fmt.Sprintf("Status: %d\n", status))
@@ -933,13 +1117,19 @@ func (l *FileRequestLogger) decompressZstd(data []byte) ([]byte, error) {
//
// Returns:
// - string: The formatted request information
func (l *FileRequestLogger) formatRequestInfo(url, method string, headers map[string][]string, body []byte) string {
func (l *FileRequestLogger) formatRequestInfo(url, method string, headers map[string][]string, body []byte, downstreamTransport string, upstreamTransport string, includeBody bool) string {
var content strings.Builder
content.WriteString("=== REQUEST INFO ===\n")
content.WriteString(fmt.Sprintf("Version: %s\n", buildinfo.Version))
content.WriteString(fmt.Sprintf("URL: %s\n", url))
content.WriteString(fmt.Sprintf("Method: %s\n", method))
if strings.TrimSpace(downstreamTransport) != "" {
content.WriteString(fmt.Sprintf("Downstream Transport: %s\n", downstreamTransport))
}
if strings.TrimSpace(upstreamTransport) != "" {
content.WriteString(fmt.Sprintf("Upstream Transport: %s\n", upstreamTransport))
}
content.WriteString(fmt.Sprintf("Timestamp: %s\n", time.Now().Format(time.RFC3339Nano)))
content.WriteString("\n")
@@ -952,6 +1142,10 @@ func (l *FileRequestLogger) formatRequestInfo(url, method string, headers map[st
}
content.WriteString("\n")
if !includeBody {
return content.String()
}
content.WriteString("=== REQUEST BODY ===\n")
content.Write(body)
content.WriteString("\n\n")
@@ -1011,6 +1205,9 @@ type FileStreamingLogWriter struct {
// apiResponse stores the upstream API response data.
apiResponse []byte
// apiWebsocketTimeline stores the upstream websocket event timeline.
apiWebsocketTimeline []byte
// apiResponseTimestamp captures when the API response was received.
apiResponseTimestamp time.Time
}
@@ -1092,6 +1289,21 @@ func (w *FileStreamingLogWriter) WriteAPIResponse(apiResponse []byte) error {
return nil
}
// WriteAPIWebsocketTimeline buffers the upstream websocket timeline for later writing.
//
// Parameters:
// - apiWebsocketTimeline: The upstream websocket event timeline
//
// Returns:
// - error: Always returns nil (buffering cannot fail)
func (w *FileStreamingLogWriter) WriteAPIWebsocketTimeline(apiWebsocketTimeline []byte) error {
if len(apiWebsocketTimeline) == 0 {
return nil
}
w.apiWebsocketTimeline = bytes.Clone(apiWebsocketTimeline)
return nil
}
func (w *FileStreamingLogWriter) SetFirstChunkTimestamp(timestamp time.Time) {
if !timestamp.IsZero() {
w.apiResponseTimestamp = timestamp
@@ -1100,7 +1312,7 @@ func (w *FileStreamingLogWriter) SetFirstChunkTimestamp(timestamp time.Time) {
// Close finalizes the log file and cleans up resources.
// It writes all buffered data to the file in the correct order:
// API REQUEST -> API RESPONSE -> RESPONSE (status, headers, body chunks)
// API WEBSOCKET TIMELINE -> API REQUEST -> API RESPONSE -> RESPONSE (status, headers, body chunks)
//
// Returns:
// - error: An error if closing fails, nil otherwise
@@ -1182,7 +1394,10 @@ func (w *FileStreamingLogWriter) asyncWriter() {
}
func (w *FileStreamingLogWriter) writeFinalLog(logFile *os.File) error {
if errWrite := writeRequestInfoWithBody(logFile, w.url, w.method, w.requestHeaders, nil, w.requestBodyPath, w.timestamp); errWrite != nil {
if errWrite := writeRequestInfoWithBody(logFile, w.url, w.method, w.requestHeaders, nil, w.requestBodyPath, w.timestamp, "http", inferUpstreamTransport(w.apiRequest, w.apiResponse, w.apiWebsocketTimeline, nil), true); errWrite != nil {
return errWrite
}
if errWrite := writeAPISection(logFile, "=== API WEBSOCKET TIMELINE ===\n", "=== API WEBSOCKET TIMELINE", w.apiWebsocketTimeline, time.Time{}); errWrite != nil {
return errWrite
}
if errWrite := writeAPISection(logFile, "=== API REQUEST ===\n", "=== API REQUEST", w.apiRequest, time.Time{}); errWrite != nil {
@@ -1265,6 +1480,17 @@ func (w *NoOpStreamingLogWriter) WriteAPIResponse(_ []byte) error {
return nil
}
// WriteAPIWebsocketTimeline is a no-op implementation that does nothing and always returns nil.
//
// Parameters:
// - apiWebsocketTimeline: The upstream websocket event timeline (ignored)
//
// Returns:
// - error: Always returns nil
func (w *NoOpStreamingLogWriter) WriteAPIWebsocketTimeline(_ []byte) error {
return nil
}
func (w *NoOpStreamingLogWriter) SetFirstChunkTimestamp(_ time.Time) {}
// Close is a no-op implementation that does nothing and always returns nil.

View File

@@ -219,7 +219,7 @@ func (e *CodexWebsocketsExecutor) Execute(ctx context.Context, auth *cliproxyaut
}
wsReqBody := buildCodexWebsocketRequestBody(body)
helps.RecordAPIRequest(ctx, e.cfg, helps.UpstreamRequestLog{
wsReqLog := helps.UpstreamRequestLog{
URL: wsURL,
Method: "WEBSOCKET",
Headers: wsHeaders.Clone(),
@@ -229,16 +229,14 @@ func (e *CodexWebsocketsExecutor) Execute(ctx context.Context, auth *cliproxyaut
AuthLabel: authLabel,
AuthType: authType,
AuthValue: authValue,
})
}
helps.RecordAPIWebsocketRequest(ctx, e.cfg, wsReqLog)
conn, respHS, errDial := e.ensureUpstreamConn(ctx, auth, sess, authID, wsURL, wsHeaders)
if respHS != nil {
helps.RecordAPIResponseMetadata(ctx, e.cfg, respHS.StatusCode, respHS.Header.Clone())
}
if errDial != nil {
bodyErr := websocketHandshakeBody(respHS)
if len(bodyErr) > 0 {
helps.AppendAPIResponseChunk(ctx, e.cfg, bodyErr)
if respHS != nil {
helps.RecordAPIWebsocketUpgradeRejection(ctx, e.cfg, websocketUpgradeRequestLog(wsReqLog), respHS.StatusCode, respHS.Header.Clone(), bodyErr)
}
if respHS != nil && respHS.StatusCode == http.StatusUpgradeRequired {
return e.CodexExecutor.Execute(ctx, auth, req, opts)
@@ -246,10 +244,10 @@ func (e *CodexWebsocketsExecutor) Execute(ctx context.Context, auth *cliproxyaut
if respHS != nil && respHS.StatusCode > 0 {
return resp, statusErr{code: respHS.StatusCode, msg: string(bodyErr)}
}
helps.RecordAPIResponseError(ctx, e.cfg, errDial)
helps.RecordAPIWebsocketError(ctx, e.cfg, "dial", errDial)
return resp, errDial
}
closeHTTPResponseBody(respHS, "codex websockets executor: close handshake response body error")
recordAPIWebsocketHandshake(ctx, e.cfg, respHS)
if sess == nil {
logCodexWebsocketConnected(executionSessionID, authID, wsURL)
defer func() {
@@ -278,10 +276,10 @@ func (e *CodexWebsocketsExecutor) Execute(ctx context.Context, auth *cliproxyaut
// Retry once with a fresh websocket connection. This is mainly to handle
// upstream closing the socket between sequential requests within the same
// execution session.
connRetry, _, errDialRetry := e.ensureUpstreamConn(ctx, auth, sess, authID, wsURL, wsHeaders)
connRetry, respHSRetry, errDialRetry := e.ensureUpstreamConn(ctx, auth, sess, authID, wsURL, wsHeaders)
if errDialRetry == nil && connRetry != nil {
wsReqBodyRetry := buildCodexWebsocketRequestBody(body)
helps.RecordAPIRequest(ctx, e.cfg, helps.UpstreamRequestLog{
helps.RecordAPIWebsocketRequest(ctx, e.cfg, helps.UpstreamRequestLog{
URL: wsURL,
Method: "WEBSOCKET",
Headers: wsHeaders.Clone(),
@@ -292,20 +290,22 @@ func (e *CodexWebsocketsExecutor) Execute(ctx context.Context, auth *cliproxyaut
AuthType: authType,
AuthValue: authValue,
})
recordAPIWebsocketHandshake(ctx, e.cfg, respHSRetry)
if errSendRetry := writeCodexWebsocketMessage(sess, connRetry, wsReqBodyRetry); errSendRetry == nil {
conn = connRetry
wsReqBody = wsReqBodyRetry
} else {
e.invalidateUpstreamConn(sess, connRetry, "send_error", errSendRetry)
helps.RecordAPIResponseError(ctx, e.cfg, errSendRetry)
helps.RecordAPIWebsocketError(ctx, e.cfg, "send_retry", errSendRetry)
return resp, errSendRetry
}
} else {
helps.RecordAPIResponseError(ctx, e.cfg, errDialRetry)
closeHTTPResponseBody(respHSRetry, "codex websockets executor: close handshake response body error")
helps.RecordAPIWebsocketError(ctx, e.cfg, "dial_retry", errDialRetry)
return resp, errDialRetry
}
} else {
helps.RecordAPIResponseError(ctx, e.cfg, errSend)
helps.RecordAPIWebsocketError(ctx, e.cfg, "send", errSend)
return resp, errSend
}
}
@@ -316,7 +316,7 @@ func (e *CodexWebsocketsExecutor) Execute(ctx context.Context, auth *cliproxyaut
}
msgType, payload, errRead := readCodexWebsocketMessage(ctx, sess, conn, readCh)
if errRead != nil {
helps.RecordAPIResponseError(ctx, e.cfg, errRead)
helps.RecordAPIWebsocketError(ctx, e.cfg, "read", errRead)
return resp, errRead
}
if msgType != websocket.TextMessage {
@@ -325,7 +325,7 @@ func (e *CodexWebsocketsExecutor) Execute(ctx context.Context, auth *cliproxyaut
if sess != nil {
e.invalidateUpstreamConn(sess, conn, "unexpected_binary", err)
}
helps.RecordAPIResponseError(ctx, e.cfg, err)
helps.RecordAPIWebsocketError(ctx, e.cfg, "unexpected_binary", err)
return resp, err
}
continue
@@ -335,13 +335,13 @@ func (e *CodexWebsocketsExecutor) Execute(ctx context.Context, auth *cliproxyaut
if len(payload) == 0 {
continue
}
helps.AppendAPIResponseChunk(ctx, e.cfg, payload)
helps.AppendAPIWebsocketResponse(ctx, e.cfg, payload)
if wsErr, ok := parseCodexWebsocketError(payload); ok {
if sess != nil {
e.invalidateUpstreamConn(sess, conn, "upstream_error", wsErr)
}
helps.RecordAPIResponseError(ctx, e.cfg, wsErr)
helps.RecordAPIWebsocketError(ctx, e.cfg, "upstream_error", wsErr)
return resp, wsErr
}
@@ -413,7 +413,7 @@ func (e *CodexWebsocketsExecutor) ExecuteStream(ctx context.Context, auth *clipr
}
wsReqBody := buildCodexWebsocketRequestBody(body)
helps.RecordAPIRequest(ctx, e.cfg, helps.UpstreamRequestLog{
wsReqLog := helps.UpstreamRequestLog{
URL: wsURL,
Method: "WEBSOCKET",
Headers: wsHeaders.Clone(),
@@ -423,18 +423,18 @@ func (e *CodexWebsocketsExecutor) ExecuteStream(ctx context.Context, auth *clipr
AuthLabel: authLabel,
AuthType: authType,
AuthValue: authValue,
})
}
helps.RecordAPIWebsocketRequest(ctx, e.cfg, wsReqLog)
conn, respHS, errDial := e.ensureUpstreamConn(ctx, auth, sess, authID, wsURL, wsHeaders)
var upstreamHeaders http.Header
if respHS != nil {
upstreamHeaders = respHS.Header.Clone()
helps.RecordAPIResponseMetadata(ctx, e.cfg, respHS.StatusCode, respHS.Header.Clone())
}
if errDial != nil {
bodyErr := websocketHandshakeBody(respHS)
if len(bodyErr) > 0 {
helps.AppendAPIResponseChunk(ctx, e.cfg, bodyErr)
if respHS != nil {
helps.RecordAPIWebsocketUpgradeRejection(ctx, e.cfg, websocketUpgradeRequestLog(wsReqLog), respHS.StatusCode, respHS.Header.Clone(), bodyErr)
}
if respHS != nil && respHS.StatusCode == http.StatusUpgradeRequired {
return e.CodexExecutor.ExecuteStream(ctx, auth, req, opts)
@@ -442,13 +442,13 @@ func (e *CodexWebsocketsExecutor) ExecuteStream(ctx context.Context, auth *clipr
if respHS != nil && respHS.StatusCode > 0 {
return nil, statusErr{code: respHS.StatusCode, msg: string(bodyErr)}
}
helps.RecordAPIResponseError(ctx, e.cfg, errDial)
helps.RecordAPIWebsocketError(ctx, e.cfg, "dial", errDial)
if sess != nil {
sess.reqMu.Unlock()
}
return nil, errDial
}
closeHTTPResponseBody(respHS, "codex websockets executor: close handshake response body error")
recordAPIWebsocketHandshake(ctx, e.cfg, respHS)
if sess == nil {
logCodexWebsocketConnected(executionSessionID, authID, wsURL)
@@ -461,20 +461,21 @@ func (e *CodexWebsocketsExecutor) ExecuteStream(ctx context.Context, auth *clipr
}
if errSend := writeCodexWebsocketMessage(sess, conn, wsReqBody); errSend != nil {
helps.RecordAPIResponseError(ctx, e.cfg, errSend)
helps.RecordAPIWebsocketError(ctx, e.cfg, "send", errSend)
if sess != nil {
e.invalidateUpstreamConn(sess, conn, "send_error", errSend)
// Retry once with a new websocket connection for the same execution session.
connRetry, _, errDialRetry := e.ensureUpstreamConn(ctx, auth, sess, authID, wsURL, wsHeaders)
connRetry, respHSRetry, errDialRetry := e.ensureUpstreamConn(ctx, auth, sess, authID, wsURL, wsHeaders)
if errDialRetry != nil || connRetry == nil {
helps.RecordAPIResponseError(ctx, e.cfg, errDialRetry)
closeHTTPResponseBody(respHSRetry, "codex websockets executor: close handshake response body error")
helps.RecordAPIWebsocketError(ctx, e.cfg, "dial_retry", errDialRetry)
sess.clearActive(readCh)
sess.reqMu.Unlock()
return nil, errDialRetry
}
wsReqBodyRetry := buildCodexWebsocketRequestBody(body)
helps.RecordAPIRequest(ctx, e.cfg, helps.UpstreamRequestLog{
helps.RecordAPIWebsocketRequest(ctx, e.cfg, helps.UpstreamRequestLog{
URL: wsURL,
Method: "WEBSOCKET",
Headers: wsHeaders.Clone(),
@@ -485,8 +486,9 @@ func (e *CodexWebsocketsExecutor) ExecuteStream(ctx context.Context, auth *clipr
AuthType: authType,
AuthValue: authValue,
})
recordAPIWebsocketHandshake(ctx, e.cfg, respHSRetry)
if errSendRetry := writeCodexWebsocketMessage(sess, connRetry, wsReqBodyRetry); errSendRetry != nil {
helps.RecordAPIResponseError(ctx, e.cfg, errSendRetry)
helps.RecordAPIWebsocketError(ctx, e.cfg, "send_retry", errSendRetry)
e.invalidateUpstreamConn(sess, connRetry, "send_error", errSendRetry)
sess.clearActive(readCh)
sess.reqMu.Unlock()
@@ -552,7 +554,7 @@ func (e *CodexWebsocketsExecutor) ExecuteStream(ctx context.Context, auth *clipr
}
terminateReason = "read_error"
terminateErr = errRead
helps.RecordAPIResponseError(ctx, e.cfg, errRead)
helps.RecordAPIWebsocketError(ctx, e.cfg, "read", errRead)
reporter.PublishFailure(ctx)
_ = send(cliproxyexecutor.StreamChunk{Err: errRead})
return
@@ -562,7 +564,7 @@ func (e *CodexWebsocketsExecutor) ExecuteStream(ctx context.Context, auth *clipr
err = fmt.Errorf("codex websockets executor: unexpected binary message")
terminateReason = "unexpected_binary"
terminateErr = err
helps.RecordAPIResponseError(ctx, e.cfg, err)
helps.RecordAPIWebsocketError(ctx, e.cfg, "unexpected_binary", err)
reporter.PublishFailure(ctx)
if sess != nil {
e.invalidateUpstreamConn(sess, conn, "unexpected_binary", err)
@@ -577,12 +579,12 @@ func (e *CodexWebsocketsExecutor) ExecuteStream(ctx context.Context, auth *clipr
if len(payload) == 0 {
continue
}
helps.AppendAPIResponseChunk(ctx, e.cfg, payload)
helps.AppendAPIWebsocketResponse(ctx, e.cfg, payload)
if wsErr, ok := parseCodexWebsocketError(payload); ok {
terminateReason = "upstream_error"
terminateErr = wsErr
helps.RecordAPIResponseError(ctx, e.cfg, wsErr)
helps.RecordAPIWebsocketError(ctx, e.cfg, "upstream_error", wsErr)
reporter.PublishFailure(ctx)
if sess != nil {
e.invalidateUpstreamConn(sess, conn, "upstream_error", wsErr)
@@ -1022,6 +1024,32 @@ func encodeCodexWebsocketAsSSE(payload []byte) []byte {
return line
}
func websocketUpgradeRequestLog(info helps.UpstreamRequestLog) helps.UpstreamRequestLog {
upgradeInfo := info
upgradeInfo.URL = helps.WebsocketUpgradeRequestURL(info.URL)
upgradeInfo.Method = http.MethodGet
upgradeInfo.Body = nil
upgradeInfo.Headers = info.Headers.Clone()
if upgradeInfo.Headers == nil {
upgradeInfo.Headers = make(http.Header)
}
if strings.TrimSpace(upgradeInfo.Headers.Get("Connection")) == "" {
upgradeInfo.Headers.Set("Connection", "Upgrade")
}
if strings.TrimSpace(upgradeInfo.Headers.Get("Upgrade")) == "" {
upgradeInfo.Headers.Set("Upgrade", "websocket")
}
return upgradeInfo
}
func recordAPIWebsocketHandshake(ctx context.Context, cfg *config.Config, resp *http.Response) {
if resp == nil {
return
}
helps.RecordAPIWebsocketHandshake(ctx, cfg, resp.StatusCode, resp.Header.Clone())
closeHTTPResponseBody(resp, "codex websockets executor: close handshake response body error")
}
func websocketHandshakeBody(resp *http.Response) []byte {
if resp == nil || resp.Body == nil {
return nil

View File

@@ -6,6 +6,7 @@ import (
"fmt"
"html"
"net/http"
"net/url"
"sort"
"strings"
"time"
@@ -19,9 +20,10 @@ import (
)
const (
apiAttemptsKey = "API_UPSTREAM_ATTEMPTS"
apiRequestKey = "API_REQUEST"
apiResponseKey = "API_RESPONSE"
apiAttemptsKey = "API_UPSTREAM_ATTEMPTS"
apiRequestKey = "API_REQUEST"
apiResponseKey = "API_RESPONSE"
apiWebsocketTimelineKey = "API_WEBSOCKET_TIMELINE"
)
// UpstreamRequestLog captures the outbound upstream request details for logging.
@@ -46,6 +48,7 @@ type upstreamAttempt struct {
headersWritten bool
bodyStarted bool
bodyHasContent bool
prevWasSSEEvent bool
errorWritten bool
}
@@ -173,15 +176,157 @@ func AppendAPIResponseChunk(ctx context.Context, cfg *config.Config, chunk []byt
attempt.response.WriteString("Body:\n")
attempt.bodyStarted = true
}
currentChunkIsSSEEvent := bytes.HasPrefix(data, []byte("event:"))
currentChunkIsSSEData := bytes.HasPrefix(data, []byte("data:"))
if attempt.bodyHasContent {
attempt.response.WriteString("\n\n")
separator := "\n\n"
if attempt.prevWasSSEEvent && currentChunkIsSSEData {
separator = "\n"
}
attempt.response.WriteString(separator)
}
attempt.response.WriteString(string(data))
attempt.bodyHasContent = true
attempt.prevWasSSEEvent = currentChunkIsSSEEvent
updateAggregatedResponse(ginCtx, attempts)
}
// RecordAPIWebsocketRequest stores an upstream websocket request event in Gin context.
func RecordAPIWebsocketRequest(ctx context.Context, cfg *config.Config, info UpstreamRequestLog) {
if cfg == nil || !cfg.RequestLog {
return
}
ginCtx := ginContextFrom(ctx)
if ginCtx == nil {
return
}
builder := &strings.Builder{}
builder.WriteString(fmt.Sprintf("Timestamp: %s\n", time.Now().Format(time.RFC3339Nano)))
builder.WriteString("Event: api.websocket.request\n")
if info.URL != "" {
builder.WriteString(fmt.Sprintf("Upstream URL: %s\n", info.URL))
}
if auth := formatAuthInfo(info); auth != "" {
builder.WriteString(fmt.Sprintf("Auth: %s\n", auth))
}
builder.WriteString("Headers:\n")
writeHeaders(builder, info.Headers)
builder.WriteString("\nBody:\n")
if len(info.Body) > 0 {
builder.Write(info.Body)
} else {
builder.WriteString("<empty>")
}
builder.WriteString("\n")
appendAPIWebsocketTimeline(ginCtx, []byte(builder.String()))
}
// RecordAPIWebsocketHandshake stores the upstream websocket handshake response metadata.
func RecordAPIWebsocketHandshake(ctx context.Context, cfg *config.Config, status int, headers http.Header) {
if cfg == nil || !cfg.RequestLog {
return
}
ginCtx := ginContextFrom(ctx)
if ginCtx == nil {
return
}
builder := &strings.Builder{}
builder.WriteString(fmt.Sprintf("Timestamp: %s\n", time.Now().Format(time.RFC3339Nano)))
builder.WriteString("Event: api.websocket.handshake\n")
if status > 0 {
builder.WriteString(fmt.Sprintf("Status: %d\n", status))
}
builder.WriteString("Headers:\n")
writeHeaders(builder, headers)
builder.WriteString("\n")
appendAPIWebsocketTimeline(ginCtx, []byte(builder.String()))
}
// RecordAPIWebsocketUpgradeRejection stores a rejected websocket upgrade as an HTTP attempt.
func RecordAPIWebsocketUpgradeRejection(ctx context.Context, cfg *config.Config, info UpstreamRequestLog, status int, headers http.Header, body []byte) {
if cfg == nil || !cfg.RequestLog {
return
}
ginCtx := ginContextFrom(ctx)
if ginCtx == nil {
return
}
RecordAPIRequest(ctx, cfg, info)
RecordAPIResponseMetadata(ctx, cfg, status, headers)
AppendAPIResponseChunk(ctx, cfg, body)
}
// WebsocketUpgradeRequestURL converts a websocket URL back to its HTTP handshake URL for logging.
func WebsocketUpgradeRequestURL(rawURL string) string {
trimmedURL := strings.TrimSpace(rawURL)
if trimmedURL == "" {
return ""
}
parsed, err := url.Parse(trimmedURL)
if err != nil {
return trimmedURL
}
switch strings.ToLower(parsed.Scheme) {
case "ws":
parsed.Scheme = "http"
case "wss":
parsed.Scheme = "https"
}
return parsed.String()
}
// AppendAPIWebsocketResponse stores an upstream websocket response frame in Gin context.
func AppendAPIWebsocketResponse(ctx context.Context, cfg *config.Config, payload []byte) {
if cfg == nil || !cfg.RequestLog {
return
}
data := bytes.TrimSpace(payload)
if len(data) == 0 {
return
}
ginCtx := ginContextFrom(ctx)
if ginCtx == nil {
return
}
markAPIResponseTimestamp(ginCtx)
builder := &strings.Builder{}
builder.WriteString(fmt.Sprintf("Timestamp: %s\n", time.Now().Format(time.RFC3339Nano)))
builder.WriteString("Event: api.websocket.response\n")
builder.Write(data)
builder.WriteString("\n")
appendAPIWebsocketTimeline(ginCtx, []byte(builder.String()))
}
// RecordAPIWebsocketError stores an upstream websocket error event in Gin context.
func RecordAPIWebsocketError(ctx context.Context, cfg *config.Config, stage string, err error) {
if cfg == nil || !cfg.RequestLog || err == nil {
return
}
ginCtx := ginContextFrom(ctx)
if ginCtx == nil {
return
}
markAPIResponseTimestamp(ginCtx)
builder := &strings.Builder{}
builder.WriteString(fmt.Sprintf("Timestamp: %s\n", time.Now().Format(time.RFC3339Nano)))
builder.WriteString("Event: api.websocket.error\n")
if trimmed := strings.TrimSpace(stage); trimmed != "" {
builder.WriteString(fmt.Sprintf("Stage: %s\n", trimmed))
}
builder.WriteString(fmt.Sprintf("Error: %s\n", err.Error()))
appendAPIWebsocketTimeline(ginCtx, []byte(builder.String()))
}
func ginContextFrom(ctx context.Context) *gin.Context {
ginCtx, _ := ctx.Value("gin").(*gin.Context)
return ginCtx
@@ -259,6 +404,40 @@ func updateAggregatedResponse(ginCtx *gin.Context, attempts []*upstreamAttempt)
ginCtx.Set(apiResponseKey, []byte(builder.String()))
}
func appendAPIWebsocketTimeline(ginCtx *gin.Context, chunk []byte) {
if ginCtx == nil {
return
}
data := bytes.TrimSpace(chunk)
if len(data) == 0 {
return
}
if existing, exists := ginCtx.Get(apiWebsocketTimelineKey); exists {
if existingBytes, ok := existing.([]byte); ok && len(existingBytes) > 0 {
combined := make([]byte, 0, len(existingBytes)+len(data)+2)
combined = append(combined, existingBytes...)
if !bytes.HasSuffix(existingBytes, []byte("\n")) {
combined = append(combined, '\n')
}
combined = append(combined, '\n')
combined = append(combined, data...)
ginCtx.Set(apiWebsocketTimelineKey, combined)
return
}
}
ginCtx.Set(apiWebsocketTimelineKey, bytes.Clone(data))
}
func markAPIResponseTimestamp(ginCtx *gin.Context) {
if ginCtx == nil {
return
}
if _, exists := ginCtx.Get("API_RESPONSE_TIMESTAMP"); exists {
return
}
ginCtx.Set("API_RESPONSE_TIMESTAMP", time.Now())
}
func writeHeaders(builder *strings.Builder, headers http.Header) {
if builder == nil {
return

View File

@@ -32,7 +32,7 @@ const (
wsEventTypeCompleted = "response.completed"
wsDoneMarker = "[DONE]"
wsTurnStateHeader = "x-codex-turn-state"
wsRequestBodyKey = "REQUEST_BODY_OVERRIDE"
wsTimelineBodyKey = "WEBSOCKET_TIMELINE_OVERRIDE"
)
var responsesWebsocketUpgrader = websocket.Upgrader{
@@ -57,10 +57,11 @@ func (h *OpenAIResponsesAPIHandler) ResponsesWebsocket(c *gin.Context) {
clientIP := websocketClientAddress(c)
log.Infof("responses websocket: client connected id=%s remote=%s", passthroughSessionID, clientIP)
var wsTerminateErr error
var wsBodyLog strings.Builder
var wsTimelineLog strings.Builder
defer func() {
releaseResponsesWebsocketToolCaches(downstreamSessionKey)
if wsTerminateErr != nil {
appendWebsocketTimelineDisconnect(&wsTimelineLog, wsTerminateErr, time.Now())
// log.Infof("responses websocket: session closing id=%s reason=%v", passthroughSessionID, wsTerminateErr)
} else {
log.Infof("responses websocket: session closing id=%s", passthroughSessionID)
@@ -69,7 +70,7 @@ func (h *OpenAIResponsesAPIHandler) ResponsesWebsocket(c *gin.Context) {
h.AuthManager.CloseExecutionSession(passthroughSessionID)
log.Infof("responses websocket: upstream execution session closed id=%s", passthroughSessionID)
}
setWebsocketRequestBody(c, wsBodyLog.String())
setWebsocketTimelineBody(c, wsTimelineLog.String())
if errClose := conn.Close(); errClose != nil {
log.Warnf("responses websocket: close connection error: %v", errClose)
}
@@ -83,7 +84,6 @@ func (h *OpenAIResponsesAPIHandler) ResponsesWebsocket(c *gin.Context) {
msgType, payload, errReadMessage := conn.ReadMessage()
if errReadMessage != nil {
wsTerminateErr = errReadMessage
appendWebsocketEvent(&wsBodyLog, "disconnect", []byte(errReadMessage.Error()))
if websocket.IsCloseError(errReadMessage, websocket.CloseNormalClosure, websocket.CloseGoingAway, websocket.CloseNoStatusReceived) {
log.Infof("responses websocket: client disconnected id=%s error=%v", passthroughSessionID, errReadMessage)
} else {
@@ -101,7 +101,7 @@ func (h *OpenAIResponsesAPIHandler) ResponsesWebsocket(c *gin.Context) {
// websocketPayloadEventType(payload),
// websocketPayloadPreview(payload),
// )
appendWebsocketEvent(&wsBodyLog, "request", payload)
appendWebsocketTimelineEvent(&wsTimelineLog, "request", payload, time.Now())
allowIncrementalInputWithPreviousResponseID := false
if pinnedAuthID != "" && h != nil && h.AuthManager != nil {
@@ -128,8 +128,7 @@ func (h *OpenAIResponsesAPIHandler) ResponsesWebsocket(c *gin.Context) {
if errMsg != nil {
h.LoggingAPIResponseError(context.WithValue(context.Background(), "gin", c), errMsg)
markAPIResponseTimestamp(c)
errorPayload, errWrite := writeResponsesWebsocketError(conn, errMsg)
appendWebsocketEvent(&wsBodyLog, "response", errorPayload)
errorPayload, errWrite := writeResponsesWebsocketError(conn, &wsTimelineLog, errMsg)
log.Infof(
"responses websocket: downstream_out id=%s type=%d event=%s payload=%s",
passthroughSessionID,
@@ -157,9 +156,8 @@ func (h *OpenAIResponsesAPIHandler) ResponsesWebsocket(c *gin.Context) {
}
lastRequest = updatedLastRequest
lastResponseOutput = []byte("[]")
if errWrite := writeResponsesWebsocketSyntheticPrewarm(c, conn, requestJSON, &wsBodyLog, passthroughSessionID); errWrite != nil {
if errWrite := writeResponsesWebsocketSyntheticPrewarm(c, conn, requestJSON, &wsTimelineLog, passthroughSessionID); errWrite != nil {
wsTerminateErr = errWrite
appendWebsocketEvent(&wsBodyLog, "disconnect", []byte(errWrite.Error()))
return
}
continue
@@ -192,10 +190,9 @@ func (h *OpenAIResponsesAPIHandler) ResponsesWebsocket(c *gin.Context) {
}
dataChan, _, errChan := h.ExecuteStreamWithAuthManager(cliCtx, h.HandlerType(), modelName, requestJSON, "")
completedOutput, errForward := h.forwardResponsesWebsocket(c, conn, cliCancel, dataChan, errChan, &wsBodyLog, passthroughSessionID)
completedOutput, errForward := h.forwardResponsesWebsocket(c, conn, cliCancel, dataChan, errChan, &wsTimelineLog, passthroughSessionID)
if errForward != nil {
wsTerminateErr = errForward
appendWebsocketEvent(&wsBodyLog, "disconnect", []byte(errForward.Error()))
log.Warnf("responses websocket: forward failed id=%s error=%v", passthroughSessionID, errForward)
return
}
@@ -597,7 +594,7 @@ func writeResponsesWebsocketSyntheticPrewarm(
c *gin.Context,
conn *websocket.Conn,
requestJSON []byte,
wsBodyLog *strings.Builder,
wsTimelineLog *strings.Builder,
sessionID string,
) error {
payloads, errPayloads := syntheticResponsesWebsocketPrewarmPayloads(requestJSON)
@@ -606,7 +603,6 @@ func writeResponsesWebsocketSyntheticPrewarm(
}
for i := 0; i < len(payloads); i++ {
markAPIResponseTimestamp(c)
appendWebsocketEvent(wsBodyLog, "response", payloads[i])
// log.Infof(
// "responses websocket: downstream_out id=%s type=%d event=%s payload=%s",
// sessionID,
@@ -614,7 +610,7 @@ func writeResponsesWebsocketSyntheticPrewarm(
// websocketPayloadEventType(payloads[i]),
// websocketPayloadPreview(payloads[i]),
// )
if errWrite := conn.WriteMessage(websocket.TextMessage, payloads[i]); errWrite != nil {
if errWrite := writeResponsesWebsocketPayload(conn, wsTimelineLog, payloads[i], time.Now()); errWrite != nil {
log.Warnf(
"responses websocket: downstream_out write failed id=%s event=%s error=%v",
sessionID,
@@ -713,7 +709,7 @@ func (h *OpenAIResponsesAPIHandler) forwardResponsesWebsocket(
cancel handlers.APIHandlerCancelFunc,
data <-chan []byte,
errs <-chan *interfaces.ErrorMessage,
wsBodyLog *strings.Builder,
wsTimelineLog *strings.Builder,
sessionID string,
) ([]byte, error) {
completed := false
@@ -736,8 +732,7 @@ func (h *OpenAIResponsesAPIHandler) forwardResponsesWebsocket(
if errMsg != nil {
h.LoggingAPIResponseError(context.WithValue(context.Background(), "gin", c), errMsg)
markAPIResponseTimestamp(c)
errorPayload, errWrite := writeResponsesWebsocketError(conn, errMsg)
appendWebsocketEvent(wsBodyLog, "response", errorPayload)
errorPayload, errWrite := writeResponsesWebsocketError(conn, wsTimelineLog, errMsg)
log.Infof(
"responses websocket: downstream_out id=%s type=%d event=%s payload=%s",
sessionID,
@@ -771,8 +766,7 @@ func (h *OpenAIResponsesAPIHandler) forwardResponsesWebsocket(
}
h.LoggingAPIResponseError(context.WithValue(context.Background(), "gin", c), errMsg)
markAPIResponseTimestamp(c)
errorPayload, errWrite := writeResponsesWebsocketError(conn, errMsg)
appendWebsocketEvent(wsBodyLog, "response", errorPayload)
errorPayload, errWrite := writeResponsesWebsocketError(conn, wsTimelineLog, errMsg)
log.Infof(
"responses websocket: downstream_out id=%s type=%d event=%s payload=%s",
sessionID,
@@ -806,7 +800,6 @@ func (h *OpenAIResponsesAPIHandler) forwardResponsesWebsocket(
completedOutput = responseCompletedOutputFromPayload(payloads[i])
}
markAPIResponseTimestamp(c)
appendWebsocketEvent(wsBodyLog, "response", payloads[i])
// log.Infof(
// "responses websocket: downstream_out id=%s type=%d event=%s payload=%s",
// sessionID,
@@ -814,7 +807,7 @@ func (h *OpenAIResponsesAPIHandler) forwardResponsesWebsocket(
// websocketPayloadEventType(payloads[i]),
// websocketPayloadPreview(payloads[i]),
// )
if errWrite := conn.WriteMessage(websocket.TextMessage, payloads[i]); errWrite != nil {
if errWrite := writeResponsesWebsocketPayload(conn, wsTimelineLog, payloads[i], time.Now()); errWrite != nil {
log.Warnf(
"responses websocket: downstream_out write failed id=%s event=%s error=%v",
sessionID,
@@ -870,7 +863,7 @@ func websocketJSONPayloadsFromChunk(chunk []byte) [][]byte {
return payloads
}
func writeResponsesWebsocketError(conn *websocket.Conn, errMsg *interfaces.ErrorMessage) ([]byte, error) {
func writeResponsesWebsocketError(conn *websocket.Conn, wsTimelineLog *strings.Builder, errMsg *interfaces.ErrorMessage) ([]byte, error) {
status := http.StatusInternalServerError
errText := http.StatusText(status)
if errMsg != nil {
@@ -940,7 +933,7 @@ func writeResponsesWebsocketError(conn *websocket.Conn, errMsg *interfaces.Error
}
}
return payload, conn.WriteMessage(websocket.TextMessage, payload)
return payload, writeResponsesWebsocketPayload(conn, wsTimelineLog, payload, time.Now())
}
func appendWebsocketEvent(builder *strings.Builder, eventType string, payload []byte) {
@@ -979,7 +972,11 @@ func websocketPayloadPreview(payload []byte) string {
return previewText
}
func setWebsocketRequestBody(c *gin.Context, body string) {
func setWebsocketTimelineBody(c *gin.Context, body string) {
setWebsocketBody(c, wsTimelineBodyKey, body)
}
func setWebsocketBody(c *gin.Context, key string, body string) {
if c == nil {
return
}
@@ -987,7 +984,40 @@ func setWebsocketRequestBody(c *gin.Context, body string) {
if trimmedBody == "" {
return
}
c.Set(wsRequestBodyKey, []byte(trimmedBody))
c.Set(key, []byte(trimmedBody))
}
func writeResponsesWebsocketPayload(conn *websocket.Conn, wsTimelineLog *strings.Builder, payload []byte, timestamp time.Time) error {
appendWebsocketTimelineEvent(wsTimelineLog, "response", payload, timestamp)
return conn.WriteMessage(websocket.TextMessage, payload)
}
func appendWebsocketTimelineDisconnect(builder *strings.Builder, err error, timestamp time.Time) {
if err == nil {
return
}
appendWebsocketTimelineEvent(builder, "disconnect", []byte(err.Error()), timestamp)
}
func appendWebsocketTimelineEvent(builder *strings.Builder, eventType string, payload []byte, timestamp time.Time) {
if builder == nil {
return
}
trimmedPayload := bytes.TrimSpace(payload)
if len(trimmedPayload) == 0 {
return
}
if builder.Len() > 0 {
builder.WriteString("\n")
}
builder.WriteString("Timestamp: ")
builder.WriteString(timestamp.Format(time.RFC3339Nano))
builder.WriteString("\n")
builder.WriteString("Event: websocket.")
builder.WriteString(eventType)
builder.WriteString("\n")
builder.Write(trimmedPayload)
builder.WriteString("\n")
}
func markAPIResponseTimestamp(c *gin.Context) {

View File

@@ -392,27 +392,45 @@ func TestAppendWebsocketEvent(t *testing.T) {
}
}
func TestSetWebsocketRequestBody(t *testing.T) {
func TestAppendWebsocketTimelineEvent(t *testing.T) {
var builder strings.Builder
ts := time.Date(2026, time.April, 1, 12, 34, 56, 789000000, time.UTC)
appendWebsocketTimelineEvent(&builder, "request", []byte(" {\"type\":\"response.create\"}\n"), ts)
got := builder.String()
if !strings.Contains(got, "Timestamp: 2026-04-01T12:34:56.789Z") {
t.Fatalf("timeline timestamp not found: %s", got)
}
if !strings.Contains(got, "Event: websocket.request") {
t.Fatalf("timeline event not found: %s", got)
}
if !strings.Contains(got, "{\"type\":\"response.create\"}") {
t.Fatalf("timeline payload not found: %s", got)
}
}
func TestSetWebsocketTimelineBody(t *testing.T) {
gin.SetMode(gin.TestMode)
recorder := httptest.NewRecorder()
c, _ := gin.CreateTestContext(recorder)
setWebsocketRequestBody(c, " \n ")
if _, exists := c.Get(wsRequestBodyKey); exists {
t.Fatalf("request body key should not be set for empty body")
setWebsocketTimelineBody(c, " \n ")
if _, exists := c.Get(wsTimelineBodyKey); exists {
t.Fatalf("timeline body key should not be set for empty body")
}
setWebsocketRequestBody(c, "event body")
value, exists := c.Get(wsRequestBodyKey)
setWebsocketTimelineBody(c, "timeline body")
value, exists := c.Get(wsTimelineBodyKey)
if !exists {
t.Fatalf("request body key not set")
t.Fatalf("timeline body key not set")
}
bodyBytes, ok := value.([]byte)
if !ok {
t.Fatalf("request body key type mismatch")
t.Fatalf("timeline body key type mismatch")
}
if string(bodyBytes) != "event body" {
t.Fatalf("request body = %q, want %q", string(bodyBytes), "event body")
if string(bodyBytes) != "timeline body" {
t.Fatalf("timeline body = %q, want %q", string(bodyBytes), "timeline body")
}
}
@@ -544,14 +562,14 @@ func TestForwardResponsesWebsocketPreservesCompletedEvent(t *testing.T) {
close(data)
close(errCh)
var bodyLog strings.Builder
var timelineLog strings.Builder
completedOutput, err := (*OpenAIResponsesAPIHandler)(nil).forwardResponsesWebsocket(
ctx,
conn,
func(...interface{}) {},
data,
errCh,
&bodyLog,
&timelineLog,
"session-1",
)
if err != nil {
@@ -562,6 +580,10 @@ func TestForwardResponsesWebsocketPreservesCompletedEvent(t *testing.T) {
serverErrCh <- errors.New("completed output not captured")
return
}
if !strings.Contains(timelineLog.String(), "Event: websocket.response") {
serverErrCh <- errors.New("websocket timeline did not capture downstream response")
return
}
serverErrCh <- nil
}))
defer server.Close()
@@ -594,6 +616,116 @@ func TestForwardResponsesWebsocketPreservesCompletedEvent(t *testing.T) {
}
}
func TestForwardResponsesWebsocketLogsAttemptedResponseOnWriteFailure(t *testing.T) {
gin.SetMode(gin.TestMode)
serverErrCh := make(chan error, 1)
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
conn, err := responsesWebsocketUpgrader.Upgrade(w, r, nil)
if err != nil {
serverErrCh <- err
return
}
ctx, _ := gin.CreateTestContext(httptest.NewRecorder())
ctx.Request = r
data := make(chan []byte, 1)
errCh := make(chan *interfaces.ErrorMessage)
data <- []byte("data: {\"type\":\"response.completed\",\"response\":{\"id\":\"resp-1\",\"output\":[{\"type\":\"message\",\"id\":\"out-1\"}]}}\n\n")
close(data)
close(errCh)
var timelineLog strings.Builder
if errClose := conn.Close(); errClose != nil {
serverErrCh <- errClose
return
}
_, err = (*OpenAIResponsesAPIHandler)(nil).forwardResponsesWebsocket(
ctx,
conn,
func(...interface{}) {},
data,
errCh,
&timelineLog,
"session-1",
)
if err == nil {
serverErrCh <- errors.New("expected websocket write failure")
return
}
if !strings.Contains(timelineLog.String(), "Event: websocket.response") {
serverErrCh <- errors.New("websocket timeline did not capture attempted downstream response")
return
}
if !strings.Contains(timelineLog.String(), "\"type\":\"response.completed\"") {
serverErrCh <- errors.New("websocket timeline did not retain attempted payload")
return
}
serverErrCh <- nil
}))
defer server.Close()
wsURL := "ws" + strings.TrimPrefix(server.URL, "http")
conn, _, err := websocket.DefaultDialer.Dial(wsURL, nil)
if err != nil {
t.Fatalf("dial websocket: %v", err)
}
defer func() {
_ = conn.Close()
}()
if errServer := <-serverErrCh; errServer != nil {
t.Fatalf("server error: %v", errServer)
}
}
func TestResponsesWebsocketTimelineRecordsDisconnectEvent(t *testing.T) {
gin.SetMode(gin.TestMode)
manager := coreauth.NewManager(nil, nil, nil)
base := handlers.NewBaseAPIHandlers(&sdkconfig.SDKConfig{}, manager)
h := NewOpenAIResponsesAPIHandler(base)
timelineCh := make(chan string, 1)
router := gin.New()
router.GET("/v1/responses/ws", func(c *gin.Context) {
h.ResponsesWebsocket(c)
timeline := ""
if value, exists := c.Get(wsTimelineBodyKey); exists {
if body, ok := value.([]byte); ok {
timeline = string(body)
}
}
timelineCh <- timeline
})
server := httptest.NewServer(router)
defer server.Close()
wsURL := "ws" + strings.TrimPrefix(server.URL, "http") + "/v1/responses/ws"
conn, _, err := websocket.DefaultDialer.Dial(wsURL, nil)
if err != nil {
t.Fatalf("dial websocket: %v", err)
}
closePayload := websocket.FormatCloseMessage(websocket.CloseGoingAway, "client closing")
if err = conn.WriteControl(websocket.CloseMessage, closePayload, time.Now().Add(time.Second)); err != nil {
t.Fatalf("write close control: %v", err)
}
_ = conn.Close()
select {
case timeline := <-timelineCh:
if !strings.Contains(timeline, "Event: websocket.disconnect") {
t.Fatalf("websocket timeline missing disconnect event: %s", timeline)
}
case <-time.After(5 * time.Second):
t.Fatal("timed out waiting for websocket timeline")
}
}
func TestWebsocketUpstreamSupportsIncrementalInputForModel(t *testing.T) {
manager := coreauth.NewManager(nil, nil, nil)
auth := &coreauth.Auth{