mirror of
https://github.com/router-for-me/CLIProxyAPIPlus.git
synced 2026-04-03 19:21:17 +00:00
Merge upstream v6.9.9 (PR #483)
This commit is contained in:
2
.gitignore
vendored
2
.gitignore
vendored
@@ -37,13 +37,13 @@ GEMINI.md
|
||||
|
||||
# Tooling metadata
|
||||
.vscode/*
|
||||
.worktrees/
|
||||
.codex/*
|
||||
.claude/*
|
||||
.gemini/*
|
||||
.serena/*
|
||||
.agent/*
|
||||
.agents/*
|
||||
.agents/*
|
||||
.opencode/*
|
||||
.idea/*
|
||||
.bmad/*
|
||||
|
||||
@@ -83,6 +83,10 @@ mode without windows or traces, and enables cross-device AI Q&A interaction and
|
||||
LAN). Essentially, it is an automated collaboration layer of "screen/audio capture + AI inference + low-friction delivery",
|
||||
helping users to immersively use AI assistants across applications on controlled devices or in restricted environments.
|
||||
|
||||
### [ProxyPal](https://github.com/buddingnewinsights/proxypal)
|
||||
|
||||
Cross-platform desktop app (macOS, Windows, Linux) wrapping CLIProxyAPI with a native GUI. Connects Claude, ChatGPT, Gemini, GitHub Copilot, Qwen, iFlow, and custom OpenAI-compatible endpoints with usage analytics, request monitoring, and auto-configuration for popular coding tools - no API keys needed.
|
||||
|
||||
> [!NOTE]
|
||||
> If you developed a project based on CLIProxyAPI, please open a PR to add it to this list.
|
||||
|
||||
|
||||
@@ -78,6 +78,10 @@ Windows 托盘应用,基于 PowerShell 脚本实现,不依赖任何第三方
|
||||
|
||||
Shadow AI 是一款专为受限环境设计的 AI 辅助工具。提供无窗口、无痕迹的隐蔽运行方式,并通过局域网实现跨设备的 AI 问答交互与控制。本质上是一个「屏幕/音频采集 + AI 推理 + 低摩擦投送」的自动化协作层,帮助用户在受控设备/受限环境下沉浸式跨应用地使用 AI 助手。
|
||||
|
||||
### [ProxyPal](https://github.com/buddingnewinsights/proxypal)
|
||||
|
||||
跨平台桌面应用(macOS、Windows、Linux),以原生 GUI 封装 CLIProxyAPI。支持连接 Claude、ChatGPT、Gemini、GitHub Copilot、Qwen、iFlow 及自定义 OpenAI 兼容端点,具备使用分析、请求监控和热门编程工具自动配置功能,无需 API 密钥。
|
||||
|
||||
> [!NOTE]
|
||||
> 如果你开发了基于 CLIProxyAPI 的项目,请提交一个 PR(拉取请求)将其添加到此列表中。
|
||||
|
||||
|
||||
@@ -170,6 +170,10 @@ New API互換リレーサイトアカウントをワンストップで管理す
|
||||
|
||||
Shadow AIは制限された環境向けに特別に設計されたAIアシスタントツールです。ウィンドウや痕跡のないステルス動作モードを提供し、LAN(ローカルエリアネットワーク)を介したクロスデバイスAI質疑応答のインタラクションと制御を可能にします。本質的には「画面/音声キャプチャ + AI推論 + 低摩擦デリバリー」の自動化コラボレーションレイヤーであり、制御されたデバイスや制限された環境でアプリケーション横断的にAIアシスタントを没入的に使用できるようユーザーを支援します。
|
||||
|
||||
### [ProxyPal](https://github.com/buddingnewinsights/proxypal)
|
||||
|
||||
CLIProxyAPIをネイティブGUIでラップしたクロスプラットフォームデスクトップアプリ(macOS、Windows、Linux)。Claude、ChatGPT、Gemini、GitHub Copilot、Qwen、iFlow、カスタムOpenAI互換エンドポイントに対応し、使用状況分析、リクエスト監視、人気コーディングツールの自動設定機能を搭載 - APIキー不要
|
||||
|
||||
> [!NOTE]
|
||||
> CLIProxyAPIをベースにプロジェクトを開発した場合は、PRを送ってこのリストに追加してください。
|
||||
|
||||
|
||||
@@ -1047,6 +1047,7 @@ func (h *Handler) buildAuthFromFileData(path string, data []byte) (*coreauth.Aut
|
||||
auth.Runtime = existing.Runtime
|
||||
}
|
||||
}
|
||||
coreauth.ApplyCustomHeadersFromMetadata(auth)
|
||||
return auth, nil
|
||||
}
|
||||
|
||||
@@ -1129,7 +1130,7 @@ func (h *Handler) PatchAuthFileStatus(c *gin.Context) {
|
||||
c.JSON(http.StatusOK, gin.H{"status": "ok", "disabled": *req.Disabled})
|
||||
}
|
||||
|
||||
// PatchAuthFileFields updates editable fields (prefix, proxy_url, priority, note) of an auth file.
|
||||
// PatchAuthFileFields updates editable fields (prefix, proxy_url, headers, priority, note) of an auth file.
|
||||
func (h *Handler) PatchAuthFileFields(c *gin.Context) {
|
||||
if h.authManager == nil {
|
||||
c.JSON(http.StatusServiceUnavailable, gin.H{"error": "core auth manager unavailable"})
|
||||
@@ -1137,11 +1138,12 @@ func (h *Handler) PatchAuthFileFields(c *gin.Context) {
|
||||
}
|
||||
|
||||
var req struct {
|
||||
Name string `json:"name"`
|
||||
Prefix *string `json:"prefix"`
|
||||
ProxyURL *string `json:"proxy_url"`
|
||||
Priority *int `json:"priority"`
|
||||
Note *string `json:"note"`
|
||||
Name string `json:"name"`
|
||||
Prefix *string `json:"prefix"`
|
||||
ProxyURL *string `json:"proxy_url"`
|
||||
Headers map[string]string `json:"headers"`
|
||||
Priority *int `json:"priority"`
|
||||
Note *string `json:"note"`
|
||||
}
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "invalid request body"})
|
||||
@@ -1177,13 +1179,107 @@ func (h *Handler) PatchAuthFileFields(c *gin.Context) {
|
||||
|
||||
changed := false
|
||||
if req.Prefix != nil {
|
||||
targetAuth.Prefix = *req.Prefix
|
||||
prefix := strings.TrimSpace(*req.Prefix)
|
||||
targetAuth.Prefix = prefix
|
||||
if targetAuth.Metadata == nil {
|
||||
targetAuth.Metadata = make(map[string]any)
|
||||
}
|
||||
if prefix == "" {
|
||||
delete(targetAuth.Metadata, "prefix")
|
||||
} else {
|
||||
targetAuth.Metadata["prefix"] = prefix
|
||||
}
|
||||
changed = true
|
||||
}
|
||||
if req.ProxyURL != nil {
|
||||
targetAuth.ProxyURL = *req.ProxyURL
|
||||
proxyURL := strings.TrimSpace(*req.ProxyURL)
|
||||
targetAuth.ProxyURL = proxyURL
|
||||
if targetAuth.Metadata == nil {
|
||||
targetAuth.Metadata = make(map[string]any)
|
||||
}
|
||||
if proxyURL == "" {
|
||||
delete(targetAuth.Metadata, "proxy_url")
|
||||
} else {
|
||||
targetAuth.Metadata["proxy_url"] = proxyURL
|
||||
}
|
||||
changed = true
|
||||
}
|
||||
if len(req.Headers) > 0 {
|
||||
existingHeaders := coreauth.ExtractCustomHeadersFromMetadata(targetAuth.Metadata)
|
||||
nextHeaders := make(map[string]string, len(existingHeaders))
|
||||
for k, v := range existingHeaders {
|
||||
nextHeaders[k] = v
|
||||
}
|
||||
headerChanged := false
|
||||
|
||||
for key, value := range req.Headers {
|
||||
name := strings.TrimSpace(key)
|
||||
if name == "" {
|
||||
continue
|
||||
}
|
||||
val := strings.TrimSpace(value)
|
||||
attrKey := "header:" + name
|
||||
if val == "" {
|
||||
if _, ok := nextHeaders[name]; ok {
|
||||
delete(nextHeaders, name)
|
||||
headerChanged = true
|
||||
}
|
||||
if targetAuth.Attributes != nil {
|
||||
if _, ok := targetAuth.Attributes[attrKey]; ok {
|
||||
headerChanged = true
|
||||
}
|
||||
}
|
||||
continue
|
||||
}
|
||||
if prev, ok := nextHeaders[name]; !ok || prev != val {
|
||||
headerChanged = true
|
||||
}
|
||||
nextHeaders[name] = val
|
||||
if targetAuth.Attributes != nil {
|
||||
if prev, ok := targetAuth.Attributes[attrKey]; !ok || prev != val {
|
||||
headerChanged = true
|
||||
}
|
||||
} else {
|
||||
headerChanged = true
|
||||
}
|
||||
}
|
||||
|
||||
if headerChanged {
|
||||
if targetAuth.Metadata == nil {
|
||||
targetAuth.Metadata = make(map[string]any)
|
||||
}
|
||||
if targetAuth.Attributes == nil {
|
||||
targetAuth.Attributes = make(map[string]string)
|
||||
}
|
||||
|
||||
for key, value := range req.Headers {
|
||||
name := strings.TrimSpace(key)
|
||||
if name == "" {
|
||||
continue
|
||||
}
|
||||
val := strings.TrimSpace(value)
|
||||
attrKey := "header:" + name
|
||||
if val == "" {
|
||||
delete(nextHeaders, name)
|
||||
delete(targetAuth.Attributes, attrKey)
|
||||
continue
|
||||
}
|
||||
nextHeaders[name] = val
|
||||
targetAuth.Attributes[attrKey] = val
|
||||
}
|
||||
|
||||
if len(nextHeaders) == 0 {
|
||||
delete(targetAuth.Metadata, "headers")
|
||||
} else {
|
||||
metaHeaders := make(map[string]any, len(nextHeaders))
|
||||
for k, v := range nextHeaders {
|
||||
metaHeaders[k] = v
|
||||
}
|
||||
targetAuth.Metadata["headers"] = metaHeaders
|
||||
}
|
||||
changed = true
|
||||
}
|
||||
}
|
||||
if req.Priority != nil || req.Note != nil {
|
||||
if targetAuth.Metadata == nil {
|
||||
targetAuth.Metadata = make(map[string]any)
|
||||
|
||||
164
internal/api/handlers/management/auth_files_patch_fields_test.go
Normal file
164
internal/api/handlers/management/auth_files_patch_fields_test.go
Normal file
@@ -0,0 +1,164 @@
|
||||
package management
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
|
||||
coreauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth"
|
||||
)
|
||||
|
||||
func TestPatchAuthFileFields_MergeHeadersAndDeleteEmptyValues(t *testing.T) {
|
||||
t.Setenv("MANAGEMENT_PASSWORD", "")
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
store := &memoryAuthStore{}
|
||||
manager := coreauth.NewManager(store, nil, nil)
|
||||
record := &coreauth.Auth{
|
||||
ID: "test.json",
|
||||
FileName: "test.json",
|
||||
Provider: "claude",
|
||||
Attributes: map[string]string{
|
||||
"path": "/tmp/test.json",
|
||||
"header:X-Old": "old",
|
||||
"header:X-Remove": "gone",
|
||||
},
|
||||
Metadata: map[string]any{
|
||||
"type": "claude",
|
||||
"headers": map[string]any{
|
||||
"X-Old": "old",
|
||||
"X-Remove": "gone",
|
||||
},
|
||||
},
|
||||
}
|
||||
if _, errRegister := manager.Register(context.Background(), record); errRegister != nil {
|
||||
t.Fatalf("failed to register auth record: %v", errRegister)
|
||||
}
|
||||
|
||||
h := NewHandlerWithoutConfigFilePath(&config.Config{AuthDir: t.TempDir()}, manager)
|
||||
|
||||
body := `{"name":"test.json","prefix":"p1","proxy_url":"http://proxy.local","headers":{"X-Old":"new","X-New":"v","X-Remove":" ","X-Nope":""}}`
|
||||
rec := httptest.NewRecorder()
|
||||
ctx, _ := gin.CreateTestContext(rec)
|
||||
req := httptest.NewRequest(http.MethodPatch, "/v0/management/auth-files/fields", strings.NewReader(body))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
ctx.Request = req
|
||||
h.PatchAuthFileFields(ctx)
|
||||
|
||||
if rec.Code != http.StatusOK {
|
||||
t.Fatalf("expected status %d, got %d with body %s", http.StatusOK, rec.Code, rec.Body.String())
|
||||
}
|
||||
|
||||
updated, ok := manager.GetByID("test.json")
|
||||
if !ok || updated == nil {
|
||||
t.Fatalf("expected auth record to exist after patch")
|
||||
}
|
||||
|
||||
if updated.Prefix != "p1" {
|
||||
t.Fatalf("prefix = %q, want %q", updated.Prefix, "p1")
|
||||
}
|
||||
if updated.ProxyURL != "http://proxy.local" {
|
||||
t.Fatalf("proxy_url = %q, want %q", updated.ProxyURL, "http://proxy.local")
|
||||
}
|
||||
|
||||
if updated.Metadata == nil {
|
||||
t.Fatalf("expected metadata to be non-nil")
|
||||
}
|
||||
if got, _ := updated.Metadata["prefix"].(string); got != "p1" {
|
||||
t.Fatalf("metadata.prefix = %q, want %q", got, "p1")
|
||||
}
|
||||
if got, _ := updated.Metadata["proxy_url"].(string); got != "http://proxy.local" {
|
||||
t.Fatalf("metadata.proxy_url = %q, want %q", got, "http://proxy.local")
|
||||
}
|
||||
|
||||
headersMeta, ok := updated.Metadata["headers"].(map[string]any)
|
||||
if !ok {
|
||||
raw, _ := json.Marshal(updated.Metadata["headers"])
|
||||
t.Fatalf("metadata.headers = %T (%s), want map[string]any", updated.Metadata["headers"], string(raw))
|
||||
}
|
||||
if got := headersMeta["X-Old"]; got != "new" {
|
||||
t.Fatalf("metadata.headers.X-Old = %#v, want %q", got, "new")
|
||||
}
|
||||
if got := headersMeta["X-New"]; got != "v" {
|
||||
t.Fatalf("metadata.headers.X-New = %#v, want %q", got, "v")
|
||||
}
|
||||
if _, ok := headersMeta["X-Remove"]; ok {
|
||||
t.Fatalf("expected metadata.headers.X-Remove to be deleted")
|
||||
}
|
||||
if _, ok := headersMeta["X-Nope"]; ok {
|
||||
t.Fatalf("expected metadata.headers.X-Nope to be absent")
|
||||
}
|
||||
|
||||
if got := updated.Attributes["header:X-Old"]; got != "new" {
|
||||
t.Fatalf("attrs header:X-Old = %q, want %q", got, "new")
|
||||
}
|
||||
if got := updated.Attributes["header:X-New"]; got != "v" {
|
||||
t.Fatalf("attrs header:X-New = %q, want %q", got, "v")
|
||||
}
|
||||
if _, ok := updated.Attributes["header:X-Remove"]; ok {
|
||||
t.Fatalf("expected attrs header:X-Remove to be deleted")
|
||||
}
|
||||
if _, ok := updated.Attributes["header:X-Nope"]; ok {
|
||||
t.Fatalf("expected attrs header:X-Nope to be absent")
|
||||
}
|
||||
}
|
||||
|
||||
func TestPatchAuthFileFields_HeadersEmptyMapIsNoop(t *testing.T) {
|
||||
t.Setenv("MANAGEMENT_PASSWORD", "")
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
store := &memoryAuthStore{}
|
||||
manager := coreauth.NewManager(store, nil, nil)
|
||||
record := &coreauth.Auth{
|
||||
ID: "noop.json",
|
||||
FileName: "noop.json",
|
||||
Provider: "claude",
|
||||
Attributes: map[string]string{
|
||||
"path": "/tmp/noop.json",
|
||||
"header:X-Kee": "1",
|
||||
},
|
||||
Metadata: map[string]any{
|
||||
"type": "claude",
|
||||
"headers": map[string]any{
|
||||
"X-Kee": "1",
|
||||
},
|
||||
},
|
||||
}
|
||||
if _, errRegister := manager.Register(context.Background(), record); errRegister != nil {
|
||||
t.Fatalf("failed to register auth record: %v", errRegister)
|
||||
}
|
||||
|
||||
h := NewHandlerWithoutConfigFilePath(&config.Config{AuthDir: t.TempDir()}, manager)
|
||||
|
||||
body := `{"name":"noop.json","note":"hello","headers":{}}`
|
||||
rec := httptest.NewRecorder()
|
||||
ctx, _ := gin.CreateTestContext(rec)
|
||||
req := httptest.NewRequest(http.MethodPatch, "/v0/management/auth-files/fields", strings.NewReader(body))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
ctx.Request = req
|
||||
h.PatchAuthFileFields(ctx)
|
||||
|
||||
if rec.Code != http.StatusOK {
|
||||
t.Fatalf("expected status %d, got %d with body %s", http.StatusOK, rec.Code, rec.Body.String())
|
||||
}
|
||||
|
||||
updated, ok := manager.GetByID("noop.json")
|
||||
if !ok || updated == nil {
|
||||
t.Fatalf("expected auth record to exist after patch")
|
||||
}
|
||||
if got := updated.Attributes["header:X-Kee"]; got != "1" {
|
||||
t.Fatalf("attrs header:X-Kee = %q, want %q", got, "1")
|
||||
}
|
||||
headersMeta, ok := updated.Metadata["headers"].(map[string]any)
|
||||
if !ok {
|
||||
t.Fatalf("expected metadata.headers to remain a map, got %T", updated.Metadata["headers"])
|
||||
}
|
||||
if got := headersMeta["X-Kee"]; got != "1" {
|
||||
t.Fatalf("metadata.headers.X-Kee = %#v, want %q", got, "1")
|
||||
}
|
||||
}
|
||||
@@ -15,6 +15,8 @@ import (
|
||||
)
|
||||
|
||||
const requestBodyOverrideContextKey = "REQUEST_BODY_OVERRIDE"
|
||||
const responseBodyOverrideContextKey = "RESPONSE_BODY_OVERRIDE"
|
||||
const websocketTimelineOverrideContextKey = "WEBSOCKET_TIMELINE_OVERRIDE"
|
||||
|
||||
// RequestInfo holds essential details of an incoming HTTP request for logging purposes.
|
||||
type RequestInfo struct {
|
||||
@@ -304,6 +306,10 @@ func (w *ResponseWriterWrapper) Finalize(c *gin.Context) error {
|
||||
if len(apiResponse) > 0 {
|
||||
_ = w.streamWriter.WriteAPIResponse(apiResponse)
|
||||
}
|
||||
apiWebsocketTimeline := w.extractAPIWebsocketTimeline(c)
|
||||
if len(apiWebsocketTimeline) > 0 {
|
||||
_ = w.streamWriter.WriteAPIWebsocketTimeline(apiWebsocketTimeline)
|
||||
}
|
||||
if err := w.streamWriter.Close(); err != nil {
|
||||
w.streamWriter = nil
|
||||
return err
|
||||
@@ -312,7 +318,7 @@ func (w *ResponseWriterWrapper) Finalize(c *gin.Context) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
return w.logRequest(w.extractRequestBody(c), finalStatusCode, w.cloneHeaders(), w.body.Bytes(), w.extractAPIRequest(c), w.extractAPIResponse(c), w.extractAPIResponseTimestamp(c), slicesAPIResponseError, forceLog)
|
||||
return w.logRequest(w.extractRequestBody(c), finalStatusCode, w.cloneHeaders(), w.extractResponseBody(c), w.extractWebsocketTimeline(c), w.extractAPIRequest(c), w.extractAPIResponse(c), w.extractAPIWebsocketTimeline(c), w.extractAPIResponseTimestamp(c), slicesAPIResponseError, forceLog)
|
||||
}
|
||||
|
||||
func (w *ResponseWriterWrapper) cloneHeaders() map[string][]string {
|
||||
@@ -352,6 +358,18 @@ func (w *ResponseWriterWrapper) extractAPIResponse(c *gin.Context) []byte {
|
||||
return data
|
||||
}
|
||||
|
||||
func (w *ResponseWriterWrapper) extractAPIWebsocketTimeline(c *gin.Context) []byte {
|
||||
apiTimeline, isExist := c.Get("API_WEBSOCKET_TIMELINE")
|
||||
if !isExist {
|
||||
return nil
|
||||
}
|
||||
data, ok := apiTimeline.([]byte)
|
||||
if !ok || len(data) == 0 {
|
||||
return nil
|
||||
}
|
||||
return bytes.Clone(data)
|
||||
}
|
||||
|
||||
func (w *ResponseWriterWrapper) extractAPIResponseTimestamp(c *gin.Context) time.Time {
|
||||
ts, isExist := c.Get("API_RESPONSE_TIMESTAMP")
|
||||
if !isExist {
|
||||
@@ -364,19 +382,8 @@ func (w *ResponseWriterWrapper) extractAPIResponseTimestamp(c *gin.Context) time
|
||||
}
|
||||
|
||||
func (w *ResponseWriterWrapper) extractRequestBody(c *gin.Context) []byte {
|
||||
if c != nil {
|
||||
if bodyOverride, isExist := c.Get(requestBodyOverrideContextKey); isExist {
|
||||
switch value := bodyOverride.(type) {
|
||||
case []byte:
|
||||
if len(value) > 0 {
|
||||
return bytes.Clone(value)
|
||||
}
|
||||
case string:
|
||||
if strings.TrimSpace(value) != "" {
|
||||
return []byte(value)
|
||||
}
|
||||
}
|
||||
}
|
||||
if body := extractBodyOverride(c, requestBodyOverrideContextKey); len(body) > 0 {
|
||||
return body
|
||||
}
|
||||
if w.requestInfo != nil && len(w.requestInfo.Body) > 0 {
|
||||
return w.requestInfo.Body
|
||||
@@ -384,13 +391,48 @@ func (w *ResponseWriterWrapper) extractRequestBody(c *gin.Context) []byte {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (w *ResponseWriterWrapper) logRequest(requestBody []byte, statusCode int, headers map[string][]string, body []byte, apiRequestBody, apiResponseBody []byte, apiResponseTimestamp time.Time, apiResponseErrors []*interfaces.ErrorMessage, forceLog bool) error {
|
||||
func (w *ResponseWriterWrapper) extractResponseBody(c *gin.Context) []byte {
|
||||
if body := extractBodyOverride(c, responseBodyOverrideContextKey); len(body) > 0 {
|
||||
return body
|
||||
}
|
||||
if w.body == nil || w.body.Len() == 0 {
|
||||
return nil
|
||||
}
|
||||
return bytes.Clone(w.body.Bytes())
|
||||
}
|
||||
|
||||
func (w *ResponseWriterWrapper) extractWebsocketTimeline(c *gin.Context) []byte {
|
||||
return extractBodyOverride(c, websocketTimelineOverrideContextKey)
|
||||
}
|
||||
|
||||
func extractBodyOverride(c *gin.Context, key string) []byte {
|
||||
if c == nil {
|
||||
return nil
|
||||
}
|
||||
bodyOverride, isExist := c.Get(key)
|
||||
if !isExist {
|
||||
return nil
|
||||
}
|
||||
switch value := bodyOverride.(type) {
|
||||
case []byte:
|
||||
if len(value) > 0 {
|
||||
return bytes.Clone(value)
|
||||
}
|
||||
case string:
|
||||
if strings.TrimSpace(value) != "" {
|
||||
return []byte(value)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (w *ResponseWriterWrapper) logRequest(requestBody []byte, statusCode int, headers map[string][]string, body, websocketTimeline, apiRequestBody, apiResponseBody, apiWebsocketTimeline []byte, apiResponseTimestamp time.Time, apiResponseErrors []*interfaces.ErrorMessage, forceLog bool) error {
|
||||
if w.requestInfo == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
if loggerWithOptions, ok := w.logger.(interface {
|
||||
LogRequestWithOptions(string, string, map[string][]string, []byte, int, map[string][]string, []byte, []byte, []byte, []*interfaces.ErrorMessage, bool, string, time.Time, time.Time) error
|
||||
LogRequestWithOptions(string, string, map[string][]string, []byte, int, map[string][]string, []byte, []byte, []byte, []byte, []byte, []*interfaces.ErrorMessage, bool, string, time.Time, time.Time) error
|
||||
}); ok {
|
||||
return loggerWithOptions.LogRequestWithOptions(
|
||||
w.requestInfo.URL,
|
||||
@@ -400,8 +442,10 @@ func (w *ResponseWriterWrapper) logRequest(requestBody []byte, statusCode int, h
|
||||
statusCode,
|
||||
headers,
|
||||
body,
|
||||
websocketTimeline,
|
||||
apiRequestBody,
|
||||
apiResponseBody,
|
||||
apiWebsocketTimeline,
|
||||
apiResponseErrors,
|
||||
forceLog,
|
||||
w.requestInfo.RequestID,
|
||||
@@ -418,8 +462,10 @@ func (w *ResponseWriterWrapper) logRequest(requestBody []byte, statusCode int, h
|
||||
statusCode,
|
||||
headers,
|
||||
body,
|
||||
websocketTimeline,
|
||||
apiRequestBody,
|
||||
apiResponseBody,
|
||||
apiWebsocketTimeline,
|
||||
apiResponseErrors,
|
||||
w.requestInfo.RequestID,
|
||||
w.requestInfo.Timestamp,
|
||||
|
||||
@@ -1,10 +1,14 @@
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/interfaces"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/logging"
|
||||
)
|
||||
|
||||
func TestExtractRequestBodyPrefersOverride(t *testing.T) {
|
||||
@@ -33,7 +37,7 @@ func TestExtractRequestBodySupportsStringOverride(t *testing.T) {
|
||||
recorder := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(recorder)
|
||||
|
||||
wrapper := &ResponseWriterWrapper{}
|
||||
wrapper := &ResponseWriterWrapper{body: &bytes.Buffer{}}
|
||||
c.Set(requestBodyOverrideContextKey, "override-as-string")
|
||||
|
||||
body := wrapper.extractRequestBody(c)
|
||||
@@ -41,3 +45,158 @@ func TestExtractRequestBodySupportsStringOverride(t *testing.T) {
|
||||
t.Fatalf("request body = %q, want %q", string(body), "override-as-string")
|
||||
}
|
||||
}
|
||||
|
||||
func TestExtractResponseBodyPrefersOverride(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
recorder := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(recorder)
|
||||
|
||||
wrapper := &ResponseWriterWrapper{body: &bytes.Buffer{}}
|
||||
wrapper.body.WriteString("original-response")
|
||||
|
||||
body := wrapper.extractResponseBody(c)
|
||||
if string(body) != "original-response" {
|
||||
t.Fatalf("response body = %q, want %q", string(body), "original-response")
|
||||
}
|
||||
|
||||
c.Set(responseBodyOverrideContextKey, []byte("override-response"))
|
||||
body = wrapper.extractResponseBody(c)
|
||||
if string(body) != "override-response" {
|
||||
t.Fatalf("response body = %q, want %q", string(body), "override-response")
|
||||
}
|
||||
|
||||
body[0] = 'X'
|
||||
if got := wrapper.extractResponseBody(c); string(got) != "override-response" {
|
||||
t.Fatalf("response override should be cloned, got %q", string(got))
|
||||
}
|
||||
}
|
||||
|
||||
func TestExtractResponseBodySupportsStringOverride(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
recorder := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(recorder)
|
||||
|
||||
wrapper := &ResponseWriterWrapper{}
|
||||
c.Set(responseBodyOverrideContextKey, "override-response-as-string")
|
||||
|
||||
body := wrapper.extractResponseBody(c)
|
||||
if string(body) != "override-response-as-string" {
|
||||
t.Fatalf("response body = %q, want %q", string(body), "override-response-as-string")
|
||||
}
|
||||
}
|
||||
|
||||
func TestExtractBodyOverrideClonesBytes(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
recorder := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(recorder)
|
||||
|
||||
override := []byte("body-override")
|
||||
c.Set(requestBodyOverrideContextKey, override)
|
||||
|
||||
body := extractBodyOverride(c, requestBodyOverrideContextKey)
|
||||
if !bytes.Equal(body, override) {
|
||||
t.Fatalf("body override = %q, want %q", string(body), string(override))
|
||||
}
|
||||
|
||||
body[0] = 'X'
|
||||
if !bytes.Equal(override, []byte("body-override")) {
|
||||
t.Fatalf("override mutated: %q", string(override))
|
||||
}
|
||||
}
|
||||
|
||||
func TestExtractWebsocketTimelineUsesOverride(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
recorder := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(recorder)
|
||||
|
||||
wrapper := &ResponseWriterWrapper{}
|
||||
if got := wrapper.extractWebsocketTimeline(c); got != nil {
|
||||
t.Fatalf("expected nil websocket timeline, got %q", string(got))
|
||||
}
|
||||
|
||||
c.Set(websocketTimelineOverrideContextKey, []byte("timeline"))
|
||||
body := wrapper.extractWebsocketTimeline(c)
|
||||
if string(body) != "timeline" {
|
||||
t.Fatalf("websocket timeline = %q, want %q", string(body), "timeline")
|
||||
}
|
||||
}
|
||||
|
||||
func TestFinalizeStreamingWritesAPIWebsocketTimeline(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
recorder := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(recorder)
|
||||
|
||||
streamWriter := &testStreamingLogWriter{}
|
||||
wrapper := &ResponseWriterWrapper{
|
||||
ResponseWriter: c.Writer,
|
||||
logger: &testRequestLogger{enabled: true},
|
||||
requestInfo: &RequestInfo{
|
||||
URL: "/v1/responses",
|
||||
Method: "POST",
|
||||
Headers: map[string][]string{"Content-Type": {"application/json"}},
|
||||
RequestID: "req-1",
|
||||
Timestamp: time.Date(2026, time.April, 1, 12, 0, 0, 0, time.UTC),
|
||||
},
|
||||
isStreaming: true,
|
||||
streamWriter: streamWriter,
|
||||
}
|
||||
|
||||
c.Set("API_WEBSOCKET_TIMELINE", []byte("Timestamp: 2026-04-01T12:00:00Z\nEvent: api.websocket.request\n{}"))
|
||||
|
||||
if err := wrapper.Finalize(c); err != nil {
|
||||
t.Fatalf("Finalize error: %v", err)
|
||||
}
|
||||
if string(streamWriter.apiWebsocketTimeline) != "Timestamp: 2026-04-01T12:00:00Z\nEvent: api.websocket.request\n{}" {
|
||||
t.Fatalf("stream writer websocket timeline = %q", string(streamWriter.apiWebsocketTimeline))
|
||||
}
|
||||
if !streamWriter.closed {
|
||||
t.Fatal("expected stream writer to be closed")
|
||||
}
|
||||
}
|
||||
|
||||
type testRequestLogger struct {
|
||||
enabled bool
|
||||
}
|
||||
|
||||
func (l *testRequestLogger) LogRequest(string, string, map[string][]string, []byte, int, map[string][]string, []byte, []byte, []byte, []byte, []byte, []*interfaces.ErrorMessage, string, time.Time, time.Time) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (l *testRequestLogger) LogStreamingRequest(string, string, map[string][]string, []byte, string) (logging.StreamingLogWriter, error) {
|
||||
return &testStreamingLogWriter{}, nil
|
||||
}
|
||||
|
||||
func (l *testRequestLogger) IsEnabled() bool {
|
||||
return l.enabled
|
||||
}
|
||||
|
||||
type testStreamingLogWriter struct {
|
||||
apiWebsocketTimeline []byte
|
||||
closed bool
|
||||
}
|
||||
|
||||
func (w *testStreamingLogWriter) WriteChunkAsync([]byte) {}
|
||||
|
||||
func (w *testStreamingLogWriter) WriteStatus(int, map[string][]string) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (w *testStreamingLogWriter) WriteAPIRequest([]byte) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (w *testStreamingLogWriter) WriteAPIResponse([]byte) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (w *testStreamingLogWriter) WriteAPIWebsocketTimeline(apiWebsocketTimeline []byte) error {
|
||||
w.apiWebsocketTimeline = bytes.Clone(apiWebsocketTimeline)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (w *testStreamingLogWriter) SetFirstChunkTimestamp(time.Time) {}
|
||||
|
||||
func (w *testStreamingLogWriter) Close() error {
|
||||
w.closed = true
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -172,6 +172,8 @@ func TestDefaultRequestLoggerFactory_UsesResolvedLogDirectory(t *testing.T) {
|
||||
nil,
|
||||
nil,
|
||||
nil,
|
||||
nil,
|
||||
nil,
|
||||
true,
|
||||
"issue-1711",
|
||||
time.Now(),
|
||||
|
||||
@@ -88,7 +88,7 @@ func (o *ClaudeAuth) GenerateAuthURL(state string, pkceCodes *PKCECodes) (string
|
||||
"client_id": {ClientID},
|
||||
"response_type": {"code"},
|
||||
"redirect_uri": {RedirectURI},
|
||||
"scope": {"org:create_api_key user:profile user:inference"},
|
||||
"scope": {"user:profile user:inference user:sessions:claude_code user:mcp_servers user:file_upload"},
|
||||
"code_challenge": {pkceCodes.CodeChallenge},
|
||||
"code_challenge_method": {"S256"},
|
||||
"state": {state},
|
||||
|
||||
@@ -4,6 +4,7 @@
|
||||
package logging
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"bytes"
|
||||
"compress/flate"
|
||||
"compress/gzip"
|
||||
@@ -41,15 +42,17 @@ type RequestLogger interface {
|
||||
// - statusCode: The response status code
|
||||
// - responseHeaders: The response headers
|
||||
// - response: The raw response data
|
||||
// - websocketTimeline: Optional downstream websocket event timeline
|
||||
// - apiRequest: The API request data
|
||||
// - apiResponse: The API response data
|
||||
// - apiWebsocketTimeline: Optional upstream websocket event timeline
|
||||
// - requestID: Optional request ID for log file naming
|
||||
// - requestTimestamp: When the request was received
|
||||
// - apiResponseTimestamp: When the API response was received
|
||||
//
|
||||
// Returns:
|
||||
// - error: An error if logging fails, nil otherwise
|
||||
LogRequest(url, method string, requestHeaders map[string][]string, body []byte, statusCode int, responseHeaders map[string][]string, response, apiRequest, apiResponse []byte, apiResponseErrors []*interfaces.ErrorMessage, requestID string, requestTimestamp, apiResponseTimestamp time.Time) error
|
||||
LogRequest(url, method string, requestHeaders map[string][]string, body []byte, statusCode int, responseHeaders map[string][]string, response, websocketTimeline, apiRequest, apiResponse, apiWebsocketTimeline []byte, apiResponseErrors []*interfaces.ErrorMessage, requestID string, requestTimestamp, apiResponseTimestamp time.Time) error
|
||||
|
||||
// LogStreamingRequest initiates logging for a streaming request and returns a writer for chunks.
|
||||
//
|
||||
@@ -111,6 +114,16 @@ type StreamingLogWriter interface {
|
||||
// - error: An error if writing fails, nil otherwise
|
||||
WriteAPIResponse(apiResponse []byte) error
|
||||
|
||||
// WriteAPIWebsocketTimeline writes the upstream websocket timeline to the log.
|
||||
// This should be called when upstream communication happened over websocket.
|
||||
//
|
||||
// Parameters:
|
||||
// - apiWebsocketTimeline: The upstream websocket event timeline
|
||||
//
|
||||
// Returns:
|
||||
// - error: An error if writing fails, nil otherwise
|
||||
WriteAPIWebsocketTimeline(apiWebsocketTimeline []byte) error
|
||||
|
||||
// SetFirstChunkTimestamp sets the TTFB timestamp captured when first chunk was received.
|
||||
//
|
||||
// Parameters:
|
||||
@@ -203,17 +216,17 @@ func (l *FileRequestLogger) SetErrorLogsMaxFiles(maxFiles int) {
|
||||
//
|
||||
// Returns:
|
||||
// - error: An error if logging fails, nil otherwise
|
||||
func (l *FileRequestLogger) LogRequest(url, method string, requestHeaders map[string][]string, body []byte, statusCode int, responseHeaders map[string][]string, response, apiRequest, apiResponse []byte, apiResponseErrors []*interfaces.ErrorMessage, requestID string, requestTimestamp, apiResponseTimestamp time.Time) error {
|
||||
return l.logRequest(url, method, requestHeaders, body, statusCode, responseHeaders, response, apiRequest, apiResponse, apiResponseErrors, false, requestID, requestTimestamp, apiResponseTimestamp)
|
||||
func (l *FileRequestLogger) LogRequest(url, method string, requestHeaders map[string][]string, body []byte, statusCode int, responseHeaders map[string][]string, response, websocketTimeline, apiRequest, apiResponse, apiWebsocketTimeline []byte, apiResponseErrors []*interfaces.ErrorMessage, requestID string, requestTimestamp, apiResponseTimestamp time.Time) error {
|
||||
return l.logRequest(url, method, requestHeaders, body, statusCode, responseHeaders, response, websocketTimeline, apiRequest, apiResponse, apiWebsocketTimeline, apiResponseErrors, false, requestID, requestTimestamp, apiResponseTimestamp)
|
||||
}
|
||||
|
||||
// LogRequestWithOptions logs a request with optional forced logging behavior.
|
||||
// The force flag allows writing error logs even when regular request logging is disabled.
|
||||
func (l *FileRequestLogger) LogRequestWithOptions(url, method string, requestHeaders map[string][]string, body []byte, statusCode int, responseHeaders map[string][]string, response, apiRequest, apiResponse []byte, apiResponseErrors []*interfaces.ErrorMessage, force bool, requestID string, requestTimestamp, apiResponseTimestamp time.Time) error {
|
||||
return l.logRequest(url, method, requestHeaders, body, statusCode, responseHeaders, response, apiRequest, apiResponse, apiResponseErrors, force, requestID, requestTimestamp, apiResponseTimestamp)
|
||||
func (l *FileRequestLogger) LogRequestWithOptions(url, method string, requestHeaders map[string][]string, body []byte, statusCode int, responseHeaders map[string][]string, response, websocketTimeline, apiRequest, apiResponse, apiWebsocketTimeline []byte, apiResponseErrors []*interfaces.ErrorMessage, force bool, requestID string, requestTimestamp, apiResponseTimestamp time.Time) error {
|
||||
return l.logRequest(url, method, requestHeaders, body, statusCode, responseHeaders, response, websocketTimeline, apiRequest, apiResponse, apiWebsocketTimeline, apiResponseErrors, force, requestID, requestTimestamp, apiResponseTimestamp)
|
||||
}
|
||||
|
||||
func (l *FileRequestLogger) logRequest(url, method string, requestHeaders map[string][]string, body []byte, statusCode int, responseHeaders map[string][]string, response, apiRequest, apiResponse []byte, apiResponseErrors []*interfaces.ErrorMessage, force bool, requestID string, requestTimestamp, apiResponseTimestamp time.Time) error {
|
||||
func (l *FileRequestLogger) logRequest(url, method string, requestHeaders map[string][]string, body []byte, statusCode int, responseHeaders map[string][]string, response, websocketTimeline, apiRequest, apiResponse, apiWebsocketTimeline []byte, apiResponseErrors []*interfaces.ErrorMessage, force bool, requestID string, requestTimestamp, apiResponseTimestamp time.Time) error {
|
||||
if !l.enabled && !force {
|
||||
return nil
|
||||
}
|
||||
@@ -260,8 +273,10 @@ func (l *FileRequestLogger) logRequest(url, method string, requestHeaders map[st
|
||||
requestHeaders,
|
||||
body,
|
||||
requestBodyPath,
|
||||
websocketTimeline,
|
||||
apiRequest,
|
||||
apiResponse,
|
||||
apiWebsocketTimeline,
|
||||
apiResponseErrors,
|
||||
statusCode,
|
||||
responseHeaders,
|
||||
@@ -518,8 +533,10 @@ func (l *FileRequestLogger) writeNonStreamingLog(
|
||||
requestHeaders map[string][]string,
|
||||
requestBody []byte,
|
||||
requestBodyPath string,
|
||||
websocketTimeline []byte,
|
||||
apiRequest []byte,
|
||||
apiResponse []byte,
|
||||
apiWebsocketTimeline []byte,
|
||||
apiResponseErrors []*interfaces.ErrorMessage,
|
||||
statusCode int,
|
||||
responseHeaders map[string][]string,
|
||||
@@ -531,7 +548,16 @@ func (l *FileRequestLogger) writeNonStreamingLog(
|
||||
if requestTimestamp.IsZero() {
|
||||
requestTimestamp = time.Now()
|
||||
}
|
||||
if errWrite := writeRequestInfoWithBody(w, url, method, requestHeaders, requestBody, requestBodyPath, requestTimestamp); errWrite != nil {
|
||||
isWebsocketTranscript := hasSectionPayload(websocketTimeline)
|
||||
downstreamTransport := inferDownstreamTransport(requestHeaders, websocketTimeline)
|
||||
upstreamTransport := inferUpstreamTransport(apiRequest, apiResponse, apiWebsocketTimeline, apiResponseErrors)
|
||||
if errWrite := writeRequestInfoWithBody(w, url, method, requestHeaders, requestBody, requestBodyPath, requestTimestamp, downstreamTransport, upstreamTransport, !isWebsocketTranscript); errWrite != nil {
|
||||
return errWrite
|
||||
}
|
||||
if errWrite := writeAPISection(w, "=== WEBSOCKET TIMELINE ===\n", "=== WEBSOCKET TIMELINE", websocketTimeline, time.Time{}); errWrite != nil {
|
||||
return errWrite
|
||||
}
|
||||
if errWrite := writeAPISection(w, "=== API WEBSOCKET TIMELINE ===\n", "=== API WEBSOCKET TIMELINE", apiWebsocketTimeline, time.Time{}); errWrite != nil {
|
||||
return errWrite
|
||||
}
|
||||
if errWrite := writeAPISection(w, "=== API REQUEST ===\n", "=== API REQUEST", apiRequest, time.Time{}); errWrite != nil {
|
||||
@@ -543,6 +569,12 @@ func (l *FileRequestLogger) writeNonStreamingLog(
|
||||
if errWrite := writeAPISection(w, "=== API RESPONSE ===\n", "=== API RESPONSE", apiResponse, apiResponseTimestamp); errWrite != nil {
|
||||
return errWrite
|
||||
}
|
||||
if isWebsocketTranscript {
|
||||
// Intentionally omit the generic downstream HTTP response section for websocket
|
||||
// transcripts. The durable session exchange is captured in WEBSOCKET TIMELINE,
|
||||
// and appending a one-off upgrade response snapshot would dilute that transcript.
|
||||
return nil
|
||||
}
|
||||
return writeResponseSection(w, statusCode, true, responseHeaders, bytes.NewReader(response), decompressErr, true)
|
||||
}
|
||||
|
||||
@@ -553,6 +585,9 @@ func writeRequestInfoWithBody(
|
||||
body []byte,
|
||||
bodyPath string,
|
||||
timestamp time.Time,
|
||||
downstreamTransport string,
|
||||
upstreamTransport string,
|
||||
includeBody bool,
|
||||
) error {
|
||||
if _, errWrite := io.WriteString(w, "=== REQUEST INFO ===\n"); errWrite != nil {
|
||||
return errWrite
|
||||
@@ -566,10 +601,20 @@ func writeRequestInfoWithBody(
|
||||
if _, errWrite := io.WriteString(w, fmt.Sprintf("Method: %s\n", method)); errWrite != nil {
|
||||
return errWrite
|
||||
}
|
||||
if strings.TrimSpace(downstreamTransport) != "" {
|
||||
if _, errWrite := io.WriteString(w, fmt.Sprintf("Downstream Transport: %s\n", downstreamTransport)); errWrite != nil {
|
||||
return errWrite
|
||||
}
|
||||
}
|
||||
if strings.TrimSpace(upstreamTransport) != "" {
|
||||
if _, errWrite := io.WriteString(w, fmt.Sprintf("Upstream Transport: %s\n", upstreamTransport)); errWrite != nil {
|
||||
return errWrite
|
||||
}
|
||||
}
|
||||
if _, errWrite := io.WriteString(w, fmt.Sprintf("Timestamp: %s\n", timestamp.Format(time.RFC3339Nano))); errWrite != nil {
|
||||
return errWrite
|
||||
}
|
||||
if _, errWrite := io.WriteString(w, "\n"); errWrite != nil {
|
||||
if errWrite := writeSectionSpacing(w, 1); errWrite != nil {
|
||||
return errWrite
|
||||
}
|
||||
|
||||
@@ -584,36 +629,121 @@ func writeRequestInfoWithBody(
|
||||
}
|
||||
}
|
||||
}
|
||||
if _, errWrite := io.WriteString(w, "\n"); errWrite != nil {
|
||||
if errWrite := writeSectionSpacing(w, 1); errWrite != nil {
|
||||
return errWrite
|
||||
}
|
||||
|
||||
if !includeBody {
|
||||
return nil
|
||||
}
|
||||
|
||||
if _, errWrite := io.WriteString(w, "=== REQUEST BODY ===\n"); errWrite != nil {
|
||||
return errWrite
|
||||
}
|
||||
|
||||
bodyTrailingNewlines := 1
|
||||
if bodyPath != "" {
|
||||
bodyFile, errOpen := os.Open(bodyPath)
|
||||
if errOpen != nil {
|
||||
return errOpen
|
||||
}
|
||||
if _, errCopy := io.Copy(w, bodyFile); errCopy != nil {
|
||||
tracker := &trailingNewlineTrackingWriter{writer: w}
|
||||
written, errCopy := io.Copy(tracker, bodyFile)
|
||||
if errCopy != nil {
|
||||
_ = bodyFile.Close()
|
||||
return errCopy
|
||||
}
|
||||
if written > 0 {
|
||||
bodyTrailingNewlines = tracker.trailingNewlines
|
||||
}
|
||||
if errClose := bodyFile.Close(); errClose != nil {
|
||||
log.WithError(errClose).Warn("failed to close request body temp file")
|
||||
}
|
||||
} else if _, errWrite := w.Write(body); errWrite != nil {
|
||||
return errWrite
|
||||
} else if len(body) > 0 {
|
||||
bodyTrailingNewlines = countTrailingNewlinesBytes(body)
|
||||
}
|
||||
|
||||
if _, errWrite := io.WriteString(w, "\n\n"); errWrite != nil {
|
||||
if errWrite := writeSectionSpacing(w, bodyTrailingNewlines); errWrite != nil {
|
||||
return errWrite
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func countTrailingNewlinesBytes(payload []byte) int {
|
||||
count := 0
|
||||
for i := len(payload) - 1; i >= 0; i-- {
|
||||
if payload[i] != '\n' {
|
||||
break
|
||||
}
|
||||
count++
|
||||
}
|
||||
return count
|
||||
}
|
||||
|
||||
func writeSectionSpacing(w io.Writer, trailingNewlines int) error {
|
||||
missingNewlines := 3 - trailingNewlines
|
||||
if missingNewlines <= 0 {
|
||||
return nil
|
||||
}
|
||||
_, errWrite := io.WriteString(w, strings.Repeat("\n", missingNewlines))
|
||||
return errWrite
|
||||
}
|
||||
|
||||
type trailingNewlineTrackingWriter struct {
|
||||
writer io.Writer
|
||||
trailingNewlines int
|
||||
}
|
||||
|
||||
func (t *trailingNewlineTrackingWriter) Write(payload []byte) (int, error) {
|
||||
written, errWrite := t.writer.Write(payload)
|
||||
if written > 0 {
|
||||
writtenPayload := payload[:written]
|
||||
trailingNewlines := countTrailingNewlinesBytes(writtenPayload)
|
||||
if trailingNewlines == len(writtenPayload) {
|
||||
t.trailingNewlines += trailingNewlines
|
||||
} else {
|
||||
t.trailingNewlines = trailingNewlines
|
||||
}
|
||||
}
|
||||
return written, errWrite
|
||||
}
|
||||
|
||||
func hasSectionPayload(payload []byte) bool {
|
||||
return len(bytes.TrimSpace(payload)) > 0
|
||||
}
|
||||
|
||||
func inferDownstreamTransport(headers map[string][]string, websocketTimeline []byte) string {
|
||||
if hasSectionPayload(websocketTimeline) {
|
||||
return "websocket"
|
||||
}
|
||||
for key, values := range headers {
|
||||
if strings.EqualFold(strings.TrimSpace(key), "Upgrade") {
|
||||
for _, value := range values {
|
||||
if strings.EqualFold(strings.TrimSpace(value), "websocket") {
|
||||
return "websocket"
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
return "http"
|
||||
}
|
||||
|
||||
func inferUpstreamTransport(apiRequest, apiResponse, apiWebsocketTimeline []byte, _ []*interfaces.ErrorMessage) string {
|
||||
hasHTTP := hasSectionPayload(apiRequest) || hasSectionPayload(apiResponse)
|
||||
hasWS := hasSectionPayload(apiWebsocketTimeline)
|
||||
switch {
|
||||
case hasHTTP && hasWS:
|
||||
return "websocket+http"
|
||||
case hasWS:
|
||||
return "websocket"
|
||||
case hasHTTP:
|
||||
return "http"
|
||||
default:
|
||||
return ""
|
||||
}
|
||||
}
|
||||
|
||||
func writeAPISection(w io.Writer, sectionHeader string, sectionPrefix string, payload []byte, timestamp time.Time) error {
|
||||
if len(payload) == 0 {
|
||||
return nil
|
||||
@@ -623,11 +753,6 @@ func writeAPISection(w io.Writer, sectionHeader string, sectionPrefix string, pa
|
||||
if _, errWrite := w.Write(payload); errWrite != nil {
|
||||
return errWrite
|
||||
}
|
||||
if !bytes.HasSuffix(payload, []byte("\n")) {
|
||||
if _, errWrite := io.WriteString(w, "\n"); errWrite != nil {
|
||||
return errWrite
|
||||
}
|
||||
}
|
||||
} else {
|
||||
if _, errWrite := io.WriteString(w, sectionHeader); errWrite != nil {
|
||||
return errWrite
|
||||
@@ -640,12 +765,9 @@ func writeAPISection(w io.Writer, sectionHeader string, sectionPrefix string, pa
|
||||
if _, errWrite := w.Write(payload); errWrite != nil {
|
||||
return errWrite
|
||||
}
|
||||
if _, errWrite := io.WriteString(w, "\n"); errWrite != nil {
|
||||
return errWrite
|
||||
}
|
||||
}
|
||||
|
||||
if _, errWrite := io.WriteString(w, "\n"); errWrite != nil {
|
||||
if errWrite := writeSectionSpacing(w, countTrailingNewlinesBytes(payload)); errWrite != nil {
|
||||
return errWrite
|
||||
}
|
||||
return nil
|
||||
@@ -662,12 +784,17 @@ func writeAPIErrorResponses(w io.Writer, apiResponseErrors []*interfaces.ErrorMe
|
||||
if _, errWrite := io.WriteString(w, fmt.Sprintf("HTTP Status: %d\n", apiResponseErrors[i].StatusCode)); errWrite != nil {
|
||||
return errWrite
|
||||
}
|
||||
trailingNewlines := 1
|
||||
if apiResponseErrors[i].Error != nil {
|
||||
if _, errWrite := io.WriteString(w, apiResponseErrors[i].Error.Error()); errWrite != nil {
|
||||
errText := apiResponseErrors[i].Error.Error()
|
||||
if _, errWrite := io.WriteString(w, errText); errWrite != nil {
|
||||
return errWrite
|
||||
}
|
||||
if errText != "" {
|
||||
trailingNewlines = countTrailingNewlinesBytes([]byte(errText))
|
||||
}
|
||||
}
|
||||
if _, errWrite := io.WriteString(w, "\n\n"); errWrite != nil {
|
||||
if errWrite := writeSectionSpacing(w, trailingNewlines); errWrite != nil {
|
||||
return errWrite
|
||||
}
|
||||
}
|
||||
@@ -694,12 +821,18 @@ func writeResponseSection(w io.Writer, statusCode int, statusWritten bool, respo
|
||||
}
|
||||
}
|
||||
|
||||
if _, errWrite := io.WriteString(w, "\n"); errWrite != nil {
|
||||
return errWrite
|
||||
var bufferedReader *bufio.Reader
|
||||
if responseReader != nil {
|
||||
bufferedReader = bufio.NewReader(responseReader)
|
||||
}
|
||||
if !responseBodyStartsWithLeadingNewline(bufferedReader) {
|
||||
if _, errWrite := io.WriteString(w, "\n"); errWrite != nil {
|
||||
return errWrite
|
||||
}
|
||||
}
|
||||
|
||||
if responseReader != nil {
|
||||
if _, errCopy := io.Copy(w, responseReader); errCopy != nil {
|
||||
if bufferedReader != nil {
|
||||
if _, errCopy := io.Copy(w, bufferedReader); errCopy != nil {
|
||||
return errCopy
|
||||
}
|
||||
}
|
||||
@@ -717,6 +850,19 @@ func writeResponseSection(w io.Writer, statusCode int, statusWritten bool, respo
|
||||
return nil
|
||||
}
|
||||
|
||||
func responseBodyStartsWithLeadingNewline(reader *bufio.Reader) bool {
|
||||
if reader == nil {
|
||||
return false
|
||||
}
|
||||
if peeked, _ := reader.Peek(2); len(peeked) >= 2 && peeked[0] == '\r' && peeked[1] == '\n' {
|
||||
return true
|
||||
}
|
||||
if peeked, _ := reader.Peek(1); len(peeked) >= 1 && peeked[0] == '\n' {
|
||||
return true
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// formatLogContent creates the complete log content for non-streaming requests.
|
||||
//
|
||||
// Parameters:
|
||||
@@ -724,6 +870,7 @@ func writeResponseSection(w io.Writer, statusCode int, statusWritten bool, respo
|
||||
// - method: The HTTP method
|
||||
// - headers: The request headers
|
||||
// - body: The request body
|
||||
// - websocketTimeline: The downstream websocket event timeline
|
||||
// - apiRequest: The API request data
|
||||
// - apiResponse: The API response data
|
||||
// - response: The raw response data
|
||||
@@ -732,11 +879,42 @@ func writeResponseSection(w io.Writer, statusCode int, statusWritten bool, respo
|
||||
//
|
||||
// Returns:
|
||||
// - string: The formatted log content
|
||||
func (l *FileRequestLogger) formatLogContent(url, method string, headers map[string][]string, body, apiRequest, apiResponse, response []byte, status int, responseHeaders map[string][]string, apiResponseErrors []*interfaces.ErrorMessage) string {
|
||||
func (l *FileRequestLogger) formatLogContent(url, method string, headers map[string][]string, body, websocketTimeline, apiRequest, apiResponse, apiWebsocketTimeline, response []byte, status int, responseHeaders map[string][]string, apiResponseErrors []*interfaces.ErrorMessage) string {
|
||||
var content strings.Builder
|
||||
isWebsocketTranscript := hasSectionPayload(websocketTimeline)
|
||||
downstreamTransport := inferDownstreamTransport(headers, websocketTimeline)
|
||||
upstreamTransport := inferUpstreamTransport(apiRequest, apiResponse, apiWebsocketTimeline, apiResponseErrors)
|
||||
|
||||
// Request info
|
||||
content.WriteString(l.formatRequestInfo(url, method, headers, body))
|
||||
content.WriteString(l.formatRequestInfo(url, method, headers, body, downstreamTransport, upstreamTransport, !isWebsocketTranscript))
|
||||
|
||||
if len(websocketTimeline) > 0 {
|
||||
if bytes.HasPrefix(websocketTimeline, []byte("=== WEBSOCKET TIMELINE")) {
|
||||
content.Write(websocketTimeline)
|
||||
if !bytes.HasSuffix(websocketTimeline, []byte("\n")) {
|
||||
content.WriteString("\n")
|
||||
}
|
||||
} else {
|
||||
content.WriteString("=== WEBSOCKET TIMELINE ===\n")
|
||||
content.Write(websocketTimeline)
|
||||
content.WriteString("\n")
|
||||
}
|
||||
content.WriteString("\n")
|
||||
}
|
||||
|
||||
if len(apiWebsocketTimeline) > 0 {
|
||||
if bytes.HasPrefix(apiWebsocketTimeline, []byte("=== API WEBSOCKET TIMELINE")) {
|
||||
content.Write(apiWebsocketTimeline)
|
||||
if !bytes.HasSuffix(apiWebsocketTimeline, []byte("\n")) {
|
||||
content.WriteString("\n")
|
||||
}
|
||||
} else {
|
||||
content.WriteString("=== API WEBSOCKET TIMELINE ===\n")
|
||||
content.Write(apiWebsocketTimeline)
|
||||
content.WriteString("\n")
|
||||
}
|
||||
content.WriteString("\n")
|
||||
}
|
||||
|
||||
if len(apiRequest) > 0 {
|
||||
if bytes.HasPrefix(apiRequest, []byte("=== API REQUEST")) {
|
||||
@@ -773,6 +951,12 @@ func (l *FileRequestLogger) formatLogContent(url, method string, headers map[str
|
||||
content.WriteString("\n")
|
||||
}
|
||||
|
||||
if isWebsocketTranscript {
|
||||
// Mirror writeNonStreamingLog: websocket transcripts end with the dedicated
|
||||
// timeline sections instead of a generic downstream HTTP response block.
|
||||
return content.String()
|
||||
}
|
||||
|
||||
// Response section
|
||||
content.WriteString("=== RESPONSE ===\n")
|
||||
content.WriteString(fmt.Sprintf("Status: %d\n", status))
|
||||
@@ -933,13 +1117,19 @@ func (l *FileRequestLogger) decompressZstd(data []byte) ([]byte, error) {
|
||||
//
|
||||
// Returns:
|
||||
// - string: The formatted request information
|
||||
func (l *FileRequestLogger) formatRequestInfo(url, method string, headers map[string][]string, body []byte) string {
|
||||
func (l *FileRequestLogger) formatRequestInfo(url, method string, headers map[string][]string, body []byte, downstreamTransport string, upstreamTransport string, includeBody bool) string {
|
||||
var content strings.Builder
|
||||
|
||||
content.WriteString("=== REQUEST INFO ===\n")
|
||||
content.WriteString(fmt.Sprintf("Version: %s\n", buildinfo.Version))
|
||||
content.WriteString(fmt.Sprintf("URL: %s\n", url))
|
||||
content.WriteString(fmt.Sprintf("Method: %s\n", method))
|
||||
if strings.TrimSpace(downstreamTransport) != "" {
|
||||
content.WriteString(fmt.Sprintf("Downstream Transport: %s\n", downstreamTransport))
|
||||
}
|
||||
if strings.TrimSpace(upstreamTransport) != "" {
|
||||
content.WriteString(fmt.Sprintf("Upstream Transport: %s\n", upstreamTransport))
|
||||
}
|
||||
content.WriteString(fmt.Sprintf("Timestamp: %s\n", time.Now().Format(time.RFC3339Nano)))
|
||||
content.WriteString("\n")
|
||||
|
||||
@@ -952,6 +1142,10 @@ func (l *FileRequestLogger) formatRequestInfo(url, method string, headers map[st
|
||||
}
|
||||
content.WriteString("\n")
|
||||
|
||||
if !includeBody {
|
||||
return content.String()
|
||||
}
|
||||
|
||||
content.WriteString("=== REQUEST BODY ===\n")
|
||||
content.Write(body)
|
||||
content.WriteString("\n\n")
|
||||
@@ -1011,6 +1205,9 @@ type FileStreamingLogWriter struct {
|
||||
// apiResponse stores the upstream API response data.
|
||||
apiResponse []byte
|
||||
|
||||
// apiWebsocketTimeline stores the upstream websocket event timeline.
|
||||
apiWebsocketTimeline []byte
|
||||
|
||||
// apiResponseTimestamp captures when the API response was received.
|
||||
apiResponseTimestamp time.Time
|
||||
}
|
||||
@@ -1092,6 +1289,21 @@ func (w *FileStreamingLogWriter) WriteAPIResponse(apiResponse []byte) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// WriteAPIWebsocketTimeline buffers the upstream websocket timeline for later writing.
|
||||
//
|
||||
// Parameters:
|
||||
// - apiWebsocketTimeline: The upstream websocket event timeline
|
||||
//
|
||||
// Returns:
|
||||
// - error: Always returns nil (buffering cannot fail)
|
||||
func (w *FileStreamingLogWriter) WriteAPIWebsocketTimeline(apiWebsocketTimeline []byte) error {
|
||||
if len(apiWebsocketTimeline) == 0 {
|
||||
return nil
|
||||
}
|
||||
w.apiWebsocketTimeline = bytes.Clone(apiWebsocketTimeline)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (w *FileStreamingLogWriter) SetFirstChunkTimestamp(timestamp time.Time) {
|
||||
if !timestamp.IsZero() {
|
||||
w.apiResponseTimestamp = timestamp
|
||||
@@ -1100,7 +1312,7 @@ func (w *FileStreamingLogWriter) SetFirstChunkTimestamp(timestamp time.Time) {
|
||||
|
||||
// Close finalizes the log file and cleans up resources.
|
||||
// It writes all buffered data to the file in the correct order:
|
||||
// API REQUEST -> API RESPONSE -> RESPONSE (status, headers, body chunks)
|
||||
// API WEBSOCKET TIMELINE -> API REQUEST -> API RESPONSE -> RESPONSE (status, headers, body chunks)
|
||||
//
|
||||
// Returns:
|
||||
// - error: An error if closing fails, nil otherwise
|
||||
@@ -1182,7 +1394,10 @@ func (w *FileStreamingLogWriter) asyncWriter() {
|
||||
}
|
||||
|
||||
func (w *FileStreamingLogWriter) writeFinalLog(logFile *os.File) error {
|
||||
if errWrite := writeRequestInfoWithBody(logFile, w.url, w.method, w.requestHeaders, nil, w.requestBodyPath, w.timestamp); errWrite != nil {
|
||||
if errWrite := writeRequestInfoWithBody(logFile, w.url, w.method, w.requestHeaders, nil, w.requestBodyPath, w.timestamp, "http", inferUpstreamTransport(w.apiRequest, w.apiResponse, w.apiWebsocketTimeline, nil), true); errWrite != nil {
|
||||
return errWrite
|
||||
}
|
||||
if errWrite := writeAPISection(logFile, "=== API WEBSOCKET TIMELINE ===\n", "=== API WEBSOCKET TIMELINE", w.apiWebsocketTimeline, time.Time{}); errWrite != nil {
|
||||
return errWrite
|
||||
}
|
||||
if errWrite := writeAPISection(logFile, "=== API REQUEST ===\n", "=== API REQUEST", w.apiRequest, time.Time{}); errWrite != nil {
|
||||
@@ -1265,6 +1480,17 @@ func (w *NoOpStreamingLogWriter) WriteAPIResponse(_ []byte) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// WriteAPIWebsocketTimeline is a no-op implementation that does nothing and always returns nil.
|
||||
//
|
||||
// Parameters:
|
||||
// - apiWebsocketTimeline: The upstream websocket event timeline (ignored)
|
||||
//
|
||||
// Returns:
|
||||
// - error: Always returns nil
|
||||
func (w *NoOpStreamingLogWriter) WriteAPIWebsocketTimeline(_ []byte) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (w *NoOpStreamingLogWriter) SetFirstChunkTimestamp(_ time.Time) {}
|
||||
|
||||
// Close is a no-op implementation that does nothing and always returns nil.
|
||||
|
||||
@@ -16,6 +16,7 @@ import (
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/runtime/executor/helps"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/thinking"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/util"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/wsrelay"
|
||||
cliproxyauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth"
|
||||
cliproxyexecutor "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/executor"
|
||||
@@ -47,8 +48,16 @@ func NewAIStudioExecutor(cfg *config.Config, provider string, relay *wsrelay.Man
|
||||
// Identifier returns the executor identifier.
|
||||
func (e *AIStudioExecutor) Identifier() string { return "aistudio" }
|
||||
|
||||
// PrepareRequest prepares the HTTP request for execution (no-op for AI Studio).
|
||||
func (e *AIStudioExecutor) PrepareRequest(_ *http.Request, _ *cliproxyauth.Auth) error {
|
||||
// PrepareRequest prepares the HTTP request for execution.
|
||||
func (e *AIStudioExecutor) PrepareRequest(req *http.Request, auth *cliproxyauth.Auth) error {
|
||||
if req == nil {
|
||||
return nil
|
||||
}
|
||||
var attrs map[string]string
|
||||
if auth != nil {
|
||||
attrs = auth.Attributes
|
||||
}
|
||||
util.ApplyCustomHeadersFromAttrs(req, attrs)
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -67,6 +76,9 @@ func (e *AIStudioExecutor) HttpRequest(ctx context.Context, auth *cliproxyauth.A
|
||||
return nil, fmt.Errorf("aistudio executor: missing auth")
|
||||
}
|
||||
httpReq := req.WithContext(ctx)
|
||||
if err := e.PrepareRequest(httpReq, auth); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if httpReq.URL == nil || strings.TrimSpace(httpReq.URL.String()) == "" {
|
||||
return nil, fmt.Errorf("aistudio executor: request URL is empty")
|
||||
}
|
||||
@@ -131,6 +143,11 @@ func (e *AIStudioExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth,
|
||||
Headers: http.Header{"Content-Type": []string{"application/json"}},
|
||||
Body: body.payload,
|
||||
}
|
||||
var attrs map[string]string
|
||||
if auth != nil {
|
||||
attrs = auth.Attributes
|
||||
}
|
||||
util.ApplyCustomHeadersFromAttrs(&http.Request{Header: wsReq.Headers}, attrs)
|
||||
|
||||
var authID, authLabel, authType, authValue string
|
||||
if auth != nil {
|
||||
@@ -190,6 +207,11 @@ func (e *AIStudioExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth
|
||||
Headers: http.Header{"Content-Type": []string{"application/json"}},
|
||||
Body: body.payload,
|
||||
}
|
||||
var attrs map[string]string
|
||||
if auth != nil {
|
||||
attrs = auth.Attributes
|
||||
}
|
||||
util.ApplyCustomHeadersFromAttrs(&http.Request{Header: wsReq.Headers}, attrs)
|
||||
var authID, authLabel, authType, authValue string
|
||||
if auth != nil {
|
||||
authID = auth.ID
|
||||
|
||||
@@ -1320,6 +1320,11 @@ func (e *AntigravityExecutor) CountTokens(ctx context.Context, auth *cliproxyaut
|
||||
if host := resolveHost(base); host != "" {
|
||||
httpReq.Host = host
|
||||
}
|
||||
var attrs map[string]string
|
||||
if auth != nil {
|
||||
attrs = auth.Attributes
|
||||
}
|
||||
util.ApplyCustomHeadersFromAttrs(httpReq, attrs)
|
||||
|
||||
helps.RecordAPIRequest(ctx, e.cfg, helps.UpstreamRequestLog{
|
||||
URL: requestURL.String(),
|
||||
@@ -1614,6 +1619,11 @@ func (e *AntigravityExecutor) buildRequest(ctx context.Context, auth *cliproxyau
|
||||
if host := resolveHost(base); host != "" {
|
||||
httpReq.Host = host
|
||||
}
|
||||
var attrs map[string]string
|
||||
if auth != nil {
|
||||
attrs = auth.Attributes
|
||||
}
|
||||
util.ApplyCustomHeadersFromAttrs(httpReq, attrs)
|
||||
|
||||
var authID, authLabel, authType, authValue string
|
||||
if auth != nil {
|
||||
|
||||
@@ -6,7 +6,6 @@ import (
|
||||
"compress/flate"
|
||||
"compress/gzip"
|
||||
"context"
|
||||
"crypto/rand"
|
||||
"crypto/sha256"
|
||||
"encoding/hex"
|
||||
"encoding/json"
|
||||
@@ -18,6 +17,7 @@ import (
|
||||
"time"
|
||||
|
||||
"github.com/andybalholm/brotli"
|
||||
"github.com/google/uuid"
|
||||
"github.com/klauspost/compress/zstd"
|
||||
claudeauth "github.com/router-for-me/CLIProxyAPI/v6/internal/auth/claude"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
|
||||
@@ -92,7 +92,7 @@ func (e *ClaudeExecutor) HttpRequest(ctx context.Context, auth *cliproxyauth.Aut
|
||||
if err := e.PrepareRequest(httpReq, auth); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
httpClient := helps.NewProxyAwareHTTPClient(ctx, e.cfg, auth, 0)
|
||||
httpClient := helps.NewUtlsHTTPClient(e.cfg, auth, 0)
|
||||
return httpClient.Do(httpReq)
|
||||
}
|
||||
|
||||
@@ -188,7 +188,7 @@ func (e *ClaudeExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, r
|
||||
AuthValue: authValue,
|
||||
})
|
||||
|
||||
httpClient := helps.NewProxyAwareHTTPClient(ctx, e.cfg, auth, 0)
|
||||
httpClient := helps.NewUtlsHTTPClient(e.cfg, auth, 0)
|
||||
httpResp, err := httpClient.Do(httpReq)
|
||||
if err != nil {
|
||||
helps.RecordAPIResponseError(ctx, e.cfg, err)
|
||||
@@ -355,7 +355,7 @@ func (e *ClaudeExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.A
|
||||
AuthValue: authValue,
|
||||
})
|
||||
|
||||
httpClient := helps.NewProxyAwareHTTPClient(ctx, e.cfg, auth, 0)
|
||||
httpClient := helps.NewUtlsHTTPClient(e.cfg, auth, 0)
|
||||
httpResp, err := httpClient.Do(httpReq)
|
||||
if err != nil {
|
||||
helps.RecordAPIResponseError(ctx, e.cfg, err)
|
||||
@@ -522,7 +522,7 @@ func (e *ClaudeExecutor) CountTokens(ctx context.Context, auth *cliproxyauth.Aut
|
||||
AuthValue: authValue,
|
||||
})
|
||||
|
||||
httpClient := helps.NewProxyAwareHTTPClient(ctx, e.cfg, auth, 0)
|
||||
httpClient := helps.NewUtlsHTTPClient(e.cfg, auth, 0)
|
||||
resp, err := httpClient.Do(httpReq)
|
||||
if err != nil {
|
||||
helps.RecordAPIResponseError(ctx, e.cfg, err)
|
||||
@@ -813,7 +813,7 @@ func applyClaudeHeaders(r *http.Request, auth *cliproxyauth.Auth, apiKey string,
|
||||
deviceProfile = helps.ResolveClaudeDeviceProfile(auth, apiKey, ginHeaders, cfg)
|
||||
}
|
||||
|
||||
baseBetas := "claude-code-20250219,oauth-2025-04-20,interleaved-thinking-2025-05-14,context-management-2025-06-27,prompt-caching-scope-2026-01-05"
|
||||
baseBetas := "claude-code-20250219,oauth-2025-04-20,interleaved-thinking-2025-05-14,context-management-2025-06-27,prompt-caching-scope-2026-01-05,structured-outputs-2025-12-15,fast-mode-2026-02-01,redact-thinking-2026-02-12,token-efficient-tools-2026-03-28"
|
||||
if val := strings.TrimSpace(ginHeaders.Get("Anthropic-Beta")); val != "" {
|
||||
baseBetas = val
|
||||
if !strings.Contains(val, "oauth") {
|
||||
@@ -851,13 +851,22 @@ func applyClaudeHeaders(r *http.Request, auth *cliproxyauth.Auth, apiKey string,
|
||||
r.Header.Set("Anthropic-Beta", baseBetas)
|
||||
|
||||
misc.EnsureHeader(r.Header, ginHeaders, "Anthropic-Version", "2023-06-01")
|
||||
misc.EnsureHeader(r.Header, ginHeaders, "Anthropic-Dangerous-Direct-Browser-Access", "true")
|
||||
// Only set browser access header for API key mode; real Claude Code CLI does not send it.
|
||||
if useAPIKey {
|
||||
misc.EnsureHeader(r.Header, ginHeaders, "Anthropic-Dangerous-Direct-Browser-Access", "true")
|
||||
}
|
||||
misc.EnsureHeader(r.Header, ginHeaders, "X-App", "cli")
|
||||
// Values below match Claude Code 2.1.63 / @anthropic-ai/sdk 0.74.0 (updated 2026-02-28).
|
||||
misc.EnsureHeader(r.Header, ginHeaders, "X-Stainless-Retry-Count", "0")
|
||||
misc.EnsureHeader(r.Header, ginHeaders, "X-Stainless-Runtime", "node")
|
||||
misc.EnsureHeader(r.Header, ginHeaders, "X-Stainless-Lang", "js")
|
||||
misc.EnsureHeader(r.Header, ginHeaders, "X-Stainless-Timeout", hdrDefault(hd.Timeout, "600"))
|
||||
// Session ID: stable per auth/apiKey, matches Claude Code's X-Claude-Code-Session-Id header.
|
||||
misc.EnsureHeader(r.Header, ginHeaders, "X-Claude-Code-Session-Id", helps.CachedSessionID(apiKey))
|
||||
// Per-request UUID, matches Claude Code's x-client-request-id for first-party API.
|
||||
if isAnthropicBase {
|
||||
misc.EnsureHeader(r.Header, ginHeaders, "x-client-request-id", uuid.New().String())
|
||||
}
|
||||
r.Header.Set("Connection", "keep-alive")
|
||||
if stream {
|
||||
r.Header.Set("Accept", "text/event-stream")
|
||||
@@ -872,16 +881,16 @@ func applyClaudeHeaders(r *http.Request, auth *cliproxyauth.Auth, apiKey string,
|
||||
// Legacy mode keeps OS/Arch runtime-derived; stabilized mode pins OS/Arch
|
||||
// to the configured baseline while still allowing newer official
|
||||
// User-Agent/package/runtime tuples to upgrade the software fingerprint.
|
||||
var attrs map[string]string
|
||||
if auth != nil {
|
||||
attrs = auth.Attributes
|
||||
}
|
||||
util.ApplyCustomHeadersFromAttrs(r, attrs)
|
||||
if stabilizeDeviceProfile {
|
||||
helps.ApplyClaudeDeviceProfileHeaders(r, deviceProfile)
|
||||
} else {
|
||||
helps.ApplyClaudeLegacyDeviceHeaders(r, ginHeaders, cfg)
|
||||
}
|
||||
var attrs map[string]string
|
||||
if auth != nil {
|
||||
attrs = auth.Attributes
|
||||
}
|
||||
util.ApplyCustomHeadersFromAttrs(r, attrs)
|
||||
// Re-enforce Accept-Encoding: identity after ApplyCustomHeadersFromAttrs, which
|
||||
// may override it with a user-configured value. Compressed SSE breaks the line
|
||||
// scanner regardless of user preference, so this is non-negotiable for streams.
|
||||
@@ -907,7 +916,7 @@ func claudeCreds(a *cliproxyauth.Auth) (apiKey, baseURL string) {
|
||||
}
|
||||
|
||||
func checkSystemInstructions(payload []byte) []byte {
|
||||
return checkSystemInstructionsWithSigningMode(payload, false, false)
|
||||
return checkSystemInstructionsWithSigningMode(payload, false, false, "2.1.63", "", "")
|
||||
}
|
||||
|
||||
func isClaudeOAuthToken(apiKey string) bool {
|
||||
@@ -1102,6 +1111,38 @@ func getClientUserAgent(ctx context.Context) string {
|
||||
return ""
|
||||
}
|
||||
|
||||
// parseEntrypointFromUA extracts the entrypoint from a Claude Code User-Agent.
|
||||
// Format: "claude-cli/x.y.z (external, cli)" → "cli"
|
||||
// Format: "claude-cli/x.y.z (external, vscode)" → "vscode"
|
||||
// Returns "cli" if parsing fails or UA is not Claude Code.
|
||||
func parseEntrypointFromUA(userAgent string) string {
|
||||
// Find content inside parentheses
|
||||
start := strings.Index(userAgent, "(")
|
||||
end := strings.LastIndex(userAgent, ")")
|
||||
if start < 0 || end <= start {
|
||||
return "cli"
|
||||
}
|
||||
inner := userAgent[start+1 : end]
|
||||
// Split by comma, take the second part (entrypoint is at index 1, after USER_TYPE)
|
||||
// Format: "(USER_TYPE, ENTRYPOINT[, extra...])"
|
||||
parts := strings.Split(inner, ",")
|
||||
if len(parts) >= 2 {
|
||||
ep := strings.TrimSpace(parts[1])
|
||||
if ep != "" {
|
||||
return ep
|
||||
}
|
||||
}
|
||||
return "cli"
|
||||
}
|
||||
|
||||
// getWorkloadFromContext extracts workload identifier from the gin request headers.
|
||||
func getWorkloadFromContext(ctx context.Context) string {
|
||||
if ginCtx, ok := ctx.Value("gin").(*gin.Context); ok && ginCtx != nil && ginCtx.Request != nil {
|
||||
return strings.TrimSpace(ginCtx.GetHeader("X-CPA-Claude-Workload"))
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
// getCloakConfigFromAuth extracts cloak configuration from auth attributes.
|
||||
// Returns (cloakMode, strictMode, sensitiveWords, cacheUserID).
|
||||
func getCloakConfigFromAuth(auth *cliproxyauth.Auth) (string, bool, []string, bool) {
|
||||
@@ -1152,28 +1193,52 @@ func injectFakeUserID(payload []byte, apiKey string, useCache bool) []byte {
|
||||
return payload
|
||||
}
|
||||
|
||||
// fingerprintSalt is the salt used by Claude Code to compute the 3-char build fingerprint.
|
||||
const fingerprintSalt = "59cf53e54c78"
|
||||
|
||||
// computeFingerprint computes the 3-char build fingerprint that Claude Code embeds in cc_version.
|
||||
// Algorithm: SHA256(salt + messageText[4] + messageText[7] + messageText[20] + version)[:3]
|
||||
func computeFingerprint(messageText, version string) string {
|
||||
indices := [3]int{4, 7, 20}
|
||||
runes := []rune(messageText)
|
||||
var sb strings.Builder
|
||||
for _, idx := range indices {
|
||||
if idx < len(runes) {
|
||||
sb.WriteRune(runes[idx])
|
||||
} else {
|
||||
sb.WriteRune('0')
|
||||
}
|
||||
}
|
||||
input := fingerprintSalt + sb.String() + version
|
||||
h := sha256.Sum256([]byte(input))
|
||||
return hex.EncodeToString(h[:])[:3]
|
||||
}
|
||||
|
||||
// generateBillingHeader creates the x-anthropic-billing-header text block that
|
||||
// real Claude Code prepends to every system prompt array.
|
||||
// Format: x-anthropic-billing-header: cc_version=<ver>.<build>; cc_entrypoint=cli; cch=<hash>;
|
||||
func generateBillingHeader(payload []byte, experimentalCCHSigning bool) string {
|
||||
// Build hash: 3-char hex, matches the pattern seen in real requests (e.g. "a43")
|
||||
buildBytes := make([]byte, 2)
|
||||
_, _ = rand.Read(buildBytes)
|
||||
buildHash := hex.EncodeToString(buildBytes)[:3]
|
||||
// Format: x-anthropic-billing-header: cc_version=<ver>.<build>; cc_entrypoint=<ep>; cch=<hash>; [cc_workload=<wl>;]
|
||||
func generateBillingHeader(payload []byte, experimentalCCHSigning bool, version, messageText, entrypoint, workload string) string {
|
||||
if entrypoint == "" {
|
||||
entrypoint = "cli"
|
||||
}
|
||||
buildHash := computeFingerprint(messageText, version)
|
||||
workloadPart := ""
|
||||
if workload != "" {
|
||||
workloadPart = fmt.Sprintf(" cc_workload=%s;", workload)
|
||||
}
|
||||
|
||||
if experimentalCCHSigning {
|
||||
return fmt.Sprintf("x-anthropic-billing-header: cc_version=2.1.63.%s; cc_entrypoint=cli; cch=00000;", buildHash)
|
||||
return fmt.Sprintf("x-anthropic-billing-header: cc_version=%s.%s; cc_entrypoint=%s; cch=00000;%s", version, buildHash, entrypoint, workloadPart)
|
||||
}
|
||||
|
||||
// Generate a deterministic cch hash from the payload content (system + messages + tools).
|
||||
// Real Claude Code uses a 5-char hex hash that varies per request.
|
||||
h := sha256.Sum256(payload)
|
||||
cch := hex.EncodeToString(h[:])[:5]
|
||||
return fmt.Sprintf("x-anthropic-billing-header: cc_version=2.1.63.%s; cc_entrypoint=cli; cch=%s;", buildHash, cch)
|
||||
return fmt.Sprintf("x-anthropic-billing-header: cc_version=%s.%s; cc_entrypoint=%s; cch=%s;%s", version, buildHash, entrypoint, cch, workloadPart)
|
||||
}
|
||||
|
||||
func checkSystemInstructionsWithMode(payload []byte, strictMode bool) []byte {
|
||||
return checkSystemInstructionsWithSigningMode(payload, strictMode, false)
|
||||
return checkSystemInstructionsWithSigningMode(payload, strictMode, false, "2.1.63", "", "")
|
||||
}
|
||||
|
||||
// checkSystemInstructionsWithSigningMode injects Claude Code-style system blocks:
|
||||
@@ -1181,10 +1246,25 @@ func checkSystemInstructionsWithMode(payload []byte, strictMode bool) []byte {
|
||||
// system[0]: billing header (no cache_control)
|
||||
// system[1]: agent identifier (no cache_control)
|
||||
// system[2..]: user system messages (cache_control added when missing)
|
||||
func checkSystemInstructionsWithSigningMode(payload []byte, strictMode bool, experimentalCCHSigning bool) []byte {
|
||||
func checkSystemInstructionsWithSigningMode(payload []byte, strictMode bool, experimentalCCHSigning bool, version, entrypoint, workload string) []byte {
|
||||
system := gjson.GetBytes(payload, "system")
|
||||
|
||||
billingText := generateBillingHeader(payload, experimentalCCHSigning)
|
||||
// Extract original message text for fingerprint computation (before billing injection).
|
||||
// Use the first system text block's content as the fingerprint source.
|
||||
messageText := ""
|
||||
if system.IsArray() {
|
||||
system.ForEach(func(_, part gjson.Result) bool {
|
||||
if part.Get("type").String() == "text" {
|
||||
messageText = part.Get("text").String()
|
||||
return false
|
||||
}
|
||||
return true
|
||||
})
|
||||
} else if system.Type == gjson.String {
|
||||
messageText = system.String()
|
||||
}
|
||||
|
||||
billingText := generateBillingHeader(payload, experimentalCCHSigning, version, messageText, entrypoint, workload)
|
||||
billingBlock := fmt.Sprintf(`{"type":"text","text":"%s"}`, billingText)
|
||||
// No cache_control on the agent block. It is a cloaking artifact with zero cache
|
||||
// value (the last system block is what actually triggers caching of all system content).
|
||||
@@ -1273,7 +1353,10 @@ func applyCloaking(ctx context.Context, cfg *config.Config, auth *cliproxyauth.A
|
||||
|
||||
// Skip system instructions for claude-3-5-haiku models
|
||||
if !strings.HasPrefix(model, "claude-3-5-haiku") {
|
||||
payload = checkSystemInstructionsWithSigningMode(payload, strictMode, useExperimentalCCHSigning)
|
||||
billingVersion := helps.DefaultClaudeVersion(cfg)
|
||||
entrypoint := parseEntrypointFromUA(clientUserAgent)
|
||||
workload := getWorkloadFromContext(ctx)
|
||||
payload = checkSystemInstructionsWithSigningMode(payload, strictMode, useExperimentalCCHSigning, billingVersion, entrypoint, workload)
|
||||
}
|
||||
|
||||
// Inject fake user ID
|
||||
|
||||
@@ -101,7 +101,7 @@ func TestApplyClaudeHeaders_UsesConfiguredBaselineFingerprint(t *testing.T) {
|
||||
req := newClaudeHeaderTestRequest(t, incoming)
|
||||
applyClaudeHeaders(req, auth, "key-baseline", false, nil, cfg)
|
||||
|
||||
assertClaudeFingerprint(t, req.Header, "claude-cli/2.1.70 (external, cli)", "0.80.0", "v24.5.0", "MacOS", "arm64")
|
||||
assertClaudeFingerprint(t, req.Header, "evil-client/9.9", "9.9.9", "v24.5.0", "Linux", "x64")
|
||||
if got := req.Header.Get("X-Stainless-Timeout"); got != "900" {
|
||||
t.Fatalf("X-Stainless-Timeout = %q, want %q", got, "900")
|
||||
}
|
||||
|
||||
@@ -219,7 +219,7 @@ func (e *CodexWebsocketsExecutor) Execute(ctx context.Context, auth *cliproxyaut
|
||||
}
|
||||
|
||||
wsReqBody := buildCodexWebsocketRequestBody(body)
|
||||
helps.RecordAPIRequest(ctx, e.cfg, helps.UpstreamRequestLog{
|
||||
wsReqLog := helps.UpstreamRequestLog{
|
||||
URL: wsURL,
|
||||
Method: "WEBSOCKET",
|
||||
Headers: wsHeaders.Clone(),
|
||||
@@ -229,16 +229,14 @@ func (e *CodexWebsocketsExecutor) Execute(ctx context.Context, auth *cliproxyaut
|
||||
AuthLabel: authLabel,
|
||||
AuthType: authType,
|
||||
AuthValue: authValue,
|
||||
})
|
||||
}
|
||||
helps.RecordAPIWebsocketRequest(ctx, e.cfg, wsReqLog)
|
||||
|
||||
conn, respHS, errDial := e.ensureUpstreamConn(ctx, auth, sess, authID, wsURL, wsHeaders)
|
||||
if respHS != nil {
|
||||
helps.RecordAPIResponseMetadata(ctx, e.cfg, respHS.StatusCode, respHS.Header.Clone())
|
||||
}
|
||||
if errDial != nil {
|
||||
bodyErr := websocketHandshakeBody(respHS)
|
||||
if len(bodyErr) > 0 {
|
||||
helps.AppendAPIResponseChunk(ctx, e.cfg, bodyErr)
|
||||
if respHS != nil {
|
||||
helps.RecordAPIWebsocketUpgradeRejection(ctx, e.cfg, websocketUpgradeRequestLog(wsReqLog), respHS.StatusCode, respHS.Header.Clone(), bodyErr)
|
||||
}
|
||||
if respHS != nil && respHS.StatusCode == http.StatusUpgradeRequired {
|
||||
return e.CodexExecutor.Execute(ctx, auth, req, opts)
|
||||
@@ -246,10 +244,10 @@ func (e *CodexWebsocketsExecutor) Execute(ctx context.Context, auth *cliproxyaut
|
||||
if respHS != nil && respHS.StatusCode > 0 {
|
||||
return resp, statusErr{code: respHS.StatusCode, msg: string(bodyErr)}
|
||||
}
|
||||
helps.RecordAPIResponseError(ctx, e.cfg, errDial)
|
||||
helps.RecordAPIWebsocketError(ctx, e.cfg, "dial", errDial)
|
||||
return resp, errDial
|
||||
}
|
||||
closeHTTPResponseBody(respHS, "codex websockets executor: close handshake response body error")
|
||||
recordAPIWebsocketHandshake(ctx, e.cfg, respHS)
|
||||
if sess == nil {
|
||||
logCodexWebsocketConnected(executionSessionID, authID, wsURL)
|
||||
defer func() {
|
||||
@@ -278,10 +276,10 @@ func (e *CodexWebsocketsExecutor) Execute(ctx context.Context, auth *cliproxyaut
|
||||
// Retry once with a fresh websocket connection. This is mainly to handle
|
||||
// upstream closing the socket between sequential requests within the same
|
||||
// execution session.
|
||||
connRetry, _, errDialRetry := e.ensureUpstreamConn(ctx, auth, sess, authID, wsURL, wsHeaders)
|
||||
connRetry, respHSRetry, errDialRetry := e.ensureUpstreamConn(ctx, auth, sess, authID, wsURL, wsHeaders)
|
||||
if errDialRetry == nil && connRetry != nil {
|
||||
wsReqBodyRetry := buildCodexWebsocketRequestBody(body)
|
||||
helps.RecordAPIRequest(ctx, e.cfg, helps.UpstreamRequestLog{
|
||||
helps.RecordAPIWebsocketRequest(ctx, e.cfg, helps.UpstreamRequestLog{
|
||||
URL: wsURL,
|
||||
Method: "WEBSOCKET",
|
||||
Headers: wsHeaders.Clone(),
|
||||
@@ -292,20 +290,22 @@ func (e *CodexWebsocketsExecutor) Execute(ctx context.Context, auth *cliproxyaut
|
||||
AuthType: authType,
|
||||
AuthValue: authValue,
|
||||
})
|
||||
recordAPIWebsocketHandshake(ctx, e.cfg, respHSRetry)
|
||||
if errSendRetry := writeCodexWebsocketMessage(sess, connRetry, wsReqBodyRetry); errSendRetry == nil {
|
||||
conn = connRetry
|
||||
wsReqBody = wsReqBodyRetry
|
||||
} else {
|
||||
e.invalidateUpstreamConn(sess, connRetry, "send_error", errSendRetry)
|
||||
helps.RecordAPIResponseError(ctx, e.cfg, errSendRetry)
|
||||
helps.RecordAPIWebsocketError(ctx, e.cfg, "send_retry", errSendRetry)
|
||||
return resp, errSendRetry
|
||||
}
|
||||
} else {
|
||||
helps.RecordAPIResponseError(ctx, e.cfg, errDialRetry)
|
||||
closeHTTPResponseBody(respHSRetry, "codex websockets executor: close handshake response body error")
|
||||
helps.RecordAPIWebsocketError(ctx, e.cfg, "dial_retry", errDialRetry)
|
||||
return resp, errDialRetry
|
||||
}
|
||||
} else {
|
||||
helps.RecordAPIResponseError(ctx, e.cfg, errSend)
|
||||
helps.RecordAPIWebsocketError(ctx, e.cfg, "send", errSend)
|
||||
return resp, errSend
|
||||
}
|
||||
}
|
||||
@@ -316,7 +316,7 @@ func (e *CodexWebsocketsExecutor) Execute(ctx context.Context, auth *cliproxyaut
|
||||
}
|
||||
msgType, payload, errRead := readCodexWebsocketMessage(ctx, sess, conn, readCh)
|
||||
if errRead != nil {
|
||||
helps.RecordAPIResponseError(ctx, e.cfg, errRead)
|
||||
helps.RecordAPIWebsocketError(ctx, e.cfg, "read", errRead)
|
||||
return resp, errRead
|
||||
}
|
||||
if msgType != websocket.TextMessage {
|
||||
@@ -325,7 +325,7 @@ func (e *CodexWebsocketsExecutor) Execute(ctx context.Context, auth *cliproxyaut
|
||||
if sess != nil {
|
||||
e.invalidateUpstreamConn(sess, conn, "unexpected_binary", err)
|
||||
}
|
||||
helps.RecordAPIResponseError(ctx, e.cfg, err)
|
||||
helps.RecordAPIWebsocketError(ctx, e.cfg, "unexpected_binary", err)
|
||||
return resp, err
|
||||
}
|
||||
continue
|
||||
@@ -335,13 +335,13 @@ func (e *CodexWebsocketsExecutor) Execute(ctx context.Context, auth *cliproxyaut
|
||||
if len(payload) == 0 {
|
||||
continue
|
||||
}
|
||||
helps.AppendAPIResponseChunk(ctx, e.cfg, payload)
|
||||
helps.AppendAPIWebsocketResponse(ctx, e.cfg, payload)
|
||||
|
||||
if wsErr, ok := parseCodexWebsocketError(payload); ok {
|
||||
if sess != nil {
|
||||
e.invalidateUpstreamConn(sess, conn, "upstream_error", wsErr)
|
||||
}
|
||||
helps.RecordAPIResponseError(ctx, e.cfg, wsErr)
|
||||
helps.RecordAPIWebsocketError(ctx, e.cfg, "upstream_error", wsErr)
|
||||
return resp, wsErr
|
||||
}
|
||||
|
||||
@@ -413,7 +413,7 @@ func (e *CodexWebsocketsExecutor) ExecuteStream(ctx context.Context, auth *clipr
|
||||
}
|
||||
|
||||
wsReqBody := buildCodexWebsocketRequestBody(body)
|
||||
helps.RecordAPIRequest(ctx, e.cfg, helps.UpstreamRequestLog{
|
||||
wsReqLog := helps.UpstreamRequestLog{
|
||||
URL: wsURL,
|
||||
Method: "WEBSOCKET",
|
||||
Headers: wsHeaders.Clone(),
|
||||
@@ -423,18 +423,18 @@ func (e *CodexWebsocketsExecutor) ExecuteStream(ctx context.Context, auth *clipr
|
||||
AuthLabel: authLabel,
|
||||
AuthType: authType,
|
||||
AuthValue: authValue,
|
||||
})
|
||||
}
|
||||
helps.RecordAPIWebsocketRequest(ctx, e.cfg, wsReqLog)
|
||||
|
||||
conn, respHS, errDial := e.ensureUpstreamConn(ctx, auth, sess, authID, wsURL, wsHeaders)
|
||||
var upstreamHeaders http.Header
|
||||
if respHS != nil {
|
||||
upstreamHeaders = respHS.Header.Clone()
|
||||
helps.RecordAPIResponseMetadata(ctx, e.cfg, respHS.StatusCode, respHS.Header.Clone())
|
||||
}
|
||||
if errDial != nil {
|
||||
bodyErr := websocketHandshakeBody(respHS)
|
||||
if len(bodyErr) > 0 {
|
||||
helps.AppendAPIResponseChunk(ctx, e.cfg, bodyErr)
|
||||
if respHS != nil {
|
||||
helps.RecordAPIWebsocketUpgradeRejection(ctx, e.cfg, websocketUpgradeRequestLog(wsReqLog), respHS.StatusCode, respHS.Header.Clone(), bodyErr)
|
||||
}
|
||||
if respHS != nil && respHS.StatusCode == http.StatusUpgradeRequired {
|
||||
return e.CodexExecutor.ExecuteStream(ctx, auth, req, opts)
|
||||
@@ -442,13 +442,13 @@ func (e *CodexWebsocketsExecutor) ExecuteStream(ctx context.Context, auth *clipr
|
||||
if respHS != nil && respHS.StatusCode > 0 {
|
||||
return nil, statusErr{code: respHS.StatusCode, msg: string(bodyErr)}
|
||||
}
|
||||
helps.RecordAPIResponseError(ctx, e.cfg, errDial)
|
||||
helps.RecordAPIWebsocketError(ctx, e.cfg, "dial", errDial)
|
||||
if sess != nil {
|
||||
sess.reqMu.Unlock()
|
||||
}
|
||||
return nil, errDial
|
||||
}
|
||||
closeHTTPResponseBody(respHS, "codex websockets executor: close handshake response body error")
|
||||
recordAPIWebsocketHandshake(ctx, e.cfg, respHS)
|
||||
|
||||
if sess == nil {
|
||||
logCodexWebsocketConnected(executionSessionID, authID, wsURL)
|
||||
@@ -461,20 +461,21 @@ func (e *CodexWebsocketsExecutor) ExecuteStream(ctx context.Context, auth *clipr
|
||||
}
|
||||
|
||||
if errSend := writeCodexWebsocketMessage(sess, conn, wsReqBody); errSend != nil {
|
||||
helps.RecordAPIResponseError(ctx, e.cfg, errSend)
|
||||
helps.RecordAPIWebsocketError(ctx, e.cfg, "send", errSend)
|
||||
if sess != nil {
|
||||
e.invalidateUpstreamConn(sess, conn, "send_error", errSend)
|
||||
|
||||
// Retry once with a new websocket connection for the same execution session.
|
||||
connRetry, _, errDialRetry := e.ensureUpstreamConn(ctx, auth, sess, authID, wsURL, wsHeaders)
|
||||
connRetry, respHSRetry, errDialRetry := e.ensureUpstreamConn(ctx, auth, sess, authID, wsURL, wsHeaders)
|
||||
if errDialRetry != nil || connRetry == nil {
|
||||
helps.RecordAPIResponseError(ctx, e.cfg, errDialRetry)
|
||||
closeHTTPResponseBody(respHSRetry, "codex websockets executor: close handshake response body error")
|
||||
helps.RecordAPIWebsocketError(ctx, e.cfg, "dial_retry", errDialRetry)
|
||||
sess.clearActive(readCh)
|
||||
sess.reqMu.Unlock()
|
||||
return nil, errDialRetry
|
||||
}
|
||||
wsReqBodyRetry := buildCodexWebsocketRequestBody(body)
|
||||
helps.RecordAPIRequest(ctx, e.cfg, helps.UpstreamRequestLog{
|
||||
helps.RecordAPIWebsocketRequest(ctx, e.cfg, helps.UpstreamRequestLog{
|
||||
URL: wsURL,
|
||||
Method: "WEBSOCKET",
|
||||
Headers: wsHeaders.Clone(),
|
||||
@@ -485,8 +486,9 @@ func (e *CodexWebsocketsExecutor) ExecuteStream(ctx context.Context, auth *clipr
|
||||
AuthType: authType,
|
||||
AuthValue: authValue,
|
||||
})
|
||||
recordAPIWebsocketHandshake(ctx, e.cfg, respHSRetry)
|
||||
if errSendRetry := writeCodexWebsocketMessage(sess, connRetry, wsReqBodyRetry); errSendRetry != nil {
|
||||
helps.RecordAPIResponseError(ctx, e.cfg, errSendRetry)
|
||||
helps.RecordAPIWebsocketError(ctx, e.cfg, "send_retry", errSendRetry)
|
||||
e.invalidateUpstreamConn(sess, connRetry, "send_error", errSendRetry)
|
||||
sess.clearActive(readCh)
|
||||
sess.reqMu.Unlock()
|
||||
@@ -552,7 +554,7 @@ func (e *CodexWebsocketsExecutor) ExecuteStream(ctx context.Context, auth *clipr
|
||||
}
|
||||
terminateReason = "read_error"
|
||||
terminateErr = errRead
|
||||
helps.RecordAPIResponseError(ctx, e.cfg, errRead)
|
||||
helps.RecordAPIWebsocketError(ctx, e.cfg, "read", errRead)
|
||||
reporter.PublishFailure(ctx)
|
||||
_ = send(cliproxyexecutor.StreamChunk{Err: errRead})
|
||||
return
|
||||
@@ -562,7 +564,7 @@ func (e *CodexWebsocketsExecutor) ExecuteStream(ctx context.Context, auth *clipr
|
||||
err = fmt.Errorf("codex websockets executor: unexpected binary message")
|
||||
terminateReason = "unexpected_binary"
|
||||
terminateErr = err
|
||||
helps.RecordAPIResponseError(ctx, e.cfg, err)
|
||||
helps.RecordAPIWebsocketError(ctx, e.cfg, "unexpected_binary", err)
|
||||
reporter.PublishFailure(ctx)
|
||||
if sess != nil {
|
||||
e.invalidateUpstreamConn(sess, conn, "unexpected_binary", err)
|
||||
@@ -577,12 +579,12 @@ func (e *CodexWebsocketsExecutor) ExecuteStream(ctx context.Context, auth *clipr
|
||||
if len(payload) == 0 {
|
||||
continue
|
||||
}
|
||||
helps.AppendAPIResponseChunk(ctx, e.cfg, payload)
|
||||
helps.AppendAPIWebsocketResponse(ctx, e.cfg, payload)
|
||||
|
||||
if wsErr, ok := parseCodexWebsocketError(payload); ok {
|
||||
terminateReason = "upstream_error"
|
||||
terminateErr = wsErr
|
||||
helps.RecordAPIResponseError(ctx, e.cfg, wsErr)
|
||||
helps.RecordAPIWebsocketError(ctx, e.cfg, "upstream_error", wsErr)
|
||||
reporter.PublishFailure(ctx)
|
||||
if sess != nil {
|
||||
e.invalidateUpstreamConn(sess, conn, "upstream_error", wsErr)
|
||||
@@ -1022,6 +1024,32 @@ func encodeCodexWebsocketAsSSE(payload []byte) []byte {
|
||||
return line
|
||||
}
|
||||
|
||||
func websocketUpgradeRequestLog(info helps.UpstreamRequestLog) helps.UpstreamRequestLog {
|
||||
upgradeInfo := info
|
||||
upgradeInfo.URL = helps.WebsocketUpgradeRequestURL(info.URL)
|
||||
upgradeInfo.Method = http.MethodGet
|
||||
upgradeInfo.Body = nil
|
||||
upgradeInfo.Headers = info.Headers.Clone()
|
||||
if upgradeInfo.Headers == nil {
|
||||
upgradeInfo.Headers = make(http.Header)
|
||||
}
|
||||
if strings.TrimSpace(upgradeInfo.Headers.Get("Connection")) == "" {
|
||||
upgradeInfo.Headers.Set("Connection", "Upgrade")
|
||||
}
|
||||
if strings.TrimSpace(upgradeInfo.Headers.Get("Upgrade")) == "" {
|
||||
upgradeInfo.Headers.Set("Upgrade", "websocket")
|
||||
}
|
||||
return upgradeInfo
|
||||
}
|
||||
|
||||
func recordAPIWebsocketHandshake(ctx context.Context, cfg *config.Config, resp *http.Response) {
|
||||
if resp == nil {
|
||||
return
|
||||
}
|
||||
helps.RecordAPIWebsocketHandshake(ctx, cfg, resp.StatusCode, resp.Header.Clone())
|
||||
closeHTTPResponseBody(resp, "codex websockets executor: close handshake response body error")
|
||||
}
|
||||
|
||||
func websocketHandshakeBody(resp *http.Response) []byte {
|
||||
if resp == nil || resp.Body == nil {
|
||||
return nil
|
||||
|
||||
@@ -82,6 +82,11 @@ func (e *GeminiCLIExecutor) PrepareRequest(req *http.Request, auth *cliproxyauth
|
||||
}
|
||||
req.Header.Set("Authorization", "Bearer "+tok.AccessToken)
|
||||
applyGeminiCLIHeaders(req, "unknown")
|
||||
var attrs map[string]string
|
||||
if auth != nil {
|
||||
attrs = auth.Attributes
|
||||
}
|
||||
util.ApplyCustomHeadersFromAttrs(req, attrs)
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -191,6 +196,7 @@ func (e *GeminiCLIExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth
|
||||
reqHTTP.Header.Set("Authorization", "Bearer "+tok.AccessToken)
|
||||
applyGeminiCLIHeaders(reqHTTP, attemptModel)
|
||||
reqHTTP.Header.Set("Accept", "application/json")
|
||||
util.ApplyCustomHeadersFromAttrs(reqHTTP, auth.Attributes)
|
||||
helps.RecordAPIRequest(ctx, e.cfg, helps.UpstreamRequestLog{
|
||||
URL: url,
|
||||
Method: http.MethodPost,
|
||||
@@ -336,6 +342,7 @@ func (e *GeminiCLIExecutor) ExecuteStream(ctx context.Context, auth *cliproxyaut
|
||||
reqHTTP.Header.Set("Authorization", "Bearer "+tok.AccessToken)
|
||||
applyGeminiCLIHeaders(reqHTTP, attemptModel)
|
||||
reqHTTP.Header.Set("Accept", "text/event-stream")
|
||||
util.ApplyCustomHeadersFromAttrs(reqHTTP, auth.Attributes)
|
||||
helps.RecordAPIRequest(ctx, e.cfg, helps.UpstreamRequestLog{
|
||||
URL: url,
|
||||
Method: http.MethodPost,
|
||||
@@ -517,6 +524,7 @@ func (e *GeminiCLIExecutor) CountTokens(ctx context.Context, auth *cliproxyauth.
|
||||
reqHTTP.Header.Set("Authorization", "Bearer "+tok.AccessToken)
|
||||
applyGeminiCLIHeaders(reqHTTP, baseModel)
|
||||
reqHTTP.Header.Set("Accept", "application/json")
|
||||
util.ApplyCustomHeadersFromAttrs(reqHTTP, auth.Attributes)
|
||||
helps.RecordAPIRequest(ctx, e.cfg, helps.UpstreamRequestLog{
|
||||
URL: url,
|
||||
Method: http.MethodPost,
|
||||
|
||||
@@ -18,6 +18,7 @@ import (
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/runtime/executor/helps"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/thinking"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/util"
|
||||
cliproxyauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth"
|
||||
cliproxyexecutor "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/executor"
|
||||
sdktranslator "github.com/router-for-me/CLIProxyAPI/v6/sdk/translator"
|
||||
@@ -363,6 +364,11 @@ func (e *GeminiVertexExecutor) executeWithServiceAccount(ctx context.Context, au
|
||||
return resp, statusErr{code: 500, msg: "internal server error"}
|
||||
}
|
||||
applyGeminiHeaders(httpReq, auth)
|
||||
var attrs map[string]string
|
||||
if auth != nil {
|
||||
attrs = auth.Attributes
|
||||
}
|
||||
util.ApplyCustomHeadersFromAttrs(httpReq, attrs)
|
||||
|
||||
var authID, authLabel, authType, authValue string
|
||||
if auth != nil {
|
||||
@@ -478,6 +484,11 @@ func (e *GeminiVertexExecutor) executeWithAPIKey(ctx context.Context, auth *clip
|
||||
httpReq.Header.Set("x-goog-api-key", apiKey)
|
||||
}
|
||||
applyGeminiHeaders(httpReq, auth)
|
||||
var attrs map[string]string
|
||||
if auth != nil {
|
||||
attrs = auth.Attributes
|
||||
}
|
||||
util.ApplyCustomHeadersFromAttrs(httpReq, attrs)
|
||||
|
||||
var authID, authLabel, authType, authValue string
|
||||
if auth != nil {
|
||||
@@ -582,6 +593,11 @@ func (e *GeminiVertexExecutor) executeStreamWithServiceAccount(ctx context.Conte
|
||||
return nil, statusErr{code: 500, msg: "internal server error"}
|
||||
}
|
||||
applyGeminiHeaders(httpReq, auth)
|
||||
var attrs map[string]string
|
||||
if auth != nil {
|
||||
attrs = auth.Attributes
|
||||
}
|
||||
util.ApplyCustomHeadersFromAttrs(httpReq, attrs)
|
||||
|
||||
var authID, authLabel, authType, authValue string
|
||||
if auth != nil {
|
||||
@@ -706,6 +722,11 @@ func (e *GeminiVertexExecutor) executeStreamWithAPIKey(ctx context.Context, auth
|
||||
httpReq.Header.Set("x-goog-api-key", apiKey)
|
||||
}
|
||||
applyGeminiHeaders(httpReq, auth)
|
||||
var attrs map[string]string
|
||||
if auth != nil {
|
||||
attrs = auth.Attributes
|
||||
}
|
||||
util.ApplyCustomHeadersFromAttrs(httpReq, attrs)
|
||||
|
||||
var authID, authLabel, authType, authValue string
|
||||
if auth != nil {
|
||||
@@ -813,6 +834,11 @@ func (e *GeminiVertexExecutor) countTokensWithServiceAccount(ctx context.Context
|
||||
return cliproxyexecutor.Response{}, statusErr{code: 500, msg: "internal server error"}
|
||||
}
|
||||
applyGeminiHeaders(httpReq, auth)
|
||||
var attrs map[string]string
|
||||
if auth != nil {
|
||||
attrs = auth.Attributes
|
||||
}
|
||||
util.ApplyCustomHeadersFromAttrs(httpReq, attrs)
|
||||
|
||||
var authID, authLabel, authType, authValue string
|
||||
if auth != nil {
|
||||
@@ -897,6 +923,11 @@ func (e *GeminiVertexExecutor) countTokensWithAPIKey(ctx context.Context, auth *
|
||||
httpReq.Header.Set("x-goog-api-key", apiKey)
|
||||
}
|
||||
applyGeminiHeaders(httpReq, auth)
|
||||
var attrs map[string]string
|
||||
if auth != nil {
|
||||
attrs = auth.Attributes
|
||||
}
|
||||
util.ApplyCustomHeadersFromAttrs(httpReq, attrs)
|
||||
|
||||
var authID, authLabel, authType, authValue string
|
||||
if auth != nil {
|
||||
|
||||
@@ -358,6 +358,16 @@ func ApplyClaudeDeviceProfileHeaders(r *http.Request, profile ClaudeDeviceProfil
|
||||
r.Header.Set("X-Stainless-Arch", profile.Arch)
|
||||
}
|
||||
|
||||
// DefaultClaudeVersion returns the version string (e.g. "2.1.63") from the
|
||||
// current baseline device profile. It extracts the version from the User-Agent.
|
||||
func DefaultClaudeVersion(cfg *config.Config) string {
|
||||
profile := defaultClaudeDeviceProfile(cfg)
|
||||
if version, ok := parseClaudeCLIVersion(profile.UserAgent); ok {
|
||||
return strconv.Itoa(version.major) + "." + strconv.Itoa(version.minor) + "." + strconv.Itoa(version.patch)
|
||||
}
|
||||
return "2.1.63"
|
||||
}
|
||||
|
||||
func ApplyClaudeLegacyDeviceHeaders(r *http.Request, ginHeaders http.Header, cfg *config.Config) {
|
||||
if r == nil {
|
||||
return
|
||||
|
||||
@@ -6,6 +6,7 @@ import (
|
||||
"fmt"
|
||||
"html"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"sort"
|
||||
"strings"
|
||||
"time"
|
||||
@@ -19,9 +20,10 @@ import (
|
||||
)
|
||||
|
||||
const (
|
||||
apiAttemptsKey = "API_UPSTREAM_ATTEMPTS"
|
||||
apiRequestKey = "API_REQUEST"
|
||||
apiResponseKey = "API_RESPONSE"
|
||||
apiAttemptsKey = "API_UPSTREAM_ATTEMPTS"
|
||||
apiRequestKey = "API_REQUEST"
|
||||
apiResponseKey = "API_RESPONSE"
|
||||
apiWebsocketTimelineKey = "API_WEBSOCKET_TIMELINE"
|
||||
)
|
||||
|
||||
// UpstreamRequestLog captures the outbound upstream request details for logging.
|
||||
@@ -46,6 +48,7 @@ type upstreamAttempt struct {
|
||||
headersWritten bool
|
||||
bodyStarted bool
|
||||
bodyHasContent bool
|
||||
prevWasSSEEvent bool
|
||||
errorWritten bool
|
||||
}
|
||||
|
||||
@@ -173,15 +176,157 @@ func AppendAPIResponseChunk(ctx context.Context, cfg *config.Config, chunk []byt
|
||||
attempt.response.WriteString("Body:\n")
|
||||
attempt.bodyStarted = true
|
||||
}
|
||||
currentChunkIsSSEEvent := bytes.HasPrefix(data, []byte("event:"))
|
||||
currentChunkIsSSEData := bytes.HasPrefix(data, []byte("data:"))
|
||||
if attempt.bodyHasContent {
|
||||
attempt.response.WriteString("\n\n")
|
||||
separator := "\n\n"
|
||||
if attempt.prevWasSSEEvent && currentChunkIsSSEData {
|
||||
separator = "\n"
|
||||
}
|
||||
attempt.response.WriteString(separator)
|
||||
}
|
||||
attempt.response.WriteString(string(data))
|
||||
attempt.bodyHasContent = true
|
||||
attempt.prevWasSSEEvent = currentChunkIsSSEEvent
|
||||
|
||||
updateAggregatedResponse(ginCtx, attempts)
|
||||
}
|
||||
|
||||
// RecordAPIWebsocketRequest stores an upstream websocket request event in Gin context.
|
||||
func RecordAPIWebsocketRequest(ctx context.Context, cfg *config.Config, info UpstreamRequestLog) {
|
||||
if cfg == nil || !cfg.RequestLog {
|
||||
return
|
||||
}
|
||||
ginCtx := ginContextFrom(ctx)
|
||||
if ginCtx == nil {
|
||||
return
|
||||
}
|
||||
|
||||
builder := &strings.Builder{}
|
||||
builder.WriteString(fmt.Sprintf("Timestamp: %s\n", time.Now().Format(time.RFC3339Nano)))
|
||||
builder.WriteString("Event: api.websocket.request\n")
|
||||
if info.URL != "" {
|
||||
builder.WriteString(fmt.Sprintf("Upstream URL: %s\n", info.URL))
|
||||
}
|
||||
if auth := formatAuthInfo(info); auth != "" {
|
||||
builder.WriteString(fmt.Sprintf("Auth: %s\n", auth))
|
||||
}
|
||||
builder.WriteString("Headers:\n")
|
||||
writeHeaders(builder, info.Headers)
|
||||
builder.WriteString("\nBody:\n")
|
||||
if len(info.Body) > 0 {
|
||||
builder.Write(info.Body)
|
||||
} else {
|
||||
builder.WriteString("<empty>")
|
||||
}
|
||||
builder.WriteString("\n")
|
||||
|
||||
appendAPIWebsocketTimeline(ginCtx, []byte(builder.String()))
|
||||
}
|
||||
|
||||
// RecordAPIWebsocketHandshake stores the upstream websocket handshake response metadata.
|
||||
func RecordAPIWebsocketHandshake(ctx context.Context, cfg *config.Config, status int, headers http.Header) {
|
||||
if cfg == nil || !cfg.RequestLog {
|
||||
return
|
||||
}
|
||||
ginCtx := ginContextFrom(ctx)
|
||||
if ginCtx == nil {
|
||||
return
|
||||
}
|
||||
|
||||
builder := &strings.Builder{}
|
||||
builder.WriteString(fmt.Sprintf("Timestamp: %s\n", time.Now().Format(time.RFC3339Nano)))
|
||||
builder.WriteString("Event: api.websocket.handshake\n")
|
||||
if status > 0 {
|
||||
builder.WriteString(fmt.Sprintf("Status: %d\n", status))
|
||||
}
|
||||
builder.WriteString("Headers:\n")
|
||||
writeHeaders(builder, headers)
|
||||
builder.WriteString("\n")
|
||||
|
||||
appendAPIWebsocketTimeline(ginCtx, []byte(builder.String()))
|
||||
}
|
||||
|
||||
// RecordAPIWebsocketUpgradeRejection stores a rejected websocket upgrade as an HTTP attempt.
|
||||
func RecordAPIWebsocketUpgradeRejection(ctx context.Context, cfg *config.Config, info UpstreamRequestLog, status int, headers http.Header, body []byte) {
|
||||
if cfg == nil || !cfg.RequestLog {
|
||||
return
|
||||
}
|
||||
ginCtx := ginContextFrom(ctx)
|
||||
if ginCtx == nil {
|
||||
return
|
||||
}
|
||||
|
||||
RecordAPIRequest(ctx, cfg, info)
|
||||
RecordAPIResponseMetadata(ctx, cfg, status, headers)
|
||||
AppendAPIResponseChunk(ctx, cfg, body)
|
||||
}
|
||||
|
||||
// WebsocketUpgradeRequestURL converts a websocket URL back to its HTTP handshake URL for logging.
|
||||
func WebsocketUpgradeRequestURL(rawURL string) string {
|
||||
trimmedURL := strings.TrimSpace(rawURL)
|
||||
if trimmedURL == "" {
|
||||
return ""
|
||||
}
|
||||
parsed, err := url.Parse(trimmedURL)
|
||||
if err != nil {
|
||||
return trimmedURL
|
||||
}
|
||||
switch strings.ToLower(parsed.Scheme) {
|
||||
case "ws":
|
||||
parsed.Scheme = "http"
|
||||
case "wss":
|
||||
parsed.Scheme = "https"
|
||||
}
|
||||
return parsed.String()
|
||||
}
|
||||
|
||||
// AppendAPIWebsocketResponse stores an upstream websocket response frame in Gin context.
|
||||
func AppendAPIWebsocketResponse(ctx context.Context, cfg *config.Config, payload []byte) {
|
||||
if cfg == nil || !cfg.RequestLog {
|
||||
return
|
||||
}
|
||||
data := bytes.TrimSpace(payload)
|
||||
if len(data) == 0 {
|
||||
return
|
||||
}
|
||||
ginCtx := ginContextFrom(ctx)
|
||||
if ginCtx == nil {
|
||||
return
|
||||
}
|
||||
markAPIResponseTimestamp(ginCtx)
|
||||
|
||||
builder := &strings.Builder{}
|
||||
builder.WriteString(fmt.Sprintf("Timestamp: %s\n", time.Now().Format(time.RFC3339Nano)))
|
||||
builder.WriteString("Event: api.websocket.response\n")
|
||||
builder.Write(data)
|
||||
builder.WriteString("\n")
|
||||
|
||||
appendAPIWebsocketTimeline(ginCtx, []byte(builder.String()))
|
||||
}
|
||||
|
||||
// RecordAPIWebsocketError stores an upstream websocket error event in Gin context.
|
||||
func RecordAPIWebsocketError(ctx context.Context, cfg *config.Config, stage string, err error) {
|
||||
if cfg == nil || !cfg.RequestLog || err == nil {
|
||||
return
|
||||
}
|
||||
ginCtx := ginContextFrom(ctx)
|
||||
if ginCtx == nil {
|
||||
return
|
||||
}
|
||||
markAPIResponseTimestamp(ginCtx)
|
||||
|
||||
builder := &strings.Builder{}
|
||||
builder.WriteString(fmt.Sprintf("Timestamp: %s\n", time.Now().Format(time.RFC3339Nano)))
|
||||
builder.WriteString("Event: api.websocket.error\n")
|
||||
if trimmed := strings.TrimSpace(stage); trimmed != "" {
|
||||
builder.WriteString(fmt.Sprintf("Stage: %s\n", trimmed))
|
||||
}
|
||||
builder.WriteString(fmt.Sprintf("Error: %s\n", err.Error()))
|
||||
|
||||
appendAPIWebsocketTimeline(ginCtx, []byte(builder.String()))
|
||||
}
|
||||
|
||||
func ginContextFrom(ctx context.Context) *gin.Context {
|
||||
ginCtx, _ := ctx.Value("gin").(*gin.Context)
|
||||
return ginCtx
|
||||
@@ -259,6 +404,40 @@ func updateAggregatedResponse(ginCtx *gin.Context, attempts []*upstreamAttempt)
|
||||
ginCtx.Set(apiResponseKey, []byte(builder.String()))
|
||||
}
|
||||
|
||||
func appendAPIWebsocketTimeline(ginCtx *gin.Context, chunk []byte) {
|
||||
if ginCtx == nil {
|
||||
return
|
||||
}
|
||||
data := bytes.TrimSpace(chunk)
|
||||
if len(data) == 0 {
|
||||
return
|
||||
}
|
||||
if existing, exists := ginCtx.Get(apiWebsocketTimelineKey); exists {
|
||||
if existingBytes, ok := existing.([]byte); ok && len(existingBytes) > 0 {
|
||||
combined := make([]byte, 0, len(existingBytes)+len(data)+2)
|
||||
combined = append(combined, existingBytes...)
|
||||
if !bytes.HasSuffix(existingBytes, []byte("\n")) {
|
||||
combined = append(combined, '\n')
|
||||
}
|
||||
combined = append(combined, '\n')
|
||||
combined = append(combined, data...)
|
||||
ginCtx.Set(apiWebsocketTimelineKey, combined)
|
||||
return
|
||||
}
|
||||
}
|
||||
ginCtx.Set(apiWebsocketTimelineKey, bytes.Clone(data))
|
||||
}
|
||||
|
||||
func markAPIResponseTimestamp(ginCtx *gin.Context) {
|
||||
if ginCtx == nil {
|
||||
return
|
||||
}
|
||||
if _, exists := ginCtx.Get("API_RESPONSE_TIMESTAMP"); exists {
|
||||
return
|
||||
}
|
||||
ginCtx.Set("API_RESPONSE_TIMESTAMP", time.Now())
|
||||
}
|
||||
|
||||
func writeHeaders(builder *strings.Builder, headers http.Header) {
|
||||
if builder == nil {
|
||||
return
|
||||
|
||||
92
internal/runtime/executor/helps/session_id_cache.go
Normal file
92
internal/runtime/executor/helps/session_id_cache.go
Normal file
@@ -0,0 +1,92 @@
|
||||
package helps
|
||||
|
||||
import (
|
||||
"crypto/sha256"
|
||||
"encoding/hex"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
)
|
||||
|
||||
type sessionIDCacheEntry struct {
|
||||
value string
|
||||
expire time.Time
|
||||
}
|
||||
|
||||
var (
|
||||
sessionIDCache = make(map[string]sessionIDCacheEntry)
|
||||
sessionIDCacheMu sync.RWMutex
|
||||
sessionIDCacheCleanupOnce sync.Once
|
||||
)
|
||||
|
||||
const (
|
||||
sessionIDTTL = time.Hour
|
||||
sessionIDCacheCleanupPeriod = 15 * time.Minute
|
||||
)
|
||||
|
||||
func startSessionIDCacheCleanup() {
|
||||
go func() {
|
||||
ticker := time.NewTicker(sessionIDCacheCleanupPeriod)
|
||||
defer ticker.Stop()
|
||||
for range ticker.C {
|
||||
purgeExpiredSessionIDs()
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
func purgeExpiredSessionIDs() {
|
||||
now := time.Now()
|
||||
sessionIDCacheMu.Lock()
|
||||
for key, entry := range sessionIDCache {
|
||||
if !entry.expire.After(now) {
|
||||
delete(sessionIDCache, key)
|
||||
}
|
||||
}
|
||||
sessionIDCacheMu.Unlock()
|
||||
}
|
||||
|
||||
func sessionIDCacheKey(apiKey string) string {
|
||||
sum := sha256.Sum256([]byte(apiKey))
|
||||
return hex.EncodeToString(sum[:])
|
||||
}
|
||||
|
||||
// CachedSessionID returns a stable session UUID per apiKey, refreshing the TTL on each access.
|
||||
func CachedSessionID(apiKey string) string {
|
||||
if apiKey == "" {
|
||||
return uuid.New().String()
|
||||
}
|
||||
|
||||
sessionIDCacheCleanupOnce.Do(startSessionIDCacheCleanup)
|
||||
|
||||
key := sessionIDCacheKey(apiKey)
|
||||
now := time.Now()
|
||||
|
||||
sessionIDCacheMu.RLock()
|
||||
entry, ok := sessionIDCache[key]
|
||||
valid := ok && entry.value != "" && entry.expire.After(now)
|
||||
sessionIDCacheMu.RUnlock()
|
||||
if valid {
|
||||
sessionIDCacheMu.Lock()
|
||||
entry = sessionIDCache[key]
|
||||
if entry.value != "" && entry.expire.After(now) {
|
||||
entry.expire = now.Add(sessionIDTTL)
|
||||
sessionIDCache[key] = entry
|
||||
sessionIDCacheMu.Unlock()
|
||||
return entry.value
|
||||
}
|
||||
sessionIDCacheMu.Unlock()
|
||||
}
|
||||
|
||||
newID := uuid.New().String()
|
||||
|
||||
sessionIDCacheMu.Lock()
|
||||
entry, ok = sessionIDCache[key]
|
||||
if !ok || entry.value == "" || !entry.expire.After(now) {
|
||||
entry.value = newID
|
||||
}
|
||||
entry.expire = now.Add(sessionIDTTL)
|
||||
sessionIDCache[key] = entry
|
||||
sessionIDCacheMu.Unlock()
|
||||
return entry.value
|
||||
}
|
||||
188
internal/runtime/executor/helps/utls_client.go
Normal file
188
internal/runtime/executor/helps/utls_client.go
Normal file
@@ -0,0 +1,188 @@
|
||||
package helps
|
||||
|
||||
import (
|
||||
"net"
|
||||
"net/http"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
tls "github.com/refraction-networking/utls"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
|
||||
cliproxyauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/sdk/proxyutil"
|
||||
log "github.com/sirupsen/logrus"
|
||||
"golang.org/x/net/http2"
|
||||
"golang.org/x/net/proxy"
|
||||
)
|
||||
|
||||
// utlsRoundTripper implements http.RoundTripper using utls with Chrome fingerprint
|
||||
// to bypass Cloudflare's TLS fingerprinting on Anthropic domains.
|
||||
type utlsRoundTripper struct {
|
||||
mu sync.Mutex
|
||||
connections map[string]*http2.ClientConn
|
||||
pending map[string]*sync.Cond
|
||||
dialer proxy.Dialer
|
||||
}
|
||||
|
||||
func newUtlsRoundTripper(proxyURL string) *utlsRoundTripper {
|
||||
var dialer proxy.Dialer = proxy.Direct
|
||||
if proxyURL != "" {
|
||||
proxyDialer, mode, errBuild := proxyutil.BuildDialer(proxyURL)
|
||||
if errBuild != nil {
|
||||
log.Errorf("utls: failed to configure proxy dialer for %q: %v", proxyURL, errBuild)
|
||||
} else if mode != proxyutil.ModeInherit && proxyDialer != nil {
|
||||
dialer = proxyDialer
|
||||
}
|
||||
}
|
||||
return &utlsRoundTripper{
|
||||
connections: make(map[string]*http2.ClientConn),
|
||||
pending: make(map[string]*sync.Cond),
|
||||
dialer: dialer,
|
||||
}
|
||||
}
|
||||
|
||||
func (t *utlsRoundTripper) getOrCreateConnection(host, addr string) (*http2.ClientConn, error) {
|
||||
t.mu.Lock()
|
||||
|
||||
if h2Conn, ok := t.connections[host]; ok && h2Conn.CanTakeNewRequest() {
|
||||
t.mu.Unlock()
|
||||
return h2Conn, nil
|
||||
}
|
||||
|
||||
if cond, ok := t.pending[host]; ok {
|
||||
cond.Wait()
|
||||
if h2Conn, ok := t.connections[host]; ok && h2Conn.CanTakeNewRequest() {
|
||||
t.mu.Unlock()
|
||||
return h2Conn, nil
|
||||
}
|
||||
}
|
||||
|
||||
cond := sync.NewCond(&t.mu)
|
||||
t.pending[host] = cond
|
||||
t.mu.Unlock()
|
||||
|
||||
h2Conn, err := t.createConnection(host, addr)
|
||||
|
||||
t.mu.Lock()
|
||||
defer t.mu.Unlock()
|
||||
|
||||
delete(t.pending, host)
|
||||
cond.Broadcast()
|
||||
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
t.connections[host] = h2Conn
|
||||
return h2Conn, nil
|
||||
}
|
||||
|
||||
func (t *utlsRoundTripper) createConnection(host, addr string) (*http2.ClientConn, error) {
|
||||
conn, err := t.dialer.Dial("tcp", addr)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
tlsConfig := &tls.Config{ServerName: host}
|
||||
tlsConn := tls.UClient(conn, tlsConfig, tls.HelloChrome_Auto)
|
||||
|
||||
if err := tlsConn.Handshake(); err != nil {
|
||||
conn.Close()
|
||||
return nil, err
|
||||
}
|
||||
|
||||
tr := &http2.Transport{}
|
||||
h2Conn, err := tr.NewClientConn(tlsConn)
|
||||
if err != nil {
|
||||
tlsConn.Close()
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return h2Conn, nil
|
||||
}
|
||||
|
||||
func (t *utlsRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) {
|
||||
hostname := req.URL.Hostname()
|
||||
port := req.URL.Port()
|
||||
if port == "" {
|
||||
port = "443"
|
||||
}
|
||||
addr := net.JoinHostPort(hostname, port)
|
||||
|
||||
h2Conn, err := t.getOrCreateConnection(hostname, addr)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
resp, err := h2Conn.RoundTrip(req)
|
||||
if err != nil {
|
||||
t.mu.Lock()
|
||||
if cached, ok := t.connections[hostname]; ok && cached == h2Conn {
|
||||
delete(t.connections, hostname)
|
||||
}
|
||||
t.mu.Unlock()
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return resp, nil
|
||||
}
|
||||
|
||||
// anthropicHosts contains the hosts that should use utls Chrome TLS fingerprint.
|
||||
var anthropicHosts = map[string]struct{}{
|
||||
"api.anthropic.com": {},
|
||||
}
|
||||
|
||||
// fallbackRoundTripper uses utls for Anthropic HTTPS hosts and falls back to
|
||||
// standard transport for all other requests (non-HTTPS or non-Anthropic hosts).
|
||||
type fallbackRoundTripper struct {
|
||||
utls *utlsRoundTripper
|
||||
fallback http.RoundTripper
|
||||
}
|
||||
|
||||
func (f *fallbackRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) {
|
||||
if req.URL.Scheme == "https" {
|
||||
if _, ok := anthropicHosts[strings.ToLower(req.URL.Hostname())]; ok {
|
||||
return f.utls.RoundTrip(req)
|
||||
}
|
||||
}
|
||||
return f.fallback.RoundTrip(req)
|
||||
}
|
||||
|
||||
// NewUtlsHTTPClient creates an HTTP client using utls Chrome TLS fingerprint.
|
||||
// Use this for Claude API requests to match real Claude Code's TLS behavior.
|
||||
// Falls back to standard transport for non-HTTPS requests.
|
||||
func NewUtlsHTTPClient(cfg *config.Config, auth *cliproxyauth.Auth, timeout time.Duration) *http.Client {
|
||||
var proxyURL string
|
||||
if auth != nil {
|
||||
proxyURL = strings.TrimSpace(auth.ProxyURL)
|
||||
}
|
||||
if proxyURL == "" && cfg != nil {
|
||||
proxyURL = strings.TrimSpace(cfg.ProxyURL)
|
||||
}
|
||||
|
||||
utlsRT := newUtlsRoundTripper(proxyURL)
|
||||
|
||||
var standardTransport http.RoundTripper = &http.Transport{
|
||||
DialContext: (&net.Dialer{
|
||||
Timeout: 30 * time.Second,
|
||||
KeepAlive: 30 * time.Second,
|
||||
}).DialContext,
|
||||
}
|
||||
if proxyURL != "" {
|
||||
if transport := buildProxyTransport(proxyURL); transport != nil {
|
||||
standardTransport = transport
|
||||
}
|
||||
}
|
||||
|
||||
client := &http.Client{
|
||||
Transport: &fallbackRoundTripper{
|
||||
utls: utlsRT,
|
||||
fallback: standardTransport,
|
||||
},
|
||||
}
|
||||
if timeout > 0 {
|
||||
client.Timeout = timeout
|
||||
}
|
||||
return client
|
||||
}
|
||||
@@ -117,6 +117,11 @@ func (e *IFlowExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, re
|
||||
return resp, err
|
||||
}
|
||||
applyIFlowHeaders(httpReq, apiKey, false)
|
||||
var attrs map[string]string
|
||||
if auth != nil {
|
||||
attrs = auth.Attributes
|
||||
}
|
||||
util.ApplyCustomHeadersFromAttrs(httpReq, attrs)
|
||||
var authID, authLabel, authType, authValue string
|
||||
if auth != nil {
|
||||
authID = auth.ID
|
||||
@@ -225,6 +230,11 @@ func (e *IFlowExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Au
|
||||
return nil, err
|
||||
}
|
||||
applyIFlowHeaders(httpReq, apiKey, true)
|
||||
var attrs map[string]string
|
||||
if auth != nil {
|
||||
attrs = auth.Attributes
|
||||
}
|
||||
util.ApplyCustomHeadersFromAttrs(httpReq, attrs)
|
||||
var authID, authLabel, authType, authValue string
|
||||
if auth != nil {
|
||||
authID = auth.ID
|
||||
|
||||
@@ -17,6 +17,7 @@ import (
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/runtime/executor/helps"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/thinking"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/util"
|
||||
cliproxyauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth"
|
||||
cliproxyexecutor "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/executor"
|
||||
sdktranslator "github.com/router-for-me/CLIProxyAPI/v6/sdk/translator"
|
||||
@@ -46,6 +47,11 @@ func (e *KimiExecutor) PrepareRequest(req *http.Request, auth *cliproxyauth.Auth
|
||||
if strings.TrimSpace(token) != "" {
|
||||
req.Header.Set("Authorization", "Bearer "+token)
|
||||
}
|
||||
var attrs map[string]string
|
||||
if auth != nil {
|
||||
attrs = auth.Attributes
|
||||
}
|
||||
util.ApplyCustomHeadersFromAttrs(req, attrs)
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -114,6 +120,11 @@ func (e *KimiExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, req
|
||||
return resp, err
|
||||
}
|
||||
applyKimiHeadersWithAuth(httpReq, token, false, auth)
|
||||
var attrs map[string]string
|
||||
if auth != nil {
|
||||
attrs = auth.Attributes
|
||||
}
|
||||
util.ApplyCustomHeadersFromAttrs(httpReq, attrs)
|
||||
var authID, authLabel, authType, authValue string
|
||||
if auth != nil {
|
||||
authID = auth.ID
|
||||
@@ -218,6 +229,11 @@ func (e *KimiExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Aut
|
||||
return nil, err
|
||||
}
|
||||
applyKimiHeadersWithAuth(httpReq, token, true, auth)
|
||||
var attrs map[string]string
|
||||
if auth != nil {
|
||||
attrs = auth.Attributes
|
||||
}
|
||||
util.ApplyCustomHeadersFromAttrs(httpReq, attrs)
|
||||
var authID, authLabel, authType, authValue string
|
||||
if auth != nil {
|
||||
authID = auth.ID
|
||||
|
||||
@@ -15,6 +15,7 @@ import (
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/runtime/executor/helps"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/thinking"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/util"
|
||||
cliproxyauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth"
|
||||
cliproxyexecutor "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/executor"
|
||||
sdktranslator "github.com/router-for-me/CLIProxyAPI/v6/sdk/translator"
|
||||
@@ -257,6 +258,11 @@ func (e *QwenExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, req
|
||||
return resp, err
|
||||
}
|
||||
applyQwenHeaders(httpReq, token, false)
|
||||
var attrs map[string]string
|
||||
if auth != nil {
|
||||
attrs = auth.Attributes
|
||||
}
|
||||
util.ApplyCustomHeadersFromAttrs(httpReq, attrs)
|
||||
var authLabel, authType, authValue string
|
||||
if auth != nil {
|
||||
authLabel = auth.Label
|
||||
@@ -367,6 +373,11 @@ func (e *QwenExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Aut
|
||||
return nil, err
|
||||
}
|
||||
applyQwenHeaders(httpReq, token, true)
|
||||
var attrs map[string]string
|
||||
if auth != nil {
|
||||
attrs = auth.Attributes
|
||||
}
|
||||
util.ApplyCustomHeadersFromAttrs(httpReq, attrs)
|
||||
var authLabel, authType, authValue string
|
||||
if auth != nil {
|
||||
authLabel = auth.Label
|
||||
|
||||
@@ -446,6 +446,7 @@ func (s *GitTokenStore) readAuthFile(path, baseDir string) (*cliproxyauth.Auth,
|
||||
if email, ok := metadata["email"].(string); ok && email != "" {
|
||||
auth.Attributes["email"] = email
|
||||
}
|
||||
cliproxyauth.ApplyCustomHeadersFromMetadata(auth)
|
||||
return auth, nil
|
||||
}
|
||||
|
||||
|
||||
@@ -595,6 +595,7 @@ func (s *ObjectTokenStore) readAuthFile(path, baseDir string) (*cliproxyauth.Aut
|
||||
LastRefreshedAt: time.Time{},
|
||||
NextRefreshAfter: time.Time{},
|
||||
}
|
||||
cliproxyauth.ApplyCustomHeadersFromMetadata(auth)
|
||||
return auth, nil
|
||||
}
|
||||
|
||||
|
||||
@@ -310,6 +310,7 @@ func (s *PostgresStore) List(ctx context.Context) ([]*cliproxyauth.Auth, error)
|
||||
LastRefreshedAt: time.Time{},
|
||||
NextRefreshAfter: time.Time{},
|
||||
}
|
||||
cliproxyauth.ApplyCustomHeadersFromMetadata(auth)
|
||||
auths = append(auths, auth)
|
||||
}
|
||||
if err = rows.Err(); err != nil {
|
||||
|
||||
@@ -157,6 +157,7 @@ func synthesizeFileAuths(ctx *SynthesisContext, fullPath string, data []byte) []
|
||||
}
|
||||
}
|
||||
}
|
||||
coreauth.ApplyCustomHeadersFromMetadata(a)
|
||||
ApplyAuthExcludedModelsMeta(a, cfg, perAccountExcluded, "oauth")
|
||||
// For codex auth files, extract plan_type from the JWT id_token.
|
||||
if provider == "codex" {
|
||||
@@ -233,6 +234,11 @@ func SynthesizeGeminiVirtualAuths(primary *coreauth.Auth, metadata map[string]an
|
||||
if noteVal, hasNote := primary.Attributes["note"]; hasNote && noteVal != "" {
|
||||
attrs["note"] = noteVal
|
||||
}
|
||||
for k, v := range primary.Attributes {
|
||||
if strings.HasPrefix(k, "header:") && strings.TrimSpace(v) != "" {
|
||||
attrs[k] = v
|
||||
}
|
||||
}
|
||||
metadataCopy := map[string]any{
|
||||
"email": email,
|
||||
"project_id": projectID,
|
||||
|
||||
@@ -69,10 +69,14 @@ func TestFileSynthesizer_Synthesize_ValidAuthFile(t *testing.T) {
|
||||
|
||||
// Create a valid auth file
|
||||
authData := map[string]any{
|
||||
"type": "claude",
|
||||
"email": "test@example.com",
|
||||
"proxy_url": "http://proxy.local",
|
||||
"prefix": "test-prefix",
|
||||
"type": "claude",
|
||||
"email": "test@example.com",
|
||||
"proxy_url": "http://proxy.local",
|
||||
"prefix": "test-prefix",
|
||||
"headers": map[string]string{
|
||||
" X-Test ": " value ",
|
||||
"X-Empty": " ",
|
||||
},
|
||||
"disable_cooling": true,
|
||||
"request_retry": 2,
|
||||
}
|
||||
@@ -110,6 +114,12 @@ func TestFileSynthesizer_Synthesize_ValidAuthFile(t *testing.T) {
|
||||
if auths[0].ProxyURL != "http://proxy.local" {
|
||||
t.Errorf("expected proxy_url http://proxy.local, got %s", auths[0].ProxyURL)
|
||||
}
|
||||
if got := auths[0].Attributes["header:X-Test"]; got != "value" {
|
||||
t.Errorf("expected header:X-Test value, got %q", got)
|
||||
}
|
||||
if _, ok := auths[0].Attributes["header:X-Empty"]; ok {
|
||||
t.Errorf("expected header:X-Empty to be absent, got %q", auths[0].Attributes["header:X-Empty"])
|
||||
}
|
||||
if v, ok := auths[0].Metadata["disable_cooling"].(bool); !ok || !v {
|
||||
t.Errorf("expected disable_cooling true, got %v", auths[0].Metadata["disable_cooling"])
|
||||
}
|
||||
@@ -450,8 +460,9 @@ func TestSynthesizeGeminiVirtualAuths_MultiProject(t *testing.T) {
|
||||
Prefix: "test-prefix",
|
||||
ProxyURL: "http://proxy.local",
|
||||
Attributes: map[string]string{
|
||||
"source": "test-source",
|
||||
"path": "/path/to/auth",
|
||||
"source": "test-source",
|
||||
"path": "/path/to/auth",
|
||||
"header:X-Tra": "value",
|
||||
},
|
||||
}
|
||||
metadata := map[string]any{
|
||||
@@ -506,6 +517,9 @@ func TestSynthesizeGeminiVirtualAuths_MultiProject(t *testing.T) {
|
||||
if v.Attributes["runtime_only"] != "true" {
|
||||
t.Error("expected runtime_only=true")
|
||||
}
|
||||
if got := v.Attributes["header:X-Tra"]; got != "value" {
|
||||
t.Errorf("expected virtual %d header:X-Tra %q, got %q", i, "value", got)
|
||||
}
|
||||
if v.Attributes["gemini_virtual_parent"] != "primary-id" {
|
||||
t.Errorf("expected gemini_virtual_parent=primary-id, got %s", v.Attributes["gemini_virtual_parent"])
|
||||
}
|
||||
|
||||
@@ -136,6 +136,8 @@ type authAwareStreamExecutor struct {
|
||||
|
||||
type invalidJSONStreamExecutor struct{}
|
||||
|
||||
type splitResponsesEventStreamExecutor struct{}
|
||||
|
||||
func (e *invalidJSONStreamExecutor) Identifier() string { return "codex" }
|
||||
|
||||
func (e *invalidJSONStreamExecutor) Execute(context.Context, *coreauth.Auth, coreexecutor.Request, coreexecutor.Options) (coreexecutor.Response, error) {
|
||||
@@ -165,6 +167,36 @@ func (e *invalidJSONStreamExecutor) HttpRequest(ctx context.Context, auth *corea
|
||||
}
|
||||
}
|
||||
|
||||
func (e *splitResponsesEventStreamExecutor) Identifier() string { return "split-sse" }
|
||||
|
||||
func (e *splitResponsesEventStreamExecutor) Execute(context.Context, *coreauth.Auth, coreexecutor.Request, coreexecutor.Options) (coreexecutor.Response, error) {
|
||||
return coreexecutor.Response{}, &coreauth.Error{Code: "not_implemented", Message: "Execute not implemented"}
|
||||
}
|
||||
|
||||
func (e *splitResponsesEventStreamExecutor) ExecuteStream(context.Context, *coreauth.Auth, coreexecutor.Request, coreexecutor.Options) (*coreexecutor.StreamResult, error) {
|
||||
ch := make(chan coreexecutor.StreamChunk, 2)
|
||||
ch <- coreexecutor.StreamChunk{Payload: []byte("event: response.completed")}
|
||||
ch <- coreexecutor.StreamChunk{Payload: []byte("data: {\"type\":\"response.completed\",\"response\":{\"id\":\"resp-1\",\"output\":[]}}")}
|
||||
close(ch)
|
||||
return &coreexecutor.StreamResult{Chunks: ch}, nil
|
||||
}
|
||||
|
||||
func (e *splitResponsesEventStreamExecutor) Refresh(ctx context.Context, auth *coreauth.Auth) (*coreauth.Auth, error) {
|
||||
return auth, nil
|
||||
}
|
||||
|
||||
func (e *splitResponsesEventStreamExecutor) CountTokens(context.Context, *coreauth.Auth, coreexecutor.Request, coreexecutor.Options) (coreexecutor.Response, error) {
|
||||
return coreexecutor.Response{}, &coreauth.Error{Code: "not_implemented", Message: "CountTokens not implemented"}
|
||||
}
|
||||
|
||||
func (e *splitResponsesEventStreamExecutor) HttpRequest(ctx context.Context, auth *coreauth.Auth, req *http.Request) (*http.Response, error) {
|
||||
return nil, &coreauth.Error{
|
||||
Code: "not_implemented",
|
||||
Message: "HttpRequest not implemented",
|
||||
HTTPStatus: http.StatusNotImplemented,
|
||||
}
|
||||
}
|
||||
|
||||
func (e *authAwareStreamExecutor) Identifier() string { return "codex" }
|
||||
|
||||
func (e *authAwareStreamExecutor) Execute(context.Context, *coreauth.Auth, coreexecutor.Request, coreexecutor.Options) (coreexecutor.Response, error) {
|
||||
@@ -607,3 +639,52 @@ func TestExecuteStreamWithAuthManager_ValidatesOpenAIResponsesStreamDataJSON(t *
|
||||
t.Fatalf("expected terminal error")
|
||||
}
|
||||
}
|
||||
|
||||
func TestExecuteStreamWithAuthManager_AllowsSplitOpenAIResponsesSSEEventLines(t *testing.T) {
|
||||
executor := &splitResponsesEventStreamExecutor{}
|
||||
manager := coreauth.NewManager(nil, nil, nil)
|
||||
manager.RegisterExecutor(executor)
|
||||
|
||||
auth1 := &coreauth.Auth{
|
||||
ID: "auth1",
|
||||
Provider: "split-sse",
|
||||
Status: coreauth.StatusActive,
|
||||
Metadata: map[string]any{"email": "test1@example.com"},
|
||||
}
|
||||
if _, err := manager.Register(context.Background(), auth1); err != nil {
|
||||
t.Fatalf("manager.Register(auth1): %v", err)
|
||||
}
|
||||
|
||||
registry.GetGlobalRegistry().RegisterClient(auth1.ID, auth1.Provider, []*registry.ModelInfo{{ID: "test-model"}})
|
||||
t.Cleanup(func() {
|
||||
registry.GetGlobalRegistry().UnregisterClient(auth1.ID)
|
||||
})
|
||||
|
||||
handler := NewBaseAPIHandlers(&sdkconfig.SDKConfig{}, manager)
|
||||
dataChan, _, errChan := handler.ExecuteStreamWithAuthManager(context.Background(), "openai-response", "test-model", []byte(`{"model":"test-model"}`), "")
|
||||
if dataChan == nil || errChan == nil {
|
||||
t.Fatalf("expected non-nil channels")
|
||||
}
|
||||
|
||||
var got []string
|
||||
for chunk := range dataChan {
|
||||
got = append(got, string(chunk))
|
||||
}
|
||||
|
||||
for msg := range errChan {
|
||||
if msg != nil {
|
||||
t.Fatalf("unexpected error: %+v", msg)
|
||||
}
|
||||
}
|
||||
|
||||
if len(got) != 2 {
|
||||
t.Fatalf("expected 2 forwarded chunks, got %d: %#v", len(got), got)
|
||||
}
|
||||
if got[0] != "event: response.completed" {
|
||||
t.Fatalf("unexpected first chunk: %q", got[0])
|
||||
}
|
||||
expectedData := "data: {\"type\":\"response.completed\",\"response\":{\"id\":\"resp-1\",\"output\":[]}}"
|
||||
if got[1] != expectedData {
|
||||
t.Fatalf("unexpected second chunk.\nGot: %q\nWant: %q", got[1], expectedData)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -5,6 +5,18 @@ import (
|
||||
"strings"
|
||||
)
|
||||
|
||||
// gatewayHeaderPrefixes lists header name prefixes injected by known AI gateway
|
||||
// proxies. Claude Code's client-side telemetry detects these and reports the
|
||||
// gateway type, so we strip them from upstream responses to avoid detection.
|
||||
var gatewayHeaderPrefixes = []string{
|
||||
"x-litellm-",
|
||||
"helicone-",
|
||||
"x-portkey-",
|
||||
"cf-aig-",
|
||||
"x-kong-",
|
||||
"x-bt-",
|
||||
}
|
||||
|
||||
// hopByHopHeaders lists RFC 7230 Section 6.1 hop-by-hop headers that MUST NOT
|
||||
// be forwarded by proxies, plus security-sensitive headers that should not leak.
|
||||
var hopByHopHeaders = map[string]struct{}{
|
||||
@@ -40,6 +52,19 @@ func FilterUpstreamHeaders(src http.Header) http.Header {
|
||||
if _, scoped := connectionScoped[canonicalKey]; scoped {
|
||||
continue
|
||||
}
|
||||
// Strip headers injected by known AI gateway proxies to avoid
|
||||
// Claude Code client-side gateway detection.
|
||||
lowerKey := strings.ToLower(key)
|
||||
gatewayMatch := false
|
||||
for _, prefix := range gatewayHeaderPrefixes {
|
||||
if strings.HasPrefix(lowerKey, prefix) {
|
||||
gatewayMatch = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if gatewayMatch {
|
||||
continue
|
||||
}
|
||||
dst[key] = values
|
||||
}
|
||||
if len(dst) == 0 {
|
||||
|
||||
@@ -9,6 +9,7 @@ package openai
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
@@ -30,11 +31,13 @@ func writeResponsesSSEChunk(w io.Writer, chunk []byte) {
|
||||
if _, err := w.Write(chunk); err != nil {
|
||||
return
|
||||
}
|
||||
if bytes.HasSuffix(chunk, []byte("\n\n")) {
|
||||
if bytes.HasSuffix(chunk, []byte("\n\n")) || bytes.HasSuffix(chunk, []byte("\r\n\r\n")) {
|
||||
return
|
||||
}
|
||||
suffix := []byte("\n\n")
|
||||
if bytes.HasSuffix(chunk, []byte("\n")) {
|
||||
if bytes.HasSuffix(chunk, []byte("\r\n")) {
|
||||
suffix = []byte("\r\n")
|
||||
} else if bytes.HasSuffix(chunk, []byte("\n")) {
|
||||
suffix = []byte("\n")
|
||||
}
|
||||
if _, err := w.Write(suffix); err != nil {
|
||||
@@ -42,6 +45,156 @@ func writeResponsesSSEChunk(w io.Writer, chunk []byte) {
|
||||
}
|
||||
}
|
||||
|
||||
type responsesSSEFramer struct {
|
||||
pending []byte
|
||||
}
|
||||
|
||||
func (f *responsesSSEFramer) WriteChunk(w io.Writer, chunk []byte) {
|
||||
if len(chunk) == 0 {
|
||||
return
|
||||
}
|
||||
if responsesSSENeedsLineBreak(f.pending, chunk) {
|
||||
f.pending = append(f.pending, '\n')
|
||||
}
|
||||
f.pending = append(f.pending, chunk...)
|
||||
for {
|
||||
frameLen := responsesSSEFrameLen(f.pending)
|
||||
if frameLen == 0 {
|
||||
break
|
||||
}
|
||||
writeResponsesSSEChunk(w, f.pending[:frameLen])
|
||||
copy(f.pending, f.pending[frameLen:])
|
||||
f.pending = f.pending[:len(f.pending)-frameLen]
|
||||
}
|
||||
if len(bytes.TrimSpace(f.pending)) == 0 {
|
||||
f.pending = f.pending[:0]
|
||||
return
|
||||
}
|
||||
if len(f.pending) == 0 || !responsesSSECanEmitWithoutDelimiter(f.pending) {
|
||||
return
|
||||
}
|
||||
writeResponsesSSEChunk(w, f.pending)
|
||||
f.pending = f.pending[:0]
|
||||
}
|
||||
|
||||
func (f *responsesSSEFramer) Flush(w io.Writer) {
|
||||
if len(f.pending) == 0 {
|
||||
return
|
||||
}
|
||||
if len(bytes.TrimSpace(f.pending)) == 0 {
|
||||
f.pending = f.pending[:0]
|
||||
return
|
||||
}
|
||||
if !responsesSSECanEmitWithoutDelimiter(f.pending) {
|
||||
f.pending = f.pending[:0]
|
||||
return
|
||||
}
|
||||
writeResponsesSSEChunk(w, f.pending)
|
||||
f.pending = f.pending[:0]
|
||||
}
|
||||
|
||||
func responsesSSEFrameLen(chunk []byte) int {
|
||||
if len(chunk) == 0 {
|
||||
return 0
|
||||
}
|
||||
lf := bytes.Index(chunk, []byte("\n\n"))
|
||||
crlf := bytes.Index(chunk, []byte("\r\n\r\n"))
|
||||
switch {
|
||||
case lf < 0:
|
||||
if crlf < 0 {
|
||||
return 0
|
||||
}
|
||||
return crlf + 4
|
||||
case crlf < 0:
|
||||
return lf + 2
|
||||
case lf < crlf:
|
||||
return lf + 2
|
||||
default:
|
||||
return crlf + 4
|
||||
}
|
||||
}
|
||||
|
||||
func responsesSSENeedsMoreData(chunk []byte) bool {
|
||||
trimmed := bytes.TrimSpace(chunk)
|
||||
if len(trimmed) == 0 {
|
||||
return false
|
||||
}
|
||||
return responsesSSEHasField(trimmed, []byte("event:")) && !responsesSSEHasField(trimmed, []byte("data:"))
|
||||
}
|
||||
|
||||
func responsesSSEHasField(chunk []byte, prefix []byte) bool {
|
||||
s := chunk
|
||||
for len(s) > 0 {
|
||||
line := s
|
||||
if i := bytes.IndexByte(s, '\n'); i >= 0 {
|
||||
line = s[:i]
|
||||
s = s[i+1:]
|
||||
} else {
|
||||
s = nil
|
||||
}
|
||||
line = bytes.TrimSpace(line)
|
||||
if bytes.HasPrefix(line, prefix) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func responsesSSECanEmitWithoutDelimiter(chunk []byte) bool {
|
||||
trimmed := bytes.TrimSpace(chunk)
|
||||
if len(trimmed) == 0 || responsesSSENeedsMoreData(trimmed) || !responsesSSEHasField(trimmed, []byte("data:")) {
|
||||
return false
|
||||
}
|
||||
return responsesSSEDataLinesValid(trimmed)
|
||||
}
|
||||
|
||||
func responsesSSEDataLinesValid(chunk []byte) bool {
|
||||
s := chunk
|
||||
for len(s) > 0 {
|
||||
line := s
|
||||
if i := bytes.IndexByte(s, '\n'); i >= 0 {
|
||||
line = s[:i]
|
||||
s = s[i+1:]
|
||||
} else {
|
||||
s = nil
|
||||
}
|
||||
line = bytes.TrimSpace(line)
|
||||
if len(line) == 0 || !bytes.HasPrefix(line, []byte("data:")) {
|
||||
continue
|
||||
}
|
||||
data := bytes.TrimSpace(line[len("data:"):])
|
||||
if len(data) == 0 || bytes.Equal(data, []byte("[DONE]")) {
|
||||
continue
|
||||
}
|
||||
if !json.Valid(data) {
|
||||
return false
|
||||
}
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
func responsesSSENeedsLineBreak(pending, chunk []byte) bool {
|
||||
if len(pending) == 0 || len(chunk) == 0 {
|
||||
return false
|
||||
}
|
||||
if bytes.HasSuffix(pending, []byte("\n")) || bytes.HasSuffix(pending, []byte("\r")) {
|
||||
return false
|
||||
}
|
||||
if chunk[0] == '\n' || chunk[0] == '\r' {
|
||||
return false
|
||||
}
|
||||
trimmed := bytes.TrimLeft(chunk, " \t")
|
||||
if len(trimmed) == 0 {
|
||||
return false
|
||||
}
|
||||
for _, prefix := range [][]byte{[]byte("data:"), []byte("event:"), []byte("id:"), []byte("retry:"), []byte(":")} {
|
||||
if bytes.HasPrefix(trimmed, prefix) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// OpenAIResponsesAPIHandler contains the handlers for OpenAIResponses API endpoints.
|
||||
// It holds a pool of clients to interact with the backend service.
|
||||
type OpenAIResponsesAPIHandler struct {
|
||||
@@ -254,6 +407,7 @@ func (h *OpenAIResponsesAPIHandler) handleStreamingResponse(c *gin.Context, rawJ
|
||||
c.Header("Connection", "keep-alive")
|
||||
c.Header("Access-Control-Allow-Origin", "*")
|
||||
}
|
||||
framer := &responsesSSEFramer{}
|
||||
|
||||
// Peek at the first chunk
|
||||
for {
|
||||
@@ -291,11 +445,11 @@ func (h *OpenAIResponsesAPIHandler) handleStreamingResponse(c *gin.Context, rawJ
|
||||
handlers.WriteUpstreamHeaders(c.Writer.Header(), upstreamHeaders)
|
||||
|
||||
// Write first chunk logic (matching forwardResponsesStream)
|
||||
writeResponsesSSEChunk(c.Writer, chunk)
|
||||
framer.WriteChunk(c.Writer, chunk)
|
||||
flusher.Flush()
|
||||
|
||||
// Continue
|
||||
h.forwardResponsesStream(c, flusher, func(err error) { cliCancel(err) }, dataChan, errChan)
|
||||
h.forwardResponsesStream(c, flusher, func(err error) { cliCancel(err) }, dataChan, errChan, framer)
|
||||
return
|
||||
}
|
||||
}
|
||||
@@ -413,12 +567,16 @@ func (h *OpenAIResponsesAPIHandler) forwardChatAsResponsesStream(c *gin.Context,
|
||||
})
|
||||
}
|
||||
|
||||
func (h *OpenAIResponsesAPIHandler) forwardResponsesStream(c *gin.Context, flusher http.Flusher, cancel func(error), data <-chan []byte, errs <-chan *interfaces.ErrorMessage) {
|
||||
func (h *OpenAIResponsesAPIHandler) forwardResponsesStream(c *gin.Context, flusher http.Flusher, cancel func(error), data <-chan []byte, errs <-chan *interfaces.ErrorMessage, framer *responsesSSEFramer) {
|
||||
if framer == nil {
|
||||
framer = &responsesSSEFramer{}
|
||||
}
|
||||
h.ForwardStream(c, flusher, cancel, data, errs, handlers.StreamForwardOptions{
|
||||
WriteChunk: func(chunk []byte) {
|
||||
writeResponsesSSEChunk(c.Writer, chunk)
|
||||
framer.WriteChunk(c.Writer, chunk)
|
||||
},
|
||||
WriteTerminalError: func(errMsg *interfaces.ErrorMessage) {
|
||||
framer.Flush(c.Writer)
|
||||
if errMsg == nil {
|
||||
return
|
||||
}
|
||||
@@ -434,6 +592,7 @@ func (h *OpenAIResponsesAPIHandler) forwardResponsesStream(c *gin.Context, flush
|
||||
_, _ = fmt.Fprintf(c.Writer, "\nevent: error\ndata: %s\n\n", string(chunk))
|
||||
},
|
||||
WriteDone: func() {
|
||||
framer.Flush(c.Writer)
|
||||
_, _ = c.Writer.Write([]byte("\n"))
|
||||
},
|
||||
})
|
||||
|
||||
@@ -32,7 +32,7 @@ func TestForwardResponsesStreamTerminalErrorUsesResponsesErrorChunk(t *testing.T
|
||||
errs <- &interfaces.ErrorMessage{StatusCode: http.StatusInternalServerError, Error: errors.New("unexpected EOF")}
|
||||
close(errs)
|
||||
|
||||
h.forwardResponsesStream(c, flusher, func(error) {}, data, errs)
|
||||
h.forwardResponsesStream(c, flusher, func(error) {}, data, errs, nil)
|
||||
body := recorder.Body.String()
|
||||
if !strings.Contains(body, `"type":"error"`) {
|
||||
t.Fatalf("expected responses error chunk, got: %q", body)
|
||||
|
||||
@@ -12,7 +12,9 @@ import (
|
||||
sdkconfig "github.com/router-for-me/CLIProxyAPI/v6/sdk/config"
|
||||
)
|
||||
|
||||
func TestForwardResponsesStreamSeparatesDataOnlySSEChunks(t *testing.T) {
|
||||
func newResponsesStreamTestHandler(t *testing.T) (*OpenAIResponsesAPIHandler, *httptest.ResponseRecorder, *gin.Context, http.Flusher) {
|
||||
t.Helper()
|
||||
|
||||
gin.SetMode(gin.TestMode)
|
||||
base := handlers.NewBaseAPIHandlers(&sdkconfig.SDKConfig{}, nil)
|
||||
h := NewOpenAIResponsesAPIHandler(base)
|
||||
@@ -26,6 +28,12 @@ func TestForwardResponsesStreamSeparatesDataOnlySSEChunks(t *testing.T) {
|
||||
t.Fatalf("expected gin writer to implement http.Flusher")
|
||||
}
|
||||
|
||||
return h, recorder, c, flusher
|
||||
}
|
||||
|
||||
func TestForwardResponsesStreamSeparatesDataOnlySSEChunks(t *testing.T) {
|
||||
h, recorder, c, flusher := newResponsesStreamTestHandler(t)
|
||||
|
||||
data := make(chan []byte, 2)
|
||||
errs := make(chan *interfaces.ErrorMessage)
|
||||
data <- []byte("data: {\"type\":\"response.output_item.done\",\"item\":{\"type\":\"function_call\",\"arguments\":\"{}\"}}")
|
||||
@@ -33,7 +41,7 @@ func TestForwardResponsesStreamSeparatesDataOnlySSEChunks(t *testing.T) {
|
||||
close(data)
|
||||
close(errs)
|
||||
|
||||
h.forwardResponsesStream(c, flusher, func(error) {}, data, errs)
|
||||
h.forwardResponsesStream(c, flusher, func(error) {}, data, errs, nil)
|
||||
body := recorder.Body.String()
|
||||
parts := strings.Split(strings.TrimSpace(body), "\n\n")
|
||||
if len(parts) != 2 {
|
||||
@@ -50,3 +58,85 @@ func TestForwardResponsesStreamSeparatesDataOnlySSEChunks(t *testing.T) {
|
||||
t.Errorf("unexpected second event.\nGot: %q\nWant: %q", parts[1], expectedPart2)
|
||||
}
|
||||
}
|
||||
|
||||
func TestForwardResponsesStreamReassemblesSplitSSEEventChunks(t *testing.T) {
|
||||
h, recorder, c, flusher := newResponsesStreamTestHandler(t)
|
||||
|
||||
data := make(chan []byte, 3)
|
||||
errs := make(chan *interfaces.ErrorMessage)
|
||||
data <- []byte("event: response.created")
|
||||
data <- []byte("data: {\"type\":\"response.created\",\"response\":{\"id\":\"resp-1\"}}")
|
||||
data <- []byte("\n")
|
||||
close(data)
|
||||
close(errs)
|
||||
|
||||
h.forwardResponsesStream(c, flusher, func(error) {}, data, errs, nil)
|
||||
|
||||
got := strings.TrimSuffix(recorder.Body.String(), "\n")
|
||||
want := "event: response.created\ndata: {\"type\":\"response.created\",\"response\":{\"id\":\"resp-1\"}}\n\n"
|
||||
if got != want {
|
||||
t.Fatalf("unexpected split-event framing.\nGot: %q\nWant: %q", got, want)
|
||||
}
|
||||
}
|
||||
|
||||
func TestForwardResponsesStreamPreservesValidFullSSEEventChunks(t *testing.T) {
|
||||
h, recorder, c, flusher := newResponsesStreamTestHandler(t)
|
||||
|
||||
data := make(chan []byte, 1)
|
||||
errs := make(chan *interfaces.ErrorMessage)
|
||||
chunk := []byte("event: response.created\ndata: {\"type\":\"response.created\",\"response\":{\"id\":\"resp-1\"}}\n\n")
|
||||
data <- chunk
|
||||
close(data)
|
||||
close(errs)
|
||||
|
||||
h.forwardResponsesStream(c, flusher, func(error) {}, data, errs, nil)
|
||||
|
||||
got := strings.TrimSuffix(recorder.Body.String(), "\n")
|
||||
if got != string(chunk) {
|
||||
t.Fatalf("unexpected full-event framing.\nGot: %q\nWant: %q", got, string(chunk))
|
||||
}
|
||||
}
|
||||
|
||||
func TestForwardResponsesStreamBuffersSplitDataPayloadChunks(t *testing.T) {
|
||||
h, recorder, c, flusher := newResponsesStreamTestHandler(t)
|
||||
|
||||
data := make(chan []byte, 2)
|
||||
errs := make(chan *interfaces.ErrorMessage)
|
||||
data <- []byte("data: {\"type\":\"response.created\"")
|
||||
data <- []byte(",\"response\":{\"id\":\"resp-1\"}}")
|
||||
close(data)
|
||||
close(errs)
|
||||
|
||||
h.forwardResponsesStream(c, flusher, func(error) {}, data, errs, nil)
|
||||
|
||||
got := recorder.Body.String()
|
||||
want := "data: {\"type\":\"response.created\",\"response\":{\"id\":\"resp-1\"}}\n\n\n"
|
||||
if got != want {
|
||||
t.Fatalf("unexpected split-data framing.\nGot: %q\nWant: %q", got, want)
|
||||
}
|
||||
}
|
||||
|
||||
func TestResponsesSSENeedsLineBreakSkipsChunksThatAlreadyStartWithNewline(t *testing.T) {
|
||||
if responsesSSENeedsLineBreak([]byte("event: response.created"), []byte("\n")) {
|
||||
t.Fatal("expected no injected newline before newline-only chunk")
|
||||
}
|
||||
if responsesSSENeedsLineBreak([]byte("event: response.created"), []byte("\r\n")) {
|
||||
t.Fatal("expected no injected newline before CRLF chunk")
|
||||
}
|
||||
}
|
||||
|
||||
func TestForwardResponsesStreamDropsIncompleteTrailingDataChunkOnFlush(t *testing.T) {
|
||||
h, recorder, c, flusher := newResponsesStreamTestHandler(t)
|
||||
|
||||
data := make(chan []byte, 1)
|
||||
errs := make(chan *interfaces.ErrorMessage)
|
||||
data <- []byte("data: {\"type\":\"response.created\"")
|
||||
close(data)
|
||||
close(errs)
|
||||
|
||||
h.forwardResponsesStream(c, flusher, func(error) {}, data, errs, nil)
|
||||
|
||||
if got := recorder.Body.String(); got != "\n" {
|
||||
t.Fatalf("expected incomplete trailing data to be dropped on flush.\nGot: %q", got)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -32,7 +32,7 @@ const (
|
||||
wsEventTypeCompleted = "response.completed"
|
||||
wsDoneMarker = "[DONE]"
|
||||
wsTurnStateHeader = "x-codex-turn-state"
|
||||
wsRequestBodyKey = "REQUEST_BODY_OVERRIDE"
|
||||
wsTimelineBodyKey = "WEBSOCKET_TIMELINE_OVERRIDE"
|
||||
)
|
||||
|
||||
var responsesWebsocketUpgrader = websocket.Upgrader{
|
||||
@@ -57,10 +57,11 @@ func (h *OpenAIResponsesAPIHandler) ResponsesWebsocket(c *gin.Context) {
|
||||
clientIP := websocketClientAddress(c)
|
||||
log.Infof("responses websocket: client connected id=%s remote=%s", passthroughSessionID, clientIP)
|
||||
var wsTerminateErr error
|
||||
var wsBodyLog strings.Builder
|
||||
var wsTimelineLog strings.Builder
|
||||
defer func() {
|
||||
releaseResponsesWebsocketToolCaches(downstreamSessionKey)
|
||||
if wsTerminateErr != nil {
|
||||
appendWebsocketTimelineDisconnect(&wsTimelineLog, wsTerminateErr, time.Now())
|
||||
// log.Infof("responses websocket: session closing id=%s reason=%v", passthroughSessionID, wsTerminateErr)
|
||||
} else {
|
||||
log.Infof("responses websocket: session closing id=%s", passthroughSessionID)
|
||||
@@ -69,7 +70,7 @@ func (h *OpenAIResponsesAPIHandler) ResponsesWebsocket(c *gin.Context) {
|
||||
h.AuthManager.CloseExecutionSession(passthroughSessionID)
|
||||
log.Infof("responses websocket: upstream execution session closed id=%s", passthroughSessionID)
|
||||
}
|
||||
setWebsocketRequestBody(c, wsBodyLog.String())
|
||||
setWebsocketTimelineBody(c, wsTimelineLog.String())
|
||||
if errClose := conn.Close(); errClose != nil {
|
||||
log.Warnf("responses websocket: close connection error: %v", errClose)
|
||||
}
|
||||
@@ -83,7 +84,6 @@ func (h *OpenAIResponsesAPIHandler) ResponsesWebsocket(c *gin.Context) {
|
||||
msgType, payload, errReadMessage := conn.ReadMessage()
|
||||
if errReadMessage != nil {
|
||||
wsTerminateErr = errReadMessage
|
||||
appendWebsocketEvent(&wsBodyLog, "disconnect", []byte(errReadMessage.Error()))
|
||||
if websocket.IsCloseError(errReadMessage, websocket.CloseNormalClosure, websocket.CloseGoingAway, websocket.CloseNoStatusReceived) {
|
||||
log.Infof("responses websocket: client disconnected id=%s error=%v", passthroughSessionID, errReadMessage)
|
||||
} else {
|
||||
@@ -101,7 +101,7 @@ func (h *OpenAIResponsesAPIHandler) ResponsesWebsocket(c *gin.Context) {
|
||||
// websocketPayloadEventType(payload),
|
||||
// websocketPayloadPreview(payload),
|
||||
// )
|
||||
appendWebsocketEvent(&wsBodyLog, "request", payload)
|
||||
appendWebsocketTimelineEvent(&wsTimelineLog, "request", payload, time.Now())
|
||||
|
||||
allowIncrementalInputWithPreviousResponseID := false
|
||||
if pinnedAuthID != "" && h != nil && h.AuthManager != nil {
|
||||
@@ -128,8 +128,7 @@ func (h *OpenAIResponsesAPIHandler) ResponsesWebsocket(c *gin.Context) {
|
||||
if errMsg != nil {
|
||||
h.LoggingAPIResponseError(context.WithValue(context.Background(), "gin", c), errMsg)
|
||||
markAPIResponseTimestamp(c)
|
||||
errorPayload, errWrite := writeResponsesWebsocketError(conn, errMsg)
|
||||
appendWebsocketEvent(&wsBodyLog, "response", errorPayload)
|
||||
errorPayload, errWrite := writeResponsesWebsocketError(conn, &wsTimelineLog, errMsg)
|
||||
log.Infof(
|
||||
"responses websocket: downstream_out id=%s type=%d event=%s payload=%s",
|
||||
passthroughSessionID,
|
||||
@@ -157,9 +156,8 @@ func (h *OpenAIResponsesAPIHandler) ResponsesWebsocket(c *gin.Context) {
|
||||
}
|
||||
lastRequest = updatedLastRequest
|
||||
lastResponseOutput = []byte("[]")
|
||||
if errWrite := writeResponsesWebsocketSyntheticPrewarm(c, conn, requestJSON, &wsBodyLog, passthroughSessionID); errWrite != nil {
|
||||
if errWrite := writeResponsesWebsocketSyntheticPrewarm(c, conn, requestJSON, &wsTimelineLog, passthroughSessionID); errWrite != nil {
|
||||
wsTerminateErr = errWrite
|
||||
appendWebsocketEvent(&wsBodyLog, "disconnect", []byte(errWrite.Error()))
|
||||
return
|
||||
}
|
||||
continue
|
||||
@@ -192,10 +190,9 @@ func (h *OpenAIResponsesAPIHandler) ResponsesWebsocket(c *gin.Context) {
|
||||
}
|
||||
dataChan, _, errChan := h.ExecuteStreamWithAuthManager(cliCtx, h.HandlerType(), modelName, requestJSON, "")
|
||||
|
||||
completedOutput, errForward := h.forwardResponsesWebsocket(c, conn, cliCancel, dataChan, errChan, &wsBodyLog, passthroughSessionID)
|
||||
completedOutput, errForward := h.forwardResponsesWebsocket(c, conn, cliCancel, dataChan, errChan, &wsTimelineLog, passthroughSessionID)
|
||||
if errForward != nil {
|
||||
wsTerminateErr = errForward
|
||||
appendWebsocketEvent(&wsBodyLog, "disconnect", []byte(errForward.Error()))
|
||||
log.Warnf("responses websocket: forward failed id=%s error=%v", passthroughSessionID, errForward)
|
||||
return
|
||||
}
|
||||
@@ -597,7 +594,7 @@ func writeResponsesWebsocketSyntheticPrewarm(
|
||||
c *gin.Context,
|
||||
conn *websocket.Conn,
|
||||
requestJSON []byte,
|
||||
wsBodyLog *strings.Builder,
|
||||
wsTimelineLog *strings.Builder,
|
||||
sessionID string,
|
||||
) error {
|
||||
payloads, errPayloads := syntheticResponsesWebsocketPrewarmPayloads(requestJSON)
|
||||
@@ -606,7 +603,6 @@ func writeResponsesWebsocketSyntheticPrewarm(
|
||||
}
|
||||
for i := 0; i < len(payloads); i++ {
|
||||
markAPIResponseTimestamp(c)
|
||||
appendWebsocketEvent(wsBodyLog, "response", payloads[i])
|
||||
// log.Infof(
|
||||
// "responses websocket: downstream_out id=%s type=%d event=%s payload=%s",
|
||||
// sessionID,
|
||||
@@ -614,7 +610,7 @@ func writeResponsesWebsocketSyntheticPrewarm(
|
||||
// websocketPayloadEventType(payloads[i]),
|
||||
// websocketPayloadPreview(payloads[i]),
|
||||
// )
|
||||
if errWrite := conn.WriteMessage(websocket.TextMessage, payloads[i]); errWrite != nil {
|
||||
if errWrite := writeResponsesWebsocketPayload(conn, wsTimelineLog, payloads[i], time.Now()); errWrite != nil {
|
||||
log.Warnf(
|
||||
"responses websocket: downstream_out write failed id=%s event=%s error=%v",
|
||||
sessionID,
|
||||
@@ -713,7 +709,7 @@ func (h *OpenAIResponsesAPIHandler) forwardResponsesWebsocket(
|
||||
cancel handlers.APIHandlerCancelFunc,
|
||||
data <-chan []byte,
|
||||
errs <-chan *interfaces.ErrorMessage,
|
||||
wsBodyLog *strings.Builder,
|
||||
wsTimelineLog *strings.Builder,
|
||||
sessionID string,
|
||||
) ([]byte, error) {
|
||||
completed := false
|
||||
@@ -736,8 +732,7 @@ func (h *OpenAIResponsesAPIHandler) forwardResponsesWebsocket(
|
||||
if errMsg != nil {
|
||||
h.LoggingAPIResponseError(context.WithValue(context.Background(), "gin", c), errMsg)
|
||||
markAPIResponseTimestamp(c)
|
||||
errorPayload, errWrite := writeResponsesWebsocketError(conn, errMsg)
|
||||
appendWebsocketEvent(wsBodyLog, "response", errorPayload)
|
||||
errorPayload, errWrite := writeResponsesWebsocketError(conn, wsTimelineLog, errMsg)
|
||||
log.Infof(
|
||||
"responses websocket: downstream_out id=%s type=%d event=%s payload=%s",
|
||||
sessionID,
|
||||
@@ -771,8 +766,7 @@ func (h *OpenAIResponsesAPIHandler) forwardResponsesWebsocket(
|
||||
}
|
||||
h.LoggingAPIResponseError(context.WithValue(context.Background(), "gin", c), errMsg)
|
||||
markAPIResponseTimestamp(c)
|
||||
errorPayload, errWrite := writeResponsesWebsocketError(conn, errMsg)
|
||||
appendWebsocketEvent(wsBodyLog, "response", errorPayload)
|
||||
errorPayload, errWrite := writeResponsesWebsocketError(conn, wsTimelineLog, errMsg)
|
||||
log.Infof(
|
||||
"responses websocket: downstream_out id=%s type=%d event=%s payload=%s",
|
||||
sessionID,
|
||||
@@ -806,7 +800,6 @@ func (h *OpenAIResponsesAPIHandler) forwardResponsesWebsocket(
|
||||
completedOutput = responseCompletedOutputFromPayload(payloads[i])
|
||||
}
|
||||
markAPIResponseTimestamp(c)
|
||||
appendWebsocketEvent(wsBodyLog, "response", payloads[i])
|
||||
// log.Infof(
|
||||
// "responses websocket: downstream_out id=%s type=%d event=%s payload=%s",
|
||||
// sessionID,
|
||||
@@ -814,7 +807,7 @@ func (h *OpenAIResponsesAPIHandler) forwardResponsesWebsocket(
|
||||
// websocketPayloadEventType(payloads[i]),
|
||||
// websocketPayloadPreview(payloads[i]),
|
||||
// )
|
||||
if errWrite := conn.WriteMessage(websocket.TextMessage, payloads[i]); errWrite != nil {
|
||||
if errWrite := writeResponsesWebsocketPayload(conn, wsTimelineLog, payloads[i], time.Now()); errWrite != nil {
|
||||
log.Warnf(
|
||||
"responses websocket: downstream_out write failed id=%s event=%s error=%v",
|
||||
sessionID,
|
||||
@@ -870,7 +863,7 @@ func websocketJSONPayloadsFromChunk(chunk []byte) [][]byte {
|
||||
return payloads
|
||||
}
|
||||
|
||||
func writeResponsesWebsocketError(conn *websocket.Conn, errMsg *interfaces.ErrorMessage) ([]byte, error) {
|
||||
func writeResponsesWebsocketError(conn *websocket.Conn, wsTimelineLog *strings.Builder, errMsg *interfaces.ErrorMessage) ([]byte, error) {
|
||||
status := http.StatusInternalServerError
|
||||
errText := http.StatusText(status)
|
||||
if errMsg != nil {
|
||||
@@ -940,7 +933,7 @@ func writeResponsesWebsocketError(conn *websocket.Conn, errMsg *interfaces.Error
|
||||
}
|
||||
}
|
||||
|
||||
return payload, conn.WriteMessage(websocket.TextMessage, payload)
|
||||
return payload, writeResponsesWebsocketPayload(conn, wsTimelineLog, payload, time.Now())
|
||||
}
|
||||
|
||||
func appendWebsocketEvent(builder *strings.Builder, eventType string, payload []byte) {
|
||||
@@ -979,7 +972,11 @@ func websocketPayloadPreview(payload []byte) string {
|
||||
return previewText
|
||||
}
|
||||
|
||||
func setWebsocketRequestBody(c *gin.Context, body string) {
|
||||
func setWebsocketTimelineBody(c *gin.Context, body string) {
|
||||
setWebsocketBody(c, wsTimelineBodyKey, body)
|
||||
}
|
||||
|
||||
func setWebsocketBody(c *gin.Context, key string, body string) {
|
||||
if c == nil {
|
||||
return
|
||||
}
|
||||
@@ -987,7 +984,40 @@ func setWebsocketRequestBody(c *gin.Context, body string) {
|
||||
if trimmedBody == "" {
|
||||
return
|
||||
}
|
||||
c.Set(wsRequestBodyKey, []byte(trimmedBody))
|
||||
c.Set(key, []byte(trimmedBody))
|
||||
}
|
||||
|
||||
func writeResponsesWebsocketPayload(conn *websocket.Conn, wsTimelineLog *strings.Builder, payload []byte, timestamp time.Time) error {
|
||||
appendWebsocketTimelineEvent(wsTimelineLog, "response", payload, timestamp)
|
||||
return conn.WriteMessage(websocket.TextMessage, payload)
|
||||
}
|
||||
|
||||
func appendWebsocketTimelineDisconnect(builder *strings.Builder, err error, timestamp time.Time) {
|
||||
if err == nil {
|
||||
return
|
||||
}
|
||||
appendWebsocketTimelineEvent(builder, "disconnect", []byte(err.Error()), timestamp)
|
||||
}
|
||||
|
||||
func appendWebsocketTimelineEvent(builder *strings.Builder, eventType string, payload []byte, timestamp time.Time) {
|
||||
if builder == nil {
|
||||
return
|
||||
}
|
||||
trimmedPayload := bytes.TrimSpace(payload)
|
||||
if len(trimmedPayload) == 0 {
|
||||
return
|
||||
}
|
||||
if builder.Len() > 0 {
|
||||
builder.WriteString("\n")
|
||||
}
|
||||
builder.WriteString("Timestamp: ")
|
||||
builder.WriteString(timestamp.Format(time.RFC3339Nano))
|
||||
builder.WriteString("\n")
|
||||
builder.WriteString("Event: websocket.")
|
||||
builder.WriteString(eventType)
|
||||
builder.WriteString("\n")
|
||||
builder.Write(trimmedPayload)
|
||||
builder.WriteString("\n")
|
||||
}
|
||||
|
||||
func markAPIResponseTimestamp(c *gin.Context) {
|
||||
|
||||
@@ -392,27 +392,45 @@ func TestAppendWebsocketEvent(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestSetWebsocketRequestBody(t *testing.T) {
|
||||
func TestAppendWebsocketTimelineEvent(t *testing.T) {
|
||||
var builder strings.Builder
|
||||
ts := time.Date(2026, time.April, 1, 12, 34, 56, 789000000, time.UTC)
|
||||
|
||||
appendWebsocketTimelineEvent(&builder, "request", []byte(" {\"type\":\"response.create\"}\n"), ts)
|
||||
|
||||
got := builder.String()
|
||||
if !strings.Contains(got, "Timestamp: 2026-04-01T12:34:56.789Z") {
|
||||
t.Fatalf("timeline timestamp not found: %s", got)
|
||||
}
|
||||
if !strings.Contains(got, "Event: websocket.request") {
|
||||
t.Fatalf("timeline event not found: %s", got)
|
||||
}
|
||||
if !strings.Contains(got, "{\"type\":\"response.create\"}") {
|
||||
t.Fatalf("timeline payload not found: %s", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSetWebsocketTimelineBody(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
recorder := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(recorder)
|
||||
|
||||
setWebsocketRequestBody(c, " \n ")
|
||||
if _, exists := c.Get(wsRequestBodyKey); exists {
|
||||
t.Fatalf("request body key should not be set for empty body")
|
||||
setWebsocketTimelineBody(c, " \n ")
|
||||
if _, exists := c.Get(wsTimelineBodyKey); exists {
|
||||
t.Fatalf("timeline body key should not be set for empty body")
|
||||
}
|
||||
|
||||
setWebsocketRequestBody(c, "event body")
|
||||
value, exists := c.Get(wsRequestBodyKey)
|
||||
setWebsocketTimelineBody(c, "timeline body")
|
||||
value, exists := c.Get(wsTimelineBodyKey)
|
||||
if !exists {
|
||||
t.Fatalf("request body key not set")
|
||||
t.Fatalf("timeline body key not set")
|
||||
}
|
||||
bodyBytes, ok := value.([]byte)
|
||||
if !ok {
|
||||
t.Fatalf("request body key type mismatch")
|
||||
t.Fatalf("timeline body key type mismatch")
|
||||
}
|
||||
if string(bodyBytes) != "event body" {
|
||||
t.Fatalf("request body = %q, want %q", string(bodyBytes), "event body")
|
||||
if string(bodyBytes) != "timeline body" {
|
||||
t.Fatalf("timeline body = %q, want %q", string(bodyBytes), "timeline body")
|
||||
}
|
||||
}
|
||||
|
||||
@@ -544,14 +562,14 @@ func TestForwardResponsesWebsocketPreservesCompletedEvent(t *testing.T) {
|
||||
close(data)
|
||||
close(errCh)
|
||||
|
||||
var bodyLog strings.Builder
|
||||
var timelineLog strings.Builder
|
||||
completedOutput, err := (*OpenAIResponsesAPIHandler)(nil).forwardResponsesWebsocket(
|
||||
ctx,
|
||||
conn,
|
||||
func(...interface{}) {},
|
||||
data,
|
||||
errCh,
|
||||
&bodyLog,
|
||||
&timelineLog,
|
||||
"session-1",
|
||||
)
|
||||
if err != nil {
|
||||
@@ -562,6 +580,10 @@ func TestForwardResponsesWebsocketPreservesCompletedEvent(t *testing.T) {
|
||||
serverErrCh <- errors.New("completed output not captured")
|
||||
return
|
||||
}
|
||||
if !strings.Contains(timelineLog.String(), "Event: websocket.response") {
|
||||
serverErrCh <- errors.New("websocket timeline did not capture downstream response")
|
||||
return
|
||||
}
|
||||
serverErrCh <- nil
|
||||
}))
|
||||
defer server.Close()
|
||||
@@ -594,6 +616,116 @@ func TestForwardResponsesWebsocketPreservesCompletedEvent(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestForwardResponsesWebsocketLogsAttemptedResponseOnWriteFailure(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
serverErrCh := make(chan error, 1)
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
conn, err := responsesWebsocketUpgrader.Upgrade(w, r, nil)
|
||||
if err != nil {
|
||||
serverErrCh <- err
|
||||
return
|
||||
}
|
||||
|
||||
ctx, _ := gin.CreateTestContext(httptest.NewRecorder())
|
||||
ctx.Request = r
|
||||
|
||||
data := make(chan []byte, 1)
|
||||
errCh := make(chan *interfaces.ErrorMessage)
|
||||
data <- []byte("data: {\"type\":\"response.completed\",\"response\":{\"id\":\"resp-1\",\"output\":[{\"type\":\"message\",\"id\":\"out-1\"}]}}\n\n")
|
||||
close(data)
|
||||
close(errCh)
|
||||
|
||||
var timelineLog strings.Builder
|
||||
if errClose := conn.Close(); errClose != nil {
|
||||
serverErrCh <- errClose
|
||||
return
|
||||
}
|
||||
|
||||
_, err = (*OpenAIResponsesAPIHandler)(nil).forwardResponsesWebsocket(
|
||||
ctx,
|
||||
conn,
|
||||
func(...interface{}) {},
|
||||
data,
|
||||
errCh,
|
||||
&timelineLog,
|
||||
"session-1",
|
||||
)
|
||||
if err == nil {
|
||||
serverErrCh <- errors.New("expected websocket write failure")
|
||||
return
|
||||
}
|
||||
if !strings.Contains(timelineLog.String(), "Event: websocket.response") {
|
||||
serverErrCh <- errors.New("websocket timeline did not capture attempted downstream response")
|
||||
return
|
||||
}
|
||||
if !strings.Contains(timelineLog.String(), "\"type\":\"response.completed\"") {
|
||||
serverErrCh <- errors.New("websocket timeline did not retain attempted payload")
|
||||
return
|
||||
}
|
||||
serverErrCh <- nil
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
wsURL := "ws" + strings.TrimPrefix(server.URL, "http")
|
||||
conn, _, err := websocket.DefaultDialer.Dial(wsURL, nil)
|
||||
if err != nil {
|
||||
t.Fatalf("dial websocket: %v", err)
|
||||
}
|
||||
defer func() {
|
||||
_ = conn.Close()
|
||||
}()
|
||||
|
||||
if errServer := <-serverErrCh; errServer != nil {
|
||||
t.Fatalf("server error: %v", errServer)
|
||||
}
|
||||
}
|
||||
|
||||
func TestResponsesWebsocketTimelineRecordsDisconnectEvent(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
manager := coreauth.NewManager(nil, nil, nil)
|
||||
base := handlers.NewBaseAPIHandlers(&sdkconfig.SDKConfig{}, manager)
|
||||
h := NewOpenAIResponsesAPIHandler(base)
|
||||
|
||||
timelineCh := make(chan string, 1)
|
||||
router := gin.New()
|
||||
router.GET("/v1/responses/ws", func(c *gin.Context) {
|
||||
h.ResponsesWebsocket(c)
|
||||
timeline := ""
|
||||
if value, exists := c.Get(wsTimelineBodyKey); exists {
|
||||
if body, ok := value.([]byte); ok {
|
||||
timeline = string(body)
|
||||
}
|
||||
}
|
||||
timelineCh <- timeline
|
||||
})
|
||||
|
||||
server := httptest.NewServer(router)
|
||||
defer server.Close()
|
||||
|
||||
wsURL := "ws" + strings.TrimPrefix(server.URL, "http") + "/v1/responses/ws"
|
||||
conn, _, err := websocket.DefaultDialer.Dial(wsURL, nil)
|
||||
if err != nil {
|
||||
t.Fatalf("dial websocket: %v", err)
|
||||
}
|
||||
|
||||
closePayload := websocket.FormatCloseMessage(websocket.CloseGoingAway, "client closing")
|
||||
if err = conn.WriteControl(websocket.CloseMessage, closePayload, time.Now().Add(time.Second)); err != nil {
|
||||
t.Fatalf("write close control: %v", err)
|
||||
}
|
||||
_ = conn.Close()
|
||||
|
||||
select {
|
||||
case timeline := <-timelineCh:
|
||||
if !strings.Contains(timeline, "Event: websocket.disconnect") {
|
||||
t.Fatalf("websocket timeline missing disconnect event: %s", timeline)
|
||||
}
|
||||
case <-time.After(5 * time.Second):
|
||||
t.Fatal("timed out waiting for websocket timeline")
|
||||
}
|
||||
}
|
||||
|
||||
func TestWebsocketUpstreamSupportsIncrementalInputForModel(t *testing.T) {
|
||||
manager := coreauth.NewManager(nil, nil, nil)
|
||||
auth := &coreauth.Auth{
|
||||
|
||||
@@ -263,6 +263,7 @@ func (s *FileTokenStore) readAuthFile(path, baseDir string) (*cliproxyauth.Auth,
|
||||
if email, ok := metadata["email"].(string); ok && email != "" {
|
||||
auth.Attributes["email"] = email
|
||||
}
|
||||
cliproxyauth.ApplyCustomHeadersFromMetadata(auth)
|
||||
return auth, nil
|
||||
}
|
||||
|
||||
|
||||
68
sdk/cliproxy/auth/custom_headers.go
Normal file
68
sdk/cliproxy/auth/custom_headers.go
Normal file
@@ -0,0 +1,68 @@
|
||||
package auth
|
||||
|
||||
import "strings"
|
||||
|
||||
func ExtractCustomHeadersFromMetadata(metadata map[string]any) map[string]string {
|
||||
if len(metadata) == 0 {
|
||||
return nil
|
||||
}
|
||||
raw, ok := metadata["headers"]
|
||||
if !ok || raw == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
out := make(map[string]string)
|
||||
switch headers := raw.(type) {
|
||||
case map[string]string:
|
||||
for key, value := range headers {
|
||||
name := strings.TrimSpace(key)
|
||||
if name == "" {
|
||||
continue
|
||||
}
|
||||
val := strings.TrimSpace(value)
|
||||
if val == "" {
|
||||
continue
|
||||
}
|
||||
out[name] = val
|
||||
}
|
||||
case map[string]any:
|
||||
for key, value := range headers {
|
||||
name := strings.TrimSpace(key)
|
||||
if name == "" {
|
||||
continue
|
||||
}
|
||||
rawVal, ok := value.(string)
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
val := strings.TrimSpace(rawVal)
|
||||
if val == "" {
|
||||
continue
|
||||
}
|
||||
out[name] = val
|
||||
}
|
||||
default:
|
||||
return nil
|
||||
}
|
||||
|
||||
if len(out) == 0 {
|
||||
return nil
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
func ApplyCustomHeadersFromMetadata(auth *Auth) {
|
||||
if auth == nil || len(auth.Metadata) == 0 {
|
||||
return
|
||||
}
|
||||
headers := ExtractCustomHeadersFromMetadata(auth.Metadata)
|
||||
if len(headers) == 0 {
|
||||
return
|
||||
}
|
||||
if auth.Attributes == nil {
|
||||
auth.Attributes = make(map[string]string)
|
||||
}
|
||||
for name, value := range headers {
|
||||
auth.Attributes["header:"+name] = value
|
||||
}
|
||||
}
|
||||
50
sdk/cliproxy/auth/custom_headers_test.go
Normal file
50
sdk/cliproxy/auth/custom_headers_test.go
Normal file
@@ -0,0 +1,50 @@
|
||||
package auth
|
||||
|
||||
import (
|
||||
"reflect"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestExtractCustomHeadersFromMetadata(t *testing.T) {
|
||||
meta := map[string]any{
|
||||
"headers": map[string]any{
|
||||
" X-Test ": " value ",
|
||||
"": "ignored",
|
||||
"X-Empty": " ",
|
||||
"X-Num": float64(1),
|
||||
},
|
||||
}
|
||||
|
||||
got := ExtractCustomHeadersFromMetadata(meta)
|
||||
want := map[string]string{"X-Test": "value"}
|
||||
if !reflect.DeepEqual(got, want) {
|
||||
t.Fatalf("ExtractCustomHeadersFromMetadata() = %#v, want %#v", got, want)
|
||||
}
|
||||
}
|
||||
|
||||
func TestApplyCustomHeadersFromMetadata(t *testing.T) {
|
||||
auth := &Auth{
|
||||
Metadata: map[string]any{
|
||||
"headers": map[string]string{
|
||||
"X-Test": "new",
|
||||
"X-Empty": " ",
|
||||
},
|
||||
},
|
||||
Attributes: map[string]string{
|
||||
"header:X-Test": "old",
|
||||
"keep": "1",
|
||||
},
|
||||
}
|
||||
|
||||
ApplyCustomHeadersFromMetadata(auth)
|
||||
|
||||
if got := auth.Attributes["header:X-Test"]; got != "new" {
|
||||
t.Fatalf("header:X-Test = %q, want %q", got, "new")
|
||||
}
|
||||
if _, ok := auth.Attributes["header:X-Empty"]; ok {
|
||||
t.Fatalf("expected header:X-Empty to be absent, got %#v", auth.Attributes["header:X-Empty"])
|
||||
}
|
||||
if got := auth.Attributes["keep"]; got != "1" {
|
||||
t.Fatalf("keep = %q, want %q", got, "1")
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user