mirror of
https://github.com/router-for-me/CLIProxyAPIPlus.git
synced 2026-03-29 16:54:41 +00:00
Compare commits
68 Commits
pr-59-reso
...
v6.6.61-0
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
d06e2dc83c | ||
|
|
336867853b | ||
|
|
6403ff4ec4 | ||
|
|
d222469b44 | ||
|
|
790a17ce98 | ||
|
|
d473c952fb | ||
|
|
7646a2b877 | ||
|
|
62090f2568 | ||
|
|
d35152bbef | ||
|
|
c281f4cbaf | ||
|
|
09455f9e85 | ||
|
|
c8e72ba0dc | ||
|
|
375ef252ab | ||
|
|
ee552f8720 | ||
|
|
2e88c4858e | ||
|
|
3f50da85c1 | ||
|
|
8be06255f7 | ||
|
|
60936b5185 | ||
|
|
72274099aa | ||
|
|
b7f7b3a1d8 | ||
|
|
dcae098e23 | ||
|
|
618606966f | ||
|
|
05f249d77f | ||
|
|
2eb05ec640 | ||
|
|
3ce0d76aa4 | ||
|
|
a00b79d9be | ||
|
|
9fe6a215e6 | ||
|
|
33e53a2a56 | ||
|
|
cd5b80785f | ||
|
|
54f71aa273 | ||
|
|
3f949b7f84 | ||
|
|
cf8b2dcc85 | ||
|
|
8e24d9dc34 | ||
|
|
443c4538bb | ||
|
|
a7fc2ee4cf | ||
|
|
8e749ac22d | ||
|
|
69e09d9bc7 | ||
|
|
ed57d82bc1 | ||
|
|
06ad527e8c | ||
|
|
7af5a90a0b | ||
|
|
7551faff79 | ||
|
|
b7409dd2de | ||
|
|
5ba325a8fc | ||
|
|
d502840f91 | ||
|
|
99238a4b59 | ||
|
|
6d43a2ff9a | ||
|
|
cdb9c2e6e8 | ||
|
|
3faa1ca9af | ||
|
|
9d975e0375 | ||
|
|
2a6d8b78d4 | ||
|
|
671558a822 | ||
|
|
6b80ec79a0 | ||
|
|
d3f4783a24 | ||
|
|
1cb6bdbc87 | ||
|
|
96ddfc1f24 | ||
|
|
c169b32570 | ||
|
|
36a512fdf2 | ||
|
|
26fbb77901 | ||
|
|
a277302262 | ||
|
|
969c1a5b72 | ||
|
|
872339bceb | ||
|
|
5dc0dbc7aa | ||
|
|
ee6fc4e8a1 | ||
|
|
8d25cf0d75 | ||
|
|
64e85e7019 | ||
|
|
349b2ba3af | ||
|
|
98db5aabd0 | ||
|
|
7fd98f3556 |
@@ -13,8 +13,6 @@ Dockerfile
|
|||||||
docs/*
|
docs/*
|
||||||
README.md
|
README.md
|
||||||
README_CN.md
|
README_CN.md
|
||||||
MANAGEMENT_API.md
|
|
||||||
MANAGEMENT_API_CN.md
|
|
||||||
LICENSE
|
LICENSE
|
||||||
|
|
||||||
# Runtime data folders (should be mounted as volumes)
|
# Runtime data folders (should be mounted as volumes)
|
||||||
@@ -32,3 +30,4 @@ bin/*
|
|||||||
.agent/*
|
.agent/*
|
||||||
.bmad/*
|
.bmad/*
|
||||||
_bmad/*
|
_bmad/*
|
||||||
|
_bmad-output/*
|
||||||
|
|||||||
7
.github/ISSUE_TEMPLATE/bug_report.md
vendored
7
.github/ISSUE_TEMPLATE/bug_report.md
vendored
@@ -7,6 +7,13 @@ assignees: ''
|
|||||||
|
|
||||||
---
|
---
|
||||||
|
|
||||||
|
**Is it a request payload issue?**
|
||||||
|
[ ] Yes, this is a request payload issue. I am using a client/cURL to send a request payload, but I received an unexpected error.
|
||||||
|
[ ] No, it's another issue.
|
||||||
|
|
||||||
|
**If it's a request payload issue, you MUST know**
|
||||||
|
Our team doesn't have any GODs or ORACLEs or MIND READERs. Please make sure to attach the request log or curl payload.
|
||||||
|
|
||||||
**Describe the bug**
|
**Describe the bug**
|
||||||
A clear and concise description of what the bug is.
|
A clear and concise description of what the bug is.
|
||||||
|
|
||||||
|
|||||||
7
.gitignore
vendored
7
.gitignore
vendored
@@ -12,11 +12,15 @@ bin/*
|
|||||||
logs/*
|
logs/*
|
||||||
conv/*
|
conv/*
|
||||||
temp/*
|
temp/*
|
||||||
|
refs/*
|
||||||
|
|
||||||
|
# Storage backends
|
||||||
pgstore/*
|
pgstore/*
|
||||||
gitstore/*
|
gitstore/*
|
||||||
objectstore/*
|
objectstore/*
|
||||||
|
|
||||||
|
# Static assets
|
||||||
static/*
|
static/*
|
||||||
refs/*
|
|
||||||
|
|
||||||
# Authentication data
|
# Authentication data
|
||||||
auths/*
|
auths/*
|
||||||
@@ -36,6 +40,7 @@ GEMINI.md
|
|||||||
.agent/*
|
.agent/*
|
||||||
.bmad/*
|
.bmad/*
|
||||||
_bmad/*
|
_bmad/*
|
||||||
|
_bmad-output/*
|
||||||
.mcp/cache/
|
.mcp/cache/
|
||||||
|
|
||||||
# macOS
|
# macOS
|
||||||
|
|||||||
BIN
assets/cubence.png
Normal file
BIN
assets/cubence.png
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 51 KiB |
@@ -39,6 +39,9 @@ api-keys:
|
|||||||
# Enable debug logging
|
# Enable debug logging
|
||||||
debug: false
|
debug: false
|
||||||
|
|
||||||
|
# When true, disable high-overhead HTTP middleware features to reduce per-request memory usage under high concurrency.
|
||||||
|
commercial-mode: false
|
||||||
|
|
||||||
# Open OAuth URLs in incognito/private browser mode.
|
# Open OAuth URLs in incognito/private browser mode.
|
||||||
# Useful when you want to login with a different account without logging out from your current session.
|
# Useful when you want to login with a different account without logging out from your current session.
|
||||||
# Default: false (but Kiro auth defaults to true for multi-account support)
|
# Default: false (but Kiro auth defaults to true for multi-account support)
|
||||||
|
|||||||
@@ -209,6 +209,94 @@ func (h *Handler) GetRequestErrorLogs(c *gin.Context) {
|
|||||||
c.JSON(http.StatusOK, gin.H{"files": files})
|
c.JSON(http.StatusOK, gin.H{"files": files})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// GetRequestLogByID finds and downloads a request log file by its request ID.
|
||||||
|
// The ID is matched against the suffix of log file names (format: *-{requestID}.log).
|
||||||
|
func (h *Handler) GetRequestLogByID(c *gin.Context) {
|
||||||
|
if h == nil {
|
||||||
|
c.JSON(http.StatusInternalServerError, gin.H{"error": "handler unavailable"})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if h.cfg == nil {
|
||||||
|
c.JSON(http.StatusServiceUnavailable, gin.H{"error": "configuration unavailable"})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
dir := h.logDirectory()
|
||||||
|
if strings.TrimSpace(dir) == "" {
|
||||||
|
c.JSON(http.StatusInternalServerError, gin.H{"error": "log directory not configured"})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
requestID := strings.TrimSpace(c.Param("id"))
|
||||||
|
if requestID == "" {
|
||||||
|
requestID = strings.TrimSpace(c.Query("id"))
|
||||||
|
}
|
||||||
|
if requestID == "" {
|
||||||
|
c.JSON(http.StatusBadRequest, gin.H{"error": "missing request ID"})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if strings.ContainsAny(requestID, "/\\") {
|
||||||
|
c.JSON(http.StatusBadRequest, gin.H{"error": "invalid request ID"})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
entries, err := os.ReadDir(dir)
|
||||||
|
if err != nil {
|
||||||
|
if os.IsNotExist(err) {
|
||||||
|
c.JSON(http.StatusNotFound, gin.H{"error": "log directory not found"})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
c.JSON(http.StatusInternalServerError, gin.H{"error": fmt.Sprintf("failed to list log directory: %v", err)})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
suffix := "-" + requestID + ".log"
|
||||||
|
var matchedFile string
|
||||||
|
for _, entry := range entries {
|
||||||
|
if entry.IsDir() {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
name := entry.Name()
|
||||||
|
if strings.HasSuffix(name, suffix) {
|
||||||
|
matchedFile = name
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if matchedFile == "" {
|
||||||
|
c.JSON(http.StatusNotFound, gin.H{"error": "log file not found for the given request ID"})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
dirAbs, errAbs := filepath.Abs(dir)
|
||||||
|
if errAbs != nil {
|
||||||
|
c.JSON(http.StatusInternalServerError, gin.H{"error": fmt.Sprintf("failed to resolve log directory: %v", errAbs)})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
fullPath := filepath.Clean(filepath.Join(dirAbs, matchedFile))
|
||||||
|
prefix := dirAbs + string(os.PathSeparator)
|
||||||
|
if !strings.HasPrefix(fullPath, prefix) {
|
||||||
|
c.JSON(http.StatusBadRequest, gin.H{"error": "invalid log file path"})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
info, errStat := os.Stat(fullPath)
|
||||||
|
if errStat != nil {
|
||||||
|
if os.IsNotExist(errStat) {
|
||||||
|
c.JSON(http.StatusNotFound, gin.H{"error": "log file not found"})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
c.JSON(http.StatusInternalServerError, gin.H{"error": fmt.Sprintf("failed to read log file: %v", errStat)})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if info.IsDir() {
|
||||||
|
c.JSON(http.StatusBadRequest, gin.H{"error": "invalid log file"})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
c.FileAttachment(fullPath, matchedFile)
|
||||||
|
}
|
||||||
|
|
||||||
// DownloadRequestErrorLog downloads a specific error request log file by name.
|
// DownloadRequestErrorLog downloads a specific error request log file by name.
|
||||||
func (h *Handler) DownloadRequestErrorLog(c *gin.Context) {
|
func (h *Handler) DownloadRequestErrorLog(c *gin.Context) {
|
||||||
if h == nil {
|
if h == nil {
|
||||||
|
|||||||
@@ -1,12 +1,25 @@
|
|||||||
package management
|
package management
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"encoding/json"
|
||||||
"net/http"
|
"net/http"
|
||||||
|
"time"
|
||||||
|
|
||||||
"github.com/gin-gonic/gin"
|
"github.com/gin-gonic/gin"
|
||||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/usage"
|
"github.com/router-for-me/CLIProxyAPI/v6/internal/usage"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
type usageExportPayload struct {
|
||||||
|
Version int `json:"version"`
|
||||||
|
ExportedAt time.Time `json:"exported_at"`
|
||||||
|
Usage usage.StatisticsSnapshot `json:"usage"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type usageImportPayload struct {
|
||||||
|
Version int `json:"version"`
|
||||||
|
Usage usage.StatisticsSnapshot `json:"usage"`
|
||||||
|
}
|
||||||
|
|
||||||
// GetUsageStatistics returns the in-memory request statistics snapshot.
|
// GetUsageStatistics returns the in-memory request statistics snapshot.
|
||||||
func (h *Handler) GetUsageStatistics(c *gin.Context) {
|
func (h *Handler) GetUsageStatistics(c *gin.Context) {
|
||||||
var snapshot usage.StatisticsSnapshot
|
var snapshot usage.StatisticsSnapshot
|
||||||
@@ -18,3 +31,49 @@ func (h *Handler) GetUsageStatistics(c *gin.Context) {
|
|||||||
"failed_requests": snapshot.FailureCount,
|
"failed_requests": snapshot.FailureCount,
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// ExportUsageStatistics returns a complete usage snapshot for backup/migration.
|
||||||
|
func (h *Handler) ExportUsageStatistics(c *gin.Context) {
|
||||||
|
var snapshot usage.StatisticsSnapshot
|
||||||
|
if h != nil && h.usageStats != nil {
|
||||||
|
snapshot = h.usageStats.Snapshot()
|
||||||
|
}
|
||||||
|
c.JSON(http.StatusOK, usageExportPayload{
|
||||||
|
Version: 1,
|
||||||
|
ExportedAt: time.Now().UTC(),
|
||||||
|
Usage: snapshot,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// ImportUsageStatistics merges a previously exported usage snapshot into memory.
|
||||||
|
func (h *Handler) ImportUsageStatistics(c *gin.Context) {
|
||||||
|
if h == nil || h.usageStats == nil {
|
||||||
|
c.JSON(http.StatusBadRequest, gin.H{"error": "usage statistics unavailable"})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
data, err := c.GetRawData()
|
||||||
|
if err != nil {
|
||||||
|
c.JSON(http.StatusBadRequest, gin.H{"error": "failed to read request body"})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
var payload usageImportPayload
|
||||||
|
if err := json.Unmarshal(data, &payload); err != nil {
|
||||||
|
c.JSON(http.StatusBadRequest, gin.H{"error": "invalid json"})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if payload.Version != 0 && payload.Version != 1 {
|
||||||
|
c.JSON(http.StatusBadRequest, gin.H{"error": "unsupported version"})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
result := h.usageStats.MergeSnapshot(payload.Usage)
|
||||||
|
snapshot := h.usageStats.Snapshot()
|
||||||
|
c.JSON(http.StatusOK, gin.H{
|
||||||
|
"added": result.Added,
|
||||||
|
"skipped": result.Skipped,
|
||||||
|
"total_requests": snapshot.TotalRequests,
|
||||||
|
"failed_requests": snapshot.FailureCount,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|||||||
@@ -209,13 +209,15 @@ func NewServer(cfg *config.Config, authManager *auth.Manager, accessManager *sdk
|
|||||||
// Resolve logs directory relative to the configuration file directory.
|
// Resolve logs directory relative to the configuration file directory.
|
||||||
var requestLogger logging.RequestLogger
|
var requestLogger logging.RequestLogger
|
||||||
var toggle func(bool)
|
var toggle func(bool)
|
||||||
if optionState.requestLoggerFactory != nil {
|
if !cfg.CommercialMode {
|
||||||
requestLogger = optionState.requestLoggerFactory(cfg, configFilePath)
|
if optionState.requestLoggerFactory != nil {
|
||||||
}
|
requestLogger = optionState.requestLoggerFactory(cfg, configFilePath)
|
||||||
if requestLogger != nil {
|
}
|
||||||
engine.Use(middleware.RequestLoggingMiddleware(requestLogger))
|
if requestLogger != nil {
|
||||||
if setter, ok := requestLogger.(interface{ SetEnabled(bool) }); ok {
|
engine.Use(middleware.RequestLoggingMiddleware(requestLogger))
|
||||||
toggle = setter.SetEnabled
|
if setter, ok := requestLogger.(interface{ SetEnabled(bool) }); ok {
|
||||||
|
toggle = setter.SetEnabled
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -494,6 +496,8 @@ func (s *Server) registerManagementRoutes() {
|
|||||||
mgmt.Use(s.managementAvailabilityMiddleware(), s.mgmt.Middleware())
|
mgmt.Use(s.managementAvailabilityMiddleware(), s.mgmt.Middleware())
|
||||||
{
|
{
|
||||||
mgmt.GET("/usage", s.mgmt.GetUsageStatistics)
|
mgmt.GET("/usage", s.mgmt.GetUsageStatistics)
|
||||||
|
mgmt.GET("/usage/export", s.mgmt.ExportUsageStatistics)
|
||||||
|
mgmt.POST("/usage/import", s.mgmt.ImportUsageStatistics)
|
||||||
mgmt.GET("/config", s.mgmt.GetConfig)
|
mgmt.GET("/config", s.mgmt.GetConfig)
|
||||||
mgmt.GET("/config.yaml", s.mgmt.GetConfigYAML)
|
mgmt.GET("/config.yaml", s.mgmt.GetConfigYAML)
|
||||||
mgmt.PUT("/config.yaml", s.mgmt.PutConfigYAML)
|
mgmt.PUT("/config.yaml", s.mgmt.PutConfigYAML)
|
||||||
@@ -538,6 +542,7 @@ func (s *Server) registerManagementRoutes() {
|
|||||||
mgmt.DELETE("/logs", s.mgmt.DeleteLogs)
|
mgmt.DELETE("/logs", s.mgmt.DeleteLogs)
|
||||||
mgmt.GET("/request-error-logs", s.mgmt.GetRequestErrorLogs)
|
mgmt.GET("/request-error-logs", s.mgmt.GetRequestErrorLogs)
|
||||||
mgmt.GET("/request-error-logs/:name", s.mgmt.DownloadRequestErrorLog)
|
mgmt.GET("/request-error-logs/:name", s.mgmt.DownloadRequestErrorLog)
|
||||||
|
mgmt.GET("/request-log-by-id/:id", s.mgmt.GetRequestLogByID)
|
||||||
mgmt.GET("/request-log", s.mgmt.GetRequestLog)
|
mgmt.GET("/request-log", s.mgmt.GetRequestLog)
|
||||||
mgmt.PUT("/request-log", s.mgmt.PutRequestLog)
|
mgmt.PUT("/request-log", s.mgmt.PutRequestLog)
|
||||||
mgmt.PATCH("/request-log", s.mgmt.PutRequestLog)
|
mgmt.PATCH("/request-log", s.mgmt.PutRequestLog)
|
||||||
|
|||||||
@@ -40,6 +40,10 @@ type KiroTokenData struct {
|
|||||||
ClientSecret string `json:"clientSecret,omitempty"`
|
ClientSecret string `json:"clientSecret,omitempty"`
|
||||||
// Email is the user's email address (used for file naming)
|
// Email is the user's email address (used for file naming)
|
||||||
Email string `json:"email,omitempty"`
|
Email string `json:"email,omitempty"`
|
||||||
|
// StartURL is the IDC/Identity Center start URL (only for IDC auth method)
|
||||||
|
StartURL string `json:"startUrl,omitempty"`
|
||||||
|
// Region is the AWS region for IDC authentication (only for IDC auth method)
|
||||||
|
Region string `json:"region,omitempty"`
|
||||||
}
|
}
|
||||||
|
|
||||||
// KiroAuthBundle aggregates authentication data after OAuth flow completion
|
// KiroAuthBundle aggregates authentication data after OAuth flow completion
|
||||||
|
|||||||
@@ -2,16 +2,19 @@
|
|||||||
package kiro
|
package kiro
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"bufio"
|
||||||
"context"
|
"context"
|
||||||
"crypto/rand"
|
"crypto/rand"
|
||||||
"crypto/sha256"
|
"crypto/sha256"
|
||||||
"encoding/base64"
|
"encoding/base64"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"html"
|
"html"
|
||||||
"io"
|
"io"
|
||||||
"net"
|
"net"
|
||||||
"net/http"
|
"net/http"
|
||||||
|
"os"
|
||||||
"strings"
|
"strings"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
@@ -24,19 +27,31 @@ import (
|
|||||||
const (
|
const (
|
||||||
// AWS SSO OIDC endpoints
|
// AWS SSO OIDC endpoints
|
||||||
ssoOIDCEndpoint = "https://oidc.us-east-1.amazonaws.com"
|
ssoOIDCEndpoint = "https://oidc.us-east-1.amazonaws.com"
|
||||||
|
|
||||||
// Kiro's start URL for Builder ID
|
// Kiro's start URL for Builder ID
|
||||||
builderIDStartURL = "https://view.awsapps.com/start"
|
builderIDStartURL = "https://view.awsapps.com/start"
|
||||||
|
|
||||||
|
// Default region for IDC
|
||||||
|
defaultIDCRegion = "us-east-1"
|
||||||
|
|
||||||
// Polling interval
|
// Polling interval
|
||||||
pollInterval = 5 * time.Second
|
pollInterval = 5 * time.Second
|
||||||
|
|
||||||
// Authorization code flow callback
|
// Authorization code flow callback
|
||||||
authCodeCallbackPath = "/oauth/callback"
|
authCodeCallbackPath = "/oauth/callback"
|
||||||
authCodeCallbackPort = 19877
|
authCodeCallbackPort = 19877
|
||||||
|
|
||||||
// User-Agent to match official Kiro IDE
|
// User-Agent to match official Kiro IDE
|
||||||
kiroUserAgent = "KiroIDE"
|
kiroUserAgent = "KiroIDE"
|
||||||
|
|
||||||
|
// IDC token refresh headers (matching Kiro IDE behavior)
|
||||||
|
idcAmzUserAgent = "aws-sdk-js/3.738.0 ua/2.1 os/other lang/js md/browser#unknown_unknown api/sso-oidc#3.738.0 m/E KiroIDE"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Sentinel errors for OIDC token polling
|
||||||
|
var (
|
||||||
|
ErrAuthorizationPending = errors.New("authorization_pending")
|
||||||
|
ErrSlowDown = errors.New("slow_down")
|
||||||
)
|
)
|
||||||
|
|
||||||
// SSOOIDCClient handles AWS SSO OIDC authentication.
|
// SSOOIDCClient handles AWS SSO OIDC authentication.
|
||||||
@@ -83,6 +98,440 @@ type CreateTokenResponse struct {
|
|||||||
RefreshToken string `json:"refreshToken"`
|
RefreshToken string `json:"refreshToken"`
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// getOIDCEndpoint returns the OIDC endpoint for the given region.
|
||||||
|
func getOIDCEndpoint(region string) string {
|
||||||
|
if region == "" {
|
||||||
|
region = defaultIDCRegion
|
||||||
|
}
|
||||||
|
return fmt.Sprintf("https://oidc.%s.amazonaws.com", region)
|
||||||
|
}
|
||||||
|
|
||||||
|
// promptInput prompts the user for input with an optional default value.
|
||||||
|
func promptInput(prompt, defaultValue string) string {
|
||||||
|
reader := bufio.NewReader(os.Stdin)
|
||||||
|
if defaultValue != "" {
|
||||||
|
fmt.Printf("%s [%s]: ", prompt, defaultValue)
|
||||||
|
} else {
|
||||||
|
fmt.Printf("%s: ", prompt)
|
||||||
|
}
|
||||||
|
input, err := reader.ReadString('\n')
|
||||||
|
if err != nil {
|
||||||
|
log.Warnf("Error reading input: %v", err)
|
||||||
|
return defaultValue
|
||||||
|
}
|
||||||
|
input = strings.TrimSpace(input)
|
||||||
|
if input == "" {
|
||||||
|
return defaultValue
|
||||||
|
}
|
||||||
|
return input
|
||||||
|
}
|
||||||
|
|
||||||
|
// promptSelect prompts the user to select from options using number input.
|
||||||
|
func promptSelect(prompt string, options []string) int {
|
||||||
|
reader := bufio.NewReader(os.Stdin)
|
||||||
|
|
||||||
|
for {
|
||||||
|
fmt.Println(prompt)
|
||||||
|
for i, opt := range options {
|
||||||
|
fmt.Printf(" %d) %s\n", i+1, opt)
|
||||||
|
}
|
||||||
|
fmt.Printf("Enter selection (1-%d): ", len(options))
|
||||||
|
|
||||||
|
input, err := reader.ReadString('\n')
|
||||||
|
if err != nil {
|
||||||
|
log.Warnf("Error reading input: %v", err)
|
||||||
|
return 0 // Default to first option on error
|
||||||
|
}
|
||||||
|
input = strings.TrimSpace(input)
|
||||||
|
|
||||||
|
// Parse the selection
|
||||||
|
var selection int
|
||||||
|
if _, err := fmt.Sscanf(input, "%d", &selection); err != nil || selection < 1 || selection > len(options) {
|
||||||
|
fmt.Printf("Invalid selection '%s'. Please enter a number between 1 and %d.\n\n", input, len(options))
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
return selection - 1
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// RegisterClientWithRegion registers a new OIDC client with AWS using a specific region.
|
||||||
|
func (c *SSOOIDCClient) RegisterClientWithRegion(ctx context.Context, region string) (*RegisterClientResponse, error) {
|
||||||
|
endpoint := getOIDCEndpoint(region)
|
||||||
|
|
||||||
|
payload := map[string]interface{}{
|
||||||
|
"clientName": "Kiro IDE",
|
||||||
|
"clientType": "public",
|
||||||
|
"scopes": []string{"codewhisperer:completions", "codewhisperer:analysis", "codewhisperer:conversations", "codewhisperer:transformations", "codewhisperer:taskassist"},
|
||||||
|
"grantTypes": []string{"urn:ietf:params:oauth:grant-type:device_code", "refresh_token"},
|
||||||
|
}
|
||||||
|
|
||||||
|
body, err := json.Marshal(payload)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
req, err := http.NewRequestWithContext(ctx, http.MethodPost, endpoint+"/client/register", strings.NewReader(string(body)))
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
req.Header.Set("Content-Type", "application/json")
|
||||||
|
req.Header.Set("User-Agent", kiroUserAgent)
|
||||||
|
|
||||||
|
resp, err := c.httpClient.Do(req)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
defer resp.Body.Close()
|
||||||
|
|
||||||
|
respBody, err := io.ReadAll(resp.Body)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
if resp.StatusCode != http.StatusOK {
|
||||||
|
log.Debugf("register client failed (status %d): %s", resp.StatusCode, string(respBody))
|
||||||
|
return nil, fmt.Errorf("register client failed (status %d)", resp.StatusCode)
|
||||||
|
}
|
||||||
|
|
||||||
|
var result RegisterClientResponse
|
||||||
|
if err := json.Unmarshal(respBody, &result); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
return &result, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// StartDeviceAuthorizationWithIDC starts the device authorization flow for IDC.
|
||||||
|
func (c *SSOOIDCClient) StartDeviceAuthorizationWithIDC(ctx context.Context, clientID, clientSecret, startURL, region string) (*StartDeviceAuthResponse, error) {
|
||||||
|
endpoint := getOIDCEndpoint(region)
|
||||||
|
|
||||||
|
payload := map[string]string{
|
||||||
|
"clientId": clientID,
|
||||||
|
"clientSecret": clientSecret,
|
||||||
|
"startUrl": startURL,
|
||||||
|
}
|
||||||
|
|
||||||
|
body, err := json.Marshal(payload)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
req, err := http.NewRequestWithContext(ctx, http.MethodPost, endpoint+"/device_authorization", strings.NewReader(string(body)))
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
req.Header.Set("Content-Type", "application/json")
|
||||||
|
req.Header.Set("User-Agent", kiroUserAgent)
|
||||||
|
|
||||||
|
resp, err := c.httpClient.Do(req)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
defer resp.Body.Close()
|
||||||
|
|
||||||
|
respBody, err := io.ReadAll(resp.Body)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
if resp.StatusCode != http.StatusOK {
|
||||||
|
log.Debugf("start device auth failed (status %d): %s", resp.StatusCode, string(respBody))
|
||||||
|
return nil, fmt.Errorf("start device auth failed (status %d)", resp.StatusCode)
|
||||||
|
}
|
||||||
|
|
||||||
|
var result StartDeviceAuthResponse
|
||||||
|
if err := json.Unmarshal(respBody, &result); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
return &result, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// CreateTokenWithRegion polls for the access token after user authorization using a specific region.
|
||||||
|
func (c *SSOOIDCClient) CreateTokenWithRegion(ctx context.Context, clientID, clientSecret, deviceCode, region string) (*CreateTokenResponse, error) {
|
||||||
|
endpoint := getOIDCEndpoint(region)
|
||||||
|
|
||||||
|
payload := map[string]string{
|
||||||
|
"clientId": clientID,
|
||||||
|
"clientSecret": clientSecret,
|
||||||
|
"deviceCode": deviceCode,
|
||||||
|
"grantType": "urn:ietf:params:oauth:grant-type:device_code",
|
||||||
|
}
|
||||||
|
|
||||||
|
body, err := json.Marshal(payload)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
req, err := http.NewRequestWithContext(ctx, http.MethodPost, endpoint+"/token", strings.NewReader(string(body)))
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
req.Header.Set("Content-Type", "application/json")
|
||||||
|
req.Header.Set("User-Agent", kiroUserAgent)
|
||||||
|
|
||||||
|
resp, err := c.httpClient.Do(req)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
defer resp.Body.Close()
|
||||||
|
|
||||||
|
respBody, err := io.ReadAll(resp.Body)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check for pending authorization
|
||||||
|
if resp.StatusCode == http.StatusBadRequest {
|
||||||
|
var errResp struct {
|
||||||
|
Error string `json:"error"`
|
||||||
|
}
|
||||||
|
if json.Unmarshal(respBody, &errResp) == nil {
|
||||||
|
if errResp.Error == "authorization_pending" {
|
||||||
|
return nil, ErrAuthorizationPending
|
||||||
|
}
|
||||||
|
if errResp.Error == "slow_down" {
|
||||||
|
return nil, ErrSlowDown
|
||||||
|
}
|
||||||
|
}
|
||||||
|
log.Debugf("create token failed: %s", string(respBody))
|
||||||
|
return nil, fmt.Errorf("create token failed")
|
||||||
|
}
|
||||||
|
|
||||||
|
if resp.StatusCode != http.StatusOK {
|
||||||
|
log.Debugf("create token failed (status %d): %s", resp.StatusCode, string(respBody))
|
||||||
|
return nil, fmt.Errorf("create token failed (status %d)", resp.StatusCode)
|
||||||
|
}
|
||||||
|
|
||||||
|
var result CreateTokenResponse
|
||||||
|
if err := json.Unmarshal(respBody, &result); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
return &result, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// RefreshTokenWithRegion refreshes an access token using the refresh token with a specific region.
|
||||||
|
func (c *SSOOIDCClient) RefreshTokenWithRegion(ctx context.Context, clientID, clientSecret, refreshToken, region, startURL string) (*KiroTokenData, error) {
|
||||||
|
endpoint := getOIDCEndpoint(region)
|
||||||
|
|
||||||
|
payload := map[string]string{
|
||||||
|
"clientId": clientID,
|
||||||
|
"clientSecret": clientSecret,
|
||||||
|
"refreshToken": refreshToken,
|
||||||
|
"grantType": "refresh_token",
|
||||||
|
}
|
||||||
|
|
||||||
|
body, err := json.Marshal(payload)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
req, err := http.NewRequestWithContext(ctx, http.MethodPost, endpoint+"/token", strings.NewReader(string(body)))
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
// Set headers matching kiro2api's IDC token refresh
|
||||||
|
// These headers are required for successful IDC token refresh
|
||||||
|
req.Header.Set("Content-Type", "application/json")
|
||||||
|
req.Header.Set("Host", fmt.Sprintf("oidc.%s.amazonaws.com", region))
|
||||||
|
req.Header.Set("Connection", "keep-alive")
|
||||||
|
req.Header.Set("x-amz-user-agent", idcAmzUserAgent)
|
||||||
|
req.Header.Set("Accept", "*/*")
|
||||||
|
req.Header.Set("Accept-Language", "*")
|
||||||
|
req.Header.Set("sec-fetch-mode", "cors")
|
||||||
|
req.Header.Set("User-Agent", "node")
|
||||||
|
req.Header.Set("Accept-Encoding", "br, gzip, deflate")
|
||||||
|
|
||||||
|
resp, err := c.httpClient.Do(req)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
defer resp.Body.Close()
|
||||||
|
|
||||||
|
respBody, err := io.ReadAll(resp.Body)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
if resp.StatusCode != http.StatusOK {
|
||||||
|
log.Warnf("IDC token refresh failed (status %d): %s", resp.StatusCode, string(respBody))
|
||||||
|
return nil, fmt.Errorf("token refresh failed (status %d)", resp.StatusCode)
|
||||||
|
}
|
||||||
|
|
||||||
|
var result CreateTokenResponse
|
||||||
|
if err := json.Unmarshal(respBody, &result); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
expiresAt := time.Now().Add(time.Duration(result.ExpiresIn) * time.Second)
|
||||||
|
|
||||||
|
return &KiroTokenData{
|
||||||
|
AccessToken: result.AccessToken,
|
||||||
|
RefreshToken: result.RefreshToken,
|
||||||
|
ExpiresAt: expiresAt.Format(time.RFC3339),
|
||||||
|
AuthMethod: "idc",
|
||||||
|
Provider: "AWS",
|
||||||
|
ClientID: clientID,
|
||||||
|
ClientSecret: clientSecret,
|
||||||
|
StartURL: startURL,
|
||||||
|
Region: region,
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// LoginWithIDC performs the full device code flow for AWS Identity Center (IDC).
|
||||||
|
func (c *SSOOIDCClient) LoginWithIDC(ctx context.Context, startURL, region string) (*KiroTokenData, error) {
|
||||||
|
fmt.Println("\n╔══════════════════════════════════════════════════════════╗")
|
||||||
|
fmt.Println("║ Kiro Authentication (AWS Identity Center) ║")
|
||||||
|
fmt.Println("╚══════════════════════════════════════════════════════════╝")
|
||||||
|
|
||||||
|
// Step 1: Register client with the specified region
|
||||||
|
fmt.Println("\nRegistering client...")
|
||||||
|
regResp, err := c.RegisterClientWithRegion(ctx, region)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to register client: %w", err)
|
||||||
|
}
|
||||||
|
log.Debugf("Client registered: %s", regResp.ClientID)
|
||||||
|
|
||||||
|
// Step 2: Start device authorization with IDC start URL
|
||||||
|
fmt.Println("Starting device authorization...")
|
||||||
|
authResp, err := c.StartDeviceAuthorizationWithIDC(ctx, regResp.ClientID, regResp.ClientSecret, startURL, region)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to start device auth: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Step 3: Show user the verification URL
|
||||||
|
fmt.Printf("\n")
|
||||||
|
fmt.Println("════════════════════════════════════════════════════════════")
|
||||||
|
fmt.Printf(" Confirm the following code in the browser:\n")
|
||||||
|
fmt.Printf(" Code: %s\n", authResp.UserCode)
|
||||||
|
fmt.Println("════════════════════════════════════════════════════════════")
|
||||||
|
fmt.Printf("\n Open this URL: %s\n\n", authResp.VerificationURIComplete)
|
||||||
|
|
||||||
|
// Set incognito mode based on config
|
||||||
|
if c.cfg != nil {
|
||||||
|
browser.SetIncognitoMode(c.cfg.IncognitoBrowser)
|
||||||
|
if !c.cfg.IncognitoBrowser {
|
||||||
|
log.Info("kiro: using normal browser mode (--no-incognito). Note: You may not be able to select a different account.")
|
||||||
|
} else {
|
||||||
|
log.Debug("kiro: using incognito mode for multi-account support")
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
browser.SetIncognitoMode(true)
|
||||||
|
log.Debug("kiro: using incognito mode for multi-account support (default)")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Open browser
|
||||||
|
if err := browser.OpenURL(authResp.VerificationURIComplete); err != nil {
|
||||||
|
log.Warnf("Could not open browser automatically: %v", err)
|
||||||
|
fmt.Println(" Please open the URL manually in your browser.")
|
||||||
|
} else {
|
||||||
|
fmt.Println(" (Browser opened automatically)")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Step 4: Poll for token
|
||||||
|
fmt.Println("Waiting for authorization...")
|
||||||
|
|
||||||
|
interval := pollInterval
|
||||||
|
if authResp.Interval > 0 {
|
||||||
|
interval = time.Duration(authResp.Interval) * time.Second
|
||||||
|
}
|
||||||
|
|
||||||
|
deadline := time.Now().Add(time.Duration(authResp.ExpiresIn) * time.Second)
|
||||||
|
|
||||||
|
for time.Now().Before(deadline) {
|
||||||
|
select {
|
||||||
|
case <-ctx.Done():
|
||||||
|
browser.CloseBrowser()
|
||||||
|
return nil, ctx.Err()
|
||||||
|
case <-time.After(interval):
|
||||||
|
tokenResp, err := c.CreateTokenWithRegion(ctx, regResp.ClientID, regResp.ClientSecret, authResp.DeviceCode, region)
|
||||||
|
if err != nil {
|
||||||
|
if errors.Is(err, ErrAuthorizationPending) {
|
||||||
|
fmt.Print(".")
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if errors.Is(err, ErrSlowDown) {
|
||||||
|
interval += 5 * time.Second
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
browser.CloseBrowser()
|
||||||
|
return nil, fmt.Errorf("token creation failed: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
fmt.Println("\n\n✓ Authorization successful!")
|
||||||
|
|
||||||
|
// Close the browser window
|
||||||
|
if err := browser.CloseBrowser(); err != nil {
|
||||||
|
log.Debugf("Failed to close browser: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Step 5: Get profile ARN from CodeWhisperer API
|
||||||
|
fmt.Println("Fetching profile information...")
|
||||||
|
profileArn := c.fetchProfileArn(ctx, tokenResp.AccessToken)
|
||||||
|
|
||||||
|
// Fetch user email
|
||||||
|
email := FetchUserEmailWithFallback(ctx, c.cfg, tokenResp.AccessToken)
|
||||||
|
if email != "" {
|
||||||
|
fmt.Printf(" Logged in as: %s\n", email)
|
||||||
|
}
|
||||||
|
|
||||||
|
expiresAt := time.Now().Add(time.Duration(tokenResp.ExpiresIn) * time.Second)
|
||||||
|
|
||||||
|
return &KiroTokenData{
|
||||||
|
AccessToken: tokenResp.AccessToken,
|
||||||
|
RefreshToken: tokenResp.RefreshToken,
|
||||||
|
ProfileArn: profileArn,
|
||||||
|
ExpiresAt: expiresAt.Format(time.RFC3339),
|
||||||
|
AuthMethod: "idc",
|
||||||
|
Provider: "AWS",
|
||||||
|
ClientID: regResp.ClientID,
|
||||||
|
ClientSecret: regResp.ClientSecret,
|
||||||
|
Email: email,
|
||||||
|
StartURL: startURL,
|
||||||
|
Region: region,
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Close browser on timeout
|
||||||
|
if err := browser.CloseBrowser(); err != nil {
|
||||||
|
log.Debugf("Failed to close browser on timeout: %v", err)
|
||||||
|
}
|
||||||
|
return nil, fmt.Errorf("authorization timed out")
|
||||||
|
}
|
||||||
|
|
||||||
|
// LoginWithMethodSelection prompts the user to select between Builder ID and IDC, then performs the login.
|
||||||
|
func (c *SSOOIDCClient) LoginWithMethodSelection(ctx context.Context) (*KiroTokenData, error) {
|
||||||
|
fmt.Println("\n╔══════════════════════════════════════════════════════════╗")
|
||||||
|
fmt.Println("║ Kiro Authentication (AWS) ║")
|
||||||
|
fmt.Println("╚══════════════════════════════════════════════════════════╝")
|
||||||
|
|
||||||
|
// Prompt for login method
|
||||||
|
options := []string{
|
||||||
|
"Use with Builder ID (personal AWS account)",
|
||||||
|
"Use with IDC Account (organization SSO)",
|
||||||
|
}
|
||||||
|
selection := promptSelect("\n? Select login method:", options)
|
||||||
|
|
||||||
|
if selection == 0 {
|
||||||
|
// Builder ID flow - use existing implementation
|
||||||
|
return c.LoginWithBuilderID(ctx)
|
||||||
|
}
|
||||||
|
|
||||||
|
// IDC flow - prompt for start URL and region
|
||||||
|
fmt.Println()
|
||||||
|
startURL := promptInput("? Enter Start URL", "")
|
||||||
|
if startURL == "" {
|
||||||
|
return nil, fmt.Errorf("start URL is required for IDC login")
|
||||||
|
}
|
||||||
|
|
||||||
|
region := promptInput("? Enter Region", defaultIDCRegion)
|
||||||
|
|
||||||
|
return c.LoginWithIDC(ctx, startURL, region)
|
||||||
|
}
|
||||||
|
|
||||||
// RegisterClient registers a new OIDC client with AWS.
|
// RegisterClient registers a new OIDC client with AWS.
|
||||||
func (c *SSOOIDCClient) RegisterClient(ctx context.Context) (*RegisterClientResponse, error) {
|
func (c *SSOOIDCClient) RegisterClient(ctx context.Context) (*RegisterClientResponse, error) {
|
||||||
payload := map[string]interface{}{
|
payload := map[string]interface{}{
|
||||||
@@ -211,10 +660,10 @@ func (c *SSOOIDCClient) CreateToken(ctx context.Context, clientID, clientSecret,
|
|||||||
}
|
}
|
||||||
if json.Unmarshal(respBody, &errResp) == nil {
|
if json.Unmarshal(respBody, &errResp) == nil {
|
||||||
if errResp.Error == "authorization_pending" {
|
if errResp.Error == "authorization_pending" {
|
||||||
return nil, fmt.Errorf("authorization_pending")
|
return nil, ErrAuthorizationPending
|
||||||
}
|
}
|
||||||
if errResp.Error == "slow_down" {
|
if errResp.Error == "slow_down" {
|
||||||
return nil, fmt.Errorf("slow_down")
|
return nil, ErrSlowDown
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
log.Debugf("create token failed: %s", string(respBody))
|
log.Debugf("create token failed: %s", string(respBody))
|
||||||
@@ -359,12 +808,11 @@ func (c *SSOOIDCClient) LoginWithBuilderID(ctx context.Context) (*KiroTokenData,
|
|||||||
case <-time.After(interval):
|
case <-time.After(interval):
|
||||||
tokenResp, err := c.CreateToken(ctx, regResp.ClientID, regResp.ClientSecret, authResp.DeviceCode)
|
tokenResp, err := c.CreateToken(ctx, regResp.ClientID, regResp.ClientSecret, authResp.DeviceCode)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
errStr := err.Error()
|
if errors.Is(err, ErrAuthorizationPending) {
|
||||||
if strings.Contains(errStr, "authorization_pending") {
|
|
||||||
fmt.Print(".")
|
fmt.Print(".")
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
if strings.Contains(errStr, "slow_down") {
|
if errors.Is(err, ErrSlowDown) {
|
||||||
interval += 5 * time.Second
|
interval += 5 * time.Second
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -39,6 +39,9 @@ type Config struct {
|
|||||||
// Debug enables or disables debug-level logging and other debug features.
|
// Debug enables or disables debug-level logging and other debug features.
|
||||||
Debug bool `yaml:"debug" json:"debug"`
|
Debug bool `yaml:"debug" json:"debug"`
|
||||||
|
|
||||||
|
// CommercialMode disables high-overhead HTTP middleware features to minimize per-request memory usage.
|
||||||
|
CommercialMode bool `yaml:"commercial-mode" json:"commercial-mode"`
|
||||||
|
|
||||||
// LoggingToFile controls whether application logs are written to rotating files or stdout.
|
// LoggingToFile controls whether application logs are written to rotating files or stdout.
|
||||||
LoggingToFile bool `yaml:"logging-to-file" json:"logging-to-file"`
|
LoggingToFile bool `yaml:"logging-to-file" json:"logging-to-file"`
|
||||||
|
|
||||||
@@ -876,8 +879,8 @@ func getOrCreateMapValue(mapNode *yaml.Node, key string) *yaml.Node {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// mergeMappingPreserve merges keys from src into dst mapping node while preserving
|
// mergeMappingPreserve merges keys from src into dst mapping node while preserving
|
||||||
// key order and comments of existing keys in dst. Unknown keys from src are appended
|
// key order and comments of existing keys in dst. New keys are only added if their
|
||||||
// to dst at the end, copying their node structure from src.
|
// value is non-zero to avoid polluting the config with defaults.
|
||||||
func mergeMappingPreserve(dst, src *yaml.Node) {
|
func mergeMappingPreserve(dst, src *yaml.Node) {
|
||||||
if dst == nil || src == nil {
|
if dst == nil || src == nil {
|
||||||
return
|
return
|
||||||
@@ -888,20 +891,19 @@ func mergeMappingPreserve(dst, src *yaml.Node) {
|
|||||||
copyNodeShallow(dst, src)
|
copyNodeShallow(dst, src)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
// Build a lookup of existing keys in dst
|
|
||||||
for i := 0; i+1 < len(src.Content); i += 2 {
|
for i := 0; i+1 < len(src.Content); i += 2 {
|
||||||
sk := src.Content[i]
|
sk := src.Content[i]
|
||||||
sv := src.Content[i+1]
|
sv := src.Content[i+1]
|
||||||
idx := findMapKeyIndex(dst, sk.Value)
|
idx := findMapKeyIndex(dst, sk.Value)
|
||||||
if idx >= 0 {
|
if idx >= 0 {
|
||||||
// Merge into existing value node
|
// Merge into existing value node (always update, even to zero values)
|
||||||
dv := dst.Content[idx+1]
|
dv := dst.Content[idx+1]
|
||||||
mergeNodePreserve(dv, sv)
|
mergeNodePreserve(dv, sv)
|
||||||
} else {
|
} else {
|
||||||
if shouldSkipEmptyCollectionOnPersist(sk.Value, sv) {
|
// New key: only add if value is non-zero to avoid polluting config with defaults
|
||||||
|
if isZeroValueNode(sv) {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
// Append new key/value pair by deep-copying from src
|
|
||||||
dst.Content = append(dst.Content, deepCopyNode(sk), deepCopyNode(sv))
|
dst.Content = append(dst.Content, deepCopyNode(sk), deepCopyNode(sv))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -984,32 +986,49 @@ func findMapKeyIndex(mapNode *yaml.Node, key string) int {
|
|||||||
return -1
|
return -1
|
||||||
}
|
}
|
||||||
|
|
||||||
func shouldSkipEmptyCollectionOnPersist(key string, node *yaml.Node) bool {
|
// isZeroValueNode returns true if the YAML node represents a zero/default value
|
||||||
switch key {
|
// that should not be written as a new key to preserve config cleanliness.
|
||||||
case "generative-language-api-key",
|
// For mappings and sequences, recursively checks if all children are zero values.
|
||||||
"gemini-api-key",
|
func isZeroValueNode(node *yaml.Node) bool {
|
||||||
"vertex-api-key",
|
|
||||||
"claude-api-key",
|
|
||||||
"codex-api-key",
|
|
||||||
"openai-compatibility":
|
|
||||||
return isEmptyCollectionNode(node)
|
|
||||||
default:
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func isEmptyCollectionNode(node *yaml.Node) bool {
|
|
||||||
if node == nil {
|
if node == nil {
|
||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
switch node.Kind {
|
switch node.Kind {
|
||||||
case yaml.SequenceNode:
|
|
||||||
return len(node.Content) == 0
|
|
||||||
case yaml.ScalarNode:
|
case yaml.ScalarNode:
|
||||||
return node.Tag == "!!null"
|
switch node.Tag {
|
||||||
default:
|
case "!!bool":
|
||||||
return false
|
return node.Value == "false"
|
||||||
|
case "!!int", "!!float":
|
||||||
|
return node.Value == "0" || node.Value == "0.0"
|
||||||
|
case "!!str":
|
||||||
|
return node.Value == ""
|
||||||
|
case "!!null":
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
case yaml.SequenceNode:
|
||||||
|
if len(node.Content) == 0 {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
// Check if all elements are zero values
|
||||||
|
for _, child := range node.Content {
|
||||||
|
if !isZeroValueNode(child) {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return true
|
||||||
|
case yaml.MappingNode:
|
||||||
|
if len(node.Content) == 0 {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
// Check if all values are zero values (values are at odd indices)
|
||||||
|
for i := 1; i < len(node.Content); i += 2 {
|
||||||
|
if !isZeroValueNode(node.Content[i]) {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return true
|
||||||
}
|
}
|
||||||
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
// deepCopyNode creates a deep copy of a yaml.Node graph.
|
// deepCopyNode creates a deep copy of a yaml.Node graph.
|
||||||
|
|||||||
@@ -30,13 +30,13 @@ type SDKConfig struct {
|
|||||||
// StreamingConfig holds server streaming behavior configuration.
|
// StreamingConfig holds server streaming behavior configuration.
|
||||||
type StreamingConfig struct {
|
type StreamingConfig struct {
|
||||||
// KeepAliveSeconds controls how often the server emits SSE heartbeats (": keep-alive\n\n").
|
// KeepAliveSeconds controls how often the server emits SSE heartbeats (": keep-alive\n\n").
|
||||||
// nil means default (15 seconds). <= 0 disables keep-alives.
|
// <= 0 disables keep-alives. Default is 0.
|
||||||
KeepAliveSeconds *int `yaml:"keepalive-seconds,omitempty" json:"keepalive-seconds,omitempty"`
|
KeepAliveSeconds int `yaml:"keepalive-seconds,omitempty" json:"keepalive-seconds,omitempty"`
|
||||||
|
|
||||||
// BootstrapRetries controls how many times the server may retry a streaming request before any bytes are sent,
|
// BootstrapRetries controls how many times the server may retry a streaming request before any bytes are sent,
|
||||||
// to allow auth rotation / transient recovery.
|
// to allow auth rotation / transient recovery.
|
||||||
// nil means default (2). 0 disables bootstrap retries.
|
// <= 0 disables bootstrap retries. Default is 0.
|
||||||
BootstrapRetries *int `yaml:"bootstrap-retries,omitempty" json:"bootstrap-retries,omitempty"`
|
BootstrapRetries int `yaml:"bootstrap-retries,omitempty" json:"bootstrap-retries,omitempty"`
|
||||||
}
|
}
|
||||||
|
|
||||||
// AccessConfig groups request authentication providers.
|
// AccessConfig groups request authentication providers.
|
||||||
|
|||||||
@@ -73,17 +73,15 @@ func GinLogrusLogger() gin.HandlerFunc {
|
|||||||
method := c.Request.Method
|
method := c.Request.Method
|
||||||
errorMessage := c.Errors.ByType(gin.ErrorTypePrivate).String()
|
errorMessage := c.Errors.ByType(gin.ErrorTypePrivate).String()
|
||||||
|
|
||||||
|
if requestID == "" {
|
||||||
|
requestID = "--------"
|
||||||
|
}
|
||||||
logLine := fmt.Sprintf("%3d | %13v | %15s | %-7s \"%s\"", statusCode, latency, clientIP, method, path)
|
logLine := fmt.Sprintf("%3d | %13v | %15s | %-7s \"%s\"", statusCode, latency, clientIP, method, path)
|
||||||
if errorMessage != "" {
|
if errorMessage != "" {
|
||||||
logLine = logLine + " | " + errorMessage
|
logLine = logLine + " | " + errorMessage
|
||||||
}
|
}
|
||||||
|
|
||||||
var entry *log.Entry
|
entry := log.WithField("request_id", requestID)
|
||||||
if requestID != "" {
|
|
||||||
entry = log.WithField("request_id", requestID)
|
|
||||||
} else {
|
|
||||||
entry = log.WithField("request_id", "--------")
|
|
||||||
}
|
|
||||||
|
|
||||||
switch {
|
switch {
|
||||||
case statusCode >= http.StatusInternalServerError:
|
case statusCode >= http.StatusInternalServerError:
|
||||||
|
|||||||
@@ -40,25 +40,22 @@ func (m *LogFormatter) Format(entry *log.Entry) ([]byte, error) {
|
|||||||
timestamp := entry.Time.Format("2006-01-02 15:04:05")
|
timestamp := entry.Time.Format("2006-01-02 15:04:05")
|
||||||
message := strings.TrimRight(entry.Message, "\r\n")
|
message := strings.TrimRight(entry.Message, "\r\n")
|
||||||
|
|
||||||
reqID := ""
|
reqID := "--------"
|
||||||
if id, ok := entry.Data["request_id"].(string); ok && id != "" {
|
if id, ok := entry.Data["request_id"].(string); ok && id != "" {
|
||||||
reqID = id
|
reqID = id
|
||||||
}
|
}
|
||||||
|
|
||||||
callerFile := "unknown"
|
level := entry.Level.String()
|
||||||
callerLine := 0
|
if level == "warning" {
|
||||||
if entry.Caller != nil {
|
level = "warn"
|
||||||
callerFile = filepath.Base(entry.Caller.File)
|
|
||||||
callerLine = entry.Caller.Line
|
|
||||||
}
|
}
|
||||||
|
levelStr := fmt.Sprintf("%-5s", level)
|
||||||
levelStr := fmt.Sprintf("%-5s", entry.Level.String())
|
|
||||||
|
|
||||||
var formatted string
|
var formatted string
|
||||||
if reqID != "" {
|
if entry.Caller != nil {
|
||||||
formatted = fmt.Sprintf("[%s] [%s] [%s:%d] | %s | %s\n", timestamp, levelStr, callerFile, callerLine, reqID, message)
|
formatted = fmt.Sprintf("[%s] [%s] [%s] [%s:%d] %s\n", timestamp, reqID, levelStr, filepath.Base(entry.Caller.File), entry.Caller.Line, message)
|
||||||
} else {
|
} else {
|
||||||
formatted = fmt.Sprintf("[%s] [%s] [%s:%d] %s\n", timestamp, levelStr, callerFile, callerLine, message)
|
formatted = fmt.Sprintf("[%s] [%s] [%s] %s\n", timestamp, reqID, levelStr, message)
|
||||||
}
|
}
|
||||||
buffer.WriteString(formatted)
|
buffer.WriteString(formatted)
|
||||||
|
|
||||||
|
|||||||
@@ -727,6 +727,7 @@ func GetIFlowModels() []*ModelInfo {
|
|||||||
{ID: "qwen3-max-preview", DisplayName: "Qwen3-Max-Preview", Description: "Qwen3 Max preview build", Created: 1757030400},
|
{ID: "qwen3-max-preview", DisplayName: "Qwen3-Max-Preview", Description: "Qwen3 Max preview build", Created: 1757030400},
|
||||||
{ID: "kimi-k2-0905", DisplayName: "Kimi-K2-Instruct-0905", Description: "Moonshot Kimi K2 instruct 0905", Created: 1757030400},
|
{ID: "kimi-k2-0905", DisplayName: "Kimi-K2-Instruct-0905", Description: "Moonshot Kimi K2 instruct 0905", Created: 1757030400},
|
||||||
{ID: "glm-4.6", DisplayName: "GLM-4.6", Description: "Zhipu GLM 4.6 general model", Created: 1759190400, Thinking: iFlowThinkingSupport},
|
{ID: "glm-4.6", DisplayName: "GLM-4.6", Description: "Zhipu GLM 4.6 general model", Created: 1759190400, Thinking: iFlowThinkingSupport},
|
||||||
|
{ID: "glm-4.7", DisplayName: "GLM-4.7", Description: "Zhipu GLM 4.7 general model", Created: 1766448000, Thinking: iFlowThinkingSupport},
|
||||||
{ID: "kimi-k2", DisplayName: "Kimi-K2", Description: "Moonshot Kimi K2 general model", Created: 1752192000},
|
{ID: "kimi-k2", DisplayName: "Kimi-K2", Description: "Moonshot Kimi K2 general model", Created: 1752192000},
|
||||||
{ID: "kimi-k2-thinking", DisplayName: "Kimi-K2-Thinking", Description: "Moonshot Kimi K2 thinking model", Created: 1762387200},
|
{ID: "kimi-k2-thinking", DisplayName: "Kimi-K2-Thinking", Description: "Moonshot Kimi K2 thinking model", Created: 1762387200},
|
||||||
{ID: "deepseek-v3.2-chat", DisplayName: "DeepSeek-V3.2", Description: "DeepSeek V3.2 Chat", Created: 1764576000},
|
{ID: "deepseek-v3.2-chat", DisplayName: "DeepSeek-V3.2", Description: "DeepSeek V3.2 Chat", Created: 1764576000},
|
||||||
@@ -740,6 +741,7 @@ func GetIFlowModels() []*ModelInfo {
|
|||||||
{ID: "qwen3-235b-a22b-instruct", DisplayName: "Qwen3-235B-A22B-Instruct", Description: "Qwen3 235B A22B Instruct", Created: 1753401600},
|
{ID: "qwen3-235b-a22b-instruct", DisplayName: "Qwen3-235B-A22B-Instruct", Description: "Qwen3 235B A22B Instruct", Created: 1753401600},
|
||||||
{ID: "qwen3-235b", DisplayName: "Qwen3-235B-A22B", Description: "Qwen3 235B A22B", Created: 1753401600},
|
{ID: "qwen3-235b", DisplayName: "Qwen3-235B-A22B", Description: "Qwen3 235B A22B", Created: 1753401600},
|
||||||
{ID: "minimax-m2", DisplayName: "MiniMax-M2", Description: "MiniMax M2", Created: 1758672000},
|
{ID: "minimax-m2", DisplayName: "MiniMax-M2", Description: "MiniMax M2", Created: 1758672000},
|
||||||
|
{ID: "minimax-m2.1", DisplayName: "MiniMax-M2.1", Description: "MiniMax M2.1", Created: 1766448000, Thinking: iFlowThinkingSupport},
|
||||||
}
|
}
|
||||||
models := make([]*ModelInfo, 0, len(entries))
|
models := make([]*ModelInfo, 0, len(entries))
|
||||||
for _, entry := range entries {
|
for _, entry := range entries {
|
||||||
|
|||||||
@@ -67,6 +67,7 @@ func (e *IFlowExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, re
|
|||||||
return resp, errValidate
|
return resp, errValidate
|
||||||
}
|
}
|
||||||
body = applyIFlowThinkingConfig(body)
|
body = applyIFlowThinkingConfig(body)
|
||||||
|
body = preserveReasoningContentInMessages(body)
|
||||||
body = applyPayloadConfig(e.cfg, req.Model, body)
|
body = applyPayloadConfig(e.cfg, req.Model, body)
|
||||||
|
|
||||||
endpoint := strings.TrimSuffix(baseURL, "/") + iflowDefaultEndpoint
|
endpoint := strings.TrimSuffix(baseURL, "/") + iflowDefaultEndpoint
|
||||||
@@ -159,6 +160,7 @@ func (e *IFlowExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Au
|
|||||||
return nil, errValidate
|
return nil, errValidate
|
||||||
}
|
}
|
||||||
body = applyIFlowThinkingConfig(body)
|
body = applyIFlowThinkingConfig(body)
|
||||||
|
body = preserveReasoningContentInMessages(body)
|
||||||
// Ensure tools array exists to avoid provider quirks similar to Qwen's behaviour.
|
// Ensure tools array exists to avoid provider quirks similar to Qwen's behaviour.
|
||||||
toolsResult := gjson.GetBytes(body, "tools")
|
toolsResult := gjson.GetBytes(body, "tools")
|
||||||
if toolsResult.Exists() && toolsResult.IsArray() && len(toolsResult.Array()) == 0 {
|
if toolsResult.Exists() && toolsResult.IsArray() && len(toolsResult.Array()) == 0 {
|
||||||
@@ -445,20 +447,98 @@ func ensureToolsArray(body []byte) []byte {
|
|||||||
return updated
|
return updated
|
||||||
}
|
}
|
||||||
|
|
||||||
// applyIFlowThinkingConfig converts normalized reasoning_effort to iFlow chat_template_kwargs.enable_thinking.
|
// preserveReasoningContentInMessages ensures reasoning_content from assistant messages in the
|
||||||
// This should be called after NormalizeThinkingConfig has processed the payload.
|
// conversation history is preserved when sending to iFlow models that support thinking.
|
||||||
// iFlow only supports boolean enable_thinking, so any non-"none" effort enables thinking.
|
// This is critical for multi-turn conversations where the model needs to see its previous
|
||||||
func applyIFlowThinkingConfig(body []byte) []byte {
|
// reasoning to maintain coherent thought chains across tool calls and conversation turns.
|
||||||
effort := gjson.GetBytes(body, "reasoning_effort")
|
//
|
||||||
if !effort.Exists() {
|
// For GLM-4.7 and MiniMax-M2.1, the full assistant response (including reasoning) must be
|
||||||
|
// appended back into message history before the next call.
|
||||||
|
func preserveReasoningContentInMessages(body []byte) []byte {
|
||||||
|
model := strings.ToLower(gjson.GetBytes(body, "model").String())
|
||||||
|
|
||||||
|
// Only apply to models that support thinking with history preservation
|
||||||
|
needsPreservation := strings.HasPrefix(model, "glm-4.7") ||
|
||||||
|
strings.HasPrefix(model, "glm-4-7") ||
|
||||||
|
strings.HasPrefix(model, "minimax-m2.1") ||
|
||||||
|
strings.HasPrefix(model, "minimax-m2-1")
|
||||||
|
|
||||||
|
if !needsPreservation {
|
||||||
return body
|
return body
|
||||||
}
|
}
|
||||||
|
|
||||||
val := strings.ToLower(strings.TrimSpace(effort.String()))
|
messages := gjson.GetBytes(body, "messages")
|
||||||
enableThinking := val != "none" && val != ""
|
if !messages.Exists() || !messages.IsArray() {
|
||||||
|
return body
|
||||||
|
}
|
||||||
|
|
||||||
body, _ = sjson.DeleteBytes(body, "reasoning_effort")
|
// Check if any assistant message already has reasoning_content preserved
|
||||||
body, _ = sjson.SetBytes(body, "chat_template_kwargs.enable_thinking", enableThinking)
|
hasReasoningContent := false
|
||||||
|
messages.ForEach(func(_, msg gjson.Result) bool {
|
||||||
|
role := msg.Get("role").String()
|
||||||
|
if role == "assistant" {
|
||||||
|
rc := msg.Get("reasoning_content")
|
||||||
|
if rc.Exists() && rc.String() != "" {
|
||||||
|
hasReasoningContent = true
|
||||||
|
return false // stop iteration
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return true
|
||||||
|
})
|
||||||
|
|
||||||
|
// If reasoning content is already present, the messages are properly formatted
|
||||||
|
// No need to modify - the client has correctly preserved reasoning in history
|
||||||
|
if hasReasoningContent {
|
||||||
|
log.Debugf("iflow executor: reasoning_content found in message history for %s", model)
|
||||||
|
}
|
||||||
|
|
||||||
|
return body
|
||||||
|
}
|
||||||
|
|
||||||
|
// applyIFlowThinkingConfig converts normalized reasoning_effort to model-specific thinking configurations.
|
||||||
|
// This should be called after NormalizeThinkingConfig has processed the payload.
|
||||||
|
//
|
||||||
|
// Model-specific handling:
|
||||||
|
// - GLM-4.7: Uses extra_body={"thinking": {"type": "enabled"}, "clear_thinking": false}
|
||||||
|
// - MiniMax-M2.1: Uses reasoning_split=true for OpenAI-style reasoning separation
|
||||||
|
// - Other iFlow models: Uses chat_template_kwargs.enable_thinking (boolean)
|
||||||
|
func applyIFlowThinkingConfig(body []byte) []byte {
|
||||||
|
effort := gjson.GetBytes(body, "reasoning_effort")
|
||||||
|
model := strings.ToLower(gjson.GetBytes(body, "model").String())
|
||||||
|
|
||||||
|
// Check if thinking should be enabled
|
||||||
|
val := ""
|
||||||
|
if effort.Exists() {
|
||||||
|
val = strings.ToLower(strings.TrimSpace(effort.String()))
|
||||||
|
}
|
||||||
|
enableThinking := effort.Exists() && val != "none" && val != ""
|
||||||
|
|
||||||
|
// Remove reasoning_effort as we'll convert to model-specific format
|
||||||
|
if effort.Exists() {
|
||||||
|
body, _ = sjson.DeleteBytes(body, "reasoning_effort")
|
||||||
|
}
|
||||||
|
|
||||||
|
// GLM-4.7: Use extra_body with thinking config and clear_thinking: false
|
||||||
|
if strings.HasPrefix(model, "glm-4.7") || strings.HasPrefix(model, "glm-4-7") {
|
||||||
|
if enableThinking {
|
||||||
|
body, _ = sjson.SetBytes(body, "extra_body.thinking.type", "enabled")
|
||||||
|
body, _ = sjson.SetBytes(body, "extra_body.clear_thinking", false)
|
||||||
|
}
|
||||||
|
return body
|
||||||
|
}
|
||||||
|
|
||||||
|
// MiniMax-M2.1: Use reasoning_split=true for interleaved thinking
|
||||||
|
if strings.HasPrefix(model, "minimax-m2.1") || strings.HasPrefix(model, "minimax-m2-1") {
|
||||||
|
if enableThinking {
|
||||||
|
body, _ = sjson.SetBytes(body, "reasoning_split", true)
|
||||||
|
}
|
||||||
|
return body
|
||||||
|
}
|
||||||
|
|
||||||
|
// Other iFlow models (including GLM-4.6): Use chat_template_kwargs.enable_thinking
|
||||||
|
if effort.Exists() {
|
||||||
|
body, _ = sjson.SetBytes(body, "chat_template_kwargs.enable_thinking", enableThinking)
|
||||||
|
}
|
||||||
|
|
||||||
return body
|
return body
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -10,6 +10,8 @@ import (
|
|||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
"net/http"
|
"net/http"
|
||||||
|
"os"
|
||||||
|
"path/filepath"
|
||||||
"strings"
|
"strings"
|
||||||
"sync"
|
"sync"
|
||||||
"time"
|
"time"
|
||||||
@@ -43,10 +45,15 @@ const (
|
|||||||
// Event Stream error type constants
|
// Event Stream error type constants
|
||||||
ErrStreamFatal = "fatal" // Connection/authentication errors, not recoverable
|
ErrStreamFatal = "fatal" // Connection/authentication errors, not recoverable
|
||||||
ErrStreamMalformed = "malformed" // Format errors, data cannot be parsed
|
ErrStreamMalformed = "malformed" // Format errors, data cannot be parsed
|
||||||
// kiroUserAgent matches amq2api format for User-Agent header
|
// kiroUserAgent matches amq2api format for User-Agent header (Amazon Q CLI style)
|
||||||
kiroUserAgent = "aws-sdk-rust/1.3.9 os/macos lang/rust/1.87.0"
|
kiroUserAgent = "aws-sdk-rust/1.3.9 os/macos lang/rust/1.87.0"
|
||||||
// kiroFullUserAgent is the complete x-amz-user-agent header matching amq2api
|
// kiroFullUserAgent is the complete x-amz-user-agent header matching amq2api (Amazon Q CLI style)
|
||||||
kiroFullUserAgent = "aws-sdk-rust/1.3.9 ua/2.1 api/ssooidc/1.88.0 os/macos lang/rust/1.87.0 m/E app/AmazonQ-For-CLI"
|
kiroFullUserAgent = "aws-sdk-rust/1.3.9 ua/2.1 api/ssooidc/1.88.0 os/macos lang/rust/1.87.0 m/E app/AmazonQ-For-CLI"
|
||||||
|
|
||||||
|
// Kiro IDE style headers (from kiro2api - for IDC auth)
|
||||||
|
kiroIDEUserAgent = "aws-sdk-js/1.0.18 ua/2.1 os/darwin#25.0.0 lang/js md/nodejs#20.16.0 api/codewhispererstreaming#1.0.18 m/E KiroIDE-0.2.13-66c23a8c5d15afabec89ef9954ef52a119f10d369df04d548fc6c1eac694b0d1"
|
||||||
|
kiroIDEAmzUserAgent = "aws-sdk-js/1.0.18 KiroIDE-0.2.13-66c23a8c5d15afabec89ef9954ef52a119f10d369df04d548fc6c1eac694b0d1"
|
||||||
|
kiroIDEAgentModeSpec = "spec"
|
||||||
)
|
)
|
||||||
|
|
||||||
// Real-time usage estimation configuration
|
// Real-time usage estimation configuration
|
||||||
@@ -101,11 +108,24 @@ var kiroEndpointConfigs = []kiroEndpointConfig{
|
|||||||
|
|
||||||
// getKiroEndpointConfigs returns the list of Kiro API endpoint configurations to try in order.
|
// getKiroEndpointConfigs returns the list of Kiro API endpoint configurations to try in order.
|
||||||
// Supports reordering based on "preferred_endpoint" in auth metadata/attributes.
|
// Supports reordering based on "preferred_endpoint" in auth metadata/attributes.
|
||||||
|
// For IDC auth method, automatically uses CodeWhisperer endpoint with CLI origin.
|
||||||
func getKiroEndpointConfigs(auth *cliproxyauth.Auth) []kiroEndpointConfig {
|
func getKiroEndpointConfigs(auth *cliproxyauth.Auth) []kiroEndpointConfig {
|
||||||
if auth == nil {
|
if auth == nil {
|
||||||
return kiroEndpointConfigs
|
return kiroEndpointConfigs
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// For IDC auth, use CodeWhisperer endpoint with AI_EDITOR origin (same as Social auth)
|
||||||
|
// Based on kiro2api analysis: IDC tokens work with CodeWhisperer endpoint using Bearer auth
|
||||||
|
// The difference is only in how tokens are refreshed (OIDC with clientId/clientSecret for IDC)
|
||||||
|
// NOT in how API calls are made - both Social and IDC use the same endpoint/origin
|
||||||
|
if auth.Metadata != nil {
|
||||||
|
authMethod, _ := auth.Metadata["auth_method"].(string)
|
||||||
|
if authMethod == "idc" {
|
||||||
|
log.Debugf("kiro: IDC auth, using CodeWhisperer endpoint")
|
||||||
|
return kiroEndpointConfigs
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// Check for preference
|
// Check for preference
|
||||||
var preference string
|
var preference string
|
||||||
if auth.Metadata != nil {
|
if auth.Metadata != nil {
|
||||||
@@ -162,6 +182,15 @@ type KiroExecutor struct {
|
|||||||
refreshMu sync.Mutex // Serializes token refresh operations to prevent race conditions
|
refreshMu sync.Mutex // Serializes token refresh operations to prevent race conditions
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// isIDCAuth checks if the auth uses IDC (Identity Center) authentication method.
|
||||||
|
func isIDCAuth(auth *cliproxyauth.Auth) bool {
|
||||||
|
if auth == nil || auth.Metadata == nil {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
authMethod, _ := auth.Metadata["auth_method"].(string)
|
||||||
|
return authMethod == "idc"
|
||||||
|
}
|
||||||
|
|
||||||
// buildKiroPayloadForFormat builds the Kiro API payload based on the source format.
|
// buildKiroPayloadForFormat builds the Kiro API payload based on the source format.
|
||||||
// This is critical because OpenAI and Claude formats have different tool structures:
|
// This is critical because OpenAI and Claude formats have different tool structures:
|
||||||
// - OpenAI: tools[].function.name, tools[].function.description
|
// - OpenAI: tools[].function.name, tools[].function.description
|
||||||
@@ -210,6 +239,10 @@ func (e *KiroExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, req
|
|||||||
log.Warnf("kiro: pre-request token refresh failed: %v", refreshErr)
|
log.Warnf("kiro: pre-request token refresh failed: %v", refreshErr)
|
||||||
} else if refreshedAuth != nil {
|
} else if refreshedAuth != nil {
|
||||||
auth = refreshedAuth
|
auth = refreshedAuth
|
||||||
|
// Persist the refreshed auth to file so subsequent requests use it
|
||||||
|
if persistErr := e.persistRefreshedAuth(auth); persistErr != nil {
|
||||||
|
log.Warnf("kiro: failed to persist refreshed auth: %v", persistErr)
|
||||||
|
}
|
||||||
accessToken, profileArn = kiroCredentials(auth)
|
accessToken, profileArn = kiroCredentials(auth)
|
||||||
log.Infof("kiro: token refreshed successfully before request")
|
log.Infof("kiro: token refreshed successfully before request")
|
||||||
}
|
}
|
||||||
@@ -262,15 +295,28 @@ func (e *KiroExecutor) executeWithRetry(ctx context.Context, auth *cliproxyauth.
|
|||||||
}
|
}
|
||||||
|
|
||||||
httpReq.Header.Set("Content-Type", kiroContentType)
|
httpReq.Header.Set("Content-Type", kiroContentType)
|
||||||
httpReq.Header.Set("Authorization", "Bearer "+accessToken)
|
|
||||||
httpReq.Header.Set("Accept", kiroAcceptStream)
|
httpReq.Header.Set("Accept", kiroAcceptStream)
|
||||||
// Use endpoint-specific X-Amz-Target (critical for avoiding 403 errors)
|
// Use endpoint-specific X-Amz-Target (critical for avoiding 403 errors)
|
||||||
httpReq.Header.Set("X-Amz-Target", endpointConfig.AmzTarget)
|
httpReq.Header.Set("X-Amz-Target", endpointConfig.AmzTarget)
|
||||||
httpReq.Header.Set("User-Agent", kiroUserAgent)
|
|
||||||
httpReq.Header.Set("X-Amz-User-Agent", kiroFullUserAgent)
|
// Use different headers based on auth type
|
||||||
|
// IDC auth uses Kiro IDE style headers (from kiro2api)
|
||||||
|
// Other auth types use Amazon Q CLI style headers
|
||||||
|
if isIDCAuth(auth) {
|
||||||
|
httpReq.Header.Set("User-Agent", kiroIDEUserAgent)
|
||||||
|
httpReq.Header.Set("X-Amz-User-Agent", kiroIDEAmzUserAgent)
|
||||||
|
httpReq.Header.Set("x-amzn-kiro-agent-mode", kiroIDEAgentModeSpec)
|
||||||
|
log.Debugf("kiro: using Kiro IDE headers for IDC auth")
|
||||||
|
} else {
|
||||||
|
httpReq.Header.Set("User-Agent", kiroUserAgent)
|
||||||
|
httpReq.Header.Set("X-Amz-User-Agent", kiroFullUserAgent)
|
||||||
|
}
|
||||||
httpReq.Header.Set("Amz-Sdk-Request", "attempt=1; max=3")
|
httpReq.Header.Set("Amz-Sdk-Request", "attempt=1; max=3")
|
||||||
httpReq.Header.Set("Amz-Sdk-Invocation-Id", uuid.New().String())
|
httpReq.Header.Set("Amz-Sdk-Invocation-Id", uuid.New().String())
|
||||||
|
|
||||||
|
// Bearer token authentication for all auth types (Builder ID, IDC, social, etc.)
|
||||||
|
httpReq.Header.Set("Authorization", "Bearer "+accessToken)
|
||||||
|
|
||||||
var attrs map[string]string
|
var attrs map[string]string
|
||||||
if auth != nil {
|
if auth != nil {
|
||||||
attrs = auth.Attributes
|
attrs = auth.Attributes
|
||||||
@@ -358,6 +404,11 @@ func (e *KiroExecutor) executeWithRetry(ctx context.Context, auth *cliproxyauth.
|
|||||||
|
|
||||||
if refreshedAuth != nil {
|
if refreshedAuth != nil {
|
||||||
auth = refreshedAuth
|
auth = refreshedAuth
|
||||||
|
// Persist the refreshed auth to file so subsequent requests use it
|
||||||
|
if persistErr := e.persistRefreshedAuth(auth); persistErr != nil {
|
||||||
|
log.Warnf("kiro: failed to persist refreshed auth: %v", persistErr)
|
||||||
|
// Continue anyway - the token is valid for this request
|
||||||
|
}
|
||||||
accessToken, profileArn = kiroCredentials(auth)
|
accessToken, profileArn = kiroCredentials(auth)
|
||||||
// Rebuild payload with new profile ARN if changed
|
// Rebuild payload with new profile ARN if changed
|
||||||
kiroPayload, _ = buildKiroPayloadForFormat(body, kiroModelID, profileArn, currentOrigin, isAgentic, isChatOnly, from, opts.Headers)
|
kiroPayload, _ = buildKiroPayloadForFormat(body, kiroModelID, profileArn, currentOrigin, isAgentic, isChatOnly, from, opts.Headers)
|
||||||
@@ -416,6 +467,11 @@ func (e *KiroExecutor) executeWithRetry(ctx context.Context, auth *cliproxyauth.
|
|||||||
}
|
}
|
||||||
if refreshedAuth != nil {
|
if refreshedAuth != nil {
|
||||||
auth = refreshedAuth
|
auth = refreshedAuth
|
||||||
|
// Persist the refreshed auth to file so subsequent requests use it
|
||||||
|
if persistErr := e.persistRefreshedAuth(auth); persistErr != nil {
|
||||||
|
log.Warnf("kiro: failed to persist refreshed auth: %v", persistErr)
|
||||||
|
// Continue anyway - the token is valid for this request
|
||||||
|
}
|
||||||
accessToken, profileArn = kiroCredentials(auth)
|
accessToken, profileArn = kiroCredentials(auth)
|
||||||
kiroPayload, _ = buildKiroPayloadForFormat(body, kiroModelID, profileArn, currentOrigin, isAgentic, isChatOnly, from, opts.Headers)
|
kiroPayload, _ = buildKiroPayloadForFormat(body, kiroModelID, profileArn, currentOrigin, isAgentic, isChatOnly, from, opts.Headers)
|
||||||
log.Infof("kiro: token refreshed for 403, retrying request")
|
log.Infof("kiro: token refreshed for 403, retrying request")
|
||||||
@@ -518,6 +574,10 @@ func (e *KiroExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Aut
|
|||||||
log.Warnf("kiro: pre-request token refresh failed: %v", refreshErr)
|
log.Warnf("kiro: pre-request token refresh failed: %v", refreshErr)
|
||||||
} else if refreshedAuth != nil {
|
} else if refreshedAuth != nil {
|
||||||
auth = refreshedAuth
|
auth = refreshedAuth
|
||||||
|
// Persist the refreshed auth to file so subsequent requests use it
|
||||||
|
if persistErr := e.persistRefreshedAuth(auth); persistErr != nil {
|
||||||
|
log.Warnf("kiro: failed to persist refreshed auth: %v", persistErr)
|
||||||
|
}
|
||||||
accessToken, profileArn = kiroCredentials(auth)
|
accessToken, profileArn = kiroCredentials(auth)
|
||||||
log.Infof("kiro: token refreshed successfully before stream request")
|
log.Infof("kiro: token refreshed successfully before stream request")
|
||||||
}
|
}
|
||||||
@@ -568,15 +628,28 @@ func (e *KiroExecutor) executeStreamWithRetry(ctx context.Context, auth *cliprox
|
|||||||
}
|
}
|
||||||
|
|
||||||
httpReq.Header.Set("Content-Type", kiroContentType)
|
httpReq.Header.Set("Content-Type", kiroContentType)
|
||||||
httpReq.Header.Set("Authorization", "Bearer "+accessToken)
|
|
||||||
httpReq.Header.Set("Accept", kiroAcceptStream)
|
httpReq.Header.Set("Accept", kiroAcceptStream)
|
||||||
// Use endpoint-specific X-Amz-Target (critical for avoiding 403 errors)
|
// Use endpoint-specific X-Amz-Target (critical for avoiding 403 errors)
|
||||||
httpReq.Header.Set("X-Amz-Target", endpointConfig.AmzTarget)
|
httpReq.Header.Set("X-Amz-Target", endpointConfig.AmzTarget)
|
||||||
httpReq.Header.Set("User-Agent", kiroUserAgent)
|
|
||||||
httpReq.Header.Set("X-Amz-User-Agent", kiroFullUserAgent)
|
// Use different headers based on auth type
|
||||||
|
// IDC auth uses Kiro IDE style headers (from kiro2api)
|
||||||
|
// Other auth types use Amazon Q CLI style headers
|
||||||
|
if isIDCAuth(auth) {
|
||||||
|
httpReq.Header.Set("User-Agent", kiroIDEUserAgent)
|
||||||
|
httpReq.Header.Set("X-Amz-User-Agent", kiroIDEAmzUserAgent)
|
||||||
|
httpReq.Header.Set("x-amzn-kiro-agent-mode", kiroIDEAgentModeSpec)
|
||||||
|
log.Debugf("kiro: using Kiro IDE headers for IDC auth")
|
||||||
|
} else {
|
||||||
|
httpReq.Header.Set("User-Agent", kiroUserAgent)
|
||||||
|
httpReq.Header.Set("X-Amz-User-Agent", kiroFullUserAgent)
|
||||||
|
}
|
||||||
httpReq.Header.Set("Amz-Sdk-Request", "attempt=1; max=3")
|
httpReq.Header.Set("Amz-Sdk-Request", "attempt=1; max=3")
|
||||||
httpReq.Header.Set("Amz-Sdk-Invocation-Id", uuid.New().String())
|
httpReq.Header.Set("Amz-Sdk-Invocation-Id", uuid.New().String())
|
||||||
|
|
||||||
|
// Bearer token authentication for all auth types (Builder ID, IDC, social, etc.)
|
||||||
|
httpReq.Header.Set("Authorization", "Bearer "+accessToken)
|
||||||
|
|
||||||
var attrs map[string]string
|
var attrs map[string]string
|
||||||
if auth != nil {
|
if auth != nil {
|
||||||
attrs = auth.Attributes
|
attrs = auth.Attributes
|
||||||
@@ -677,6 +750,11 @@ func (e *KiroExecutor) executeStreamWithRetry(ctx context.Context, auth *cliprox
|
|||||||
|
|
||||||
if refreshedAuth != nil {
|
if refreshedAuth != nil {
|
||||||
auth = refreshedAuth
|
auth = refreshedAuth
|
||||||
|
// Persist the refreshed auth to file so subsequent requests use it
|
||||||
|
if persistErr := e.persistRefreshedAuth(auth); persistErr != nil {
|
||||||
|
log.Warnf("kiro: failed to persist refreshed auth: %v", persistErr)
|
||||||
|
// Continue anyway - the token is valid for this request
|
||||||
|
}
|
||||||
accessToken, profileArn = kiroCredentials(auth)
|
accessToken, profileArn = kiroCredentials(auth)
|
||||||
// Rebuild payload with new profile ARN if changed
|
// Rebuild payload with new profile ARN if changed
|
||||||
kiroPayload, _ = buildKiroPayloadForFormat(body, kiroModelID, profileArn, currentOrigin, isAgentic, isChatOnly, from, opts.Headers)
|
kiroPayload, _ = buildKiroPayloadForFormat(body, kiroModelID, profileArn, currentOrigin, isAgentic, isChatOnly, from, opts.Headers)
|
||||||
@@ -735,6 +813,11 @@ func (e *KiroExecutor) executeStreamWithRetry(ctx context.Context, auth *cliprox
|
|||||||
}
|
}
|
||||||
if refreshedAuth != nil {
|
if refreshedAuth != nil {
|
||||||
auth = refreshedAuth
|
auth = refreshedAuth
|
||||||
|
// Persist the refreshed auth to file so subsequent requests use it
|
||||||
|
if persistErr := e.persistRefreshedAuth(auth); persistErr != nil {
|
||||||
|
log.Warnf("kiro: failed to persist refreshed auth: %v", persistErr)
|
||||||
|
// Continue anyway - the token is valid for this request
|
||||||
|
}
|
||||||
accessToken, profileArn = kiroCredentials(auth)
|
accessToken, profileArn = kiroCredentials(auth)
|
||||||
kiroPayload, _ = buildKiroPayloadForFormat(body, kiroModelID, profileArn, currentOrigin, isAgentic, isChatOnly, from, opts.Headers)
|
kiroPayload, _ = buildKiroPayloadForFormat(body, kiroModelID, profileArn, currentOrigin, isAgentic, isChatOnly, from, opts.Headers)
|
||||||
log.Infof("kiro: token refreshed for 403, retrying stream request")
|
log.Infof("kiro: token refreshed for 403, retrying stream request")
|
||||||
@@ -1001,12 +1084,12 @@ func getEffectiveProfileArn(auth *cliproxyauth.Auth, profileArn string) string {
|
|||||||
// This consolidates the auth_method check that was previously done separately.
|
// This consolidates the auth_method check that was previously done separately.
|
||||||
func getEffectiveProfileArnWithWarning(auth *cliproxyauth.Auth, profileArn string) string {
|
func getEffectiveProfileArnWithWarning(auth *cliproxyauth.Auth, profileArn string) string {
|
||||||
if auth != nil && auth.Metadata != nil {
|
if auth != nil && auth.Metadata != nil {
|
||||||
if authMethod, ok := auth.Metadata["auth_method"].(string); ok && authMethod == "builder-id" {
|
if authMethod, ok := auth.Metadata["auth_method"].(string); ok && (authMethod == "builder-id" || authMethod == "idc") {
|
||||||
// builder-id auth doesn't need profileArn
|
// builder-id and idc auth don't need profileArn
|
||||||
return ""
|
return ""
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
// For non-builder-id auth (social auth), profileArn is required
|
// For non-builder-id/idc auth (social auth), profileArn is required
|
||||||
if profileArn == "" {
|
if profileArn == "" {
|
||||||
log.Warnf("kiro: profile ARN not found in auth, API calls may fail")
|
log.Warnf("kiro: profile ARN not found in auth, API calls may fail")
|
||||||
}
|
}
|
||||||
@@ -3010,6 +3093,7 @@ func (e *KiroExecutor) Refresh(ctx context.Context, auth *cliproxyauth.Auth) (*c
|
|||||||
var refreshToken string
|
var refreshToken string
|
||||||
var clientID, clientSecret string
|
var clientID, clientSecret string
|
||||||
var authMethod string
|
var authMethod string
|
||||||
|
var region, startURL string
|
||||||
|
|
||||||
if auth.Metadata != nil {
|
if auth.Metadata != nil {
|
||||||
if rt, ok := auth.Metadata["refresh_token"].(string); ok {
|
if rt, ok := auth.Metadata["refresh_token"].(string); ok {
|
||||||
@@ -3024,6 +3108,12 @@ func (e *KiroExecutor) Refresh(ctx context.Context, auth *cliproxyauth.Auth) (*c
|
|||||||
if am, ok := auth.Metadata["auth_method"].(string); ok {
|
if am, ok := auth.Metadata["auth_method"].(string); ok {
|
||||||
authMethod = am
|
authMethod = am
|
||||||
}
|
}
|
||||||
|
if r, ok := auth.Metadata["region"].(string); ok {
|
||||||
|
region = r
|
||||||
|
}
|
||||||
|
if su, ok := auth.Metadata["start_url"].(string); ok {
|
||||||
|
startURL = su
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if refreshToken == "" {
|
if refreshToken == "" {
|
||||||
@@ -3033,12 +3123,20 @@ func (e *KiroExecutor) Refresh(ctx context.Context, auth *cliproxyauth.Auth) (*c
|
|||||||
var tokenData *kiroauth.KiroTokenData
|
var tokenData *kiroauth.KiroTokenData
|
||||||
var err error
|
var err error
|
||||||
|
|
||||||
// Use SSO OIDC refresh for AWS Builder ID, otherwise use Kiro's OAuth refresh endpoint
|
ssoClient := kiroauth.NewSSOOIDCClient(e.cfg)
|
||||||
if clientID != "" && clientSecret != "" && authMethod == "builder-id" {
|
|
||||||
|
// Use SSO OIDC refresh for AWS Builder ID or IDC, otherwise use Kiro's OAuth refresh endpoint
|
||||||
|
switch {
|
||||||
|
case clientID != "" && clientSecret != "" && authMethod == "idc" && region != "":
|
||||||
|
// IDC refresh with region-specific endpoint
|
||||||
|
log.Debugf("kiro executor: using SSO OIDC refresh for IDC (region=%s)", region)
|
||||||
|
tokenData, err = ssoClient.RefreshTokenWithRegion(ctx, clientID, clientSecret, refreshToken, region, startURL)
|
||||||
|
case clientID != "" && clientSecret != "" && authMethod == "builder-id":
|
||||||
|
// Builder ID refresh with default endpoint
|
||||||
log.Debugf("kiro executor: using SSO OIDC refresh for AWS Builder ID")
|
log.Debugf("kiro executor: using SSO OIDC refresh for AWS Builder ID")
|
||||||
ssoClient := kiroauth.NewSSOOIDCClient(e.cfg)
|
|
||||||
tokenData, err = ssoClient.RefreshToken(ctx, clientID, clientSecret, refreshToken)
|
tokenData, err = ssoClient.RefreshToken(ctx, clientID, clientSecret, refreshToken)
|
||||||
} else {
|
default:
|
||||||
|
// Fallback to Kiro's OAuth refresh endpoint (for social auth: Google/GitHub)
|
||||||
log.Debugf("kiro executor: using Kiro OAuth refresh endpoint")
|
log.Debugf("kiro executor: using Kiro OAuth refresh endpoint")
|
||||||
oauth := kiroauth.NewKiroOAuth(e.cfg)
|
oauth := kiroauth.NewKiroOAuth(e.cfg)
|
||||||
tokenData, err = oauth.RefreshToken(ctx, refreshToken)
|
tokenData, err = oauth.RefreshToken(ctx, refreshToken)
|
||||||
@@ -3094,6 +3192,53 @@ func (e *KiroExecutor) Refresh(ctx context.Context, auth *cliproxyauth.Auth) (*c
|
|||||||
return updated, nil
|
return updated, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// persistRefreshedAuth persists a refreshed auth record to disk.
|
||||||
|
// This ensures token refreshes from inline retry are saved to the auth file.
|
||||||
|
func (e *KiroExecutor) persistRefreshedAuth(auth *cliproxyauth.Auth) error {
|
||||||
|
if auth == nil || auth.Metadata == nil {
|
||||||
|
return fmt.Errorf("kiro executor: cannot persist nil auth or metadata")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Determine the file path from auth attributes or filename
|
||||||
|
var authPath string
|
||||||
|
if auth.Attributes != nil {
|
||||||
|
if p := strings.TrimSpace(auth.Attributes["path"]); p != "" {
|
||||||
|
authPath = p
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if authPath == "" {
|
||||||
|
fileName := strings.TrimSpace(auth.FileName)
|
||||||
|
if fileName == "" {
|
||||||
|
return fmt.Errorf("kiro executor: auth has no file path or filename")
|
||||||
|
}
|
||||||
|
if filepath.IsAbs(fileName) {
|
||||||
|
authPath = fileName
|
||||||
|
} else if e.cfg != nil && e.cfg.AuthDir != "" {
|
||||||
|
authPath = filepath.Join(e.cfg.AuthDir, fileName)
|
||||||
|
} else {
|
||||||
|
return fmt.Errorf("kiro executor: cannot determine auth file path")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Marshal metadata to JSON
|
||||||
|
raw, err := json.Marshal(auth.Metadata)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("kiro executor: marshal metadata failed: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Write to temp file first, then rename (atomic write)
|
||||||
|
tmp := authPath + ".tmp"
|
||||||
|
if err := os.WriteFile(tmp, raw, 0o600); err != nil {
|
||||||
|
return fmt.Errorf("kiro executor: write temp auth file failed: %w", err)
|
||||||
|
}
|
||||||
|
if err := os.Rename(tmp, authPath); err != nil {
|
||||||
|
return fmt.Errorf("kiro executor: rename auth file failed: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
log.Debugf("kiro executor: persisted refreshed auth to %s", authPath)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
// isTokenExpired checks if a JWT access token has expired.
|
// isTokenExpired checks if a JWT access token has expired.
|
||||||
// Returns true if the token is expired or cannot be parsed.
|
// Returns true if the token is expired or cannot be parsed.
|
||||||
func (e *KiroExecutor) isTokenExpired(accessToken string) bool {
|
func (e *KiroExecutor) isTokenExpired(accessToken string) bool {
|
||||||
|
|||||||
@@ -19,7 +19,7 @@ type usageReporter struct {
|
|||||||
provider string
|
provider string
|
||||||
model string
|
model string
|
||||||
authID string
|
authID string
|
||||||
authIndex uint64
|
authIndex string
|
||||||
apiKey string
|
apiKey string
|
||||||
source string
|
source string
|
||||||
requestedAt time.Time
|
requestedAt time.Time
|
||||||
@@ -275,6 +275,20 @@ func parseClaudeStreamUsage(line []byte) (usage.Detail, bool) {
|
|||||||
return detail, true
|
return detail, true
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func parseGeminiFamilyUsageDetail(node gjson.Result) usage.Detail {
|
||||||
|
detail := usage.Detail{
|
||||||
|
InputTokens: node.Get("promptTokenCount").Int(),
|
||||||
|
OutputTokens: node.Get("candidatesTokenCount").Int(),
|
||||||
|
ReasoningTokens: node.Get("thoughtsTokenCount").Int(),
|
||||||
|
TotalTokens: node.Get("totalTokenCount").Int(),
|
||||||
|
CachedTokens: node.Get("cachedContentTokenCount").Int(),
|
||||||
|
}
|
||||||
|
if detail.TotalTokens == 0 {
|
||||||
|
detail.TotalTokens = detail.InputTokens + detail.OutputTokens + detail.ReasoningTokens
|
||||||
|
}
|
||||||
|
return detail
|
||||||
|
}
|
||||||
|
|
||||||
func parseGeminiCLIUsage(data []byte) usage.Detail {
|
func parseGeminiCLIUsage(data []byte) usage.Detail {
|
||||||
usageNode := gjson.ParseBytes(data)
|
usageNode := gjson.ParseBytes(data)
|
||||||
node := usageNode.Get("response.usageMetadata")
|
node := usageNode.Get("response.usageMetadata")
|
||||||
@@ -284,16 +298,7 @@ func parseGeminiCLIUsage(data []byte) usage.Detail {
|
|||||||
if !node.Exists() {
|
if !node.Exists() {
|
||||||
return usage.Detail{}
|
return usage.Detail{}
|
||||||
}
|
}
|
||||||
detail := usage.Detail{
|
return parseGeminiFamilyUsageDetail(node)
|
||||||
InputTokens: node.Get("promptTokenCount").Int(),
|
|
||||||
OutputTokens: node.Get("candidatesTokenCount").Int(),
|
|
||||||
ReasoningTokens: node.Get("thoughtsTokenCount").Int(),
|
|
||||||
TotalTokens: node.Get("totalTokenCount").Int(),
|
|
||||||
}
|
|
||||||
if detail.TotalTokens == 0 {
|
|
||||||
detail.TotalTokens = detail.InputTokens + detail.OutputTokens + detail.ReasoningTokens
|
|
||||||
}
|
|
||||||
return detail
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func parseGeminiUsage(data []byte) usage.Detail {
|
func parseGeminiUsage(data []byte) usage.Detail {
|
||||||
@@ -305,16 +310,7 @@ func parseGeminiUsage(data []byte) usage.Detail {
|
|||||||
if !node.Exists() {
|
if !node.Exists() {
|
||||||
return usage.Detail{}
|
return usage.Detail{}
|
||||||
}
|
}
|
||||||
detail := usage.Detail{
|
return parseGeminiFamilyUsageDetail(node)
|
||||||
InputTokens: node.Get("promptTokenCount").Int(),
|
|
||||||
OutputTokens: node.Get("candidatesTokenCount").Int(),
|
|
||||||
ReasoningTokens: node.Get("thoughtsTokenCount").Int(),
|
|
||||||
TotalTokens: node.Get("totalTokenCount").Int(),
|
|
||||||
}
|
|
||||||
if detail.TotalTokens == 0 {
|
|
||||||
detail.TotalTokens = detail.InputTokens + detail.OutputTokens + detail.ReasoningTokens
|
|
||||||
}
|
|
||||||
return detail
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func parseGeminiStreamUsage(line []byte) (usage.Detail, bool) {
|
func parseGeminiStreamUsage(line []byte) (usage.Detail, bool) {
|
||||||
@@ -329,16 +325,7 @@ func parseGeminiStreamUsage(line []byte) (usage.Detail, bool) {
|
|||||||
if !node.Exists() {
|
if !node.Exists() {
|
||||||
return usage.Detail{}, false
|
return usage.Detail{}, false
|
||||||
}
|
}
|
||||||
detail := usage.Detail{
|
return parseGeminiFamilyUsageDetail(node), true
|
||||||
InputTokens: node.Get("promptTokenCount").Int(),
|
|
||||||
OutputTokens: node.Get("candidatesTokenCount").Int(),
|
|
||||||
ReasoningTokens: node.Get("thoughtsTokenCount").Int(),
|
|
||||||
TotalTokens: node.Get("totalTokenCount").Int(),
|
|
||||||
}
|
|
||||||
if detail.TotalTokens == 0 {
|
|
||||||
detail.TotalTokens = detail.InputTokens + detail.OutputTokens + detail.ReasoningTokens
|
|
||||||
}
|
|
||||||
return detail, true
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func parseGeminiCLIStreamUsage(line []byte) (usage.Detail, bool) {
|
func parseGeminiCLIStreamUsage(line []byte) (usage.Detail, bool) {
|
||||||
@@ -353,16 +340,7 @@ func parseGeminiCLIStreamUsage(line []byte) (usage.Detail, bool) {
|
|||||||
if !node.Exists() {
|
if !node.Exists() {
|
||||||
return usage.Detail{}, false
|
return usage.Detail{}, false
|
||||||
}
|
}
|
||||||
detail := usage.Detail{
|
return parseGeminiFamilyUsageDetail(node), true
|
||||||
InputTokens: node.Get("promptTokenCount").Int(),
|
|
||||||
OutputTokens: node.Get("candidatesTokenCount").Int(),
|
|
||||||
ReasoningTokens: node.Get("thoughtsTokenCount").Int(),
|
|
||||||
TotalTokens: node.Get("totalTokenCount").Int(),
|
|
||||||
}
|
|
||||||
if detail.TotalTokens == 0 {
|
|
||||||
detail.TotalTokens = detail.InputTokens + detail.OutputTokens + detail.ReasoningTokens
|
|
||||||
}
|
|
||||||
return detail, true
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func parseAntigravityUsage(data []byte) usage.Detail {
|
func parseAntigravityUsage(data []byte) usage.Detail {
|
||||||
@@ -377,16 +355,7 @@ func parseAntigravityUsage(data []byte) usage.Detail {
|
|||||||
if !node.Exists() {
|
if !node.Exists() {
|
||||||
return usage.Detail{}
|
return usage.Detail{}
|
||||||
}
|
}
|
||||||
detail := usage.Detail{
|
return parseGeminiFamilyUsageDetail(node)
|
||||||
InputTokens: node.Get("promptTokenCount").Int(),
|
|
||||||
OutputTokens: node.Get("candidatesTokenCount").Int(),
|
|
||||||
ReasoningTokens: node.Get("thoughtsTokenCount").Int(),
|
|
||||||
TotalTokens: node.Get("totalTokenCount").Int(),
|
|
||||||
}
|
|
||||||
if detail.TotalTokens == 0 {
|
|
||||||
detail.TotalTokens = detail.InputTokens + detail.OutputTokens + detail.ReasoningTokens
|
|
||||||
}
|
|
||||||
return detail
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func parseAntigravityStreamUsage(line []byte) (usage.Detail, bool) {
|
func parseAntigravityStreamUsage(line []byte) (usage.Detail, bool) {
|
||||||
@@ -404,16 +373,7 @@ func parseAntigravityStreamUsage(line []byte) (usage.Detail, bool) {
|
|||||||
if !node.Exists() {
|
if !node.Exists() {
|
||||||
return usage.Detail{}, false
|
return usage.Detail{}, false
|
||||||
}
|
}
|
||||||
detail := usage.Detail{
|
return parseGeminiFamilyUsageDetail(node), true
|
||||||
InputTokens: node.Get("promptTokenCount").Int(),
|
|
||||||
OutputTokens: node.Get("candidatesTokenCount").Int(),
|
|
||||||
ReasoningTokens: node.Get("thoughtsTokenCount").Int(),
|
|
||||||
TotalTokens: node.Get("totalTokenCount").Int(),
|
|
||||||
}
|
|
||||||
if detail.TotalTokens == 0 {
|
|
||||||
detail.TotalTokens = detail.InputTokens + detail.OutputTokens + detail.ReasoningTokens
|
|
||||||
}
|
|
||||||
return detail, true
|
|
||||||
}
|
}
|
||||||
|
|
||||||
var stopChunkWithoutUsage sync.Map
|
var stopChunkWithoutUsage sync.Map
|
||||||
@@ -522,12 +482,16 @@ func StripUsageMetadataFromJSON(rawJSON []byte) ([]byte, bool) {
|
|||||||
cleaned := jsonBytes
|
cleaned := jsonBytes
|
||||||
var changed bool
|
var changed bool
|
||||||
|
|
||||||
if gjson.GetBytes(cleaned, "usageMetadata").Exists() {
|
if usageMetadata = gjson.GetBytes(cleaned, "usageMetadata"); usageMetadata.Exists() {
|
||||||
|
// Rename usageMetadata to cpaUsageMetadata in the message_start event of Claude
|
||||||
|
cleaned, _ = sjson.SetRawBytes(cleaned, "cpaUsageMetadata", []byte(usageMetadata.Raw))
|
||||||
cleaned, _ = sjson.DeleteBytes(cleaned, "usageMetadata")
|
cleaned, _ = sjson.DeleteBytes(cleaned, "usageMetadata")
|
||||||
changed = true
|
changed = true
|
||||||
}
|
}
|
||||||
|
|
||||||
if gjson.GetBytes(cleaned, "response.usageMetadata").Exists() {
|
if usageMetadata = gjson.GetBytes(cleaned, "response.usageMetadata"); usageMetadata.Exists() {
|
||||||
|
// Rename usageMetadata to cpaUsageMetadata in the message_start event of Claude
|
||||||
|
cleaned, _ = sjson.SetRawBytes(cleaned, "response.cpaUsageMetadata", []byte(usageMetadata.Raw))
|
||||||
cleaned, _ = sjson.DeleteBytes(cleaned, "response.usageMetadata")
|
cleaned, _ = sjson.DeleteBytes(cleaned, "response.usageMetadata")
|
||||||
changed = true
|
changed = true
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -99,6 +99,14 @@ func ConvertAntigravityResponseToClaude(_ context.Context, _ string, originalReq
|
|||||||
// This follows the Claude Code API specification for streaming message initialization
|
// This follows the Claude Code API specification for streaming message initialization
|
||||||
messageStartTemplate := `{"type": "message_start", "message": {"id": "msg_1nZdL29xx5MUA1yADyHTEsnR8uuvGzszyY", "type": "message", "role": "assistant", "content": [], "model": "claude-3-5-sonnet-20241022", "stop_reason": null, "stop_sequence": null, "usage": {"input_tokens": 0, "output_tokens": 0}}}`
|
messageStartTemplate := `{"type": "message_start", "message": {"id": "msg_1nZdL29xx5MUA1yADyHTEsnR8uuvGzszyY", "type": "message", "role": "assistant", "content": [], "model": "claude-3-5-sonnet-20241022", "stop_reason": null, "stop_sequence": null, "usage": {"input_tokens": 0, "output_tokens": 0}}}`
|
||||||
|
|
||||||
|
// Use cpaUsageMetadata within the message_start event for Claude.
|
||||||
|
if promptTokenCount := gjson.GetBytes(rawJSON, "response.cpaUsageMetadata.promptTokenCount"); promptTokenCount.Exists() {
|
||||||
|
messageStartTemplate, _ = sjson.Set(messageStartTemplate, "message.usage.input_tokens", promptTokenCount.Int())
|
||||||
|
}
|
||||||
|
if candidatesTokenCount := gjson.GetBytes(rawJSON, "response.cpaUsageMetadata.candidatesTokenCount"); candidatesTokenCount.Exists() {
|
||||||
|
messageStartTemplate, _ = sjson.Set(messageStartTemplate, "message.usage.output_tokens", candidatesTokenCount.Int())
|
||||||
|
}
|
||||||
|
|
||||||
// Override default values with actual response metadata if available from the Gemini CLI response
|
// Override default values with actual response metadata if available from the Gemini CLI response
|
||||||
if modelVersionResult := gjson.GetBytes(rawJSON, "response.modelVersion"); modelVersionResult.Exists() {
|
if modelVersionResult := gjson.GetBytes(rawJSON, "response.modelVersion"); modelVersionResult.Exists() {
|
||||||
messageStartTemplate, _ = sjson.Set(messageStartTemplate, "message.model", modelVersionResult.String())
|
messageStartTemplate, _ = sjson.Set(messageStartTemplate, "message.model", modelVersionResult.String())
|
||||||
@@ -271,11 +279,11 @@ func ConvertAntigravityResponseToClaude(_ context.Context, _ string, originalReq
|
|||||||
|
|
||||||
if usageResult := gjson.GetBytes(rawJSON, "response.usageMetadata"); usageResult.Exists() {
|
if usageResult := gjson.GetBytes(rawJSON, "response.usageMetadata"); usageResult.Exists() {
|
||||||
params.HasUsageMetadata = true
|
params.HasUsageMetadata = true
|
||||||
params.PromptTokenCount = usageResult.Get("promptTokenCount").Int()
|
params.CachedTokenCount = usageResult.Get("cachedContentTokenCount").Int()
|
||||||
|
params.PromptTokenCount = usageResult.Get("promptTokenCount").Int() - params.CachedTokenCount
|
||||||
params.CandidatesTokenCount = usageResult.Get("candidatesTokenCount").Int()
|
params.CandidatesTokenCount = usageResult.Get("candidatesTokenCount").Int()
|
||||||
params.ThoughtsTokenCount = usageResult.Get("thoughtsTokenCount").Int()
|
params.ThoughtsTokenCount = usageResult.Get("thoughtsTokenCount").Int()
|
||||||
params.TotalTokenCount = usageResult.Get("totalTokenCount").Int()
|
params.TotalTokenCount = usageResult.Get("totalTokenCount").Int()
|
||||||
params.CachedTokenCount = usageResult.Get("cachedContentTokenCount").Int()
|
|
||||||
if params.CandidatesTokenCount == 0 && params.TotalTokenCount > 0 {
|
if params.CandidatesTokenCount == 0 && params.TotalTokenCount > 0 {
|
||||||
params.CandidatesTokenCount = params.TotalTokenCount - params.PromptTokenCount - params.ThoughtsTokenCount
|
params.CandidatesTokenCount = params.TotalTokenCount - params.PromptTokenCount - params.ThoughtsTokenCount
|
||||||
if params.CandidatesTokenCount < 0 {
|
if params.CandidatesTokenCount < 0 {
|
||||||
|
|||||||
@@ -247,10 +247,30 @@ func ConvertOpenAIRequestToAntigravity(modelName string, inputRawJSON []byte, _
|
|||||||
} else if role == "assistant" {
|
} else if role == "assistant" {
|
||||||
node := []byte(`{"role":"model","parts":[]}`)
|
node := []byte(`{"role":"model","parts":[]}`)
|
||||||
p := 0
|
p := 0
|
||||||
if content.Type == gjson.String {
|
if content.Type == gjson.String && content.String() != "" {
|
||||||
node, _ = sjson.SetBytes(node, "parts.-1.text", content.String())
|
node, _ = sjson.SetBytes(node, "parts.-1.text", content.String())
|
||||||
out, _ = sjson.SetRawBytes(out, "request.contents.-1", node)
|
|
||||||
p++
|
p++
|
||||||
|
} else if content.IsArray() {
|
||||||
|
// Assistant multimodal content (e.g. text + image) -> single model content with parts
|
||||||
|
for _, item := range content.Array() {
|
||||||
|
switch item.Get("type").String() {
|
||||||
|
case "text":
|
||||||
|
p++
|
||||||
|
case "image_url":
|
||||||
|
// If the assistant returned an inline data URL, preserve it for history fidelity.
|
||||||
|
imageURL := item.Get("image_url.url").String()
|
||||||
|
if len(imageURL) > 5 { // expect data:...
|
||||||
|
pieces := strings.SplitN(imageURL[5:], ";", 2)
|
||||||
|
if len(pieces) == 2 && len(pieces[1]) > 7 {
|
||||||
|
mime := pieces[0]
|
||||||
|
data := pieces[1][7:]
|
||||||
|
node, _ = sjson.SetBytes(node, "parts."+itoa(p)+".inlineData.mime_type", mime)
|
||||||
|
node, _ = sjson.SetBytes(node, "parts."+itoa(p)+".inlineData.data", data)
|
||||||
|
p++
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Tool calls -> single model content with functionCall parts
|
// Tool calls -> single model content with functionCall parts
|
||||||
@@ -305,6 +325,8 @@ func ConvertOpenAIRequestToAntigravity(modelName string, inputRawJSON []byte, _
|
|||||||
if pp > 0 {
|
if pp > 0 {
|
||||||
out, _ = sjson.SetRawBytes(out, "request.contents.-1", toolNode)
|
out, _ = sjson.SetRawBytes(out, "request.contents.-1", toolNode)
|
||||||
}
|
}
|
||||||
|
} else {
|
||||||
|
out, _ = sjson.SetRawBytes(out, "request.contents.-1", node)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -87,15 +87,15 @@ func ConvertAntigravityResponseToOpenAI(_ context.Context, _ string, originalReq
|
|||||||
|
|
||||||
// Extract and set usage metadata (token counts).
|
// Extract and set usage metadata (token counts).
|
||||||
if usageResult := gjson.GetBytes(rawJSON, "response.usageMetadata"); usageResult.Exists() {
|
if usageResult := gjson.GetBytes(rawJSON, "response.usageMetadata"); usageResult.Exists() {
|
||||||
|
cachedTokenCount := usageResult.Get("cachedContentTokenCount").Int()
|
||||||
if candidatesTokenCountResult := usageResult.Get("candidatesTokenCount"); candidatesTokenCountResult.Exists() {
|
if candidatesTokenCountResult := usageResult.Get("candidatesTokenCount"); candidatesTokenCountResult.Exists() {
|
||||||
template, _ = sjson.Set(template, "usage.completion_tokens", candidatesTokenCountResult.Int())
|
template, _ = sjson.Set(template, "usage.completion_tokens", candidatesTokenCountResult.Int())
|
||||||
}
|
}
|
||||||
if totalTokenCountResult := usageResult.Get("totalTokenCount"); totalTokenCountResult.Exists() {
|
if totalTokenCountResult := usageResult.Get("totalTokenCount"); totalTokenCountResult.Exists() {
|
||||||
template, _ = sjson.Set(template, "usage.total_tokens", totalTokenCountResult.Int())
|
template, _ = sjson.Set(template, "usage.total_tokens", totalTokenCountResult.Int())
|
||||||
}
|
}
|
||||||
promptTokenCount := usageResult.Get("promptTokenCount").Int()
|
promptTokenCount := usageResult.Get("promptTokenCount").Int() - cachedTokenCount
|
||||||
thoughtsTokenCount := usageResult.Get("thoughtsTokenCount").Int()
|
thoughtsTokenCount := usageResult.Get("thoughtsTokenCount").Int()
|
||||||
cachedTokenCount := usageResult.Get("cachedContentTokenCount").Int()
|
|
||||||
template, _ = sjson.Set(template, "usage.prompt_tokens", promptTokenCount+thoughtsTokenCount)
|
template, _ = sjson.Set(template, "usage.prompt_tokens", promptTokenCount+thoughtsTokenCount)
|
||||||
if thoughtsTokenCount > 0 {
|
if thoughtsTokenCount > 0 {
|
||||||
template, _ = sjson.Set(template, "usage.completion_tokens_details.reasoning_tokens", thoughtsTokenCount)
|
template, _ = sjson.Set(template, "usage.completion_tokens_details.reasoning_tokens", thoughtsTokenCount)
|
||||||
@@ -181,12 +181,14 @@ func ConvertAntigravityResponseToOpenAI(_ context.Context, _ string, originalReq
|
|||||||
mimeType = "image/png"
|
mimeType = "image/png"
|
||||||
}
|
}
|
||||||
imageURL := fmt.Sprintf("data:%s;base64,%s", mimeType, data)
|
imageURL := fmt.Sprintf("data:%s;base64,%s", mimeType, data)
|
||||||
imagePayload := `{"image_url":{"url":""},"type":"image_url"}`
|
|
||||||
imagePayload, _ = sjson.Set(imagePayload, "image_url.url", imageURL)
|
|
||||||
imagesResult := gjson.Get(template, "choices.0.delta.images")
|
imagesResult := gjson.Get(template, "choices.0.delta.images")
|
||||||
if !imagesResult.Exists() || !imagesResult.IsArray() {
|
if !imagesResult.Exists() || !imagesResult.IsArray() {
|
||||||
template, _ = sjson.SetRaw(template, "choices.0.delta.images", `[]`)
|
template, _ = sjson.SetRaw(template, "choices.0.delta.images", `[]`)
|
||||||
}
|
}
|
||||||
|
imageIndex := len(gjson.Get(template, "choices.0.delta.images").Array())
|
||||||
|
imagePayload := `{"type":"image_url","image_url":{"url":""}}`
|
||||||
|
imagePayload, _ = sjson.Set(imagePayload, "index", imageIndex)
|
||||||
|
imagePayload, _ = sjson.Set(imagePayload, "image_url.url", imageURL)
|
||||||
template, _ = sjson.Set(template, "choices.0.delta.role", "assistant")
|
template, _ = sjson.Set(template, "choices.0.delta.role", "assistant")
|
||||||
template, _ = sjson.SetRaw(template, "choices.0.delta.images.-1", imagePayload)
|
template, _ = sjson.SetRaw(template, "choices.0.delta.images.-1", imagePayload)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -209,9 +209,12 @@ func ConvertClaudeResponseToOpenAI(_ context.Context, modelName string, original
|
|||||||
if usage := root.Get("usage"); usage.Exists() {
|
if usage := root.Get("usage"); usage.Exists() {
|
||||||
inputTokens := usage.Get("input_tokens").Int()
|
inputTokens := usage.Get("input_tokens").Int()
|
||||||
outputTokens := usage.Get("output_tokens").Int()
|
outputTokens := usage.Get("output_tokens").Int()
|
||||||
template, _ = sjson.Set(template, "usage.prompt_tokens", inputTokens)
|
cacheReadInputTokens := usage.Get("cache_read_input_tokens").Int()
|
||||||
|
cacheCreationInputTokens := usage.Get("cache_creation_input_tokens").Int()
|
||||||
|
template, _ = sjson.Set(template, "usage.prompt_tokens", inputTokens+cacheCreationInputTokens)
|
||||||
template, _ = sjson.Set(template, "usage.completion_tokens", outputTokens)
|
template, _ = sjson.Set(template, "usage.completion_tokens", outputTokens)
|
||||||
template, _ = sjson.Set(template, "usage.total_tokens", inputTokens+outputTokens)
|
template, _ = sjson.Set(template, "usage.total_tokens", inputTokens+outputTokens)
|
||||||
|
template, _ = sjson.Set(template, "usage.prompt_tokens_details.cached_tokens", cacheReadInputTokens)
|
||||||
}
|
}
|
||||||
return []string{template}
|
return []string{template}
|
||||||
|
|
||||||
@@ -285,8 +288,6 @@ func ConvertClaudeResponseToOpenAINonStream(_ context.Context, _ string, origina
|
|||||||
var messageID string
|
var messageID string
|
||||||
var model string
|
var model string
|
||||||
var createdAt int64
|
var createdAt int64
|
||||||
var inputTokens, outputTokens int64
|
|
||||||
var reasoningTokens int64
|
|
||||||
var stopReason string
|
var stopReason string
|
||||||
var contentParts []string
|
var contentParts []string
|
||||||
var reasoningParts []string
|
var reasoningParts []string
|
||||||
@@ -303,9 +304,6 @@ func ConvertClaudeResponseToOpenAINonStream(_ context.Context, _ string, origina
|
|||||||
messageID = message.Get("id").String()
|
messageID = message.Get("id").String()
|
||||||
model = message.Get("model").String()
|
model = message.Get("model").String()
|
||||||
createdAt = time.Now().Unix()
|
createdAt = time.Now().Unix()
|
||||||
if usage := message.Get("usage"); usage.Exists() {
|
|
||||||
inputTokens = usage.Get("input_tokens").Int()
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
case "content_block_start":
|
case "content_block_start":
|
||||||
@@ -368,11 +366,14 @@ func ConvertClaudeResponseToOpenAINonStream(_ context.Context, _ string, origina
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
if usage := root.Get("usage"); usage.Exists() {
|
if usage := root.Get("usage"); usage.Exists() {
|
||||||
outputTokens = usage.Get("output_tokens").Int()
|
inputTokens := usage.Get("input_tokens").Int()
|
||||||
// Estimate reasoning tokens from accumulated thinking content
|
outputTokens := usage.Get("output_tokens").Int()
|
||||||
if len(reasoningParts) > 0 {
|
cacheReadInputTokens := usage.Get("cache_read_input_tokens").Int()
|
||||||
reasoningTokens = int64(len(strings.Join(reasoningParts, "")) / 4) // Rough estimation
|
cacheCreationInputTokens := usage.Get("cache_creation_input_tokens").Int()
|
||||||
}
|
out, _ = sjson.Set(out, "usage.prompt_tokens", inputTokens+cacheCreationInputTokens)
|
||||||
|
out, _ = sjson.Set(out, "usage.completion_tokens", outputTokens)
|
||||||
|
out, _ = sjson.Set(out, "usage.total_tokens", inputTokens+outputTokens)
|
||||||
|
out, _ = sjson.Set(out, "usage.prompt_tokens_details.cached_tokens", cacheReadInputTokens)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -431,16 +432,5 @@ func ConvertClaudeResponseToOpenAINonStream(_ context.Context, _ string, origina
|
|||||||
out, _ = sjson.Set(out, "choices.0.finish_reason", mapAnthropicStopReasonToOpenAI(stopReason))
|
out, _ = sjson.Set(out, "choices.0.finish_reason", mapAnthropicStopReasonToOpenAI(stopReason))
|
||||||
}
|
}
|
||||||
|
|
||||||
// Set usage information including prompt tokens, completion tokens, and total tokens
|
|
||||||
totalTokens := inputTokens + outputTokens
|
|
||||||
out, _ = sjson.Set(out, "usage.prompt_tokens", inputTokens)
|
|
||||||
out, _ = sjson.Set(out, "usage.completion_tokens", outputTokens)
|
|
||||||
out, _ = sjson.Set(out, "usage.total_tokens", totalTokens)
|
|
||||||
|
|
||||||
// Add reasoning tokens to usage details if any reasoning content was processed
|
|
||||||
if reasoningTokens > 0 {
|
|
||||||
out, _ = sjson.Set(out, "usage.completion_tokens_details.reasoning_tokens", reasoningTokens)
|
|
||||||
}
|
|
||||||
|
|
||||||
return out
|
return out
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -114,13 +114,16 @@ func ConvertOpenAIResponsesRequestToClaude(modelName string, inputRawJSON []byte
|
|||||||
var builder strings.Builder
|
var builder strings.Builder
|
||||||
if parts := item.Get("content"); parts.Exists() && parts.IsArray() {
|
if parts := item.Get("content"); parts.Exists() && parts.IsArray() {
|
||||||
parts.ForEach(func(_, part gjson.Result) bool {
|
parts.ForEach(func(_, part gjson.Result) bool {
|
||||||
text := part.Get("text").String()
|
textResult := part.Get("text")
|
||||||
|
text := textResult.String()
|
||||||
if builder.Len() > 0 && text != "" {
|
if builder.Len() > 0 && text != "" {
|
||||||
builder.WriteByte('\n')
|
builder.WriteByte('\n')
|
||||||
}
|
}
|
||||||
builder.WriteString(text)
|
builder.WriteString(text)
|
||||||
return true
|
return true
|
||||||
})
|
})
|
||||||
|
} else if parts.Type == gjson.String {
|
||||||
|
builder.WriteString(parts.String())
|
||||||
}
|
}
|
||||||
instructionsText = builder.String()
|
instructionsText = builder.String()
|
||||||
if instructionsText != "" {
|
if instructionsText != "" {
|
||||||
@@ -207,6 +210,8 @@ func ConvertOpenAIResponsesRequestToClaude(modelName string, inputRawJSON []byte
|
|||||||
}
|
}
|
||||||
return true
|
return true
|
||||||
})
|
})
|
||||||
|
} else if parts.Type == gjson.String {
|
||||||
|
textAggregate.WriteString(parts.String())
|
||||||
}
|
}
|
||||||
|
|
||||||
// Fallback to given role if content types not decisive
|
// Fallback to given role if content types not decisive
|
||||||
|
|||||||
@@ -218,8 +218,29 @@ func ConvertOpenAIRequestToGeminiCLI(modelName string, inputRawJSON []byte, _ bo
|
|||||||
if content.Type == gjson.String {
|
if content.Type == gjson.String {
|
||||||
// Assistant text -> single model content
|
// Assistant text -> single model content
|
||||||
node, _ = sjson.SetBytes(node, "parts.-1.text", content.String())
|
node, _ = sjson.SetBytes(node, "parts.-1.text", content.String())
|
||||||
out, _ = sjson.SetRawBytes(out, "request.contents.-1", node)
|
|
||||||
p++
|
p++
|
||||||
|
} else if content.IsArray() {
|
||||||
|
// Assistant multimodal content (e.g. text + image) -> single model content with parts
|
||||||
|
for _, item := range content.Array() {
|
||||||
|
switch item.Get("type").String() {
|
||||||
|
case "text":
|
||||||
|
node, _ = sjson.SetBytes(node, "parts."+itoa(p)+".text", item.Get("text").String())
|
||||||
|
p++
|
||||||
|
case "image_url":
|
||||||
|
// If the assistant returned an inline data URL, preserve it for history fidelity.
|
||||||
|
imageURL := item.Get("image_url.url").String()
|
||||||
|
if len(imageURL) > 5 { // expect data:...
|
||||||
|
pieces := strings.SplitN(imageURL[5:], ";", 2)
|
||||||
|
if len(pieces) == 2 && len(pieces[1]) > 7 {
|
||||||
|
mime := pieces[0]
|
||||||
|
data := pieces[1][7:]
|
||||||
|
node, _ = sjson.SetBytes(node, "parts."+itoa(p)+".inlineData.mime_type", mime)
|
||||||
|
node, _ = sjson.SetBytes(node, "parts."+itoa(p)+".inlineData.data", data)
|
||||||
|
p++
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Tool calls -> single model content with functionCall parts
|
// Tool calls -> single model content with functionCall parts
|
||||||
@@ -260,6 +281,8 @@ func ConvertOpenAIRequestToGeminiCLI(modelName string, inputRawJSON []byte, _ bo
|
|||||||
if pp > 0 {
|
if pp > 0 {
|
||||||
out, _ = sjson.SetRawBytes(out, "request.contents.-1", toolNode)
|
out, _ = sjson.SetRawBytes(out, "request.contents.-1", toolNode)
|
||||||
}
|
}
|
||||||
|
} else {
|
||||||
|
out, _ = sjson.SetRawBytes(out, "request.contents.-1", node)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -170,12 +170,14 @@ func ConvertCliResponseToOpenAI(_ context.Context, _ string, originalRequestRawJ
|
|||||||
mimeType = "image/png"
|
mimeType = "image/png"
|
||||||
}
|
}
|
||||||
imageURL := fmt.Sprintf("data:%s;base64,%s", mimeType, data)
|
imageURL := fmt.Sprintf("data:%s;base64,%s", mimeType, data)
|
||||||
imagePayload := `{"image_url":{"url":""},"type":"image_url"}`
|
|
||||||
imagePayload, _ = sjson.Set(imagePayload, "image_url.url", imageURL)
|
|
||||||
imagesResult := gjson.Get(template, "choices.0.delta.images")
|
imagesResult := gjson.Get(template, "choices.0.delta.images")
|
||||||
if !imagesResult.Exists() || !imagesResult.IsArray() {
|
if !imagesResult.Exists() || !imagesResult.IsArray() {
|
||||||
template, _ = sjson.SetRaw(template, "choices.0.delta.images", `[]`)
|
template, _ = sjson.SetRaw(template, "choices.0.delta.images", `[]`)
|
||||||
}
|
}
|
||||||
|
imageIndex := len(gjson.Get(template, "choices.0.delta.images").Array())
|
||||||
|
imagePayload := `{"type":"image_url","image_url":{"url":""}}`
|
||||||
|
imagePayload, _ = sjson.Set(imagePayload, "index", imageIndex)
|
||||||
|
imagePayload, _ = sjson.Set(imagePayload, "image_url.url", imageURL)
|
||||||
template, _ = sjson.Set(template, "choices.0.delta.role", "assistant")
|
template, _ = sjson.Set(template, "choices.0.delta.role", "assistant")
|
||||||
template, _ = sjson.SetRaw(template, "choices.0.delta.images.-1", imagePayload)
|
template, _ = sjson.SetRaw(template, "choices.0.delta.images.-1", imagePayload)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -233,18 +233,15 @@ func ConvertOpenAIRequestToGemini(modelName string, inputRawJSON []byte, _ bool)
|
|||||||
} else if role == "assistant" {
|
} else if role == "assistant" {
|
||||||
node := []byte(`{"role":"model","parts":[]}`)
|
node := []byte(`{"role":"model","parts":[]}`)
|
||||||
p := 0
|
p := 0
|
||||||
|
|
||||||
if content.Type == gjson.String {
|
if content.Type == gjson.String {
|
||||||
// Assistant text -> single model content
|
// Assistant text -> single model content
|
||||||
node, _ = sjson.SetBytes(node, "parts.-1.text", content.String())
|
node, _ = sjson.SetBytes(node, "parts.-1.text", content.String())
|
||||||
out, _ = sjson.SetRawBytes(out, "contents.-1", node)
|
|
||||||
p++
|
p++
|
||||||
} else if content.IsArray() {
|
} else if content.IsArray() {
|
||||||
// Assistant multimodal content (e.g. text + image) -> single model content with parts
|
// Assistant multimodal content (e.g. text + image) -> single model content with parts
|
||||||
for _, item := range content.Array() {
|
for _, item := range content.Array() {
|
||||||
switch item.Get("type").String() {
|
switch item.Get("type").String() {
|
||||||
case "text":
|
case "text":
|
||||||
node, _ = sjson.SetBytes(node, "parts."+itoa(p)+".text", item.Get("text").String())
|
|
||||||
p++
|
p++
|
||||||
case "image_url":
|
case "image_url":
|
||||||
// If the assistant returned an inline data URL, preserve it for history fidelity.
|
// If the assistant returned an inline data URL, preserve it for history fidelity.
|
||||||
@@ -261,7 +258,6 @@ func ConvertOpenAIRequestToGemini(modelName string, inputRawJSON []byte, _ bool)
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
out, _ = sjson.SetRawBytes(out, "contents.-1", node)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Tool calls -> single model content with functionCall parts
|
// Tool calls -> single model content with functionCall parts
|
||||||
@@ -302,6 +298,8 @@ func ConvertOpenAIRequestToGemini(modelName string, inputRawJSON []byte, _ bool)
|
|||||||
if pp > 0 {
|
if pp > 0 {
|
||||||
out, _ = sjson.SetRawBytes(out, "contents.-1", toolNode)
|
out, _ = sjson.SetRawBytes(out, "contents.-1", toolNode)
|
||||||
}
|
}
|
||||||
|
} else {
|
||||||
|
out, _ = sjson.SetRawBytes(out, "contents.-1", node)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -89,15 +89,15 @@ func ConvertGeminiResponseToOpenAI(_ context.Context, _ string, originalRequestR
|
|||||||
|
|
||||||
// Extract and set usage metadata (token counts).
|
// Extract and set usage metadata (token counts).
|
||||||
if usageResult := gjson.GetBytes(rawJSON, "usageMetadata"); usageResult.Exists() {
|
if usageResult := gjson.GetBytes(rawJSON, "usageMetadata"); usageResult.Exists() {
|
||||||
|
cachedTokenCount := usageResult.Get("cachedContentTokenCount").Int()
|
||||||
if candidatesTokenCountResult := usageResult.Get("candidatesTokenCount"); candidatesTokenCountResult.Exists() {
|
if candidatesTokenCountResult := usageResult.Get("candidatesTokenCount"); candidatesTokenCountResult.Exists() {
|
||||||
template, _ = sjson.Set(template, "usage.completion_tokens", candidatesTokenCountResult.Int())
|
template, _ = sjson.Set(template, "usage.completion_tokens", candidatesTokenCountResult.Int())
|
||||||
}
|
}
|
||||||
if totalTokenCountResult := usageResult.Get("totalTokenCount"); totalTokenCountResult.Exists() {
|
if totalTokenCountResult := usageResult.Get("totalTokenCount"); totalTokenCountResult.Exists() {
|
||||||
template, _ = sjson.Set(template, "usage.total_tokens", totalTokenCountResult.Int())
|
template, _ = sjson.Set(template, "usage.total_tokens", totalTokenCountResult.Int())
|
||||||
}
|
}
|
||||||
promptTokenCount := usageResult.Get("promptTokenCount").Int()
|
promptTokenCount := usageResult.Get("promptTokenCount").Int() - cachedTokenCount
|
||||||
thoughtsTokenCount := usageResult.Get("thoughtsTokenCount").Int()
|
thoughtsTokenCount := usageResult.Get("thoughtsTokenCount").Int()
|
||||||
cachedTokenCount := usageResult.Get("cachedContentTokenCount").Int()
|
|
||||||
template, _ = sjson.Set(template, "usage.prompt_tokens", promptTokenCount+thoughtsTokenCount)
|
template, _ = sjson.Set(template, "usage.prompt_tokens", promptTokenCount+thoughtsTokenCount)
|
||||||
if thoughtsTokenCount > 0 {
|
if thoughtsTokenCount > 0 {
|
||||||
template, _ = sjson.Set(template, "usage.completion_tokens_details.reasoning_tokens", thoughtsTokenCount)
|
template, _ = sjson.Set(template, "usage.completion_tokens_details.reasoning_tokens", thoughtsTokenCount)
|
||||||
@@ -182,12 +182,14 @@ func ConvertGeminiResponseToOpenAI(_ context.Context, _ string, originalRequestR
|
|||||||
mimeType = "image/png"
|
mimeType = "image/png"
|
||||||
}
|
}
|
||||||
imageURL := fmt.Sprintf("data:%s;base64,%s", mimeType, data)
|
imageURL := fmt.Sprintf("data:%s;base64,%s", mimeType, data)
|
||||||
imagePayload := `{"image_url":{"url":""},"type":"image_url"}`
|
|
||||||
imagePayload, _ = sjson.Set(imagePayload, "image_url.url", imageURL)
|
|
||||||
imagesResult := gjson.Get(template, "choices.0.delta.images")
|
imagesResult := gjson.Get(template, "choices.0.delta.images")
|
||||||
if !imagesResult.Exists() || !imagesResult.IsArray() {
|
if !imagesResult.Exists() || !imagesResult.IsArray() {
|
||||||
template, _ = sjson.SetRaw(template, "choices.0.delta.images", `[]`)
|
template, _ = sjson.SetRaw(template, "choices.0.delta.images", `[]`)
|
||||||
}
|
}
|
||||||
|
imageIndex := len(gjson.Get(template, "choices.0.delta.images").Array())
|
||||||
|
imagePayload := `{"type":"image_url","image_url":{"url":""}}`
|
||||||
|
imagePayload, _ = sjson.Set(imagePayload, "index", imageIndex)
|
||||||
|
imagePayload, _ = sjson.Set(imagePayload, "image_url.url", imageURL)
|
||||||
template, _ = sjson.Set(template, "choices.0.delta.role", "assistant")
|
template, _ = sjson.Set(template, "choices.0.delta.role", "assistant")
|
||||||
template, _ = sjson.SetRaw(template, "choices.0.delta.images.-1", imagePayload)
|
template, _ = sjson.SetRaw(template, "choices.0.delta.images.-1", imagePayload)
|
||||||
}
|
}
|
||||||
@@ -316,12 +318,14 @@ func ConvertGeminiResponseToOpenAINonStream(_ context.Context, _ string, origina
|
|||||||
mimeType = "image/png"
|
mimeType = "image/png"
|
||||||
}
|
}
|
||||||
imageURL := fmt.Sprintf("data:%s;base64,%s", mimeType, data)
|
imageURL := fmt.Sprintf("data:%s;base64,%s", mimeType, data)
|
||||||
imagePayload := `{"image_url":{"url":""},"type":"image_url"}`
|
|
||||||
imagePayload, _ = sjson.Set(imagePayload, "image_url.url", imageURL)
|
|
||||||
imagesResult := gjson.Get(template, "choices.0.message.images")
|
imagesResult := gjson.Get(template, "choices.0.message.images")
|
||||||
if !imagesResult.Exists() || !imagesResult.IsArray() {
|
if !imagesResult.Exists() || !imagesResult.IsArray() {
|
||||||
template, _ = sjson.SetRaw(template, "choices.0.message.images", `[]`)
|
template, _ = sjson.SetRaw(template, "choices.0.message.images", `[]`)
|
||||||
}
|
}
|
||||||
|
imageIndex := len(gjson.Get(template, "choices.0.message.images").Array())
|
||||||
|
imagePayload := `{"type":"image_url","image_url":{"url":""}}`
|
||||||
|
imagePayload, _ = sjson.Set(imagePayload, "index", imageIndex)
|
||||||
|
imagePayload, _ = sjson.Set(imagePayload, "image_url.url", imageURL)
|
||||||
template, _ = sjson.Set(template, "choices.0.message.role", "assistant")
|
template, _ = sjson.Set(template, "choices.0.message.role", "assistant")
|
||||||
template, _ = sjson.SetRaw(template, "choices.0.message.images.-1", imagePayload)
|
template, _ = sjson.SetRaw(template, "choices.0.message.images.-1", imagePayload)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -10,6 +10,7 @@ import (
|
|||||||
// MergeAdjacentMessages merges adjacent messages with the same role.
|
// MergeAdjacentMessages merges adjacent messages with the same role.
|
||||||
// This reduces API call complexity and improves compatibility.
|
// This reduces API call complexity and improves compatibility.
|
||||||
// Based on AIClient-2-API implementation.
|
// Based on AIClient-2-API implementation.
|
||||||
|
// NOTE: Tool messages are NOT merged because each has a unique tool_call_id that must be preserved.
|
||||||
func MergeAdjacentMessages(messages []gjson.Result) []gjson.Result {
|
func MergeAdjacentMessages(messages []gjson.Result) []gjson.Result {
|
||||||
if len(messages) <= 1 {
|
if len(messages) <= 1 {
|
||||||
return messages
|
return messages
|
||||||
@@ -26,6 +27,12 @@ func MergeAdjacentMessages(messages []gjson.Result) []gjson.Result {
|
|||||||
currentRole := msg.Get("role").String()
|
currentRole := msg.Get("role").String()
|
||||||
lastRole := lastMsg.Get("role").String()
|
lastRole := lastMsg.Get("role").String()
|
||||||
|
|
||||||
|
// Don't merge tool messages - each has a unique tool_call_id
|
||||||
|
if currentRole == "tool" || lastRole == "tool" {
|
||||||
|
merged = append(merged, msg)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
if currentRole == lastRole {
|
if currentRole == lastRole {
|
||||||
// Merge content from current message into last message
|
// Merge content from current message into last message
|
||||||
mergedContent := mergeMessageContent(lastMsg, msg)
|
mergedContent := mergeMessageContent(lastMsg, msg)
|
||||||
|
|||||||
@@ -450,24 +450,10 @@ func processOpenAIMessages(messages gjson.Result, modelID, origin string) ([]Kir
|
|||||||
// Merge adjacent messages with the same role
|
// Merge adjacent messages with the same role
|
||||||
messagesArray := kirocommon.MergeAdjacentMessages(messages.Array())
|
messagesArray := kirocommon.MergeAdjacentMessages(messages.Array())
|
||||||
|
|
||||||
// Build tool_call_id to name mapping from assistant messages
|
// Track pending tool results that should be attached to the next user message
|
||||||
toolCallIDToName := make(map[string]string)
|
// This is critical for LiteLLM-translated requests where tool results appear
|
||||||
for _, msg := range messagesArray {
|
// as separate "tool" role messages between assistant and user messages
|
||||||
if msg.Get("role").String() == "assistant" {
|
var pendingToolResults []KiroToolResult
|
||||||
toolCalls := msg.Get("tool_calls")
|
|
||||||
if toolCalls.IsArray() {
|
|
||||||
for _, tc := range toolCalls.Array() {
|
|
||||||
if tc.Get("type").String() == "function" {
|
|
||||||
id := tc.Get("id").String()
|
|
||||||
name := tc.Get("function.name").String()
|
|
||||||
if id != "" && name != "" {
|
|
||||||
toolCallIDToName[id] = name
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
for i, msg := range messagesArray {
|
for i, msg := range messagesArray {
|
||||||
role := msg.Get("role").String()
|
role := msg.Get("role").String()
|
||||||
@@ -480,6 +466,10 @@ func processOpenAIMessages(messages gjson.Result, modelID, origin string) ([]Kir
|
|||||||
|
|
||||||
case "user":
|
case "user":
|
||||||
userMsg, toolResults := buildUserMessageFromOpenAI(msg, modelID, origin)
|
userMsg, toolResults := buildUserMessageFromOpenAI(msg, modelID, origin)
|
||||||
|
// Merge any pending tool results from preceding "tool" role messages
|
||||||
|
toolResults = append(pendingToolResults, toolResults...)
|
||||||
|
pendingToolResults = nil // Reset pending tool results
|
||||||
|
|
||||||
if isLastMessage {
|
if isLastMessage {
|
||||||
currentUserMsg = &userMsg
|
currentUserMsg = &userMsg
|
||||||
currentToolResults = toolResults
|
currentToolResults = toolResults
|
||||||
@@ -505,6 +495,24 @@ func processOpenAIMessages(messages gjson.Result, modelID, origin string) ([]Kir
|
|||||||
|
|
||||||
case "assistant":
|
case "assistant":
|
||||||
assistantMsg := buildAssistantMessageFromOpenAI(msg)
|
assistantMsg := buildAssistantMessageFromOpenAI(msg)
|
||||||
|
|
||||||
|
// If there are pending tool results, we need to insert a synthetic user message
|
||||||
|
// before this assistant message to maintain proper conversation structure
|
||||||
|
if len(pendingToolResults) > 0 {
|
||||||
|
syntheticUserMsg := KiroUserInputMessage{
|
||||||
|
Content: "Tool results provided.",
|
||||||
|
ModelID: modelID,
|
||||||
|
Origin: origin,
|
||||||
|
UserInputMessageContext: &KiroUserInputMessageContext{
|
||||||
|
ToolResults: pendingToolResults,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
history = append(history, KiroHistoryMessage{
|
||||||
|
UserInputMessage: &syntheticUserMsg,
|
||||||
|
})
|
||||||
|
pendingToolResults = nil
|
||||||
|
}
|
||||||
|
|
||||||
if isLastMessage {
|
if isLastMessage {
|
||||||
history = append(history, KiroHistoryMessage{
|
history = append(history, KiroHistoryMessage{
|
||||||
AssistantResponseMessage: &assistantMsg,
|
AssistantResponseMessage: &assistantMsg,
|
||||||
@@ -524,7 +532,7 @@ func processOpenAIMessages(messages gjson.Result, modelID, origin string) ([]Kir
|
|||||||
case "tool":
|
case "tool":
|
||||||
// Tool messages in OpenAI format provide results for tool_calls
|
// Tool messages in OpenAI format provide results for tool_calls
|
||||||
// These are typically followed by user or assistant messages
|
// These are typically followed by user or assistant messages
|
||||||
// Process them and merge into the next user message's tool results
|
// Collect them as pending and attach to the next user message
|
||||||
toolCallID := msg.Get("tool_call_id").String()
|
toolCallID := msg.Get("tool_call_id").String()
|
||||||
content := msg.Get("content").String()
|
content := msg.Get("content").String()
|
||||||
|
|
||||||
@@ -534,9 +542,21 @@ func processOpenAIMessages(messages gjson.Result, modelID, origin string) ([]Kir
|
|||||||
Content: []KiroTextContent{{Text: content}},
|
Content: []KiroTextContent{{Text: content}},
|
||||||
Status: "success",
|
Status: "success",
|
||||||
}
|
}
|
||||||
// Tool results should be included in the next user message
|
// Collect pending tool results to attach to the next user message
|
||||||
// For now, collect them and they'll be handled when we build the current message
|
pendingToolResults = append(pendingToolResults, toolResult)
|
||||||
currentToolResults = append(currentToolResults, toolResult)
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Handle case where tool results are at the end with no following user message
|
||||||
|
if len(pendingToolResults) > 0 {
|
||||||
|
currentToolResults = append(currentToolResults, pendingToolResults...)
|
||||||
|
// If there's no current user message, create a synthetic one for the tool results
|
||||||
|
if currentUserMsg == nil {
|
||||||
|
currentUserMsg = &KiroUserInputMessage{
|
||||||
|
Content: "Tool results provided.",
|
||||||
|
ModelID: modelID,
|
||||||
|
Origin: origin,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -551,9 +571,6 @@ func buildUserMessageFromOpenAI(msg gjson.Result, modelID, origin string) (KiroU
|
|||||||
var toolResults []KiroToolResult
|
var toolResults []KiroToolResult
|
||||||
var images []KiroImage
|
var images []KiroImage
|
||||||
|
|
||||||
// Track seen toolCallIds to deduplicate
|
|
||||||
seenToolCallIDs := make(map[string]bool)
|
|
||||||
|
|
||||||
if content.IsArray() {
|
if content.IsArray() {
|
||||||
for _, part := range content.Array() {
|
for _, part := range content.Array() {
|
||||||
partType := part.Get("type").String()
|
partType := part.Get("type").String()
|
||||||
@@ -589,9 +606,6 @@ func buildUserMessageFromOpenAI(msg gjson.Result, modelID, origin string) (KiroU
|
|||||||
contentBuilder.WriteString(content.String())
|
contentBuilder.WriteString(content.String())
|
||||||
}
|
}
|
||||||
|
|
||||||
// Check for tool_calls in the message (shouldn't be in user messages, but handle edge cases)
|
|
||||||
_ = seenToolCallIDs // Used for deduplication if needed
|
|
||||||
|
|
||||||
userMsg := KiroUserInputMessage{
|
userMsg := KiroUserInputMessage{
|
||||||
Content: contentBuilder.String(),
|
Content: contentBuilder.String(),
|
||||||
ModelID: modelID,
|
ModelID: modelID,
|
||||||
|
|||||||
386
internal/translator/kiro/openai/kiro_openai_request_test.go
Normal file
386
internal/translator/kiro/openai/kiro_openai_request_test.go
Normal file
@@ -0,0 +1,386 @@
|
|||||||
|
package openai
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/json"
|
||||||
|
"testing"
|
||||||
|
)
|
||||||
|
|
||||||
|
// TestToolResultsAttachedToCurrentMessage verifies that tool results from "tool" role messages
|
||||||
|
// are properly attached to the current user message (the last message in the conversation).
|
||||||
|
// This is critical for LiteLLM-translated requests where tool results appear as separate messages.
|
||||||
|
func TestToolResultsAttachedToCurrentMessage(t *testing.T) {
|
||||||
|
// OpenAI format request simulating LiteLLM's translation from Anthropic format
|
||||||
|
// Sequence: user -> assistant (with tool_calls) -> tool (result) -> user
|
||||||
|
// The last user message should have the tool results attached
|
||||||
|
input := []byte(`{
|
||||||
|
"model": "kiro-claude-opus-4-5-agentic",
|
||||||
|
"messages": [
|
||||||
|
{"role": "user", "content": "Hello, can you read a file for me?"},
|
||||||
|
{
|
||||||
|
"role": "assistant",
|
||||||
|
"content": "I'll read that file for you.",
|
||||||
|
"tool_calls": [
|
||||||
|
{
|
||||||
|
"id": "call_abc123",
|
||||||
|
"type": "function",
|
||||||
|
"function": {
|
||||||
|
"name": "Read",
|
||||||
|
"arguments": "{\"file_path\": \"/tmp/test.txt\"}"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"role": "tool",
|
||||||
|
"tool_call_id": "call_abc123",
|
||||||
|
"content": "File contents: Hello World!"
|
||||||
|
},
|
||||||
|
{"role": "user", "content": "What did the file say?"}
|
||||||
|
]
|
||||||
|
}`)
|
||||||
|
|
||||||
|
result, _ := BuildKiroPayloadFromOpenAI(input, "kiro-model", "", "CLI", false, false, nil, nil)
|
||||||
|
|
||||||
|
var payload KiroPayload
|
||||||
|
if err := json.Unmarshal(result, &payload); err != nil {
|
||||||
|
t.Fatalf("Failed to unmarshal result: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// The last user message becomes currentMessage
|
||||||
|
// History should have: user (first), assistant (with tool_calls)
|
||||||
|
t.Logf("History count: %d", len(payload.ConversationState.History))
|
||||||
|
if len(payload.ConversationState.History) != 2 {
|
||||||
|
t.Errorf("Expected 2 history entries (user + assistant), got %d", len(payload.ConversationState.History))
|
||||||
|
}
|
||||||
|
|
||||||
|
// Tool results should be attached to currentMessage (the last user message)
|
||||||
|
ctx := payload.ConversationState.CurrentMessage.UserInputMessage.UserInputMessageContext
|
||||||
|
if ctx == nil {
|
||||||
|
t.Fatal("Expected currentMessage to have UserInputMessageContext with tool results")
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(ctx.ToolResults) != 1 {
|
||||||
|
t.Fatalf("Expected 1 tool result in currentMessage, got %d", len(ctx.ToolResults))
|
||||||
|
}
|
||||||
|
|
||||||
|
tr := ctx.ToolResults[0]
|
||||||
|
if tr.ToolUseID != "call_abc123" {
|
||||||
|
t.Errorf("Expected toolUseId 'call_abc123', got '%s'", tr.ToolUseID)
|
||||||
|
}
|
||||||
|
if len(tr.Content) == 0 || tr.Content[0].Text != "File contents: Hello World!" {
|
||||||
|
t.Errorf("Tool result content mismatch, got: %+v", tr.Content)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestToolResultsInHistoryUserMessage verifies that when there are multiple user messages
|
||||||
|
// after tool results, the tool results are attached to the correct user message in history.
|
||||||
|
func TestToolResultsInHistoryUserMessage(t *testing.T) {
|
||||||
|
// Sequence: user -> assistant (with tool_calls) -> tool (result) -> user -> assistant -> user
|
||||||
|
// The first user after tool should have tool results in history
|
||||||
|
input := []byte(`{
|
||||||
|
"model": "kiro-claude-opus-4-5-agentic",
|
||||||
|
"messages": [
|
||||||
|
{"role": "user", "content": "Hello"},
|
||||||
|
{
|
||||||
|
"role": "assistant",
|
||||||
|
"content": "I'll read the file.",
|
||||||
|
"tool_calls": [
|
||||||
|
{
|
||||||
|
"id": "call_1",
|
||||||
|
"type": "function",
|
||||||
|
"function": {
|
||||||
|
"name": "Read",
|
||||||
|
"arguments": "{}"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"role": "tool",
|
||||||
|
"tool_call_id": "call_1",
|
||||||
|
"content": "File result"
|
||||||
|
},
|
||||||
|
{"role": "user", "content": "Thanks for the file"},
|
||||||
|
{"role": "assistant", "content": "You're welcome"},
|
||||||
|
{"role": "user", "content": "Bye"}
|
||||||
|
]
|
||||||
|
}`)
|
||||||
|
|
||||||
|
result, _ := BuildKiroPayloadFromOpenAI(input, "kiro-model", "", "CLI", false, false, nil, nil)
|
||||||
|
|
||||||
|
var payload KiroPayload
|
||||||
|
if err := json.Unmarshal(result, &payload); err != nil {
|
||||||
|
t.Fatalf("Failed to unmarshal result: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// History should have: user, assistant, user (with tool results), assistant
|
||||||
|
// CurrentMessage should be: last user "Bye"
|
||||||
|
t.Logf("History count: %d", len(payload.ConversationState.History))
|
||||||
|
|
||||||
|
// Find the user message in history with tool results
|
||||||
|
foundToolResults := false
|
||||||
|
for i, h := range payload.ConversationState.History {
|
||||||
|
if h.UserInputMessage != nil {
|
||||||
|
t.Logf("History[%d]: user message content=%q", i, h.UserInputMessage.Content)
|
||||||
|
if h.UserInputMessage.UserInputMessageContext != nil {
|
||||||
|
if len(h.UserInputMessage.UserInputMessageContext.ToolResults) > 0 {
|
||||||
|
foundToolResults = true
|
||||||
|
t.Logf(" Found %d tool results", len(h.UserInputMessage.UserInputMessageContext.ToolResults))
|
||||||
|
tr := h.UserInputMessage.UserInputMessageContext.ToolResults[0]
|
||||||
|
if tr.ToolUseID != "call_1" {
|
||||||
|
t.Errorf("Expected toolUseId 'call_1', got '%s'", tr.ToolUseID)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if h.AssistantResponseMessage != nil {
|
||||||
|
t.Logf("History[%d]: assistant message content=%q", i, h.AssistantResponseMessage.Content)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if !foundToolResults {
|
||||||
|
t.Error("Tool results were not attached to any user message in history")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestToolResultsWithMultipleToolCalls verifies handling of multiple tool calls
|
||||||
|
func TestToolResultsWithMultipleToolCalls(t *testing.T) {
|
||||||
|
input := []byte(`{
|
||||||
|
"model": "kiro-claude-opus-4-5-agentic",
|
||||||
|
"messages": [
|
||||||
|
{"role": "user", "content": "Read two files for me"},
|
||||||
|
{
|
||||||
|
"role": "assistant",
|
||||||
|
"content": "I'll read both files.",
|
||||||
|
"tool_calls": [
|
||||||
|
{
|
||||||
|
"id": "call_1",
|
||||||
|
"type": "function",
|
||||||
|
"function": {
|
||||||
|
"name": "Read",
|
||||||
|
"arguments": "{\"file_path\": \"/tmp/file1.txt\"}"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": "call_2",
|
||||||
|
"type": "function",
|
||||||
|
"function": {
|
||||||
|
"name": "Read",
|
||||||
|
"arguments": "{\"file_path\": \"/tmp/file2.txt\"}"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"role": "tool",
|
||||||
|
"tool_call_id": "call_1",
|
||||||
|
"content": "Content of file 1"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"role": "tool",
|
||||||
|
"tool_call_id": "call_2",
|
||||||
|
"content": "Content of file 2"
|
||||||
|
},
|
||||||
|
{"role": "user", "content": "What do they say?"}
|
||||||
|
]
|
||||||
|
}`)
|
||||||
|
|
||||||
|
result, _ := BuildKiroPayloadFromOpenAI(input, "kiro-model", "", "CLI", false, false, nil, nil)
|
||||||
|
|
||||||
|
var payload KiroPayload
|
||||||
|
if err := json.Unmarshal(result, &payload); err != nil {
|
||||||
|
t.Fatalf("Failed to unmarshal result: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
t.Logf("History count: %d", len(payload.ConversationState.History))
|
||||||
|
t.Logf("CurrentMessage content: %q", payload.ConversationState.CurrentMessage.UserInputMessage.Content)
|
||||||
|
|
||||||
|
// Check if there are any tool results anywhere
|
||||||
|
var totalToolResults int
|
||||||
|
for i, h := range payload.ConversationState.History {
|
||||||
|
if h.UserInputMessage != nil && h.UserInputMessage.UserInputMessageContext != nil {
|
||||||
|
count := len(h.UserInputMessage.UserInputMessageContext.ToolResults)
|
||||||
|
t.Logf("History[%d] user message has %d tool results", i, count)
|
||||||
|
totalToolResults += count
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
ctx := payload.ConversationState.CurrentMessage.UserInputMessage.UserInputMessageContext
|
||||||
|
if ctx != nil {
|
||||||
|
t.Logf("CurrentMessage has %d tool results", len(ctx.ToolResults))
|
||||||
|
totalToolResults += len(ctx.ToolResults)
|
||||||
|
} else {
|
||||||
|
t.Logf("CurrentMessage has no UserInputMessageContext")
|
||||||
|
}
|
||||||
|
|
||||||
|
if totalToolResults != 2 {
|
||||||
|
t.Errorf("Expected 2 tool results total, got %d", totalToolResults)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestToolResultsAtEndOfConversation verifies tool results are handled when
|
||||||
|
// the conversation ends with tool results (no following user message)
|
||||||
|
func TestToolResultsAtEndOfConversation(t *testing.T) {
|
||||||
|
input := []byte(`{
|
||||||
|
"model": "kiro-claude-opus-4-5-agentic",
|
||||||
|
"messages": [
|
||||||
|
{"role": "user", "content": "Read a file"},
|
||||||
|
{
|
||||||
|
"role": "assistant",
|
||||||
|
"content": "Reading the file.",
|
||||||
|
"tool_calls": [
|
||||||
|
{
|
||||||
|
"id": "call_end",
|
||||||
|
"type": "function",
|
||||||
|
"function": {
|
||||||
|
"name": "Read",
|
||||||
|
"arguments": "{\"file_path\": \"/tmp/test.txt\"}"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"role": "tool",
|
||||||
|
"tool_call_id": "call_end",
|
||||||
|
"content": "File contents here"
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}`)
|
||||||
|
|
||||||
|
result, _ := BuildKiroPayloadFromOpenAI(input, "kiro-model", "", "CLI", false, false, nil, nil)
|
||||||
|
|
||||||
|
var payload KiroPayload
|
||||||
|
if err := json.Unmarshal(result, &payload); err != nil {
|
||||||
|
t.Fatalf("Failed to unmarshal result: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// When the last message is a tool result, a synthetic user message is created
|
||||||
|
// and tool results should be attached to it
|
||||||
|
ctx := payload.ConversationState.CurrentMessage.UserInputMessage.UserInputMessageContext
|
||||||
|
if ctx == nil || len(ctx.ToolResults) == 0 {
|
||||||
|
t.Error("Expected tool results to be attached to current message when conversation ends with tool result")
|
||||||
|
} else {
|
||||||
|
if ctx.ToolResults[0].ToolUseID != "call_end" {
|
||||||
|
t.Errorf("Expected toolUseId 'call_end', got '%s'", ctx.ToolResults[0].ToolUseID)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestToolResultsFollowedByAssistant verifies handling when tool results are followed
|
||||||
|
// by an assistant message (no intermediate user message).
|
||||||
|
// This is the pattern from LiteLLM translation of Anthropic format where:
|
||||||
|
// user message has ONLY tool_result blocks -> LiteLLM creates tool messages
|
||||||
|
// then the next message is assistant
|
||||||
|
func TestToolResultsFollowedByAssistant(t *testing.T) {
|
||||||
|
// Sequence: user -> assistant (with tool_calls) -> tool -> tool -> assistant -> user
|
||||||
|
// This simulates LiteLLM's translation of:
|
||||||
|
// user: "Read files"
|
||||||
|
// assistant: [tool_use, tool_use]
|
||||||
|
// user: [tool_result, tool_result] <- becomes multiple "tool" role messages
|
||||||
|
// assistant: "I've read them"
|
||||||
|
// user: "What did they say?"
|
||||||
|
input := []byte(`{
|
||||||
|
"model": "kiro-claude-opus-4-5-agentic",
|
||||||
|
"messages": [
|
||||||
|
{"role": "user", "content": "Read two files for me"},
|
||||||
|
{
|
||||||
|
"role": "assistant",
|
||||||
|
"content": "I'll read both files.",
|
||||||
|
"tool_calls": [
|
||||||
|
{
|
||||||
|
"id": "call_1",
|
||||||
|
"type": "function",
|
||||||
|
"function": {
|
||||||
|
"name": "Read",
|
||||||
|
"arguments": "{\"file_path\": \"/tmp/a.txt\"}"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": "call_2",
|
||||||
|
"type": "function",
|
||||||
|
"function": {
|
||||||
|
"name": "Read",
|
||||||
|
"arguments": "{\"file_path\": \"/tmp/b.txt\"}"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"role": "tool",
|
||||||
|
"tool_call_id": "call_1",
|
||||||
|
"content": "Contents of file A"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"role": "tool",
|
||||||
|
"tool_call_id": "call_2",
|
||||||
|
"content": "Contents of file B"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"role": "assistant",
|
||||||
|
"content": "I've read both files."
|
||||||
|
},
|
||||||
|
{"role": "user", "content": "What did they say?"}
|
||||||
|
]
|
||||||
|
}`)
|
||||||
|
|
||||||
|
result, _ := BuildKiroPayloadFromOpenAI(input, "kiro-model", "", "CLI", false, false, nil, nil)
|
||||||
|
|
||||||
|
var payload KiroPayload
|
||||||
|
if err := json.Unmarshal(result, &payload); err != nil {
|
||||||
|
t.Fatalf("Failed to unmarshal result: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
t.Logf("History count: %d", len(payload.ConversationState.History))
|
||||||
|
|
||||||
|
// Tool results should be attached to a synthetic user message or the history should be valid
|
||||||
|
var totalToolResults int
|
||||||
|
for i, h := range payload.ConversationState.History {
|
||||||
|
if h.UserInputMessage != nil {
|
||||||
|
t.Logf("History[%d]: user message content=%q", i, h.UserInputMessage.Content)
|
||||||
|
if h.UserInputMessage.UserInputMessageContext != nil {
|
||||||
|
count := len(h.UserInputMessage.UserInputMessageContext.ToolResults)
|
||||||
|
t.Logf(" Has %d tool results", count)
|
||||||
|
totalToolResults += count
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if h.AssistantResponseMessage != nil {
|
||||||
|
t.Logf("History[%d]: assistant message content=%q", i, h.AssistantResponseMessage.Content)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
ctx := payload.ConversationState.CurrentMessage.UserInputMessage.UserInputMessageContext
|
||||||
|
if ctx != nil {
|
||||||
|
t.Logf("CurrentMessage has %d tool results", len(ctx.ToolResults))
|
||||||
|
totalToolResults += len(ctx.ToolResults)
|
||||||
|
}
|
||||||
|
|
||||||
|
if totalToolResults != 2 {
|
||||||
|
t.Errorf("Expected 2 tool results total, got %d", totalToolResults)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestAssistantEndsConversation verifies handling when assistant is the last message
|
||||||
|
func TestAssistantEndsConversation(t *testing.T) {
|
||||||
|
input := []byte(`{
|
||||||
|
"model": "kiro-claude-opus-4-5-agentic",
|
||||||
|
"messages": [
|
||||||
|
{"role": "user", "content": "Hello"},
|
||||||
|
{
|
||||||
|
"role": "assistant",
|
||||||
|
"content": "Hi there!"
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}`)
|
||||||
|
|
||||||
|
result, _ := BuildKiroPayloadFromOpenAI(input, "kiro-model", "", "CLI", false, false, nil, nil)
|
||||||
|
|
||||||
|
var payload KiroPayload
|
||||||
|
if err := json.Unmarshal(result, &payload); err != nil {
|
||||||
|
t.Fatalf("Failed to unmarshal result: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// When assistant is last, a "Continue" user message should be created
|
||||||
|
if payload.ConversationState.CurrentMessage.UserInputMessage.Content == "" {
|
||||||
|
t.Error("Expected a 'Continue' message to be created when assistant is last")
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -6,6 +6,7 @@ package usage
|
|||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"strings"
|
||||||
"sync"
|
"sync"
|
||||||
"sync/atomic"
|
"sync/atomic"
|
||||||
"time"
|
"time"
|
||||||
@@ -90,7 +91,7 @@ type modelStats struct {
|
|||||||
type RequestDetail struct {
|
type RequestDetail struct {
|
||||||
Timestamp time.Time `json:"timestamp"`
|
Timestamp time.Time `json:"timestamp"`
|
||||||
Source string `json:"source"`
|
Source string `json:"source"`
|
||||||
AuthIndex uint64 `json:"auth_index"`
|
AuthIndex string `json:"auth_index"`
|
||||||
Tokens TokenStats `json:"tokens"`
|
Tokens TokenStats `json:"tokens"`
|
||||||
Failed bool `json:"failed"`
|
Failed bool `json:"failed"`
|
||||||
}
|
}
|
||||||
@@ -281,6 +282,118 @@ func (s *RequestStatistics) Snapshot() StatisticsSnapshot {
|
|||||||
return result
|
return result
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type MergeResult struct {
|
||||||
|
Added int64 `json:"added"`
|
||||||
|
Skipped int64 `json:"skipped"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// MergeSnapshot merges an exported statistics snapshot into the current store.
|
||||||
|
// Existing data is preserved and duplicate request details are skipped.
|
||||||
|
func (s *RequestStatistics) MergeSnapshot(snapshot StatisticsSnapshot) MergeResult {
|
||||||
|
result := MergeResult{}
|
||||||
|
if s == nil {
|
||||||
|
return result
|
||||||
|
}
|
||||||
|
|
||||||
|
s.mu.Lock()
|
||||||
|
defer s.mu.Unlock()
|
||||||
|
|
||||||
|
seen := make(map[string]struct{})
|
||||||
|
for apiName, stats := range s.apis {
|
||||||
|
if stats == nil {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
for modelName, modelStatsValue := range stats.Models {
|
||||||
|
if modelStatsValue == nil {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
for _, detail := range modelStatsValue.Details {
|
||||||
|
seen[dedupKey(apiName, modelName, detail)] = struct{}{}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
for apiName, apiSnapshot := range snapshot.APIs {
|
||||||
|
apiName = strings.TrimSpace(apiName)
|
||||||
|
if apiName == "" {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
stats, ok := s.apis[apiName]
|
||||||
|
if !ok || stats == nil {
|
||||||
|
stats = &apiStats{Models: make(map[string]*modelStats)}
|
||||||
|
s.apis[apiName] = stats
|
||||||
|
} else if stats.Models == nil {
|
||||||
|
stats.Models = make(map[string]*modelStats)
|
||||||
|
}
|
||||||
|
for modelName, modelSnapshot := range apiSnapshot.Models {
|
||||||
|
modelName = strings.TrimSpace(modelName)
|
||||||
|
if modelName == "" {
|
||||||
|
modelName = "unknown"
|
||||||
|
}
|
||||||
|
for _, detail := range modelSnapshot.Details {
|
||||||
|
detail.Tokens = normaliseTokenStats(detail.Tokens)
|
||||||
|
if detail.Timestamp.IsZero() {
|
||||||
|
detail.Timestamp = time.Now()
|
||||||
|
}
|
||||||
|
key := dedupKey(apiName, modelName, detail)
|
||||||
|
if _, exists := seen[key]; exists {
|
||||||
|
result.Skipped++
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
seen[key] = struct{}{}
|
||||||
|
s.recordImported(apiName, modelName, stats, detail)
|
||||||
|
result.Added++
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return result
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *RequestStatistics) recordImported(apiName, modelName string, stats *apiStats, detail RequestDetail) {
|
||||||
|
totalTokens := detail.Tokens.TotalTokens
|
||||||
|
if totalTokens < 0 {
|
||||||
|
totalTokens = 0
|
||||||
|
}
|
||||||
|
|
||||||
|
s.totalRequests++
|
||||||
|
if detail.Failed {
|
||||||
|
s.failureCount++
|
||||||
|
} else {
|
||||||
|
s.successCount++
|
||||||
|
}
|
||||||
|
s.totalTokens += totalTokens
|
||||||
|
|
||||||
|
s.updateAPIStats(stats, modelName, detail)
|
||||||
|
|
||||||
|
dayKey := detail.Timestamp.Format("2006-01-02")
|
||||||
|
hourKey := detail.Timestamp.Hour()
|
||||||
|
|
||||||
|
s.requestsByDay[dayKey]++
|
||||||
|
s.requestsByHour[hourKey]++
|
||||||
|
s.tokensByDay[dayKey] += totalTokens
|
||||||
|
s.tokensByHour[hourKey] += totalTokens
|
||||||
|
}
|
||||||
|
|
||||||
|
func dedupKey(apiName, modelName string, detail RequestDetail) string {
|
||||||
|
timestamp := detail.Timestamp.UTC().Format(time.RFC3339Nano)
|
||||||
|
tokens := normaliseTokenStats(detail.Tokens)
|
||||||
|
return fmt.Sprintf(
|
||||||
|
"%s|%s|%s|%s|%s|%t|%d|%d|%d|%d|%d",
|
||||||
|
apiName,
|
||||||
|
modelName,
|
||||||
|
timestamp,
|
||||||
|
detail.Source,
|
||||||
|
detail.AuthIndex,
|
||||||
|
detail.Failed,
|
||||||
|
tokens.InputTokens,
|
||||||
|
tokens.OutputTokens,
|
||||||
|
tokens.ReasoningTokens,
|
||||||
|
tokens.CachedTokens,
|
||||||
|
tokens.TotalTokens,
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
func resolveAPIIdentifier(ctx context.Context, record coreusage.Record) string {
|
func resolveAPIIdentifier(ctx context.Context, record coreusage.Record) string {
|
||||||
if ctx != nil {
|
if ctx != nil {
|
||||||
if ginCtx, ok := ctx.Value("gin").(*gin.Context); ok && ginCtx != nil {
|
if ginCtx, ok := ctx.Value("gin").(*gin.Context); ok && ginCtx != nil {
|
||||||
@@ -340,6 +453,16 @@ func normaliseDetail(detail coreusage.Detail) TokenStats {
|
|||||||
return tokens
|
return tokens
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func normaliseTokenStats(tokens TokenStats) TokenStats {
|
||||||
|
if tokens.TotalTokens == 0 {
|
||||||
|
tokens.TotalTokens = tokens.InputTokens + tokens.OutputTokens + tokens.ReasoningTokens
|
||||||
|
}
|
||||||
|
if tokens.TotalTokens == 0 {
|
||||||
|
tokens.TotalTokens = tokens.InputTokens + tokens.OutputTokens + tokens.ReasoningTokens + tokens.CachedTokens
|
||||||
|
}
|
||||||
|
return tokens
|
||||||
|
}
|
||||||
|
|
||||||
func formatHour(hour int) string {
|
func formatHour(hour int) string {
|
||||||
if hour < 0 {
|
if hour < 0 {
|
||||||
hour = 0
|
hour = 0
|
||||||
|
|||||||
@@ -104,8 +104,8 @@ func BuildErrorResponseBody(status int, errText string) []byte {
|
|||||||
// Returning 0 disables keep-alives (default when unset).
|
// Returning 0 disables keep-alives (default when unset).
|
||||||
func StreamingKeepAliveInterval(cfg *config.SDKConfig) time.Duration {
|
func StreamingKeepAliveInterval(cfg *config.SDKConfig) time.Duration {
|
||||||
seconds := defaultStreamingKeepAliveSeconds
|
seconds := defaultStreamingKeepAliveSeconds
|
||||||
if cfg != nil && cfg.Streaming.KeepAliveSeconds != nil {
|
if cfg != nil {
|
||||||
seconds = *cfg.Streaming.KeepAliveSeconds
|
seconds = cfg.Streaming.KeepAliveSeconds
|
||||||
}
|
}
|
||||||
if seconds <= 0 {
|
if seconds <= 0 {
|
||||||
return 0
|
return 0
|
||||||
@@ -116,8 +116,8 @@ func StreamingKeepAliveInterval(cfg *config.SDKConfig) time.Duration {
|
|||||||
// StreamingBootstrapRetries returns how many times a streaming request may be retried before any bytes are sent.
|
// StreamingBootstrapRetries returns how many times a streaming request may be retried before any bytes are sent.
|
||||||
func StreamingBootstrapRetries(cfg *config.SDKConfig) int {
|
func StreamingBootstrapRetries(cfg *config.SDKConfig) int {
|
||||||
retries := defaultStreamingBootstrapRetries
|
retries := defaultStreamingBootstrapRetries
|
||||||
if cfg != nil && cfg.Streaming.BootstrapRetries != nil {
|
if cfg != nil {
|
||||||
retries = *cfg.Streaming.BootstrapRetries
|
retries = cfg.Streaming.BootstrapRetries
|
||||||
}
|
}
|
||||||
if retries < 0 {
|
if retries < 0 {
|
||||||
retries = 0
|
retries = 0
|
||||||
|
|||||||
@@ -94,12 +94,11 @@ func TestExecuteStreamWithAuthManager_RetriesBeforeFirstByte(t *testing.T) {
|
|||||||
registry.GetGlobalRegistry().UnregisterClient(auth2.ID)
|
registry.GetGlobalRegistry().UnregisterClient(auth2.ID)
|
||||||
})
|
})
|
||||||
|
|
||||||
bootstrapRetries := 1
|
|
||||||
handler := NewBaseAPIHandlers(&sdkconfig.SDKConfig{
|
handler := NewBaseAPIHandlers(&sdkconfig.SDKConfig{
|
||||||
Streaming: sdkconfig.StreamingConfig{
|
Streaming: sdkconfig.StreamingConfig{
|
||||||
BootstrapRetries: &bootstrapRetries,
|
BootstrapRetries: 1,
|
||||||
},
|
},
|
||||||
}, manager, nil)
|
}, manager)
|
||||||
dataChan, errChan := handler.ExecuteStreamWithAuthManager(context.Background(), "openai", "test-model", []byte(`{"model":"test-model"}`), "")
|
dataChan, errChan := handler.ExecuteStreamWithAuthManager(context.Background(), "openai", "test-model", []byte(`{"model":"test-model"}`), "")
|
||||||
if dataChan == nil || errChan == nil {
|
if dataChan == nil || errChan == nil {
|
||||||
t.Fatalf("expected non-nil channels")
|
t.Fatalf("expected non-nil channels")
|
||||||
|
|||||||
116
sdk/auth/kiro.go
116
sdk/auth/kiro.go
@@ -53,20 +53,8 @@ func (a *KiroAuthenticator) RefreshLead() *time.Duration {
|
|||||||
return &d
|
return &d
|
||||||
}
|
}
|
||||||
|
|
||||||
// Login performs OAuth login for Kiro with AWS Builder ID.
|
// createAuthRecord creates an auth record from token data.
|
||||||
func (a *KiroAuthenticator) Login(ctx context.Context, cfg *config.Config, opts *LoginOptions) (*coreauth.Auth, error) {
|
func (a *KiroAuthenticator) createAuthRecord(tokenData *kiroauth.KiroTokenData, source string) (*coreauth.Auth, error) {
|
||||||
if cfg == nil {
|
|
||||||
return nil, fmt.Errorf("kiro auth: configuration is required")
|
|
||||||
}
|
|
||||||
|
|
||||||
oauth := kiroauth.NewKiroOAuth(cfg)
|
|
||||||
|
|
||||||
// Use AWS Builder ID device code flow
|
|
||||||
tokenData, err := oauth.LoginWithBuilderID(ctx)
|
|
||||||
if err != nil {
|
|
||||||
return nil, fmt.Errorf("login failed: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Parse expires_at
|
// Parse expires_at
|
||||||
expiresAt, err := time.Parse(time.RFC3339, tokenData.ExpiresAt)
|
expiresAt, err := time.Parse(time.RFC3339, tokenData.ExpiresAt)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -76,34 +64,63 @@ func (a *KiroAuthenticator) Login(ctx context.Context, cfg *config.Config, opts
|
|||||||
// Extract identifier for file naming
|
// Extract identifier for file naming
|
||||||
idPart := extractKiroIdentifier(tokenData.Email, tokenData.ProfileArn)
|
idPart := extractKiroIdentifier(tokenData.Email, tokenData.ProfileArn)
|
||||||
|
|
||||||
|
// Determine label based on auth method
|
||||||
|
label := fmt.Sprintf("kiro-%s", source)
|
||||||
|
if tokenData.AuthMethod == "idc" {
|
||||||
|
label = "kiro-idc"
|
||||||
|
}
|
||||||
|
|
||||||
now := time.Now()
|
now := time.Now()
|
||||||
fileName := fmt.Sprintf("kiro-aws-%s.json", idPart)
|
fileName := fmt.Sprintf("%s-%s.json", label, idPart)
|
||||||
|
|
||||||
|
metadata := map[string]any{
|
||||||
|
"type": "kiro",
|
||||||
|
"access_token": tokenData.AccessToken,
|
||||||
|
"refresh_token": tokenData.RefreshToken,
|
||||||
|
"profile_arn": tokenData.ProfileArn,
|
||||||
|
"expires_at": tokenData.ExpiresAt,
|
||||||
|
"auth_method": tokenData.AuthMethod,
|
||||||
|
"provider": tokenData.Provider,
|
||||||
|
"client_id": tokenData.ClientID,
|
||||||
|
"client_secret": tokenData.ClientSecret,
|
||||||
|
"email": tokenData.Email,
|
||||||
|
}
|
||||||
|
|
||||||
|
// Add IDC-specific fields if present
|
||||||
|
if tokenData.StartURL != "" {
|
||||||
|
metadata["start_url"] = tokenData.StartURL
|
||||||
|
}
|
||||||
|
if tokenData.Region != "" {
|
||||||
|
metadata["region"] = tokenData.Region
|
||||||
|
}
|
||||||
|
|
||||||
|
attributes := map[string]string{
|
||||||
|
"profile_arn": tokenData.ProfileArn,
|
||||||
|
"source": source,
|
||||||
|
"email": tokenData.Email,
|
||||||
|
}
|
||||||
|
|
||||||
|
// Add IDC-specific attributes if present
|
||||||
|
if tokenData.AuthMethod == "idc" {
|
||||||
|
attributes["source"] = "aws-idc"
|
||||||
|
if tokenData.StartURL != "" {
|
||||||
|
attributes["start_url"] = tokenData.StartURL
|
||||||
|
}
|
||||||
|
if tokenData.Region != "" {
|
||||||
|
attributes["region"] = tokenData.Region
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
record := &coreauth.Auth{
|
record := &coreauth.Auth{
|
||||||
ID: fileName,
|
ID: fileName,
|
||||||
Provider: "kiro",
|
Provider: "kiro",
|
||||||
FileName: fileName,
|
FileName: fileName,
|
||||||
Label: "kiro-aws",
|
Label: label,
|
||||||
Status: coreauth.StatusActive,
|
Status: coreauth.StatusActive,
|
||||||
CreatedAt: now,
|
CreatedAt: now,
|
||||||
UpdatedAt: now,
|
UpdatedAt: now,
|
||||||
Metadata: map[string]any{
|
Metadata: metadata,
|
||||||
"type": "kiro",
|
Attributes: attributes,
|
||||||
"access_token": tokenData.AccessToken,
|
|
||||||
"refresh_token": tokenData.RefreshToken,
|
|
||||||
"profile_arn": tokenData.ProfileArn,
|
|
||||||
"expires_at": tokenData.ExpiresAt,
|
|
||||||
"auth_method": tokenData.AuthMethod,
|
|
||||||
"provider": tokenData.Provider,
|
|
||||||
"client_id": tokenData.ClientID,
|
|
||||||
"client_secret": tokenData.ClientSecret,
|
|
||||||
"email": tokenData.Email,
|
|
||||||
},
|
|
||||||
Attributes: map[string]string{
|
|
||||||
"profile_arn": tokenData.ProfileArn,
|
|
||||||
"source": "aws-builder-id",
|
|
||||||
"email": tokenData.Email,
|
|
||||||
},
|
|
||||||
// NextRefreshAfter is aligned with RefreshLead (5min)
|
// NextRefreshAfter is aligned with RefreshLead (5min)
|
||||||
NextRefreshAfter: expiresAt.Add(-5 * time.Minute),
|
NextRefreshAfter: expiresAt.Add(-5 * time.Minute),
|
||||||
}
|
}
|
||||||
@@ -117,6 +134,23 @@ func (a *KiroAuthenticator) Login(ctx context.Context, cfg *config.Config, opts
|
|||||||
return record, nil
|
return record, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Login performs OAuth login for Kiro with AWS (Builder ID or IDC).
|
||||||
|
// This shows a method selection prompt and handles both flows.
|
||||||
|
func (a *KiroAuthenticator) Login(ctx context.Context, cfg *config.Config, opts *LoginOptions) (*coreauth.Auth, error) {
|
||||||
|
if cfg == nil {
|
||||||
|
return nil, fmt.Errorf("kiro auth: configuration is required")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Use the unified method selection flow (Builder ID or IDC)
|
||||||
|
ssoClient := kiroauth.NewSSOOIDCClient(cfg)
|
||||||
|
tokenData, err := ssoClient.LoginWithMethodSelection(ctx)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("login failed: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return a.createAuthRecord(tokenData, "aws")
|
||||||
|
}
|
||||||
|
|
||||||
// LoginWithAuthCode performs OAuth login for Kiro with AWS Builder ID using authorization code flow.
|
// LoginWithAuthCode performs OAuth login for Kiro with AWS Builder ID using authorization code flow.
|
||||||
// This provides a better UX than device code flow as it uses automatic browser callback.
|
// This provides a better UX than device code flow as it uses automatic browser callback.
|
||||||
func (a *KiroAuthenticator) LoginWithAuthCode(ctx context.Context, cfg *config.Config, opts *LoginOptions) (*coreauth.Auth, error) {
|
func (a *KiroAuthenticator) LoginWithAuthCode(ctx context.Context, cfg *config.Config, opts *LoginOptions) (*coreauth.Auth, error) {
|
||||||
@@ -388,15 +422,23 @@ func (a *KiroAuthenticator) Refresh(ctx context.Context, cfg *config.Config, aut
|
|||||||
clientID, _ := auth.Metadata["client_id"].(string)
|
clientID, _ := auth.Metadata["client_id"].(string)
|
||||||
clientSecret, _ := auth.Metadata["client_secret"].(string)
|
clientSecret, _ := auth.Metadata["client_secret"].(string)
|
||||||
authMethod, _ := auth.Metadata["auth_method"].(string)
|
authMethod, _ := auth.Metadata["auth_method"].(string)
|
||||||
|
startURL, _ := auth.Metadata["start_url"].(string)
|
||||||
|
region, _ := auth.Metadata["region"].(string)
|
||||||
|
|
||||||
var tokenData *kiroauth.KiroTokenData
|
var tokenData *kiroauth.KiroTokenData
|
||||||
var err error
|
var err error
|
||||||
|
|
||||||
// Use SSO OIDC refresh for AWS Builder ID, otherwise use Kiro's OAuth refresh endpoint
|
ssoClient := kiroauth.NewSSOOIDCClient(cfg)
|
||||||
if clientID != "" && clientSecret != "" && authMethod == "builder-id" {
|
|
||||||
ssoClient := kiroauth.NewSSOOIDCClient(cfg)
|
// Use SSO OIDC refresh for AWS Builder ID or IDC, otherwise use Kiro's OAuth refresh endpoint
|
||||||
|
switch {
|
||||||
|
case clientID != "" && clientSecret != "" && authMethod == "idc" && region != "":
|
||||||
|
// IDC refresh with region-specific endpoint
|
||||||
|
tokenData, err = ssoClient.RefreshTokenWithRegion(ctx, clientID, clientSecret, refreshToken, region, startURL)
|
||||||
|
case clientID != "" && clientSecret != "" && authMethod == "builder-id":
|
||||||
|
// Builder ID refresh with default endpoint
|
||||||
tokenData, err = ssoClient.RefreshToken(ctx, clientID, clientSecret, refreshToken)
|
tokenData, err = ssoClient.RefreshToken(ctx, clientID, clientSecret, refreshToken)
|
||||||
} else {
|
default:
|
||||||
// Fallback to Kiro's refresh endpoint (for social auth: Google/GitHub)
|
// Fallback to Kiro's refresh endpoint (for social auth: Google/GitHub)
|
||||||
oauth := kiroauth.NewKiroOAuth(cfg)
|
oauth := kiroauth.NewKiroOAuth(cfg)
|
||||||
tokenData, err = oauth.RefreshToken(ctx, refreshToken)
|
tokenData, err = oauth.RefreshToken(ctx, refreshToken)
|
||||||
|
|||||||
@@ -203,10 +203,10 @@ func (m *Manager) Register(ctx context.Context, auth *Auth) (*Auth, error) {
|
|||||||
if auth == nil {
|
if auth == nil {
|
||||||
return nil, nil
|
return nil, nil
|
||||||
}
|
}
|
||||||
auth.EnsureIndex()
|
|
||||||
if auth.ID == "" {
|
if auth.ID == "" {
|
||||||
auth.ID = uuid.NewString()
|
auth.ID = uuid.NewString()
|
||||||
}
|
}
|
||||||
|
auth.EnsureIndex()
|
||||||
m.mu.Lock()
|
m.mu.Lock()
|
||||||
m.auths[auth.ID] = auth.Clone()
|
m.auths[auth.ID] = auth.Clone()
|
||||||
m.mu.Unlock()
|
m.mu.Unlock()
|
||||||
@@ -221,7 +221,7 @@ func (m *Manager) Update(ctx context.Context, auth *Auth) (*Auth, error) {
|
|||||||
return nil, nil
|
return nil, nil
|
||||||
}
|
}
|
||||||
m.mu.Lock()
|
m.mu.Lock()
|
||||||
if existing, ok := m.auths[auth.ID]; ok && existing != nil && !auth.indexAssigned && auth.Index == 0 {
|
if existing, ok := m.auths[auth.ID]; ok && existing != nil && !auth.indexAssigned && auth.Index == "" {
|
||||||
auth.Index = existing.Index
|
auth.Index = existing.Index
|
||||||
auth.indexAssigned = existing.indexAssigned
|
auth.indexAssigned = existing.indexAssigned
|
||||||
}
|
}
|
||||||
@@ -263,7 +263,6 @@ func (m *Manager) Execute(ctx context.Context, providers []string, req cliproxye
|
|||||||
return cliproxyexecutor.Response{}, &Error{Code: "provider_not_found", Message: "no provider supplied"}
|
return cliproxyexecutor.Response{}, &Error{Code: "provider_not_found", Message: "no provider supplied"}
|
||||||
}
|
}
|
||||||
rotated := m.rotateProviders(req.Model, normalized)
|
rotated := m.rotateProviders(req.Model, normalized)
|
||||||
defer m.advanceProviderCursor(req.Model, normalized)
|
|
||||||
|
|
||||||
retryTimes, maxWait := m.retrySettings()
|
retryTimes, maxWait := m.retrySettings()
|
||||||
attempts := retryTimes + 1
|
attempts := retryTimes + 1
|
||||||
@@ -302,7 +301,6 @@ func (m *Manager) ExecuteCount(ctx context.Context, providers []string, req clip
|
|||||||
return cliproxyexecutor.Response{}, &Error{Code: "provider_not_found", Message: "no provider supplied"}
|
return cliproxyexecutor.Response{}, &Error{Code: "provider_not_found", Message: "no provider supplied"}
|
||||||
}
|
}
|
||||||
rotated := m.rotateProviders(req.Model, normalized)
|
rotated := m.rotateProviders(req.Model, normalized)
|
||||||
defer m.advanceProviderCursor(req.Model, normalized)
|
|
||||||
|
|
||||||
retryTimes, maxWait := m.retrySettings()
|
retryTimes, maxWait := m.retrySettings()
|
||||||
attempts := retryTimes + 1
|
attempts := retryTimes + 1
|
||||||
@@ -341,7 +339,6 @@ func (m *Manager) ExecuteStream(ctx context.Context, providers []string, req cli
|
|||||||
return nil, &Error{Code: "provider_not_found", Message: "no provider supplied"}
|
return nil, &Error{Code: "provider_not_found", Message: "no provider supplied"}
|
||||||
}
|
}
|
||||||
rotated := m.rotateProviders(req.Model, normalized)
|
rotated := m.rotateProviders(req.Model, normalized)
|
||||||
defer m.advanceProviderCursor(req.Model, normalized)
|
|
||||||
|
|
||||||
retryTimes, maxWait := m.retrySettings()
|
retryTimes, maxWait := m.retrySettings()
|
||||||
attempts := retryTimes + 1
|
attempts := retryTimes + 1
|
||||||
@@ -640,13 +637,20 @@ func (m *Manager) normalizeProviders(providers []string) []string {
|
|||||||
return result
|
return result
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// rotateProviders returns a rotated view of the providers list starting from the
|
||||||
|
// current offset for the model, and atomically increments the offset for the next call.
|
||||||
|
// This ensures concurrent requests get different starting providers.
|
||||||
func (m *Manager) rotateProviders(model string, providers []string) []string {
|
func (m *Manager) rotateProviders(model string, providers []string) []string {
|
||||||
if len(providers) == 0 {
|
if len(providers) == 0 {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
m.mu.RLock()
|
|
||||||
|
// Atomic read-and-increment: get current offset and advance cursor in one lock
|
||||||
|
m.mu.Lock()
|
||||||
offset := m.providerOffsets[model]
|
offset := m.providerOffsets[model]
|
||||||
m.mu.RUnlock()
|
m.providerOffsets[model] = (offset + 1) % len(providers)
|
||||||
|
m.mu.Unlock()
|
||||||
|
|
||||||
if len(providers) > 0 {
|
if len(providers) > 0 {
|
||||||
offset %= len(providers)
|
offset %= len(providers)
|
||||||
}
|
}
|
||||||
@@ -662,19 +666,6 @@ func (m *Manager) rotateProviders(model string, providers []string) []string {
|
|||||||
return rotated
|
return rotated
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *Manager) advanceProviderCursor(model string, providers []string) {
|
|
||||||
if len(providers) == 0 {
|
|
||||||
m.mu.Lock()
|
|
||||||
delete(m.providerOffsets, model)
|
|
||||||
m.mu.Unlock()
|
|
||||||
return
|
|
||||||
}
|
|
||||||
m.mu.Lock()
|
|
||||||
current := m.providerOffsets[model]
|
|
||||||
m.providerOffsets[model] = (current + 1) % len(providers)
|
|
||||||
m.mu.Unlock()
|
|
||||||
}
|
|
||||||
|
|
||||||
func (m *Manager) retrySettings() (int, time.Duration) {
|
func (m *Manager) retrySettings() (int, time.Duration) {
|
||||||
if m == nil {
|
if m == nil {
|
||||||
return 0, 0
|
return 0, 0
|
||||||
@@ -1,11 +1,12 @@
|
|||||||
package auth
|
package auth
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"crypto/sha256"
|
||||||
|
"encoding/hex"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"strconv"
|
"strconv"
|
||||||
"strings"
|
"strings"
|
||||||
"sync"
|
"sync"
|
||||||
"sync/atomic"
|
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
baseauth "github.com/router-for-me/CLIProxyAPI/v6/internal/auth"
|
baseauth "github.com/router-for-me/CLIProxyAPI/v6/internal/auth"
|
||||||
@@ -15,8 +16,8 @@ import (
|
|||||||
type Auth struct {
|
type Auth struct {
|
||||||
// ID uniquely identifies the auth record across restarts.
|
// ID uniquely identifies the auth record across restarts.
|
||||||
ID string `json:"id"`
|
ID string `json:"id"`
|
||||||
// Index is a monotonically increasing runtime identifier used for diagnostics.
|
// Index is a stable runtime identifier derived from auth metadata (not persisted).
|
||||||
Index uint64 `json:"-"`
|
Index string `json:"-"`
|
||||||
// Provider is the upstream provider key (e.g. "gemini", "claude").
|
// Provider is the upstream provider key (e.g. "gemini", "claude").
|
||||||
Provider string `json:"provider"`
|
Provider string `json:"provider"`
|
||||||
// Prefix optionally namespaces models for routing (e.g., "teamA/gemini-3-pro-preview").
|
// Prefix optionally namespaces models for routing (e.g., "teamA/gemini-3-pro-preview").
|
||||||
@@ -94,12 +95,6 @@ type ModelState struct {
|
|||||||
UpdatedAt time.Time `json:"updated_at"`
|
UpdatedAt time.Time `json:"updated_at"`
|
||||||
}
|
}
|
||||||
|
|
||||||
var authIndexCounter atomic.Uint64
|
|
||||||
|
|
||||||
func nextAuthIndex() uint64 {
|
|
||||||
return authIndexCounter.Add(1) - 1
|
|
||||||
}
|
|
||||||
|
|
||||||
// Clone shallow copies the Auth structure, duplicating maps to avoid accidental mutation.
|
// Clone shallow copies the Auth structure, duplicating maps to avoid accidental mutation.
|
||||||
func (a *Auth) Clone() *Auth {
|
func (a *Auth) Clone() *Auth {
|
||||||
if a == nil {
|
if a == nil {
|
||||||
@@ -128,15 +123,41 @@ func (a *Auth) Clone() *Auth {
|
|||||||
return ©Auth
|
return ©Auth
|
||||||
}
|
}
|
||||||
|
|
||||||
// EnsureIndex returns the global index, assigning one if it was not set yet.
|
func stableAuthIndex(seed string) string {
|
||||||
func (a *Auth) EnsureIndex() uint64 {
|
seed = strings.TrimSpace(seed)
|
||||||
if a == nil {
|
if seed == "" {
|
||||||
return 0
|
return ""
|
||||||
}
|
}
|
||||||
if a.indexAssigned {
|
sum := sha256.Sum256([]byte(seed))
|
||||||
|
return hex.EncodeToString(sum[:8])
|
||||||
|
}
|
||||||
|
|
||||||
|
// EnsureIndex returns a stable index derived from the auth file name or API key.
|
||||||
|
func (a *Auth) EnsureIndex() string {
|
||||||
|
if a == nil {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
if a.indexAssigned && a.Index != "" {
|
||||||
return a.Index
|
return a.Index
|
||||||
}
|
}
|
||||||
idx := nextAuthIndex()
|
|
||||||
|
seed := strings.TrimSpace(a.FileName)
|
||||||
|
if seed != "" {
|
||||||
|
seed = "file:" + seed
|
||||||
|
} else if a.Attributes != nil {
|
||||||
|
if apiKey := strings.TrimSpace(a.Attributes["api_key"]); apiKey != "" {
|
||||||
|
seed = "api_key:" + apiKey
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if seed == "" {
|
||||||
|
if id := strings.TrimSpace(a.ID); id != "" {
|
||||||
|
seed = "id:" + id
|
||||||
|
} else {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
idx := stableAuthIndex(seed)
|
||||||
a.Index = idx
|
a.Index = idx
|
||||||
a.indexAssigned = true
|
a.indexAssigned = true
|
||||||
return idx
|
return idx
|
||||||
|
|||||||
@@ -14,7 +14,7 @@ type Record struct {
|
|||||||
Model string
|
Model string
|
||||||
APIKey string
|
APIKey string
|
||||||
AuthID string
|
AuthID string
|
||||||
AuthIndex uint64
|
AuthIndex string
|
||||||
Source string
|
Source string
|
||||||
RequestedAt time.Time
|
RequestedAt time.Time
|
||||||
Failed bool
|
Failed bool
|
||||||
|
|||||||
Reference in New Issue
Block a user