mirror of
https://github.com/router-for-me/CLIProxyAPIPlus.git
synced 2026-04-20 22:51:45 +00:00
Merge PR #479
This commit is contained in:
@@ -10,6 +10,7 @@ import (
|
||||
"bytes"
|
||||
"context"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
@@ -22,6 +23,25 @@ import (
|
||||
"github.com/tidwall/sjson"
|
||||
)
|
||||
|
||||
func writeResponsesSSEChunk(w io.Writer, chunk []byte) {
|
||||
if w == nil || len(chunk) == 0 {
|
||||
return
|
||||
}
|
||||
if _, err := w.Write(chunk); err != nil {
|
||||
return
|
||||
}
|
||||
if bytes.HasSuffix(chunk, []byte("\n\n")) {
|
||||
return
|
||||
}
|
||||
suffix := []byte("\n\n")
|
||||
if bytes.HasSuffix(chunk, []byte("\n")) {
|
||||
suffix = []byte("\n")
|
||||
}
|
||||
if _, err := w.Write(suffix); err != nil {
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
// OpenAIResponsesAPIHandler contains the handlers for OpenAIResponses API endpoints.
|
||||
// It holds a pool of clients to interact with the backend service.
|
||||
type OpenAIResponsesAPIHandler struct {
|
||||
@@ -271,11 +291,7 @@ func (h *OpenAIResponsesAPIHandler) handleStreamingResponse(c *gin.Context, rawJ
|
||||
handlers.WriteUpstreamHeaders(c.Writer.Header(), upstreamHeaders)
|
||||
|
||||
// Write first chunk logic (matching forwardResponsesStream)
|
||||
if bytes.HasPrefix(chunk, []byte("event:")) {
|
||||
_, _ = c.Writer.Write([]byte("\n"))
|
||||
}
|
||||
_, _ = c.Writer.Write(chunk)
|
||||
_, _ = c.Writer.Write([]byte("\n"))
|
||||
writeResponsesSSEChunk(c.Writer, chunk)
|
||||
flusher.Flush()
|
||||
|
||||
// Continue
|
||||
@@ -400,11 +416,7 @@ func (h *OpenAIResponsesAPIHandler) forwardChatAsResponsesStream(c *gin.Context,
|
||||
func (h *OpenAIResponsesAPIHandler) forwardResponsesStream(c *gin.Context, flusher http.Flusher, cancel func(error), data <-chan []byte, errs <-chan *interfaces.ErrorMessage) {
|
||||
h.ForwardStream(c, flusher, cancel, data, errs, handlers.StreamForwardOptions{
|
||||
WriteChunk: func(chunk []byte) {
|
||||
if bytes.HasPrefix(chunk, []byte("event:")) {
|
||||
_, _ = c.Writer.Write([]byte("\n"))
|
||||
}
|
||||
_, _ = c.Writer.Write(chunk)
|
||||
_, _ = c.Writer.Write([]byte("\n"))
|
||||
writeResponsesSSEChunk(c.Writer, chunk)
|
||||
},
|
||||
WriteTerminalError: func(errMsg *interfaces.ErrorMessage) {
|
||||
if errMsg == nil {
|
||||
|
||||
@@ -0,0 +1,52 @@
|
||||
package openai
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/interfaces"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/sdk/api/handlers"
|
||||
sdkconfig "github.com/router-for-me/CLIProxyAPI/v6/sdk/config"
|
||||
)
|
||||
|
||||
func TestForwardResponsesStreamSeparatesDataOnlySSEChunks(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
base := handlers.NewBaseAPIHandlers(&sdkconfig.SDKConfig{}, nil)
|
||||
h := NewOpenAIResponsesAPIHandler(base)
|
||||
|
||||
recorder := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(recorder)
|
||||
c.Request = httptest.NewRequest(http.MethodPost, "/v1/responses", nil)
|
||||
|
||||
flusher, ok := c.Writer.(http.Flusher)
|
||||
if !ok {
|
||||
t.Fatalf("expected gin writer to implement http.Flusher")
|
||||
}
|
||||
|
||||
data := make(chan []byte, 2)
|
||||
errs := make(chan *interfaces.ErrorMessage)
|
||||
data <- []byte("data: {\"type\":\"response.output_item.done\",\"item\":{\"type\":\"function_call\",\"arguments\":\"{}\"}}")
|
||||
data <- []byte("data: {\"type\":\"response.completed\",\"response\":{\"id\":\"resp-1\",\"output\":[]}}")
|
||||
close(data)
|
||||
close(errs)
|
||||
|
||||
h.forwardResponsesStream(c, flusher, func(error) {}, data, errs)
|
||||
body := recorder.Body.String()
|
||||
parts := strings.Split(strings.TrimSpace(body), "\n\n")
|
||||
if len(parts) != 2 {
|
||||
t.Fatalf("expected 2 SSE events, got %d. Body: %q", len(parts), body)
|
||||
}
|
||||
|
||||
expectedPart1 := "data: {\"type\":\"response.output_item.done\",\"item\":{\"type\":\"function_call\",\"arguments\":\"{}\"}}"
|
||||
if parts[0] != expectedPart1 {
|
||||
t.Errorf("unexpected first event.\nGot: %q\nWant: %q", parts[0], expectedPart1)
|
||||
}
|
||||
|
||||
expectedPart2 := "data: {\"type\":\"response.completed\",\"response\":{\"id\":\"resp-1\",\"output\":[]}}"
|
||||
if parts[1] != expectedPart2 {
|
||||
t.Errorf("unexpected second event.\nGot: %q\nWant: %q", parts[1], expectedPart2)
|
||||
}
|
||||
}
|
||||
@@ -33,9 +33,6 @@ const (
|
||||
wsDoneMarker = "[DONE]"
|
||||
wsTurnStateHeader = "x-codex-turn-state"
|
||||
wsRequestBodyKey = "REQUEST_BODY_OVERRIDE"
|
||||
wsPayloadLogMaxSize = 2048
|
||||
wsBodyLogMaxSize = 64 * 1024
|
||||
wsBodyLogTruncated = "\n[websocket log truncated]\n"
|
||||
)
|
||||
|
||||
var responsesWebsocketUpgrader = websocket.Upgrader{
|
||||
@@ -55,14 +52,14 @@ func (h *OpenAIResponsesAPIHandler) ResponsesWebsocket(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
passthroughSessionID := uuid.NewString()
|
||||
clientRemoteAddr := ""
|
||||
if c != nil && c.Request != nil {
|
||||
clientRemoteAddr = strings.TrimSpace(c.Request.RemoteAddr)
|
||||
}
|
||||
log.Infof("responses websocket: client connected id=%s remote=%s", passthroughSessionID, clientRemoteAddr)
|
||||
downstreamSessionKey := websocketDownstreamSessionKey(c.Request)
|
||||
retainResponsesWebsocketToolCaches(downstreamSessionKey)
|
||||
clientIP := websocketClientAddress(c)
|
||||
log.Infof("responses websocket: client connected id=%s remote=%s", passthroughSessionID, clientIP)
|
||||
var wsTerminateErr error
|
||||
var wsBodyLog strings.Builder
|
||||
defer func() {
|
||||
releaseResponsesWebsocketToolCaches(downstreamSessionKey)
|
||||
if wsTerminateErr != nil {
|
||||
// log.Infof("responses websocket: session closing id=%s reason=%v", passthroughSessionID, wsTerminateErr)
|
||||
} else {
|
||||
@@ -167,6 +164,9 @@ func (h *OpenAIResponsesAPIHandler) ResponsesWebsocket(c *gin.Context) {
|
||||
}
|
||||
continue
|
||||
}
|
||||
|
||||
requestJSON = repairResponsesWebsocketToolCalls(downstreamSessionKey, requestJSON)
|
||||
updatedLastRequest = bytes.Clone(requestJSON)
|
||||
lastRequest = updatedLastRequest
|
||||
|
||||
modelName := gjson.GetBytes(requestJSON, "model").String()
|
||||
@@ -203,6 +203,13 @@ func (h *OpenAIResponsesAPIHandler) ResponsesWebsocket(c *gin.Context) {
|
||||
}
|
||||
}
|
||||
|
||||
func websocketClientAddress(c *gin.Context) string {
|
||||
if c == nil || c.Request == nil {
|
||||
return ""
|
||||
}
|
||||
return strings.TrimSpace(c.ClientIP())
|
||||
}
|
||||
|
||||
func websocketUpgradeHeaders(req *http.Request) http.Header {
|
||||
headers := http.Header{}
|
||||
if req == nil {
|
||||
@@ -277,6 +284,15 @@ func normalizeResponseSubsequentRequest(rawJSON []byte, lastRequest []byte, last
|
||||
}
|
||||
}
|
||||
|
||||
// Compaction can cause clients to replace local websocket history with a new
|
||||
// compact transcript on the next `response.create`. When the input already
|
||||
// contains historical model output items, treating it as an incremental append
|
||||
// duplicates stale turn-state and can leave late orphaned function_call items.
|
||||
if shouldReplaceWebsocketTranscript(rawJSON, nextInput) {
|
||||
normalized := normalizeResponseTranscriptReplacement(rawJSON, lastRequest)
|
||||
return normalized, bytes.Clone(normalized), nil
|
||||
}
|
||||
|
||||
// Websocket v2 mode uses response.create with previous_response_id + incremental input.
|
||||
// Do not expand it into a full input transcript; upstream expects the incremental payload.
|
||||
if allowIncrementalInputWithPreviousResponseID {
|
||||
@@ -318,6 +334,10 @@ func normalizeResponseSubsequentRequest(rawJSON []byte, lastRequest []byte, last
|
||||
Error: fmt.Errorf("invalid request input: %w", errMerge),
|
||||
}
|
||||
}
|
||||
dedupedInput, errDedupeFunctionCalls := dedupeFunctionCallsByCallID(mergedInput)
|
||||
if errDedupeFunctionCalls == nil {
|
||||
mergedInput = dedupedInput
|
||||
}
|
||||
|
||||
normalized, errDelete := sjson.DeleteBytes(rawJSON, "type")
|
||||
if errDelete != nil {
|
||||
@@ -348,6 +368,91 @@ func normalizeResponseSubsequentRequest(rawJSON []byte, lastRequest []byte, last
|
||||
return normalized, bytes.Clone(normalized), nil
|
||||
}
|
||||
|
||||
func shouldReplaceWebsocketTranscript(rawJSON []byte, nextInput gjson.Result) bool {
|
||||
requestType := strings.TrimSpace(gjson.GetBytes(rawJSON, "type").String())
|
||||
if requestType != wsRequestTypeCreate && requestType != wsRequestTypeAppend {
|
||||
return false
|
||||
}
|
||||
if strings.TrimSpace(gjson.GetBytes(rawJSON, "previous_response_id").String()) != "" {
|
||||
return false
|
||||
}
|
||||
if !nextInput.Exists() || !nextInput.IsArray() {
|
||||
return false
|
||||
}
|
||||
|
||||
for _, item := range nextInput.Array() {
|
||||
switch strings.TrimSpace(item.Get("type").String()) {
|
||||
case "function_call":
|
||||
return true
|
||||
case "message":
|
||||
role := strings.TrimSpace(item.Get("role").String())
|
||||
if role == "assistant" {
|
||||
return true
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
func normalizeResponseTranscriptReplacement(rawJSON []byte, lastRequest []byte) []byte {
|
||||
normalized, errDelete := sjson.DeleteBytes(rawJSON, "type")
|
||||
if errDelete != nil {
|
||||
normalized = bytes.Clone(rawJSON)
|
||||
}
|
||||
normalized, _ = sjson.DeleteBytes(normalized, "previous_response_id")
|
||||
if !gjson.GetBytes(normalized, "model").Exists() {
|
||||
modelName := strings.TrimSpace(gjson.GetBytes(lastRequest, "model").String())
|
||||
if modelName != "" {
|
||||
normalized, _ = sjson.SetBytes(normalized, "model", modelName)
|
||||
}
|
||||
}
|
||||
if !gjson.GetBytes(normalized, "instructions").Exists() {
|
||||
instructions := gjson.GetBytes(lastRequest, "instructions")
|
||||
if instructions.Exists() {
|
||||
normalized, _ = sjson.SetRawBytes(normalized, "instructions", []byte(instructions.Raw))
|
||||
}
|
||||
}
|
||||
normalized, _ = sjson.SetBytes(normalized, "stream", true)
|
||||
return bytes.Clone(normalized)
|
||||
}
|
||||
|
||||
func dedupeFunctionCallsByCallID(rawArray string) (string, error) {
|
||||
rawArray = strings.TrimSpace(rawArray)
|
||||
if rawArray == "" {
|
||||
return "[]", nil
|
||||
}
|
||||
var items []json.RawMessage
|
||||
if errUnmarshal := json.Unmarshal([]byte(rawArray), &items); errUnmarshal != nil {
|
||||
return "", errUnmarshal
|
||||
}
|
||||
|
||||
seenCallIDs := make(map[string]struct{}, len(items))
|
||||
filtered := make([]json.RawMessage, 0, len(items))
|
||||
for _, item := range items {
|
||||
if len(item) == 0 {
|
||||
continue
|
||||
}
|
||||
itemType := strings.TrimSpace(gjson.GetBytes(item, "type").String())
|
||||
if itemType == "function_call" {
|
||||
callID := strings.TrimSpace(gjson.GetBytes(item, "call_id").String())
|
||||
if callID != "" {
|
||||
if _, ok := seenCallIDs[callID]; ok {
|
||||
continue
|
||||
}
|
||||
seenCallIDs[callID] = struct{}{}
|
||||
}
|
||||
}
|
||||
filtered = append(filtered, item)
|
||||
}
|
||||
|
||||
out, errMarshal := json.Marshal(filtered)
|
||||
if errMarshal != nil {
|
||||
return "", errMarshal
|
||||
}
|
||||
return string(out), nil
|
||||
}
|
||||
|
||||
func websocketUpstreamSupportsIncrementalInput(attributes map[string]string, metadata map[string]any) bool {
|
||||
if len(attributes) > 0 {
|
||||
if raw := strings.TrimSpace(attributes["websockets"]); raw != "" {
|
||||
@@ -613,6 +718,10 @@ func (h *OpenAIResponsesAPIHandler) forwardResponsesWebsocket(
|
||||
) ([]byte, error) {
|
||||
completed := false
|
||||
completedOutput := []byte("[]")
|
||||
downstreamSessionKey := ""
|
||||
if c != nil && c.Request != nil {
|
||||
downstreamSessionKey = websocketDownstreamSessionKey(c.Request)
|
||||
}
|
||||
|
||||
for {
|
||||
select {
|
||||
@@ -690,6 +799,7 @@ func (h *OpenAIResponsesAPIHandler) forwardResponsesWebsocket(
|
||||
|
||||
payloads := websocketJSONPayloadsFromChunk(chunk)
|
||||
for i := range payloads {
|
||||
recordResponsesWebsocketToolCallsFromPayload(downstreamSessionKey, payloads[i])
|
||||
eventType := gjson.GetBytes(payloads[i], "type").String()
|
||||
if eventType == wsEventTypeCompleted {
|
||||
completed = true
|
||||
@@ -837,71 +947,18 @@ func appendWebsocketEvent(builder *strings.Builder, eventType string, payload []
|
||||
if builder == nil {
|
||||
return
|
||||
}
|
||||
if builder.Len() >= wsBodyLogMaxSize {
|
||||
return
|
||||
}
|
||||
trimmedPayload := bytes.TrimSpace(payload)
|
||||
if len(trimmedPayload) == 0 {
|
||||
return
|
||||
}
|
||||
if builder.Len() > 0 {
|
||||
if !appendWebsocketLogString(builder, "\n") {
|
||||
return
|
||||
}
|
||||
builder.WriteString("\n")
|
||||
}
|
||||
if !appendWebsocketLogString(builder, "websocket.") {
|
||||
return
|
||||
}
|
||||
if !appendWebsocketLogString(builder, eventType) {
|
||||
return
|
||||
}
|
||||
if !appendWebsocketLogString(builder, "\n") {
|
||||
return
|
||||
}
|
||||
if !appendWebsocketLogBytes(builder, trimmedPayload, len(wsBodyLogTruncated)) {
|
||||
appendWebsocketLogString(builder, wsBodyLogTruncated)
|
||||
return
|
||||
}
|
||||
appendWebsocketLogString(builder, "\n")
|
||||
}
|
||||
|
||||
func appendWebsocketLogString(builder *strings.Builder, value string) bool {
|
||||
if builder == nil {
|
||||
return false
|
||||
}
|
||||
remaining := wsBodyLogMaxSize - builder.Len()
|
||||
if remaining <= 0 {
|
||||
return false
|
||||
}
|
||||
if len(value) <= remaining {
|
||||
builder.WriteString(value)
|
||||
return true
|
||||
}
|
||||
builder.WriteString(value[:remaining])
|
||||
return false
|
||||
}
|
||||
|
||||
func appendWebsocketLogBytes(builder *strings.Builder, value []byte, reserveForSuffix int) bool {
|
||||
if builder == nil {
|
||||
return false
|
||||
}
|
||||
remaining := wsBodyLogMaxSize - builder.Len()
|
||||
if remaining <= 0 {
|
||||
return false
|
||||
}
|
||||
if len(value) <= remaining {
|
||||
builder.Write(value)
|
||||
return true
|
||||
}
|
||||
limit := remaining - reserveForSuffix
|
||||
if limit < 0 {
|
||||
limit = 0
|
||||
}
|
||||
if limit > len(value) {
|
||||
limit = len(value)
|
||||
}
|
||||
builder.Write(value[:limit])
|
||||
return false
|
||||
builder.WriteString("websocket.")
|
||||
builder.WriteString(eventType)
|
||||
builder.WriteString("\n")
|
||||
builder.Write(trimmedPayload)
|
||||
builder.WriteString("\n")
|
||||
}
|
||||
|
||||
func websocketPayloadEventType(payload []byte) string {
|
||||
@@ -917,15 +974,8 @@ func websocketPayloadPreview(payload []byte) string {
|
||||
if len(trimmedPayload) == 0 {
|
||||
return "<empty>"
|
||||
}
|
||||
preview := trimmedPayload
|
||||
if len(preview) > wsPayloadLogMaxSize {
|
||||
preview = preview[:wsPayloadLogMaxSize]
|
||||
}
|
||||
previewText := strings.ReplaceAll(string(preview), "\n", "\\n")
|
||||
previewText := strings.ReplaceAll(string(trimmedPayload), "\n", "\\n")
|
||||
previewText = strings.ReplaceAll(previewText, "\r", "\\r")
|
||||
if len(trimmedPayload) > wsPayloadLogMaxSize {
|
||||
return fmt.Sprintf("%s...(truncated,total=%d)", previewText, len(trimmedPayload))
|
||||
}
|
||||
return previewText
|
||||
}
|
||||
|
||||
|
||||
@@ -10,6 +10,7 @@ import (
|
||||
"strings"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/gorilla/websocket"
|
||||
@@ -27,6 +28,12 @@ type websocketCaptureExecutor struct {
|
||||
payloads [][]byte
|
||||
}
|
||||
|
||||
type websocketCompactionCaptureExecutor struct {
|
||||
mu sync.Mutex
|
||||
streamPayloads [][]byte
|
||||
compactPayload []byte
|
||||
}
|
||||
|
||||
type orderedWebsocketSelector struct {
|
||||
mu sync.Mutex
|
||||
order []string
|
||||
@@ -126,6 +133,52 @@ func (e *websocketCaptureExecutor) HttpRequest(context.Context, *coreauth.Auth,
|
||||
return nil, errors.New("not implemented")
|
||||
}
|
||||
|
||||
func (e *websocketCompactionCaptureExecutor) Identifier() string { return "test-provider" }
|
||||
|
||||
func (e *websocketCompactionCaptureExecutor) Execute(_ context.Context, _ *coreauth.Auth, req coreexecutor.Request, opts coreexecutor.Options) (coreexecutor.Response, error) {
|
||||
e.mu.Lock()
|
||||
e.compactPayload = bytes.Clone(req.Payload)
|
||||
e.mu.Unlock()
|
||||
if opts.Alt != "responses/compact" {
|
||||
return coreexecutor.Response{}, fmt.Errorf("unexpected non-compact execute alt: %q", opts.Alt)
|
||||
}
|
||||
return coreexecutor.Response{Payload: []byte(`{"id":"cmp-1","object":"response.compaction"}`)}, nil
|
||||
}
|
||||
|
||||
func (e *websocketCompactionCaptureExecutor) ExecuteStream(_ context.Context, _ *coreauth.Auth, req coreexecutor.Request, _ coreexecutor.Options) (*coreexecutor.StreamResult, error) {
|
||||
e.mu.Lock()
|
||||
callIndex := len(e.streamPayloads)
|
||||
e.streamPayloads = append(e.streamPayloads, bytes.Clone(req.Payload))
|
||||
e.mu.Unlock()
|
||||
|
||||
var payload []byte
|
||||
switch callIndex {
|
||||
case 0:
|
||||
payload = []byte(`{"type":"response.completed","response":{"id":"resp-1","output":[{"type":"function_call","id":"fc-1","call_id":"call-1","name":"tool"}]}}`)
|
||||
case 1:
|
||||
payload = []byte(`{"type":"response.completed","response":{"id":"resp-2","output":[{"type":"message","id":"assistant-1"}]}}`)
|
||||
default:
|
||||
payload = []byte(`{"type":"response.completed","response":{"id":"resp-3","output":[{"type":"message","id":"assistant-2"}]}}`)
|
||||
}
|
||||
|
||||
chunks := make(chan coreexecutor.StreamChunk, 1)
|
||||
chunks <- coreexecutor.StreamChunk{Payload: payload}
|
||||
close(chunks)
|
||||
return &coreexecutor.StreamResult{Chunks: chunks}, nil
|
||||
}
|
||||
|
||||
func (e *websocketCompactionCaptureExecutor) Refresh(_ context.Context, auth *coreauth.Auth) (*coreauth.Auth, error) {
|
||||
return auth, nil
|
||||
}
|
||||
|
||||
func (e *websocketCompactionCaptureExecutor) CountTokens(context.Context, *coreauth.Auth, coreexecutor.Request, coreexecutor.Options) (coreexecutor.Response, error) {
|
||||
return coreexecutor.Response{}, errors.New("not implemented")
|
||||
}
|
||||
|
||||
func (e *websocketCompactionCaptureExecutor) HttpRequest(context.Context, *coreauth.Auth, *http.Request) (*http.Response, error) {
|
||||
return nil, errors.New("not implemented")
|
||||
}
|
||||
|
||||
func TestNormalizeResponsesWebsocketRequestCreate(t *testing.T) {
|
||||
raw := []byte(`{"type":"response.create","model":"test-model","stream":false,"input":[{"type":"message","id":"msg-1"}]}`)
|
||||
|
||||
@@ -339,33 +392,6 @@ func TestAppendWebsocketEvent(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestAppendWebsocketEventTruncatesAtLimit(t *testing.T) {
|
||||
var builder strings.Builder
|
||||
payload := bytes.Repeat([]byte("x"), wsBodyLogMaxSize)
|
||||
|
||||
appendWebsocketEvent(&builder, "request", payload)
|
||||
|
||||
got := builder.String()
|
||||
if len(got) > wsBodyLogMaxSize {
|
||||
t.Fatalf("body log len = %d, want <= %d", len(got), wsBodyLogMaxSize)
|
||||
}
|
||||
if !strings.Contains(got, wsBodyLogTruncated) {
|
||||
t.Fatalf("expected truncation marker in body log")
|
||||
}
|
||||
}
|
||||
|
||||
func TestAppendWebsocketEventNoGrowthAfterLimit(t *testing.T) {
|
||||
var builder strings.Builder
|
||||
appendWebsocketEvent(&builder, "request", bytes.Repeat([]byte("x"), wsBodyLogMaxSize))
|
||||
initial := builder.String()
|
||||
|
||||
appendWebsocketEvent(&builder, "response", []byte(`{"type":"response.completed"}`))
|
||||
|
||||
if builder.String() != initial {
|
||||
t.Fatalf("builder grew after reaching limit")
|
||||
}
|
||||
}
|
||||
|
||||
func TestSetWebsocketRequestBody(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
recorder := httptest.NewRecorder()
|
||||
@@ -390,6 +416,108 @@ func TestSetWebsocketRequestBody(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestRepairResponsesWebsocketToolCallsInsertsCachedOutput(t *testing.T) {
|
||||
cache := newWebsocketToolOutputCache(time.Minute, 10)
|
||||
sessionKey := "session-1"
|
||||
|
||||
cacheWarm := []byte(`{"previous_response_id":"resp-1","input":[{"type":"function_call_output","call_id":"call-1","output":"ok"}]}`)
|
||||
warmed := repairResponsesWebsocketToolCallsWithCache(cache, sessionKey, cacheWarm)
|
||||
if gjson.GetBytes(warmed, "input.0.call_id").String() != "call-1" {
|
||||
t.Fatalf("expected warmup output to remain")
|
||||
}
|
||||
|
||||
raw := []byte(`{"input":[{"type":"function_call","call_id":"call-1","name":"tool"},{"type":"message","id":"msg-1"}]}`)
|
||||
repaired := repairResponsesWebsocketToolCallsWithCache(cache, sessionKey, raw)
|
||||
|
||||
input := gjson.GetBytes(repaired, "input").Array()
|
||||
if len(input) != 3 {
|
||||
t.Fatalf("repaired input len = %d, want 3", len(input))
|
||||
}
|
||||
if input[0].Get("type").String() != "function_call" || input[0].Get("call_id").String() != "call-1" {
|
||||
t.Fatalf("unexpected first item: %s", input[0].Raw)
|
||||
}
|
||||
if input[1].Get("type").String() != "function_call_output" || input[1].Get("call_id").String() != "call-1" {
|
||||
t.Fatalf("missing inserted output: %s", input[1].Raw)
|
||||
}
|
||||
if input[2].Get("type").String() != "message" || input[2].Get("id").String() != "msg-1" {
|
||||
t.Fatalf("unexpected trailing item: %s", input[2].Raw)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRepairResponsesWebsocketToolCallsDropsOrphanFunctionCall(t *testing.T) {
|
||||
cache := newWebsocketToolOutputCache(time.Minute, 10)
|
||||
sessionKey := "session-1"
|
||||
|
||||
raw := []byte(`{"input":[{"type":"function_call","call_id":"call-1","name":"tool"},{"type":"message","id":"msg-1"}]}`)
|
||||
repaired := repairResponsesWebsocketToolCallsWithCache(cache, sessionKey, raw)
|
||||
|
||||
input := gjson.GetBytes(repaired, "input").Array()
|
||||
if len(input) != 1 {
|
||||
t.Fatalf("repaired input len = %d, want 1", len(input))
|
||||
}
|
||||
if input[0].Get("type").String() != "message" || input[0].Get("id").String() != "msg-1" {
|
||||
t.Fatalf("unexpected remaining item: %s", input[0].Raw)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRepairResponsesWebsocketToolCallsInsertsCachedCallForOrphanOutput(t *testing.T) {
|
||||
outputCache := newWebsocketToolOutputCache(time.Minute, 10)
|
||||
callCache := newWebsocketToolOutputCache(time.Minute, 10)
|
||||
sessionKey := "session-1"
|
||||
|
||||
callCache.record(sessionKey, "call-1", []byte(`{"type":"function_call","call_id":"call-1","name":"tool"}`))
|
||||
|
||||
raw := []byte(`{"input":[{"type":"function_call_output","call_id":"call-1","output":"ok"},{"type":"message","id":"msg-1"}]}`)
|
||||
repaired := repairResponsesWebsocketToolCallsWithCaches(outputCache, callCache, sessionKey, raw)
|
||||
|
||||
input := gjson.GetBytes(repaired, "input").Array()
|
||||
if len(input) != 3 {
|
||||
t.Fatalf("repaired input len = %d, want 3", len(input))
|
||||
}
|
||||
if input[0].Get("type").String() != "function_call" || input[0].Get("call_id").String() != "call-1" {
|
||||
t.Fatalf("missing inserted call: %s", input[0].Raw)
|
||||
}
|
||||
if input[1].Get("type").String() != "function_call_output" || input[1].Get("call_id").String() != "call-1" {
|
||||
t.Fatalf("unexpected output item: %s", input[1].Raw)
|
||||
}
|
||||
if input[2].Get("type").String() != "message" || input[2].Get("id").String() != "msg-1" {
|
||||
t.Fatalf("unexpected trailing item: %s", input[2].Raw)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRepairResponsesWebsocketToolCallsDropsOrphanOutputWhenCallMissing(t *testing.T) {
|
||||
outputCache := newWebsocketToolOutputCache(time.Minute, 10)
|
||||
callCache := newWebsocketToolOutputCache(time.Minute, 10)
|
||||
sessionKey := "session-1"
|
||||
|
||||
raw := []byte(`{"input":[{"type":"function_call_output","call_id":"call-1","output":"ok"},{"type":"message","id":"msg-1"}]}`)
|
||||
repaired := repairResponsesWebsocketToolCallsWithCaches(outputCache, callCache, sessionKey, raw)
|
||||
|
||||
input := gjson.GetBytes(repaired, "input").Array()
|
||||
if len(input) != 1 {
|
||||
t.Fatalf("repaired input len = %d, want 1", len(input))
|
||||
}
|
||||
if input[0].Get("type").String() != "message" || input[0].Get("id").String() != "msg-1" {
|
||||
t.Fatalf("unexpected remaining item: %s", input[0].Raw)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRecordResponsesWebsocketToolCallsFromPayloadWithCache(t *testing.T) {
|
||||
cache := newWebsocketToolOutputCache(time.Minute, 10)
|
||||
sessionKey := "session-1"
|
||||
|
||||
payload := []byte(`{"type":"response.completed","response":{"id":"resp-1","output":[{"type":"function_call","id":"fc-1","call_id":"call-1","name":"tool","arguments":"{}"}]}}`)
|
||||
recordResponsesWebsocketToolCallsFromPayloadWithCache(cache, sessionKey, payload)
|
||||
|
||||
cached, ok := cache.get(sessionKey, "call-1")
|
||||
if !ok {
|
||||
t.Fatalf("expected cached tool call")
|
||||
}
|
||||
if gjson.GetBytes(cached, "type").String() != "function_call" || gjson.GetBytes(cached, "call_id").String() != "call-1" {
|
||||
t.Fatalf("unexpected cached tool call: %s", cached)
|
||||
}
|
||||
}
|
||||
|
||||
func TestForwardResponsesWebsocketPreservesCompletedEvent(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
@@ -593,6 +721,31 @@ func TestResponsesWebsocketPrewarmHandledLocallyForSSEUpstream(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestWebsocketClientAddressUsesGinClientIP(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
recorder := httptest.NewRecorder()
|
||||
c, engine := gin.CreateTestContext(recorder)
|
||||
if err := engine.SetTrustedProxies([]string{"0.0.0.0/0", "::/0"}); err != nil {
|
||||
t.Fatalf("SetTrustedProxies: %v", err)
|
||||
}
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/v1/responses/ws", nil)
|
||||
req.RemoteAddr = "172.18.0.1:34282"
|
||||
req.Header.Set("X-Forwarded-For", "203.0.113.7")
|
||||
c.Request = req
|
||||
|
||||
if got := websocketClientAddress(c); got != strings.TrimSpace(c.ClientIP()) {
|
||||
t.Fatalf("websocketClientAddress = %q, ClientIP = %q", got, c.ClientIP())
|
||||
}
|
||||
}
|
||||
|
||||
func TestWebsocketClientAddressReturnsEmptyForNilContext(t *testing.T) {
|
||||
if got := websocketClientAddress(nil); got != "" {
|
||||
t.Fatalf("websocketClientAddress(nil) = %q, want empty", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestResponsesWebsocketPinsOnlyWebsocketCapableAuth(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
@@ -662,3 +815,183 @@ func TestResponsesWebsocketPinsOnlyWebsocketCapableAuth(t *testing.T) {
|
||||
t.Fatalf("selected auth IDs = %v, want [auth-sse auth-ws]", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestNormalizeResponsesWebsocketRequestTreatsTranscriptReplacementAsReset(t *testing.T) {
|
||||
lastRequest := []byte(`{"model":"test-model","stream":true,"input":[{"type":"message","id":"msg-1"},{"type":"function_call","id":"fc-1","call_id":"call-1"},{"type":"function_call_output","id":"tool-out-1","call_id":"call-1"},{"type":"message","id":"assistant-1","role":"assistant"}]}`)
|
||||
lastResponseOutput := []byte(`[
|
||||
{"type":"message","id":"assistant-1","role":"assistant"}
|
||||
]`)
|
||||
raw := []byte(`{"type":"response.create","input":[{"type":"function_call","id":"fc-compact","call_id":"call-1","name":"tool"},{"type":"message","id":"msg-2"}]}`)
|
||||
|
||||
normalized, next, errMsg := normalizeResponsesWebsocketRequest(raw, lastRequest, lastResponseOutput)
|
||||
if errMsg != nil {
|
||||
t.Fatalf("unexpected error: %v", errMsg.Error)
|
||||
}
|
||||
if gjson.GetBytes(normalized, "previous_response_id").Exists() {
|
||||
t.Fatalf("previous_response_id must not exist in transcript replacement mode")
|
||||
}
|
||||
items := gjson.GetBytes(normalized, "input").Array()
|
||||
if len(items) != 2 {
|
||||
t.Fatalf("replacement input len = %d, want 2: %s", len(items), normalized)
|
||||
}
|
||||
if items[0].Get("id").String() != "fc-compact" || items[1].Get("id").String() != "msg-2" {
|
||||
t.Fatalf("replacement transcript was not preserved as-is: %s", normalized)
|
||||
}
|
||||
if !bytes.Equal(next, normalized) {
|
||||
t.Fatalf("next request snapshot should match replacement request")
|
||||
}
|
||||
}
|
||||
|
||||
func TestNormalizeResponsesWebsocketRequestDoesNotTreatDeveloperMessageAsReplacement(t *testing.T) {
|
||||
lastRequest := []byte(`{"model":"test-model","stream":true,"input":[{"type":"message","id":"msg-1"}]}`)
|
||||
lastResponseOutput := []byte(`[
|
||||
{"type":"message","id":"assistant-1","role":"assistant"}
|
||||
]`)
|
||||
raw := []byte(`{"type":"response.create","input":[{"type":"message","id":"dev-1","role":"developer"},{"type":"message","id":"msg-2"}]}`)
|
||||
|
||||
normalized, next, errMsg := normalizeResponsesWebsocketRequest(raw, lastRequest, lastResponseOutput)
|
||||
if errMsg != nil {
|
||||
t.Fatalf("unexpected error: %v", errMsg.Error)
|
||||
}
|
||||
items := gjson.GetBytes(normalized, "input").Array()
|
||||
if len(items) != 4 {
|
||||
t.Fatalf("merged input len = %d, want 4: %s", len(items), normalized)
|
||||
}
|
||||
if items[0].Get("id").String() != "msg-1" ||
|
||||
items[1].Get("id").String() != "assistant-1" ||
|
||||
items[2].Get("id").String() != "dev-1" ||
|
||||
items[3].Get("id").String() != "msg-2" {
|
||||
t.Fatalf("developer follow-up should preserve merge behavior: %s", normalized)
|
||||
}
|
||||
if !bytes.Equal(next, normalized) {
|
||||
t.Fatalf("next request snapshot should match merged request")
|
||||
}
|
||||
}
|
||||
|
||||
func TestNormalizeResponsesWebsocketRequestDropsDuplicateFunctionCallsByCallID(t *testing.T) {
|
||||
lastRequest := []byte(`{"model":"test-model","stream":true,"input":[{"type":"function_call","id":"fc-1","call_id":"call-1"},{"type":"function_call_output","id":"tool-out-1","call_id":"call-1"}]}`)
|
||||
lastResponseOutput := []byte(`[
|
||||
{"type":"function_call","id":"fc-1","call_id":"call-1","name":"tool"}
|
||||
]`)
|
||||
raw := []byte(`{"type":"response.create","input":[{"type":"message","id":"msg-2"}]}`)
|
||||
|
||||
normalized, _, errMsg := normalizeResponsesWebsocketRequest(raw, lastRequest, lastResponseOutput)
|
||||
if errMsg != nil {
|
||||
t.Fatalf("unexpected error: %v", errMsg.Error)
|
||||
}
|
||||
|
||||
items := gjson.GetBytes(normalized, "input").Array()
|
||||
if len(items) != 3 {
|
||||
t.Fatalf("merged input len = %d, want 3: %s", len(items), normalized)
|
||||
}
|
||||
if items[0].Get("id").String() != "fc-1" ||
|
||||
items[1].Get("id").String() != "tool-out-1" ||
|
||||
items[2].Get("id").String() != "msg-2" {
|
||||
t.Fatalf("unexpected merged input order: %s", normalized)
|
||||
}
|
||||
}
|
||||
|
||||
func TestResponsesWebsocketCompactionResetsTurnStateOnTranscriptReplacement(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
executor := &websocketCompactionCaptureExecutor{}
|
||||
manager := coreauth.NewManager(nil, nil, nil)
|
||||
manager.RegisterExecutor(executor)
|
||||
auth := &coreauth.Auth{ID: "auth-sse", Provider: executor.Identifier(), Status: coreauth.StatusActive}
|
||||
if _, err := manager.Register(context.Background(), auth); err != nil {
|
||||
t.Fatalf("Register auth: %v", err)
|
||||
}
|
||||
registry.GetGlobalRegistry().RegisterClient(auth.ID, auth.Provider, []*registry.ModelInfo{{ID: "test-model"}})
|
||||
t.Cleanup(func() {
|
||||
registry.GetGlobalRegistry().UnregisterClient(auth.ID)
|
||||
})
|
||||
|
||||
base := handlers.NewBaseAPIHandlers(&sdkconfig.SDKConfig{}, manager)
|
||||
h := NewOpenAIResponsesAPIHandler(base)
|
||||
router := gin.New()
|
||||
router.GET("/v1/responses/ws", h.ResponsesWebsocket)
|
||||
router.POST("/v1/responses/compact", h.Compact)
|
||||
|
||||
server := httptest.NewServer(router)
|
||||
defer server.Close()
|
||||
|
||||
wsURL := "ws" + strings.TrimPrefix(server.URL, "http") + "/v1/responses/ws"
|
||||
conn, _, err := websocket.DefaultDialer.Dial(wsURL, nil)
|
||||
if err != nil {
|
||||
t.Fatalf("dial websocket: %v", err)
|
||||
}
|
||||
defer func() {
|
||||
if errClose := conn.Close(); errClose != nil {
|
||||
t.Fatalf("close websocket: %v", errClose)
|
||||
}
|
||||
}()
|
||||
|
||||
requests := []string{
|
||||
`{"type":"response.create","model":"test-model","input":[{"type":"message","id":"msg-1"}]}`,
|
||||
`{"type":"response.create","input":[{"type":"function_call_output","call_id":"call-1","id":"tool-out-1"}]}`,
|
||||
}
|
||||
for i := range requests {
|
||||
if errWrite := conn.WriteMessage(websocket.TextMessage, []byte(requests[i])); errWrite != nil {
|
||||
t.Fatalf("write websocket message %d: %v", i+1, errWrite)
|
||||
}
|
||||
_, payload, errReadMessage := conn.ReadMessage()
|
||||
if errReadMessage != nil {
|
||||
t.Fatalf("read websocket message %d: %v", i+1, errReadMessage)
|
||||
}
|
||||
if got := gjson.GetBytes(payload, "type").String(); got != wsEventTypeCompleted {
|
||||
t.Fatalf("message %d payload type = %s, want %s", i+1, got, wsEventTypeCompleted)
|
||||
}
|
||||
}
|
||||
|
||||
compactResp, errPost := server.Client().Post(
|
||||
server.URL+"/v1/responses/compact",
|
||||
"application/json",
|
||||
strings.NewReader(`{"model":"test-model","input":[{"type":"message","id":"summary-1"}]}`),
|
||||
)
|
||||
if errPost != nil {
|
||||
t.Fatalf("compact request failed: %v", errPost)
|
||||
}
|
||||
if errClose := compactResp.Body.Close(); errClose != nil {
|
||||
t.Fatalf("close compact response body: %v", errClose)
|
||||
}
|
||||
if compactResp.StatusCode != http.StatusOK {
|
||||
t.Fatalf("compact status = %d, want %d", compactResp.StatusCode, http.StatusOK)
|
||||
}
|
||||
|
||||
// Simulate a post-compaction client turn that replaces local history with a compacted transcript.
|
||||
// The websocket handler must treat this as a state reset, not append it to stale pre-compaction state.
|
||||
postCompact := `{"type":"response.create","input":[{"type":"function_call","id":"fc-compact","call_id":"call-1","name":"tool"},{"type":"message","id":"msg-2"}]}`
|
||||
if errWrite := conn.WriteMessage(websocket.TextMessage, []byte(postCompact)); errWrite != nil {
|
||||
t.Fatalf("write post-compact websocket message: %v", errWrite)
|
||||
}
|
||||
_, payload, errReadMessage := conn.ReadMessage()
|
||||
if errReadMessage != nil {
|
||||
t.Fatalf("read post-compact websocket message: %v", errReadMessage)
|
||||
}
|
||||
if got := gjson.GetBytes(payload, "type").String(); got != wsEventTypeCompleted {
|
||||
t.Fatalf("post-compact payload type = %s, want %s", got, wsEventTypeCompleted)
|
||||
}
|
||||
|
||||
executor.mu.Lock()
|
||||
defer executor.mu.Unlock()
|
||||
|
||||
if executor.compactPayload == nil {
|
||||
t.Fatalf("compact payload was not captured")
|
||||
}
|
||||
if len(executor.streamPayloads) != 3 {
|
||||
t.Fatalf("stream payload count = %d, want 3", len(executor.streamPayloads))
|
||||
}
|
||||
|
||||
merged := executor.streamPayloads[2]
|
||||
items := gjson.GetBytes(merged, "input").Array()
|
||||
if len(items) != 2 {
|
||||
t.Fatalf("merged input len = %d, want 2: %s", len(items), merged)
|
||||
}
|
||||
if items[0].Get("id").String() != "fc-compact" ||
|
||||
items[1].Get("id").String() != "msg-2" {
|
||||
t.Fatalf("unexpected post-compact input order: %s", merged)
|
||||
}
|
||||
if items[0].Get("call_id").String() != "call-1" {
|
||||
t.Fatalf("post-compact function call id = %s, want call-1", items[0].Get("call_id").String())
|
||||
}
|
||||
}
|
||||
|
||||
@@ -0,0 +1,402 @@
|
||||
package openai
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/tidwall/gjson"
|
||||
"github.com/tidwall/sjson"
|
||||
)
|
||||
|
||||
const (
|
||||
websocketToolOutputCacheMaxPerSession = 256
|
||||
websocketToolOutputCacheTTL = 30 * time.Minute
|
||||
)
|
||||
|
||||
var defaultWebsocketToolOutputCache = newWebsocketToolOutputCache(0, websocketToolOutputCacheMaxPerSession)
|
||||
var defaultWebsocketToolCallCache = newWebsocketToolOutputCache(0, websocketToolOutputCacheMaxPerSession)
|
||||
var defaultWebsocketToolSessionRefs = newWebsocketToolSessionRefCounter()
|
||||
|
||||
type websocketToolOutputCache struct {
|
||||
mu sync.Mutex
|
||||
ttl time.Duration
|
||||
maxPerSession int
|
||||
sessions map[string]*websocketToolOutputSession
|
||||
}
|
||||
|
||||
type websocketToolOutputSession struct {
|
||||
lastSeen time.Time
|
||||
outputs map[string]json.RawMessage
|
||||
order []string
|
||||
}
|
||||
|
||||
func newWebsocketToolOutputCache(ttl time.Duration, maxPerSession int) *websocketToolOutputCache {
|
||||
if ttl < 0 {
|
||||
ttl = websocketToolOutputCacheTTL
|
||||
}
|
||||
if maxPerSession <= 0 {
|
||||
maxPerSession = websocketToolOutputCacheMaxPerSession
|
||||
}
|
||||
return &websocketToolOutputCache{
|
||||
ttl: ttl,
|
||||
maxPerSession: maxPerSession,
|
||||
sessions: make(map[string]*websocketToolOutputSession),
|
||||
}
|
||||
}
|
||||
|
||||
func (c *websocketToolOutputCache) record(sessionKey string, callID string, item json.RawMessage) {
|
||||
sessionKey = strings.TrimSpace(sessionKey)
|
||||
callID = strings.TrimSpace(callID)
|
||||
if sessionKey == "" || callID == "" || c == nil {
|
||||
return
|
||||
}
|
||||
|
||||
now := time.Now()
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
|
||||
c.cleanupLocked(now)
|
||||
|
||||
session, ok := c.sessions[sessionKey]
|
||||
if !ok || session == nil {
|
||||
session = &websocketToolOutputSession{
|
||||
lastSeen: now,
|
||||
outputs: make(map[string]json.RawMessage),
|
||||
}
|
||||
c.sessions[sessionKey] = session
|
||||
}
|
||||
session.lastSeen = now
|
||||
|
||||
if _, exists := session.outputs[callID]; !exists {
|
||||
session.order = append(session.order, callID)
|
||||
}
|
||||
session.outputs[callID] = append(json.RawMessage(nil), item...)
|
||||
|
||||
for len(session.order) > c.maxPerSession {
|
||||
evict := session.order[0]
|
||||
session.order = session.order[1:]
|
||||
delete(session.outputs, evict)
|
||||
}
|
||||
}
|
||||
|
||||
func (c *websocketToolOutputCache) get(sessionKey string, callID string) (json.RawMessage, bool) {
|
||||
sessionKey = strings.TrimSpace(sessionKey)
|
||||
callID = strings.TrimSpace(callID)
|
||||
if sessionKey == "" || callID == "" || c == nil {
|
||||
return nil, false
|
||||
}
|
||||
|
||||
now := time.Now()
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
|
||||
c.cleanupLocked(now)
|
||||
|
||||
session, ok := c.sessions[sessionKey]
|
||||
if !ok || session == nil {
|
||||
return nil, false
|
||||
}
|
||||
session.lastSeen = now
|
||||
item, ok := session.outputs[callID]
|
||||
if !ok || len(item) == 0 {
|
||||
return nil, false
|
||||
}
|
||||
return append(json.RawMessage(nil), item...), true
|
||||
}
|
||||
|
||||
func (c *websocketToolOutputCache) cleanupLocked(now time.Time) {
|
||||
if c == nil || c.ttl <= 0 {
|
||||
return
|
||||
}
|
||||
|
||||
for key, session := range c.sessions {
|
||||
if session == nil {
|
||||
delete(c.sessions, key)
|
||||
continue
|
||||
}
|
||||
if now.Sub(session.lastSeen) > c.ttl {
|
||||
delete(c.sessions, key)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (c *websocketToolOutputCache) deleteSession(sessionKey string) {
|
||||
sessionKey = strings.TrimSpace(sessionKey)
|
||||
if sessionKey == "" || c == nil {
|
||||
return
|
||||
}
|
||||
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
|
||||
delete(c.sessions, sessionKey)
|
||||
}
|
||||
|
||||
func websocketDownstreamSessionKey(req *http.Request) string {
|
||||
if req == nil {
|
||||
return ""
|
||||
}
|
||||
if requestID := strings.TrimSpace(req.Header.Get("X-Client-Request-Id")); requestID != "" {
|
||||
return requestID
|
||||
}
|
||||
if raw := strings.TrimSpace(req.Header.Get("X-Codex-Turn-Metadata")); raw != "" {
|
||||
if sessionID := strings.TrimSpace(gjson.Get(raw, "session_id").String()); sessionID != "" {
|
||||
return sessionID
|
||||
}
|
||||
}
|
||||
if sessionID := strings.TrimSpace(req.Header.Get("Session_id")); sessionID != "" {
|
||||
return sessionID
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
type websocketToolSessionRefCounter struct {
|
||||
mu sync.Mutex
|
||||
counts map[string]int
|
||||
}
|
||||
|
||||
func newWebsocketToolSessionRefCounter() *websocketToolSessionRefCounter {
|
||||
return &websocketToolSessionRefCounter{counts: make(map[string]int)}
|
||||
}
|
||||
|
||||
func (c *websocketToolSessionRefCounter) acquire(sessionKey string) {
|
||||
sessionKey = strings.TrimSpace(sessionKey)
|
||||
if sessionKey == "" || c == nil {
|
||||
return
|
||||
}
|
||||
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
|
||||
c.counts[sessionKey]++
|
||||
}
|
||||
|
||||
func (c *websocketToolSessionRefCounter) release(sessionKey string) bool {
|
||||
sessionKey = strings.TrimSpace(sessionKey)
|
||||
if sessionKey == "" || c == nil {
|
||||
return false
|
||||
}
|
||||
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
|
||||
count := c.counts[sessionKey]
|
||||
if count <= 1 {
|
||||
delete(c.counts, sessionKey)
|
||||
return true
|
||||
}
|
||||
c.counts[sessionKey] = count - 1
|
||||
return false
|
||||
}
|
||||
|
||||
func retainResponsesWebsocketToolCaches(sessionKey string) {
|
||||
if defaultWebsocketToolSessionRefs == nil {
|
||||
return
|
||||
}
|
||||
defaultWebsocketToolSessionRefs.acquire(sessionKey)
|
||||
}
|
||||
|
||||
func releaseResponsesWebsocketToolCaches(sessionKey string) {
|
||||
if defaultWebsocketToolSessionRefs == nil {
|
||||
return
|
||||
}
|
||||
if !defaultWebsocketToolSessionRefs.release(sessionKey) {
|
||||
return
|
||||
}
|
||||
|
||||
if defaultWebsocketToolOutputCache != nil {
|
||||
defaultWebsocketToolOutputCache.deleteSession(sessionKey)
|
||||
}
|
||||
if defaultWebsocketToolCallCache != nil {
|
||||
defaultWebsocketToolCallCache.deleteSession(sessionKey)
|
||||
}
|
||||
}
|
||||
|
||||
func repairResponsesWebsocketToolCalls(sessionKey string, payload []byte) []byte {
|
||||
return repairResponsesWebsocketToolCallsWithCaches(defaultWebsocketToolOutputCache, defaultWebsocketToolCallCache, sessionKey, payload)
|
||||
}
|
||||
|
||||
func repairResponsesWebsocketToolCallsWithCache(cache *websocketToolOutputCache, sessionKey string, payload []byte) []byte {
|
||||
return repairResponsesWebsocketToolCallsWithCaches(cache, nil, sessionKey, payload)
|
||||
}
|
||||
|
||||
func repairResponsesWebsocketToolCallsWithCaches(outputCache, callCache *websocketToolOutputCache, sessionKey string, payload []byte) []byte {
|
||||
sessionKey = strings.TrimSpace(sessionKey)
|
||||
if sessionKey == "" || outputCache == nil || len(payload) == 0 {
|
||||
return payload
|
||||
}
|
||||
|
||||
input := gjson.GetBytes(payload, "input")
|
||||
if !input.Exists() || !input.IsArray() {
|
||||
return payload
|
||||
}
|
||||
|
||||
allowOrphanOutputs := strings.TrimSpace(gjson.GetBytes(payload, "previous_response_id").String()) != ""
|
||||
updatedRaw, errRepair := repairResponsesToolCallsArray(outputCache, callCache, sessionKey, input.Raw, allowOrphanOutputs)
|
||||
if errRepair != nil || updatedRaw == "" || updatedRaw == input.Raw {
|
||||
return payload
|
||||
}
|
||||
|
||||
updated, errSet := sjson.SetRawBytes(payload, "input", []byte(updatedRaw))
|
||||
if errSet != nil {
|
||||
return payload
|
||||
}
|
||||
return updated
|
||||
}
|
||||
|
||||
func repairResponsesToolCallsArray(outputCache, callCache *websocketToolOutputCache, sessionKey string, rawArray string, allowOrphanOutputs bool) (string, error) {
|
||||
rawArray = strings.TrimSpace(rawArray)
|
||||
if rawArray == "" {
|
||||
return "[]", nil
|
||||
}
|
||||
|
||||
var items []json.RawMessage
|
||||
if errUnmarshal := json.Unmarshal([]byte(rawArray), &items); errUnmarshal != nil {
|
||||
return "", errUnmarshal
|
||||
}
|
||||
|
||||
// First pass: record tool outputs and remember which call_ids have outputs in this payload.
|
||||
outputPresent := make(map[string]struct{}, len(items))
|
||||
callPresent := make(map[string]struct{}, len(items))
|
||||
for _, item := range items {
|
||||
if len(item) == 0 {
|
||||
continue
|
||||
}
|
||||
itemType := strings.TrimSpace(gjson.GetBytes(item, "type").String())
|
||||
switch itemType {
|
||||
case "function_call_output":
|
||||
callID := strings.TrimSpace(gjson.GetBytes(item, "call_id").String())
|
||||
if callID == "" {
|
||||
continue
|
||||
}
|
||||
outputPresent[callID] = struct{}{}
|
||||
outputCache.record(sessionKey, callID, item)
|
||||
case "function_call":
|
||||
callID := strings.TrimSpace(gjson.GetBytes(item, "call_id").String())
|
||||
if callID == "" {
|
||||
continue
|
||||
}
|
||||
callPresent[callID] = struct{}{}
|
||||
if callCache != nil {
|
||||
callCache.record(sessionKey, callID, item)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
filtered := make([]json.RawMessage, 0, len(items))
|
||||
insertedCalls := make(map[string]struct{}, len(items))
|
||||
for _, item := range items {
|
||||
if len(item) == 0 {
|
||||
continue
|
||||
}
|
||||
itemType := strings.TrimSpace(gjson.GetBytes(item, "type").String())
|
||||
if itemType == "function_call_output" {
|
||||
callID := strings.TrimSpace(gjson.GetBytes(item, "call_id").String())
|
||||
if callID == "" {
|
||||
// Upstream rejects tool outputs without a call_id; drop it.
|
||||
continue
|
||||
}
|
||||
|
||||
if allowOrphanOutputs {
|
||||
filtered = append(filtered, item)
|
||||
continue
|
||||
}
|
||||
|
||||
if _, ok := callPresent[callID]; ok {
|
||||
filtered = append(filtered, item)
|
||||
continue
|
||||
}
|
||||
|
||||
if callCache != nil {
|
||||
if cached, ok := callCache.get(sessionKey, callID); ok {
|
||||
if _, already := insertedCalls[callID]; !already {
|
||||
filtered = append(filtered, cached)
|
||||
insertedCalls[callID] = struct{}{}
|
||||
callPresent[callID] = struct{}{}
|
||||
}
|
||||
filtered = append(filtered, item)
|
||||
continue
|
||||
}
|
||||
}
|
||||
|
||||
// Drop orphaned function_call_output items; upstream rejects transcripts with missing calls.
|
||||
continue
|
||||
}
|
||||
if itemType != "function_call" {
|
||||
filtered = append(filtered, item)
|
||||
continue
|
||||
}
|
||||
|
||||
callID := strings.TrimSpace(gjson.GetBytes(item, "call_id").String())
|
||||
if callID == "" {
|
||||
// Upstream rejects tool calls without a call_id; drop it.
|
||||
continue
|
||||
}
|
||||
|
||||
if _, ok := outputPresent[callID]; ok {
|
||||
filtered = append(filtered, item)
|
||||
continue
|
||||
}
|
||||
|
||||
if cached, ok := outputCache.get(sessionKey, callID); ok {
|
||||
filtered = append(filtered, item)
|
||||
filtered = append(filtered, cached)
|
||||
outputPresent[callID] = struct{}{}
|
||||
continue
|
||||
}
|
||||
|
||||
// Drop orphaned function_call items; upstream rejects transcripts with missing outputs.
|
||||
}
|
||||
|
||||
out, errMarshal := json.Marshal(filtered)
|
||||
if errMarshal != nil {
|
||||
return "", errMarshal
|
||||
}
|
||||
return string(out), nil
|
||||
}
|
||||
|
||||
func recordResponsesWebsocketToolCallsFromPayload(sessionKey string, payload []byte) {
|
||||
recordResponsesWebsocketToolCallsFromPayloadWithCache(defaultWebsocketToolCallCache, sessionKey, payload)
|
||||
}
|
||||
|
||||
func recordResponsesWebsocketToolCallsFromPayloadWithCache(cache *websocketToolOutputCache, sessionKey string, payload []byte) {
|
||||
sessionKey = strings.TrimSpace(sessionKey)
|
||||
if sessionKey == "" || cache == nil || len(payload) == 0 {
|
||||
return
|
||||
}
|
||||
|
||||
eventType := strings.TrimSpace(gjson.GetBytes(payload, "type").String())
|
||||
switch eventType {
|
||||
case "response.completed":
|
||||
output := gjson.GetBytes(payload, "response.output")
|
||||
if !output.Exists() || !output.IsArray() {
|
||||
return
|
||||
}
|
||||
for _, item := range output.Array() {
|
||||
if strings.TrimSpace(item.Get("type").String()) != "function_call" {
|
||||
continue
|
||||
}
|
||||
callID := strings.TrimSpace(item.Get("call_id").String())
|
||||
if callID == "" {
|
||||
continue
|
||||
}
|
||||
cache.record(sessionKey, callID, json.RawMessage(item.Raw))
|
||||
}
|
||||
case "response.output_item.added", "response.output_item.done":
|
||||
item := gjson.GetBytes(payload, "item")
|
||||
if !item.Exists() || !item.IsObject() {
|
||||
return
|
||||
}
|
||||
if strings.TrimSpace(item.Get("type").String()) != "function_call" {
|
||||
return
|
||||
}
|
||||
callID := strings.TrimSpace(item.Get("call_id").String())
|
||||
if callID == "" {
|
||||
return
|
||||
}
|
||||
cache.record(sessionKey, callID, json.RawMessage(item.Raw))
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user