mirror of
https://github.com/router-for-me/CLIProxyAPIPlus.git
synced 2026-04-24 09:21:02 +00:00
Compare commits
50 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
cdb9c2e6e8 | ||
|
|
3faa1ca9af | ||
|
|
9d975e0375 | ||
|
|
2a6d8b78d4 | ||
|
|
6b80ec79a0 | ||
|
|
d3f4783a24 | ||
|
|
1cb6bdbc87 | ||
|
|
96ddfc1f24 | ||
|
|
c169b32570 | ||
|
|
36a512fdf2 | ||
|
|
26fbb77901 | ||
|
|
5dc0dbc7aa | ||
|
|
ee6fc4e8a1 | ||
|
|
8fee16aecd | ||
|
|
2b7ba54a2f | ||
|
|
007c3304f2 | ||
|
|
e76ba0ede9 | ||
|
|
c06ac07e23 | ||
|
|
e592a57458 | ||
|
|
66769ec657 | ||
|
|
f413feec61 | ||
|
|
2e538e3486 | ||
|
|
9617a7b0d6 | ||
|
|
7569320770 | ||
|
|
8d25cf0d75 | ||
|
|
64e85e7019 | ||
|
|
a862984dca | ||
|
|
f0365f0465 | ||
|
|
6d1e20e940 | ||
|
|
0c0aae1eac | ||
|
|
5dcf7cb846 | ||
|
|
349b2ba3af | ||
|
|
98db5aabd0 | ||
|
|
e52b542e22 | ||
|
|
8f6abb8a86 | ||
|
|
ed8eaae964 | ||
|
|
7fd98f3556 | ||
|
|
e8de87ee90 | ||
|
|
4e572ec8b9 | ||
|
|
6c7f18c448 | ||
|
|
24bc9cba67 | ||
|
|
5bf89dd757 | ||
|
|
4442574e53 | ||
|
|
71a6dffbb6 | ||
|
|
24e8e20b59 | ||
|
|
a87f09bad2 | ||
|
|
bc6c4cdbfc | ||
|
|
404546ce93 | ||
|
|
6dd1cf1dd6 | ||
|
|
9058d406a3 |
BIN
assets/cubence.png
Normal file
BIN
assets/cubence.png
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 51 KiB |
BIN
assets/packycode.png
Normal file
BIN
assets/packycode.png
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 8.1 KiB |
@@ -78,6 +78,11 @@ routing:
|
|||||||
# When true, enable authentication for the WebSocket API (/v1/ws).
|
# When true, enable authentication for the WebSocket API (/v1/ws).
|
||||||
ws-auth: false
|
ws-auth: false
|
||||||
|
|
||||||
|
# Streaming behavior (SSE keep-alives + safe bootstrap retries).
|
||||||
|
# streaming:
|
||||||
|
# keepalive-seconds: 15 # Default: 0 (disabled). <= 0 disables keep-alives.
|
||||||
|
# bootstrap-retries: 1 # Default: 0 (disabled). Retries before first byte is sent.
|
||||||
|
|
||||||
# Gemini API keys
|
# Gemini API keys
|
||||||
# gemini-api-key:
|
# gemini-api-key:
|
||||||
# - api-key: "AIzaSy...01"
|
# - api-key: "AIzaSy...01"
|
||||||
|
|||||||
@@ -209,6 +209,94 @@ func (h *Handler) GetRequestErrorLogs(c *gin.Context) {
|
|||||||
c.JSON(http.StatusOK, gin.H{"files": files})
|
c.JSON(http.StatusOK, gin.H{"files": files})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// GetRequestLogByID finds and downloads a request log file by its request ID.
|
||||||
|
// The ID is matched against the suffix of log file names (format: *-{requestID}.log).
|
||||||
|
func (h *Handler) GetRequestLogByID(c *gin.Context) {
|
||||||
|
if h == nil {
|
||||||
|
c.JSON(http.StatusInternalServerError, gin.H{"error": "handler unavailable"})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if h.cfg == nil {
|
||||||
|
c.JSON(http.StatusServiceUnavailable, gin.H{"error": "configuration unavailable"})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
dir := h.logDirectory()
|
||||||
|
if strings.TrimSpace(dir) == "" {
|
||||||
|
c.JSON(http.StatusInternalServerError, gin.H{"error": "log directory not configured"})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
requestID := strings.TrimSpace(c.Param("id"))
|
||||||
|
if requestID == "" {
|
||||||
|
requestID = strings.TrimSpace(c.Query("id"))
|
||||||
|
}
|
||||||
|
if requestID == "" {
|
||||||
|
c.JSON(http.StatusBadRequest, gin.H{"error": "missing request ID"})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if strings.ContainsAny(requestID, "/\\") {
|
||||||
|
c.JSON(http.StatusBadRequest, gin.H{"error": "invalid request ID"})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
entries, err := os.ReadDir(dir)
|
||||||
|
if err != nil {
|
||||||
|
if os.IsNotExist(err) {
|
||||||
|
c.JSON(http.StatusNotFound, gin.H{"error": "log directory not found"})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
c.JSON(http.StatusInternalServerError, gin.H{"error": fmt.Sprintf("failed to list log directory: %v", err)})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
suffix := "-" + requestID + ".log"
|
||||||
|
var matchedFile string
|
||||||
|
for _, entry := range entries {
|
||||||
|
if entry.IsDir() {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
name := entry.Name()
|
||||||
|
if strings.HasSuffix(name, suffix) {
|
||||||
|
matchedFile = name
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if matchedFile == "" {
|
||||||
|
c.JSON(http.StatusNotFound, gin.H{"error": "log file not found for the given request ID"})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
dirAbs, errAbs := filepath.Abs(dir)
|
||||||
|
if errAbs != nil {
|
||||||
|
c.JSON(http.StatusInternalServerError, gin.H{"error": fmt.Sprintf("failed to resolve log directory: %v", errAbs)})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
fullPath := filepath.Clean(filepath.Join(dirAbs, matchedFile))
|
||||||
|
prefix := dirAbs + string(os.PathSeparator)
|
||||||
|
if !strings.HasPrefix(fullPath, prefix) {
|
||||||
|
c.JSON(http.StatusBadRequest, gin.H{"error": "invalid log file path"})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
info, errStat := os.Stat(fullPath)
|
||||||
|
if errStat != nil {
|
||||||
|
if os.IsNotExist(errStat) {
|
||||||
|
c.JSON(http.StatusNotFound, gin.H{"error": "log file not found"})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
c.JSON(http.StatusInternalServerError, gin.H{"error": fmt.Sprintf("failed to read log file: %v", errStat)})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if info.IsDir() {
|
||||||
|
c.JSON(http.StatusBadRequest, gin.H{"error": "invalid log file"})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
c.FileAttachment(fullPath, matchedFile)
|
||||||
|
}
|
||||||
|
|
||||||
// DownloadRequestErrorLog downloads a specific error request log file by name.
|
// DownloadRequestErrorLog downloads a specific error request log file by name.
|
||||||
func (h *Handler) DownloadRequestErrorLog(c *gin.Context) {
|
func (h *Handler) DownloadRequestErrorLog(c *gin.Context) {
|
||||||
if h == nil {
|
if h == nil {
|
||||||
|
|||||||
@@ -98,10 +98,11 @@ func captureRequestInfo(c *gin.Context) (*RequestInfo, error) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
return &RequestInfo{
|
return &RequestInfo{
|
||||||
URL: url,
|
URL: url,
|
||||||
Method: method,
|
Method: method,
|
||||||
Headers: headers,
|
Headers: headers,
|
||||||
Body: body,
|
Body: body,
|
||||||
|
RequestID: logging.GetGinRequestID(c),
|
||||||
}, nil
|
}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -15,10 +15,11 @@ import (
|
|||||||
|
|
||||||
// RequestInfo holds essential details of an incoming HTTP request for logging purposes.
|
// RequestInfo holds essential details of an incoming HTTP request for logging purposes.
|
||||||
type RequestInfo struct {
|
type RequestInfo struct {
|
||||||
URL string // URL is the request URL.
|
URL string // URL is the request URL.
|
||||||
Method string // Method is the HTTP method (e.g., GET, POST).
|
Method string // Method is the HTTP method (e.g., GET, POST).
|
||||||
Headers map[string][]string // Headers contains the request headers.
|
Headers map[string][]string // Headers contains the request headers.
|
||||||
Body []byte // Body is the raw request body.
|
Body []byte // Body is the raw request body.
|
||||||
|
RequestID string // RequestID is the unique identifier for the request.
|
||||||
}
|
}
|
||||||
|
|
||||||
// ResponseWriterWrapper wraps the standard gin.ResponseWriter to intercept and log response data.
|
// ResponseWriterWrapper wraps the standard gin.ResponseWriter to intercept and log response data.
|
||||||
@@ -149,6 +150,7 @@ func (w *ResponseWriterWrapper) WriteHeader(statusCode int) {
|
|||||||
w.requestInfo.Method,
|
w.requestInfo.Method,
|
||||||
w.requestInfo.Headers,
|
w.requestInfo.Headers,
|
||||||
w.requestInfo.Body,
|
w.requestInfo.Body,
|
||||||
|
w.requestInfo.RequestID,
|
||||||
)
|
)
|
||||||
if err == nil {
|
if err == nil {
|
||||||
w.streamWriter = streamWriter
|
w.streamWriter = streamWriter
|
||||||
@@ -346,7 +348,7 @@ func (w *ResponseWriterWrapper) logRequest(statusCode int, headers map[string][]
|
|||||||
}
|
}
|
||||||
|
|
||||||
if loggerWithOptions, ok := w.logger.(interface {
|
if loggerWithOptions, ok := w.logger.(interface {
|
||||||
LogRequestWithOptions(string, string, map[string][]string, []byte, int, map[string][]string, []byte, []byte, []byte, []*interfaces.ErrorMessage, bool) error
|
LogRequestWithOptions(string, string, map[string][]string, []byte, int, map[string][]string, []byte, []byte, []byte, []*interfaces.ErrorMessage, bool, string) error
|
||||||
}); ok {
|
}); ok {
|
||||||
return loggerWithOptions.LogRequestWithOptions(
|
return loggerWithOptions.LogRequestWithOptions(
|
||||||
w.requestInfo.URL,
|
w.requestInfo.URL,
|
||||||
@@ -360,6 +362,7 @@ func (w *ResponseWriterWrapper) logRequest(statusCode int, headers map[string][]
|
|||||||
apiResponseBody,
|
apiResponseBody,
|
||||||
apiResponseErrors,
|
apiResponseErrors,
|
||||||
forceLog,
|
forceLog,
|
||||||
|
w.requestInfo.RequestID,
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -374,5 +377,6 @@ func (w *ResponseWriterWrapper) logRequest(statusCode int, headers map[string][]
|
|||||||
apiRequestBody,
|
apiRequestBody,
|
||||||
apiResponseBody,
|
apiResponseBody,
|
||||||
apiResponseErrors,
|
apiResponseErrors,
|
||||||
|
w.requestInfo.RequestID,
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -279,16 +279,23 @@ func (m *AmpModule) hasModelMappingsChanged(old *config.AmpCode, new *config.Amp
|
|||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
|
|
||||||
// Build map for efficient comparison
|
// Build map for efficient and robust comparison
|
||||||
oldMap := make(map[string]string, len(old.ModelMappings))
|
type mappingInfo struct {
|
||||||
|
to string
|
||||||
|
regex bool
|
||||||
|
}
|
||||||
|
oldMap := make(map[string]mappingInfo, len(old.ModelMappings))
|
||||||
for _, mapping := range old.ModelMappings {
|
for _, mapping := range old.ModelMappings {
|
||||||
oldMap[strings.TrimSpace(mapping.From)] = strings.TrimSpace(mapping.To)
|
oldMap[strings.TrimSpace(mapping.From)] = mappingInfo{
|
||||||
|
to: strings.TrimSpace(mapping.To),
|
||||||
|
regex: mapping.Regex,
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, mapping := range new.ModelMappings {
|
for _, mapping := range new.ModelMappings {
|
||||||
from := strings.TrimSpace(mapping.From)
|
from := strings.TrimSpace(mapping.From)
|
||||||
to := strings.TrimSpace(mapping.To)
|
to := strings.TrimSpace(mapping.To)
|
||||||
if oldTo, exists := oldMap[from]; !exists || oldTo != to {
|
if oldVal, exists := oldMap[from]; !exists || oldVal.to != to || oldVal.regex != mapping.Regex {
|
||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -3,6 +3,7 @@
|
|||||||
package amp
|
package amp
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"regexp"
|
||||||
"strings"
|
"strings"
|
||||||
"sync"
|
"sync"
|
||||||
|
|
||||||
@@ -26,13 +27,15 @@ type ModelMapper interface {
|
|||||||
// DefaultModelMapper implements ModelMapper with thread-safe mapping storage.
|
// DefaultModelMapper implements ModelMapper with thread-safe mapping storage.
|
||||||
type DefaultModelMapper struct {
|
type DefaultModelMapper struct {
|
||||||
mu sync.RWMutex
|
mu sync.RWMutex
|
||||||
mappings map[string]string // from -> to (normalized lowercase keys)
|
mappings map[string]string // exact: from -> to (normalized lowercase keys)
|
||||||
|
regexps []regexMapping // regex rules evaluated in order
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewModelMapper creates a new model mapper with the given initial mappings.
|
// NewModelMapper creates a new model mapper with the given initial mappings.
|
||||||
func NewModelMapper(mappings []config.AmpModelMapping) *DefaultModelMapper {
|
func NewModelMapper(mappings []config.AmpModelMapping) *DefaultModelMapper {
|
||||||
m := &DefaultModelMapper{
|
m := &DefaultModelMapper{
|
||||||
mappings: make(map[string]string),
|
mappings: make(map[string]string),
|
||||||
|
regexps: nil,
|
||||||
}
|
}
|
||||||
m.UpdateMappings(mappings)
|
m.UpdateMappings(mappings)
|
||||||
return m
|
return m
|
||||||
@@ -55,7 +58,18 @@ func (m *DefaultModelMapper) MapModel(requestedModel string) string {
|
|||||||
// Check for direct mapping
|
// Check for direct mapping
|
||||||
targetModel, exists := m.mappings[normalizedRequest]
|
targetModel, exists := m.mappings[normalizedRequest]
|
||||||
if !exists {
|
if !exists {
|
||||||
return ""
|
// Try regex mappings in order
|
||||||
|
base, _ := util.NormalizeThinkingModel(requestedModel)
|
||||||
|
for _, rm := range m.regexps {
|
||||||
|
if rm.re.MatchString(requestedModel) || (base != "" && rm.re.MatchString(base)) {
|
||||||
|
targetModel = rm.to
|
||||||
|
exists = true
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if !exists {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Verify target model has available providers
|
// Verify target model has available providers
|
||||||
@@ -78,6 +92,7 @@ func (m *DefaultModelMapper) UpdateMappings(mappings []config.AmpModelMapping) {
|
|||||||
|
|
||||||
// Clear and rebuild mappings
|
// Clear and rebuild mappings
|
||||||
m.mappings = make(map[string]string, len(mappings))
|
m.mappings = make(map[string]string, len(mappings))
|
||||||
|
m.regexps = make([]regexMapping, 0, len(mappings))
|
||||||
|
|
||||||
for _, mapping := range mappings {
|
for _, mapping := range mappings {
|
||||||
from := strings.TrimSpace(mapping.From)
|
from := strings.TrimSpace(mapping.From)
|
||||||
@@ -88,16 +103,30 @@ func (m *DefaultModelMapper) UpdateMappings(mappings []config.AmpModelMapping) {
|
|||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
// Store with normalized lowercase key for case-insensitive lookup
|
if mapping.Regex {
|
||||||
normalizedFrom := strings.ToLower(from)
|
// Compile case-insensitive regex; wrap with (?i) to match behavior of exact lookups
|
||||||
m.mappings[normalizedFrom] = to
|
pattern := "(?i)" + from
|
||||||
|
re, err := regexp.Compile(pattern)
|
||||||
log.Debugf("amp model mapping registered: %s -> %s", from, to)
|
if err != nil {
|
||||||
|
log.Warnf("amp model mapping: invalid regex %q: %v", from, err)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
m.regexps = append(m.regexps, regexMapping{re: re, to: to})
|
||||||
|
log.Debugf("amp model regex mapping registered: /%s/ -> %s", from, to)
|
||||||
|
} else {
|
||||||
|
// Store with normalized lowercase key for case-insensitive lookup
|
||||||
|
normalizedFrom := strings.ToLower(from)
|
||||||
|
m.mappings[normalizedFrom] = to
|
||||||
|
log.Debugf("amp model mapping registered: %s -> %s", from, to)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if len(m.mappings) > 0 {
|
if len(m.mappings) > 0 {
|
||||||
log.Infof("amp model mapping: loaded %d mapping(s)", len(m.mappings))
|
log.Infof("amp model mapping: loaded %d mapping(s)", len(m.mappings))
|
||||||
}
|
}
|
||||||
|
if n := len(m.regexps); n > 0 {
|
||||||
|
log.Infof("amp model mapping: loaded %d regex mapping(s)", n)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetMappings returns a copy of current mappings (for debugging/status).
|
// GetMappings returns a copy of current mappings (for debugging/status).
|
||||||
@@ -111,3 +140,8 @@ func (m *DefaultModelMapper) GetMappings() map[string]string {
|
|||||||
}
|
}
|
||||||
return result
|
return result
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type regexMapping struct {
|
||||||
|
re *regexp.Regexp
|
||||||
|
to string
|
||||||
|
}
|
||||||
|
|||||||
@@ -203,3 +203,81 @@ func TestModelMapper_GetMappings_ReturnsCopy(t *testing.T) {
|
|||||||
t.Error("Original map was modified")
|
t.Error("Original map was modified")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestModelMapper_Regex_MatchBaseWithoutParens(t *testing.T) {
|
||||||
|
reg := registry.GetGlobalRegistry()
|
||||||
|
reg.RegisterClient("test-client-regex-1", "gemini", []*registry.ModelInfo{
|
||||||
|
{ID: "gemini-2.5-pro", OwnedBy: "google", Type: "gemini"},
|
||||||
|
})
|
||||||
|
defer reg.UnregisterClient("test-client-regex-1")
|
||||||
|
|
||||||
|
mappings := []config.AmpModelMapping{
|
||||||
|
{From: "^gpt-5$", To: "gemini-2.5-pro", Regex: true},
|
||||||
|
}
|
||||||
|
|
||||||
|
mapper := NewModelMapper(mappings)
|
||||||
|
|
||||||
|
// Incoming model has reasoning suffix but should match base via regex
|
||||||
|
result := mapper.MapModel("gpt-5(high)")
|
||||||
|
if result != "gemini-2.5-pro" {
|
||||||
|
t.Errorf("Expected gemini-2.5-pro, got %s", result)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestModelMapper_Regex_ExactPrecedence(t *testing.T) {
|
||||||
|
reg := registry.GetGlobalRegistry()
|
||||||
|
reg.RegisterClient("test-client-regex-2", "claude", []*registry.ModelInfo{
|
||||||
|
{ID: "claude-sonnet-4", OwnedBy: "anthropic", Type: "claude"},
|
||||||
|
})
|
||||||
|
reg.RegisterClient("test-client-regex-3", "gemini", []*registry.ModelInfo{
|
||||||
|
{ID: "gemini-2.5-pro", OwnedBy: "google", Type: "gemini"},
|
||||||
|
})
|
||||||
|
defer reg.UnregisterClient("test-client-regex-2")
|
||||||
|
defer reg.UnregisterClient("test-client-regex-3")
|
||||||
|
|
||||||
|
mappings := []config.AmpModelMapping{
|
||||||
|
{From: "gpt-5", To: "claude-sonnet-4"}, // exact
|
||||||
|
{From: "^gpt-5.*$", To: "gemini-2.5-pro", Regex: true}, // regex
|
||||||
|
}
|
||||||
|
|
||||||
|
mapper := NewModelMapper(mappings)
|
||||||
|
|
||||||
|
// Exact match should win over regex
|
||||||
|
result := mapper.MapModel("gpt-5")
|
||||||
|
if result != "claude-sonnet-4" {
|
||||||
|
t.Errorf("Expected claude-sonnet-4, got %s", result)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestModelMapper_Regex_InvalidPattern_Skipped(t *testing.T) {
|
||||||
|
// Invalid regex should be skipped and not cause panic
|
||||||
|
mappings := []config.AmpModelMapping{
|
||||||
|
{From: "(", To: "target", Regex: true},
|
||||||
|
}
|
||||||
|
|
||||||
|
mapper := NewModelMapper(mappings)
|
||||||
|
|
||||||
|
result := mapper.MapModel("anything")
|
||||||
|
if result != "" {
|
||||||
|
t.Errorf("Expected empty result due to invalid regex, got %s", result)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestModelMapper_Regex_CaseInsensitive(t *testing.T) {
|
||||||
|
reg := registry.GetGlobalRegistry()
|
||||||
|
reg.RegisterClient("test-client-regex-4", "claude", []*registry.ModelInfo{
|
||||||
|
{ID: "claude-sonnet-4", OwnedBy: "anthropic", Type: "claude"},
|
||||||
|
})
|
||||||
|
defer reg.UnregisterClient("test-client-regex-4")
|
||||||
|
|
||||||
|
mappings := []config.AmpModelMapping{
|
||||||
|
{From: "^CLAUDE-OPUS-.*$", To: "claude-sonnet-4", Regex: true},
|
||||||
|
}
|
||||||
|
|
||||||
|
mapper := NewModelMapper(mappings)
|
||||||
|
|
||||||
|
result := mapper.MapModel("claude-opus-4.5")
|
||||||
|
if result != "claude-sonnet-4" {
|
||||||
|
t.Errorf("Expected claude-sonnet-4, got %s", result)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|||||||
@@ -538,6 +538,7 @@ func (s *Server) registerManagementRoutes() {
|
|||||||
mgmt.DELETE("/logs", s.mgmt.DeleteLogs)
|
mgmt.DELETE("/logs", s.mgmt.DeleteLogs)
|
||||||
mgmt.GET("/request-error-logs", s.mgmt.GetRequestErrorLogs)
|
mgmt.GET("/request-error-logs", s.mgmt.GetRequestErrorLogs)
|
||||||
mgmt.GET("/request-error-logs/:name", s.mgmt.DownloadRequestErrorLog)
|
mgmt.GET("/request-error-logs/:name", s.mgmt.DownloadRequestErrorLog)
|
||||||
|
mgmt.GET("/request-log-by-id/:id", s.mgmt.GetRequestLogByID)
|
||||||
mgmt.GET("/request-log", s.mgmt.GetRequestLog)
|
mgmt.GET("/request-log", s.mgmt.GetRequestLog)
|
||||||
mgmt.PUT("/request-log", s.mgmt.PutRequestLog)
|
mgmt.PUT("/request-log", s.mgmt.PutRequestLog)
|
||||||
mgmt.PATCH("/request-log", s.mgmt.PutRequestLog)
|
mgmt.PATCH("/request-log", s.mgmt.PutRequestLog)
|
||||||
|
|||||||
@@ -40,6 +40,10 @@ type KiroTokenData struct {
|
|||||||
ClientSecret string `json:"clientSecret,omitempty"`
|
ClientSecret string `json:"clientSecret,omitempty"`
|
||||||
// Email is the user's email address (used for file naming)
|
// Email is the user's email address (used for file naming)
|
||||||
Email string `json:"email,omitempty"`
|
Email string `json:"email,omitempty"`
|
||||||
|
// StartURL is the IDC/Identity Center start URL (only for IDC auth method)
|
||||||
|
StartURL string `json:"startUrl,omitempty"`
|
||||||
|
// Region is the AWS region for IDC authentication (only for IDC auth method)
|
||||||
|
Region string `json:"region,omitempty"`
|
||||||
}
|
}
|
||||||
|
|
||||||
// KiroAuthBundle aggregates authentication data after OAuth flow completion
|
// KiroAuthBundle aggregates authentication data after OAuth flow completion
|
||||||
|
|||||||
@@ -2,16 +2,19 @@
|
|||||||
package kiro
|
package kiro
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"bufio"
|
||||||
"context"
|
"context"
|
||||||
"crypto/rand"
|
"crypto/rand"
|
||||||
"crypto/sha256"
|
"crypto/sha256"
|
||||||
"encoding/base64"
|
"encoding/base64"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"html"
|
"html"
|
||||||
"io"
|
"io"
|
||||||
"net"
|
"net"
|
||||||
"net/http"
|
"net/http"
|
||||||
|
"os"
|
||||||
"strings"
|
"strings"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
@@ -24,19 +27,31 @@ import (
|
|||||||
const (
|
const (
|
||||||
// AWS SSO OIDC endpoints
|
// AWS SSO OIDC endpoints
|
||||||
ssoOIDCEndpoint = "https://oidc.us-east-1.amazonaws.com"
|
ssoOIDCEndpoint = "https://oidc.us-east-1.amazonaws.com"
|
||||||
|
|
||||||
// Kiro's start URL for Builder ID
|
// Kiro's start URL for Builder ID
|
||||||
builderIDStartURL = "https://view.awsapps.com/start"
|
builderIDStartURL = "https://view.awsapps.com/start"
|
||||||
|
|
||||||
|
// Default region for IDC
|
||||||
|
defaultIDCRegion = "us-east-1"
|
||||||
|
|
||||||
// Polling interval
|
// Polling interval
|
||||||
pollInterval = 5 * time.Second
|
pollInterval = 5 * time.Second
|
||||||
|
|
||||||
// Authorization code flow callback
|
// Authorization code flow callback
|
||||||
authCodeCallbackPath = "/oauth/callback"
|
authCodeCallbackPath = "/oauth/callback"
|
||||||
authCodeCallbackPort = 19877
|
authCodeCallbackPort = 19877
|
||||||
|
|
||||||
// User-Agent to match official Kiro IDE
|
// User-Agent to match official Kiro IDE
|
||||||
kiroUserAgent = "KiroIDE"
|
kiroUserAgent = "KiroIDE"
|
||||||
|
|
||||||
|
// IDC token refresh headers (matching Kiro IDE behavior)
|
||||||
|
idcAmzUserAgent = "aws-sdk-js/3.738.0 ua/2.1 os/other lang/js md/browser#unknown_unknown api/sso-oidc#3.738.0 m/E KiroIDE"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Sentinel errors for OIDC token polling
|
||||||
|
var (
|
||||||
|
ErrAuthorizationPending = errors.New("authorization_pending")
|
||||||
|
ErrSlowDown = errors.New("slow_down")
|
||||||
)
|
)
|
||||||
|
|
||||||
// SSOOIDCClient handles AWS SSO OIDC authentication.
|
// SSOOIDCClient handles AWS SSO OIDC authentication.
|
||||||
@@ -83,6 +98,440 @@ type CreateTokenResponse struct {
|
|||||||
RefreshToken string `json:"refreshToken"`
|
RefreshToken string `json:"refreshToken"`
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// getOIDCEndpoint returns the OIDC endpoint for the given region.
|
||||||
|
func getOIDCEndpoint(region string) string {
|
||||||
|
if region == "" {
|
||||||
|
region = defaultIDCRegion
|
||||||
|
}
|
||||||
|
return fmt.Sprintf("https://oidc.%s.amazonaws.com", region)
|
||||||
|
}
|
||||||
|
|
||||||
|
// promptInput prompts the user for input with an optional default value.
|
||||||
|
func promptInput(prompt, defaultValue string) string {
|
||||||
|
reader := bufio.NewReader(os.Stdin)
|
||||||
|
if defaultValue != "" {
|
||||||
|
fmt.Printf("%s [%s]: ", prompt, defaultValue)
|
||||||
|
} else {
|
||||||
|
fmt.Printf("%s: ", prompt)
|
||||||
|
}
|
||||||
|
input, err := reader.ReadString('\n')
|
||||||
|
if err != nil {
|
||||||
|
log.Warnf("Error reading input: %v", err)
|
||||||
|
return defaultValue
|
||||||
|
}
|
||||||
|
input = strings.TrimSpace(input)
|
||||||
|
if input == "" {
|
||||||
|
return defaultValue
|
||||||
|
}
|
||||||
|
return input
|
||||||
|
}
|
||||||
|
|
||||||
|
// promptSelect prompts the user to select from options using number input.
|
||||||
|
func promptSelect(prompt string, options []string) int {
|
||||||
|
reader := bufio.NewReader(os.Stdin)
|
||||||
|
|
||||||
|
for {
|
||||||
|
fmt.Println(prompt)
|
||||||
|
for i, opt := range options {
|
||||||
|
fmt.Printf(" %d) %s\n", i+1, opt)
|
||||||
|
}
|
||||||
|
fmt.Printf("Enter selection (1-%d): ", len(options))
|
||||||
|
|
||||||
|
input, err := reader.ReadString('\n')
|
||||||
|
if err != nil {
|
||||||
|
log.Warnf("Error reading input: %v", err)
|
||||||
|
return 0 // Default to first option on error
|
||||||
|
}
|
||||||
|
input = strings.TrimSpace(input)
|
||||||
|
|
||||||
|
// Parse the selection
|
||||||
|
var selection int
|
||||||
|
if _, err := fmt.Sscanf(input, "%d", &selection); err != nil || selection < 1 || selection > len(options) {
|
||||||
|
fmt.Printf("Invalid selection '%s'. Please enter a number between 1 and %d.\n\n", input, len(options))
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
return selection - 1
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// RegisterClientWithRegion registers a new OIDC client with AWS using a specific region.
|
||||||
|
func (c *SSOOIDCClient) RegisterClientWithRegion(ctx context.Context, region string) (*RegisterClientResponse, error) {
|
||||||
|
endpoint := getOIDCEndpoint(region)
|
||||||
|
|
||||||
|
payload := map[string]interface{}{
|
||||||
|
"clientName": "Kiro IDE",
|
||||||
|
"clientType": "public",
|
||||||
|
"scopes": []string{"codewhisperer:completions", "codewhisperer:analysis", "codewhisperer:conversations", "codewhisperer:transformations", "codewhisperer:taskassist"},
|
||||||
|
"grantTypes": []string{"urn:ietf:params:oauth:grant-type:device_code", "refresh_token"},
|
||||||
|
}
|
||||||
|
|
||||||
|
body, err := json.Marshal(payload)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
req, err := http.NewRequestWithContext(ctx, http.MethodPost, endpoint+"/client/register", strings.NewReader(string(body)))
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
req.Header.Set("Content-Type", "application/json")
|
||||||
|
req.Header.Set("User-Agent", kiroUserAgent)
|
||||||
|
|
||||||
|
resp, err := c.httpClient.Do(req)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
defer resp.Body.Close()
|
||||||
|
|
||||||
|
respBody, err := io.ReadAll(resp.Body)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
if resp.StatusCode != http.StatusOK {
|
||||||
|
log.Debugf("register client failed (status %d): %s", resp.StatusCode, string(respBody))
|
||||||
|
return nil, fmt.Errorf("register client failed (status %d)", resp.StatusCode)
|
||||||
|
}
|
||||||
|
|
||||||
|
var result RegisterClientResponse
|
||||||
|
if err := json.Unmarshal(respBody, &result); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
return &result, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// StartDeviceAuthorizationWithIDC starts the device authorization flow for IDC.
|
||||||
|
func (c *SSOOIDCClient) StartDeviceAuthorizationWithIDC(ctx context.Context, clientID, clientSecret, startURL, region string) (*StartDeviceAuthResponse, error) {
|
||||||
|
endpoint := getOIDCEndpoint(region)
|
||||||
|
|
||||||
|
payload := map[string]string{
|
||||||
|
"clientId": clientID,
|
||||||
|
"clientSecret": clientSecret,
|
||||||
|
"startUrl": startURL,
|
||||||
|
}
|
||||||
|
|
||||||
|
body, err := json.Marshal(payload)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
req, err := http.NewRequestWithContext(ctx, http.MethodPost, endpoint+"/device_authorization", strings.NewReader(string(body)))
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
req.Header.Set("Content-Type", "application/json")
|
||||||
|
req.Header.Set("User-Agent", kiroUserAgent)
|
||||||
|
|
||||||
|
resp, err := c.httpClient.Do(req)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
defer resp.Body.Close()
|
||||||
|
|
||||||
|
respBody, err := io.ReadAll(resp.Body)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
if resp.StatusCode != http.StatusOK {
|
||||||
|
log.Debugf("start device auth failed (status %d): %s", resp.StatusCode, string(respBody))
|
||||||
|
return nil, fmt.Errorf("start device auth failed (status %d)", resp.StatusCode)
|
||||||
|
}
|
||||||
|
|
||||||
|
var result StartDeviceAuthResponse
|
||||||
|
if err := json.Unmarshal(respBody, &result); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
return &result, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// CreateTokenWithRegion polls for the access token after user authorization using a specific region.
|
||||||
|
func (c *SSOOIDCClient) CreateTokenWithRegion(ctx context.Context, clientID, clientSecret, deviceCode, region string) (*CreateTokenResponse, error) {
|
||||||
|
endpoint := getOIDCEndpoint(region)
|
||||||
|
|
||||||
|
payload := map[string]string{
|
||||||
|
"clientId": clientID,
|
||||||
|
"clientSecret": clientSecret,
|
||||||
|
"deviceCode": deviceCode,
|
||||||
|
"grantType": "urn:ietf:params:oauth:grant-type:device_code",
|
||||||
|
}
|
||||||
|
|
||||||
|
body, err := json.Marshal(payload)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
req, err := http.NewRequestWithContext(ctx, http.MethodPost, endpoint+"/token", strings.NewReader(string(body)))
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
req.Header.Set("Content-Type", "application/json")
|
||||||
|
req.Header.Set("User-Agent", kiroUserAgent)
|
||||||
|
|
||||||
|
resp, err := c.httpClient.Do(req)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
defer resp.Body.Close()
|
||||||
|
|
||||||
|
respBody, err := io.ReadAll(resp.Body)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check for pending authorization
|
||||||
|
if resp.StatusCode == http.StatusBadRequest {
|
||||||
|
var errResp struct {
|
||||||
|
Error string `json:"error"`
|
||||||
|
}
|
||||||
|
if json.Unmarshal(respBody, &errResp) == nil {
|
||||||
|
if errResp.Error == "authorization_pending" {
|
||||||
|
return nil, ErrAuthorizationPending
|
||||||
|
}
|
||||||
|
if errResp.Error == "slow_down" {
|
||||||
|
return nil, ErrSlowDown
|
||||||
|
}
|
||||||
|
}
|
||||||
|
log.Debugf("create token failed: %s", string(respBody))
|
||||||
|
return nil, fmt.Errorf("create token failed")
|
||||||
|
}
|
||||||
|
|
||||||
|
if resp.StatusCode != http.StatusOK {
|
||||||
|
log.Debugf("create token failed (status %d): %s", resp.StatusCode, string(respBody))
|
||||||
|
return nil, fmt.Errorf("create token failed (status %d)", resp.StatusCode)
|
||||||
|
}
|
||||||
|
|
||||||
|
var result CreateTokenResponse
|
||||||
|
if err := json.Unmarshal(respBody, &result); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
return &result, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// RefreshTokenWithRegion refreshes an access token using the refresh token with a specific region.
|
||||||
|
func (c *SSOOIDCClient) RefreshTokenWithRegion(ctx context.Context, clientID, clientSecret, refreshToken, region, startURL string) (*KiroTokenData, error) {
|
||||||
|
endpoint := getOIDCEndpoint(region)
|
||||||
|
|
||||||
|
payload := map[string]string{
|
||||||
|
"clientId": clientID,
|
||||||
|
"clientSecret": clientSecret,
|
||||||
|
"refreshToken": refreshToken,
|
||||||
|
"grantType": "refresh_token",
|
||||||
|
}
|
||||||
|
|
||||||
|
body, err := json.Marshal(payload)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
req, err := http.NewRequestWithContext(ctx, http.MethodPost, endpoint+"/token", strings.NewReader(string(body)))
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
// Set headers matching kiro2api's IDC token refresh
|
||||||
|
// These headers are required for successful IDC token refresh
|
||||||
|
req.Header.Set("Content-Type", "application/json")
|
||||||
|
req.Header.Set("Host", fmt.Sprintf("oidc.%s.amazonaws.com", region))
|
||||||
|
req.Header.Set("Connection", "keep-alive")
|
||||||
|
req.Header.Set("x-amz-user-agent", idcAmzUserAgent)
|
||||||
|
req.Header.Set("Accept", "*/*")
|
||||||
|
req.Header.Set("Accept-Language", "*")
|
||||||
|
req.Header.Set("sec-fetch-mode", "cors")
|
||||||
|
req.Header.Set("User-Agent", "node")
|
||||||
|
req.Header.Set("Accept-Encoding", "br, gzip, deflate")
|
||||||
|
|
||||||
|
resp, err := c.httpClient.Do(req)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
defer resp.Body.Close()
|
||||||
|
|
||||||
|
respBody, err := io.ReadAll(resp.Body)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
if resp.StatusCode != http.StatusOK {
|
||||||
|
log.Warnf("IDC token refresh failed (status %d): %s", resp.StatusCode, string(respBody))
|
||||||
|
return nil, fmt.Errorf("token refresh failed (status %d)", resp.StatusCode)
|
||||||
|
}
|
||||||
|
|
||||||
|
var result CreateTokenResponse
|
||||||
|
if err := json.Unmarshal(respBody, &result); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
expiresAt := time.Now().Add(time.Duration(result.ExpiresIn) * time.Second)
|
||||||
|
|
||||||
|
return &KiroTokenData{
|
||||||
|
AccessToken: result.AccessToken,
|
||||||
|
RefreshToken: result.RefreshToken,
|
||||||
|
ExpiresAt: expiresAt.Format(time.RFC3339),
|
||||||
|
AuthMethod: "idc",
|
||||||
|
Provider: "AWS",
|
||||||
|
ClientID: clientID,
|
||||||
|
ClientSecret: clientSecret,
|
||||||
|
StartURL: startURL,
|
||||||
|
Region: region,
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// LoginWithIDC performs the full device code flow for AWS Identity Center (IDC).
|
||||||
|
func (c *SSOOIDCClient) LoginWithIDC(ctx context.Context, startURL, region string) (*KiroTokenData, error) {
|
||||||
|
fmt.Println("\n╔══════════════════════════════════════════════════════════╗")
|
||||||
|
fmt.Println("║ Kiro Authentication (AWS Identity Center) ║")
|
||||||
|
fmt.Println("╚══════════════════════════════════════════════════════════╝")
|
||||||
|
|
||||||
|
// Step 1: Register client with the specified region
|
||||||
|
fmt.Println("\nRegistering client...")
|
||||||
|
regResp, err := c.RegisterClientWithRegion(ctx, region)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to register client: %w", err)
|
||||||
|
}
|
||||||
|
log.Debugf("Client registered: %s", regResp.ClientID)
|
||||||
|
|
||||||
|
// Step 2: Start device authorization with IDC start URL
|
||||||
|
fmt.Println("Starting device authorization...")
|
||||||
|
authResp, err := c.StartDeviceAuthorizationWithIDC(ctx, regResp.ClientID, regResp.ClientSecret, startURL, region)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to start device auth: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Step 3: Show user the verification URL
|
||||||
|
fmt.Printf("\n")
|
||||||
|
fmt.Println("════════════════════════════════════════════════════════════")
|
||||||
|
fmt.Printf(" Confirm the following code in the browser:\n")
|
||||||
|
fmt.Printf(" Code: %s\n", authResp.UserCode)
|
||||||
|
fmt.Println("════════════════════════════════════════════════════════════")
|
||||||
|
fmt.Printf("\n Open this URL: %s\n\n", authResp.VerificationURIComplete)
|
||||||
|
|
||||||
|
// Set incognito mode based on config
|
||||||
|
if c.cfg != nil {
|
||||||
|
browser.SetIncognitoMode(c.cfg.IncognitoBrowser)
|
||||||
|
if !c.cfg.IncognitoBrowser {
|
||||||
|
log.Info("kiro: using normal browser mode (--no-incognito). Note: You may not be able to select a different account.")
|
||||||
|
} else {
|
||||||
|
log.Debug("kiro: using incognito mode for multi-account support")
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
browser.SetIncognitoMode(true)
|
||||||
|
log.Debug("kiro: using incognito mode for multi-account support (default)")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Open browser
|
||||||
|
if err := browser.OpenURL(authResp.VerificationURIComplete); err != nil {
|
||||||
|
log.Warnf("Could not open browser automatically: %v", err)
|
||||||
|
fmt.Println(" Please open the URL manually in your browser.")
|
||||||
|
} else {
|
||||||
|
fmt.Println(" (Browser opened automatically)")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Step 4: Poll for token
|
||||||
|
fmt.Println("Waiting for authorization...")
|
||||||
|
|
||||||
|
interval := pollInterval
|
||||||
|
if authResp.Interval > 0 {
|
||||||
|
interval = time.Duration(authResp.Interval) * time.Second
|
||||||
|
}
|
||||||
|
|
||||||
|
deadline := time.Now().Add(time.Duration(authResp.ExpiresIn) * time.Second)
|
||||||
|
|
||||||
|
for time.Now().Before(deadline) {
|
||||||
|
select {
|
||||||
|
case <-ctx.Done():
|
||||||
|
browser.CloseBrowser()
|
||||||
|
return nil, ctx.Err()
|
||||||
|
case <-time.After(interval):
|
||||||
|
tokenResp, err := c.CreateTokenWithRegion(ctx, regResp.ClientID, regResp.ClientSecret, authResp.DeviceCode, region)
|
||||||
|
if err != nil {
|
||||||
|
if errors.Is(err, ErrAuthorizationPending) {
|
||||||
|
fmt.Print(".")
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if errors.Is(err, ErrSlowDown) {
|
||||||
|
interval += 5 * time.Second
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
browser.CloseBrowser()
|
||||||
|
return nil, fmt.Errorf("token creation failed: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
fmt.Println("\n\n✓ Authorization successful!")
|
||||||
|
|
||||||
|
// Close the browser window
|
||||||
|
if err := browser.CloseBrowser(); err != nil {
|
||||||
|
log.Debugf("Failed to close browser: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Step 5: Get profile ARN from CodeWhisperer API
|
||||||
|
fmt.Println("Fetching profile information...")
|
||||||
|
profileArn := c.fetchProfileArn(ctx, tokenResp.AccessToken)
|
||||||
|
|
||||||
|
// Fetch user email
|
||||||
|
email := FetchUserEmailWithFallback(ctx, c.cfg, tokenResp.AccessToken)
|
||||||
|
if email != "" {
|
||||||
|
fmt.Printf(" Logged in as: %s\n", email)
|
||||||
|
}
|
||||||
|
|
||||||
|
expiresAt := time.Now().Add(time.Duration(tokenResp.ExpiresIn) * time.Second)
|
||||||
|
|
||||||
|
return &KiroTokenData{
|
||||||
|
AccessToken: tokenResp.AccessToken,
|
||||||
|
RefreshToken: tokenResp.RefreshToken,
|
||||||
|
ProfileArn: profileArn,
|
||||||
|
ExpiresAt: expiresAt.Format(time.RFC3339),
|
||||||
|
AuthMethod: "idc",
|
||||||
|
Provider: "AWS",
|
||||||
|
ClientID: regResp.ClientID,
|
||||||
|
ClientSecret: regResp.ClientSecret,
|
||||||
|
Email: email,
|
||||||
|
StartURL: startURL,
|
||||||
|
Region: region,
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Close browser on timeout
|
||||||
|
if err := browser.CloseBrowser(); err != nil {
|
||||||
|
log.Debugf("Failed to close browser on timeout: %v", err)
|
||||||
|
}
|
||||||
|
return nil, fmt.Errorf("authorization timed out")
|
||||||
|
}
|
||||||
|
|
||||||
|
// LoginWithMethodSelection prompts the user to select between Builder ID and IDC, then performs the login.
|
||||||
|
func (c *SSOOIDCClient) LoginWithMethodSelection(ctx context.Context) (*KiroTokenData, error) {
|
||||||
|
fmt.Println("\n╔══════════════════════════════════════════════════════════╗")
|
||||||
|
fmt.Println("║ Kiro Authentication (AWS) ║")
|
||||||
|
fmt.Println("╚══════════════════════════════════════════════════════════╝")
|
||||||
|
|
||||||
|
// Prompt for login method
|
||||||
|
options := []string{
|
||||||
|
"Use with Builder ID (personal AWS account)",
|
||||||
|
"Use with IDC Account (organization SSO)",
|
||||||
|
}
|
||||||
|
selection := promptSelect("\n? Select login method:", options)
|
||||||
|
|
||||||
|
if selection == 0 {
|
||||||
|
// Builder ID flow - use existing implementation
|
||||||
|
return c.LoginWithBuilderID(ctx)
|
||||||
|
}
|
||||||
|
|
||||||
|
// IDC flow - prompt for start URL and region
|
||||||
|
fmt.Println()
|
||||||
|
startURL := promptInput("? Enter Start URL", "")
|
||||||
|
if startURL == "" {
|
||||||
|
return nil, fmt.Errorf("start URL is required for IDC login")
|
||||||
|
}
|
||||||
|
|
||||||
|
region := promptInput("? Enter Region", defaultIDCRegion)
|
||||||
|
|
||||||
|
return c.LoginWithIDC(ctx, startURL, region)
|
||||||
|
}
|
||||||
|
|
||||||
// RegisterClient registers a new OIDC client with AWS.
|
// RegisterClient registers a new OIDC client with AWS.
|
||||||
func (c *SSOOIDCClient) RegisterClient(ctx context.Context) (*RegisterClientResponse, error) {
|
func (c *SSOOIDCClient) RegisterClient(ctx context.Context) (*RegisterClientResponse, error) {
|
||||||
payload := map[string]interface{}{
|
payload := map[string]interface{}{
|
||||||
@@ -211,10 +660,10 @@ func (c *SSOOIDCClient) CreateToken(ctx context.Context, clientID, clientSecret,
|
|||||||
}
|
}
|
||||||
if json.Unmarshal(respBody, &errResp) == nil {
|
if json.Unmarshal(respBody, &errResp) == nil {
|
||||||
if errResp.Error == "authorization_pending" {
|
if errResp.Error == "authorization_pending" {
|
||||||
return nil, fmt.Errorf("authorization_pending")
|
return nil, ErrAuthorizationPending
|
||||||
}
|
}
|
||||||
if errResp.Error == "slow_down" {
|
if errResp.Error == "slow_down" {
|
||||||
return nil, fmt.Errorf("slow_down")
|
return nil, ErrSlowDown
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
log.Debugf("create token failed: %s", string(respBody))
|
log.Debugf("create token failed: %s", string(respBody))
|
||||||
@@ -359,12 +808,11 @@ func (c *SSOOIDCClient) LoginWithBuilderID(ctx context.Context) (*KiroTokenData,
|
|||||||
case <-time.After(interval):
|
case <-time.After(interval):
|
||||||
tokenResp, err := c.CreateToken(ctx, regResp.ClientID, regResp.ClientSecret, authResp.DeviceCode)
|
tokenResp, err := c.CreateToken(ctx, regResp.ClientID, regResp.ClientSecret, authResp.DeviceCode)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
errStr := err.Error()
|
if errors.Is(err, ErrAuthorizationPending) {
|
||||||
if strings.Contains(errStr, "authorization_pending") {
|
|
||||||
fmt.Print(".")
|
fmt.Print(".")
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
if strings.Contains(errStr, "slow_down") {
|
if errors.Is(err, ErrSlowDown) {
|
||||||
interval += 5 * time.Second
|
interval += 5 * time.Second
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -156,6 +156,11 @@ type AmpModelMapping struct {
|
|||||||
// To is the target model name to route to (e.g., "claude-sonnet-4").
|
// To is the target model name to route to (e.g., "claude-sonnet-4").
|
||||||
// The target model must have available providers in the registry.
|
// The target model must have available providers in the registry.
|
||||||
To string `yaml:"to" json:"to"`
|
To string `yaml:"to" json:"to"`
|
||||||
|
|
||||||
|
// Regex indicates whether the 'from' field should be interpreted as a regular
|
||||||
|
// expression for matching model names. When true, this mapping is evaluated
|
||||||
|
// after exact matches and in the order provided. Defaults to false (exact match).
|
||||||
|
Regex bool `yaml:"regex,omitempty" json:"regex,omitempty"`
|
||||||
}
|
}
|
||||||
|
|
||||||
// AmpCode groups Amp CLI integration settings including upstream routing,
|
// AmpCode groups Amp CLI integration settings including upstream routing,
|
||||||
@@ -401,7 +406,7 @@ func LoadConfigOptional(configFile string, optional bool) (*Config, error) {
|
|||||||
cfg.DisableCooling = false
|
cfg.DisableCooling = false
|
||||||
cfg.AmpCode.RestrictManagementToLocalhost = false // Default to false: API key auth is sufficient
|
cfg.AmpCode.RestrictManagementToLocalhost = false // Default to false: API key auth is sufficient
|
||||||
cfg.RemoteManagement.PanelGitHubRepository = DefaultPanelGitHubRepository
|
cfg.RemoteManagement.PanelGitHubRepository = DefaultPanelGitHubRepository
|
||||||
cfg.IncognitoBrowser = false // Default to normal browser (AWS uses incognito by force)
|
cfg.IncognitoBrowser = false // Default to normal browser (AWS uses incognito by force)
|
||||||
if err = yaml.Unmarshal(data, &cfg); err != nil {
|
if err = yaml.Unmarshal(data, &cfg); err != nil {
|
||||||
if optional {
|
if optional {
|
||||||
// In cloud deploy mode, if YAML parsing fails, return empty config instead of error.
|
// In cloud deploy mode, if YAML parsing fails, return empty config instead of error.
|
||||||
|
|||||||
@@ -22,6 +22,21 @@ type SDKConfig struct {
|
|||||||
|
|
||||||
// Access holds request authentication provider configuration.
|
// Access holds request authentication provider configuration.
|
||||||
Access AccessConfig `yaml:"auth,omitempty" json:"auth,omitempty"`
|
Access AccessConfig `yaml:"auth,omitempty" json:"auth,omitempty"`
|
||||||
|
|
||||||
|
// Streaming configures server-side streaming behavior (keep-alives and safe bootstrap retries).
|
||||||
|
Streaming StreamingConfig `yaml:"streaming" json:"streaming"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// StreamingConfig holds server streaming behavior configuration.
|
||||||
|
type StreamingConfig struct {
|
||||||
|
// KeepAliveSeconds controls how often the server emits SSE heartbeats (": keep-alive\n\n").
|
||||||
|
// nil means default (15 seconds). <= 0 disables keep-alives.
|
||||||
|
KeepAliveSeconds *int `yaml:"keepalive-seconds,omitempty" json:"keepalive-seconds,omitempty"`
|
||||||
|
|
||||||
|
// BootstrapRetries controls how many times the server may retry a streaming request before any bytes are sent,
|
||||||
|
// to allow auth rotation / transient recovery.
|
||||||
|
// nil means default (2). 0 disables bootstrap retries.
|
||||||
|
BootstrapRetries *int `yaml:"bootstrap-retries,omitempty" json:"bootstrap-retries,omitempty"`
|
||||||
}
|
}
|
||||||
|
|
||||||
// AccessConfig groups request authentication providers.
|
// AccessConfig groups request authentication providers.
|
||||||
|
|||||||
@@ -7,6 +7,7 @@ import (
|
|||||||
"fmt"
|
"fmt"
|
||||||
"net/http"
|
"net/http"
|
||||||
"runtime/debug"
|
"runtime/debug"
|
||||||
|
"strings"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/gin-gonic/gin"
|
"github.com/gin-gonic/gin"
|
||||||
@@ -14,11 +15,24 @@ import (
|
|||||||
log "github.com/sirupsen/logrus"
|
log "github.com/sirupsen/logrus"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
// aiAPIPrefixes defines path prefixes for AI API requests that should have request ID tracking.
|
||||||
|
var aiAPIPrefixes = []string{
|
||||||
|
"/v1/chat/completions",
|
||||||
|
"/v1/completions",
|
||||||
|
"/v1/messages",
|
||||||
|
"/v1/responses",
|
||||||
|
"/v1beta/models/",
|
||||||
|
"/api/provider/",
|
||||||
|
}
|
||||||
|
|
||||||
const skipGinLogKey = "__gin_skip_request_logging__"
|
const skipGinLogKey = "__gin_skip_request_logging__"
|
||||||
|
|
||||||
// GinLogrusLogger returns a Gin middleware handler that logs HTTP requests and responses
|
// GinLogrusLogger returns a Gin middleware handler that logs HTTP requests and responses
|
||||||
// using logrus. It captures request details including method, path, status code, latency,
|
// using logrus. It captures request details including method, path, status code, latency,
|
||||||
// client IP, and any error messages, formatting them in a Gin-style log format.
|
// client IP, and any error messages. Request ID is only added for AI API requests.
|
||||||
|
//
|
||||||
|
// Output format (AI API): [2025-12-23 20:14:10] [info ] | a1b2c3d4 | 200 | 23.559s | ...
|
||||||
|
// Output format (others): [2025-12-23 20:14:10] [info ] | -------- | 200 | 23.559s | ...
|
||||||
//
|
//
|
||||||
// Returns:
|
// Returns:
|
||||||
// - gin.HandlerFunc: A middleware handler for request logging
|
// - gin.HandlerFunc: A middleware handler for request logging
|
||||||
@@ -28,6 +42,15 @@ func GinLogrusLogger() gin.HandlerFunc {
|
|||||||
path := c.Request.URL.Path
|
path := c.Request.URL.Path
|
||||||
raw := util.MaskSensitiveQuery(c.Request.URL.RawQuery)
|
raw := util.MaskSensitiveQuery(c.Request.URL.RawQuery)
|
||||||
|
|
||||||
|
// Only generate request ID for AI API paths
|
||||||
|
var requestID string
|
||||||
|
if isAIAPIPath(path) {
|
||||||
|
requestID = GenerateRequestID()
|
||||||
|
SetGinRequestID(c, requestID)
|
||||||
|
ctx := WithRequestID(c.Request.Context(), requestID)
|
||||||
|
c.Request = c.Request.WithContext(ctx)
|
||||||
|
}
|
||||||
|
|
||||||
c.Next()
|
c.Next()
|
||||||
|
|
||||||
if shouldSkipGinRequestLogging(c) {
|
if shouldSkipGinRequestLogging(c) {
|
||||||
@@ -49,23 +72,40 @@ func GinLogrusLogger() gin.HandlerFunc {
|
|||||||
clientIP := c.ClientIP()
|
clientIP := c.ClientIP()
|
||||||
method := c.Request.Method
|
method := c.Request.Method
|
||||||
errorMessage := c.Errors.ByType(gin.ErrorTypePrivate).String()
|
errorMessage := c.Errors.ByType(gin.ErrorTypePrivate).String()
|
||||||
timestamp := time.Now().Format("2006/01/02 - 15:04:05")
|
|
||||||
logLine := fmt.Sprintf("[GIN] %s | %3d | %13v | %15s | %-7s \"%s\"", timestamp, statusCode, latency, clientIP, method, path)
|
logLine := fmt.Sprintf("%3d | %13v | %15s | %-7s \"%s\"", statusCode, latency, clientIP, method, path)
|
||||||
if errorMessage != "" {
|
if errorMessage != "" {
|
||||||
logLine = logLine + " | " + errorMessage
|
logLine = logLine + " | " + errorMessage
|
||||||
}
|
}
|
||||||
|
|
||||||
|
var entry *log.Entry
|
||||||
|
if requestID != "" {
|
||||||
|
entry = log.WithField("request_id", requestID)
|
||||||
|
} else {
|
||||||
|
entry = log.WithField("request_id", "--------")
|
||||||
|
}
|
||||||
|
|
||||||
switch {
|
switch {
|
||||||
case statusCode >= http.StatusInternalServerError:
|
case statusCode >= http.StatusInternalServerError:
|
||||||
log.Error(logLine)
|
entry.Error(logLine)
|
||||||
case statusCode >= http.StatusBadRequest:
|
case statusCode >= http.StatusBadRequest:
|
||||||
log.Warn(logLine)
|
entry.Warn(logLine)
|
||||||
default:
|
default:
|
||||||
log.Info(logLine)
|
entry.Info(logLine)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// isAIAPIPath checks if the given path is an AI API endpoint that should have request ID tracking.
|
||||||
|
func isAIAPIPath(path string) bool {
|
||||||
|
for _, prefix := range aiAPIPrefixes {
|
||||||
|
if strings.HasPrefix(path, prefix) {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
// GinLogrusRecovery returns a Gin middleware handler that recovers from panics and logs
|
// GinLogrusRecovery returns a Gin middleware handler that recovers from panics and logs
|
||||||
// them using logrus. When a panic occurs, it captures the panic value, stack trace,
|
// them using logrus. When a panic occurs, it captures the panic value, stack trace,
|
||||||
// and request path, then returns a 500 Internal Server Error response to the client.
|
// and request path, then returns a 500 Internal Server Error response to the client.
|
||||||
|
|||||||
@@ -24,7 +24,8 @@ var (
|
|||||||
)
|
)
|
||||||
|
|
||||||
// LogFormatter defines a custom log format for logrus.
|
// LogFormatter defines a custom log format for logrus.
|
||||||
// This formatter adds timestamp, level, and source location to each log entry.
|
// This formatter adds timestamp, level, request ID, and source location to each log entry.
|
||||||
|
// Format: [2025-12-23 20:14:04] [debug] [manager.go:524] | a1b2c3d4 | Use API key sk-9...0RHO for model gpt-5.2
|
||||||
type LogFormatter struct{}
|
type LogFormatter struct{}
|
||||||
|
|
||||||
// Format renders a single log entry with custom formatting.
|
// Format renders a single log entry with custom formatting.
|
||||||
@@ -38,16 +39,27 @@ func (m *LogFormatter) Format(entry *log.Entry) ([]byte, error) {
|
|||||||
|
|
||||||
timestamp := entry.Time.Format("2006-01-02 15:04:05")
|
timestamp := entry.Time.Format("2006-01-02 15:04:05")
|
||||||
message := strings.TrimRight(entry.Message, "\r\n")
|
message := strings.TrimRight(entry.Message, "\r\n")
|
||||||
|
|
||||||
// Handle nil Caller (can happen with some log entries)
|
reqID := ""
|
||||||
|
if id, ok := entry.Data["request_id"].(string); ok && id != "" {
|
||||||
|
reqID = id
|
||||||
|
}
|
||||||
|
|
||||||
callerFile := "unknown"
|
callerFile := "unknown"
|
||||||
callerLine := 0
|
callerLine := 0
|
||||||
if entry.Caller != nil {
|
if entry.Caller != nil {
|
||||||
callerFile = filepath.Base(entry.Caller.File)
|
callerFile = filepath.Base(entry.Caller.File)
|
||||||
callerLine = entry.Caller.Line
|
callerLine = entry.Caller.Line
|
||||||
}
|
}
|
||||||
|
|
||||||
formatted := fmt.Sprintf("[%s] [%s] [%s:%d] %s\n", timestamp, entry.Level, callerFile, callerLine, message)
|
levelStr := fmt.Sprintf("%-5s", entry.Level.String())
|
||||||
|
|
||||||
|
var formatted string
|
||||||
|
if reqID != "" {
|
||||||
|
formatted = fmt.Sprintf("[%s] [%s] [%s:%d] | %s | %s\n", timestamp, levelStr, callerFile, callerLine, reqID, message)
|
||||||
|
} else {
|
||||||
|
formatted = fmt.Sprintf("[%s] [%s] [%s:%d] %s\n", timestamp, levelStr, callerFile, callerLine, message)
|
||||||
|
}
|
||||||
buffer.WriteString(formatted)
|
buffer.WriteString(formatted)
|
||||||
|
|
||||||
return buffer.Bytes(), nil
|
return buffer.Bytes(), nil
|
||||||
|
|||||||
@@ -43,10 +43,11 @@ type RequestLogger interface {
|
|||||||
// - response: The raw response data
|
// - response: The raw response data
|
||||||
// - apiRequest: The API request data
|
// - apiRequest: The API request data
|
||||||
// - apiResponse: The API response data
|
// - apiResponse: The API response data
|
||||||
|
// - requestID: Optional request ID for log file naming
|
||||||
//
|
//
|
||||||
// Returns:
|
// Returns:
|
||||||
// - error: An error if logging fails, nil otherwise
|
// - error: An error if logging fails, nil otherwise
|
||||||
LogRequest(url, method string, requestHeaders map[string][]string, body []byte, statusCode int, responseHeaders map[string][]string, response, apiRequest, apiResponse []byte, apiResponseErrors []*interfaces.ErrorMessage) error
|
LogRequest(url, method string, requestHeaders map[string][]string, body []byte, statusCode int, responseHeaders map[string][]string, response, apiRequest, apiResponse []byte, apiResponseErrors []*interfaces.ErrorMessage, requestID string) error
|
||||||
|
|
||||||
// LogStreamingRequest initiates logging for a streaming request and returns a writer for chunks.
|
// LogStreamingRequest initiates logging for a streaming request and returns a writer for chunks.
|
||||||
//
|
//
|
||||||
@@ -55,11 +56,12 @@ type RequestLogger interface {
|
|||||||
// - method: The HTTP method
|
// - method: The HTTP method
|
||||||
// - headers: The request headers
|
// - headers: The request headers
|
||||||
// - body: The request body
|
// - body: The request body
|
||||||
|
// - requestID: Optional request ID for log file naming
|
||||||
//
|
//
|
||||||
// Returns:
|
// Returns:
|
||||||
// - StreamingLogWriter: A writer for streaming response chunks
|
// - StreamingLogWriter: A writer for streaming response chunks
|
||||||
// - error: An error if logging initialization fails, nil otherwise
|
// - error: An error if logging initialization fails, nil otherwise
|
||||||
LogStreamingRequest(url, method string, headers map[string][]string, body []byte) (StreamingLogWriter, error)
|
LogStreamingRequest(url, method string, headers map[string][]string, body []byte, requestID string) (StreamingLogWriter, error)
|
||||||
|
|
||||||
// IsEnabled returns whether request logging is currently enabled.
|
// IsEnabled returns whether request logging is currently enabled.
|
||||||
//
|
//
|
||||||
@@ -177,20 +179,21 @@ func (l *FileRequestLogger) SetEnabled(enabled bool) {
|
|||||||
// - response: The raw response data
|
// - response: The raw response data
|
||||||
// - apiRequest: The API request data
|
// - apiRequest: The API request data
|
||||||
// - apiResponse: The API response data
|
// - apiResponse: The API response data
|
||||||
|
// - requestID: Optional request ID for log file naming
|
||||||
//
|
//
|
||||||
// Returns:
|
// Returns:
|
||||||
// - error: An error if logging fails, nil otherwise
|
// - error: An error if logging fails, nil otherwise
|
||||||
func (l *FileRequestLogger) LogRequest(url, method string, requestHeaders map[string][]string, body []byte, statusCode int, responseHeaders map[string][]string, response, apiRequest, apiResponse []byte, apiResponseErrors []*interfaces.ErrorMessage) error {
|
func (l *FileRequestLogger) LogRequest(url, method string, requestHeaders map[string][]string, body []byte, statusCode int, responseHeaders map[string][]string, response, apiRequest, apiResponse []byte, apiResponseErrors []*interfaces.ErrorMessage, requestID string) error {
|
||||||
return l.logRequest(url, method, requestHeaders, body, statusCode, responseHeaders, response, apiRequest, apiResponse, apiResponseErrors, false)
|
return l.logRequest(url, method, requestHeaders, body, statusCode, responseHeaders, response, apiRequest, apiResponse, apiResponseErrors, false, requestID)
|
||||||
}
|
}
|
||||||
|
|
||||||
// LogRequestWithOptions logs a request with optional forced logging behavior.
|
// LogRequestWithOptions logs a request with optional forced logging behavior.
|
||||||
// The force flag allows writing error logs even when regular request logging is disabled.
|
// The force flag allows writing error logs even when regular request logging is disabled.
|
||||||
func (l *FileRequestLogger) LogRequestWithOptions(url, method string, requestHeaders map[string][]string, body []byte, statusCode int, responseHeaders map[string][]string, response, apiRequest, apiResponse []byte, apiResponseErrors []*interfaces.ErrorMessage, force bool) error {
|
func (l *FileRequestLogger) LogRequestWithOptions(url, method string, requestHeaders map[string][]string, body []byte, statusCode int, responseHeaders map[string][]string, response, apiRequest, apiResponse []byte, apiResponseErrors []*interfaces.ErrorMessage, force bool, requestID string) error {
|
||||||
return l.logRequest(url, method, requestHeaders, body, statusCode, responseHeaders, response, apiRequest, apiResponse, apiResponseErrors, force)
|
return l.logRequest(url, method, requestHeaders, body, statusCode, responseHeaders, response, apiRequest, apiResponse, apiResponseErrors, force, requestID)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (l *FileRequestLogger) logRequest(url, method string, requestHeaders map[string][]string, body []byte, statusCode int, responseHeaders map[string][]string, response, apiRequest, apiResponse []byte, apiResponseErrors []*interfaces.ErrorMessage, force bool) error {
|
func (l *FileRequestLogger) logRequest(url, method string, requestHeaders map[string][]string, body []byte, statusCode int, responseHeaders map[string][]string, response, apiRequest, apiResponse []byte, apiResponseErrors []*interfaces.ErrorMessage, force bool, requestID string) error {
|
||||||
if !l.enabled && !force {
|
if !l.enabled && !force {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
@@ -200,10 +203,10 @@ func (l *FileRequestLogger) logRequest(url, method string, requestHeaders map[st
|
|||||||
return fmt.Errorf("failed to create logs directory: %w", errEnsure)
|
return fmt.Errorf("failed to create logs directory: %w", errEnsure)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Generate filename
|
// Generate filename with request ID
|
||||||
filename := l.generateFilename(url)
|
filename := l.generateFilename(url, requestID)
|
||||||
if force && !l.enabled {
|
if force && !l.enabled {
|
||||||
filename = l.generateErrorFilename(url)
|
filename = l.generateErrorFilename(url, requestID)
|
||||||
}
|
}
|
||||||
filePath := filepath.Join(l.logsDir, filename)
|
filePath := filepath.Join(l.logsDir, filename)
|
||||||
|
|
||||||
@@ -271,11 +274,12 @@ func (l *FileRequestLogger) logRequest(url, method string, requestHeaders map[st
|
|||||||
// - method: The HTTP method
|
// - method: The HTTP method
|
||||||
// - headers: The request headers
|
// - headers: The request headers
|
||||||
// - body: The request body
|
// - body: The request body
|
||||||
|
// - requestID: Optional request ID for log file naming
|
||||||
//
|
//
|
||||||
// Returns:
|
// Returns:
|
||||||
// - StreamingLogWriter: A writer for streaming response chunks
|
// - StreamingLogWriter: A writer for streaming response chunks
|
||||||
// - error: An error if logging initialization fails, nil otherwise
|
// - error: An error if logging initialization fails, nil otherwise
|
||||||
func (l *FileRequestLogger) LogStreamingRequest(url, method string, headers map[string][]string, body []byte) (StreamingLogWriter, error) {
|
func (l *FileRequestLogger) LogStreamingRequest(url, method string, headers map[string][]string, body []byte, requestID string) (StreamingLogWriter, error) {
|
||||||
if !l.enabled {
|
if !l.enabled {
|
||||||
return &NoOpStreamingLogWriter{}, nil
|
return &NoOpStreamingLogWriter{}, nil
|
||||||
}
|
}
|
||||||
@@ -285,8 +289,8 @@ func (l *FileRequestLogger) LogStreamingRequest(url, method string, headers map[
|
|||||||
return nil, fmt.Errorf("failed to create logs directory: %w", err)
|
return nil, fmt.Errorf("failed to create logs directory: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Generate filename
|
// Generate filename with request ID
|
||||||
filename := l.generateFilename(url)
|
filename := l.generateFilename(url, requestID)
|
||||||
filePath := filepath.Join(l.logsDir, filename)
|
filePath := filepath.Join(l.logsDir, filename)
|
||||||
|
|
||||||
requestHeaders := make(map[string][]string, len(headers))
|
requestHeaders := make(map[string][]string, len(headers))
|
||||||
@@ -330,8 +334,8 @@ func (l *FileRequestLogger) LogStreamingRequest(url, method string, headers map[
|
|||||||
}
|
}
|
||||||
|
|
||||||
// generateErrorFilename creates a filename with an error prefix to differentiate forced error logs.
|
// generateErrorFilename creates a filename with an error prefix to differentiate forced error logs.
|
||||||
func (l *FileRequestLogger) generateErrorFilename(url string) string {
|
func (l *FileRequestLogger) generateErrorFilename(url string, requestID ...string) string {
|
||||||
return fmt.Sprintf("error-%s", l.generateFilename(url))
|
return fmt.Sprintf("error-%s", l.generateFilename(url, requestID...))
|
||||||
}
|
}
|
||||||
|
|
||||||
// ensureLogsDir creates the logs directory if it doesn't exist.
|
// ensureLogsDir creates the logs directory if it doesn't exist.
|
||||||
@@ -346,13 +350,15 @@ func (l *FileRequestLogger) ensureLogsDir() error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// generateFilename creates a sanitized filename from the URL path and current timestamp.
|
// generateFilename creates a sanitized filename from the URL path and current timestamp.
|
||||||
|
// Format: v1-responses-2025-12-23T195811-a1b2c3d4.log
|
||||||
//
|
//
|
||||||
// Parameters:
|
// Parameters:
|
||||||
// - url: The request URL
|
// - url: The request URL
|
||||||
|
// - requestID: Optional request ID to include in filename
|
||||||
//
|
//
|
||||||
// Returns:
|
// Returns:
|
||||||
// - string: A sanitized filename for the log file
|
// - string: A sanitized filename for the log file
|
||||||
func (l *FileRequestLogger) generateFilename(url string) string {
|
func (l *FileRequestLogger) generateFilename(url string, requestID ...string) string {
|
||||||
// Extract path from URL
|
// Extract path from URL
|
||||||
path := url
|
path := url
|
||||||
if strings.Contains(url, "?") {
|
if strings.Contains(url, "?") {
|
||||||
@@ -368,12 +374,18 @@ func (l *FileRequestLogger) generateFilename(url string) string {
|
|||||||
sanitized := l.sanitizeForFilename(path)
|
sanitized := l.sanitizeForFilename(path)
|
||||||
|
|
||||||
// Add timestamp
|
// Add timestamp
|
||||||
timestamp := time.Now().Format("2006-01-02T150405-.000000000")
|
timestamp := time.Now().Format("2006-01-02T150405")
|
||||||
timestamp = strings.Replace(timestamp, ".", "", -1)
|
|
||||||
|
|
||||||
id := requestLogID.Add(1)
|
// Use request ID if provided, otherwise use sequential ID
|
||||||
|
var idPart string
|
||||||
|
if len(requestID) > 0 && requestID[0] != "" {
|
||||||
|
idPart = requestID[0]
|
||||||
|
} else {
|
||||||
|
id := requestLogID.Add(1)
|
||||||
|
idPart = fmt.Sprintf("%d", id)
|
||||||
|
}
|
||||||
|
|
||||||
return fmt.Sprintf("%s-%s-%d.log", sanitized, timestamp, id)
|
return fmt.Sprintf("%s-%s-%s.log", sanitized, timestamp, idPart)
|
||||||
}
|
}
|
||||||
|
|
||||||
// sanitizeForFilename replaces characters that are not safe for filenames.
|
// sanitizeForFilename replaces characters that are not safe for filenames.
|
||||||
|
|||||||
61
internal/logging/requestid.go
Normal file
61
internal/logging/requestid.go
Normal file
@@ -0,0 +1,61 @@
|
|||||||
|
package logging
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"crypto/rand"
|
||||||
|
"encoding/hex"
|
||||||
|
|
||||||
|
"github.com/gin-gonic/gin"
|
||||||
|
)
|
||||||
|
|
||||||
|
// requestIDKey is the context key for storing/retrieving request IDs.
|
||||||
|
type requestIDKey struct{}
|
||||||
|
|
||||||
|
// ginRequestIDKey is the Gin context key for request IDs.
|
||||||
|
const ginRequestIDKey = "__request_id__"
|
||||||
|
|
||||||
|
// GenerateRequestID creates a new 8-character hex request ID.
|
||||||
|
func GenerateRequestID() string {
|
||||||
|
b := make([]byte, 4)
|
||||||
|
if _, err := rand.Read(b); err != nil {
|
||||||
|
return "00000000"
|
||||||
|
}
|
||||||
|
return hex.EncodeToString(b)
|
||||||
|
}
|
||||||
|
|
||||||
|
// WithRequestID returns a new context with the request ID attached.
|
||||||
|
func WithRequestID(ctx context.Context, requestID string) context.Context {
|
||||||
|
return context.WithValue(ctx, requestIDKey{}, requestID)
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetRequestID retrieves the request ID from the context.
|
||||||
|
// Returns empty string if not found.
|
||||||
|
func GetRequestID(ctx context.Context) string {
|
||||||
|
if ctx == nil {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
if id, ok := ctx.Value(requestIDKey{}).(string); ok {
|
||||||
|
return id
|
||||||
|
}
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetGinRequestID stores the request ID in the Gin context.
|
||||||
|
func SetGinRequestID(c *gin.Context, requestID string) {
|
||||||
|
if c != nil {
|
||||||
|
c.Set(ginRequestIDKey, requestID)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetGinRequestID retrieves the request ID from the Gin context.
|
||||||
|
func GetGinRequestID(c *gin.Context) string {
|
||||||
|
if c == nil {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
if id, exists := c.Get(ginRequestIDKey); exists {
|
||||||
|
if s, ok := id.(string); ok {
|
||||||
|
return s
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return ""
|
||||||
|
}
|
||||||
@@ -727,6 +727,7 @@ func GetIFlowModels() []*ModelInfo {
|
|||||||
{ID: "qwen3-max-preview", DisplayName: "Qwen3-Max-Preview", Description: "Qwen3 Max preview build", Created: 1757030400},
|
{ID: "qwen3-max-preview", DisplayName: "Qwen3-Max-Preview", Description: "Qwen3 Max preview build", Created: 1757030400},
|
||||||
{ID: "kimi-k2-0905", DisplayName: "Kimi-K2-Instruct-0905", Description: "Moonshot Kimi K2 instruct 0905", Created: 1757030400},
|
{ID: "kimi-k2-0905", DisplayName: "Kimi-K2-Instruct-0905", Description: "Moonshot Kimi K2 instruct 0905", Created: 1757030400},
|
||||||
{ID: "glm-4.6", DisplayName: "GLM-4.6", Description: "Zhipu GLM 4.6 general model", Created: 1759190400, Thinking: iFlowThinkingSupport},
|
{ID: "glm-4.6", DisplayName: "GLM-4.6", Description: "Zhipu GLM 4.6 general model", Created: 1759190400, Thinking: iFlowThinkingSupport},
|
||||||
|
{ID: "glm-4.7", DisplayName: "GLM-4.7", Description: "Zhipu GLM 4.7 general model", Created: 1766448000, Thinking: iFlowThinkingSupport},
|
||||||
{ID: "kimi-k2", DisplayName: "Kimi-K2", Description: "Moonshot Kimi K2 general model", Created: 1752192000},
|
{ID: "kimi-k2", DisplayName: "Kimi-K2", Description: "Moonshot Kimi K2 general model", Created: 1752192000},
|
||||||
{ID: "kimi-k2-thinking", DisplayName: "Kimi-K2-Thinking", Description: "Moonshot Kimi K2 thinking model", Created: 1762387200},
|
{ID: "kimi-k2-thinking", DisplayName: "Kimi-K2-Thinking", Description: "Moonshot Kimi K2 thinking model", Created: 1762387200},
|
||||||
{ID: "deepseek-v3.2-chat", DisplayName: "DeepSeek-V3.2", Description: "DeepSeek V3.2 Chat", Created: 1764576000},
|
{ID: "deepseek-v3.2-chat", DisplayName: "DeepSeek-V3.2", Description: "DeepSeek V3.2 Chat", Created: 1764576000},
|
||||||
@@ -740,6 +741,7 @@ func GetIFlowModels() []*ModelInfo {
|
|||||||
{ID: "qwen3-235b-a22b-instruct", DisplayName: "Qwen3-235B-A22B-Instruct", Description: "Qwen3 235B A22B Instruct", Created: 1753401600},
|
{ID: "qwen3-235b-a22b-instruct", DisplayName: "Qwen3-235B-A22B-Instruct", Description: "Qwen3 235B A22B Instruct", Created: 1753401600},
|
||||||
{ID: "qwen3-235b", DisplayName: "Qwen3-235B-A22B", Description: "Qwen3 235B A22B", Created: 1753401600},
|
{ID: "qwen3-235b", DisplayName: "Qwen3-235B-A22B", Description: "Qwen3 235B A22B", Created: 1753401600},
|
||||||
{ID: "minimax-m2", DisplayName: "MiniMax-M2", Description: "MiniMax M2", Created: 1758672000},
|
{ID: "minimax-m2", DisplayName: "MiniMax-M2", Description: "MiniMax M2", Created: 1758672000},
|
||||||
|
{ID: "minimax-m2.1", DisplayName: "MiniMax-M2.1", Description: "MiniMax M2.1", Created: 1766448000},
|
||||||
}
|
}
|
||||||
models := make([]*ModelInfo, 0, len(entries))
|
models := make([]*ModelInfo, 0, len(entries))
|
||||||
for _, entry := range entries {
|
for _, entry := range entries {
|
||||||
|
|||||||
@@ -42,7 +42,7 @@ const (
|
|||||||
antigravityModelsPath = "/v1internal:fetchAvailableModels"
|
antigravityModelsPath = "/v1internal:fetchAvailableModels"
|
||||||
antigravityClientID = "1071006060591-tmhssin2h21lcre235vtolojh4g403ep.apps.googleusercontent.com"
|
antigravityClientID = "1071006060591-tmhssin2h21lcre235vtolojh4g403ep.apps.googleusercontent.com"
|
||||||
antigravityClientSecret = "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"
|
antigravityClientSecret = "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"
|
||||||
defaultAntigravityAgent = "antigravity/1.11.5 windows/amd64"
|
defaultAntigravityAgent = "antigravity/1.104.0 darwin/arm64"
|
||||||
antigravityAuthType = "antigravity"
|
antigravityAuthType = "antigravity"
|
||||||
refreshSkew = 3000 * time.Second
|
refreshSkew = 3000 * time.Second
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -662,7 +662,14 @@ func decodeResponseBody(body io.ReadCloser, contentEncoding string) (io.ReadClos
|
|||||||
}
|
}
|
||||||
|
|
||||||
func applyClaudeHeaders(r *http.Request, auth *cliproxyauth.Auth, apiKey string, stream bool, extraBetas []string) {
|
func applyClaudeHeaders(r *http.Request, auth *cliproxyauth.Auth, apiKey string, stream bool, extraBetas []string) {
|
||||||
r.Header.Set("Authorization", "Bearer "+apiKey)
|
useAPIKey := auth != nil && auth.Attributes != nil && strings.TrimSpace(auth.Attributes["api_key"]) != ""
|
||||||
|
isAnthropicBase := r.URL != nil && strings.EqualFold(r.URL.Scheme, "https") && strings.EqualFold(r.URL.Host, "api.anthropic.com")
|
||||||
|
if isAnthropicBase && useAPIKey {
|
||||||
|
r.Header.Del("Authorization")
|
||||||
|
r.Header.Set("x-api-key", apiKey)
|
||||||
|
} else {
|
||||||
|
r.Header.Set("Authorization", "Bearer "+apiKey)
|
||||||
|
}
|
||||||
r.Header.Set("Content-Type", "application/json")
|
r.Header.Set("Content-Type", "application/json")
|
||||||
|
|
||||||
var ginHeaders http.Header
|
var ginHeaders http.Header
|
||||||
|
|||||||
@@ -10,6 +10,8 @@ import (
|
|||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
"net/http"
|
"net/http"
|
||||||
|
"os"
|
||||||
|
"path/filepath"
|
||||||
"strings"
|
"strings"
|
||||||
"sync"
|
"sync"
|
||||||
"time"
|
"time"
|
||||||
@@ -43,10 +45,15 @@ const (
|
|||||||
// Event Stream error type constants
|
// Event Stream error type constants
|
||||||
ErrStreamFatal = "fatal" // Connection/authentication errors, not recoverable
|
ErrStreamFatal = "fatal" // Connection/authentication errors, not recoverable
|
||||||
ErrStreamMalformed = "malformed" // Format errors, data cannot be parsed
|
ErrStreamMalformed = "malformed" // Format errors, data cannot be parsed
|
||||||
// kiroUserAgent matches amq2api format for User-Agent header
|
// kiroUserAgent matches amq2api format for User-Agent header (Amazon Q CLI style)
|
||||||
kiroUserAgent = "aws-sdk-rust/1.3.9 os/macos lang/rust/1.87.0"
|
kiroUserAgent = "aws-sdk-rust/1.3.9 os/macos lang/rust/1.87.0"
|
||||||
// kiroFullUserAgent is the complete x-amz-user-agent header matching amq2api
|
// kiroFullUserAgent is the complete x-amz-user-agent header matching amq2api (Amazon Q CLI style)
|
||||||
kiroFullUserAgent = "aws-sdk-rust/1.3.9 ua/2.1 api/ssooidc/1.88.0 os/macos lang/rust/1.87.0 m/E app/AmazonQ-For-CLI"
|
kiroFullUserAgent = "aws-sdk-rust/1.3.9 ua/2.1 api/ssooidc/1.88.0 os/macos lang/rust/1.87.0 m/E app/AmazonQ-For-CLI"
|
||||||
|
|
||||||
|
// Kiro IDE style headers (from kiro2api - for IDC auth)
|
||||||
|
kiroIDEUserAgent = "aws-sdk-js/1.0.18 ua/2.1 os/darwin#25.0.0 lang/js md/nodejs#20.16.0 api/codewhispererstreaming#1.0.18 m/E KiroIDE-0.2.13-66c23a8c5d15afabec89ef9954ef52a119f10d369df04d548fc6c1eac694b0d1"
|
||||||
|
kiroIDEAmzUserAgent = "aws-sdk-js/1.0.18 KiroIDE-0.2.13-66c23a8c5d15afabec89ef9954ef52a119f10d369df04d548fc6c1eac694b0d1"
|
||||||
|
kiroIDEAgentModeSpec = "spec"
|
||||||
)
|
)
|
||||||
|
|
||||||
// Real-time usage estimation configuration
|
// Real-time usage estimation configuration
|
||||||
@@ -101,11 +108,24 @@ var kiroEndpointConfigs = []kiroEndpointConfig{
|
|||||||
|
|
||||||
// getKiroEndpointConfigs returns the list of Kiro API endpoint configurations to try in order.
|
// getKiroEndpointConfigs returns the list of Kiro API endpoint configurations to try in order.
|
||||||
// Supports reordering based on "preferred_endpoint" in auth metadata/attributes.
|
// Supports reordering based on "preferred_endpoint" in auth metadata/attributes.
|
||||||
|
// For IDC auth method, automatically uses CodeWhisperer endpoint with CLI origin.
|
||||||
func getKiroEndpointConfigs(auth *cliproxyauth.Auth) []kiroEndpointConfig {
|
func getKiroEndpointConfigs(auth *cliproxyauth.Auth) []kiroEndpointConfig {
|
||||||
if auth == nil {
|
if auth == nil {
|
||||||
return kiroEndpointConfigs
|
return kiroEndpointConfigs
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// For IDC auth, use CodeWhisperer endpoint with AI_EDITOR origin (same as Social auth)
|
||||||
|
// Based on kiro2api analysis: IDC tokens work with CodeWhisperer endpoint using Bearer auth
|
||||||
|
// The difference is only in how tokens are refreshed (OIDC with clientId/clientSecret for IDC)
|
||||||
|
// NOT in how API calls are made - both Social and IDC use the same endpoint/origin
|
||||||
|
if auth.Metadata != nil {
|
||||||
|
authMethod, _ := auth.Metadata["auth_method"].(string)
|
||||||
|
if authMethod == "idc" {
|
||||||
|
log.Debugf("kiro: IDC auth, using CodeWhisperer endpoint")
|
||||||
|
return kiroEndpointConfigs
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// Check for preference
|
// Check for preference
|
||||||
var preference string
|
var preference string
|
||||||
if auth.Metadata != nil {
|
if auth.Metadata != nil {
|
||||||
@@ -162,6 +182,15 @@ type KiroExecutor struct {
|
|||||||
refreshMu sync.Mutex // Serializes token refresh operations to prevent race conditions
|
refreshMu sync.Mutex // Serializes token refresh operations to prevent race conditions
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// isIDCAuth checks if the auth uses IDC (Identity Center) authentication method.
|
||||||
|
func isIDCAuth(auth *cliproxyauth.Auth) bool {
|
||||||
|
if auth == nil || auth.Metadata == nil {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
authMethod, _ := auth.Metadata["auth_method"].(string)
|
||||||
|
return authMethod == "idc"
|
||||||
|
}
|
||||||
|
|
||||||
// buildKiroPayloadForFormat builds the Kiro API payload based on the source format.
|
// buildKiroPayloadForFormat builds the Kiro API payload based on the source format.
|
||||||
// This is critical because OpenAI and Claude formats have different tool structures:
|
// This is critical because OpenAI and Claude formats have different tool structures:
|
||||||
// - OpenAI: tools[].function.name, tools[].function.description
|
// - OpenAI: tools[].function.name, tools[].function.description
|
||||||
@@ -210,6 +239,10 @@ func (e *KiroExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, req
|
|||||||
log.Warnf("kiro: pre-request token refresh failed: %v", refreshErr)
|
log.Warnf("kiro: pre-request token refresh failed: %v", refreshErr)
|
||||||
} else if refreshedAuth != nil {
|
} else if refreshedAuth != nil {
|
||||||
auth = refreshedAuth
|
auth = refreshedAuth
|
||||||
|
// Persist the refreshed auth to file so subsequent requests use it
|
||||||
|
if persistErr := e.persistRefreshedAuth(auth); persistErr != nil {
|
||||||
|
log.Warnf("kiro: failed to persist refreshed auth: %v", persistErr)
|
||||||
|
}
|
||||||
accessToken, profileArn = kiroCredentials(auth)
|
accessToken, profileArn = kiroCredentials(auth)
|
||||||
log.Infof("kiro: token refreshed successfully before request")
|
log.Infof("kiro: token refreshed successfully before request")
|
||||||
}
|
}
|
||||||
@@ -262,15 +295,28 @@ func (e *KiroExecutor) executeWithRetry(ctx context.Context, auth *cliproxyauth.
|
|||||||
}
|
}
|
||||||
|
|
||||||
httpReq.Header.Set("Content-Type", kiroContentType)
|
httpReq.Header.Set("Content-Type", kiroContentType)
|
||||||
httpReq.Header.Set("Authorization", "Bearer "+accessToken)
|
|
||||||
httpReq.Header.Set("Accept", kiroAcceptStream)
|
httpReq.Header.Set("Accept", kiroAcceptStream)
|
||||||
// Use endpoint-specific X-Amz-Target (critical for avoiding 403 errors)
|
// Use endpoint-specific X-Amz-Target (critical for avoiding 403 errors)
|
||||||
httpReq.Header.Set("X-Amz-Target", endpointConfig.AmzTarget)
|
httpReq.Header.Set("X-Amz-Target", endpointConfig.AmzTarget)
|
||||||
httpReq.Header.Set("User-Agent", kiroUserAgent)
|
|
||||||
httpReq.Header.Set("X-Amz-User-Agent", kiroFullUserAgent)
|
// Use different headers based on auth type
|
||||||
|
// IDC auth uses Kiro IDE style headers (from kiro2api)
|
||||||
|
// Other auth types use Amazon Q CLI style headers
|
||||||
|
if isIDCAuth(auth) {
|
||||||
|
httpReq.Header.Set("User-Agent", kiroIDEUserAgent)
|
||||||
|
httpReq.Header.Set("X-Amz-User-Agent", kiroIDEAmzUserAgent)
|
||||||
|
httpReq.Header.Set("x-amzn-kiro-agent-mode", kiroIDEAgentModeSpec)
|
||||||
|
log.Debugf("kiro: using Kiro IDE headers for IDC auth")
|
||||||
|
} else {
|
||||||
|
httpReq.Header.Set("User-Agent", kiroUserAgent)
|
||||||
|
httpReq.Header.Set("X-Amz-User-Agent", kiroFullUserAgent)
|
||||||
|
}
|
||||||
httpReq.Header.Set("Amz-Sdk-Request", "attempt=1; max=3")
|
httpReq.Header.Set("Amz-Sdk-Request", "attempt=1; max=3")
|
||||||
httpReq.Header.Set("Amz-Sdk-Invocation-Id", uuid.New().String())
|
httpReq.Header.Set("Amz-Sdk-Invocation-Id", uuid.New().String())
|
||||||
|
|
||||||
|
// Bearer token authentication for all auth types (Builder ID, IDC, social, etc.)
|
||||||
|
httpReq.Header.Set("Authorization", "Bearer "+accessToken)
|
||||||
|
|
||||||
var attrs map[string]string
|
var attrs map[string]string
|
||||||
if auth != nil {
|
if auth != nil {
|
||||||
attrs = auth.Attributes
|
attrs = auth.Attributes
|
||||||
@@ -358,6 +404,11 @@ func (e *KiroExecutor) executeWithRetry(ctx context.Context, auth *cliproxyauth.
|
|||||||
|
|
||||||
if refreshedAuth != nil {
|
if refreshedAuth != nil {
|
||||||
auth = refreshedAuth
|
auth = refreshedAuth
|
||||||
|
// Persist the refreshed auth to file so subsequent requests use it
|
||||||
|
if persistErr := e.persistRefreshedAuth(auth); persistErr != nil {
|
||||||
|
log.Warnf("kiro: failed to persist refreshed auth: %v", persistErr)
|
||||||
|
// Continue anyway - the token is valid for this request
|
||||||
|
}
|
||||||
accessToken, profileArn = kiroCredentials(auth)
|
accessToken, profileArn = kiroCredentials(auth)
|
||||||
// Rebuild payload with new profile ARN if changed
|
// Rebuild payload with new profile ARN if changed
|
||||||
kiroPayload, _ = buildKiroPayloadForFormat(body, kiroModelID, profileArn, currentOrigin, isAgentic, isChatOnly, from, opts.Headers)
|
kiroPayload, _ = buildKiroPayloadForFormat(body, kiroModelID, profileArn, currentOrigin, isAgentic, isChatOnly, from, opts.Headers)
|
||||||
@@ -416,6 +467,11 @@ func (e *KiroExecutor) executeWithRetry(ctx context.Context, auth *cliproxyauth.
|
|||||||
}
|
}
|
||||||
if refreshedAuth != nil {
|
if refreshedAuth != nil {
|
||||||
auth = refreshedAuth
|
auth = refreshedAuth
|
||||||
|
// Persist the refreshed auth to file so subsequent requests use it
|
||||||
|
if persistErr := e.persistRefreshedAuth(auth); persistErr != nil {
|
||||||
|
log.Warnf("kiro: failed to persist refreshed auth: %v", persistErr)
|
||||||
|
// Continue anyway - the token is valid for this request
|
||||||
|
}
|
||||||
accessToken, profileArn = kiroCredentials(auth)
|
accessToken, profileArn = kiroCredentials(auth)
|
||||||
kiroPayload, _ = buildKiroPayloadForFormat(body, kiroModelID, profileArn, currentOrigin, isAgentic, isChatOnly, from, opts.Headers)
|
kiroPayload, _ = buildKiroPayloadForFormat(body, kiroModelID, profileArn, currentOrigin, isAgentic, isChatOnly, from, opts.Headers)
|
||||||
log.Infof("kiro: token refreshed for 403, retrying request")
|
log.Infof("kiro: token refreshed for 403, retrying request")
|
||||||
@@ -518,6 +574,10 @@ func (e *KiroExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Aut
|
|||||||
log.Warnf("kiro: pre-request token refresh failed: %v", refreshErr)
|
log.Warnf("kiro: pre-request token refresh failed: %v", refreshErr)
|
||||||
} else if refreshedAuth != nil {
|
} else if refreshedAuth != nil {
|
||||||
auth = refreshedAuth
|
auth = refreshedAuth
|
||||||
|
// Persist the refreshed auth to file so subsequent requests use it
|
||||||
|
if persistErr := e.persistRefreshedAuth(auth); persistErr != nil {
|
||||||
|
log.Warnf("kiro: failed to persist refreshed auth: %v", persistErr)
|
||||||
|
}
|
||||||
accessToken, profileArn = kiroCredentials(auth)
|
accessToken, profileArn = kiroCredentials(auth)
|
||||||
log.Infof("kiro: token refreshed successfully before stream request")
|
log.Infof("kiro: token refreshed successfully before stream request")
|
||||||
}
|
}
|
||||||
@@ -568,15 +628,28 @@ func (e *KiroExecutor) executeStreamWithRetry(ctx context.Context, auth *cliprox
|
|||||||
}
|
}
|
||||||
|
|
||||||
httpReq.Header.Set("Content-Type", kiroContentType)
|
httpReq.Header.Set("Content-Type", kiroContentType)
|
||||||
httpReq.Header.Set("Authorization", "Bearer "+accessToken)
|
|
||||||
httpReq.Header.Set("Accept", kiroAcceptStream)
|
httpReq.Header.Set("Accept", kiroAcceptStream)
|
||||||
// Use endpoint-specific X-Amz-Target (critical for avoiding 403 errors)
|
// Use endpoint-specific X-Amz-Target (critical for avoiding 403 errors)
|
||||||
httpReq.Header.Set("X-Amz-Target", endpointConfig.AmzTarget)
|
httpReq.Header.Set("X-Amz-Target", endpointConfig.AmzTarget)
|
||||||
httpReq.Header.Set("User-Agent", kiroUserAgent)
|
|
||||||
httpReq.Header.Set("X-Amz-User-Agent", kiroFullUserAgent)
|
// Use different headers based on auth type
|
||||||
|
// IDC auth uses Kiro IDE style headers (from kiro2api)
|
||||||
|
// Other auth types use Amazon Q CLI style headers
|
||||||
|
if isIDCAuth(auth) {
|
||||||
|
httpReq.Header.Set("User-Agent", kiroIDEUserAgent)
|
||||||
|
httpReq.Header.Set("X-Amz-User-Agent", kiroIDEAmzUserAgent)
|
||||||
|
httpReq.Header.Set("x-amzn-kiro-agent-mode", kiroIDEAgentModeSpec)
|
||||||
|
log.Debugf("kiro: using Kiro IDE headers for IDC auth")
|
||||||
|
} else {
|
||||||
|
httpReq.Header.Set("User-Agent", kiroUserAgent)
|
||||||
|
httpReq.Header.Set("X-Amz-User-Agent", kiroFullUserAgent)
|
||||||
|
}
|
||||||
httpReq.Header.Set("Amz-Sdk-Request", "attempt=1; max=3")
|
httpReq.Header.Set("Amz-Sdk-Request", "attempt=1; max=3")
|
||||||
httpReq.Header.Set("Amz-Sdk-Invocation-Id", uuid.New().String())
|
httpReq.Header.Set("Amz-Sdk-Invocation-Id", uuid.New().String())
|
||||||
|
|
||||||
|
// Bearer token authentication for all auth types (Builder ID, IDC, social, etc.)
|
||||||
|
httpReq.Header.Set("Authorization", "Bearer "+accessToken)
|
||||||
|
|
||||||
var attrs map[string]string
|
var attrs map[string]string
|
||||||
if auth != nil {
|
if auth != nil {
|
||||||
attrs = auth.Attributes
|
attrs = auth.Attributes
|
||||||
@@ -677,6 +750,11 @@ func (e *KiroExecutor) executeStreamWithRetry(ctx context.Context, auth *cliprox
|
|||||||
|
|
||||||
if refreshedAuth != nil {
|
if refreshedAuth != nil {
|
||||||
auth = refreshedAuth
|
auth = refreshedAuth
|
||||||
|
// Persist the refreshed auth to file so subsequent requests use it
|
||||||
|
if persistErr := e.persistRefreshedAuth(auth); persistErr != nil {
|
||||||
|
log.Warnf("kiro: failed to persist refreshed auth: %v", persistErr)
|
||||||
|
// Continue anyway - the token is valid for this request
|
||||||
|
}
|
||||||
accessToken, profileArn = kiroCredentials(auth)
|
accessToken, profileArn = kiroCredentials(auth)
|
||||||
// Rebuild payload with new profile ARN if changed
|
// Rebuild payload with new profile ARN if changed
|
||||||
kiroPayload, _ = buildKiroPayloadForFormat(body, kiroModelID, profileArn, currentOrigin, isAgentic, isChatOnly, from, opts.Headers)
|
kiroPayload, _ = buildKiroPayloadForFormat(body, kiroModelID, profileArn, currentOrigin, isAgentic, isChatOnly, from, opts.Headers)
|
||||||
@@ -735,6 +813,11 @@ func (e *KiroExecutor) executeStreamWithRetry(ctx context.Context, auth *cliprox
|
|||||||
}
|
}
|
||||||
if refreshedAuth != nil {
|
if refreshedAuth != nil {
|
||||||
auth = refreshedAuth
|
auth = refreshedAuth
|
||||||
|
// Persist the refreshed auth to file so subsequent requests use it
|
||||||
|
if persistErr := e.persistRefreshedAuth(auth); persistErr != nil {
|
||||||
|
log.Warnf("kiro: failed to persist refreshed auth: %v", persistErr)
|
||||||
|
// Continue anyway - the token is valid for this request
|
||||||
|
}
|
||||||
accessToken, profileArn = kiroCredentials(auth)
|
accessToken, profileArn = kiroCredentials(auth)
|
||||||
kiroPayload, _ = buildKiroPayloadForFormat(body, kiroModelID, profileArn, currentOrigin, isAgentic, isChatOnly, from, opts.Headers)
|
kiroPayload, _ = buildKiroPayloadForFormat(body, kiroModelID, profileArn, currentOrigin, isAgentic, isChatOnly, from, opts.Headers)
|
||||||
log.Infof("kiro: token refreshed for 403, retrying stream request")
|
log.Infof("kiro: token refreshed for 403, retrying stream request")
|
||||||
@@ -1001,12 +1084,12 @@ func getEffectiveProfileArn(auth *cliproxyauth.Auth, profileArn string) string {
|
|||||||
// This consolidates the auth_method check that was previously done separately.
|
// This consolidates the auth_method check that was previously done separately.
|
||||||
func getEffectiveProfileArnWithWarning(auth *cliproxyauth.Auth, profileArn string) string {
|
func getEffectiveProfileArnWithWarning(auth *cliproxyauth.Auth, profileArn string) string {
|
||||||
if auth != nil && auth.Metadata != nil {
|
if auth != nil && auth.Metadata != nil {
|
||||||
if authMethod, ok := auth.Metadata["auth_method"].(string); ok && authMethod == "builder-id" {
|
if authMethod, ok := auth.Metadata["auth_method"].(string); ok && (authMethod == "builder-id" || authMethod == "idc") {
|
||||||
// builder-id auth doesn't need profileArn
|
// builder-id and idc auth don't need profileArn
|
||||||
return ""
|
return ""
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
// For non-builder-id auth (social auth), profileArn is required
|
// For non-builder-id/idc auth (social auth), profileArn is required
|
||||||
if profileArn == "" {
|
if profileArn == "" {
|
||||||
log.Warnf("kiro: profile ARN not found in auth, API calls may fail")
|
log.Warnf("kiro: profile ARN not found in auth, API calls may fail")
|
||||||
}
|
}
|
||||||
@@ -3010,6 +3093,7 @@ func (e *KiroExecutor) Refresh(ctx context.Context, auth *cliproxyauth.Auth) (*c
|
|||||||
var refreshToken string
|
var refreshToken string
|
||||||
var clientID, clientSecret string
|
var clientID, clientSecret string
|
||||||
var authMethod string
|
var authMethod string
|
||||||
|
var region, startURL string
|
||||||
|
|
||||||
if auth.Metadata != nil {
|
if auth.Metadata != nil {
|
||||||
if rt, ok := auth.Metadata["refresh_token"].(string); ok {
|
if rt, ok := auth.Metadata["refresh_token"].(string); ok {
|
||||||
@@ -3024,6 +3108,12 @@ func (e *KiroExecutor) Refresh(ctx context.Context, auth *cliproxyauth.Auth) (*c
|
|||||||
if am, ok := auth.Metadata["auth_method"].(string); ok {
|
if am, ok := auth.Metadata["auth_method"].(string); ok {
|
||||||
authMethod = am
|
authMethod = am
|
||||||
}
|
}
|
||||||
|
if r, ok := auth.Metadata["region"].(string); ok {
|
||||||
|
region = r
|
||||||
|
}
|
||||||
|
if su, ok := auth.Metadata["start_url"].(string); ok {
|
||||||
|
startURL = su
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if refreshToken == "" {
|
if refreshToken == "" {
|
||||||
@@ -3033,12 +3123,20 @@ func (e *KiroExecutor) Refresh(ctx context.Context, auth *cliproxyauth.Auth) (*c
|
|||||||
var tokenData *kiroauth.KiroTokenData
|
var tokenData *kiroauth.KiroTokenData
|
||||||
var err error
|
var err error
|
||||||
|
|
||||||
// Use SSO OIDC refresh for AWS Builder ID, otherwise use Kiro's OAuth refresh endpoint
|
ssoClient := kiroauth.NewSSOOIDCClient(e.cfg)
|
||||||
if clientID != "" && clientSecret != "" && authMethod == "builder-id" {
|
|
||||||
|
// Use SSO OIDC refresh for AWS Builder ID or IDC, otherwise use Kiro's OAuth refresh endpoint
|
||||||
|
switch {
|
||||||
|
case clientID != "" && clientSecret != "" && authMethod == "idc" && region != "":
|
||||||
|
// IDC refresh with region-specific endpoint
|
||||||
|
log.Debugf("kiro executor: using SSO OIDC refresh for IDC (region=%s)", region)
|
||||||
|
tokenData, err = ssoClient.RefreshTokenWithRegion(ctx, clientID, clientSecret, refreshToken, region, startURL)
|
||||||
|
case clientID != "" && clientSecret != "" && authMethod == "builder-id":
|
||||||
|
// Builder ID refresh with default endpoint
|
||||||
log.Debugf("kiro executor: using SSO OIDC refresh for AWS Builder ID")
|
log.Debugf("kiro executor: using SSO OIDC refresh for AWS Builder ID")
|
||||||
ssoClient := kiroauth.NewSSOOIDCClient(e.cfg)
|
|
||||||
tokenData, err = ssoClient.RefreshToken(ctx, clientID, clientSecret, refreshToken)
|
tokenData, err = ssoClient.RefreshToken(ctx, clientID, clientSecret, refreshToken)
|
||||||
} else {
|
default:
|
||||||
|
// Fallback to Kiro's OAuth refresh endpoint (for social auth: Google/GitHub)
|
||||||
log.Debugf("kiro executor: using Kiro OAuth refresh endpoint")
|
log.Debugf("kiro executor: using Kiro OAuth refresh endpoint")
|
||||||
oauth := kiroauth.NewKiroOAuth(e.cfg)
|
oauth := kiroauth.NewKiroOAuth(e.cfg)
|
||||||
tokenData, err = oauth.RefreshToken(ctx, refreshToken)
|
tokenData, err = oauth.RefreshToken(ctx, refreshToken)
|
||||||
@@ -3094,6 +3192,53 @@ func (e *KiroExecutor) Refresh(ctx context.Context, auth *cliproxyauth.Auth) (*c
|
|||||||
return updated, nil
|
return updated, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// persistRefreshedAuth persists a refreshed auth record to disk.
|
||||||
|
// This ensures token refreshes from inline retry are saved to the auth file.
|
||||||
|
func (e *KiroExecutor) persistRefreshedAuth(auth *cliproxyauth.Auth) error {
|
||||||
|
if auth == nil || auth.Metadata == nil {
|
||||||
|
return fmt.Errorf("kiro executor: cannot persist nil auth or metadata")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Determine the file path from auth attributes or filename
|
||||||
|
var authPath string
|
||||||
|
if auth.Attributes != nil {
|
||||||
|
if p := strings.TrimSpace(auth.Attributes["path"]); p != "" {
|
||||||
|
authPath = p
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if authPath == "" {
|
||||||
|
fileName := strings.TrimSpace(auth.FileName)
|
||||||
|
if fileName == "" {
|
||||||
|
return fmt.Errorf("kiro executor: auth has no file path or filename")
|
||||||
|
}
|
||||||
|
if filepath.IsAbs(fileName) {
|
||||||
|
authPath = fileName
|
||||||
|
} else if e.cfg != nil && e.cfg.AuthDir != "" {
|
||||||
|
authPath = filepath.Join(e.cfg.AuthDir, fileName)
|
||||||
|
} else {
|
||||||
|
return fmt.Errorf("kiro executor: cannot determine auth file path")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Marshal metadata to JSON
|
||||||
|
raw, err := json.Marshal(auth.Metadata)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("kiro executor: marshal metadata failed: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Write to temp file first, then rename (atomic write)
|
||||||
|
tmp := authPath + ".tmp"
|
||||||
|
if err := os.WriteFile(tmp, raw, 0o600); err != nil {
|
||||||
|
return fmt.Errorf("kiro executor: write temp auth file failed: %w", err)
|
||||||
|
}
|
||||||
|
if err := os.Rename(tmp, authPath); err != nil {
|
||||||
|
return fmt.Errorf("kiro executor: rename auth file failed: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
log.Debugf("kiro executor: persisted refreshed auth to %s", authPath)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
// isTokenExpired checks if a JWT access token has expired.
|
// isTokenExpired checks if a JWT access token has expired.
|
||||||
// Returns true if the token is expired or cannot be parsed.
|
// Returns true if the token is expired or cannot be parsed.
|
||||||
func (e *KiroExecutor) isTokenExpired(accessToken string) bool {
|
func (e *KiroExecutor) isTokenExpired(accessToken string) bool {
|
||||||
|
|||||||
@@ -86,6 +86,10 @@ func ConvertClaudeRequestToAntigravity(modelName string, inputRawJSON []byte, _
|
|||||||
hasSystemInstruction = true
|
hasSystemInstruction = true
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
} else if systemResult.Type == gjson.String {
|
||||||
|
systemInstructionJSON = `{"role":"user","parts":[{"text":""}]}`
|
||||||
|
systemInstructionJSON, _ = sjson.Set(systemInstructionJSON, "parts.0.text", systemResult.String())
|
||||||
|
hasSystemInstruction = true
|
||||||
}
|
}
|
||||||
|
|
||||||
// contents
|
// contents
|
||||||
|
|||||||
@@ -35,6 +35,7 @@ type Params struct {
|
|||||||
CandidatesTokenCount int64 // Cached candidate token count from usage metadata
|
CandidatesTokenCount int64 // Cached candidate token count from usage metadata
|
||||||
ThoughtsTokenCount int64 // Cached thinking token count from usage metadata
|
ThoughtsTokenCount int64 // Cached thinking token count from usage metadata
|
||||||
TotalTokenCount int64 // Cached total token count from usage metadata
|
TotalTokenCount int64 // Cached total token count from usage metadata
|
||||||
|
CachedTokenCount int64 // Cached content token count (indicates prompt caching)
|
||||||
HasSentFinalEvents bool // Indicates if final content/message events have been sent
|
HasSentFinalEvents bool // Indicates if final content/message events have been sent
|
||||||
HasToolUse bool // Indicates if tool use was observed in the stream
|
HasToolUse bool // Indicates if tool use was observed in the stream
|
||||||
HasContent bool // Tracks whether any content (text, thinking, or tool use) has been output
|
HasContent bool // Tracks whether any content (text, thinking, or tool use) has been output
|
||||||
@@ -274,6 +275,7 @@ func ConvertAntigravityResponseToClaude(_ context.Context, _ string, originalReq
|
|||||||
params.CandidatesTokenCount = usageResult.Get("candidatesTokenCount").Int()
|
params.CandidatesTokenCount = usageResult.Get("candidatesTokenCount").Int()
|
||||||
params.ThoughtsTokenCount = usageResult.Get("thoughtsTokenCount").Int()
|
params.ThoughtsTokenCount = usageResult.Get("thoughtsTokenCount").Int()
|
||||||
params.TotalTokenCount = usageResult.Get("totalTokenCount").Int()
|
params.TotalTokenCount = usageResult.Get("totalTokenCount").Int()
|
||||||
|
params.CachedTokenCount = usageResult.Get("cachedContentTokenCount").Int()
|
||||||
if params.CandidatesTokenCount == 0 && params.TotalTokenCount > 0 {
|
if params.CandidatesTokenCount == 0 && params.TotalTokenCount > 0 {
|
||||||
params.CandidatesTokenCount = params.TotalTokenCount - params.PromptTokenCount - params.ThoughtsTokenCount
|
params.CandidatesTokenCount = params.TotalTokenCount - params.PromptTokenCount - params.ThoughtsTokenCount
|
||||||
if params.CandidatesTokenCount < 0 {
|
if params.CandidatesTokenCount < 0 {
|
||||||
@@ -322,6 +324,14 @@ func appendFinalEvents(params *Params, output *string, force bool) {
|
|||||||
*output = *output + "event: message_delta\n"
|
*output = *output + "event: message_delta\n"
|
||||||
*output = *output + "data: "
|
*output = *output + "data: "
|
||||||
delta := fmt.Sprintf(`{"type":"message_delta","delta":{"stop_reason":"%s","stop_sequence":null},"usage":{"input_tokens":%d,"output_tokens":%d}}`, stopReason, params.PromptTokenCount, usageOutputTokens)
|
delta := fmt.Sprintf(`{"type":"message_delta","delta":{"stop_reason":"%s","stop_sequence":null},"usage":{"input_tokens":%d,"output_tokens":%d}}`, stopReason, params.PromptTokenCount, usageOutputTokens)
|
||||||
|
// Add cache_read_input_tokens if cached tokens are present (indicates prompt caching is working)
|
||||||
|
if params.CachedTokenCount > 0 {
|
||||||
|
var err error
|
||||||
|
delta, err = sjson.Set(delta, "usage.cache_read_input_tokens", params.CachedTokenCount)
|
||||||
|
if err != nil {
|
||||||
|
log.Warnf("antigravity claude response: failed to set cache_read_input_tokens: %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
*output = *output + delta + "\n\n\n"
|
*output = *output + delta + "\n\n\n"
|
||||||
|
|
||||||
params.HasSentFinalEvents = true
|
params.HasSentFinalEvents = true
|
||||||
@@ -361,6 +371,7 @@ func ConvertAntigravityResponseToClaudeNonStream(_ context.Context, _ string, or
|
|||||||
candidateTokens := root.Get("response.usageMetadata.candidatesTokenCount").Int()
|
candidateTokens := root.Get("response.usageMetadata.candidatesTokenCount").Int()
|
||||||
thoughtTokens := root.Get("response.usageMetadata.thoughtsTokenCount").Int()
|
thoughtTokens := root.Get("response.usageMetadata.thoughtsTokenCount").Int()
|
||||||
totalTokens := root.Get("response.usageMetadata.totalTokenCount").Int()
|
totalTokens := root.Get("response.usageMetadata.totalTokenCount").Int()
|
||||||
|
cachedTokens := root.Get("response.usageMetadata.cachedContentTokenCount").Int()
|
||||||
outputTokens := candidateTokens + thoughtTokens
|
outputTokens := candidateTokens + thoughtTokens
|
||||||
if outputTokens == 0 && totalTokens > 0 {
|
if outputTokens == 0 && totalTokens > 0 {
|
||||||
outputTokens = totalTokens - promptTokens
|
outputTokens = totalTokens - promptTokens
|
||||||
@@ -374,6 +385,14 @@ func ConvertAntigravityResponseToClaudeNonStream(_ context.Context, _ string, or
|
|||||||
responseJSON, _ = sjson.Set(responseJSON, "model", root.Get("response.modelVersion").String())
|
responseJSON, _ = sjson.Set(responseJSON, "model", root.Get("response.modelVersion").String())
|
||||||
responseJSON, _ = sjson.Set(responseJSON, "usage.input_tokens", promptTokens)
|
responseJSON, _ = sjson.Set(responseJSON, "usage.input_tokens", promptTokens)
|
||||||
responseJSON, _ = sjson.Set(responseJSON, "usage.output_tokens", outputTokens)
|
responseJSON, _ = sjson.Set(responseJSON, "usage.output_tokens", outputTokens)
|
||||||
|
// Add cache_read_input_tokens if cached tokens are present (indicates prompt caching is working)
|
||||||
|
if cachedTokens > 0 {
|
||||||
|
var err error
|
||||||
|
responseJSON, err = sjson.Set(responseJSON, "usage.cache_read_input_tokens", cachedTokens)
|
||||||
|
if err != nil {
|
||||||
|
log.Warnf("antigravity claude response: failed to set cache_read_input_tokens: %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
contentArrayInitialized := false
|
contentArrayInitialized := false
|
||||||
ensureContentArray := func() {
|
ensureContentArray := func() {
|
||||||
|
|||||||
@@ -266,7 +266,11 @@ func ConvertOpenAIRequestToAntigravity(modelName string, inputRawJSON []byte, _
|
|||||||
fargs := tc.Get("function.arguments").String()
|
fargs := tc.Get("function.arguments").String()
|
||||||
node, _ = sjson.SetBytes(node, "parts."+itoa(p)+".functionCall.id", fid)
|
node, _ = sjson.SetBytes(node, "parts."+itoa(p)+".functionCall.id", fid)
|
||||||
node, _ = sjson.SetBytes(node, "parts."+itoa(p)+".functionCall.name", fname)
|
node, _ = sjson.SetBytes(node, "parts."+itoa(p)+".functionCall.name", fname)
|
||||||
node, _ = sjson.SetRawBytes(node, "parts."+itoa(p)+".functionCall.args", []byte(fargs))
|
if gjson.Valid(fargs) {
|
||||||
|
node, _ = sjson.SetRawBytes(node, "parts."+itoa(p)+".functionCall.args", []byte(fargs))
|
||||||
|
} else {
|
||||||
|
node, _ = sjson.SetBytes(node, "parts."+itoa(p)+".functionCall.args.params", []byte(fargs))
|
||||||
|
}
|
||||||
node, _ = sjson.SetBytes(node, "parts."+itoa(p)+".thoughtSignature", geminiCLIFunctionThoughtSignature)
|
node, _ = sjson.SetBytes(node, "parts."+itoa(p)+".thoughtSignature", geminiCLIFunctionThoughtSignature)
|
||||||
p++
|
p++
|
||||||
if fid != "" {
|
if fid != "" {
|
||||||
|
|||||||
@@ -13,6 +13,8 @@ import (
|
|||||||
"sync/atomic"
|
"sync/atomic"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
log "github.com/sirupsen/logrus"
|
||||||
|
|
||||||
. "github.com/router-for-me/CLIProxyAPI/v6/internal/translator/gemini/openai/chat-completions"
|
. "github.com/router-for-me/CLIProxyAPI/v6/internal/translator/gemini/openai/chat-completions"
|
||||||
"github.com/tidwall/gjson"
|
"github.com/tidwall/gjson"
|
||||||
"github.com/tidwall/sjson"
|
"github.com/tidwall/sjson"
|
||||||
@@ -93,10 +95,19 @@ func ConvertAntigravityResponseToOpenAI(_ context.Context, _ string, originalReq
|
|||||||
}
|
}
|
||||||
promptTokenCount := usageResult.Get("promptTokenCount").Int()
|
promptTokenCount := usageResult.Get("promptTokenCount").Int()
|
||||||
thoughtsTokenCount := usageResult.Get("thoughtsTokenCount").Int()
|
thoughtsTokenCount := usageResult.Get("thoughtsTokenCount").Int()
|
||||||
|
cachedTokenCount := usageResult.Get("cachedContentTokenCount").Int()
|
||||||
template, _ = sjson.Set(template, "usage.prompt_tokens", promptTokenCount+thoughtsTokenCount)
|
template, _ = sjson.Set(template, "usage.prompt_tokens", promptTokenCount+thoughtsTokenCount)
|
||||||
if thoughtsTokenCount > 0 {
|
if thoughtsTokenCount > 0 {
|
||||||
template, _ = sjson.Set(template, "usage.completion_tokens_details.reasoning_tokens", thoughtsTokenCount)
|
template, _ = sjson.Set(template, "usage.completion_tokens_details.reasoning_tokens", thoughtsTokenCount)
|
||||||
}
|
}
|
||||||
|
// Include cached token count if present (indicates prompt caching is working)
|
||||||
|
if cachedTokenCount > 0 {
|
||||||
|
var err error
|
||||||
|
template, err = sjson.Set(template, "usage.prompt_tokens_details.cached_tokens", cachedTokenCount)
|
||||||
|
if err != nil {
|
||||||
|
log.Warnf("antigravity openai response: failed to set cached_tokens: %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Process the main content part of the response.
|
// Process the main content part of the response.
|
||||||
|
|||||||
@@ -62,6 +62,8 @@ func ConvertClaudeRequestToCLI(modelName string, inputRawJSON []byte, _ bool) []
|
|||||||
if hasSystemParts {
|
if hasSystemParts {
|
||||||
out, _ = sjson.SetRaw(out, "request.systemInstruction", systemInstruction)
|
out, _ = sjson.SetRaw(out, "request.systemInstruction", systemInstruction)
|
||||||
}
|
}
|
||||||
|
} else if systemResult.Type == gjson.String {
|
||||||
|
out, _ = sjson.Set(out, "request.systemInstruction.parts.-1.text", systemResult.String())
|
||||||
}
|
}
|
||||||
|
|
||||||
// contents
|
// contents
|
||||||
|
|||||||
@@ -244,7 +244,7 @@ func ConvertOpenAIRequestToGeminiCLI(modelName string, inputRawJSON []byte, _ bo
|
|||||||
out, _ = sjson.SetRawBytes(out, "request.contents.-1", node)
|
out, _ = sjson.SetRawBytes(out, "request.contents.-1", node)
|
||||||
|
|
||||||
// Append a single tool content combining name + response per function
|
// Append a single tool content combining name + response per function
|
||||||
toolNode := []byte(`{"role":"tool","parts":[]}`)
|
toolNode := []byte(`{"role":"user","parts":[]}`)
|
||||||
pp := 0
|
pp := 0
|
||||||
for _, fid := range fIDs {
|
for _, fid := range fIDs {
|
||||||
if name, ok := tcID2Name[fid]; ok {
|
if name, ok := tcID2Name[fid]; ok {
|
||||||
|
|||||||
@@ -55,6 +55,8 @@ func ConvertClaudeRequestToGemini(modelName string, inputRawJSON []byte, _ bool)
|
|||||||
if hasSystemParts {
|
if hasSystemParts {
|
||||||
out, _ = sjson.SetRaw(out, "system_instruction", systemInstruction)
|
out, _ = sjson.SetRaw(out, "system_instruction", systemInstruction)
|
||||||
}
|
}
|
||||||
|
} else if systemResult.Type == gjson.String {
|
||||||
|
out, _ = sjson.Set(out, "request.system_instruction.parts.-1.text", systemResult.String())
|
||||||
}
|
}
|
||||||
|
|
||||||
// contents
|
// contents
|
||||||
|
|||||||
@@ -286,7 +286,7 @@ func ConvertOpenAIRequestToGemini(modelName string, inputRawJSON []byte, _ bool)
|
|||||||
out, _ = sjson.SetRawBytes(out, "contents.-1", node)
|
out, _ = sjson.SetRawBytes(out, "contents.-1", node)
|
||||||
|
|
||||||
// Append a single tool content combining name + response per function
|
// Append a single tool content combining name + response per function
|
||||||
toolNode := []byte(`{"role":"tool","parts":[]}`)
|
toolNode := []byte(`{"role":"user","parts":[]}`)
|
||||||
pp := 0
|
pp := 0
|
||||||
for _, fid := range fIDs {
|
for _, fid := range fIDs {
|
||||||
if name, ok := tcID2Name[fid]; ok {
|
if name, ok := tcID2Name[fid]; ok {
|
||||||
|
|||||||
@@ -13,6 +13,7 @@ import (
|
|||||||
"sync/atomic"
|
"sync/atomic"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
log "github.com/sirupsen/logrus"
|
||||||
"github.com/tidwall/gjson"
|
"github.com/tidwall/gjson"
|
||||||
"github.com/tidwall/sjson"
|
"github.com/tidwall/sjson"
|
||||||
)
|
)
|
||||||
@@ -96,10 +97,19 @@ func ConvertGeminiResponseToOpenAI(_ context.Context, _ string, originalRequestR
|
|||||||
}
|
}
|
||||||
promptTokenCount := usageResult.Get("promptTokenCount").Int()
|
promptTokenCount := usageResult.Get("promptTokenCount").Int()
|
||||||
thoughtsTokenCount := usageResult.Get("thoughtsTokenCount").Int()
|
thoughtsTokenCount := usageResult.Get("thoughtsTokenCount").Int()
|
||||||
|
cachedTokenCount := usageResult.Get("cachedContentTokenCount").Int()
|
||||||
template, _ = sjson.Set(template, "usage.prompt_tokens", promptTokenCount+thoughtsTokenCount)
|
template, _ = sjson.Set(template, "usage.prompt_tokens", promptTokenCount+thoughtsTokenCount)
|
||||||
if thoughtsTokenCount > 0 {
|
if thoughtsTokenCount > 0 {
|
||||||
template, _ = sjson.Set(template, "usage.completion_tokens_details.reasoning_tokens", thoughtsTokenCount)
|
template, _ = sjson.Set(template, "usage.completion_tokens_details.reasoning_tokens", thoughtsTokenCount)
|
||||||
}
|
}
|
||||||
|
// Include cached token count if present (indicates prompt caching is working)
|
||||||
|
if cachedTokenCount > 0 {
|
||||||
|
var err error
|
||||||
|
template, err = sjson.Set(template, "usage.prompt_tokens_details.cached_tokens", cachedTokenCount)
|
||||||
|
if err != nil {
|
||||||
|
log.Warnf("gemini openai response: failed to set cached_tokens in streaming: %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Process the main content part of the response.
|
// Process the main content part of the response.
|
||||||
@@ -240,10 +250,19 @@ func ConvertGeminiResponseToOpenAINonStream(_ context.Context, _ string, origina
|
|||||||
}
|
}
|
||||||
promptTokenCount := usageResult.Get("promptTokenCount").Int()
|
promptTokenCount := usageResult.Get("promptTokenCount").Int()
|
||||||
thoughtsTokenCount := usageResult.Get("thoughtsTokenCount").Int()
|
thoughtsTokenCount := usageResult.Get("thoughtsTokenCount").Int()
|
||||||
|
cachedTokenCount := usageResult.Get("cachedContentTokenCount").Int()
|
||||||
template, _ = sjson.Set(template, "usage.prompt_tokens", promptTokenCount+thoughtsTokenCount)
|
template, _ = sjson.Set(template, "usage.prompt_tokens", promptTokenCount+thoughtsTokenCount)
|
||||||
if thoughtsTokenCount > 0 {
|
if thoughtsTokenCount > 0 {
|
||||||
template, _ = sjson.Set(template, "usage.completion_tokens_details.reasoning_tokens", thoughtsTokenCount)
|
template, _ = sjson.Set(template, "usage.completion_tokens_details.reasoning_tokens", thoughtsTokenCount)
|
||||||
}
|
}
|
||||||
|
// Include cached token count if present (indicates prompt caching is working)
|
||||||
|
if cachedTokenCount > 0 {
|
||||||
|
var err error
|
||||||
|
template, err = sjson.Set(template, "usage.prompt_tokens_details.cached_tokens", cachedTokenCount)
|
||||||
|
if err != nil {
|
||||||
|
log.Warnf("gemini openai response: failed to set cached_tokens in non-streaming: %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Process the main content part of the response.
|
// Process the main content part of the response.
|
||||||
|
|||||||
@@ -10,6 +10,7 @@ import (
|
|||||||
// MergeAdjacentMessages merges adjacent messages with the same role.
|
// MergeAdjacentMessages merges adjacent messages with the same role.
|
||||||
// This reduces API call complexity and improves compatibility.
|
// This reduces API call complexity and improves compatibility.
|
||||||
// Based on AIClient-2-API implementation.
|
// Based on AIClient-2-API implementation.
|
||||||
|
// NOTE: Tool messages are NOT merged because each has a unique tool_call_id that must be preserved.
|
||||||
func MergeAdjacentMessages(messages []gjson.Result) []gjson.Result {
|
func MergeAdjacentMessages(messages []gjson.Result) []gjson.Result {
|
||||||
if len(messages) <= 1 {
|
if len(messages) <= 1 {
|
||||||
return messages
|
return messages
|
||||||
@@ -26,6 +27,12 @@ func MergeAdjacentMessages(messages []gjson.Result) []gjson.Result {
|
|||||||
currentRole := msg.Get("role").String()
|
currentRole := msg.Get("role").String()
|
||||||
lastRole := lastMsg.Get("role").String()
|
lastRole := lastMsg.Get("role").String()
|
||||||
|
|
||||||
|
// Don't merge tool messages - each has a unique tool_call_id
|
||||||
|
if currentRole == "tool" || lastRole == "tool" {
|
||||||
|
merged = append(merged, msg)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
if currentRole == lastRole {
|
if currentRole == lastRole {
|
||||||
// Merge content from current message into last message
|
// Merge content from current message into last message
|
||||||
mergedContent := mergeMessageContent(lastMsg, msg)
|
mergedContent := mergeMessageContent(lastMsg, msg)
|
||||||
|
|||||||
@@ -450,24 +450,10 @@ func processOpenAIMessages(messages gjson.Result, modelID, origin string) ([]Kir
|
|||||||
// Merge adjacent messages with the same role
|
// Merge adjacent messages with the same role
|
||||||
messagesArray := kirocommon.MergeAdjacentMessages(messages.Array())
|
messagesArray := kirocommon.MergeAdjacentMessages(messages.Array())
|
||||||
|
|
||||||
// Build tool_call_id to name mapping from assistant messages
|
// Track pending tool results that should be attached to the next user message
|
||||||
toolCallIDToName := make(map[string]string)
|
// This is critical for LiteLLM-translated requests where tool results appear
|
||||||
for _, msg := range messagesArray {
|
// as separate "tool" role messages between assistant and user messages
|
||||||
if msg.Get("role").String() == "assistant" {
|
var pendingToolResults []KiroToolResult
|
||||||
toolCalls := msg.Get("tool_calls")
|
|
||||||
if toolCalls.IsArray() {
|
|
||||||
for _, tc := range toolCalls.Array() {
|
|
||||||
if tc.Get("type").String() == "function" {
|
|
||||||
id := tc.Get("id").String()
|
|
||||||
name := tc.Get("function.name").String()
|
|
||||||
if id != "" && name != "" {
|
|
||||||
toolCallIDToName[id] = name
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
for i, msg := range messagesArray {
|
for i, msg := range messagesArray {
|
||||||
role := msg.Get("role").String()
|
role := msg.Get("role").String()
|
||||||
@@ -480,6 +466,10 @@ func processOpenAIMessages(messages gjson.Result, modelID, origin string) ([]Kir
|
|||||||
|
|
||||||
case "user":
|
case "user":
|
||||||
userMsg, toolResults := buildUserMessageFromOpenAI(msg, modelID, origin)
|
userMsg, toolResults := buildUserMessageFromOpenAI(msg, modelID, origin)
|
||||||
|
// Merge any pending tool results from preceding "tool" role messages
|
||||||
|
toolResults = append(pendingToolResults, toolResults...)
|
||||||
|
pendingToolResults = nil // Reset pending tool results
|
||||||
|
|
||||||
if isLastMessage {
|
if isLastMessage {
|
||||||
currentUserMsg = &userMsg
|
currentUserMsg = &userMsg
|
||||||
currentToolResults = toolResults
|
currentToolResults = toolResults
|
||||||
@@ -505,6 +495,24 @@ func processOpenAIMessages(messages gjson.Result, modelID, origin string) ([]Kir
|
|||||||
|
|
||||||
case "assistant":
|
case "assistant":
|
||||||
assistantMsg := buildAssistantMessageFromOpenAI(msg)
|
assistantMsg := buildAssistantMessageFromOpenAI(msg)
|
||||||
|
|
||||||
|
// If there are pending tool results, we need to insert a synthetic user message
|
||||||
|
// before this assistant message to maintain proper conversation structure
|
||||||
|
if len(pendingToolResults) > 0 {
|
||||||
|
syntheticUserMsg := KiroUserInputMessage{
|
||||||
|
Content: "Tool results provided.",
|
||||||
|
ModelID: modelID,
|
||||||
|
Origin: origin,
|
||||||
|
UserInputMessageContext: &KiroUserInputMessageContext{
|
||||||
|
ToolResults: pendingToolResults,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
history = append(history, KiroHistoryMessage{
|
||||||
|
UserInputMessage: &syntheticUserMsg,
|
||||||
|
})
|
||||||
|
pendingToolResults = nil
|
||||||
|
}
|
||||||
|
|
||||||
if isLastMessage {
|
if isLastMessage {
|
||||||
history = append(history, KiroHistoryMessage{
|
history = append(history, KiroHistoryMessage{
|
||||||
AssistantResponseMessage: &assistantMsg,
|
AssistantResponseMessage: &assistantMsg,
|
||||||
@@ -524,7 +532,7 @@ func processOpenAIMessages(messages gjson.Result, modelID, origin string) ([]Kir
|
|||||||
case "tool":
|
case "tool":
|
||||||
// Tool messages in OpenAI format provide results for tool_calls
|
// Tool messages in OpenAI format provide results for tool_calls
|
||||||
// These are typically followed by user or assistant messages
|
// These are typically followed by user or assistant messages
|
||||||
// Process them and merge into the next user message's tool results
|
// Collect them as pending and attach to the next user message
|
||||||
toolCallID := msg.Get("tool_call_id").String()
|
toolCallID := msg.Get("tool_call_id").String()
|
||||||
content := msg.Get("content").String()
|
content := msg.Get("content").String()
|
||||||
|
|
||||||
@@ -534,9 +542,21 @@ func processOpenAIMessages(messages gjson.Result, modelID, origin string) ([]Kir
|
|||||||
Content: []KiroTextContent{{Text: content}},
|
Content: []KiroTextContent{{Text: content}},
|
||||||
Status: "success",
|
Status: "success",
|
||||||
}
|
}
|
||||||
// Tool results should be included in the next user message
|
// Collect pending tool results to attach to the next user message
|
||||||
// For now, collect them and they'll be handled when we build the current message
|
pendingToolResults = append(pendingToolResults, toolResult)
|
||||||
currentToolResults = append(currentToolResults, toolResult)
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Handle case where tool results are at the end with no following user message
|
||||||
|
if len(pendingToolResults) > 0 {
|
||||||
|
currentToolResults = append(currentToolResults, pendingToolResults...)
|
||||||
|
// If there's no current user message, create a synthetic one for the tool results
|
||||||
|
if currentUserMsg == nil {
|
||||||
|
currentUserMsg = &KiroUserInputMessage{
|
||||||
|
Content: "Tool results provided.",
|
||||||
|
ModelID: modelID,
|
||||||
|
Origin: origin,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -551,9 +571,6 @@ func buildUserMessageFromOpenAI(msg gjson.Result, modelID, origin string) (KiroU
|
|||||||
var toolResults []KiroToolResult
|
var toolResults []KiroToolResult
|
||||||
var images []KiroImage
|
var images []KiroImage
|
||||||
|
|
||||||
// Track seen toolCallIds to deduplicate
|
|
||||||
seenToolCallIDs := make(map[string]bool)
|
|
||||||
|
|
||||||
if content.IsArray() {
|
if content.IsArray() {
|
||||||
for _, part := range content.Array() {
|
for _, part := range content.Array() {
|
||||||
partType := part.Get("type").String()
|
partType := part.Get("type").String()
|
||||||
@@ -589,9 +606,6 @@ func buildUserMessageFromOpenAI(msg gjson.Result, modelID, origin string) (KiroU
|
|||||||
contentBuilder.WriteString(content.String())
|
contentBuilder.WriteString(content.String())
|
||||||
}
|
}
|
||||||
|
|
||||||
// Check for tool_calls in the message (shouldn't be in user messages, but handle edge cases)
|
|
||||||
_ = seenToolCallIDs // Used for deduplication if needed
|
|
||||||
|
|
||||||
userMsg := KiroUserInputMessage{
|
userMsg := KiroUserInputMessage{
|
||||||
Content: contentBuilder.String(),
|
Content: contentBuilder.String(),
|
||||||
ModelID: modelID,
|
ModelID: modelID,
|
||||||
|
|||||||
386
internal/translator/kiro/openai/kiro_openai_request_test.go
Normal file
386
internal/translator/kiro/openai/kiro_openai_request_test.go
Normal file
@@ -0,0 +1,386 @@
|
|||||||
|
package openai
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/json"
|
||||||
|
"testing"
|
||||||
|
)
|
||||||
|
|
||||||
|
// TestToolResultsAttachedToCurrentMessage verifies that tool results from "tool" role messages
|
||||||
|
// are properly attached to the current user message (the last message in the conversation).
|
||||||
|
// This is critical for LiteLLM-translated requests where tool results appear as separate messages.
|
||||||
|
func TestToolResultsAttachedToCurrentMessage(t *testing.T) {
|
||||||
|
// OpenAI format request simulating LiteLLM's translation from Anthropic format
|
||||||
|
// Sequence: user -> assistant (with tool_calls) -> tool (result) -> user
|
||||||
|
// The last user message should have the tool results attached
|
||||||
|
input := []byte(`{
|
||||||
|
"model": "kiro-claude-opus-4-5-agentic",
|
||||||
|
"messages": [
|
||||||
|
{"role": "user", "content": "Hello, can you read a file for me?"},
|
||||||
|
{
|
||||||
|
"role": "assistant",
|
||||||
|
"content": "I'll read that file for you.",
|
||||||
|
"tool_calls": [
|
||||||
|
{
|
||||||
|
"id": "call_abc123",
|
||||||
|
"type": "function",
|
||||||
|
"function": {
|
||||||
|
"name": "Read",
|
||||||
|
"arguments": "{\"file_path\": \"/tmp/test.txt\"}"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"role": "tool",
|
||||||
|
"tool_call_id": "call_abc123",
|
||||||
|
"content": "File contents: Hello World!"
|
||||||
|
},
|
||||||
|
{"role": "user", "content": "What did the file say?"}
|
||||||
|
]
|
||||||
|
}`)
|
||||||
|
|
||||||
|
result, _ := BuildKiroPayloadFromOpenAI(input, "kiro-model", "", "CLI", false, false, nil, nil)
|
||||||
|
|
||||||
|
var payload KiroPayload
|
||||||
|
if err := json.Unmarshal(result, &payload); err != nil {
|
||||||
|
t.Fatalf("Failed to unmarshal result: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// The last user message becomes currentMessage
|
||||||
|
// History should have: user (first), assistant (with tool_calls)
|
||||||
|
t.Logf("History count: %d", len(payload.ConversationState.History))
|
||||||
|
if len(payload.ConversationState.History) != 2 {
|
||||||
|
t.Errorf("Expected 2 history entries (user + assistant), got %d", len(payload.ConversationState.History))
|
||||||
|
}
|
||||||
|
|
||||||
|
// Tool results should be attached to currentMessage (the last user message)
|
||||||
|
ctx := payload.ConversationState.CurrentMessage.UserInputMessage.UserInputMessageContext
|
||||||
|
if ctx == nil {
|
||||||
|
t.Fatal("Expected currentMessage to have UserInputMessageContext with tool results")
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(ctx.ToolResults) != 1 {
|
||||||
|
t.Fatalf("Expected 1 tool result in currentMessage, got %d", len(ctx.ToolResults))
|
||||||
|
}
|
||||||
|
|
||||||
|
tr := ctx.ToolResults[0]
|
||||||
|
if tr.ToolUseID != "call_abc123" {
|
||||||
|
t.Errorf("Expected toolUseId 'call_abc123', got '%s'", tr.ToolUseID)
|
||||||
|
}
|
||||||
|
if len(tr.Content) == 0 || tr.Content[0].Text != "File contents: Hello World!" {
|
||||||
|
t.Errorf("Tool result content mismatch, got: %+v", tr.Content)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestToolResultsInHistoryUserMessage verifies that when there are multiple user messages
|
||||||
|
// after tool results, the tool results are attached to the correct user message in history.
|
||||||
|
func TestToolResultsInHistoryUserMessage(t *testing.T) {
|
||||||
|
// Sequence: user -> assistant (with tool_calls) -> tool (result) -> user -> assistant -> user
|
||||||
|
// The first user after tool should have tool results in history
|
||||||
|
input := []byte(`{
|
||||||
|
"model": "kiro-claude-opus-4-5-agentic",
|
||||||
|
"messages": [
|
||||||
|
{"role": "user", "content": "Hello"},
|
||||||
|
{
|
||||||
|
"role": "assistant",
|
||||||
|
"content": "I'll read the file.",
|
||||||
|
"tool_calls": [
|
||||||
|
{
|
||||||
|
"id": "call_1",
|
||||||
|
"type": "function",
|
||||||
|
"function": {
|
||||||
|
"name": "Read",
|
||||||
|
"arguments": "{}"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"role": "tool",
|
||||||
|
"tool_call_id": "call_1",
|
||||||
|
"content": "File result"
|
||||||
|
},
|
||||||
|
{"role": "user", "content": "Thanks for the file"},
|
||||||
|
{"role": "assistant", "content": "You're welcome"},
|
||||||
|
{"role": "user", "content": "Bye"}
|
||||||
|
]
|
||||||
|
}`)
|
||||||
|
|
||||||
|
result, _ := BuildKiroPayloadFromOpenAI(input, "kiro-model", "", "CLI", false, false, nil, nil)
|
||||||
|
|
||||||
|
var payload KiroPayload
|
||||||
|
if err := json.Unmarshal(result, &payload); err != nil {
|
||||||
|
t.Fatalf("Failed to unmarshal result: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// History should have: user, assistant, user (with tool results), assistant
|
||||||
|
// CurrentMessage should be: last user "Bye"
|
||||||
|
t.Logf("History count: %d", len(payload.ConversationState.History))
|
||||||
|
|
||||||
|
// Find the user message in history with tool results
|
||||||
|
foundToolResults := false
|
||||||
|
for i, h := range payload.ConversationState.History {
|
||||||
|
if h.UserInputMessage != nil {
|
||||||
|
t.Logf("History[%d]: user message content=%q", i, h.UserInputMessage.Content)
|
||||||
|
if h.UserInputMessage.UserInputMessageContext != nil {
|
||||||
|
if len(h.UserInputMessage.UserInputMessageContext.ToolResults) > 0 {
|
||||||
|
foundToolResults = true
|
||||||
|
t.Logf(" Found %d tool results", len(h.UserInputMessage.UserInputMessageContext.ToolResults))
|
||||||
|
tr := h.UserInputMessage.UserInputMessageContext.ToolResults[0]
|
||||||
|
if tr.ToolUseID != "call_1" {
|
||||||
|
t.Errorf("Expected toolUseId 'call_1', got '%s'", tr.ToolUseID)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if h.AssistantResponseMessage != nil {
|
||||||
|
t.Logf("History[%d]: assistant message content=%q", i, h.AssistantResponseMessage.Content)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if !foundToolResults {
|
||||||
|
t.Error("Tool results were not attached to any user message in history")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestToolResultsWithMultipleToolCalls verifies handling of multiple tool calls
|
||||||
|
func TestToolResultsWithMultipleToolCalls(t *testing.T) {
|
||||||
|
input := []byte(`{
|
||||||
|
"model": "kiro-claude-opus-4-5-agentic",
|
||||||
|
"messages": [
|
||||||
|
{"role": "user", "content": "Read two files for me"},
|
||||||
|
{
|
||||||
|
"role": "assistant",
|
||||||
|
"content": "I'll read both files.",
|
||||||
|
"tool_calls": [
|
||||||
|
{
|
||||||
|
"id": "call_1",
|
||||||
|
"type": "function",
|
||||||
|
"function": {
|
||||||
|
"name": "Read",
|
||||||
|
"arguments": "{\"file_path\": \"/tmp/file1.txt\"}"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": "call_2",
|
||||||
|
"type": "function",
|
||||||
|
"function": {
|
||||||
|
"name": "Read",
|
||||||
|
"arguments": "{\"file_path\": \"/tmp/file2.txt\"}"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"role": "tool",
|
||||||
|
"tool_call_id": "call_1",
|
||||||
|
"content": "Content of file 1"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"role": "tool",
|
||||||
|
"tool_call_id": "call_2",
|
||||||
|
"content": "Content of file 2"
|
||||||
|
},
|
||||||
|
{"role": "user", "content": "What do they say?"}
|
||||||
|
]
|
||||||
|
}`)
|
||||||
|
|
||||||
|
result, _ := BuildKiroPayloadFromOpenAI(input, "kiro-model", "", "CLI", false, false, nil, nil)
|
||||||
|
|
||||||
|
var payload KiroPayload
|
||||||
|
if err := json.Unmarshal(result, &payload); err != nil {
|
||||||
|
t.Fatalf("Failed to unmarshal result: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
t.Logf("History count: %d", len(payload.ConversationState.History))
|
||||||
|
t.Logf("CurrentMessage content: %q", payload.ConversationState.CurrentMessage.UserInputMessage.Content)
|
||||||
|
|
||||||
|
// Check if there are any tool results anywhere
|
||||||
|
var totalToolResults int
|
||||||
|
for i, h := range payload.ConversationState.History {
|
||||||
|
if h.UserInputMessage != nil && h.UserInputMessage.UserInputMessageContext != nil {
|
||||||
|
count := len(h.UserInputMessage.UserInputMessageContext.ToolResults)
|
||||||
|
t.Logf("History[%d] user message has %d tool results", i, count)
|
||||||
|
totalToolResults += count
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
ctx := payload.ConversationState.CurrentMessage.UserInputMessage.UserInputMessageContext
|
||||||
|
if ctx != nil {
|
||||||
|
t.Logf("CurrentMessage has %d tool results", len(ctx.ToolResults))
|
||||||
|
totalToolResults += len(ctx.ToolResults)
|
||||||
|
} else {
|
||||||
|
t.Logf("CurrentMessage has no UserInputMessageContext")
|
||||||
|
}
|
||||||
|
|
||||||
|
if totalToolResults != 2 {
|
||||||
|
t.Errorf("Expected 2 tool results total, got %d", totalToolResults)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestToolResultsAtEndOfConversation verifies tool results are handled when
|
||||||
|
// the conversation ends with tool results (no following user message)
|
||||||
|
func TestToolResultsAtEndOfConversation(t *testing.T) {
|
||||||
|
input := []byte(`{
|
||||||
|
"model": "kiro-claude-opus-4-5-agentic",
|
||||||
|
"messages": [
|
||||||
|
{"role": "user", "content": "Read a file"},
|
||||||
|
{
|
||||||
|
"role": "assistant",
|
||||||
|
"content": "Reading the file.",
|
||||||
|
"tool_calls": [
|
||||||
|
{
|
||||||
|
"id": "call_end",
|
||||||
|
"type": "function",
|
||||||
|
"function": {
|
||||||
|
"name": "Read",
|
||||||
|
"arguments": "{\"file_path\": \"/tmp/test.txt\"}"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"role": "tool",
|
||||||
|
"tool_call_id": "call_end",
|
||||||
|
"content": "File contents here"
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}`)
|
||||||
|
|
||||||
|
result, _ := BuildKiroPayloadFromOpenAI(input, "kiro-model", "", "CLI", false, false, nil, nil)
|
||||||
|
|
||||||
|
var payload KiroPayload
|
||||||
|
if err := json.Unmarshal(result, &payload); err != nil {
|
||||||
|
t.Fatalf("Failed to unmarshal result: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// When the last message is a tool result, a synthetic user message is created
|
||||||
|
// and tool results should be attached to it
|
||||||
|
ctx := payload.ConversationState.CurrentMessage.UserInputMessage.UserInputMessageContext
|
||||||
|
if ctx == nil || len(ctx.ToolResults) == 0 {
|
||||||
|
t.Error("Expected tool results to be attached to current message when conversation ends with tool result")
|
||||||
|
} else {
|
||||||
|
if ctx.ToolResults[0].ToolUseID != "call_end" {
|
||||||
|
t.Errorf("Expected toolUseId 'call_end', got '%s'", ctx.ToolResults[0].ToolUseID)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestToolResultsFollowedByAssistant verifies handling when tool results are followed
|
||||||
|
// by an assistant message (no intermediate user message).
|
||||||
|
// This is the pattern from LiteLLM translation of Anthropic format where:
|
||||||
|
// user message has ONLY tool_result blocks -> LiteLLM creates tool messages
|
||||||
|
// then the next message is assistant
|
||||||
|
func TestToolResultsFollowedByAssistant(t *testing.T) {
|
||||||
|
// Sequence: user -> assistant (with tool_calls) -> tool -> tool -> assistant -> user
|
||||||
|
// This simulates LiteLLM's translation of:
|
||||||
|
// user: "Read files"
|
||||||
|
// assistant: [tool_use, tool_use]
|
||||||
|
// user: [tool_result, tool_result] <- becomes multiple "tool" role messages
|
||||||
|
// assistant: "I've read them"
|
||||||
|
// user: "What did they say?"
|
||||||
|
input := []byte(`{
|
||||||
|
"model": "kiro-claude-opus-4-5-agentic",
|
||||||
|
"messages": [
|
||||||
|
{"role": "user", "content": "Read two files for me"},
|
||||||
|
{
|
||||||
|
"role": "assistant",
|
||||||
|
"content": "I'll read both files.",
|
||||||
|
"tool_calls": [
|
||||||
|
{
|
||||||
|
"id": "call_1",
|
||||||
|
"type": "function",
|
||||||
|
"function": {
|
||||||
|
"name": "Read",
|
||||||
|
"arguments": "{\"file_path\": \"/tmp/a.txt\"}"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": "call_2",
|
||||||
|
"type": "function",
|
||||||
|
"function": {
|
||||||
|
"name": "Read",
|
||||||
|
"arguments": "{\"file_path\": \"/tmp/b.txt\"}"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"role": "tool",
|
||||||
|
"tool_call_id": "call_1",
|
||||||
|
"content": "Contents of file A"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"role": "tool",
|
||||||
|
"tool_call_id": "call_2",
|
||||||
|
"content": "Contents of file B"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"role": "assistant",
|
||||||
|
"content": "I've read both files."
|
||||||
|
},
|
||||||
|
{"role": "user", "content": "What did they say?"}
|
||||||
|
]
|
||||||
|
}`)
|
||||||
|
|
||||||
|
result, _ := BuildKiroPayloadFromOpenAI(input, "kiro-model", "", "CLI", false, false, nil, nil)
|
||||||
|
|
||||||
|
var payload KiroPayload
|
||||||
|
if err := json.Unmarshal(result, &payload); err != nil {
|
||||||
|
t.Fatalf("Failed to unmarshal result: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
t.Logf("History count: %d", len(payload.ConversationState.History))
|
||||||
|
|
||||||
|
// Tool results should be attached to a synthetic user message or the history should be valid
|
||||||
|
var totalToolResults int
|
||||||
|
for i, h := range payload.ConversationState.History {
|
||||||
|
if h.UserInputMessage != nil {
|
||||||
|
t.Logf("History[%d]: user message content=%q", i, h.UserInputMessage.Content)
|
||||||
|
if h.UserInputMessage.UserInputMessageContext != nil {
|
||||||
|
count := len(h.UserInputMessage.UserInputMessageContext.ToolResults)
|
||||||
|
t.Logf(" Has %d tool results", count)
|
||||||
|
totalToolResults += count
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if h.AssistantResponseMessage != nil {
|
||||||
|
t.Logf("History[%d]: assistant message content=%q", i, h.AssistantResponseMessage.Content)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
ctx := payload.ConversationState.CurrentMessage.UserInputMessage.UserInputMessageContext
|
||||||
|
if ctx != nil {
|
||||||
|
t.Logf("CurrentMessage has %d tool results", len(ctx.ToolResults))
|
||||||
|
totalToolResults += len(ctx.ToolResults)
|
||||||
|
}
|
||||||
|
|
||||||
|
if totalToolResults != 2 {
|
||||||
|
t.Errorf("Expected 2 tool results total, got %d", totalToolResults)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestAssistantEndsConversation verifies handling when assistant is the last message
|
||||||
|
func TestAssistantEndsConversation(t *testing.T) {
|
||||||
|
input := []byte(`{
|
||||||
|
"model": "kiro-claude-opus-4-5-agentic",
|
||||||
|
"messages": [
|
||||||
|
{"role": "user", "content": "Hello"},
|
||||||
|
{
|
||||||
|
"role": "assistant",
|
||||||
|
"content": "Hi there!"
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}`)
|
||||||
|
|
||||||
|
result, _ := BuildKiroPayloadFromOpenAI(input, "kiro-model", "", "CLI", false, false, nil, nil)
|
||||||
|
|
||||||
|
var payload KiroPayload
|
||||||
|
if err := json.Unmarshal(result, &payload); err != nil {
|
||||||
|
t.Fatalf("Failed to unmarshal result: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// When assistant is last, a "Continue" user message should be created
|
||||||
|
if payload.ConversationState.CurrentMessage.UserInputMessage.Content == "" {
|
||||||
|
t.Error("Expected a 'Continue' message to be created when assistant is last")
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -14,7 +14,6 @@ import (
|
|||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
"net/http"
|
"net/http"
|
||||||
"time"
|
|
||||||
|
|
||||||
"github.com/gin-gonic/gin"
|
"github.com/gin-gonic/gin"
|
||||||
. "github.com/router-for-me/CLIProxyAPI/v6/internal/constant"
|
. "github.com/router-for-me/CLIProxyAPI/v6/internal/constant"
|
||||||
@@ -185,14 +184,6 @@ func (h *ClaudeCodeAPIHandler) handleNonStreamingResponse(c *gin.Context, rawJSO
|
|||||||
// - c: The Gin context for the request.
|
// - c: The Gin context for the request.
|
||||||
// - rawJSON: The raw JSON request body.
|
// - rawJSON: The raw JSON request body.
|
||||||
func (h *ClaudeCodeAPIHandler) handleStreamingResponse(c *gin.Context, rawJSON []byte) {
|
func (h *ClaudeCodeAPIHandler) handleStreamingResponse(c *gin.Context, rawJSON []byte) {
|
||||||
// Set up Server-Sent Events (SSE) headers for streaming response
|
|
||||||
// These headers are essential for maintaining a persistent connection
|
|
||||||
// and enabling real-time streaming of chat completions
|
|
||||||
c.Header("Content-Type", "text/event-stream")
|
|
||||||
c.Header("Cache-Control", "no-cache")
|
|
||||||
c.Header("Connection", "keep-alive")
|
|
||||||
c.Header("Access-Control-Allow-Origin", "*")
|
|
||||||
|
|
||||||
// Get the http.Flusher interface to manually flush the response.
|
// Get the http.Flusher interface to manually flush the response.
|
||||||
// This is crucial for streaming as it allows immediate sending of data chunks
|
// This is crucial for streaming as it allows immediate sending of data chunks
|
||||||
flusher, ok := c.Writer.(http.Flusher)
|
flusher, ok := c.Writer.(http.Flusher)
|
||||||
@@ -213,58 +204,82 @@ func (h *ClaudeCodeAPIHandler) handleStreamingResponse(c *gin.Context, rawJSON [
|
|||||||
cliCtx, cliCancel := h.GetContextWithCancel(h, c, context.Background())
|
cliCtx, cliCancel := h.GetContextWithCancel(h, c, context.Background())
|
||||||
|
|
||||||
dataChan, errChan := h.ExecuteStreamWithAuthManager(cliCtx, h.HandlerType(), modelName, rawJSON, "")
|
dataChan, errChan := h.ExecuteStreamWithAuthManager(cliCtx, h.HandlerType(), modelName, rawJSON, "")
|
||||||
h.forwardClaudeStream(c, flusher, func(err error) { cliCancel(err) }, dataChan, errChan)
|
setSSEHeaders := func() {
|
||||||
return
|
c.Header("Content-Type", "text/event-stream")
|
||||||
}
|
c.Header("Cache-Control", "no-cache")
|
||||||
|
c.Header("Connection", "keep-alive")
|
||||||
|
c.Header("Access-Control-Allow-Origin", "*")
|
||||||
|
}
|
||||||
|
|
||||||
func (h *ClaudeCodeAPIHandler) forwardClaudeStream(c *gin.Context, flusher http.Flusher, cancel func(error), data <-chan []byte, errs <-chan *interfaces.ErrorMessage) {
|
// Peek at the first chunk to determine success or failure before setting headers
|
||||||
// OpenAI-style stream forwarding: write each SSE chunk and flush immediately.
|
|
||||||
// This guarantees clients see incremental output even for small responses.
|
|
||||||
for {
|
for {
|
||||||
select {
|
select {
|
||||||
case <-c.Request.Context().Done():
|
case <-c.Request.Context().Done():
|
||||||
cancel(c.Request.Context().Err())
|
cliCancel(c.Request.Context().Err())
|
||||||
return
|
return
|
||||||
|
case errMsg, ok := <-errChan:
|
||||||
case chunk, ok := <-data:
|
|
||||||
if !ok {
|
if !ok {
|
||||||
|
// Err channel closed cleanly; wait for data channel.
|
||||||
|
errChan = nil
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
// Upstream failed immediately. Return proper error status and JSON.
|
||||||
|
h.WriteErrorResponse(c, errMsg)
|
||||||
|
if errMsg != nil {
|
||||||
|
cliCancel(errMsg.Error)
|
||||||
|
} else {
|
||||||
|
cliCancel(nil)
|
||||||
|
}
|
||||||
|
return
|
||||||
|
case chunk, ok := <-dataChan:
|
||||||
|
if !ok {
|
||||||
|
// Stream closed without data? Send DONE or just headers.
|
||||||
|
setSSEHeaders()
|
||||||
flusher.Flush()
|
flusher.Flush()
|
||||||
cancel(nil)
|
cliCancel(nil)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Success! Set headers now.
|
||||||
|
setSSEHeaders()
|
||||||
|
|
||||||
|
// Write the first chunk
|
||||||
if len(chunk) > 0 {
|
if len(chunk) > 0 {
|
||||||
_, _ = c.Writer.Write(chunk)
|
_, _ = c.Writer.Write(chunk)
|
||||||
flusher.Flush()
|
flusher.Flush()
|
||||||
}
|
}
|
||||||
|
|
||||||
case errMsg, ok := <-errs:
|
// Continue streaming the rest
|
||||||
if !ok {
|
h.forwardClaudeStream(c, flusher, func(err error) { cliCancel(err) }, dataChan, errChan)
|
||||||
continue
|
|
||||||
}
|
|
||||||
if errMsg != nil {
|
|
||||||
status := http.StatusInternalServerError
|
|
||||||
if errMsg.StatusCode > 0 {
|
|
||||||
status = errMsg.StatusCode
|
|
||||||
}
|
|
||||||
c.Status(status)
|
|
||||||
|
|
||||||
// An error occurred: emit as a proper SSE error event
|
|
||||||
errorBytes, _ := json.Marshal(h.toClaudeError(errMsg))
|
|
||||||
_, _ = fmt.Fprintf(c.Writer, "event: error\ndata: %s\n\n", errorBytes)
|
|
||||||
flusher.Flush()
|
|
||||||
}
|
|
||||||
|
|
||||||
var execErr error
|
|
||||||
if errMsg != nil {
|
|
||||||
execErr = errMsg.Error
|
|
||||||
}
|
|
||||||
cancel(execErr)
|
|
||||||
return
|
return
|
||||||
case <-time.After(500 * time.Millisecond):
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (h *ClaudeCodeAPIHandler) forwardClaudeStream(c *gin.Context, flusher http.Flusher, cancel func(error), data <-chan []byte, errs <-chan *interfaces.ErrorMessage) {
|
||||||
|
h.ForwardStream(c, flusher, cancel, data, errs, handlers.StreamForwardOptions{
|
||||||
|
WriteChunk: func(chunk []byte) {
|
||||||
|
if len(chunk) == 0 {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
_, _ = c.Writer.Write(chunk)
|
||||||
|
},
|
||||||
|
WriteTerminalError: func(errMsg *interfaces.ErrorMessage) {
|
||||||
|
if errMsg == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
status := http.StatusInternalServerError
|
||||||
|
if errMsg.StatusCode > 0 {
|
||||||
|
status = errMsg.StatusCode
|
||||||
|
}
|
||||||
|
c.Status(status)
|
||||||
|
|
||||||
|
errorBytes, _ := json.Marshal(h.toClaudeError(errMsg))
|
||||||
|
_, _ = fmt.Fprintf(c.Writer, "event: error\ndata: %s\n\n", errorBytes)
|
||||||
|
},
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
type claudeErrorDetail struct {
|
type claudeErrorDetail struct {
|
||||||
Type string `json:"type"`
|
Type string `json:"type"`
|
||||||
Message string `json:"message"`
|
Message string `json:"message"`
|
||||||
|
|||||||
@@ -182,19 +182,18 @@ func (h *GeminiCLIAPIHandler) handleInternalGenerateContent(c *gin.Context, rawJ
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (h *GeminiCLIAPIHandler) forwardCLIStream(c *gin.Context, flusher http.Flusher, alt string, cancel func(error), data <-chan []byte, errs <-chan *interfaces.ErrorMessage) {
|
func (h *GeminiCLIAPIHandler) forwardCLIStream(c *gin.Context, flusher http.Flusher, alt string, cancel func(error), data <-chan []byte, errs <-chan *interfaces.ErrorMessage) {
|
||||||
for {
|
var keepAliveInterval *time.Duration
|
||||||
select {
|
if alt != "" {
|
||||||
case <-c.Request.Context().Done():
|
disabled := time.Duration(0)
|
||||||
cancel(c.Request.Context().Err())
|
keepAliveInterval = &disabled
|
||||||
return
|
}
|
||||||
case chunk, ok := <-data:
|
|
||||||
if !ok {
|
h.ForwardStream(c, flusher, cancel, data, errs, handlers.StreamForwardOptions{
|
||||||
cancel(nil)
|
KeepAliveInterval: keepAliveInterval,
|
||||||
return
|
WriteChunk: func(chunk []byte) {
|
||||||
}
|
|
||||||
if alt == "" {
|
if alt == "" {
|
||||||
if bytes.Equal(chunk, []byte("data: [DONE]")) || bytes.Equal(chunk, []byte("[DONE]")) {
|
if bytes.Equal(chunk, []byte("data: [DONE]")) || bytes.Equal(chunk, []byte("[DONE]")) {
|
||||||
continue
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
if !bytes.HasPrefix(chunk, []byte("data:")) {
|
if !bytes.HasPrefix(chunk, []byte("data:")) {
|
||||||
@@ -206,22 +205,25 @@ func (h *GeminiCLIAPIHandler) forwardCLIStream(c *gin.Context, flusher http.Flus
|
|||||||
} else {
|
} else {
|
||||||
_, _ = c.Writer.Write(chunk)
|
_, _ = c.Writer.Write(chunk)
|
||||||
}
|
}
|
||||||
flusher.Flush()
|
},
|
||||||
case errMsg, ok := <-errs:
|
WriteTerminalError: func(errMsg *interfaces.ErrorMessage) {
|
||||||
if !ok {
|
if errMsg == nil {
|
||||||
continue
|
return
|
||||||
}
|
}
|
||||||
if errMsg != nil {
|
status := http.StatusInternalServerError
|
||||||
h.WriteErrorResponse(c, errMsg)
|
if errMsg.StatusCode > 0 {
|
||||||
flusher.Flush()
|
status = errMsg.StatusCode
|
||||||
}
|
}
|
||||||
var execErr error
|
errText := http.StatusText(status)
|
||||||
if errMsg != nil {
|
if errMsg.Error != nil && errMsg.Error.Error() != "" {
|
||||||
execErr = errMsg.Error
|
errText = errMsg.Error.Error()
|
||||||
}
|
}
|
||||||
cancel(execErr)
|
body := handlers.BuildErrorResponseBody(status, errText)
|
||||||
return
|
if alt == "" {
|
||||||
case <-time.After(500 * time.Millisecond):
|
_, _ = fmt.Fprintf(c.Writer, "event: error\ndata: %s\n\n", string(body))
|
||||||
}
|
} else {
|
||||||
}
|
_, _ = c.Writer.Write(body)
|
||||||
|
}
|
||||||
|
},
|
||||||
|
})
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -226,13 +226,6 @@ func (h *GeminiAPIHandler) GeminiHandler(c *gin.Context) {
|
|||||||
func (h *GeminiAPIHandler) handleStreamGenerateContent(c *gin.Context, modelName string, rawJSON []byte) {
|
func (h *GeminiAPIHandler) handleStreamGenerateContent(c *gin.Context, modelName string, rawJSON []byte) {
|
||||||
alt := h.GetAlt(c)
|
alt := h.GetAlt(c)
|
||||||
|
|
||||||
if alt == "" {
|
|
||||||
c.Header("Content-Type", "text/event-stream")
|
|
||||||
c.Header("Cache-Control", "no-cache")
|
|
||||||
c.Header("Connection", "keep-alive")
|
|
||||||
c.Header("Access-Control-Allow-Origin", "*")
|
|
||||||
}
|
|
||||||
|
|
||||||
// Get the http.Flusher interface to manually flush the response.
|
// Get the http.Flusher interface to manually flush the response.
|
||||||
flusher, ok := c.Writer.(http.Flusher)
|
flusher, ok := c.Writer.(http.Flusher)
|
||||||
if !ok {
|
if !ok {
|
||||||
@@ -247,8 +240,65 @@ func (h *GeminiAPIHandler) handleStreamGenerateContent(c *gin.Context, modelName
|
|||||||
|
|
||||||
cliCtx, cliCancel := h.GetContextWithCancel(h, c, context.Background())
|
cliCtx, cliCancel := h.GetContextWithCancel(h, c, context.Background())
|
||||||
dataChan, errChan := h.ExecuteStreamWithAuthManager(cliCtx, h.HandlerType(), modelName, rawJSON, alt)
|
dataChan, errChan := h.ExecuteStreamWithAuthManager(cliCtx, h.HandlerType(), modelName, rawJSON, alt)
|
||||||
h.forwardGeminiStream(c, flusher, alt, func(err error) { cliCancel(err) }, dataChan, errChan)
|
|
||||||
return
|
setSSEHeaders := func() {
|
||||||
|
c.Header("Content-Type", "text/event-stream")
|
||||||
|
c.Header("Cache-Control", "no-cache")
|
||||||
|
c.Header("Connection", "keep-alive")
|
||||||
|
c.Header("Access-Control-Allow-Origin", "*")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Peek at the first chunk
|
||||||
|
for {
|
||||||
|
select {
|
||||||
|
case <-c.Request.Context().Done():
|
||||||
|
cliCancel(c.Request.Context().Err())
|
||||||
|
return
|
||||||
|
case errMsg, ok := <-errChan:
|
||||||
|
if !ok {
|
||||||
|
// Err channel closed cleanly; wait for data channel.
|
||||||
|
errChan = nil
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
// Upstream failed immediately. Return proper error status and JSON.
|
||||||
|
h.WriteErrorResponse(c, errMsg)
|
||||||
|
if errMsg != nil {
|
||||||
|
cliCancel(errMsg.Error)
|
||||||
|
} else {
|
||||||
|
cliCancel(nil)
|
||||||
|
}
|
||||||
|
return
|
||||||
|
case chunk, ok := <-dataChan:
|
||||||
|
if !ok {
|
||||||
|
// Closed without data
|
||||||
|
if alt == "" {
|
||||||
|
setSSEHeaders()
|
||||||
|
}
|
||||||
|
flusher.Flush()
|
||||||
|
cliCancel(nil)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Success! Set headers.
|
||||||
|
if alt == "" {
|
||||||
|
setSSEHeaders()
|
||||||
|
}
|
||||||
|
|
||||||
|
// Write first chunk
|
||||||
|
if alt == "" {
|
||||||
|
_, _ = c.Writer.Write([]byte("data: "))
|
||||||
|
_, _ = c.Writer.Write(chunk)
|
||||||
|
_, _ = c.Writer.Write([]byte("\n\n"))
|
||||||
|
} else {
|
||||||
|
_, _ = c.Writer.Write(chunk)
|
||||||
|
}
|
||||||
|
flusher.Flush()
|
||||||
|
|
||||||
|
// Continue
|
||||||
|
h.forwardGeminiStream(c, flusher, alt, func(err error) { cliCancel(err) }, dataChan, errChan)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// handleCountTokens handles token counting requests for Gemini models.
|
// handleCountTokens handles token counting requests for Gemini models.
|
||||||
@@ -297,16 +347,15 @@ func (h *GeminiAPIHandler) handleGenerateContent(c *gin.Context, modelName strin
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (h *GeminiAPIHandler) forwardGeminiStream(c *gin.Context, flusher http.Flusher, alt string, cancel func(error), data <-chan []byte, errs <-chan *interfaces.ErrorMessage) {
|
func (h *GeminiAPIHandler) forwardGeminiStream(c *gin.Context, flusher http.Flusher, alt string, cancel func(error), data <-chan []byte, errs <-chan *interfaces.ErrorMessage) {
|
||||||
for {
|
var keepAliveInterval *time.Duration
|
||||||
select {
|
if alt != "" {
|
||||||
case <-c.Request.Context().Done():
|
disabled := time.Duration(0)
|
||||||
cancel(c.Request.Context().Err())
|
keepAliveInterval = &disabled
|
||||||
return
|
}
|
||||||
case chunk, ok := <-data:
|
|
||||||
if !ok {
|
h.ForwardStream(c, flusher, cancel, data, errs, handlers.StreamForwardOptions{
|
||||||
cancel(nil)
|
KeepAliveInterval: keepAliveInterval,
|
||||||
return
|
WriteChunk: func(chunk []byte) {
|
||||||
}
|
|
||||||
if alt == "" {
|
if alt == "" {
|
||||||
_, _ = c.Writer.Write([]byte("data: "))
|
_, _ = c.Writer.Write([]byte("data: "))
|
||||||
_, _ = c.Writer.Write(chunk)
|
_, _ = c.Writer.Write(chunk)
|
||||||
@@ -314,22 +363,25 @@ func (h *GeminiAPIHandler) forwardGeminiStream(c *gin.Context, flusher http.Flus
|
|||||||
} else {
|
} else {
|
||||||
_, _ = c.Writer.Write(chunk)
|
_, _ = c.Writer.Write(chunk)
|
||||||
}
|
}
|
||||||
flusher.Flush()
|
},
|
||||||
case errMsg, ok := <-errs:
|
WriteTerminalError: func(errMsg *interfaces.ErrorMessage) {
|
||||||
if !ok {
|
if errMsg == nil {
|
||||||
continue
|
return
|
||||||
}
|
}
|
||||||
if errMsg != nil {
|
status := http.StatusInternalServerError
|
||||||
h.WriteErrorResponse(c, errMsg)
|
if errMsg.StatusCode > 0 {
|
||||||
flusher.Flush()
|
status = errMsg.StatusCode
|
||||||
}
|
}
|
||||||
var execErr error
|
errText := http.StatusText(status)
|
||||||
if errMsg != nil {
|
if errMsg.Error != nil && errMsg.Error.Error() != "" {
|
||||||
execErr = errMsg.Error
|
errText = errMsg.Error.Error()
|
||||||
}
|
}
|
||||||
cancel(execErr)
|
body := handlers.BuildErrorResponseBody(status, errText)
|
||||||
return
|
if alt == "" {
|
||||||
case <-time.After(500 * time.Millisecond):
|
_, _ = fmt.Fprintf(c.Writer, "event: error\ndata: %s\n\n", string(body))
|
||||||
}
|
} else {
|
||||||
}
|
_, _ = c.Writer.Write(body)
|
||||||
|
}
|
||||||
|
},
|
||||||
|
})
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -9,9 +9,12 @@ import (
|
|||||||
"fmt"
|
"fmt"
|
||||||
"net/http"
|
"net/http"
|
||||||
"strings"
|
"strings"
|
||||||
|
"time"
|
||||||
|
|
||||||
"github.com/gin-gonic/gin"
|
"github.com/gin-gonic/gin"
|
||||||
|
"github.com/google/uuid"
|
||||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/interfaces"
|
"github.com/router-for-me/CLIProxyAPI/v6/internal/interfaces"
|
||||||
|
"github.com/router-for-me/CLIProxyAPI/v6/internal/logging"
|
||||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/util"
|
"github.com/router-for-me/CLIProxyAPI/v6/internal/util"
|
||||||
coreauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth"
|
coreauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth"
|
||||||
coreexecutor "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/executor"
|
coreexecutor "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/executor"
|
||||||
@@ -40,6 +43,117 @@ type ErrorDetail struct {
|
|||||||
Code string `json:"code,omitempty"`
|
Code string `json:"code,omitempty"`
|
||||||
}
|
}
|
||||||
|
|
||||||
|
const idempotencyKeyMetadataKey = "idempotency_key"
|
||||||
|
|
||||||
|
const (
|
||||||
|
defaultStreamingKeepAliveSeconds = 0
|
||||||
|
defaultStreamingBootstrapRetries = 0
|
||||||
|
)
|
||||||
|
|
||||||
|
// BuildErrorResponseBody builds an OpenAI-compatible JSON error response body.
|
||||||
|
// If errText is already valid JSON, it is returned as-is to preserve upstream error payloads.
|
||||||
|
func BuildErrorResponseBody(status int, errText string) []byte {
|
||||||
|
if status <= 0 {
|
||||||
|
status = http.StatusInternalServerError
|
||||||
|
}
|
||||||
|
if strings.TrimSpace(errText) == "" {
|
||||||
|
errText = http.StatusText(status)
|
||||||
|
}
|
||||||
|
|
||||||
|
trimmed := strings.TrimSpace(errText)
|
||||||
|
if trimmed != "" && json.Valid([]byte(trimmed)) {
|
||||||
|
return []byte(trimmed)
|
||||||
|
}
|
||||||
|
|
||||||
|
errType := "invalid_request_error"
|
||||||
|
var code string
|
||||||
|
switch status {
|
||||||
|
case http.StatusUnauthorized:
|
||||||
|
errType = "authentication_error"
|
||||||
|
code = "invalid_api_key"
|
||||||
|
case http.StatusForbidden:
|
||||||
|
errType = "permission_error"
|
||||||
|
code = "insufficient_quota"
|
||||||
|
case http.StatusTooManyRequests:
|
||||||
|
errType = "rate_limit_error"
|
||||||
|
code = "rate_limit_exceeded"
|
||||||
|
case http.StatusNotFound:
|
||||||
|
errType = "invalid_request_error"
|
||||||
|
code = "model_not_found"
|
||||||
|
default:
|
||||||
|
if status >= http.StatusInternalServerError {
|
||||||
|
errType = "server_error"
|
||||||
|
code = "internal_server_error"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
payload, err := json.Marshal(ErrorResponse{
|
||||||
|
Error: ErrorDetail{
|
||||||
|
Message: errText,
|
||||||
|
Type: errType,
|
||||||
|
Code: code,
|
||||||
|
},
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
return []byte(fmt.Sprintf(`{"error":{"message":%q,"type":"server_error","code":"internal_server_error"}}`, errText))
|
||||||
|
}
|
||||||
|
return payload
|
||||||
|
}
|
||||||
|
|
||||||
|
// StreamingKeepAliveInterval returns the SSE keep-alive interval for this server.
|
||||||
|
// Returning 0 disables keep-alives (default when unset).
|
||||||
|
func StreamingKeepAliveInterval(cfg *config.SDKConfig) time.Duration {
|
||||||
|
seconds := defaultStreamingKeepAliveSeconds
|
||||||
|
if cfg != nil && cfg.Streaming.KeepAliveSeconds != nil {
|
||||||
|
seconds = *cfg.Streaming.KeepAliveSeconds
|
||||||
|
}
|
||||||
|
if seconds <= 0 {
|
||||||
|
return 0
|
||||||
|
}
|
||||||
|
return time.Duration(seconds) * time.Second
|
||||||
|
}
|
||||||
|
|
||||||
|
// StreamingBootstrapRetries returns how many times a streaming request may be retried before any bytes are sent.
|
||||||
|
func StreamingBootstrapRetries(cfg *config.SDKConfig) int {
|
||||||
|
retries := defaultStreamingBootstrapRetries
|
||||||
|
if cfg != nil && cfg.Streaming.BootstrapRetries != nil {
|
||||||
|
retries = *cfg.Streaming.BootstrapRetries
|
||||||
|
}
|
||||||
|
if retries < 0 {
|
||||||
|
retries = 0
|
||||||
|
}
|
||||||
|
return retries
|
||||||
|
}
|
||||||
|
|
||||||
|
func requestExecutionMetadata(ctx context.Context) map[string]any {
|
||||||
|
// Idempotency-Key is an optional client-supplied header used to correlate retries.
|
||||||
|
// It is forwarded as execution metadata; when absent we generate a UUID.
|
||||||
|
key := ""
|
||||||
|
if ctx != nil {
|
||||||
|
if ginCtx, ok := ctx.Value("gin").(*gin.Context); ok && ginCtx != nil && ginCtx.Request != nil {
|
||||||
|
key = strings.TrimSpace(ginCtx.GetHeader("Idempotency-Key"))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if key == "" {
|
||||||
|
key = uuid.NewString()
|
||||||
|
}
|
||||||
|
return map[string]any{idempotencyKeyMetadataKey: key}
|
||||||
|
}
|
||||||
|
|
||||||
|
func mergeMetadata(base, overlay map[string]any) map[string]any {
|
||||||
|
if len(base) == 0 && len(overlay) == 0 {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
out := make(map[string]any, len(base)+len(overlay))
|
||||||
|
for k, v := range base {
|
||||||
|
out[k] = v
|
||||||
|
}
|
||||||
|
for k, v := range overlay {
|
||||||
|
out[k] = v
|
||||||
|
}
|
||||||
|
return out
|
||||||
|
}
|
||||||
|
|
||||||
// BaseAPIHandler contains the handlers for API endpoints.
|
// BaseAPIHandler contains the handlers for API endpoints.
|
||||||
// It holds a pool of clients to interact with the backend service and manages
|
// It holds a pool of clients to interact with the backend service and manages
|
||||||
// load balancing, client selection, and configuration.
|
// load balancing, client selection, and configuration.
|
||||||
@@ -104,13 +218,39 @@ func (h *BaseAPIHandler) GetAlt(c *gin.Context) string {
|
|||||||
// Parameters:
|
// Parameters:
|
||||||
// - handler: The API handler associated with the request.
|
// - handler: The API handler associated with the request.
|
||||||
// - c: The Gin context of the current request.
|
// - c: The Gin context of the current request.
|
||||||
// - ctx: The parent context.
|
// - ctx: The parent context (caller values/deadlines are preserved; request context adds cancellation and request ID).
|
||||||
//
|
//
|
||||||
// Returns:
|
// Returns:
|
||||||
// - context.Context: The new context with cancellation and embedded values.
|
// - context.Context: The new context with cancellation and embedded values.
|
||||||
// - APIHandlerCancelFunc: A function to cancel the context and log the response.
|
// - APIHandlerCancelFunc: A function to cancel the context and log the response.
|
||||||
func (h *BaseAPIHandler) GetContextWithCancel(handler interfaces.APIHandler, c *gin.Context, ctx context.Context) (context.Context, APIHandlerCancelFunc) {
|
func (h *BaseAPIHandler) GetContextWithCancel(handler interfaces.APIHandler, c *gin.Context, ctx context.Context) (context.Context, APIHandlerCancelFunc) {
|
||||||
newCtx, cancel := context.WithCancel(ctx)
|
parentCtx := ctx
|
||||||
|
if parentCtx == nil {
|
||||||
|
parentCtx = context.Background()
|
||||||
|
}
|
||||||
|
|
||||||
|
var requestCtx context.Context
|
||||||
|
if c != nil && c.Request != nil {
|
||||||
|
requestCtx = c.Request.Context()
|
||||||
|
}
|
||||||
|
|
||||||
|
if requestCtx != nil && logging.GetRequestID(parentCtx) == "" {
|
||||||
|
if requestID := logging.GetRequestID(requestCtx); requestID != "" {
|
||||||
|
parentCtx = logging.WithRequestID(parentCtx, requestID)
|
||||||
|
} else if requestID := logging.GetGinRequestID(c); requestID != "" {
|
||||||
|
parentCtx = logging.WithRequestID(parentCtx, requestID)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
newCtx, cancel := context.WithCancel(parentCtx)
|
||||||
|
if requestCtx != nil && requestCtx != parentCtx {
|
||||||
|
go func() {
|
||||||
|
select {
|
||||||
|
case <-requestCtx.Done():
|
||||||
|
cancel()
|
||||||
|
case <-newCtx.Done():
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
}
|
||||||
newCtx = context.WithValue(newCtx, "gin", c)
|
newCtx = context.WithValue(newCtx, "gin", c)
|
||||||
newCtx = context.WithValue(newCtx, "handler", handler)
|
newCtx = context.WithValue(newCtx, "handler", handler)
|
||||||
return newCtx, func(params ...interface{}) {
|
return newCtx, func(params ...interface{}) {
|
||||||
@@ -183,6 +323,7 @@ func (h *BaseAPIHandler) ExecuteWithAuthManager(ctx context.Context, handlerType
|
|||||||
if errMsg != nil {
|
if errMsg != nil {
|
||||||
return nil, errMsg
|
return nil, errMsg
|
||||||
}
|
}
|
||||||
|
reqMeta := requestExecutionMetadata(ctx)
|
||||||
req := coreexecutor.Request{
|
req := coreexecutor.Request{
|
||||||
Model: normalizedModel,
|
Model: normalizedModel,
|
||||||
Payload: cloneBytes(rawJSON),
|
Payload: cloneBytes(rawJSON),
|
||||||
@@ -196,9 +337,7 @@ func (h *BaseAPIHandler) ExecuteWithAuthManager(ctx context.Context, handlerType
|
|||||||
OriginalRequest: cloneBytes(rawJSON),
|
OriginalRequest: cloneBytes(rawJSON),
|
||||||
SourceFormat: sdktranslator.FromString(handlerType),
|
SourceFormat: sdktranslator.FromString(handlerType),
|
||||||
}
|
}
|
||||||
if cloned := cloneMetadata(metadata); cloned != nil {
|
opts.Metadata = mergeMetadata(cloneMetadata(metadata), reqMeta)
|
||||||
opts.Metadata = cloned
|
|
||||||
}
|
|
||||||
resp, err := h.AuthManager.Execute(ctx, providers, req, opts)
|
resp, err := h.AuthManager.Execute(ctx, providers, req, opts)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
status := http.StatusInternalServerError
|
status := http.StatusInternalServerError
|
||||||
@@ -225,6 +364,7 @@ func (h *BaseAPIHandler) ExecuteCountWithAuthManager(ctx context.Context, handle
|
|||||||
if errMsg != nil {
|
if errMsg != nil {
|
||||||
return nil, errMsg
|
return nil, errMsg
|
||||||
}
|
}
|
||||||
|
reqMeta := requestExecutionMetadata(ctx)
|
||||||
req := coreexecutor.Request{
|
req := coreexecutor.Request{
|
||||||
Model: normalizedModel,
|
Model: normalizedModel,
|
||||||
Payload: cloneBytes(rawJSON),
|
Payload: cloneBytes(rawJSON),
|
||||||
@@ -238,9 +378,7 @@ func (h *BaseAPIHandler) ExecuteCountWithAuthManager(ctx context.Context, handle
|
|||||||
OriginalRequest: cloneBytes(rawJSON),
|
OriginalRequest: cloneBytes(rawJSON),
|
||||||
SourceFormat: sdktranslator.FromString(handlerType),
|
SourceFormat: sdktranslator.FromString(handlerType),
|
||||||
}
|
}
|
||||||
if cloned := cloneMetadata(metadata); cloned != nil {
|
opts.Metadata = mergeMetadata(cloneMetadata(metadata), reqMeta)
|
||||||
opts.Metadata = cloned
|
|
||||||
}
|
|
||||||
resp, err := h.AuthManager.ExecuteCount(ctx, providers, req, opts)
|
resp, err := h.AuthManager.ExecuteCount(ctx, providers, req, opts)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
status := http.StatusInternalServerError
|
status := http.StatusInternalServerError
|
||||||
@@ -270,6 +408,7 @@ func (h *BaseAPIHandler) ExecuteStreamWithAuthManager(ctx context.Context, handl
|
|||||||
close(errChan)
|
close(errChan)
|
||||||
return nil, errChan
|
return nil, errChan
|
||||||
}
|
}
|
||||||
|
reqMeta := requestExecutionMetadata(ctx)
|
||||||
req := coreexecutor.Request{
|
req := coreexecutor.Request{
|
||||||
Model: normalizedModel,
|
Model: normalizedModel,
|
||||||
Payload: cloneBytes(rawJSON),
|
Payload: cloneBytes(rawJSON),
|
||||||
@@ -283,9 +422,7 @@ func (h *BaseAPIHandler) ExecuteStreamWithAuthManager(ctx context.Context, handl
|
|||||||
OriginalRequest: cloneBytes(rawJSON),
|
OriginalRequest: cloneBytes(rawJSON),
|
||||||
SourceFormat: sdktranslator.FromString(handlerType),
|
SourceFormat: sdktranslator.FromString(handlerType),
|
||||||
}
|
}
|
||||||
if cloned := cloneMetadata(metadata); cloned != nil {
|
opts.Metadata = mergeMetadata(cloneMetadata(metadata), reqMeta)
|
||||||
opts.Metadata = cloned
|
|
||||||
}
|
|
||||||
chunks, err := h.AuthManager.ExecuteStream(ctx, providers, req, opts)
|
chunks, err := h.AuthManager.ExecuteStream(ctx, providers, req, opts)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
errChan := make(chan *interfaces.ErrorMessage, 1)
|
errChan := make(chan *interfaces.ErrorMessage, 1)
|
||||||
@@ -310,31 +447,94 @@ func (h *BaseAPIHandler) ExecuteStreamWithAuthManager(ctx context.Context, handl
|
|||||||
go func() {
|
go func() {
|
||||||
defer close(dataChan)
|
defer close(dataChan)
|
||||||
defer close(errChan)
|
defer close(errChan)
|
||||||
for chunk := range chunks {
|
sentPayload := false
|
||||||
if chunk.Err != nil {
|
bootstrapRetries := 0
|
||||||
status := http.StatusInternalServerError
|
maxBootstrapRetries := StreamingBootstrapRetries(h.Cfg)
|
||||||
if se, ok := chunk.Err.(interface{ StatusCode() int }); ok && se != nil {
|
|
||||||
if code := se.StatusCode(); code > 0 {
|
bootstrapEligible := func(err error) bool {
|
||||||
status = code
|
status := statusFromError(err)
|
||||||
}
|
if status == 0 {
|
||||||
}
|
return true
|
||||||
var addon http.Header
|
|
||||||
if he, ok := chunk.Err.(interface{ Headers() http.Header }); ok && he != nil {
|
|
||||||
if hdr := he.Headers(); hdr != nil {
|
|
||||||
addon = hdr.Clone()
|
|
||||||
}
|
|
||||||
}
|
|
||||||
errChan <- &interfaces.ErrorMessage{StatusCode: status, Error: chunk.Err, Addon: addon}
|
|
||||||
return
|
|
||||||
}
|
}
|
||||||
if len(chunk.Payload) > 0 {
|
switch status {
|
||||||
dataChan <- cloneBytes(chunk.Payload)
|
case http.StatusUnauthorized, http.StatusForbidden, http.StatusPaymentRequired,
|
||||||
|
http.StatusRequestTimeout, http.StatusTooManyRequests:
|
||||||
|
return true
|
||||||
|
default:
|
||||||
|
return status >= http.StatusInternalServerError
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
outer:
|
||||||
|
for {
|
||||||
|
for {
|
||||||
|
var chunk coreexecutor.StreamChunk
|
||||||
|
var ok bool
|
||||||
|
if ctx != nil {
|
||||||
|
select {
|
||||||
|
case <-ctx.Done():
|
||||||
|
return
|
||||||
|
case chunk, ok = <-chunks:
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
chunk, ok = <-chunks
|
||||||
|
}
|
||||||
|
if !ok {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if chunk.Err != nil {
|
||||||
|
streamErr := chunk.Err
|
||||||
|
// Safe bootstrap recovery: if the upstream fails before any payload bytes are sent,
|
||||||
|
// retry a few times (to allow auth rotation / transient recovery) and then attempt model fallback.
|
||||||
|
if !sentPayload {
|
||||||
|
if bootstrapRetries < maxBootstrapRetries && bootstrapEligible(streamErr) {
|
||||||
|
bootstrapRetries++
|
||||||
|
retryChunks, retryErr := h.AuthManager.ExecuteStream(ctx, providers, req, opts)
|
||||||
|
if retryErr == nil {
|
||||||
|
chunks = retryChunks
|
||||||
|
continue outer
|
||||||
|
}
|
||||||
|
streamErr = retryErr
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
status := http.StatusInternalServerError
|
||||||
|
if se, ok := streamErr.(interface{ StatusCode() int }); ok && se != nil {
|
||||||
|
if code := se.StatusCode(); code > 0 {
|
||||||
|
status = code
|
||||||
|
}
|
||||||
|
}
|
||||||
|
var addon http.Header
|
||||||
|
if he, ok := streamErr.(interface{ Headers() http.Header }); ok && he != nil {
|
||||||
|
if hdr := he.Headers(); hdr != nil {
|
||||||
|
addon = hdr.Clone()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
errChan <- &interfaces.ErrorMessage{StatusCode: status, Error: streamErr, Addon: addon}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if len(chunk.Payload) > 0 {
|
||||||
|
sentPayload = true
|
||||||
|
dataChan <- cloneBytes(chunk.Payload)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
return dataChan, errChan
|
return dataChan, errChan
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func statusFromError(err error) int {
|
||||||
|
if err == nil {
|
||||||
|
return 0
|
||||||
|
}
|
||||||
|
if se, ok := err.(interface{ StatusCode() int }); ok && se != nil {
|
||||||
|
if code := se.StatusCode(); code > 0 {
|
||||||
|
return code
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return 0
|
||||||
|
}
|
||||||
|
|
||||||
func (h *BaseAPIHandler) getRequestDetails(modelName string) (providers []string, normalizedModel string, metadata map[string]any, err *interfaces.ErrorMessage) {
|
func (h *BaseAPIHandler) getRequestDetails(modelName string) (providers []string, normalizedModel string, metadata map[string]any, err *interfaces.ErrorMessage) {
|
||||||
// Resolve "auto" model to an actual available model first
|
// Resolve "auto" model to an actual available model first
|
||||||
resolvedModelName := util.ResolveAutoModel(modelName)
|
resolvedModelName := util.ResolveAutoModel(modelName)
|
||||||
@@ -418,38 +618,7 @@ func (h *BaseAPIHandler) WriteErrorResponse(c *gin.Context, msg *interfaces.Erro
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Prefer preserving upstream JSON error bodies when possible.
|
body := BuildErrorResponseBody(status, errText)
|
||||||
buildJSONBody := func() []byte {
|
|
||||||
trimmed := strings.TrimSpace(errText)
|
|
||||||
if trimmed != "" && json.Valid([]byte(trimmed)) {
|
|
||||||
return []byte(trimmed)
|
|
||||||
}
|
|
||||||
errType := "invalid_request_error"
|
|
||||||
switch status {
|
|
||||||
case http.StatusUnauthorized:
|
|
||||||
errType = "authentication_error"
|
|
||||||
case http.StatusForbidden:
|
|
||||||
errType = "permission_error"
|
|
||||||
case http.StatusTooManyRequests:
|
|
||||||
errType = "rate_limit_error"
|
|
||||||
default:
|
|
||||||
if status >= http.StatusInternalServerError {
|
|
||||||
errType = "server_error"
|
|
||||||
}
|
|
||||||
}
|
|
||||||
payload, err := json.Marshal(ErrorResponse{
|
|
||||||
Error: ErrorDetail{
|
|
||||||
Message: errText,
|
|
||||||
Type: errType,
|
|
||||||
},
|
|
||||||
})
|
|
||||||
if err != nil {
|
|
||||||
return []byte(fmt.Sprintf(`{"error":{"message":%q,"type":"server_error"}}`, errText))
|
|
||||||
}
|
|
||||||
return payload
|
|
||||||
}
|
|
||||||
|
|
||||||
body := buildJSONBody()
|
|
||||||
c.Set("API_RESPONSE", bytes.Clone(body))
|
c.Set("API_RESPONSE", bytes.Clone(body))
|
||||||
|
|
||||||
if !c.Writer.Written() {
|
if !c.Writer.Written() {
|
||||||
|
|||||||
125
sdk/api/handlers/handlers_stream_bootstrap_test.go
Normal file
125
sdk/api/handlers/handlers_stream_bootstrap_test.go
Normal file
@@ -0,0 +1,125 @@
|
|||||||
|
package handlers
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"net/http"
|
||||||
|
"sync"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/router-for-me/CLIProxyAPI/v6/internal/registry"
|
||||||
|
coreauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth"
|
||||||
|
coreexecutor "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/executor"
|
||||||
|
sdkconfig "github.com/router-for-me/CLIProxyAPI/v6/sdk/config"
|
||||||
|
)
|
||||||
|
|
||||||
|
type failOnceStreamExecutor struct {
|
||||||
|
mu sync.Mutex
|
||||||
|
calls int
|
||||||
|
}
|
||||||
|
|
||||||
|
func (e *failOnceStreamExecutor) Identifier() string { return "codex" }
|
||||||
|
|
||||||
|
func (e *failOnceStreamExecutor) Execute(context.Context, *coreauth.Auth, coreexecutor.Request, coreexecutor.Options) (coreexecutor.Response, error) {
|
||||||
|
return coreexecutor.Response{}, &coreauth.Error{Code: "not_implemented", Message: "Execute not implemented"}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (e *failOnceStreamExecutor) ExecuteStream(context.Context, *coreauth.Auth, coreexecutor.Request, coreexecutor.Options) (<-chan coreexecutor.StreamChunk, error) {
|
||||||
|
e.mu.Lock()
|
||||||
|
e.calls++
|
||||||
|
call := e.calls
|
||||||
|
e.mu.Unlock()
|
||||||
|
|
||||||
|
ch := make(chan coreexecutor.StreamChunk, 1)
|
||||||
|
if call == 1 {
|
||||||
|
ch <- coreexecutor.StreamChunk{
|
||||||
|
Err: &coreauth.Error{
|
||||||
|
Code: "unauthorized",
|
||||||
|
Message: "unauthorized",
|
||||||
|
Retryable: false,
|
||||||
|
HTTPStatus: http.StatusUnauthorized,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
close(ch)
|
||||||
|
return ch, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
ch <- coreexecutor.StreamChunk{Payload: []byte("ok")}
|
||||||
|
close(ch)
|
||||||
|
return ch, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (e *failOnceStreamExecutor) Refresh(ctx context.Context, auth *coreauth.Auth) (*coreauth.Auth, error) {
|
||||||
|
return auth, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (e *failOnceStreamExecutor) CountTokens(context.Context, *coreauth.Auth, coreexecutor.Request, coreexecutor.Options) (coreexecutor.Response, error) {
|
||||||
|
return coreexecutor.Response{}, &coreauth.Error{Code: "not_implemented", Message: "CountTokens not implemented"}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (e *failOnceStreamExecutor) Calls() int {
|
||||||
|
e.mu.Lock()
|
||||||
|
defer e.mu.Unlock()
|
||||||
|
return e.calls
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestExecuteStreamWithAuthManager_RetriesBeforeFirstByte(t *testing.T) {
|
||||||
|
executor := &failOnceStreamExecutor{}
|
||||||
|
manager := coreauth.NewManager(nil, nil, nil)
|
||||||
|
manager.RegisterExecutor(executor)
|
||||||
|
|
||||||
|
auth1 := &coreauth.Auth{
|
||||||
|
ID: "auth1",
|
||||||
|
Provider: "codex",
|
||||||
|
Status: coreauth.StatusActive,
|
||||||
|
Metadata: map[string]any{"email": "test1@example.com"},
|
||||||
|
}
|
||||||
|
if _, err := manager.Register(context.Background(), auth1); err != nil {
|
||||||
|
t.Fatalf("manager.Register(auth1): %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
auth2 := &coreauth.Auth{
|
||||||
|
ID: "auth2",
|
||||||
|
Provider: "codex",
|
||||||
|
Status: coreauth.StatusActive,
|
||||||
|
Metadata: map[string]any{"email": "test2@example.com"},
|
||||||
|
}
|
||||||
|
if _, err := manager.Register(context.Background(), auth2); err != nil {
|
||||||
|
t.Fatalf("manager.Register(auth2): %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
registry.GetGlobalRegistry().RegisterClient(auth1.ID, auth1.Provider, []*registry.ModelInfo{{ID: "test-model"}})
|
||||||
|
registry.GetGlobalRegistry().RegisterClient(auth2.ID, auth2.Provider, []*registry.ModelInfo{{ID: "test-model"}})
|
||||||
|
t.Cleanup(func() {
|
||||||
|
registry.GetGlobalRegistry().UnregisterClient(auth1.ID)
|
||||||
|
registry.GetGlobalRegistry().UnregisterClient(auth2.ID)
|
||||||
|
})
|
||||||
|
|
||||||
|
bootstrapRetries := 1
|
||||||
|
handler := NewBaseAPIHandlers(&sdkconfig.SDKConfig{
|
||||||
|
Streaming: sdkconfig.StreamingConfig{
|
||||||
|
BootstrapRetries: &bootstrapRetries,
|
||||||
|
},
|
||||||
|
}, manager, nil)
|
||||||
|
dataChan, errChan := handler.ExecuteStreamWithAuthManager(context.Background(), "openai", "test-model", []byte(`{"model":"test-model"}`), "")
|
||||||
|
if dataChan == nil || errChan == nil {
|
||||||
|
t.Fatalf("expected non-nil channels")
|
||||||
|
}
|
||||||
|
|
||||||
|
var got []byte
|
||||||
|
for chunk := range dataChan {
|
||||||
|
got = append(got, chunk...)
|
||||||
|
}
|
||||||
|
|
||||||
|
for msg := range errChan {
|
||||||
|
if msg != nil {
|
||||||
|
t.Fatalf("unexpected error: %+v", msg)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if string(got) != "ok" {
|
||||||
|
t.Fatalf("expected payload ok, got %q", string(got))
|
||||||
|
}
|
||||||
|
if executor.Calls() != 2 {
|
||||||
|
t.Fatalf("expected 2 stream attempts, got %d", executor.Calls())
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -11,7 +11,7 @@ import (
|
|||||||
"encoding/json"
|
"encoding/json"
|
||||||
"fmt"
|
"fmt"
|
||||||
"net/http"
|
"net/http"
|
||||||
"time"
|
"sync"
|
||||||
|
|
||||||
"github.com/gin-gonic/gin"
|
"github.com/gin-gonic/gin"
|
||||||
. "github.com/router-for-me/CLIProxyAPI/v6/internal/constant"
|
. "github.com/router-for-me/CLIProxyAPI/v6/internal/constant"
|
||||||
@@ -443,11 +443,6 @@ func (h *OpenAIAPIHandler) handleNonStreamingResponse(c *gin.Context, rawJSON []
|
|||||||
// - c: The Gin context containing the HTTP request and response
|
// - c: The Gin context containing the HTTP request and response
|
||||||
// - rawJSON: The raw JSON bytes of the OpenAI-compatible request
|
// - rawJSON: The raw JSON bytes of the OpenAI-compatible request
|
||||||
func (h *OpenAIAPIHandler) handleStreamingResponse(c *gin.Context, rawJSON []byte) {
|
func (h *OpenAIAPIHandler) handleStreamingResponse(c *gin.Context, rawJSON []byte) {
|
||||||
c.Header("Content-Type", "text/event-stream")
|
|
||||||
c.Header("Cache-Control", "no-cache")
|
|
||||||
c.Header("Connection", "keep-alive")
|
|
||||||
c.Header("Access-Control-Allow-Origin", "*")
|
|
||||||
|
|
||||||
// Get the http.Flusher interface to manually flush the response.
|
// Get the http.Flusher interface to manually flush the response.
|
||||||
flusher, ok := c.Writer.(http.Flusher)
|
flusher, ok := c.Writer.(http.Flusher)
|
||||||
if !ok {
|
if !ok {
|
||||||
@@ -463,7 +458,55 @@ func (h *OpenAIAPIHandler) handleStreamingResponse(c *gin.Context, rawJSON []byt
|
|||||||
modelName := gjson.GetBytes(rawJSON, "model").String()
|
modelName := gjson.GetBytes(rawJSON, "model").String()
|
||||||
cliCtx, cliCancel := h.GetContextWithCancel(h, c, context.Background())
|
cliCtx, cliCancel := h.GetContextWithCancel(h, c, context.Background())
|
||||||
dataChan, errChan := h.ExecuteStreamWithAuthManager(cliCtx, h.HandlerType(), modelName, rawJSON, h.GetAlt(c))
|
dataChan, errChan := h.ExecuteStreamWithAuthManager(cliCtx, h.HandlerType(), modelName, rawJSON, h.GetAlt(c))
|
||||||
h.handleStreamResult(c, flusher, func(err error) { cliCancel(err) }, dataChan, errChan)
|
|
||||||
|
setSSEHeaders := func() {
|
||||||
|
c.Header("Content-Type", "text/event-stream")
|
||||||
|
c.Header("Cache-Control", "no-cache")
|
||||||
|
c.Header("Connection", "keep-alive")
|
||||||
|
c.Header("Access-Control-Allow-Origin", "*")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Peek at the first chunk to determine success or failure before setting headers
|
||||||
|
for {
|
||||||
|
select {
|
||||||
|
case <-c.Request.Context().Done():
|
||||||
|
cliCancel(c.Request.Context().Err())
|
||||||
|
return
|
||||||
|
case errMsg, ok := <-errChan:
|
||||||
|
if !ok {
|
||||||
|
// Err channel closed cleanly; wait for data channel.
|
||||||
|
errChan = nil
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
// Upstream failed immediately. Return proper error status and JSON.
|
||||||
|
h.WriteErrorResponse(c, errMsg)
|
||||||
|
if errMsg != nil {
|
||||||
|
cliCancel(errMsg.Error)
|
||||||
|
} else {
|
||||||
|
cliCancel(nil)
|
||||||
|
}
|
||||||
|
return
|
||||||
|
case chunk, ok := <-dataChan:
|
||||||
|
if !ok {
|
||||||
|
// Stream closed without data? Send DONE or just headers.
|
||||||
|
setSSEHeaders()
|
||||||
|
_, _ = fmt.Fprintf(c.Writer, "data: [DONE]\n\n")
|
||||||
|
flusher.Flush()
|
||||||
|
cliCancel(nil)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Success! Commit to streaming headers.
|
||||||
|
setSSEHeaders()
|
||||||
|
|
||||||
|
_, _ = fmt.Fprintf(c.Writer, "data: %s\n\n", string(chunk))
|
||||||
|
flusher.Flush()
|
||||||
|
|
||||||
|
// Continue streaming the rest
|
||||||
|
h.handleStreamResult(c, flusher, func(err error) { cliCancel(err) }, dataChan, errChan)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// handleCompletionsNonStreamingResponse handles non-streaming completions responses.
|
// handleCompletionsNonStreamingResponse handles non-streaming completions responses.
|
||||||
@@ -500,11 +543,6 @@ func (h *OpenAIAPIHandler) handleCompletionsNonStreamingResponse(c *gin.Context,
|
|||||||
// - c: The Gin context containing the HTTP request and response
|
// - c: The Gin context containing the HTTP request and response
|
||||||
// - rawJSON: The raw JSON bytes of the OpenAI-compatible completions request
|
// - rawJSON: The raw JSON bytes of the OpenAI-compatible completions request
|
||||||
func (h *OpenAIAPIHandler) handleCompletionsStreamingResponse(c *gin.Context, rawJSON []byte) {
|
func (h *OpenAIAPIHandler) handleCompletionsStreamingResponse(c *gin.Context, rawJSON []byte) {
|
||||||
c.Header("Content-Type", "text/event-stream")
|
|
||||||
c.Header("Cache-Control", "no-cache")
|
|
||||||
c.Header("Connection", "keep-alive")
|
|
||||||
c.Header("Access-Control-Allow-Origin", "*")
|
|
||||||
|
|
||||||
// Get the http.Flusher interface to manually flush the response.
|
// Get the http.Flusher interface to manually flush the response.
|
||||||
flusher, ok := c.Writer.(http.Flusher)
|
flusher, ok := c.Writer.(http.Flusher)
|
||||||
if !ok {
|
if !ok {
|
||||||
@@ -524,71 +562,109 @@ func (h *OpenAIAPIHandler) handleCompletionsStreamingResponse(c *gin.Context, ra
|
|||||||
cliCtx, cliCancel := h.GetContextWithCancel(h, c, context.Background())
|
cliCtx, cliCancel := h.GetContextWithCancel(h, c, context.Background())
|
||||||
dataChan, errChan := h.ExecuteStreamWithAuthManager(cliCtx, h.HandlerType(), modelName, chatCompletionsJSON, "")
|
dataChan, errChan := h.ExecuteStreamWithAuthManager(cliCtx, h.HandlerType(), modelName, chatCompletionsJSON, "")
|
||||||
|
|
||||||
|
setSSEHeaders := func() {
|
||||||
|
c.Header("Content-Type", "text/event-stream")
|
||||||
|
c.Header("Cache-Control", "no-cache")
|
||||||
|
c.Header("Connection", "keep-alive")
|
||||||
|
c.Header("Access-Control-Allow-Origin", "*")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Peek at the first chunk
|
||||||
for {
|
for {
|
||||||
select {
|
select {
|
||||||
case <-c.Request.Context().Done():
|
case <-c.Request.Context().Done():
|
||||||
cliCancel(c.Request.Context().Err())
|
cliCancel(c.Request.Context().Err())
|
||||||
return
|
return
|
||||||
case chunk, isOk := <-dataChan:
|
case errMsg, ok := <-errChan:
|
||||||
if !isOk {
|
if !ok {
|
||||||
|
// Err channel closed cleanly; wait for data channel.
|
||||||
|
errChan = nil
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
h.WriteErrorResponse(c, errMsg)
|
||||||
|
if errMsg != nil {
|
||||||
|
cliCancel(errMsg.Error)
|
||||||
|
} else {
|
||||||
|
cliCancel(nil)
|
||||||
|
}
|
||||||
|
return
|
||||||
|
case chunk, ok := <-dataChan:
|
||||||
|
if !ok {
|
||||||
|
setSSEHeaders()
|
||||||
_, _ = fmt.Fprintf(c.Writer, "data: [DONE]\n\n")
|
_, _ = fmt.Fprintf(c.Writer, "data: [DONE]\n\n")
|
||||||
flusher.Flush()
|
flusher.Flush()
|
||||||
cliCancel()
|
cliCancel(nil)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Success! Set headers.
|
||||||
|
setSSEHeaders()
|
||||||
|
|
||||||
|
// Write the first chunk
|
||||||
converted := convertChatCompletionsStreamChunkToCompletions(chunk)
|
converted := convertChatCompletionsStreamChunkToCompletions(chunk)
|
||||||
if converted != nil {
|
if converted != nil {
|
||||||
_, _ = fmt.Fprintf(c.Writer, "data: %s\n\n", string(converted))
|
_, _ = fmt.Fprintf(c.Writer, "data: %s\n\n", string(converted))
|
||||||
flusher.Flush()
|
flusher.Flush()
|
||||||
}
|
}
|
||||||
case errMsg, isOk := <-errChan:
|
|
||||||
if !isOk {
|
done := make(chan struct{})
|
||||||
continue
|
var doneOnce sync.Once
|
||||||
}
|
stop := func() { doneOnce.Do(func() { close(done) }) }
|
||||||
if errMsg != nil {
|
|
||||||
h.WriteErrorResponse(c, errMsg)
|
convertedChan := make(chan []byte)
|
||||||
flusher.Flush()
|
go func() {
|
||||||
}
|
defer close(convertedChan)
|
||||||
var execErr error
|
for {
|
||||||
if errMsg != nil {
|
select {
|
||||||
execErr = errMsg.Error
|
case <-done:
|
||||||
}
|
return
|
||||||
cliCancel(execErr)
|
case chunk, ok := <-dataChan:
|
||||||
|
if !ok {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
converted := convertChatCompletionsStreamChunkToCompletions(chunk)
|
||||||
|
if converted == nil {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
select {
|
||||||
|
case <-done:
|
||||||
|
return
|
||||||
|
case convertedChan <- converted:
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
h.handleStreamResult(c, flusher, func(err error) {
|
||||||
|
stop()
|
||||||
|
cliCancel(err)
|
||||||
|
}, convertedChan, errChan)
|
||||||
return
|
return
|
||||||
case <-time.After(500 * time.Millisecond):
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
func (h *OpenAIAPIHandler) handleStreamResult(c *gin.Context, flusher http.Flusher, cancel func(error), data <-chan []byte, errs <-chan *interfaces.ErrorMessage) {
|
func (h *OpenAIAPIHandler) handleStreamResult(c *gin.Context, flusher http.Flusher, cancel func(error), data <-chan []byte, errs <-chan *interfaces.ErrorMessage) {
|
||||||
for {
|
h.ForwardStream(c, flusher, cancel, data, errs, handlers.StreamForwardOptions{
|
||||||
select {
|
WriteChunk: func(chunk []byte) {
|
||||||
case <-c.Request.Context().Done():
|
_, _ = fmt.Fprintf(c.Writer, "data: %s\n\n", string(chunk))
|
||||||
cancel(c.Request.Context().Err())
|
},
|
||||||
return
|
WriteTerminalError: func(errMsg *interfaces.ErrorMessage) {
|
||||||
case chunk, ok := <-data:
|
if errMsg == nil {
|
||||||
if !ok {
|
|
||||||
_, _ = fmt.Fprintf(c.Writer, "data: [DONE]\n\n")
|
|
||||||
flusher.Flush()
|
|
||||||
cancel(nil)
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
_, _ = fmt.Fprintf(c.Writer, "data: %s\n\n", string(chunk))
|
status := http.StatusInternalServerError
|
||||||
flusher.Flush()
|
if errMsg.StatusCode > 0 {
|
||||||
case errMsg, ok := <-errs:
|
status = errMsg.StatusCode
|
||||||
if !ok {
|
|
||||||
continue
|
|
||||||
}
|
}
|
||||||
if errMsg != nil {
|
errText := http.StatusText(status)
|
||||||
h.WriteErrorResponse(c, errMsg)
|
if errMsg.Error != nil && errMsg.Error.Error() != "" {
|
||||||
flusher.Flush()
|
errText = errMsg.Error.Error()
|
||||||
}
|
}
|
||||||
var execErr error
|
body := handlers.BuildErrorResponseBody(status, errText)
|
||||||
if errMsg != nil {
|
_, _ = fmt.Fprintf(c.Writer, "data: %s\n\n", string(body))
|
||||||
execErr = errMsg.Error
|
},
|
||||||
}
|
WriteDone: func() {
|
||||||
cancel(execErr)
|
_, _ = fmt.Fprint(c.Writer, "data: [DONE]\n\n")
|
||||||
return
|
},
|
||||||
case <-time.After(500 * time.Millisecond):
|
})
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -11,7 +11,6 @@ import (
|
|||||||
"context"
|
"context"
|
||||||
"fmt"
|
"fmt"
|
||||||
"net/http"
|
"net/http"
|
||||||
"time"
|
|
||||||
|
|
||||||
"github.com/gin-gonic/gin"
|
"github.com/gin-gonic/gin"
|
||||||
. "github.com/router-for-me/CLIProxyAPI/v6/internal/constant"
|
. "github.com/router-for-me/CLIProxyAPI/v6/internal/constant"
|
||||||
@@ -128,11 +127,6 @@ func (h *OpenAIResponsesAPIHandler) handleNonStreamingResponse(c *gin.Context, r
|
|||||||
// - c: The Gin context containing the HTTP request and response
|
// - c: The Gin context containing the HTTP request and response
|
||||||
// - rawJSON: The raw JSON bytes of the OpenAIResponses-compatible request
|
// - rawJSON: The raw JSON bytes of the OpenAIResponses-compatible request
|
||||||
func (h *OpenAIResponsesAPIHandler) handleStreamingResponse(c *gin.Context, rawJSON []byte) {
|
func (h *OpenAIResponsesAPIHandler) handleStreamingResponse(c *gin.Context, rawJSON []byte) {
|
||||||
c.Header("Content-Type", "text/event-stream")
|
|
||||||
c.Header("Cache-Control", "no-cache")
|
|
||||||
c.Header("Connection", "keep-alive")
|
|
||||||
c.Header("Access-Control-Allow-Origin", "*")
|
|
||||||
|
|
||||||
// Get the http.Flusher interface to manually flush the response.
|
// Get the http.Flusher interface to manually flush the response.
|
||||||
flusher, ok := c.Writer.(http.Flusher)
|
flusher, ok := c.Writer.(http.Flusher)
|
||||||
if !ok {
|
if !ok {
|
||||||
@@ -149,46 +143,88 @@ func (h *OpenAIResponsesAPIHandler) handleStreamingResponse(c *gin.Context, rawJ
|
|||||||
modelName := gjson.GetBytes(rawJSON, "model").String()
|
modelName := gjson.GetBytes(rawJSON, "model").String()
|
||||||
cliCtx, cliCancel := h.GetContextWithCancel(h, c, context.Background())
|
cliCtx, cliCancel := h.GetContextWithCancel(h, c, context.Background())
|
||||||
dataChan, errChan := h.ExecuteStreamWithAuthManager(cliCtx, h.HandlerType(), modelName, rawJSON, "")
|
dataChan, errChan := h.ExecuteStreamWithAuthManager(cliCtx, h.HandlerType(), modelName, rawJSON, "")
|
||||||
h.forwardResponsesStream(c, flusher, func(err error) { cliCancel(err) }, dataChan, errChan)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
func (h *OpenAIResponsesAPIHandler) forwardResponsesStream(c *gin.Context, flusher http.Flusher, cancel func(error), data <-chan []byte, errs <-chan *interfaces.ErrorMessage) {
|
setSSEHeaders := func() {
|
||||||
|
c.Header("Content-Type", "text/event-stream")
|
||||||
|
c.Header("Cache-Control", "no-cache")
|
||||||
|
c.Header("Connection", "keep-alive")
|
||||||
|
c.Header("Access-Control-Allow-Origin", "*")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Peek at the first chunk
|
||||||
for {
|
for {
|
||||||
select {
|
select {
|
||||||
case <-c.Request.Context().Done():
|
case <-c.Request.Context().Done():
|
||||||
cancel(c.Request.Context().Err())
|
cliCancel(c.Request.Context().Err())
|
||||||
return
|
return
|
||||||
case chunk, ok := <-data:
|
case errMsg, ok := <-errChan:
|
||||||
if !ok {
|
if !ok {
|
||||||
|
// Err channel closed cleanly; wait for data channel.
|
||||||
|
errChan = nil
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
// Upstream failed immediately. Return proper error status and JSON.
|
||||||
|
h.WriteErrorResponse(c, errMsg)
|
||||||
|
if errMsg != nil {
|
||||||
|
cliCancel(errMsg.Error)
|
||||||
|
} else {
|
||||||
|
cliCancel(nil)
|
||||||
|
}
|
||||||
|
return
|
||||||
|
case chunk, ok := <-dataChan:
|
||||||
|
if !ok {
|
||||||
|
// Stream closed without data? Send headers and done.
|
||||||
|
setSSEHeaders()
|
||||||
_, _ = c.Writer.Write([]byte("\n"))
|
_, _ = c.Writer.Write([]byte("\n"))
|
||||||
flusher.Flush()
|
flusher.Flush()
|
||||||
cancel(nil)
|
cliCancel(nil)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Success! Set headers.
|
||||||
|
setSSEHeaders()
|
||||||
|
|
||||||
|
// Write first chunk logic (matching forwardResponsesStream)
|
||||||
if bytes.HasPrefix(chunk, []byte("event:")) {
|
if bytes.HasPrefix(chunk, []byte("event:")) {
|
||||||
_, _ = c.Writer.Write([]byte("\n"))
|
_, _ = c.Writer.Write([]byte("\n"))
|
||||||
}
|
}
|
||||||
_, _ = c.Writer.Write(chunk)
|
_, _ = c.Writer.Write(chunk)
|
||||||
_, _ = c.Writer.Write([]byte("\n"))
|
_, _ = c.Writer.Write([]byte("\n"))
|
||||||
|
|
||||||
flusher.Flush()
|
flusher.Flush()
|
||||||
case errMsg, ok := <-errs:
|
|
||||||
if !ok {
|
// Continue
|
||||||
continue
|
h.forwardResponsesStream(c, flusher, func(err error) { cliCancel(err) }, dataChan, errChan)
|
||||||
}
|
|
||||||
if errMsg != nil {
|
|
||||||
h.WriteErrorResponse(c, errMsg)
|
|
||||||
flusher.Flush()
|
|
||||||
}
|
|
||||||
var execErr error
|
|
||||||
if errMsg != nil {
|
|
||||||
execErr = errMsg.Error
|
|
||||||
}
|
|
||||||
cancel(execErr)
|
|
||||||
return
|
return
|
||||||
case <-time.After(500 * time.Millisecond):
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (h *OpenAIResponsesAPIHandler) forwardResponsesStream(c *gin.Context, flusher http.Flusher, cancel func(error), data <-chan []byte, errs <-chan *interfaces.ErrorMessage) {
|
||||||
|
h.ForwardStream(c, flusher, cancel, data, errs, handlers.StreamForwardOptions{
|
||||||
|
WriteChunk: func(chunk []byte) {
|
||||||
|
if bytes.HasPrefix(chunk, []byte("event:")) {
|
||||||
|
_, _ = c.Writer.Write([]byte("\n"))
|
||||||
|
}
|
||||||
|
_, _ = c.Writer.Write(chunk)
|
||||||
|
_, _ = c.Writer.Write([]byte("\n"))
|
||||||
|
},
|
||||||
|
WriteTerminalError: func(errMsg *interfaces.ErrorMessage) {
|
||||||
|
if errMsg == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
status := http.StatusInternalServerError
|
||||||
|
if errMsg.StatusCode > 0 {
|
||||||
|
status = errMsg.StatusCode
|
||||||
|
}
|
||||||
|
errText := http.StatusText(status)
|
||||||
|
if errMsg.Error != nil && errMsg.Error.Error() != "" {
|
||||||
|
errText = errMsg.Error.Error()
|
||||||
|
}
|
||||||
|
body := handlers.BuildErrorResponseBody(status, errText)
|
||||||
|
_, _ = fmt.Fprintf(c.Writer, "\nevent: error\ndata: %s\n\n", string(body))
|
||||||
|
},
|
||||||
|
WriteDone: func() {
|
||||||
|
_, _ = c.Writer.Write([]byte("\n"))
|
||||||
|
},
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|||||||
121
sdk/api/handlers/stream_forwarder.go
Normal file
121
sdk/api/handlers/stream_forwarder.go
Normal file
@@ -0,0 +1,121 @@
|
|||||||
|
package handlers
|
||||||
|
|
||||||
|
import (
|
||||||
|
"net/http"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/gin-gonic/gin"
|
||||||
|
"github.com/router-for-me/CLIProxyAPI/v6/internal/interfaces"
|
||||||
|
)
|
||||||
|
|
||||||
|
type StreamForwardOptions struct {
|
||||||
|
// KeepAliveInterval overrides the configured streaming keep-alive interval.
|
||||||
|
// If nil, the configured default is used. If set to <= 0, keep-alives are disabled.
|
||||||
|
KeepAliveInterval *time.Duration
|
||||||
|
|
||||||
|
// WriteChunk writes a single data chunk to the response body. It should not flush.
|
||||||
|
WriteChunk func(chunk []byte)
|
||||||
|
|
||||||
|
// WriteTerminalError writes an error payload to the response body when streaming fails
|
||||||
|
// after headers have already been committed. It should not flush.
|
||||||
|
WriteTerminalError func(errMsg *interfaces.ErrorMessage)
|
||||||
|
|
||||||
|
// WriteDone optionally writes a terminal marker when the upstream data channel closes
|
||||||
|
// without an error (e.g. OpenAI's `[DONE]`). It should not flush.
|
||||||
|
WriteDone func()
|
||||||
|
|
||||||
|
// WriteKeepAlive optionally writes a keep-alive heartbeat. It should not flush.
|
||||||
|
// When nil, a standard SSE comment heartbeat is used.
|
||||||
|
WriteKeepAlive func()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *BaseAPIHandler) ForwardStream(c *gin.Context, flusher http.Flusher, cancel func(error), data <-chan []byte, errs <-chan *interfaces.ErrorMessage, opts StreamForwardOptions) {
|
||||||
|
if c == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if cancel == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
writeChunk := opts.WriteChunk
|
||||||
|
if writeChunk == nil {
|
||||||
|
writeChunk = func([]byte) {}
|
||||||
|
}
|
||||||
|
|
||||||
|
writeKeepAlive := opts.WriteKeepAlive
|
||||||
|
if writeKeepAlive == nil {
|
||||||
|
writeKeepAlive = func() {
|
||||||
|
_, _ = c.Writer.Write([]byte(": keep-alive\n\n"))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
keepAliveInterval := StreamingKeepAliveInterval(h.Cfg)
|
||||||
|
if opts.KeepAliveInterval != nil {
|
||||||
|
keepAliveInterval = *opts.KeepAliveInterval
|
||||||
|
}
|
||||||
|
var keepAlive *time.Ticker
|
||||||
|
var keepAliveC <-chan time.Time
|
||||||
|
if keepAliveInterval > 0 {
|
||||||
|
keepAlive = time.NewTicker(keepAliveInterval)
|
||||||
|
defer keepAlive.Stop()
|
||||||
|
keepAliveC = keepAlive.C
|
||||||
|
}
|
||||||
|
|
||||||
|
var terminalErr *interfaces.ErrorMessage
|
||||||
|
for {
|
||||||
|
select {
|
||||||
|
case <-c.Request.Context().Done():
|
||||||
|
cancel(c.Request.Context().Err())
|
||||||
|
return
|
||||||
|
case chunk, ok := <-data:
|
||||||
|
if !ok {
|
||||||
|
// Prefer surfacing a terminal error if one is pending.
|
||||||
|
if terminalErr == nil {
|
||||||
|
select {
|
||||||
|
case errMsg, ok := <-errs:
|
||||||
|
if ok && errMsg != nil {
|
||||||
|
terminalErr = errMsg
|
||||||
|
}
|
||||||
|
default:
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if terminalErr != nil {
|
||||||
|
if opts.WriteTerminalError != nil {
|
||||||
|
opts.WriteTerminalError(terminalErr)
|
||||||
|
}
|
||||||
|
flusher.Flush()
|
||||||
|
cancel(terminalErr.Error)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if opts.WriteDone != nil {
|
||||||
|
opts.WriteDone()
|
||||||
|
}
|
||||||
|
flusher.Flush()
|
||||||
|
cancel(nil)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
writeChunk(chunk)
|
||||||
|
flusher.Flush()
|
||||||
|
case errMsg, ok := <-errs:
|
||||||
|
if !ok {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if errMsg != nil {
|
||||||
|
terminalErr = errMsg
|
||||||
|
if opts.WriteTerminalError != nil {
|
||||||
|
opts.WriteTerminalError(errMsg)
|
||||||
|
flusher.Flush()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
var execErr error
|
||||||
|
if errMsg != nil {
|
||||||
|
execErr = errMsg.Error
|
||||||
|
}
|
||||||
|
cancel(execErr)
|
||||||
|
return
|
||||||
|
case <-keepAliveC:
|
||||||
|
writeKeepAlive()
|
||||||
|
flusher.Flush()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
116
sdk/auth/kiro.go
116
sdk/auth/kiro.go
@@ -53,20 +53,8 @@ func (a *KiroAuthenticator) RefreshLead() *time.Duration {
|
|||||||
return &d
|
return &d
|
||||||
}
|
}
|
||||||
|
|
||||||
// Login performs OAuth login for Kiro with AWS Builder ID.
|
// createAuthRecord creates an auth record from token data.
|
||||||
func (a *KiroAuthenticator) Login(ctx context.Context, cfg *config.Config, opts *LoginOptions) (*coreauth.Auth, error) {
|
func (a *KiroAuthenticator) createAuthRecord(tokenData *kiroauth.KiroTokenData, source string) (*coreauth.Auth, error) {
|
||||||
if cfg == nil {
|
|
||||||
return nil, fmt.Errorf("kiro auth: configuration is required")
|
|
||||||
}
|
|
||||||
|
|
||||||
oauth := kiroauth.NewKiroOAuth(cfg)
|
|
||||||
|
|
||||||
// Use AWS Builder ID device code flow
|
|
||||||
tokenData, err := oauth.LoginWithBuilderID(ctx)
|
|
||||||
if err != nil {
|
|
||||||
return nil, fmt.Errorf("login failed: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Parse expires_at
|
// Parse expires_at
|
||||||
expiresAt, err := time.Parse(time.RFC3339, tokenData.ExpiresAt)
|
expiresAt, err := time.Parse(time.RFC3339, tokenData.ExpiresAt)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -76,34 +64,63 @@ func (a *KiroAuthenticator) Login(ctx context.Context, cfg *config.Config, opts
|
|||||||
// Extract identifier for file naming
|
// Extract identifier for file naming
|
||||||
idPart := extractKiroIdentifier(tokenData.Email, tokenData.ProfileArn)
|
idPart := extractKiroIdentifier(tokenData.Email, tokenData.ProfileArn)
|
||||||
|
|
||||||
|
// Determine label based on auth method
|
||||||
|
label := fmt.Sprintf("kiro-%s", source)
|
||||||
|
if tokenData.AuthMethod == "idc" {
|
||||||
|
label = "kiro-idc"
|
||||||
|
}
|
||||||
|
|
||||||
now := time.Now()
|
now := time.Now()
|
||||||
fileName := fmt.Sprintf("kiro-aws-%s.json", idPart)
|
fileName := fmt.Sprintf("%s-%s.json", label, idPart)
|
||||||
|
|
||||||
|
metadata := map[string]any{
|
||||||
|
"type": "kiro",
|
||||||
|
"access_token": tokenData.AccessToken,
|
||||||
|
"refresh_token": tokenData.RefreshToken,
|
||||||
|
"profile_arn": tokenData.ProfileArn,
|
||||||
|
"expires_at": tokenData.ExpiresAt,
|
||||||
|
"auth_method": tokenData.AuthMethod,
|
||||||
|
"provider": tokenData.Provider,
|
||||||
|
"client_id": tokenData.ClientID,
|
||||||
|
"client_secret": tokenData.ClientSecret,
|
||||||
|
"email": tokenData.Email,
|
||||||
|
}
|
||||||
|
|
||||||
|
// Add IDC-specific fields if present
|
||||||
|
if tokenData.StartURL != "" {
|
||||||
|
metadata["start_url"] = tokenData.StartURL
|
||||||
|
}
|
||||||
|
if tokenData.Region != "" {
|
||||||
|
metadata["region"] = tokenData.Region
|
||||||
|
}
|
||||||
|
|
||||||
|
attributes := map[string]string{
|
||||||
|
"profile_arn": tokenData.ProfileArn,
|
||||||
|
"source": source,
|
||||||
|
"email": tokenData.Email,
|
||||||
|
}
|
||||||
|
|
||||||
|
// Add IDC-specific attributes if present
|
||||||
|
if tokenData.AuthMethod == "idc" {
|
||||||
|
attributes["source"] = "aws-idc"
|
||||||
|
if tokenData.StartURL != "" {
|
||||||
|
attributes["start_url"] = tokenData.StartURL
|
||||||
|
}
|
||||||
|
if tokenData.Region != "" {
|
||||||
|
attributes["region"] = tokenData.Region
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
record := &coreauth.Auth{
|
record := &coreauth.Auth{
|
||||||
ID: fileName,
|
ID: fileName,
|
||||||
Provider: "kiro",
|
Provider: "kiro",
|
||||||
FileName: fileName,
|
FileName: fileName,
|
||||||
Label: "kiro-aws",
|
Label: label,
|
||||||
Status: coreauth.StatusActive,
|
Status: coreauth.StatusActive,
|
||||||
CreatedAt: now,
|
CreatedAt: now,
|
||||||
UpdatedAt: now,
|
UpdatedAt: now,
|
||||||
Metadata: map[string]any{
|
Metadata: metadata,
|
||||||
"type": "kiro",
|
Attributes: attributes,
|
||||||
"access_token": tokenData.AccessToken,
|
|
||||||
"refresh_token": tokenData.RefreshToken,
|
|
||||||
"profile_arn": tokenData.ProfileArn,
|
|
||||||
"expires_at": tokenData.ExpiresAt,
|
|
||||||
"auth_method": tokenData.AuthMethod,
|
|
||||||
"provider": tokenData.Provider,
|
|
||||||
"client_id": tokenData.ClientID,
|
|
||||||
"client_secret": tokenData.ClientSecret,
|
|
||||||
"email": tokenData.Email,
|
|
||||||
},
|
|
||||||
Attributes: map[string]string{
|
|
||||||
"profile_arn": tokenData.ProfileArn,
|
|
||||||
"source": "aws-builder-id",
|
|
||||||
"email": tokenData.Email,
|
|
||||||
},
|
|
||||||
// NextRefreshAfter is aligned with RefreshLead (5min)
|
// NextRefreshAfter is aligned with RefreshLead (5min)
|
||||||
NextRefreshAfter: expiresAt.Add(-5 * time.Minute),
|
NextRefreshAfter: expiresAt.Add(-5 * time.Minute),
|
||||||
}
|
}
|
||||||
@@ -117,6 +134,23 @@ func (a *KiroAuthenticator) Login(ctx context.Context, cfg *config.Config, opts
|
|||||||
return record, nil
|
return record, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Login performs OAuth login for Kiro with AWS (Builder ID or IDC).
|
||||||
|
// This shows a method selection prompt and handles both flows.
|
||||||
|
func (a *KiroAuthenticator) Login(ctx context.Context, cfg *config.Config, opts *LoginOptions) (*coreauth.Auth, error) {
|
||||||
|
if cfg == nil {
|
||||||
|
return nil, fmt.Errorf("kiro auth: configuration is required")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Use the unified method selection flow (Builder ID or IDC)
|
||||||
|
ssoClient := kiroauth.NewSSOOIDCClient(cfg)
|
||||||
|
tokenData, err := ssoClient.LoginWithMethodSelection(ctx)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("login failed: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return a.createAuthRecord(tokenData, "aws")
|
||||||
|
}
|
||||||
|
|
||||||
// LoginWithAuthCode performs OAuth login for Kiro with AWS Builder ID using authorization code flow.
|
// LoginWithAuthCode performs OAuth login for Kiro with AWS Builder ID using authorization code flow.
|
||||||
// This provides a better UX than device code flow as it uses automatic browser callback.
|
// This provides a better UX than device code flow as it uses automatic browser callback.
|
||||||
func (a *KiroAuthenticator) LoginWithAuthCode(ctx context.Context, cfg *config.Config, opts *LoginOptions) (*coreauth.Auth, error) {
|
func (a *KiroAuthenticator) LoginWithAuthCode(ctx context.Context, cfg *config.Config, opts *LoginOptions) (*coreauth.Auth, error) {
|
||||||
@@ -388,15 +422,23 @@ func (a *KiroAuthenticator) Refresh(ctx context.Context, cfg *config.Config, aut
|
|||||||
clientID, _ := auth.Metadata["client_id"].(string)
|
clientID, _ := auth.Metadata["client_id"].(string)
|
||||||
clientSecret, _ := auth.Metadata["client_secret"].(string)
|
clientSecret, _ := auth.Metadata["client_secret"].(string)
|
||||||
authMethod, _ := auth.Metadata["auth_method"].(string)
|
authMethod, _ := auth.Metadata["auth_method"].(string)
|
||||||
|
startURL, _ := auth.Metadata["start_url"].(string)
|
||||||
|
region, _ := auth.Metadata["region"].(string)
|
||||||
|
|
||||||
var tokenData *kiroauth.KiroTokenData
|
var tokenData *kiroauth.KiroTokenData
|
||||||
var err error
|
var err error
|
||||||
|
|
||||||
// Use SSO OIDC refresh for AWS Builder ID, otherwise use Kiro's OAuth refresh endpoint
|
ssoClient := kiroauth.NewSSOOIDCClient(cfg)
|
||||||
if clientID != "" && clientSecret != "" && authMethod == "builder-id" {
|
|
||||||
ssoClient := kiroauth.NewSSOOIDCClient(cfg)
|
// Use SSO OIDC refresh for AWS Builder ID or IDC, otherwise use Kiro's OAuth refresh endpoint
|
||||||
|
switch {
|
||||||
|
case clientID != "" && clientSecret != "" && authMethod == "idc" && region != "":
|
||||||
|
// IDC refresh with region-specific endpoint
|
||||||
|
tokenData, err = ssoClient.RefreshTokenWithRegion(ctx, clientID, clientSecret, refreshToken, region, startURL)
|
||||||
|
case clientID != "" && clientSecret != "" && authMethod == "builder-id":
|
||||||
|
// Builder ID refresh with default endpoint
|
||||||
tokenData, err = ssoClient.RefreshToken(ctx, clientID, clientSecret, refreshToken)
|
tokenData, err = ssoClient.RefreshToken(ctx, clientID, clientSecret, refreshToken)
|
||||||
} else {
|
default:
|
||||||
// Fallback to Kiro's refresh endpoint (for social auth: Google/GitHub)
|
// Fallback to Kiro's refresh endpoint (for social auth: Google/GitHub)
|
||||||
oauth := kiroauth.NewKiroOAuth(cfg)
|
oauth := kiroauth.NewKiroOAuth(cfg)
|
||||||
tokenData, err = oauth.RefreshToken(ctx, refreshToken)
|
tokenData, err = oauth.RefreshToken(ctx, refreshToken)
|
||||||
|
|||||||
@@ -12,6 +12,7 @@ import (
|
|||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/google/uuid"
|
"github.com/google/uuid"
|
||||||
|
"github.com/router-for-me/CLIProxyAPI/v6/internal/logging"
|
||||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/registry"
|
"github.com/router-for-me/CLIProxyAPI/v6/internal/registry"
|
||||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/util"
|
"github.com/router-for-me/CLIProxyAPI/v6/internal/util"
|
||||||
cliproxyexecutor "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/executor"
|
cliproxyexecutor "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/executor"
|
||||||
@@ -389,17 +390,18 @@ func (m *Manager) executeWithProvider(ctx context.Context, provider string, req
|
|||||||
|
|
||||||
accountType, accountInfo := auth.AccountInfo()
|
accountType, accountInfo := auth.AccountInfo()
|
||||||
proxyInfo := auth.ProxyInfo()
|
proxyInfo := auth.ProxyInfo()
|
||||||
|
entry := logEntryWithRequestID(ctx)
|
||||||
if accountType == "api_key" {
|
if accountType == "api_key" {
|
||||||
if proxyInfo != "" {
|
if proxyInfo != "" {
|
||||||
log.Debugf("Use API key %s for model %s %s", util.HideAPIKey(accountInfo), req.Model, proxyInfo)
|
entry.Debugf("Use API key %s for model %s %s", util.HideAPIKey(accountInfo), req.Model, proxyInfo)
|
||||||
} else {
|
} else {
|
||||||
log.Debugf("Use API key %s for model %s", util.HideAPIKey(accountInfo), req.Model)
|
entry.Debugf("Use API key %s for model %s", util.HideAPIKey(accountInfo), req.Model)
|
||||||
}
|
}
|
||||||
} else if accountType == "oauth" {
|
} else if accountType == "oauth" {
|
||||||
if proxyInfo != "" {
|
if proxyInfo != "" {
|
||||||
log.Debugf("Use OAuth %s for model %s %s", accountInfo, req.Model, proxyInfo)
|
entry.Debugf("Use OAuth %s for model %s %s", accountInfo, req.Model, proxyInfo)
|
||||||
} else {
|
} else {
|
||||||
log.Debugf("Use OAuth %s for model %s", accountInfo, req.Model)
|
entry.Debugf("Use OAuth %s for model %s", accountInfo, req.Model)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -449,17 +451,18 @@ func (m *Manager) executeCountWithProvider(ctx context.Context, provider string,
|
|||||||
|
|
||||||
accountType, accountInfo := auth.AccountInfo()
|
accountType, accountInfo := auth.AccountInfo()
|
||||||
proxyInfo := auth.ProxyInfo()
|
proxyInfo := auth.ProxyInfo()
|
||||||
|
entry := logEntryWithRequestID(ctx)
|
||||||
if accountType == "api_key" {
|
if accountType == "api_key" {
|
||||||
if proxyInfo != "" {
|
if proxyInfo != "" {
|
||||||
log.Debugf("Use API key %s for model %s %s", util.HideAPIKey(accountInfo), req.Model, proxyInfo)
|
entry.Debugf("Use API key %s for model %s %s", util.HideAPIKey(accountInfo), req.Model, proxyInfo)
|
||||||
} else {
|
} else {
|
||||||
log.Debugf("Use API key %s for model %s", util.HideAPIKey(accountInfo), req.Model)
|
entry.Debugf("Use API key %s for model %s", util.HideAPIKey(accountInfo), req.Model)
|
||||||
}
|
}
|
||||||
} else if accountType == "oauth" {
|
} else if accountType == "oauth" {
|
||||||
if proxyInfo != "" {
|
if proxyInfo != "" {
|
||||||
log.Debugf("Use OAuth %s for model %s %s", accountInfo, req.Model, proxyInfo)
|
entry.Debugf("Use OAuth %s for model %s %s", accountInfo, req.Model, proxyInfo)
|
||||||
} else {
|
} else {
|
||||||
log.Debugf("Use OAuth %s for model %s", accountInfo, req.Model)
|
entry.Debugf("Use OAuth %s for model %s", accountInfo, req.Model)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -509,17 +512,18 @@ func (m *Manager) executeStreamWithProvider(ctx context.Context, provider string
|
|||||||
|
|
||||||
accountType, accountInfo := auth.AccountInfo()
|
accountType, accountInfo := auth.AccountInfo()
|
||||||
proxyInfo := auth.ProxyInfo()
|
proxyInfo := auth.ProxyInfo()
|
||||||
|
entry := logEntryWithRequestID(ctx)
|
||||||
if accountType == "api_key" {
|
if accountType == "api_key" {
|
||||||
if proxyInfo != "" {
|
if proxyInfo != "" {
|
||||||
log.Debugf("Use API key %s for model %s %s", util.HideAPIKey(accountInfo), req.Model, proxyInfo)
|
entry.Debugf("Use API key %s for model %s %s", util.HideAPIKey(accountInfo), req.Model, proxyInfo)
|
||||||
} else {
|
} else {
|
||||||
log.Debugf("Use API key %s for model %s", util.HideAPIKey(accountInfo), req.Model)
|
entry.Debugf("Use API key %s for model %s", util.HideAPIKey(accountInfo), req.Model)
|
||||||
}
|
}
|
||||||
} else if accountType == "oauth" {
|
} else if accountType == "oauth" {
|
||||||
if proxyInfo != "" {
|
if proxyInfo != "" {
|
||||||
log.Debugf("Use OAuth %s for model %s %s", accountInfo, req.Model, proxyInfo)
|
entry.Debugf("Use OAuth %s for model %s %s", accountInfo, req.Model, proxyInfo)
|
||||||
} else {
|
} else {
|
||||||
log.Debugf("Use OAuth %s for model %s", accountInfo, req.Model)
|
entry.Debugf("Use OAuth %s for model %s", accountInfo, req.Model)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -1606,6 +1610,17 @@ type RequestPreparer interface {
|
|||||||
PrepareRequest(req *http.Request, auth *Auth) error
|
PrepareRequest(req *http.Request, auth *Auth) error
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// logEntryWithRequestID returns a logrus entry with request_id field if available in context.
|
||||||
|
func logEntryWithRequestID(ctx context.Context) *log.Entry {
|
||||||
|
if ctx == nil {
|
||||||
|
return log.NewEntry(log.StandardLogger())
|
||||||
|
}
|
||||||
|
if reqID := logging.GetRequestID(ctx); reqID != "" {
|
||||||
|
return log.WithField("request_id", reqID)
|
||||||
|
}
|
||||||
|
return log.NewEntry(log.StandardLogger())
|
||||||
|
}
|
||||||
|
|
||||||
// InjectCredentials delegates per-provider HTTP request preparation when supported.
|
// InjectCredentials delegates per-provider HTTP request preparation when supported.
|
||||||
// If the registered executor for the auth provider implements RequestPreparer,
|
// If the registered executor for the auth provider implements RequestPreparer,
|
||||||
// it will be invoked to modify the request (e.g., add headers).
|
// it will be invoked to modify the request (e.g., add headers).
|
||||||
@@ -12,6 +12,7 @@ type AccessProvider = internalconfig.AccessProvider
|
|||||||
|
|
||||||
type Config = internalconfig.Config
|
type Config = internalconfig.Config
|
||||||
|
|
||||||
|
type StreamingConfig = internalconfig.StreamingConfig
|
||||||
type TLSConfig = internalconfig.TLSConfig
|
type TLSConfig = internalconfig.TLSConfig
|
||||||
type RemoteManagement = internalconfig.RemoteManagement
|
type RemoteManagement = internalconfig.RemoteManagement
|
||||||
type AmpCode = internalconfig.AmpCode
|
type AmpCode = internalconfig.AmpCode
|
||||||
|
|||||||
Reference in New Issue
Block a user