mirror of
https://github.com/router-for-me/CLIProxyAPIPlus.git
synced 2026-04-21 07:02:05 +00:00
Merge branch 'main' into plus
This commit is contained in:
@@ -190,10 +190,11 @@ type BaseAPIHandler struct {
|
||||
// Returns:
|
||||
// - *BaseAPIHandler: A new API handlers instance
|
||||
func NewBaseAPIHandlers(cfg *config.SDKConfig, authManager *coreauth.Manager) *BaseAPIHandler {
|
||||
return &BaseAPIHandler{
|
||||
h := &BaseAPIHandler{
|
||||
Cfg: cfg,
|
||||
AuthManager: authManager,
|
||||
}
|
||||
return h
|
||||
}
|
||||
|
||||
// UpdateClients updates the handlers' client list and configuration.
|
||||
|
||||
37
sdk/api/handlers/openai/endpoint_compat.go
Normal file
37
sdk/api/handlers/openai/endpoint_compat.go
Normal file
@@ -0,0 +1,37 @@
|
||||
package openai
|
||||
|
||||
import "github.com/router-for-me/CLIProxyAPI/v6/internal/registry"
|
||||
|
||||
const (
|
||||
openAIChatEndpoint = "/chat/completions"
|
||||
openAIResponsesEndpoint = "/responses"
|
||||
)
|
||||
|
||||
func resolveEndpointOverride(modelName, requestedEndpoint string) (string, bool) {
|
||||
if modelName == "" {
|
||||
return "", false
|
||||
}
|
||||
info := registry.GetGlobalRegistry().GetModelInfo(modelName, "")
|
||||
if info == nil || len(info.SupportedEndpoints) == 0 {
|
||||
return "", false
|
||||
}
|
||||
if endpointListContains(info.SupportedEndpoints, requestedEndpoint) {
|
||||
return "", false
|
||||
}
|
||||
if requestedEndpoint == openAIChatEndpoint && endpointListContains(info.SupportedEndpoints, openAIResponsesEndpoint) {
|
||||
return openAIResponsesEndpoint, true
|
||||
}
|
||||
if requestedEndpoint == openAIResponsesEndpoint && endpointListContains(info.SupportedEndpoints, openAIChatEndpoint) {
|
||||
return openAIChatEndpoint, true
|
||||
}
|
||||
return "", false
|
||||
}
|
||||
|
||||
func endpointListContains(items []string, value string) bool {
|
||||
for _, item := range items {
|
||||
if item == value {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
@@ -17,6 +17,7 @@ import (
|
||||
. "github.com/router-for-me/CLIProxyAPI/v6/internal/constant"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/interfaces"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/registry"
|
||||
codexconverter "github.com/router-for-me/CLIProxyAPI/v6/internal/translator/codex/openai/chat-completions"
|
||||
responsesconverter "github.com/router-for-me/CLIProxyAPI/v6/internal/translator/openai/openai/responses"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/sdk/api/handlers"
|
||||
"github.com/tidwall/gjson"
|
||||
@@ -112,6 +113,23 @@ func (h *OpenAIAPIHandler) ChatCompletions(c *gin.Context) {
|
||||
streamResult := gjson.GetBytes(rawJSON, "stream")
|
||||
stream := streamResult.Type == gjson.True
|
||||
|
||||
modelName := gjson.GetBytes(rawJSON, "model").String()
|
||||
if overrideEndpoint, ok := resolveEndpointOverride(modelName, openAIChatEndpoint); ok && overrideEndpoint == openAIResponsesEndpoint {
|
||||
originalChat := rawJSON
|
||||
if shouldTreatAsResponsesFormat(rawJSON) {
|
||||
// Already responses-style payload; no conversion needed.
|
||||
} else {
|
||||
rawJSON = codexconverter.ConvertOpenAIRequestToCodex(modelName, rawJSON, stream)
|
||||
}
|
||||
stream = gjson.GetBytes(rawJSON, "stream").Bool()
|
||||
if stream {
|
||||
h.handleStreamingResponseViaResponses(c, rawJSON, originalChat)
|
||||
} else {
|
||||
h.handleNonStreamingResponseViaResponses(c, rawJSON, originalChat)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// Some clients send OpenAI Responses-format payloads to /v1/chat/completions.
|
||||
// Convert them to Chat Completions so downstream translators preserve tool metadata.
|
||||
if shouldTreatAsResponsesFormat(rawJSON) {
|
||||
@@ -245,6 +263,76 @@ func convertCompletionsRequestToChatCompletions(rawJSON []byte) []byte {
|
||||
return []byte(out)
|
||||
}
|
||||
|
||||
func convertResponsesObjectToChatCompletion(ctx context.Context, modelName string, originalChatJSON, responsesRequestJSON, responsesPayload []byte) []byte {
|
||||
if len(responsesPayload) == 0 {
|
||||
return nil
|
||||
}
|
||||
wrapped := wrapResponsesPayloadAsCompleted(responsesPayload)
|
||||
if len(wrapped) == 0 {
|
||||
return nil
|
||||
}
|
||||
var param any
|
||||
converted := codexconverter.ConvertCodexResponseToOpenAINonStream(ctx, modelName, originalChatJSON, responsesRequestJSON, wrapped, ¶m)
|
||||
if converted == "" {
|
||||
return nil
|
||||
}
|
||||
return []byte(converted)
|
||||
}
|
||||
|
||||
func wrapResponsesPayloadAsCompleted(payload []byte) []byte {
|
||||
if gjson.GetBytes(payload, "type").Exists() {
|
||||
return payload
|
||||
}
|
||||
if gjson.GetBytes(payload, "object").String() != "response" {
|
||||
return payload
|
||||
}
|
||||
wrapped := `{"type":"response.completed","response":{}}`
|
||||
wrapped, _ = sjson.SetRaw(wrapped, "response", string(payload))
|
||||
return []byte(wrapped)
|
||||
}
|
||||
|
||||
func writeConvertedResponsesChunk(c *gin.Context, ctx context.Context, modelName string, originalChatJSON, responsesRequestJSON, chunk []byte, param *any) {
|
||||
outputs := codexconverter.ConvertCodexResponseToOpenAI(ctx, modelName, originalChatJSON, responsesRequestJSON, chunk, param)
|
||||
for _, out := range outputs {
|
||||
if out == "" {
|
||||
continue
|
||||
}
|
||||
_, _ = fmt.Fprintf(c.Writer, "data: %s\n\n", out)
|
||||
}
|
||||
}
|
||||
|
||||
func (h *OpenAIAPIHandler) forwardResponsesAsChatStream(c *gin.Context, flusher http.Flusher, cancel func(error), data <-chan []byte, errs <-chan *interfaces.ErrorMessage, ctx context.Context, modelName string, originalChatJSON, responsesRequestJSON []byte, param *any) {
|
||||
h.ForwardStream(c, flusher, cancel, data, errs, handlers.StreamForwardOptions{
|
||||
WriteChunk: func(chunk []byte) {
|
||||
outputs := codexconverter.ConvertCodexResponseToOpenAI(ctx, modelName, originalChatJSON, responsesRequestJSON, chunk, param)
|
||||
for _, out := range outputs {
|
||||
if out == "" {
|
||||
continue
|
||||
}
|
||||
_, _ = fmt.Fprintf(c.Writer, "data: %s\n\n", out)
|
||||
}
|
||||
},
|
||||
WriteTerminalError: func(errMsg *interfaces.ErrorMessage) {
|
||||
if errMsg == nil {
|
||||
return
|
||||
}
|
||||
status := http.StatusInternalServerError
|
||||
if errMsg.StatusCode > 0 {
|
||||
status = errMsg.StatusCode
|
||||
}
|
||||
errText := http.StatusText(status)
|
||||
if errMsg.Error != nil && errMsg.Error.Error() != "" {
|
||||
errText = errMsg.Error.Error()
|
||||
}
|
||||
body := handlers.BuildErrorResponseBody(status, errText)
|
||||
_, _ = fmt.Fprintf(c.Writer, "data: %s\n\n", string(body))
|
||||
},
|
||||
WriteDone: func() {
|
||||
_, _ = fmt.Fprintf(c.Writer, "data: [DONE]\n\n")
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
// convertChatCompletionsResponseToCompletions converts chat completions API response back to completions format.
|
||||
// This ensures the completions endpoint returns data in the expected format.
|
||||
//
|
||||
@@ -435,6 +523,30 @@ func (h *OpenAIAPIHandler) handleNonStreamingResponse(c *gin.Context, rawJSON []
|
||||
cliCancel()
|
||||
}
|
||||
|
||||
func (h *OpenAIAPIHandler) handleNonStreamingResponseViaResponses(c *gin.Context, rawJSON []byte, originalChatJSON []byte) {
|
||||
c.Header("Content-Type", "application/json")
|
||||
|
||||
modelName := gjson.GetBytes(rawJSON, "model").String()
|
||||
cliCtx, cliCancel := h.GetContextWithCancel(h, c, context.Background())
|
||||
resp, errMsg := h.ExecuteWithAuthManager(cliCtx, OpenaiResponse, modelName, rawJSON, h.GetAlt(c))
|
||||
if errMsg != nil {
|
||||
h.WriteErrorResponse(c, errMsg)
|
||||
cliCancel(errMsg.Error)
|
||||
return
|
||||
}
|
||||
converted := convertResponsesObjectToChatCompletion(cliCtx, modelName, originalChatJSON, rawJSON, resp)
|
||||
if converted == nil {
|
||||
h.WriteErrorResponse(c, &interfaces.ErrorMessage{
|
||||
StatusCode: http.StatusInternalServerError,
|
||||
Error: fmt.Errorf("failed to convert response to chat completion format"),
|
||||
})
|
||||
cliCancel(fmt.Errorf("response conversion failed"))
|
||||
return
|
||||
}
|
||||
_, _ = c.Writer.Write(converted)
|
||||
cliCancel()
|
||||
}
|
||||
|
||||
// handleStreamingResponse handles streaming responses for Gemini models.
|
||||
// It establishes a streaming connection with the backend service and forwards
|
||||
// the response chunks to the client in real-time using Server-Sent Events.
|
||||
@@ -509,6 +621,67 @@ func (h *OpenAIAPIHandler) handleStreamingResponse(c *gin.Context, rawJSON []byt
|
||||
}
|
||||
}
|
||||
|
||||
func (h *OpenAIAPIHandler) handleStreamingResponseViaResponses(c *gin.Context, rawJSON []byte, originalChatJSON []byte) {
|
||||
flusher, ok := c.Writer.(http.Flusher)
|
||||
if !ok {
|
||||
c.JSON(http.StatusInternalServerError, handlers.ErrorResponse{
|
||||
Error: handlers.ErrorDetail{
|
||||
Message: "Streaming not supported",
|
||||
Type: "server_error",
|
||||
},
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
modelName := gjson.GetBytes(rawJSON, "model").String()
|
||||
cliCtx, cliCancel := h.GetContextWithCancel(h, c, context.Background())
|
||||
dataChan, errChan := h.ExecuteStreamWithAuthManager(cliCtx, OpenaiResponse, modelName, rawJSON, h.GetAlt(c))
|
||||
var param any
|
||||
|
||||
setSSEHeaders := func() {
|
||||
c.Header("Content-Type", "text/event-stream")
|
||||
c.Header("Cache-Control", "no-cache")
|
||||
c.Header("Connection", "keep-alive")
|
||||
c.Header("Access-Control-Allow-Origin", "*")
|
||||
}
|
||||
|
||||
// Peek for first usable chunk
|
||||
for {
|
||||
select {
|
||||
case <-c.Request.Context().Done():
|
||||
cliCancel(c.Request.Context().Err())
|
||||
return
|
||||
case errMsg, ok := <-errChan:
|
||||
if !ok {
|
||||
errChan = nil
|
||||
continue
|
||||
}
|
||||
h.WriteErrorResponse(c, errMsg)
|
||||
if errMsg != nil {
|
||||
cliCancel(errMsg.Error)
|
||||
} else {
|
||||
cliCancel(nil)
|
||||
}
|
||||
return
|
||||
case chunk, ok := <-dataChan:
|
||||
if !ok {
|
||||
setSSEHeaders()
|
||||
_, _ = fmt.Fprintf(c.Writer, "data: [DONE]\n\n")
|
||||
flusher.Flush()
|
||||
cliCancel(nil)
|
||||
return
|
||||
}
|
||||
|
||||
setSSEHeaders()
|
||||
writeConvertedResponsesChunk(c, cliCtx, modelName, originalChatJSON, rawJSON, chunk, ¶m)
|
||||
flusher.Flush()
|
||||
|
||||
h.forwardResponsesAsChatStream(c, flusher, func(err error) { cliCancel(err) }, dataChan, errChan, cliCtx, modelName, originalChatJSON, rawJSON, ¶m)
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// handleCompletionsNonStreamingResponse handles non-streaming completions responses.
|
||||
// It converts completions request to chat completions format, sends to backend,
|
||||
// then converts the response back to completions format before sending to client.
|
||||
|
||||
@@ -16,6 +16,7 @@ import (
|
||||
. "github.com/router-for-me/CLIProxyAPI/v6/internal/constant"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/interfaces"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/registry"
|
||||
responsesconverter "github.com/router-for-me/CLIProxyAPI/v6/internal/translator/openai/openai/responses"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/sdk/api/handlers"
|
||||
"github.com/tidwall/gjson"
|
||||
)
|
||||
@@ -83,7 +84,21 @@ func (h *OpenAIResponsesAPIHandler) Responses(c *gin.Context) {
|
||||
|
||||
// Check if the client requested a streaming response.
|
||||
streamResult := gjson.GetBytes(rawJSON, "stream")
|
||||
if streamResult.Type == gjson.True {
|
||||
stream := streamResult.Type == gjson.True
|
||||
|
||||
modelName := gjson.GetBytes(rawJSON, "model").String()
|
||||
if overrideEndpoint, ok := resolveEndpointOverride(modelName, openAIResponsesEndpoint); ok && overrideEndpoint == openAIChatEndpoint {
|
||||
chatJSON := responsesconverter.ConvertOpenAIResponsesRequestToOpenAIChatCompletions(modelName, rawJSON, stream)
|
||||
stream = gjson.GetBytes(chatJSON, "stream").Bool()
|
||||
if stream {
|
||||
h.handleStreamingResponseViaChat(c, rawJSON, chatJSON)
|
||||
} else {
|
||||
h.handleNonStreamingResponseViaChat(c, rawJSON, chatJSON)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
if stream {
|
||||
h.handleStreamingResponse(c, rawJSON)
|
||||
} else {
|
||||
h.handleNonStreamingResponse(c, rawJSON)
|
||||
@@ -116,6 +131,31 @@ func (h *OpenAIResponsesAPIHandler) handleNonStreamingResponse(c *gin.Context, r
|
||||
cliCancel()
|
||||
}
|
||||
|
||||
func (h *OpenAIResponsesAPIHandler) handleNonStreamingResponseViaChat(c *gin.Context, originalResponsesJSON, chatJSON []byte) {
|
||||
c.Header("Content-Type", "application/json")
|
||||
|
||||
modelName := gjson.GetBytes(chatJSON, "model").String()
|
||||
cliCtx, cliCancel := h.GetContextWithCancel(h, c, context.Background())
|
||||
resp, errMsg := h.ExecuteWithAuthManager(cliCtx, OpenAI, modelName, chatJSON, "")
|
||||
if errMsg != nil {
|
||||
h.WriteErrorResponse(c, errMsg)
|
||||
cliCancel(errMsg.Error)
|
||||
return
|
||||
}
|
||||
var param any
|
||||
converted := responsesconverter.ConvertOpenAIChatCompletionsResponseToOpenAIResponsesNonStream(cliCtx, modelName, originalResponsesJSON, originalResponsesJSON, resp, ¶m)
|
||||
if converted == "" {
|
||||
h.WriteErrorResponse(c, &interfaces.ErrorMessage{
|
||||
StatusCode: http.StatusInternalServerError,
|
||||
Error: fmt.Errorf("failed to convert chat completion response to responses format"),
|
||||
})
|
||||
cliCancel(fmt.Errorf("response conversion failed"))
|
||||
return
|
||||
}
|
||||
_, _ = c.Writer.Write([]byte(converted))
|
||||
cliCancel()
|
||||
}
|
||||
|
||||
// handleStreamingResponse handles streaming responses for Gemini models.
|
||||
// It establishes a streaming connection with the backend service and forwards
|
||||
// the response chunks to the client in real-time using Server-Sent Events.
|
||||
@@ -196,6 +236,116 @@ func (h *OpenAIResponsesAPIHandler) handleStreamingResponse(c *gin.Context, rawJ
|
||||
}
|
||||
}
|
||||
|
||||
func (h *OpenAIResponsesAPIHandler) handleStreamingResponseViaChat(c *gin.Context, originalResponsesJSON, chatJSON []byte) {
|
||||
flusher, ok := c.Writer.(http.Flusher)
|
||||
if !ok {
|
||||
c.JSON(http.StatusInternalServerError, handlers.ErrorResponse{
|
||||
Error: handlers.ErrorDetail{
|
||||
Message: "Streaming not supported",
|
||||
Type: "server_error",
|
||||
},
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
modelName := gjson.GetBytes(chatJSON, "model").String()
|
||||
cliCtx, cliCancel := h.GetContextWithCancel(h, c, context.Background())
|
||||
dataChan, errChan := h.ExecuteStreamWithAuthManager(cliCtx, OpenAI, modelName, chatJSON, "")
|
||||
var param any
|
||||
|
||||
setSSEHeaders := func() {
|
||||
c.Header("Content-Type", "text/event-stream")
|
||||
c.Header("Cache-Control", "no-cache")
|
||||
c.Header("Connection", "keep-alive")
|
||||
c.Header("Access-Control-Allow-Origin", "*")
|
||||
}
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-c.Request.Context().Done():
|
||||
cliCancel(c.Request.Context().Err())
|
||||
return
|
||||
case errMsg, ok := <-errChan:
|
||||
if !ok {
|
||||
errChan = nil
|
||||
continue
|
||||
}
|
||||
h.WriteErrorResponse(c, errMsg)
|
||||
if errMsg != nil {
|
||||
cliCancel(errMsg.Error)
|
||||
} else {
|
||||
cliCancel(nil)
|
||||
}
|
||||
return
|
||||
case chunk, ok := <-dataChan:
|
||||
if !ok {
|
||||
setSSEHeaders()
|
||||
_, _ = c.Writer.Write([]byte("\n"))
|
||||
flusher.Flush()
|
||||
cliCancel(nil)
|
||||
return
|
||||
}
|
||||
|
||||
setSSEHeaders()
|
||||
writeChatAsResponsesChunk(c, cliCtx, modelName, originalResponsesJSON, chunk, ¶m)
|
||||
flusher.Flush()
|
||||
|
||||
h.forwardChatAsResponsesStream(c, flusher, func(err error) { cliCancel(err) }, dataChan, errChan, cliCtx, modelName, originalResponsesJSON, ¶m)
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func writeChatAsResponsesChunk(c *gin.Context, ctx context.Context, modelName string, originalResponsesJSON, chunk []byte, param *any) {
|
||||
outputs := responsesconverter.ConvertOpenAIChatCompletionsResponseToOpenAIResponses(ctx, modelName, originalResponsesJSON, originalResponsesJSON, chunk, param)
|
||||
for _, out := range outputs {
|
||||
if out == "" {
|
||||
continue
|
||||
}
|
||||
if bytes.HasPrefix([]byte(out), []byte("event:")) {
|
||||
_, _ = c.Writer.Write([]byte("\n"))
|
||||
}
|
||||
_, _ = c.Writer.Write([]byte(out))
|
||||
_, _ = c.Writer.Write([]byte("\n"))
|
||||
}
|
||||
}
|
||||
|
||||
func (h *OpenAIResponsesAPIHandler) forwardChatAsResponsesStream(c *gin.Context, flusher http.Flusher, cancel func(error), data <-chan []byte, errs <-chan *interfaces.ErrorMessage, ctx context.Context, modelName string, originalResponsesJSON []byte, param *any) {
|
||||
h.ForwardStream(c, flusher, cancel, data, errs, handlers.StreamForwardOptions{
|
||||
WriteChunk: func(chunk []byte) {
|
||||
outputs := responsesconverter.ConvertOpenAIChatCompletionsResponseToOpenAIResponses(ctx, modelName, originalResponsesJSON, originalResponsesJSON, chunk, param)
|
||||
for _, out := range outputs {
|
||||
if out == "" {
|
||||
continue
|
||||
}
|
||||
if bytes.HasPrefix([]byte(out), []byte("event:")) {
|
||||
_, _ = c.Writer.Write([]byte("\n"))
|
||||
}
|
||||
_, _ = c.Writer.Write([]byte(out))
|
||||
_, _ = c.Writer.Write([]byte("\n"))
|
||||
}
|
||||
},
|
||||
WriteTerminalError: func(errMsg *interfaces.ErrorMessage) {
|
||||
if errMsg == nil {
|
||||
return
|
||||
}
|
||||
status := http.StatusInternalServerError
|
||||
if errMsg.StatusCode > 0 {
|
||||
status = errMsg.StatusCode
|
||||
}
|
||||
errText := http.StatusText(status)
|
||||
if errMsg.Error != nil && errMsg.Error.Error() != "" {
|
||||
errText = errMsg.Error.Error()
|
||||
}
|
||||
body := handlers.BuildErrorResponseBody(status, errText)
|
||||
_, _ = fmt.Fprintf(c.Writer, "\nevent: error\ndata: %s\n\n", string(body))
|
||||
},
|
||||
WriteDone: func() {
|
||||
_, _ = c.Writer.Write([]byte("\n"))
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
func (h *OpenAIResponsesAPIHandler) forwardResponsesStream(c *gin.Context, flusher http.Flusher, cancel func(error), data <-chan []byte, errs <-chan *interfaces.ErrorMessage) {
|
||||
h.ForwardStream(c, flusher, cancel, data, errs, handlers.StreamForwardOptions{
|
||||
WriteChunk: func(chunk []byte) {
|
||||
|
||||
@@ -216,6 +216,15 @@ func (s *FileTokenStore) readAuthFile(path, baseDir string) (*cliproxyauth.Auth,
|
||||
return nil, fmt.Errorf("stat file: %w", err)
|
||||
}
|
||||
id := s.idFor(path, baseDir)
|
||||
|
||||
// Calculate NextRefreshAfter from expires_at (20 minutes before expiry)
|
||||
var nextRefreshAfter time.Time
|
||||
if expiresAtStr, ok := metadata["expires_at"].(string); ok && expiresAtStr != "" {
|
||||
if expiresAt, err := time.Parse(time.RFC3339, expiresAtStr); err == nil {
|
||||
nextRefreshAfter = expiresAt.Add(-20 * time.Minute)
|
||||
}
|
||||
}
|
||||
|
||||
auth := &cliproxyauth.Auth{
|
||||
ID: id,
|
||||
Provider: provider,
|
||||
@@ -227,7 +236,7 @@ func (s *FileTokenStore) readAuthFile(path, baseDir string) (*cliproxyauth.Auth,
|
||||
CreatedAt: info.ModTime(),
|
||||
UpdatedAt: info.ModTime(),
|
||||
LastRefreshedAt: time.Time{},
|
||||
NextRefreshAfter: time.Time{},
|
||||
NextRefreshAfter: nextRefreshAfter,
|
||||
}
|
||||
if email, ok := metadata["email"].(string); ok && email != "" {
|
||||
auth.Attributes["email"] = email
|
||||
|
||||
129
sdk/auth/github_copilot.go
Normal file
129
sdk/auth/github_copilot.go
Normal file
@@ -0,0 +1,129 @@
|
||||
package auth
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/auth/copilot"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/browser"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
|
||||
coreauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth"
|
||||
log "github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
// GitHubCopilotAuthenticator implements the OAuth device flow login for GitHub Copilot.
|
||||
type GitHubCopilotAuthenticator struct{}
|
||||
|
||||
// NewGitHubCopilotAuthenticator constructs a new GitHub Copilot authenticator.
|
||||
func NewGitHubCopilotAuthenticator() Authenticator {
|
||||
return &GitHubCopilotAuthenticator{}
|
||||
}
|
||||
|
||||
// Provider returns the provider key for github-copilot.
|
||||
func (GitHubCopilotAuthenticator) Provider() string {
|
||||
return "github-copilot"
|
||||
}
|
||||
|
||||
// RefreshLead returns nil since GitHub OAuth tokens don't expire in the traditional sense.
|
||||
// The token remains valid until the user revokes it or the Copilot subscription expires.
|
||||
func (GitHubCopilotAuthenticator) RefreshLead() *time.Duration {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Login initiates the GitHub device flow authentication for Copilot access.
|
||||
func (a GitHubCopilotAuthenticator) Login(ctx context.Context, cfg *config.Config, opts *LoginOptions) (*coreauth.Auth, error) {
|
||||
if cfg == nil {
|
||||
return nil, fmt.Errorf("cliproxy auth: configuration is required")
|
||||
}
|
||||
if opts == nil {
|
||||
opts = &LoginOptions{}
|
||||
}
|
||||
|
||||
authSvc := copilot.NewCopilotAuth(cfg)
|
||||
|
||||
// Start the device flow
|
||||
fmt.Println("Starting GitHub Copilot authentication...")
|
||||
deviceCode, err := authSvc.StartDeviceFlow(ctx)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("github-copilot: failed to start device flow: %w", err)
|
||||
}
|
||||
|
||||
// Display the user code and verification URL
|
||||
fmt.Printf("\nTo authenticate, please visit: %s\n", deviceCode.VerificationURI)
|
||||
fmt.Printf("And enter the code: %s\n\n", deviceCode.UserCode)
|
||||
|
||||
// Try to open the browser automatically
|
||||
if !opts.NoBrowser {
|
||||
if browser.IsAvailable() {
|
||||
if errOpen := browser.OpenURL(deviceCode.VerificationURI); errOpen != nil {
|
||||
log.Warnf("Failed to open browser automatically: %v", errOpen)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fmt.Println("Waiting for GitHub authorization...")
|
||||
fmt.Printf("(This will timeout in %d seconds if not authorized)\n", deviceCode.ExpiresIn)
|
||||
|
||||
// Wait for user authorization
|
||||
authBundle, err := authSvc.WaitForAuthorization(ctx, deviceCode)
|
||||
if err != nil {
|
||||
errMsg := copilot.GetUserFriendlyMessage(err)
|
||||
return nil, fmt.Errorf("github-copilot: %s", errMsg)
|
||||
}
|
||||
|
||||
// Verify the token can get a Copilot API token
|
||||
fmt.Println("Verifying Copilot access...")
|
||||
apiToken, err := authSvc.GetCopilotAPIToken(ctx, authBundle.TokenData.AccessToken)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("github-copilot: failed to verify Copilot access - you may not have an active Copilot subscription: %w", err)
|
||||
}
|
||||
|
||||
// Create the token storage
|
||||
tokenStorage := authSvc.CreateTokenStorage(authBundle)
|
||||
|
||||
// Build metadata with token information for the executor
|
||||
metadata := map[string]any{
|
||||
"type": "github-copilot",
|
||||
"username": authBundle.Username,
|
||||
"access_token": authBundle.TokenData.AccessToken,
|
||||
"token_type": authBundle.TokenData.TokenType,
|
||||
"scope": authBundle.TokenData.Scope,
|
||||
"timestamp": time.Now().UnixMilli(),
|
||||
}
|
||||
|
||||
if apiToken.ExpiresAt > 0 {
|
||||
metadata["api_token_expires_at"] = apiToken.ExpiresAt
|
||||
}
|
||||
|
||||
fileName := fmt.Sprintf("github-copilot-%s.json", authBundle.Username)
|
||||
|
||||
fmt.Printf("\nGitHub Copilot authentication successful for user: %s\n", authBundle.Username)
|
||||
|
||||
return &coreauth.Auth{
|
||||
ID: fileName,
|
||||
Provider: a.Provider(),
|
||||
FileName: fileName,
|
||||
Label: authBundle.Username,
|
||||
Storage: tokenStorage,
|
||||
Metadata: metadata,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// RefreshGitHubCopilotToken validates and returns the current token status.
|
||||
// GitHub OAuth tokens don't need traditional refresh - we just validate they still work.
|
||||
func RefreshGitHubCopilotToken(ctx context.Context, cfg *config.Config, storage *copilot.CopilotTokenStorage) error {
|
||||
if storage == nil || storage.AccessToken == "" {
|
||||
return fmt.Errorf("no token available")
|
||||
}
|
||||
|
||||
authSvc := copilot.NewCopilotAuth(cfg)
|
||||
|
||||
// Validate the token can still get a Copilot API token
|
||||
_, err := authSvc.GetCopilotAPIToken(ctx, storage.AccessToken)
|
||||
if err != nil {
|
||||
return fmt.Errorf("token validation failed: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
374
sdk/auth/kiro.go
Normal file
374
sdk/auth/kiro.go
Normal file
@@ -0,0 +1,374 @@
|
||||
package auth
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
kiroauth "github.com/router-for-me/CLIProxyAPI/v6/internal/auth/kiro"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
|
||||
coreauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth"
|
||||
)
|
||||
|
||||
// extractKiroIdentifier extracts a meaningful identifier for file naming.
|
||||
// Returns account name if provided, otherwise profile ARN ID, then client ID.
|
||||
// All extracted values are sanitized to prevent path injection attacks.
|
||||
func extractKiroIdentifier(accountName, profileArn, clientID string) string {
|
||||
// Priority 1: Use account name if provided
|
||||
if accountName != "" {
|
||||
return kiroauth.SanitizeEmailForFilename(accountName)
|
||||
}
|
||||
|
||||
// Priority 2: Use profile ARN ID part (sanitized to prevent path injection)
|
||||
if profileArn != "" {
|
||||
parts := strings.Split(profileArn, "/")
|
||||
if len(parts) >= 2 {
|
||||
// Sanitize the ARN component to prevent path traversal
|
||||
return kiroauth.SanitizeEmailForFilename(parts[len(parts)-1])
|
||||
}
|
||||
}
|
||||
|
||||
// Priority 3: Use client ID (for IDC auth without email/profileArn)
|
||||
if clientID != "" {
|
||||
return kiroauth.SanitizeEmailForFilename(clientID)
|
||||
}
|
||||
|
||||
// Fallback: timestamp
|
||||
return fmt.Sprintf("%d", time.Now().UnixNano()%100000)
|
||||
}
|
||||
|
||||
// KiroAuthenticator implements OAuth authentication for Kiro with Google login.
|
||||
type KiroAuthenticator struct{}
|
||||
|
||||
// NewKiroAuthenticator constructs a Kiro authenticator.
|
||||
func NewKiroAuthenticator() *KiroAuthenticator {
|
||||
return &KiroAuthenticator{}
|
||||
}
|
||||
|
||||
// Provider returns the provider key for the authenticator.
|
||||
func (a *KiroAuthenticator) Provider() string {
|
||||
return "kiro"
|
||||
}
|
||||
|
||||
// RefreshLead indicates how soon before expiry a refresh should be attempted.
|
||||
// Set to 20 minutes for proactive refresh before token expiry.
|
||||
func (a *KiroAuthenticator) RefreshLead() *time.Duration {
|
||||
d := 20 * time.Minute
|
||||
return &d
|
||||
}
|
||||
|
||||
// createAuthRecord creates an auth record from token data.
|
||||
func (a *KiroAuthenticator) createAuthRecord(tokenData *kiroauth.KiroTokenData, source string) (*coreauth.Auth, error) {
|
||||
// Parse expires_at
|
||||
expiresAt, err := time.Parse(time.RFC3339, tokenData.ExpiresAt)
|
||||
if err != nil {
|
||||
expiresAt = time.Now().Add(1 * time.Hour)
|
||||
}
|
||||
|
||||
// Determine label and identifier based on auth method
|
||||
var label, idPart string
|
||||
if tokenData.AuthMethod == "idc" {
|
||||
label = "kiro-idc"
|
||||
// For IDC auth, always use clientID as identifier
|
||||
if tokenData.ClientID != "" {
|
||||
idPart = kiroauth.SanitizeEmailForFilename(tokenData.ClientID)
|
||||
} else {
|
||||
idPart = fmt.Sprintf("%d", time.Now().UnixNano()%100000)
|
||||
}
|
||||
} else {
|
||||
label = fmt.Sprintf("kiro-%s", source)
|
||||
idPart = extractKiroIdentifier(tokenData.Email, tokenData.ProfileArn, tokenData.ClientID)
|
||||
}
|
||||
|
||||
now := time.Now()
|
||||
fileName := fmt.Sprintf("%s-%s.json", label, idPart)
|
||||
|
||||
metadata := map[string]any{
|
||||
"type": "kiro",
|
||||
"access_token": tokenData.AccessToken,
|
||||
"refresh_token": tokenData.RefreshToken,
|
||||
"profile_arn": tokenData.ProfileArn,
|
||||
"expires_at": tokenData.ExpiresAt,
|
||||
"auth_method": tokenData.AuthMethod,
|
||||
"provider": tokenData.Provider,
|
||||
"client_id": tokenData.ClientID,
|
||||
"client_secret": tokenData.ClientSecret,
|
||||
"email": tokenData.Email,
|
||||
}
|
||||
|
||||
// Add IDC-specific fields if present
|
||||
if tokenData.StartURL != "" {
|
||||
metadata["start_url"] = tokenData.StartURL
|
||||
}
|
||||
if tokenData.Region != "" {
|
||||
metadata["region"] = tokenData.Region
|
||||
}
|
||||
|
||||
attributes := map[string]string{
|
||||
"profile_arn": tokenData.ProfileArn,
|
||||
"source": source,
|
||||
"email": tokenData.Email,
|
||||
}
|
||||
|
||||
// Add IDC-specific attributes if present
|
||||
if tokenData.AuthMethod == "idc" {
|
||||
attributes["source"] = "aws-idc"
|
||||
if tokenData.StartURL != "" {
|
||||
attributes["start_url"] = tokenData.StartURL
|
||||
}
|
||||
if tokenData.Region != "" {
|
||||
attributes["region"] = tokenData.Region
|
||||
}
|
||||
}
|
||||
|
||||
record := &coreauth.Auth{
|
||||
ID: fileName,
|
||||
Provider: "kiro",
|
||||
FileName: fileName,
|
||||
Label: label,
|
||||
Status: coreauth.StatusActive,
|
||||
CreatedAt: now,
|
||||
UpdatedAt: now,
|
||||
Metadata: metadata,
|
||||
Attributes: attributes,
|
||||
// NextRefreshAfter: 20 minutes before expiry
|
||||
NextRefreshAfter: expiresAt.Add(-20 * time.Minute),
|
||||
}
|
||||
|
||||
if tokenData.Email != "" {
|
||||
fmt.Printf("\n✓ Kiro authentication completed successfully! (Account: %s)\n", tokenData.Email)
|
||||
} else {
|
||||
fmt.Println("\n✓ Kiro authentication completed successfully!")
|
||||
}
|
||||
|
||||
return record, nil
|
||||
}
|
||||
|
||||
// Login performs OAuth login for Kiro with AWS (Builder ID or IDC).
|
||||
// This shows a method selection prompt and handles both flows.
|
||||
func (a *KiroAuthenticator) Login(ctx context.Context, cfg *config.Config, opts *LoginOptions) (*coreauth.Auth, error) {
|
||||
if cfg == nil {
|
||||
return nil, fmt.Errorf("kiro auth: configuration is required")
|
||||
}
|
||||
|
||||
// Use the unified method selection flow (Builder ID or IDC)
|
||||
ssoClient := kiroauth.NewSSOOIDCClient(cfg)
|
||||
tokenData, err := ssoClient.LoginWithMethodSelection(ctx)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("login failed: %w", err)
|
||||
}
|
||||
|
||||
return a.createAuthRecord(tokenData, "aws")
|
||||
}
|
||||
|
||||
// LoginWithAuthCode performs OAuth login for Kiro with AWS Builder ID using authorization code flow.
|
||||
// This provides a better UX than device code flow as it uses automatic browser callback.
|
||||
func (a *KiroAuthenticator) LoginWithAuthCode(ctx context.Context, cfg *config.Config, opts *LoginOptions) (*coreauth.Auth, error) {
|
||||
if cfg == nil {
|
||||
return nil, fmt.Errorf("kiro auth: configuration is required")
|
||||
}
|
||||
|
||||
oauth := kiroauth.NewKiroOAuth(cfg)
|
||||
|
||||
// Use AWS Builder ID authorization code flow
|
||||
tokenData, err := oauth.LoginWithBuilderIDAuthCode(ctx)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("login failed: %w", err)
|
||||
}
|
||||
|
||||
// Parse expires_at
|
||||
expiresAt, err := time.Parse(time.RFC3339, tokenData.ExpiresAt)
|
||||
if err != nil {
|
||||
expiresAt = time.Now().Add(1 * time.Hour)
|
||||
}
|
||||
|
||||
// Extract identifier for file naming
|
||||
idPart := extractKiroIdentifier(tokenData.Email, tokenData.ProfileArn, tokenData.ClientID)
|
||||
|
||||
now := time.Now()
|
||||
fileName := fmt.Sprintf("kiro-aws-%s.json", idPart)
|
||||
|
||||
record := &coreauth.Auth{
|
||||
ID: fileName,
|
||||
Provider: "kiro",
|
||||
FileName: fileName,
|
||||
Label: "kiro-aws",
|
||||
Status: coreauth.StatusActive,
|
||||
CreatedAt: now,
|
||||
UpdatedAt: now,
|
||||
Metadata: map[string]any{
|
||||
"type": "kiro",
|
||||
"access_token": tokenData.AccessToken,
|
||||
"refresh_token": tokenData.RefreshToken,
|
||||
"profile_arn": tokenData.ProfileArn,
|
||||
"expires_at": tokenData.ExpiresAt,
|
||||
"auth_method": tokenData.AuthMethod,
|
||||
"provider": tokenData.Provider,
|
||||
"client_id": tokenData.ClientID,
|
||||
"client_secret": tokenData.ClientSecret,
|
||||
"email": tokenData.Email,
|
||||
},
|
||||
Attributes: map[string]string{
|
||||
"profile_arn": tokenData.ProfileArn,
|
||||
"source": "aws-builder-id-authcode",
|
||||
"email": tokenData.Email,
|
||||
},
|
||||
// NextRefreshAfter: 20 minutes before expiry
|
||||
NextRefreshAfter: expiresAt.Add(-20 * time.Minute),
|
||||
}
|
||||
|
||||
if tokenData.Email != "" {
|
||||
fmt.Printf("\n✓ Kiro authentication completed successfully! (Account: %s)\n", tokenData.Email)
|
||||
} else {
|
||||
fmt.Println("\n✓ Kiro authentication completed successfully!")
|
||||
}
|
||||
|
||||
return record, nil
|
||||
}
|
||||
|
||||
// LoginWithGoogle performs OAuth login for Kiro with Google.
|
||||
// NOTE: Google login is not available for third-party applications due to AWS Cognito restrictions.
|
||||
// Please use AWS Builder ID or import your token from Kiro IDE.
|
||||
func (a *KiroAuthenticator) LoginWithGoogle(ctx context.Context, cfg *config.Config, opts *LoginOptions) (*coreauth.Auth, error) {
|
||||
return nil, fmt.Errorf("Google login is not available for third-party applications due to AWS Cognito restrictions.\n\nAlternatives:\n 1. Use AWS Builder ID: cliproxy kiro --builder-id\n 2. Import token from Kiro IDE: cliproxy kiro --import\n\nTo get a token from Kiro IDE:\n 1. Open Kiro IDE and login with Google\n 2. Find: ~/.kiro/kiro-auth-token.json\n 3. Run: cliproxy kiro --import")
|
||||
}
|
||||
|
||||
// LoginWithGitHub performs OAuth login for Kiro with GitHub.
|
||||
// NOTE: GitHub login is not available for third-party applications due to AWS Cognito restrictions.
|
||||
// Please use AWS Builder ID or import your token from Kiro IDE.
|
||||
func (a *KiroAuthenticator) LoginWithGitHub(ctx context.Context, cfg *config.Config, opts *LoginOptions) (*coreauth.Auth, error) {
|
||||
return nil, fmt.Errorf("GitHub login is not available for third-party applications due to AWS Cognito restrictions.\n\nAlternatives:\n 1. Use AWS Builder ID: cliproxy kiro --builder-id\n 2. Import token from Kiro IDE: cliproxy kiro --import\n\nTo get a token from Kiro IDE:\n 1. Open Kiro IDE and login with GitHub\n 2. Find: ~/.kiro/kiro-auth-token.json\n 3. Run: cliproxy kiro --import")
|
||||
}
|
||||
|
||||
// ImportFromKiroIDE imports token from Kiro IDE's token file.
|
||||
func (a *KiroAuthenticator) ImportFromKiroIDE(ctx context.Context, cfg *config.Config) (*coreauth.Auth, error) {
|
||||
tokenData, err := kiroauth.LoadKiroIDEToken()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to load Kiro IDE token: %w", err)
|
||||
}
|
||||
|
||||
// Parse expires_at
|
||||
expiresAt, err := time.Parse(time.RFC3339, tokenData.ExpiresAt)
|
||||
if err != nil {
|
||||
expiresAt = time.Now().Add(1 * time.Hour)
|
||||
}
|
||||
|
||||
// Extract email from JWT if not already set (for imported tokens)
|
||||
if tokenData.Email == "" {
|
||||
tokenData.Email = kiroauth.ExtractEmailFromJWT(tokenData.AccessToken)
|
||||
}
|
||||
|
||||
// Extract identifier for file naming
|
||||
idPart := extractKiroIdentifier(tokenData.Email, tokenData.ProfileArn, tokenData.ClientID)
|
||||
// Sanitize provider to prevent path traversal (defense-in-depth)
|
||||
provider := kiroauth.SanitizeEmailForFilename(strings.ToLower(strings.TrimSpace(tokenData.Provider)))
|
||||
if provider == "" {
|
||||
provider = "imported" // Fallback for legacy tokens without provider
|
||||
}
|
||||
|
||||
now := time.Now()
|
||||
fileName := fmt.Sprintf("kiro-%s-%s.json", provider, idPart)
|
||||
|
||||
record := &coreauth.Auth{
|
||||
ID: fileName,
|
||||
Provider: "kiro",
|
||||
FileName: fileName,
|
||||
Label: fmt.Sprintf("kiro-%s", provider),
|
||||
Status: coreauth.StatusActive,
|
||||
CreatedAt: now,
|
||||
UpdatedAt: now,
|
||||
Metadata: map[string]any{
|
||||
"type": "kiro",
|
||||
"access_token": tokenData.AccessToken,
|
||||
"refresh_token": tokenData.RefreshToken,
|
||||
"profile_arn": tokenData.ProfileArn,
|
||||
"expires_at": tokenData.ExpiresAt,
|
||||
"auth_method": tokenData.AuthMethod,
|
||||
"provider": tokenData.Provider,
|
||||
"client_id": tokenData.ClientID,
|
||||
"client_secret": tokenData.ClientSecret,
|
||||
"email": tokenData.Email,
|
||||
"region": tokenData.Region,
|
||||
"start_url": tokenData.StartURL,
|
||||
},
|
||||
Attributes: map[string]string{
|
||||
"profile_arn": tokenData.ProfileArn,
|
||||
"source": "kiro-ide-import",
|
||||
"email": tokenData.Email,
|
||||
"region": tokenData.Region,
|
||||
},
|
||||
// NextRefreshAfter: 20 minutes before expiry
|
||||
NextRefreshAfter: expiresAt.Add(-20 * time.Minute),
|
||||
}
|
||||
|
||||
// Display the email if extracted
|
||||
if tokenData.Email != "" {
|
||||
fmt.Printf("\n✓ Imported Kiro token from IDE (Provider: %s, Account: %s)\n", tokenData.Provider, tokenData.Email)
|
||||
} else {
|
||||
fmt.Printf("\n✓ Imported Kiro token from IDE (Provider: %s)\n", tokenData.Provider)
|
||||
}
|
||||
|
||||
return record, nil
|
||||
}
|
||||
|
||||
// Refresh refreshes an expired Kiro token using AWS SSO OIDC.
|
||||
func (a *KiroAuthenticator) Refresh(ctx context.Context, cfg *config.Config, auth *coreauth.Auth) (*coreauth.Auth, error) {
|
||||
if auth == nil || auth.Metadata == nil {
|
||||
return nil, fmt.Errorf("invalid auth record")
|
||||
}
|
||||
|
||||
refreshToken, ok := auth.Metadata["refresh_token"].(string)
|
||||
if !ok || refreshToken == "" {
|
||||
return nil, fmt.Errorf("refresh token not found")
|
||||
}
|
||||
|
||||
clientID, _ := auth.Metadata["client_id"].(string)
|
||||
clientSecret, _ := auth.Metadata["client_secret"].(string)
|
||||
authMethod, _ := auth.Metadata["auth_method"].(string)
|
||||
startURL, _ := auth.Metadata["start_url"].(string)
|
||||
region, _ := auth.Metadata["region"].(string)
|
||||
|
||||
var tokenData *kiroauth.KiroTokenData
|
||||
var err error
|
||||
|
||||
ssoClient := kiroauth.NewSSOOIDCClient(cfg)
|
||||
|
||||
// Use SSO OIDC refresh for AWS Builder ID or IDC, otherwise use Kiro's OAuth refresh endpoint
|
||||
switch {
|
||||
case clientID != "" && clientSecret != "" && authMethod == "idc" && region != "":
|
||||
// IDC refresh with region-specific endpoint
|
||||
tokenData, err = ssoClient.RefreshTokenWithRegion(ctx, clientID, clientSecret, refreshToken, region, startURL)
|
||||
case clientID != "" && clientSecret != "" && authMethod == "builder-id":
|
||||
// Builder ID refresh with default endpoint
|
||||
tokenData, err = ssoClient.RefreshToken(ctx, clientID, clientSecret, refreshToken)
|
||||
default:
|
||||
// Fallback to Kiro's refresh endpoint (for social auth: Google/GitHub)
|
||||
oauth := kiroauth.NewKiroOAuth(cfg)
|
||||
tokenData, err = oauth.RefreshToken(ctx, refreshToken)
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("token refresh failed: %w", err)
|
||||
}
|
||||
|
||||
// Parse expires_at
|
||||
expiresAt, err := time.Parse(time.RFC3339, tokenData.ExpiresAt)
|
||||
if err != nil {
|
||||
expiresAt = time.Now().Add(1 * time.Hour)
|
||||
}
|
||||
|
||||
// Clone auth to avoid mutating the input parameter
|
||||
updated := auth.Clone()
|
||||
now := time.Now()
|
||||
updated.UpdatedAt = now
|
||||
updated.LastRefreshedAt = now
|
||||
updated.Metadata["access_token"] = tokenData.AccessToken
|
||||
updated.Metadata["refresh_token"] = tokenData.RefreshToken
|
||||
updated.Metadata["expires_at"] = tokenData.ExpiresAt
|
||||
updated.Metadata["last_refresh"] = now.Format(time.RFC3339) // For double-check optimization
|
||||
// NextRefreshAfter: 20 minutes before expiry
|
||||
updated.NextRefreshAfter = expiresAt.Add(-20 * time.Minute)
|
||||
|
||||
return updated, nil
|
||||
}
|
||||
@@ -74,3 +74,16 @@ func (m *Manager) Login(ctx context.Context, provider string, cfg *config.Config
|
||||
}
|
||||
return record, savedPath, nil
|
||||
}
|
||||
|
||||
// SaveAuth persists an auth record directly without going through the login flow.
|
||||
func (m *Manager) SaveAuth(record *coreauth.Auth, cfg *config.Config) (string, error) {
|
||||
if m.store == nil {
|
||||
return "", fmt.Errorf("no store configured")
|
||||
}
|
||||
if cfg != nil {
|
||||
if dirSetter, ok := m.store.(interface{ SetBaseDir(string) }); ok {
|
||||
dirSetter.SetBaseDir(cfg.AuthDir)
|
||||
}
|
||||
}
|
||||
return m.store.Save(context.Background(), record)
|
||||
}
|
||||
|
||||
@@ -14,6 +14,8 @@ func init() {
|
||||
registerRefreshLead("gemini", func() Authenticator { return NewGeminiAuthenticator() })
|
||||
registerRefreshLead("gemini-cli", func() Authenticator { return NewGeminiAuthenticator() })
|
||||
registerRefreshLead("antigravity", func() Authenticator { return NewAntigravityAuthenticator() })
|
||||
registerRefreshLead("kiro", func() Authenticator { return NewKiroAuthenticator() })
|
||||
registerRefreshLead("github-copilot", func() Authenticator { return NewGitHubCopilotAuthenticator() })
|
||||
}
|
||||
|
||||
func registerRefreshLead(provider string, factory func() Authenticator) {
|
||||
|
||||
@@ -47,9 +47,9 @@ type RefreshEvaluator interface {
|
||||
}
|
||||
|
||||
const (
|
||||
refreshCheckInterval = 5 * time.Second
|
||||
refreshCheckInterval = 30 * time.Second
|
||||
refreshPendingBackoff = time.Minute
|
||||
refreshFailureBackoff = 5 * time.Minute
|
||||
refreshFailureBackoff = 1 * time.Minute
|
||||
quotaBackoffBase = time.Second
|
||||
quotaBackoffMax = 30 * time.Minute
|
||||
)
|
||||
@@ -1987,7 +1987,9 @@ func (m *Manager) refreshAuth(ctx context.Context, id string) {
|
||||
updated.Runtime = auth.Runtime
|
||||
}
|
||||
updated.LastRefreshedAt = now
|
||||
updated.NextRefreshAfter = time.Time{}
|
||||
// Preserve NextRefreshAfter set by the Authenticator
|
||||
// If the Authenticator set a reasonable refresh time, it should not be overwritten
|
||||
// If the Authenticator did not set it (zero value), shouldRefresh will use default logic
|
||||
updated.LastError = nil
|
||||
updated.UpdatedAt = now
|
||||
_, _ = m.Update(ctx, updated)
|
||||
|
||||
@@ -221,7 +221,7 @@ func modelAliasChannel(auth *Auth) string {
|
||||
// and auth kind. Returns empty string if the provider/authKind combination doesn't support
|
||||
// OAuth model alias (e.g., API key authentication).
|
||||
//
|
||||
// Supported channels: gemini-cli, vertex, aistudio, antigravity, claude, codex, qwen, iflow.
|
||||
// Supported channels: gemini-cli, vertex, aistudio, antigravity, claude, codex, qwen, iflow, kiro, github-copilot.
|
||||
func OAuthModelAliasChannel(provider, authKind string) string {
|
||||
provider = strings.ToLower(strings.TrimSpace(provider))
|
||||
authKind = strings.ToLower(strings.TrimSpace(authKind))
|
||||
@@ -245,7 +245,7 @@ func OAuthModelAliasChannel(provider, authKind string) string {
|
||||
return ""
|
||||
}
|
||||
return "codex"
|
||||
case "gemini-cli", "aistudio", "antigravity", "qwen", "iflow":
|
||||
case "gemini-cli", "aistudio", "antigravity", "qwen", "iflow", "kiro", "github-copilot":
|
||||
return provider
|
||||
default:
|
||||
return ""
|
||||
|
||||
@@ -43,6 +43,15 @@ func TestResolveOAuthUpstreamModel_SuffixPreservation(t *testing.T) {
|
||||
input: "gemini-2.5-pro",
|
||||
want: "gemini-2.5-pro-exp-03-25",
|
||||
},
|
||||
{
|
||||
name: "kiro alias resolves",
|
||||
aliases: map[string][]internalconfig.OAuthModelAlias{
|
||||
"kiro": {{Name: "kiro-claude-sonnet-4-5", Alias: "sonnet"}},
|
||||
},
|
||||
channel: "kiro",
|
||||
input: "sonnet",
|
||||
want: "kiro-claude-sonnet-4-5",
|
||||
},
|
||||
{
|
||||
name: "config suffix takes priority",
|
||||
aliases: map[string][]internalconfig.OAuthModelAlias{
|
||||
@@ -152,6 +161,8 @@ func createAuthForChannel(channel string) *Auth {
|
||||
return &Auth{Provider: "qwen"}
|
||||
case "iflow":
|
||||
return &Auth{Provider: "iflow"}
|
||||
case "kiro":
|
||||
return &Auth{Provider: "kiro"}
|
||||
default:
|
||||
return &Auth{Provider: channel}
|
||||
}
|
||||
|
||||
@@ -227,6 +227,18 @@ func (a *Auth) AccountInfo() (string, string) {
|
||||
}
|
||||
}
|
||||
|
||||
// For GitHub provider, return username
|
||||
if strings.ToLower(a.Provider) == "github" {
|
||||
if a.Metadata != nil {
|
||||
if username, ok := a.Metadata["username"].(string); ok {
|
||||
username = strings.TrimSpace(username)
|
||||
if username != "" {
|
||||
return "oauth", username
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Check metadata for email first (OAuth-style auth)
|
||||
if a.Metadata != nil {
|
||||
if v, ok := a.Metadata["email"].(string); ok {
|
||||
|
||||
@@ -13,6 +13,7 @@ import (
|
||||
"time"
|
||||
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/api"
|
||||
kiroauth "github.com/router-for-me/CLIProxyAPI/v6/internal/auth/kiro"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/registry"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/runtime/executor"
|
||||
_ "github.com/router-for-me/CLIProxyAPI/v6/internal/usage"
|
||||
@@ -97,6 +98,16 @@ func (s *Service) RegisterUsagePlugin(plugin usage.Plugin) {
|
||||
usage.RegisterPlugin(plugin)
|
||||
}
|
||||
|
||||
// GetWatcher returns the underlying WatcherWrapper instance.
|
||||
// This allows external components (e.g., RefreshManager) to interact with the watcher.
|
||||
// Returns nil if the service or watcher is not initialized.
|
||||
func (s *Service) GetWatcher() *WatcherWrapper {
|
||||
if s == nil {
|
||||
return nil
|
||||
}
|
||||
return s.watcher
|
||||
}
|
||||
|
||||
// newDefaultAuthManager creates a default authentication manager with all supported providers.
|
||||
func newDefaultAuthManager() *sdkAuth.Manager {
|
||||
return sdkAuth.NewManager(
|
||||
@@ -379,6 +390,10 @@ func (s *Service) ensureExecutorsForAuth(a *coreauth.Auth) {
|
||||
s.coreManager.RegisterExecutor(executor.NewQwenExecutor(s.cfg))
|
||||
case "iflow":
|
||||
s.coreManager.RegisterExecutor(executor.NewIFlowExecutor(s.cfg))
|
||||
case "kiro":
|
||||
s.coreManager.RegisterExecutor(executor.NewKiroExecutor(s.cfg))
|
||||
case "github-copilot":
|
||||
s.coreManager.RegisterExecutor(executor.NewGitHubCopilotExecutor(s.cfg))
|
||||
default:
|
||||
providerKey := strings.ToLower(strings.TrimSpace(a.Provider))
|
||||
if providerKey == "" {
|
||||
@@ -570,6 +585,18 @@ func (s *Service) Run(ctx context.Context) error {
|
||||
}
|
||||
watcherWrapper.SetConfig(s.cfg)
|
||||
|
||||
// 方案 A: 连接 Kiro 后台刷新器回调到 Watcher
|
||||
// 当后台刷新器成功刷新 token 后,立即通知 Watcher 更新内存中的 Auth 对象
|
||||
// 这解决了后台刷新与内存 Auth 对象之间的时间差问题
|
||||
kiroauth.GetRefreshManager().SetOnTokenRefreshed(func(tokenID string, tokenData *kiroauth.KiroTokenData) {
|
||||
if tokenData == nil || watcherWrapper == nil {
|
||||
return
|
||||
}
|
||||
log.Debugf("kiro refresh callback: notifying watcher for token %s", tokenID)
|
||||
watcherWrapper.NotifyTokenRefreshed(tokenID, tokenData.AccessToken, tokenData.RefreshToken, tokenData.ExpiresAt)
|
||||
})
|
||||
log.Debug("kiro: connected background refresh callback to watcher")
|
||||
|
||||
watcherCtx, watcherCancel := context.WithCancel(context.Background())
|
||||
s.watcherCancel = watcherCancel
|
||||
if err = watcherWrapper.Start(watcherCtx); err != nil {
|
||||
@@ -767,6 +794,12 @@ func (s *Service) registerModelsForAuth(a *coreauth.Auth) {
|
||||
case "iflow":
|
||||
models = registry.GetIFlowModels()
|
||||
models = applyExcludedModels(models, excluded)
|
||||
case "github-copilot":
|
||||
models = registry.GetGitHubCopilotModels()
|
||||
models = applyExcludedModels(models, excluded)
|
||||
case "kiro":
|
||||
models = s.fetchKiroModels(a)
|
||||
models = applyExcludedModels(models, excluded)
|
||||
default:
|
||||
// Handle OpenAI-compatibility providers by name using config
|
||||
if s.cfg != nil {
|
||||
@@ -1328,3 +1361,201 @@ func applyOAuthModelAlias(cfg *config.Config, provider, authKind string, models
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
// fetchKiroModels attempts to dynamically fetch Kiro models from the API.
|
||||
// If dynamic fetch fails, it falls back to static registry.GetKiroModels().
|
||||
func (s *Service) fetchKiroModels(a *coreauth.Auth) []*ModelInfo {
|
||||
if a == nil {
|
||||
log.Debug("kiro: auth is nil, using static models")
|
||||
return registry.GetKiroModels()
|
||||
}
|
||||
|
||||
// Extract token data from auth attributes
|
||||
tokenData := s.extractKiroTokenData(a)
|
||||
if tokenData == nil || tokenData.AccessToken == "" {
|
||||
log.Debug("kiro: no valid token data in auth, using static models")
|
||||
return registry.GetKiroModels()
|
||||
}
|
||||
|
||||
// Create KiroAuth instance
|
||||
kAuth := kiroauth.NewKiroAuth(s.cfg)
|
||||
if kAuth == nil {
|
||||
log.Warn("kiro: failed to create KiroAuth instance, using static models")
|
||||
return registry.GetKiroModels()
|
||||
}
|
||||
|
||||
// Use timeout context for API call
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 15*time.Second)
|
||||
defer cancel()
|
||||
|
||||
// Attempt to fetch dynamic models
|
||||
apiModels, err := kAuth.ListAvailableModels(ctx, tokenData)
|
||||
if err != nil {
|
||||
log.Warnf("kiro: failed to fetch dynamic models: %v, using static models", err)
|
||||
return registry.GetKiroModels()
|
||||
}
|
||||
|
||||
if len(apiModels) == 0 {
|
||||
log.Debug("kiro: API returned no models, using static models")
|
||||
return registry.GetKiroModels()
|
||||
}
|
||||
|
||||
// Convert API models to ModelInfo
|
||||
models := convertKiroAPIModels(apiModels)
|
||||
|
||||
// Generate agentic variants
|
||||
models = generateKiroAgenticVariants(models)
|
||||
|
||||
log.Infof("kiro: successfully fetched %d models from API (including agentic variants)", len(models))
|
||||
return models
|
||||
}
|
||||
|
||||
// extractKiroTokenData extracts KiroTokenData from auth attributes and metadata.
|
||||
func (s *Service) extractKiroTokenData(a *coreauth.Auth) *kiroauth.KiroTokenData {
|
||||
if a == nil || a.Attributes == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
accessToken := strings.TrimSpace(a.Attributes["access_token"])
|
||||
if accessToken == "" {
|
||||
return nil
|
||||
}
|
||||
|
||||
tokenData := &kiroauth.KiroTokenData{
|
||||
AccessToken: accessToken,
|
||||
ProfileArn: strings.TrimSpace(a.Attributes["profile_arn"]),
|
||||
}
|
||||
|
||||
// Also try to get refresh token from metadata
|
||||
if a.Metadata != nil {
|
||||
if rt, ok := a.Metadata["refresh_token"].(string); ok {
|
||||
tokenData.RefreshToken = rt
|
||||
}
|
||||
}
|
||||
|
||||
return tokenData
|
||||
}
|
||||
|
||||
// convertKiroAPIModels converts Kiro API models to ModelInfo slice.
|
||||
func convertKiroAPIModels(apiModels []*kiroauth.KiroModel) []*ModelInfo {
|
||||
if len(apiModels) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
now := time.Now().Unix()
|
||||
models := make([]*ModelInfo, 0, len(apiModels))
|
||||
|
||||
for _, m := range apiModels {
|
||||
if m == nil || m.ModelID == "" {
|
||||
continue
|
||||
}
|
||||
|
||||
// Create model ID with kiro- prefix
|
||||
modelID := "kiro-" + normalizeKiroModelID(m.ModelID)
|
||||
|
||||
info := &ModelInfo{
|
||||
ID: modelID,
|
||||
Object: "model",
|
||||
Created: now,
|
||||
OwnedBy: "aws",
|
||||
Type: "kiro",
|
||||
DisplayName: formatKiroDisplayName(m.ModelName, m.RateMultiplier),
|
||||
Description: m.Description,
|
||||
ContextLength: 200000,
|
||||
MaxCompletionTokens: 64000,
|
||||
Thinking: ®istry.ThinkingSupport{Min: 1024, Max: 32000, ZeroAllowed: true, DynamicAllowed: true},
|
||||
}
|
||||
|
||||
if m.MaxInputTokens > 0 {
|
||||
info.ContextLength = m.MaxInputTokens
|
||||
}
|
||||
|
||||
models = append(models, info)
|
||||
}
|
||||
|
||||
return models
|
||||
}
|
||||
|
||||
// normalizeKiroModelID normalizes a Kiro model ID by converting dots to dashes
|
||||
// and removing common prefixes.
|
||||
func normalizeKiroModelID(modelID string) string {
|
||||
// Remove common prefixes
|
||||
modelID = strings.TrimPrefix(modelID, "anthropic.")
|
||||
modelID = strings.TrimPrefix(modelID, "amazon.")
|
||||
|
||||
// Replace dots with dashes for consistency
|
||||
modelID = strings.ReplaceAll(modelID, ".", "-")
|
||||
|
||||
// Replace underscores with dashes
|
||||
modelID = strings.ReplaceAll(modelID, "_", "-")
|
||||
|
||||
return strings.ToLower(modelID)
|
||||
}
|
||||
|
||||
// formatKiroDisplayName formats the display name with rate multiplier info.
|
||||
func formatKiroDisplayName(modelName string, rateMultiplier float64) string {
|
||||
if modelName == "" {
|
||||
return ""
|
||||
}
|
||||
|
||||
displayName := "Kiro " + modelName
|
||||
if rateMultiplier > 0 && rateMultiplier != 1.0 {
|
||||
displayName += fmt.Sprintf(" (%.1fx credit)", rateMultiplier)
|
||||
}
|
||||
|
||||
return displayName
|
||||
}
|
||||
|
||||
// generateKiroAgenticVariants generates agentic variants for Kiro models.
|
||||
// Agentic variants have optimized system prompts for coding agents.
|
||||
func generateKiroAgenticVariants(models []*ModelInfo) []*ModelInfo {
|
||||
if len(models) == 0 {
|
||||
return models
|
||||
}
|
||||
|
||||
result := make([]*ModelInfo, 0, len(models)*2)
|
||||
result = append(result, models...)
|
||||
|
||||
for _, m := range models {
|
||||
if m == nil {
|
||||
continue
|
||||
}
|
||||
|
||||
// Skip if already an agentic variant
|
||||
if strings.HasSuffix(m.ID, "-agentic") {
|
||||
continue
|
||||
}
|
||||
|
||||
// Skip auto models from agentic variant generation
|
||||
if strings.Contains(m.ID, "-auto") {
|
||||
continue
|
||||
}
|
||||
|
||||
// Create agentic variant
|
||||
agentic := &ModelInfo{
|
||||
ID: m.ID + "-agentic",
|
||||
Object: m.Object,
|
||||
Created: m.Created,
|
||||
OwnedBy: m.OwnedBy,
|
||||
Type: m.Type,
|
||||
DisplayName: m.DisplayName + " (Agentic)",
|
||||
Description: m.Description + " - Optimized for coding agents (chunked writes)",
|
||||
ContextLength: m.ContextLength,
|
||||
MaxCompletionTokens: m.MaxCompletionTokens,
|
||||
}
|
||||
|
||||
// Copy thinking support if present
|
||||
if m.Thinking != nil {
|
||||
agentic.Thinking = ®istry.ThinkingSupport{
|
||||
Min: m.Thinking.Min,
|
||||
Max: m.Thinking.Max,
|
||||
ZeroAllowed: m.Thinking.ZeroAllowed,
|
||||
DynamicAllowed: m.Thinking.DynamicAllowed,
|
||||
}
|
||||
}
|
||||
|
||||
result = append(result, agentic)
|
||||
}
|
||||
|
||||
return result
|
||||
}
|
||||
|
||||
@@ -89,6 +89,7 @@ type WatcherWrapper struct {
|
||||
snapshotAuths func() []*coreauth.Auth
|
||||
setUpdateQueue func(queue chan<- watcher.AuthUpdate)
|
||||
dispatchRuntimeUpdate func(update watcher.AuthUpdate) bool
|
||||
notifyTokenRefreshed func(tokenID, accessToken, refreshToken, expiresAt string) // 方案 A: 后台刷新通知
|
||||
}
|
||||
|
||||
// Start proxies to the underlying watcher Start implementation.
|
||||
@@ -146,3 +147,16 @@ func (w *WatcherWrapper) SetAuthUpdateQueue(queue chan<- watcher.AuthUpdate) {
|
||||
}
|
||||
w.setUpdateQueue(queue)
|
||||
}
|
||||
|
||||
// NotifyTokenRefreshed 通知 Watcher 后台刷新器已更新 token
|
||||
// 这是方案 A 的核心方法,用于解决后台刷新与内存 Auth 对象的时间差问题
|
||||
// tokenID: token 文件名(如 kiro-xxx.json)
|
||||
// accessToken: 新的 access token
|
||||
// refreshToken: 新的 refresh token
|
||||
// expiresAt: 新的过期时间(RFC3339 格式)
|
||||
func (w *WatcherWrapper) NotifyTokenRefreshed(tokenID, accessToken, refreshToken, expiresAt string) {
|
||||
if w == nil || w.notifyTokenRefreshed == nil {
|
||||
return
|
||||
}
|
||||
w.notifyTokenRefreshed(tokenID, accessToken, refreshToken, expiresAt)
|
||||
}
|
||||
|
||||
@@ -31,5 +31,8 @@ func defaultWatcherFactory(configPath, authDir string, reload func(*config.Confi
|
||||
dispatchRuntimeUpdate: func(update watcher.AuthUpdate) bool {
|
||||
return w.DispatchRuntimeAuthUpdate(update)
|
||||
},
|
||||
notifyTokenRefreshed: func(tokenID, accessToken, refreshToken, expiresAt string) {
|
||||
w.NotifyTokenRefreshed(tokenID, accessToken, refreshToken, expiresAt)
|
||||
},
|
||||
}, nil
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user