From 7cb76ae1a59c435d95035006b44a7b0ca9207470 Mon Sep 17 00:00:00 2001 From: Luis Pater Date: Sat, 5 Jul 2025 07:53:46 +0800 Subject: [PATCH] Enhance quota management and refactor configuration handling - Introduced `QuotaExceeded` settings in configuration to handle quota limits more effectively. - Added preview model switching logic to `Client` to automatically use fallback models on quota exhaustion. - Refactored `APIHandlers` to leverage new configuration structure. - Simplified server initialization and removed redundant `ServerConfig` structure. - Streamlined client initialization by unifying configuration handling throughout the project. - Improved error handling and response mechanisms in both streaming and non-streaming flows. --- config.yaml | 7 +- internal/api/handlers.go | 258 ++++++++++++++++++-------------- internal/api/server.go | 25 +--- internal/client/client.go | 303 +++++++++++++++++++++++++------------- internal/cmd/run.go | 12 +- internal/config/config.go | 13 +- 6 files changed, 374 insertions(+), 244 deletions(-) diff --git a/config.yaml b/config.yaml index 3f8b1885..d5e7e16a 100644 --- a/config.yaml +++ b/config.yaml @@ -1,7 +1,10 @@ port: 8317 -auth_dir: "~/.cli-proxy-api" +auth-dir: "~/.cli-proxy-api" debug: true proxy-url: "" -api_keys: +quota-exceeded: + switch-project: true + switch-preview-model: true +api-keys: - "12345" - "23456" \ No newline at end of file diff --git a/internal/api/handlers.go b/internal/api/handlers.go index 5ca23ad9..9c399d9b 100644 --- a/internal/api/handlers.go +++ b/internal/api/handlers.go @@ -5,6 +5,7 @@ import ( "fmt" "github.com/luispater/CLIProxyAPI/internal/api/translator" "github.com/luispater/CLIProxyAPI/internal/client" + "github.com/luispater/CLIProxyAPI/internal/config" log "github.com/sirupsen/logrus" "github.com/tidwall/gjson" "net/http" @@ -23,15 +24,15 @@ var ( // It holds a pool of clients to interact with the backend service. type APIHandlers struct { cliClients []*client.Client - debug bool + cfg *config.Config } // NewAPIHandlers creates a new API handlers instance. // It takes a slice of clients and a debug flag as input. -func NewAPIHandlers(cliClients []*client.Client, debug bool) *APIHandlers { +func NewAPIHandlers(cliClients []*client.Client, cfg *config.Config) *APIHandlers { return &APIHandlers{ cliClients: cliClients, - debug: debug, + cfg: cfg, } } @@ -216,48 +217,70 @@ func (h *APIHandlers) handleNonStreamingResponse(c *gin.Context, rawJson []byte) } }() - // Lock the mutex to update the last used page index - mutex.Lock() - startIndex := lastUsedClientIndex - currentIndex := (startIndex + 1) % len(h.cliClients) - lastUsedClientIndex = currentIndex - mutex.Unlock() + for { + // Lock the mutex to update the last used client index + mutex.Lock() + startIndex := lastUsedClientIndex + currentIndex := (startIndex + 1) % len(h.cliClients) + lastUsedClientIndex = currentIndex + mutex.Unlock() - // Reorder the pages to start from the last used index - reorderedPages := make([]*client.Client, len(h.cliClients)) - for i := 0; i < len(h.cliClients); i++ { - reorderedPages[i] = h.cliClients[(startIndex+1+i)%len(h.cliClients)] - } + // Reorder the client to start from the last used index + reorderedClients := make([]*client.Client, 0) + for i := 0; i < len(h.cliClients); i++ { + cliClient = h.cliClients[(startIndex+1+i)%len(h.cliClients)] + if cliClient.IsModelQuotaExceeded(modelName) { + log.Debugf("Model %s is quota exceeded for account %s, project id: %s", modelName, cliClient.GetEmail(), cliClient.GetProjectID()) + cliClient = nil + continue + } + reorderedClients = append(reorderedClients, cliClient) + } - locked := false - for i := 0; i < len(reorderedPages); i++ { - cliClient = reorderedPages[i] - if cliClient.RequestMutex.TryLock() { - locked = true + if len(reorderedClients) == 0 { + c.Status(429) + _, _ = fmt.Fprint(c.Writer, fmt.Sprintf(`{"error":{"code":429,"message":"All the models of '%s' are quota exceeded","status":"RESOURCE_EXHAUSTED"}}`, modelName)) + flusher.Flush() + cliCancel() + return + } + + locked := false + for i := 0; i < len(reorderedClients); i++ { + cliClient = reorderedClients[i] + if cliClient.RequestMutex.TryLock() { + locked = true + break + } + } + if !locked { + cliClient = h.cliClients[0] + cliClient.RequestMutex.Lock() + } + + log.Debugf("Request use account: %s, project id: %s", cliClient.GetEmail(), cliClient.GetProjectID()) + + resp, err := cliClient.SendMessage(cliCtx, rawJson, modelName, contents, tools) + if err != nil { + if err.StatusCode == 429 && h.cfg.QuotaExceeded.SwitchProject { + continue + } else { + c.Status(err.StatusCode) + _, _ = fmt.Fprint(c.Writer, err.Error.Error()) + flusher.Flush() + cliCancel() + } + break + } else { + openAIFormat := translator.ConvertCliToOpenAINonStream(resp) + if openAIFormat != "" { + _, _ = fmt.Fprintf(c.Writer, "data: %s\n\n", openAIFormat) + flusher.Flush() + } + cliCancel() break } } - if !locked { - cliClient = h.cliClients[0] - cliClient.RequestMutex.Lock() - } - - log.Debugf("Request use account: %s, project id: %s", cliClient.GetEmail(), cliClient.GetProjectID()) - - resp, err := cliClient.SendMessage(cliCtx, rawJson, modelName, contents, tools) - if err != nil { - c.Status(err.StatusCode) - _, _ = fmt.Fprint(c.Writer, err.Error.Error()) - flusher.Flush() - cliCancel() - } else { - openAIFormat := translator.ConvertCliToOpenAINonStream(resp) - if openAIFormat != "" { - _, _ = fmt.Fprintf(c.Writer, "data: %s\n\n", openAIFormat) - flusher.Flush() - } - cliCancel() - } } // handleStreamingResponse handles streaming responses @@ -290,79 +313,98 @@ func (h *APIHandlers) handleStreamingResponse(c *gin.Context, rawJson []byte) { } }() - // Use a round-robin approach to select the next available client. - // This distributes the load among the available clients. - mutex.Lock() - startIndex := lastUsedClientIndex - currentIndex := (startIndex + 1) % len(h.cliClients) - lastUsedClientIndex = currentIndex - mutex.Unlock() - - // Reorder the clients to start from the next client in the rotation. - reorderedPages := make([]*client.Client, len(h.cliClients)) - for i := 0; i < len(h.cliClients); i++ { - reorderedPages[i] = h.cliClients[(startIndex+1+i)%len(h.cliClients)] - } - - // Attempt to lock a client for the request. - locked := false - for i := 0; i < len(reorderedPages); i++ { - cliClient = reorderedPages[i] - if cliClient.RequestMutex.TryLock() { - locked = true - break - } - } - // If no client is available, block and wait for the first client. - if !locked { - cliClient = h.cliClients[0] - cliClient.RequestMutex.Lock() - } - log.Debugf("Request use account: %s, project id: %s", cliClient.GetEmail(), cliClient.GetProjectID()) - // Send the message and receive response chunks and errors via channels. - respChan, errChan := cliClient.SendMessageStream(cliCtx, rawJson, modelName, contents, tools) - hasFirstResponse := false +outLoop: for { - select { - // Handle client disconnection. - case <-c.Request.Context().Done(): - if c.Request.Context().Err().Error() == "context canceled" { - log.Debugf("Client disconnected: %v", c.Request.Context().Err()) - cliCancel() // Cancel the backend request. - return + // Lock the mutex to update the last used client index + mutex.Lock() + startIndex := lastUsedClientIndex + currentIndex := (startIndex + 1) % len(h.cliClients) + lastUsedClientIndex = currentIndex + mutex.Unlock() + + // Reorder the client to start from the last used index + reorderedClients := make([]*client.Client, 0) + for i := 0; i < len(h.cliClients); i++ { + cliClient = h.cliClients[(startIndex+1+i)%len(h.cliClients)] + if cliClient.IsModelQuotaExceeded(modelName) { + log.Debugf("Model %s is quota exceeded for account %s, project id: %s", modelName, cliClient.GetEmail(), cliClient.GetProjectID()) + cliClient = nil + continue } - // Process incoming response chunks. - case chunk, okStream := <-respChan: - if !okStream { - // Stream is closed, send the final [DONE] message. - _, _ = fmt.Fprintf(c.Writer, "data: [DONE]\n\n") - flusher.Flush() - cliCancel() - return - } else { - // Convert the chunk to OpenAI format and send it to the client. - hasFirstResponse = true - openAIFormat := translator.ConvertCliToOpenAI(chunk) - if openAIFormat != "" { - _, _ = fmt.Fprintf(c.Writer, "data: %s\n\n", openAIFormat) + reorderedClients = append(reorderedClients, cliClient) + } + + if len(reorderedClients) == 0 { + c.Status(429) + _, _ = fmt.Fprint(c.Writer, fmt.Sprintf(`{"error":{"code":429,"message":"All the models of '%s' are quota exceeded","status":"RESOURCE_EXHAUSTED"}}`, modelName)) + flusher.Flush() + cliCancel() + return + } + + locked := false + for i := 0; i < len(reorderedClients); i++ { + cliClient = reorderedClients[i] + if cliClient.RequestMutex.TryLock() { + locked = true + break + } + } + if !locked { + cliClient = h.cliClients[0] + cliClient.RequestMutex.Lock() + } + + log.Debugf("Request use account: %s, project id: %s", cliClient.GetEmail(), cliClient.GetProjectID()) + // Send the message and receive response chunks and errors via channels. + respChan, errChan := cliClient.SendMessageStream(cliCtx, rawJson, modelName, contents, tools) + hasFirstResponse := false + for { + select { + // Handle client disconnection. + case <-c.Request.Context().Done(): + if c.Request.Context().Err().Error() == "context canceled" { + log.Debugf("Client disconnected: %v", c.Request.Context().Err()) + cliCancel() // Cancel the backend request. + return + } + // Process incoming response chunks. + case chunk, okStream := <-respChan: + if !okStream { + // Stream is closed, send the final [DONE] message. + _, _ = fmt.Fprintf(c.Writer, "data: [DONE]\n\n") + flusher.Flush() + cliCancel() + return + } else { + // Convert the chunk to OpenAI format and send it to the client. + hasFirstResponse = true + openAIFormat := translator.ConvertCliToOpenAI(chunk) + if openAIFormat != "" { + _, _ = fmt.Fprintf(c.Writer, "data: %s\n\n", openAIFormat) + flusher.Flush() + } + } + // Handle errors from the backend. + case err, okError := <-errChan: + if okError { + if err.StatusCode == 429 && h.cfg.QuotaExceeded.SwitchProject { + continue outLoop + } else { + c.Status(err.StatusCode) + _, _ = fmt.Fprint(c.Writer, err.Error.Error()) + flusher.Flush() + cliCancel() + } + return + } + // Send a keep-alive signal to the client. + case <-time.After(500 * time.Millisecond): + if hasFirstResponse { + _, _ = c.Writer.Write([]byte(": CLI-PROXY-API PROCESSING\n\n")) flusher.Flush() } } - // Handle errors from the backend. - case err, okError := <-errChan: - if okError { - c.Status(err.StatusCode) - _, _ = fmt.Fprint(c.Writer, err.Error.Error()) - flusher.Flush() - cliCancel() - return - } - // Send a keep-alive signal to the client. - case <-time.After(500 * time.Millisecond): - if hasFirstResponse { - _, _ = c.Writer.Write([]byte(": CLI-PROXY-API PROCESSING\n\n")) - flusher.Flush() - } } } } diff --git a/internal/api/server.go b/internal/api/server.go index 8bfea165..d507eabc 100644 --- a/internal/api/server.go +++ b/internal/api/server.go @@ -6,6 +6,7 @@ import ( "fmt" "github.com/gin-gonic/gin" "github.com/luispater/CLIProxyAPI/internal/client" + "github.com/luispater/CLIProxyAPI/internal/config" log "github.com/sirupsen/logrus" "net/http" "strings" @@ -17,29 +18,19 @@ type Server struct { engine *gin.Engine server *http.Server handlers *APIHandlers - cfg *ServerConfig -} - -// ServerConfig contains the configuration for the API server. -type ServerConfig struct { - // Port is the port number the server will listen on. - Port string - // Debug enables or disables debug mode for the server and Gin. - Debug bool - // ApiKeys is a list of valid API keys for authentication. - ApiKeys []string + cfg *config.Config } // NewServer creates and initializes a new API server instance. // It sets up the Gin engine, middleware, routes, and handlers. -func NewServer(config *ServerConfig, cliClients []*client.Client) *Server { +func NewServer(cfg *config.Config, cliClients []*client.Client) *Server { // Set gin mode - if !config.Debug { + if !cfg.Debug { gin.SetMode(gin.ReleaseMode) } // Create handlers - handlers := NewAPIHandlers(cliClients, config.Debug) + handlers := NewAPIHandlers(cliClients, cfg) // Create gin engine engine := gin.New() @@ -53,7 +44,7 @@ func NewServer(config *ServerConfig, cliClients []*client.Client) *Server { s := &Server{ engine: engine, handlers: handlers, - cfg: config, + cfg: cfg, } // Setup routes @@ -61,7 +52,7 @@ func NewServer(config *ServerConfig, cliClients []*client.Client) *Server { // Create HTTP server s.server = &http.Server{ - Addr: ":" + config.Port, + Addr: fmt.Sprintf(":%d", cfg.Port), Handler: engine, } @@ -138,7 +129,7 @@ func corsMiddleware() gin.HandlerFunc { // AuthMiddleware returns a Gin middleware handler that authenticates requests // using API keys. If no API keys are configured, it allows all requests. -func AuthMiddleware(cfg *ServerConfig) gin.HandlerFunc { +func AuthMiddleware(cfg *config.Config) gin.HandlerFunc { return func(c *gin.Context) { if len(cfg.ApiKeys) == 0 { c.Next() diff --git a/internal/client/client.go b/internal/client/client.go index 0f189060..c99dfe55 100644 --- a/internal/client/client.go +++ b/internal/client/client.go @@ -29,20 +29,29 @@ const ( pluginVersion = "0.1.9" ) +var ( + previewModels = map[string][]string{ + "gemini-2.5-pro": {"gemini-2.5-pro-preview-05-06", "gemini-2.5-pro-preview-06-05"}, + "gemini-2.5-flash": {"gemini-2.5-flash-preview-04-17", "gemini-2.5-flash-preview-05-20"}, + } +) + // Client is the main client for interacting with the CLI API. type Client struct { - httpClient *http.Client - RequestMutex sync.Mutex - tokenStorage *auth.TokenStorage - cfg *config.Config + httpClient *http.Client + RequestMutex sync.Mutex + tokenStorage *auth.TokenStorage + cfg *config.Config + modelQuotaExceeded map[string]*time.Time } // NewClient creates a new CLI API client. func NewClient(httpClient *http.Client, ts *auth.TokenStorage, cfg *config.Config) *Client { return &Client{ - httpClient: httpClient, - tokenStorage: ts, - cfg: cfg, + httpClient: httpClient, + tokenStorage: ts, + cfg: cfg, + modelQuotaExceeded: make(map[string]*time.Time), } } @@ -214,97 +223,6 @@ func (c *Client) makeAPIRequest(ctx context.Context, endpoint, method string, bo return nil } -// SendMessageStream handles a single conversational turn, including tool calls. -func (c *Client) SendMessageStream(ctx context.Context, rawJson []byte, model string, contents []Content, tools []ToolDeclaration) (<-chan []byte, <-chan *ErrorMessage) { - dataTag := []byte("data: ") - errChan := make(chan *ErrorMessage) - dataChan := make(chan []byte) - go func() { - defer close(errChan) - defer close(dataChan) - - request := GenerateContentRequest{ - Contents: contents, - GenerationConfig: GenerationConfig{ - ThinkingConfig: GenerationConfigThinkingConfig{ - IncludeThoughts: true, - }, - }, - } - request.Tools = tools - - requestBody := map[string]interface{}{ - "project": c.tokenStorage.ProjectID, // Assuming ProjectID is available - "request": request, - "model": model, - } - - byteRequestBody, _ := json.Marshal(requestBody) - - // log.Debug(string(byteRequestBody)) - - reasoningEffortResult := gjson.GetBytes(rawJson, "reasoning_effort") - if reasoningEffortResult.String() == "none" { - byteRequestBody, _ = sjson.DeleteBytes(byteRequestBody, "request.generationConfig.thinkingConfig.include_thoughts") - byteRequestBody, _ = sjson.SetBytes(byteRequestBody, "request.generationConfig.thinkingConfig.thinkingBudget", 0) - } else if reasoningEffortResult.String() == "auto" { - byteRequestBody, _ = sjson.SetBytes(byteRequestBody, "request.generationConfig.thinkingConfig.thinkingBudget", -1) - } else if reasoningEffortResult.String() == "low" { - byteRequestBody, _ = sjson.SetBytes(byteRequestBody, "request.generationConfig.thinkingConfig.thinkingBudget", 1024) - } else if reasoningEffortResult.String() == "medium" { - byteRequestBody, _ = sjson.SetBytes(byteRequestBody, "request.generationConfig.thinkingConfig.thinkingBudget", 8192) - } else if reasoningEffortResult.String() == "high" { - byteRequestBody, _ = sjson.SetBytes(byteRequestBody, "request.generationConfig.thinkingConfig.thinkingBudget", 24576) - } else { - byteRequestBody, _ = sjson.SetBytes(byteRequestBody, "request.generationConfig.thinkingConfig.thinkingBudget", -1) - } - - temperatureResult := gjson.GetBytes(rawJson, "temperature") - if temperatureResult.Exists() && temperatureResult.Type == gjson.Number { - byteRequestBody, _ = sjson.SetBytes(byteRequestBody, "request.generationConfig.temperature", temperatureResult.Num) - } - - topPResult := gjson.GetBytes(rawJson, "top_p") - if topPResult.Exists() && topPResult.Type == gjson.Number { - byteRequestBody, _ = sjson.SetBytes(byteRequestBody, "request.generationConfig.topP", topPResult.Num) - } - - topKResult := gjson.GetBytes(rawJson, "top_k") - if topKResult.Exists() && topKResult.Type == gjson.Number { - byteRequestBody, _ = sjson.SetBytes(byteRequestBody, "request.generationConfig.topK", topKResult.Num) - } - - // log.Debug(string(byteRequestBody)) - - stream, err := c.APIRequest(ctx, "streamGenerateContent", byteRequestBody, true) - if err != nil { - // log.Println(err) - errChan <- err - return - } - - scanner := bufio.NewScanner(stream) - for scanner.Scan() { - line := scanner.Bytes() - // log.Printf("Received stream chunk: %s", line) - if bytes.HasPrefix(line, dataTag) { - dataChan <- line[6:] - } - } - - if errScanner := scanner.Err(); errScanner != nil { - // log.Println(err) - errChan <- &ErrorMessage{500, errScanner} - _ = stream.Close() - return - } - - _ = stream.Close() - }() - - return dataChan, errChan -} - // APIRequest handles making requests to the CLI API endpoints. func (c *Client) APIRequest(ctx context.Context, endpoint string, body interface{}, stream bool) (io.ReadCloser, *ErrorMessage) { var jsonBody []byte @@ -415,17 +333,192 @@ func (c *Client) SendMessage(ctx context.Context, rawJson []byte, model string, byteRequestBody, _ = sjson.SetBytes(byteRequestBody, "request.generationConfig.topK", topKResult.Num) } + modelName := model // log.Debug(string(byteRequestBody)) + for { + if c.isModelQuotaExceeded(modelName) { + if c.cfg.QuotaExceeded.SwitchPreviewModel { + modelName = c.getPreviewModel(model) + if modelName != "" { + log.Debugf("Model %s is quota exceeded. Switch to preview model %s", model, modelName) + byteRequestBody, _ = sjson.SetBytes(byteRequestBody, "model", modelName) + continue + } + } + return nil, &ErrorMessage{ + StatusCode: 429, + Error: fmt.Errorf(`{"error":{"code":429,"message":"All the models of '%s' are quota exceeded","status":"RESOURCE_EXHAUSTED"}}`, model), + } + } - respBody, err := c.APIRequest(ctx, "generateContent", byteRequestBody, false) - if err != nil { - return nil, err + respBody, err := c.APIRequest(ctx, "generateContent", byteRequestBody, false) + if err != nil { + if err.StatusCode == 429 { + now := time.Now() + c.modelQuotaExceeded[modelName] = &now + if c.cfg.QuotaExceeded.SwitchPreviewModel { + continue + } + } + return nil, err + } + delete(c.modelQuotaExceeded, modelName) + bodyBytes, errReadAll := io.ReadAll(respBody) + if errReadAll != nil { + return nil, &ErrorMessage{StatusCode: 500, Error: errReadAll} + } + return bodyBytes, nil } - bodyBytes, errReadAll := io.ReadAll(respBody) - if errReadAll != nil { - return nil, &ErrorMessage{StatusCode: 500, Error: errReadAll} +} + +// SendMessageStream handles a single conversational turn, including tool calls. +func (c *Client) SendMessageStream(ctx context.Context, rawJson []byte, model string, contents []Content, tools []ToolDeclaration) (<-chan []byte, <-chan *ErrorMessage) { + dataTag := []byte("data: ") + errChan := make(chan *ErrorMessage) + dataChan := make(chan []byte) + go func() { + defer close(errChan) + defer close(dataChan) + + request := GenerateContentRequest{ + Contents: contents, + GenerationConfig: GenerationConfig{ + ThinkingConfig: GenerationConfigThinkingConfig{ + IncludeThoughts: true, + }, + }, + } + request.Tools = tools + + requestBody := map[string]interface{}{ + "project": c.tokenStorage.ProjectID, // Assuming ProjectID is available + "request": request, + "model": model, + } + + byteRequestBody, _ := json.Marshal(requestBody) + + // log.Debug(string(byteRequestBody)) + + reasoningEffortResult := gjson.GetBytes(rawJson, "reasoning_effort") + if reasoningEffortResult.String() == "none" { + byteRequestBody, _ = sjson.DeleteBytes(byteRequestBody, "request.generationConfig.thinkingConfig.include_thoughts") + byteRequestBody, _ = sjson.SetBytes(byteRequestBody, "request.generationConfig.thinkingConfig.thinkingBudget", 0) + } else if reasoningEffortResult.String() == "auto" { + byteRequestBody, _ = sjson.SetBytes(byteRequestBody, "request.generationConfig.thinkingConfig.thinkingBudget", -1) + } else if reasoningEffortResult.String() == "low" { + byteRequestBody, _ = sjson.SetBytes(byteRequestBody, "request.generationConfig.thinkingConfig.thinkingBudget", 1024) + } else if reasoningEffortResult.String() == "medium" { + byteRequestBody, _ = sjson.SetBytes(byteRequestBody, "request.generationConfig.thinkingConfig.thinkingBudget", 8192) + } else if reasoningEffortResult.String() == "high" { + byteRequestBody, _ = sjson.SetBytes(byteRequestBody, "request.generationConfig.thinkingConfig.thinkingBudget", 24576) + } else { + byteRequestBody, _ = sjson.SetBytes(byteRequestBody, "request.generationConfig.thinkingConfig.thinkingBudget", -1) + } + + temperatureResult := gjson.GetBytes(rawJson, "temperature") + if temperatureResult.Exists() && temperatureResult.Type == gjson.Number { + byteRequestBody, _ = sjson.SetBytes(byteRequestBody, "request.generationConfig.temperature", temperatureResult.Num) + } + + topPResult := gjson.GetBytes(rawJson, "top_p") + if topPResult.Exists() && topPResult.Type == gjson.Number { + byteRequestBody, _ = sjson.SetBytes(byteRequestBody, "request.generationConfig.topP", topPResult.Num) + } + + topKResult := gjson.GetBytes(rawJson, "top_k") + if topKResult.Exists() && topKResult.Type == gjson.Number { + byteRequestBody, _ = sjson.SetBytes(byteRequestBody, "request.generationConfig.topK", topKResult.Num) + } + + // log.Debug(string(byteRequestBody)) + modelName := model + var stream io.ReadCloser + for { + if c.isModelQuotaExceeded(modelName) { + if c.cfg.QuotaExceeded.SwitchPreviewModel { + modelName = c.getPreviewModel(model) + if modelName != "" { + log.Debugf("Model %s is quota exceeded. Switch to preview model %s", model, modelName) + byteRequestBody, _ = sjson.SetBytes(byteRequestBody, "model", modelName) + continue + } + } + errChan <- &ErrorMessage{ + StatusCode: 429, + Error: fmt.Errorf(`{"error":{"code":429,"message":"All the models of '%s' are quota exceeded","status":"RESOURCE_EXHAUSTED"}}`, model), + } + return + } + var err *ErrorMessage + stream, err = c.APIRequest(ctx, "streamGenerateContent", byteRequestBody, true) + if err != nil { + if err.StatusCode == 429 { + now := time.Now() + c.modelQuotaExceeded[modelName] = &now + if c.cfg.QuotaExceeded.SwitchPreviewModel { + continue + } + } + errChan <- err + return + } + delete(c.modelQuotaExceeded, modelName) + break + } + + scanner := bufio.NewScanner(stream) + for scanner.Scan() { + line := scanner.Bytes() + // log.Printf("Received stream chunk: %s", line) + if bytes.HasPrefix(line, dataTag) { + dataChan <- line[6:] + } + } + + if errScanner := scanner.Err(); errScanner != nil { + // log.Println(err) + errChan <- &ErrorMessage{500, errScanner} + _ = stream.Close() + return + } + + _ = stream.Close() + }() + + return dataChan, errChan +} + +func (c *Client) isModelQuotaExceeded(model string) bool { + if lastExceededTime, hasKey := c.modelQuotaExceeded[model]; hasKey { + duration := time.Now().Sub(*lastExceededTime) + if duration > 30*time.Minute { + return false + } + return true } - return bodyBytes, nil + return false +} + +func (c *Client) getPreviewModel(model string) string { + if models, hasKey := previewModels[model]; hasKey { + for i := 0; i < len(models); i++ { + if !c.isModelQuotaExceeded(models[i]) { + return models[i] + } + } + } + return "" +} + +func (c *Client) IsModelQuotaExceeded(model string) bool { + if c.isModelQuotaExceeded(model) { + if c.cfg.QuotaExceeded.SwitchPreviewModel { + return c.getPreviewModel(model) == "" + } + return true + } + return false } // CheckCloudAPIIsEnabled sends a simple test request to the API to verify diff --git a/internal/cmd/run.go b/internal/cmd/run.go index 5fa6ebb1..12bfc032 100644 --- a/internal/cmd/run.go +++ b/internal/cmd/run.go @@ -3,7 +3,6 @@ package cmd import ( "context" "encoding/json" - "fmt" "github.com/luispater/CLIProxyAPI/internal/api" "github.com/luispater/CLIProxyAPI/internal/auth" "github.com/luispater/CLIProxyAPI/internal/client" @@ -22,13 +21,6 @@ import ( // It loads all available authentication tokens, creates a pool of clients, // starts the API server, and handles graceful shutdown signals. func StartService(cfg *config.Config) { - // Configure the API server based on the main application config. - apiConfig := &api.ServerConfig{ - Port: fmt.Sprintf("%d", cfg.Port), - Debug: cfg.Debug, - ApiKeys: cfg.ApiKeys, - } - // Create a pool of API clients, one for each token file found. cliClients := make([]*client.Client, 0) err := filepath.Walk(cfg.AuthDir, func(path string, info fs.FileInfo, err error) error { @@ -73,8 +65,8 @@ func StartService(cfg *config.Config) { } // Create and start the API server with the pool of clients. - apiServer := api.NewServer(apiConfig, cliClients) - log.Infof("Starting API server on port %s", apiConfig.Port) + apiServer := api.NewServer(cfg, cliClients) + log.Infof("Starting API server on port %d", cfg.Port) if err = apiServer.Start(); err != nil { log.Fatalf("API server failed to start: %v", err) } diff --git a/internal/config/config.go b/internal/config/config.go index 38fb864d..4af1d9a8 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -11,13 +11,22 @@ type Config struct { // Port is the network port on which the API server will listen. Port int `yaml:"port"` // AuthDir is the directory where authentication token files are stored. - AuthDir string `yaml:"auth_dir"` + AuthDir string `yaml:"auth-dir"` // Debug enables or disables debug-level logging and other debug features. Debug bool `yaml:"debug"` // ProxyUrl is the URL of an optional proxy server to use for outbound requests. ProxyUrl string `yaml:"proxy-url"` // ApiKeys is a list of keys for authenticating clients to this proxy server. - ApiKeys []string `yaml:"api_keys"` + ApiKeys []string `yaml:"api-keys"` + // QuotaExceeded defines the behavior when a quota is exceeded. + QuotaExceeded ConfigQuotaExceeded `yaml:"quota-exceeded"` +} + +type ConfigQuotaExceeded struct { + // SwitchProject indicates whether to automatically switch to another project when a quota is exceeded. + SwitchProject bool `yaml:"switch-project"` + // SwitchPreviewModel indicates whether to automatically switch to a preview model when a quota is exceeded. + SwitchPreviewModel bool `yaml:"switch-preview-model"` } // LoadConfig reads a YAML configuration file from the given path,