mirror of
https://github.com/router-for-me/CLIProxyAPIPlus.git
synced 2026-03-08 06:43:41 +00:00
Fixed: #1901
test(websocket): add tests for incremental input and prewarm handling logic - Added test cases for incremental input support based on upstream capabilities. - Introduced validation for prewarm handling of `response.create` messages locally. - Enhanced test coverage for websocket executor behavior, including payload forwarding checks. - Updated websocket implementation with prewarm and incremental input logic for better testability.
This commit is contained in:
@@ -14,7 +14,11 @@ import (
|
||||
"github.com/google/uuid"
|
||||
"github.com/gorilla/websocket"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/interfaces"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/registry"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/thinking"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/util"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/sdk/api/handlers"
|
||||
coreauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth"
|
||||
cliproxyexecutor "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/executor"
|
||||
log "github.com/sirupsen/logrus"
|
||||
"github.com/tidwall/gjson"
|
||||
@@ -100,11 +104,17 @@ func (h *OpenAIResponsesAPIHandler) ResponsesWebsocket(c *gin.Context) {
|
||||
// )
|
||||
appendWebsocketEvent(&wsBodyLog, "request", payload)
|
||||
|
||||
allowIncrementalInputWithPreviousResponseID := websocketUpstreamSupportsIncrementalInput(nil, nil)
|
||||
allowIncrementalInputWithPreviousResponseID := false
|
||||
if pinnedAuthID != "" && h != nil && h.AuthManager != nil {
|
||||
if pinnedAuth, ok := h.AuthManager.GetByID(pinnedAuthID); ok && pinnedAuth != nil {
|
||||
allowIncrementalInputWithPreviousResponseID = websocketUpstreamSupportsIncrementalInput(pinnedAuth.Attributes, pinnedAuth.Metadata)
|
||||
}
|
||||
} else {
|
||||
requestModelName := strings.TrimSpace(gjson.GetBytes(payload, "model").String())
|
||||
if requestModelName == "" {
|
||||
requestModelName = strings.TrimSpace(gjson.GetBytes(lastRequest, "model").String())
|
||||
}
|
||||
allowIncrementalInputWithPreviousResponseID = h.websocketUpstreamSupportsIncrementalInputForModel(requestModelName)
|
||||
}
|
||||
|
||||
var requestJSON []byte
|
||||
@@ -139,6 +149,22 @@ func (h *OpenAIResponsesAPIHandler) ResponsesWebsocket(c *gin.Context) {
|
||||
}
|
||||
continue
|
||||
}
|
||||
if shouldHandleResponsesWebsocketPrewarmLocally(payload, lastRequest, allowIncrementalInputWithPreviousResponseID) {
|
||||
if updated, errDelete := sjson.DeleteBytes(requestJSON, "generate"); errDelete == nil {
|
||||
requestJSON = updated
|
||||
}
|
||||
if updated, errDelete := sjson.DeleteBytes(updatedLastRequest, "generate"); errDelete == nil {
|
||||
updatedLastRequest = updated
|
||||
}
|
||||
lastRequest = updatedLastRequest
|
||||
lastResponseOutput = []byte("[]")
|
||||
if errWrite := writeResponsesWebsocketSyntheticPrewarm(c, conn, requestJSON, &wsBodyLog, passthroughSessionID); errWrite != nil {
|
||||
wsTerminateErr = errWrite
|
||||
appendWebsocketEvent(&wsBodyLog, "disconnect", []byte(errWrite.Error()))
|
||||
return
|
||||
}
|
||||
continue
|
||||
}
|
||||
lastRequest = updatedLastRequest
|
||||
|
||||
modelName := gjson.GetBytes(requestJSON, "model").String()
|
||||
@@ -339,6 +365,192 @@ func websocketUpstreamSupportsIncrementalInput(attributes map[string]string, met
|
||||
return false
|
||||
}
|
||||
|
||||
func (h *OpenAIResponsesAPIHandler) websocketUpstreamSupportsIncrementalInputForModel(modelName string) bool {
|
||||
if h == nil || h.AuthManager == nil {
|
||||
return false
|
||||
}
|
||||
|
||||
resolvedModelName := modelName
|
||||
initialSuffix := thinking.ParseSuffix(modelName)
|
||||
if initialSuffix.ModelName == "auto" {
|
||||
resolvedBase := util.ResolveAutoModel(initialSuffix.ModelName)
|
||||
if initialSuffix.HasSuffix {
|
||||
resolvedModelName = fmt.Sprintf("%s(%s)", resolvedBase, initialSuffix.RawSuffix)
|
||||
} else {
|
||||
resolvedModelName = resolvedBase
|
||||
}
|
||||
} else {
|
||||
resolvedModelName = util.ResolveAutoModel(modelName)
|
||||
}
|
||||
|
||||
parsed := thinking.ParseSuffix(resolvedModelName)
|
||||
baseModel := strings.TrimSpace(parsed.ModelName)
|
||||
providers := util.GetProviderName(baseModel)
|
||||
if len(providers) == 0 && baseModel != resolvedModelName {
|
||||
providers = util.GetProviderName(resolvedModelName)
|
||||
}
|
||||
if len(providers) == 0 {
|
||||
return false
|
||||
}
|
||||
|
||||
providerSet := make(map[string]struct{}, len(providers))
|
||||
for i := 0; i < len(providers); i++ {
|
||||
providerKey := strings.TrimSpace(strings.ToLower(providers[i]))
|
||||
if providerKey == "" {
|
||||
continue
|
||||
}
|
||||
providerSet[providerKey] = struct{}{}
|
||||
}
|
||||
if len(providerSet) == 0 {
|
||||
return false
|
||||
}
|
||||
|
||||
modelKey := baseModel
|
||||
if modelKey == "" {
|
||||
modelKey = strings.TrimSpace(resolvedModelName)
|
||||
}
|
||||
registryRef := registry.GetGlobalRegistry()
|
||||
now := time.Now()
|
||||
auths := h.AuthManager.List()
|
||||
for i := 0; i < len(auths); i++ {
|
||||
auth := auths[i]
|
||||
if auth == nil {
|
||||
continue
|
||||
}
|
||||
providerKey := strings.TrimSpace(strings.ToLower(auth.Provider))
|
||||
if _, ok := providerSet[providerKey]; !ok {
|
||||
continue
|
||||
}
|
||||
if modelKey != "" && registryRef != nil && !registryRef.ClientSupportsModel(auth.ID, modelKey) {
|
||||
continue
|
||||
}
|
||||
if !responsesWebsocketAuthAvailableForModel(auth, modelKey, now) {
|
||||
continue
|
||||
}
|
||||
if websocketUpstreamSupportsIncrementalInput(auth.Attributes, auth.Metadata) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func responsesWebsocketAuthAvailableForModel(auth *coreauth.Auth, modelName string, now time.Time) bool {
|
||||
if auth == nil {
|
||||
return false
|
||||
}
|
||||
if auth.Disabled || auth.Status == coreauth.StatusDisabled {
|
||||
return false
|
||||
}
|
||||
if modelName != "" && len(auth.ModelStates) > 0 {
|
||||
state, ok := auth.ModelStates[modelName]
|
||||
if (!ok || state == nil) && modelName != "" {
|
||||
baseModel := strings.TrimSpace(thinking.ParseSuffix(modelName).ModelName)
|
||||
if baseModel != "" && baseModel != modelName {
|
||||
state, ok = auth.ModelStates[baseModel]
|
||||
}
|
||||
}
|
||||
if ok && state != nil {
|
||||
if state.Status == coreauth.StatusDisabled {
|
||||
return false
|
||||
}
|
||||
if state.Unavailable && !state.NextRetryAfter.IsZero() && state.NextRetryAfter.After(now) {
|
||||
return false
|
||||
}
|
||||
return true
|
||||
}
|
||||
}
|
||||
if auth.Unavailable && !auth.NextRetryAfter.IsZero() && auth.NextRetryAfter.After(now) {
|
||||
return false
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
func shouldHandleResponsesWebsocketPrewarmLocally(rawJSON []byte, lastRequest []byte, allowIncrementalInputWithPreviousResponseID bool) bool {
|
||||
if allowIncrementalInputWithPreviousResponseID || len(lastRequest) != 0 {
|
||||
return false
|
||||
}
|
||||
if strings.TrimSpace(gjson.GetBytes(rawJSON, "type").String()) != wsRequestTypeCreate {
|
||||
return false
|
||||
}
|
||||
generateResult := gjson.GetBytes(rawJSON, "generate")
|
||||
return generateResult.Exists() && !generateResult.Bool()
|
||||
}
|
||||
|
||||
func writeResponsesWebsocketSyntheticPrewarm(
|
||||
c *gin.Context,
|
||||
conn *websocket.Conn,
|
||||
requestJSON []byte,
|
||||
wsBodyLog *strings.Builder,
|
||||
sessionID string,
|
||||
) error {
|
||||
payloads, errPayloads := syntheticResponsesWebsocketPrewarmPayloads(requestJSON)
|
||||
if errPayloads != nil {
|
||||
return errPayloads
|
||||
}
|
||||
for i := 0; i < len(payloads); i++ {
|
||||
markAPIResponseTimestamp(c)
|
||||
appendWebsocketEvent(wsBodyLog, "response", payloads[i])
|
||||
// log.Infof(
|
||||
// "responses websocket: downstream_out id=%s type=%d event=%s payload=%s",
|
||||
// sessionID,
|
||||
// websocket.TextMessage,
|
||||
// websocketPayloadEventType(payloads[i]),
|
||||
// websocketPayloadPreview(payloads[i]),
|
||||
// )
|
||||
if errWrite := conn.WriteMessage(websocket.TextMessage, payloads[i]); errWrite != nil {
|
||||
log.Warnf(
|
||||
"responses websocket: downstream_out write failed id=%s event=%s error=%v",
|
||||
sessionID,
|
||||
websocketPayloadEventType(payloads[i]),
|
||||
errWrite,
|
||||
)
|
||||
return errWrite
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func syntheticResponsesWebsocketPrewarmPayloads(requestJSON []byte) ([][]byte, error) {
|
||||
responseID := "resp_prewarm_" + uuid.NewString()
|
||||
createdAt := time.Now().Unix()
|
||||
modelName := strings.TrimSpace(gjson.GetBytes(requestJSON, "model").String())
|
||||
|
||||
createdPayload := []byte(`{"type":"response.created","sequence_number":0,"response":{"id":"","object":"response","created_at":0,"status":"in_progress","background":false,"error":null,"output":[]}}`)
|
||||
var errSet error
|
||||
createdPayload, errSet = sjson.SetBytes(createdPayload, "response.id", responseID)
|
||||
if errSet != nil {
|
||||
return nil, errSet
|
||||
}
|
||||
createdPayload, errSet = sjson.SetBytes(createdPayload, "response.created_at", createdAt)
|
||||
if errSet != nil {
|
||||
return nil, errSet
|
||||
}
|
||||
if modelName != "" {
|
||||
createdPayload, errSet = sjson.SetBytes(createdPayload, "response.model", modelName)
|
||||
if errSet != nil {
|
||||
return nil, errSet
|
||||
}
|
||||
}
|
||||
|
||||
completedPayload := []byte(`{"type":"response.completed","sequence_number":1,"response":{"id":"","object":"response","created_at":0,"status":"completed","background":false,"error":null,"output":[],"usage":{"input_tokens":0,"output_tokens":0,"total_tokens":0}}}`)
|
||||
completedPayload, errSet = sjson.SetBytes(completedPayload, "response.id", responseID)
|
||||
if errSet != nil {
|
||||
return nil, errSet
|
||||
}
|
||||
completedPayload, errSet = sjson.SetBytes(completedPayload, "response.created_at", createdAt)
|
||||
if errSet != nil {
|
||||
return nil, errSet
|
||||
}
|
||||
if modelName != "" {
|
||||
completedPayload, errSet = sjson.SetBytes(completedPayload, "response.model", modelName)
|
||||
if errSet != nil {
|
||||
return nil, errSet
|
||||
}
|
||||
}
|
||||
|
||||
return [][]byte{createdPayload, completedPayload}, nil
|
||||
}
|
||||
|
||||
func mergeJSONArrayRaw(existingRaw, appendRaw string) (string, error) {
|
||||
existingRaw = strings.TrimSpace(existingRaw)
|
||||
appendRaw = strings.TrimSpace(appendRaw)
|
||||
@@ -550,47 +762,63 @@ func writeResponsesWebsocketError(conn *websocket.Conn, errMsg *interfaces.Error
|
||||
}
|
||||
|
||||
body := handlers.BuildErrorResponseBody(status, errText)
|
||||
payload := map[string]any{
|
||||
"type": wsEventTypeError,
|
||||
"status": status,
|
||||
payload := []byte(`{}`)
|
||||
var errSet error
|
||||
payload, errSet = sjson.SetBytes(payload, "type", wsEventTypeError)
|
||||
if errSet != nil {
|
||||
return nil, errSet
|
||||
}
|
||||
payload, errSet = sjson.SetBytes(payload, "status", status)
|
||||
if errSet != nil {
|
||||
return nil, errSet
|
||||
}
|
||||
|
||||
if errMsg != nil && errMsg.Addon != nil {
|
||||
headers := map[string]any{}
|
||||
headers := []byte(`{}`)
|
||||
hasHeaders := false
|
||||
for key, values := range errMsg.Addon {
|
||||
if len(values) == 0 {
|
||||
continue
|
||||
}
|
||||
headers[key] = values[0]
|
||||
headerPath := strings.ReplaceAll(strings.ReplaceAll(key, `\\`, `\\\\`), ".", `\\.`)
|
||||
headers, errSet = sjson.SetBytes(headers, headerPath, values[0])
|
||||
if errSet != nil {
|
||||
return nil, errSet
|
||||
}
|
||||
hasHeaders = true
|
||||
}
|
||||
if len(headers) > 0 {
|
||||
payload["headers"] = headers
|
||||
}
|
||||
}
|
||||
|
||||
if len(body) > 0 && json.Valid(body) {
|
||||
var decoded map[string]any
|
||||
if errDecode := json.Unmarshal(body, &decoded); errDecode == nil {
|
||||
if inner, ok := decoded["error"]; ok {
|
||||
payload["error"] = inner
|
||||
} else {
|
||||
payload["error"] = decoded
|
||||
if hasHeaders {
|
||||
payload, errSet = sjson.SetRawBytes(payload, "headers", headers)
|
||||
if errSet != nil {
|
||||
return nil, errSet
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if _, ok := payload["error"]; !ok {
|
||||
payload["error"] = map[string]any{
|
||||
"type": "server_error",
|
||||
"message": errText,
|
||||
if len(body) > 0 && json.Valid(body) {
|
||||
errorNode := gjson.GetBytes(body, "error")
|
||||
if errorNode.Exists() {
|
||||
payload, errSet = sjson.SetRawBytes(payload, "error", []byte(errorNode.Raw))
|
||||
} else {
|
||||
payload, errSet = sjson.SetRawBytes(payload, "error", body)
|
||||
}
|
||||
if errSet != nil {
|
||||
return nil, errSet
|
||||
}
|
||||
}
|
||||
|
||||
data, err := json.Marshal(payload)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
if !gjson.GetBytes(payload, "error").Exists() {
|
||||
payload, errSet = sjson.SetBytes(payload, "error.type", "server_error")
|
||||
if errSet != nil {
|
||||
return nil, errSet
|
||||
}
|
||||
payload, errSet = sjson.SetBytes(payload, "error.message", errText)
|
||||
if errSet != nil {
|
||||
return nil, errSet
|
||||
}
|
||||
}
|
||||
return data, conn.WriteMessage(websocket.TextMessage, data)
|
||||
|
||||
return payload, conn.WriteMessage(websocket.TextMessage, payload)
|
||||
}
|
||||
|
||||
func appendWebsocketEvent(builder *strings.Builder, eventType string, payload []byte) {
|
||||
|
||||
@@ -2,7 +2,9 @@ package openai
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"strings"
|
||||
@@ -11,9 +13,46 @@ import (
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/gorilla/websocket"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/interfaces"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/registry"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/sdk/api/handlers"
|
||||
coreauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth"
|
||||
coreexecutor "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/executor"
|
||||
sdkconfig "github.com/router-for-me/CLIProxyAPI/v6/sdk/config"
|
||||
"github.com/tidwall/gjson"
|
||||
)
|
||||
|
||||
type websocketCaptureExecutor struct {
|
||||
streamCalls int
|
||||
payloads [][]byte
|
||||
}
|
||||
|
||||
func (e *websocketCaptureExecutor) Identifier() string { return "test-provider" }
|
||||
|
||||
func (e *websocketCaptureExecutor) Execute(context.Context, *coreauth.Auth, coreexecutor.Request, coreexecutor.Options) (coreexecutor.Response, error) {
|
||||
return coreexecutor.Response{}, errors.New("not implemented")
|
||||
}
|
||||
|
||||
func (e *websocketCaptureExecutor) ExecuteStream(_ context.Context, _ *coreauth.Auth, req coreexecutor.Request, _ coreexecutor.Options) (*coreexecutor.StreamResult, error) {
|
||||
e.streamCalls++
|
||||
e.payloads = append(e.payloads, bytes.Clone(req.Payload))
|
||||
chunks := make(chan coreexecutor.StreamChunk, 1)
|
||||
chunks <- coreexecutor.StreamChunk{Payload: []byte(`{"type":"response.completed","response":{"id":"resp-upstream","output":[{"type":"message","id":"out-1"}]}}`)}
|
||||
close(chunks)
|
||||
return &coreexecutor.StreamResult{Chunks: chunks}, nil
|
||||
}
|
||||
|
||||
func (e *websocketCaptureExecutor) Refresh(_ context.Context, auth *coreauth.Auth) (*coreauth.Auth, error) {
|
||||
return auth, nil
|
||||
}
|
||||
|
||||
func (e *websocketCaptureExecutor) CountTokens(context.Context, *coreauth.Auth, coreexecutor.Request, coreexecutor.Options) (coreexecutor.Response, error) {
|
||||
return coreexecutor.Response{}, errors.New("not implemented")
|
||||
}
|
||||
|
||||
func (e *websocketCaptureExecutor) HttpRequest(context.Context, *coreauth.Auth, *http.Request) (*http.Response, error) {
|
||||
return nil, errors.New("not implemented")
|
||||
}
|
||||
|
||||
func TestNormalizeResponsesWebsocketRequestCreate(t *testing.T) {
|
||||
raw := []byte(`{"type":"response.create","model":"test-model","stream":false,"input":[{"type":"message","id":"msg-1"}]}`)
|
||||
|
||||
@@ -326,3 +365,130 @@ func TestForwardResponsesWebsocketPreservesCompletedEvent(t *testing.T) {
|
||||
t.Fatalf("server error: %v", errServer)
|
||||
}
|
||||
}
|
||||
|
||||
func TestWebsocketUpstreamSupportsIncrementalInputForModel(t *testing.T) {
|
||||
manager := coreauth.NewManager(nil, nil, nil)
|
||||
auth := &coreauth.Auth{
|
||||
ID: "auth-ws",
|
||||
Provider: "test-provider",
|
||||
Status: coreauth.StatusActive,
|
||||
Attributes: map[string]string{"websockets": "true"},
|
||||
}
|
||||
if _, err := manager.Register(context.Background(), auth); err != nil {
|
||||
t.Fatalf("Register auth: %v", err)
|
||||
}
|
||||
registry.GetGlobalRegistry().RegisterClient(auth.ID, auth.Provider, []*registry.ModelInfo{{ID: "test-model"}})
|
||||
t.Cleanup(func() {
|
||||
registry.GetGlobalRegistry().UnregisterClient(auth.ID)
|
||||
})
|
||||
|
||||
base := handlers.NewBaseAPIHandlers(&sdkconfig.SDKConfig{}, manager)
|
||||
h := NewOpenAIResponsesAPIHandler(base)
|
||||
if !h.websocketUpstreamSupportsIncrementalInputForModel("test-model") {
|
||||
t.Fatalf("expected websocket-capable upstream for test-model")
|
||||
}
|
||||
}
|
||||
|
||||
func TestResponsesWebsocketPrewarmHandledLocallyForSSEUpstream(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
executor := &websocketCaptureExecutor{}
|
||||
manager := coreauth.NewManager(nil, nil, nil)
|
||||
manager.RegisterExecutor(executor)
|
||||
auth := &coreauth.Auth{ID: "auth-sse", Provider: executor.Identifier(), Status: coreauth.StatusActive}
|
||||
if _, err := manager.Register(context.Background(), auth); err != nil {
|
||||
t.Fatalf("Register auth: %v", err)
|
||||
}
|
||||
registry.GetGlobalRegistry().RegisterClient(auth.ID, auth.Provider, []*registry.ModelInfo{{ID: "test-model"}})
|
||||
t.Cleanup(func() {
|
||||
registry.GetGlobalRegistry().UnregisterClient(auth.ID)
|
||||
})
|
||||
|
||||
base := handlers.NewBaseAPIHandlers(&sdkconfig.SDKConfig{}, manager)
|
||||
h := NewOpenAIResponsesAPIHandler(base)
|
||||
router := gin.New()
|
||||
router.GET("/v1/responses/ws", h.ResponsesWebsocket)
|
||||
|
||||
server := httptest.NewServer(router)
|
||||
defer server.Close()
|
||||
|
||||
wsURL := "ws" + strings.TrimPrefix(server.URL, "http") + "/v1/responses/ws"
|
||||
conn, _, err := websocket.DefaultDialer.Dial(wsURL, nil)
|
||||
if err != nil {
|
||||
t.Fatalf("dial websocket: %v", err)
|
||||
}
|
||||
defer func() {
|
||||
errClose := conn.Close()
|
||||
if errClose != nil {
|
||||
t.Fatalf("close websocket: %v", errClose)
|
||||
}
|
||||
}()
|
||||
|
||||
errWrite := conn.WriteMessage(websocket.TextMessage, []byte(`{"type":"response.create","model":"test-model","generate":false}`))
|
||||
if errWrite != nil {
|
||||
t.Fatalf("write prewarm websocket message: %v", errWrite)
|
||||
}
|
||||
|
||||
_, createdPayload, errReadMessage := conn.ReadMessage()
|
||||
if errReadMessage != nil {
|
||||
t.Fatalf("read prewarm created message: %v", errReadMessage)
|
||||
}
|
||||
if gjson.GetBytes(createdPayload, "type").String() != "response.created" {
|
||||
t.Fatalf("created payload type = %s, want response.created", gjson.GetBytes(createdPayload, "type").String())
|
||||
}
|
||||
prewarmResponseID := gjson.GetBytes(createdPayload, "response.id").String()
|
||||
if prewarmResponseID == "" {
|
||||
t.Fatalf("prewarm response id is empty")
|
||||
}
|
||||
if executor.streamCalls != 0 {
|
||||
t.Fatalf("stream calls after prewarm = %d, want 0", executor.streamCalls)
|
||||
}
|
||||
|
||||
_, completedPayload, errReadMessage := conn.ReadMessage()
|
||||
if errReadMessage != nil {
|
||||
t.Fatalf("read prewarm completed message: %v", errReadMessage)
|
||||
}
|
||||
if gjson.GetBytes(completedPayload, "type").String() != wsEventTypeCompleted {
|
||||
t.Fatalf("completed payload type = %s, want %s", gjson.GetBytes(completedPayload, "type").String(), wsEventTypeCompleted)
|
||||
}
|
||||
if gjson.GetBytes(completedPayload, "response.id").String() != prewarmResponseID {
|
||||
t.Fatalf("completed response id = %s, want %s", gjson.GetBytes(completedPayload, "response.id").String(), prewarmResponseID)
|
||||
}
|
||||
if gjson.GetBytes(completedPayload, "response.usage.total_tokens").Int() != 0 {
|
||||
t.Fatalf("prewarm total tokens = %d, want 0", gjson.GetBytes(completedPayload, "response.usage.total_tokens").Int())
|
||||
}
|
||||
|
||||
secondRequest := fmt.Sprintf(`{"type":"response.create","previous_response_id":%q,"input":[{"type":"message","id":"msg-1"}]}`, prewarmResponseID)
|
||||
errWrite = conn.WriteMessage(websocket.TextMessage, []byte(secondRequest))
|
||||
if errWrite != nil {
|
||||
t.Fatalf("write follow-up websocket message: %v", errWrite)
|
||||
}
|
||||
|
||||
_, upstreamPayload, errReadMessage := conn.ReadMessage()
|
||||
if errReadMessage != nil {
|
||||
t.Fatalf("read upstream completed message: %v", errReadMessage)
|
||||
}
|
||||
if gjson.GetBytes(upstreamPayload, "type").String() != wsEventTypeCompleted {
|
||||
t.Fatalf("upstream payload type = %s, want %s", gjson.GetBytes(upstreamPayload, "type").String(), wsEventTypeCompleted)
|
||||
}
|
||||
if executor.streamCalls != 1 {
|
||||
t.Fatalf("stream calls after follow-up = %d, want 1", executor.streamCalls)
|
||||
}
|
||||
if len(executor.payloads) != 1 {
|
||||
t.Fatalf("captured upstream payloads = %d, want 1", len(executor.payloads))
|
||||
}
|
||||
forwarded := executor.payloads[0]
|
||||
if gjson.GetBytes(forwarded, "previous_response_id").Exists() {
|
||||
t.Fatalf("previous_response_id leaked upstream: %s", forwarded)
|
||||
}
|
||||
if gjson.GetBytes(forwarded, "generate").Exists() {
|
||||
t.Fatalf("generate leaked upstream: %s", forwarded)
|
||||
}
|
||||
if gjson.GetBytes(forwarded, "model").String() != "test-model" {
|
||||
t.Fatalf("forwarded model = %s, want test-model", gjson.GetBytes(forwarded, "model").String())
|
||||
}
|
||||
input := gjson.GetBytes(forwarded, "input").Array()
|
||||
if len(input) != 1 || input[0].Get("id").String() != "msg-1" {
|
||||
t.Fatalf("unexpected forwarded input: %s", forwarded)
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user