mirror of
https://github.com/router-for-me/CLIProxyAPIPlus.git
synced 2026-04-17 20:03:42 +00:00
Add GitLab Duo OAuth and PAT support
This commit is contained in:
@@ -79,6 +79,8 @@ func main() {
|
||||
var kiloLogin bool
|
||||
var iflowLogin bool
|
||||
var iflowCookie bool
|
||||
var gitlabLogin bool
|
||||
var gitlabTokenLogin bool
|
||||
var noBrowser bool
|
||||
var oauthCallbackPort int
|
||||
var antigravityLogin bool
|
||||
@@ -111,6 +113,8 @@ func main() {
|
||||
flag.BoolVar(&kiloLogin, "kilo-login", false, "Login to Kilo AI using device flow")
|
||||
flag.BoolVar(&iflowLogin, "iflow-login", false, "Login to iFlow using OAuth")
|
||||
flag.BoolVar(&iflowCookie, "iflow-cookie", false, "Login to iFlow using Cookie")
|
||||
flag.BoolVar(&gitlabLogin, "gitlab-login", false, "Login to GitLab Duo using OAuth")
|
||||
flag.BoolVar(&gitlabTokenLogin, "gitlab-token-login", false, "Login to GitLab Duo using a personal access token")
|
||||
flag.BoolVar(&noBrowser, "no-browser", false, "Don't open browser automatically for OAuth")
|
||||
flag.IntVar(&oauthCallbackPort, "oauth-callback-port", 0, "Override OAuth callback port (defaults to provider-specific port)")
|
||||
flag.BoolVar(&useIncognito, "incognito", false, "Open browser in incognito/private mode for OAuth (useful for multiple accounts)")
|
||||
@@ -527,6 +531,10 @@ func main() {
|
||||
cmd.DoIFlowLogin(cfg, options)
|
||||
} else if iflowCookie {
|
||||
cmd.DoIFlowCookieAuth(cfg, options)
|
||||
} else if gitlabLogin {
|
||||
cmd.DoGitLabLogin(cfg, options)
|
||||
} else if gitlabTokenLogin {
|
||||
cmd.DoGitLabTokenLogin(cfg, options)
|
||||
} else if kimiLogin {
|
||||
cmd.DoKimiLogin(cfg, options)
|
||||
} else if kiroLogin {
|
||||
|
||||
492
internal/auth/gitlab/gitlab.go
Normal file
492
internal/auth/gitlab/gitlab.go
Normal file
@@ -0,0 +1,492 @@
|
||||
package gitlab
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/rand"
|
||||
"crypto/sha256"
|
||||
"encoding/base64"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/util"
|
||||
log "github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
const (
|
||||
DefaultBaseURL = "https://gitlab.com"
|
||||
DefaultCallbackPort = 17171
|
||||
defaultOAuthScope = "api read_user"
|
||||
)
|
||||
|
||||
type PKCECodes struct {
|
||||
CodeVerifier string
|
||||
CodeChallenge string
|
||||
}
|
||||
|
||||
type OAuthResult struct {
|
||||
Code string
|
||||
State string
|
||||
Error string
|
||||
}
|
||||
|
||||
type OAuthServer struct {
|
||||
server *http.Server
|
||||
port int
|
||||
resultChan chan *OAuthResult
|
||||
errorChan chan error
|
||||
mu sync.Mutex
|
||||
running bool
|
||||
}
|
||||
|
||||
type TokenResponse struct {
|
||||
AccessToken string `json:"access_token"`
|
||||
TokenType string `json:"token_type"`
|
||||
RefreshToken string `json:"refresh_token"`
|
||||
Scope string `json:"scope"`
|
||||
CreatedAt int64 `json:"created_at"`
|
||||
ExpiresIn int `json:"expires_in"`
|
||||
}
|
||||
|
||||
type User struct {
|
||||
ID int64 `json:"id"`
|
||||
Username string `json:"username"`
|
||||
Name string `json:"name"`
|
||||
Email string `json:"email"`
|
||||
PublicEmail string `json:"public_email"`
|
||||
}
|
||||
|
||||
type PersonalAccessTokenSelf struct {
|
||||
ID int64 `json:"id"`
|
||||
Name string `json:"name"`
|
||||
Scopes []string `json:"scopes"`
|
||||
UserID int64 `json:"user_id"`
|
||||
}
|
||||
|
||||
type ModelDetails struct {
|
||||
ModelProvider string `json:"model_provider"`
|
||||
ModelName string `json:"model_name"`
|
||||
}
|
||||
|
||||
type DirectAccessResponse struct {
|
||||
BaseURL string `json:"base_url"`
|
||||
Token string `json:"token"`
|
||||
ExpiresAt int64 `json:"expires_at"`
|
||||
Headers map[string]string `json:"headers"`
|
||||
ModelDetails *ModelDetails `json:"model_details,omitempty"`
|
||||
}
|
||||
|
||||
type DiscoveredModel struct {
|
||||
ModelProvider string
|
||||
ModelName string
|
||||
}
|
||||
|
||||
type AuthClient struct {
|
||||
httpClient *http.Client
|
||||
}
|
||||
|
||||
func NewAuthClient(cfg *config.Config) *AuthClient {
|
||||
client := &http.Client{}
|
||||
if cfg != nil {
|
||||
client = util.SetProxy(&cfg.SDKConfig, client)
|
||||
}
|
||||
return &AuthClient{httpClient: client}
|
||||
}
|
||||
|
||||
func NormalizeBaseURL(raw string) string {
|
||||
value := strings.TrimSpace(raw)
|
||||
if value == "" {
|
||||
return DefaultBaseURL
|
||||
}
|
||||
if !strings.Contains(value, "://") {
|
||||
value = "https://" + value
|
||||
}
|
||||
value = strings.TrimRight(value, "/")
|
||||
return value
|
||||
}
|
||||
|
||||
func TokenExpiry(now time.Time, token *TokenResponse) time.Time {
|
||||
if token == nil {
|
||||
return time.Time{}
|
||||
}
|
||||
if token.CreatedAt > 0 && token.ExpiresIn > 0 {
|
||||
return time.Unix(token.CreatedAt+int64(token.ExpiresIn), 0).UTC()
|
||||
}
|
||||
if token.ExpiresIn > 0 {
|
||||
return now.UTC().Add(time.Duration(token.ExpiresIn) * time.Second)
|
||||
}
|
||||
return time.Time{}
|
||||
}
|
||||
|
||||
func GeneratePKCECodes() (*PKCECodes, error) {
|
||||
verifierBytes := make([]byte, 32)
|
||||
if _, err := rand.Read(verifierBytes); err != nil {
|
||||
return nil, fmt.Errorf("gitlab pkce generation failed: %w", err)
|
||||
}
|
||||
verifier := base64.RawURLEncoding.EncodeToString(verifierBytes)
|
||||
sum := sha256.Sum256([]byte(verifier))
|
||||
challenge := base64.RawURLEncoding.EncodeToString(sum[:])
|
||||
return &PKCECodes{
|
||||
CodeVerifier: verifier,
|
||||
CodeChallenge: challenge,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func NewOAuthServer(port int) *OAuthServer {
|
||||
return &OAuthServer{
|
||||
port: port,
|
||||
resultChan: make(chan *OAuthResult, 1),
|
||||
errorChan: make(chan error, 1),
|
||||
}
|
||||
}
|
||||
|
||||
func (s *OAuthServer) Start() error {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
|
||||
if s.running {
|
||||
return fmt.Errorf("gitlab oauth server already running")
|
||||
}
|
||||
if !s.isPortAvailable() {
|
||||
return fmt.Errorf("port %d is already in use", s.port)
|
||||
}
|
||||
|
||||
mux := http.NewServeMux()
|
||||
mux.HandleFunc("/auth/callback", s.handleCallback)
|
||||
|
||||
s.server = &http.Server{
|
||||
Addr: fmt.Sprintf(":%d", s.port),
|
||||
Handler: mux,
|
||||
ReadTimeout: 10 * time.Second,
|
||||
WriteTimeout: 10 * time.Second,
|
||||
}
|
||||
s.running = true
|
||||
|
||||
go func() {
|
||||
if err := s.server.ListenAndServe(); err != nil && err != http.ErrServerClosed {
|
||||
s.errorChan <- err
|
||||
}
|
||||
}()
|
||||
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *OAuthServer) Stop(ctx context.Context) error {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
if !s.running || s.server == nil {
|
||||
return nil
|
||||
}
|
||||
defer func() {
|
||||
s.running = false
|
||||
s.server = nil
|
||||
}()
|
||||
return s.server.Shutdown(ctx)
|
||||
}
|
||||
|
||||
func (s *OAuthServer) WaitForCallback(timeout time.Duration) (*OAuthResult, error) {
|
||||
select {
|
||||
case result := <-s.resultChan:
|
||||
return result, nil
|
||||
case err := <-s.errorChan:
|
||||
return nil, err
|
||||
case <-time.After(timeout):
|
||||
return nil, fmt.Errorf("timeout waiting for OAuth callback")
|
||||
}
|
||||
}
|
||||
|
||||
func (s *OAuthServer) handleCallback(w http.ResponseWriter, r *http.Request) {
|
||||
if r.Method != http.MethodGet {
|
||||
http.Error(w, "method not allowed", http.StatusMethodNotAllowed)
|
||||
return
|
||||
}
|
||||
query := r.URL.Query()
|
||||
if errParam := strings.TrimSpace(query.Get("error")); errParam != "" {
|
||||
s.sendResult(&OAuthResult{Error: errParam})
|
||||
http.Error(w, errParam, http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
code := strings.TrimSpace(query.Get("code"))
|
||||
state := strings.TrimSpace(query.Get("state"))
|
||||
if code == "" || state == "" {
|
||||
s.sendResult(&OAuthResult{Error: "missing_code_or_state"})
|
||||
http.Error(w, "missing code or state", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
s.sendResult(&OAuthResult{Code: code, State: state})
|
||||
_, _ = w.Write([]byte("GitLab authentication received. You can close this tab."))
|
||||
}
|
||||
|
||||
func (s *OAuthServer) sendResult(result *OAuthResult) {
|
||||
select {
|
||||
case s.resultChan <- result:
|
||||
default:
|
||||
log.Debug("gitlab oauth result channel full, dropping callback result")
|
||||
}
|
||||
}
|
||||
|
||||
func (s *OAuthServer) isPortAvailable() bool {
|
||||
listener, err := net.Listen("tcp", fmt.Sprintf(":%d", s.port))
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
_ = listener.Close()
|
||||
return true
|
||||
}
|
||||
|
||||
func RedirectURL(port int) string {
|
||||
return fmt.Sprintf("http://localhost:%d/auth/callback", port)
|
||||
}
|
||||
|
||||
func (c *AuthClient) GenerateAuthURL(baseURL, clientID, redirectURI, state string, pkce *PKCECodes) (string, error) {
|
||||
if pkce == nil {
|
||||
return "", fmt.Errorf("gitlab auth URL generation failed: PKCE codes are required")
|
||||
}
|
||||
if strings.TrimSpace(clientID) == "" {
|
||||
return "", fmt.Errorf("gitlab auth URL generation failed: client ID is required")
|
||||
}
|
||||
baseURL = NormalizeBaseURL(baseURL)
|
||||
params := url.Values{
|
||||
"client_id": {strings.TrimSpace(clientID)},
|
||||
"response_type": {"code"},
|
||||
"redirect_uri": {strings.TrimSpace(redirectURI)},
|
||||
"scope": {defaultOAuthScope},
|
||||
"state": {strings.TrimSpace(state)},
|
||||
"code_challenge": {pkce.CodeChallenge},
|
||||
"code_challenge_method": {"S256"},
|
||||
}
|
||||
return fmt.Sprintf("%s/oauth/authorize?%s", baseURL, params.Encode()), nil
|
||||
}
|
||||
|
||||
func (c *AuthClient) ExchangeCodeForTokens(ctx context.Context, baseURL, clientID, clientSecret, redirectURI, code, codeVerifier string) (*TokenResponse, error) {
|
||||
form := url.Values{
|
||||
"grant_type": {"authorization_code"},
|
||||
"client_id": {strings.TrimSpace(clientID)},
|
||||
"code": {strings.TrimSpace(code)},
|
||||
"redirect_uri": {strings.TrimSpace(redirectURI)},
|
||||
"code_verifier": {strings.TrimSpace(codeVerifier)},
|
||||
}
|
||||
if secret := strings.TrimSpace(clientSecret); secret != "" {
|
||||
form.Set("client_secret", secret)
|
||||
}
|
||||
return c.postToken(ctx, NormalizeBaseURL(baseURL)+"/oauth/token", form)
|
||||
}
|
||||
|
||||
func (c *AuthClient) RefreshTokens(ctx context.Context, baseURL, clientID, clientSecret, refreshToken string) (*TokenResponse, error) {
|
||||
form := url.Values{
|
||||
"grant_type": {"refresh_token"},
|
||||
"refresh_token": {strings.TrimSpace(refreshToken)},
|
||||
}
|
||||
if clientID = strings.TrimSpace(clientID); clientID != "" {
|
||||
form.Set("client_id", clientID)
|
||||
}
|
||||
if secret := strings.TrimSpace(clientSecret); secret != "" {
|
||||
form.Set("client_secret", secret)
|
||||
}
|
||||
return c.postToken(ctx, NormalizeBaseURL(baseURL)+"/oauth/token", form)
|
||||
}
|
||||
|
||||
func (c *AuthClient) postToken(ctx context.Context, tokenURL string, form url.Values) (*TokenResponse, error) {
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodPost, tokenURL, strings.NewReader(form.Encode()))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("gitlab token request failed: %w", err)
|
||||
}
|
||||
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
|
||||
req.Header.Set("Accept", "application/json")
|
||||
|
||||
resp, err := c.httpClient.Do(req)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("gitlab token request failed: %w", err)
|
||||
}
|
||||
defer func() { _ = resp.Body.Close() }()
|
||||
|
||||
body, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("gitlab token response read failed: %w", err)
|
||||
}
|
||||
if resp.StatusCode < 200 || resp.StatusCode >= 300 {
|
||||
return nil, fmt.Errorf("gitlab token request failed with status %d: %s", resp.StatusCode, strings.TrimSpace(string(body)))
|
||||
}
|
||||
var token TokenResponse
|
||||
if err := json.Unmarshal(body, &token); err != nil {
|
||||
return nil, fmt.Errorf("gitlab token response decode failed: %w", err)
|
||||
}
|
||||
return &token, nil
|
||||
}
|
||||
|
||||
func (c *AuthClient) GetCurrentUser(ctx context.Context, baseURL, token string) (*User, error) {
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodGet, NormalizeBaseURL(baseURL)+"/api/v4/user", nil)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("gitlab user request failed: %w", err)
|
||||
}
|
||||
req.Header.Set("Authorization", "Bearer "+strings.TrimSpace(token))
|
||||
req.Header.Set("Accept", "application/json")
|
||||
|
||||
resp, err := c.httpClient.Do(req)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("gitlab user request failed: %w", err)
|
||||
}
|
||||
defer func() { _ = resp.Body.Close() }()
|
||||
|
||||
body, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("gitlab user response read failed: %w", err)
|
||||
}
|
||||
if resp.StatusCode < 200 || resp.StatusCode >= 300 {
|
||||
return nil, fmt.Errorf("gitlab user request failed with status %d: %s", resp.StatusCode, strings.TrimSpace(string(body)))
|
||||
}
|
||||
|
||||
var user User
|
||||
if err := json.Unmarshal(body, &user); err != nil {
|
||||
return nil, fmt.Errorf("gitlab user response decode failed: %w", err)
|
||||
}
|
||||
return &user, nil
|
||||
}
|
||||
|
||||
func (c *AuthClient) GetPersonalAccessTokenSelf(ctx context.Context, baseURL, token string) (*PersonalAccessTokenSelf, error) {
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodGet, NormalizeBaseURL(baseURL)+"/api/v4/personal_access_tokens/self", nil)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("gitlab PAT self request failed: %w", err)
|
||||
}
|
||||
req.Header.Set("Authorization", "Bearer "+strings.TrimSpace(token))
|
||||
req.Header.Set("Accept", "application/json")
|
||||
|
||||
resp, err := c.httpClient.Do(req)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("gitlab PAT self request failed: %w", err)
|
||||
}
|
||||
defer func() { _ = resp.Body.Close() }()
|
||||
|
||||
body, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("gitlab PAT self response read failed: %w", err)
|
||||
}
|
||||
if resp.StatusCode < 200 || resp.StatusCode >= 300 {
|
||||
return nil, fmt.Errorf("gitlab PAT self request failed with status %d: %s", resp.StatusCode, strings.TrimSpace(string(body)))
|
||||
}
|
||||
|
||||
var pat PersonalAccessTokenSelf
|
||||
if err := json.Unmarshal(body, &pat); err != nil {
|
||||
return nil, fmt.Errorf("gitlab PAT self response decode failed: %w", err)
|
||||
}
|
||||
return &pat, nil
|
||||
}
|
||||
|
||||
func (c *AuthClient) FetchDirectAccess(ctx context.Context, baseURL, token string) (*DirectAccessResponse, error) {
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodPost, NormalizeBaseURL(baseURL)+"/api/v4/code_suggestions/direct_access", nil)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("gitlab direct access request failed: %w", err)
|
||||
}
|
||||
req.Header.Set("Authorization", "Bearer "+strings.TrimSpace(token))
|
||||
req.Header.Set("Accept", "application/json")
|
||||
|
||||
resp, err := c.httpClient.Do(req)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("gitlab direct access request failed: %w", err)
|
||||
}
|
||||
defer func() { _ = resp.Body.Close() }()
|
||||
|
||||
body, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("gitlab direct access response read failed: %w", err)
|
||||
}
|
||||
if resp.StatusCode < 200 || resp.StatusCode >= 300 {
|
||||
return nil, fmt.Errorf("gitlab direct access request failed with status %d: %s", resp.StatusCode, strings.TrimSpace(string(body)))
|
||||
}
|
||||
|
||||
var direct DirectAccessResponse
|
||||
if err := json.Unmarshal(body, &direct); err != nil {
|
||||
return nil, fmt.Errorf("gitlab direct access response decode failed: %w", err)
|
||||
}
|
||||
if direct.Headers == nil {
|
||||
direct.Headers = make(map[string]string)
|
||||
}
|
||||
return &direct, nil
|
||||
}
|
||||
|
||||
func ExtractDiscoveredModels(metadata map[string]any) []DiscoveredModel {
|
||||
if len(metadata) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
models := make([]DiscoveredModel, 0, 4)
|
||||
seen := make(map[string]struct{})
|
||||
appendModel := func(provider, name string) {
|
||||
provider = strings.TrimSpace(provider)
|
||||
name = strings.TrimSpace(name)
|
||||
if name == "" {
|
||||
return
|
||||
}
|
||||
key := strings.ToLower(provider + "\x00" + name)
|
||||
if _, ok := seen[key]; ok {
|
||||
return
|
||||
}
|
||||
seen[key] = struct{}{}
|
||||
models = append(models, DiscoveredModel{
|
||||
ModelProvider: provider,
|
||||
ModelName: name,
|
||||
})
|
||||
}
|
||||
|
||||
if raw, ok := metadata["model_details"]; ok {
|
||||
appendDiscoveredModels(raw, appendModel)
|
||||
}
|
||||
appendModel(stringValue(metadata["model_provider"]), stringValue(metadata["model_name"]))
|
||||
|
||||
for _, key := range []string{"models", "supported_models", "discovered_models"} {
|
||||
if raw, ok := metadata[key]; ok {
|
||||
appendDiscoveredModels(raw, appendModel)
|
||||
}
|
||||
}
|
||||
|
||||
return models
|
||||
}
|
||||
|
||||
func appendDiscoveredModels(raw any, appendModel func(provider, name string)) {
|
||||
switch typed := raw.(type) {
|
||||
case map[string]any:
|
||||
appendModel(stringValue(typed["model_provider"]), stringValue(typed["model_name"]))
|
||||
appendModel(stringValue(typed["provider"]), stringValue(typed["name"]))
|
||||
if nested, ok := typed["models"]; ok {
|
||||
appendDiscoveredModels(nested, appendModel)
|
||||
}
|
||||
case []any:
|
||||
for _, item := range typed {
|
||||
appendDiscoveredModels(item, appendModel)
|
||||
}
|
||||
case []string:
|
||||
for _, item := range typed {
|
||||
appendModel("", item)
|
||||
}
|
||||
case string:
|
||||
appendModel("", typed)
|
||||
}
|
||||
}
|
||||
|
||||
func stringValue(raw any) string {
|
||||
switch typed := raw.(type) {
|
||||
case string:
|
||||
return strings.TrimSpace(typed)
|
||||
case fmt.Stringer:
|
||||
return strings.TrimSpace(typed.String())
|
||||
case json.Number:
|
||||
return typed.String()
|
||||
case int:
|
||||
return strconv.Itoa(typed)
|
||||
case int64:
|
||||
return strconv.FormatInt(typed, 10)
|
||||
case float64:
|
||||
return strconv.FormatInt(int64(typed), 10)
|
||||
default:
|
||||
return ""
|
||||
}
|
||||
}
|
||||
72
internal/auth/gitlab/gitlab_test.go
Normal file
72
internal/auth/gitlab/gitlab_test.go
Normal file
@@ -0,0 +1,72 @@
|
||||
package gitlab
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestNormalizeBaseURL(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
in string
|
||||
want string
|
||||
}{
|
||||
{name: "default", in: "", want: DefaultBaseURL},
|
||||
{name: "plain host", in: "gitlab.example.com", want: "https://gitlab.example.com"},
|
||||
{name: "trim trailing slash", in: "https://gitlab.example.com/", want: "https://gitlab.example.com"},
|
||||
}
|
||||
|
||||
for _, tc := range tests {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
if got := NormalizeBaseURL(tc.in); got != tc.want {
|
||||
t.Fatalf("NormalizeBaseURL(%q) = %q, want %q", tc.in, got, tc.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestFetchDirectAccess_ParsesModelDetails(t *testing.T) {
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if r.Method != http.MethodPost {
|
||||
t.Fatalf("expected POST, got %s", r.Method)
|
||||
}
|
||||
if got := r.Header.Get("Authorization"); got != "Bearer pat-123" {
|
||||
t.Fatalf("expected Authorization header, got %q", got)
|
||||
}
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
_, _ = w.Write([]byte(`{
|
||||
"base_url":"https://gateway.gitlab.example.com/v1",
|
||||
"token":"duo-gateway-token",
|
||||
"expires_at":2000000000,
|
||||
"headers":{
|
||||
"X-Gitlab-Realm":"saas",
|
||||
"X-Gitlab-Host-Name":"gitlab.example.com"
|
||||
},
|
||||
"model_details":{
|
||||
"model_provider":"anthropic",
|
||||
"model_name":"claude-sonnet-4-5"
|
||||
}
|
||||
}`))
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
client := &AuthClient{httpClient: server.Client()}
|
||||
direct, err := client.FetchDirectAccess(context.Background(), server.URL, "pat-123")
|
||||
if err != nil {
|
||||
t.Fatalf("FetchDirectAccess returned error: %v", err)
|
||||
}
|
||||
if direct.BaseURL != "https://gateway.gitlab.example.com/v1" {
|
||||
t.Fatalf("unexpected base_url %q", direct.BaseURL)
|
||||
}
|
||||
if direct.Token != "duo-gateway-token" {
|
||||
t.Fatalf("unexpected token %q", direct.Token)
|
||||
}
|
||||
if direct.ModelDetails == nil || direct.ModelDetails.ModelName != "claude-sonnet-4-5" {
|
||||
t.Fatalf("unexpected model details: %+v", direct.ModelDetails)
|
||||
}
|
||||
if direct.Headers["X-Gitlab-Realm"] != "saas" {
|
||||
t.Fatalf("expected X-Gitlab-Realm header, got %+v", direct.Headers)
|
||||
}
|
||||
}
|
||||
@@ -23,6 +23,7 @@ func newAuthManager() *sdkAuth.Manager {
|
||||
sdkAuth.NewKiroAuthenticator(),
|
||||
sdkAuth.NewGitHubCopilotAuthenticator(),
|
||||
sdkAuth.NewKiloAuthenticator(),
|
||||
sdkAuth.NewGitLabAuthenticator(),
|
||||
)
|
||||
return manager
|
||||
}
|
||||
|
||||
69
internal/cmd/gitlab_login.go
Normal file
69
internal/cmd/gitlab_login.go
Normal file
@@ -0,0 +1,69 @@
|
||||
package cmd
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
|
||||
sdkAuth "github.com/router-for-me/CLIProxyAPI/v6/sdk/auth"
|
||||
)
|
||||
|
||||
func DoGitLabLogin(cfg *config.Config, options *LoginOptions) {
|
||||
if options == nil {
|
||||
options = &LoginOptions{}
|
||||
}
|
||||
|
||||
promptFn := options.Prompt
|
||||
if promptFn == nil {
|
||||
promptFn = defaultProjectPrompt()
|
||||
}
|
||||
|
||||
manager := newAuthManager()
|
||||
authOpts := &sdkAuth.LoginOptions{
|
||||
NoBrowser: options.NoBrowser,
|
||||
CallbackPort: options.CallbackPort,
|
||||
Metadata: map[string]string{
|
||||
"login_mode": "oauth",
|
||||
},
|
||||
Prompt: promptFn,
|
||||
}
|
||||
|
||||
_, savedPath, err := manager.Login(context.Background(), "gitlab", cfg, authOpts)
|
||||
if err != nil {
|
||||
fmt.Printf("GitLab Duo authentication failed: %v\n", err)
|
||||
return
|
||||
}
|
||||
if savedPath != "" {
|
||||
fmt.Printf("Authentication saved to %s\n", savedPath)
|
||||
}
|
||||
fmt.Println("GitLab Duo authentication successful!")
|
||||
}
|
||||
|
||||
func DoGitLabTokenLogin(cfg *config.Config, options *LoginOptions) {
|
||||
if options == nil {
|
||||
options = &LoginOptions{}
|
||||
}
|
||||
|
||||
promptFn := options.Prompt
|
||||
if promptFn == nil {
|
||||
promptFn = defaultProjectPrompt()
|
||||
}
|
||||
|
||||
manager := newAuthManager()
|
||||
authOpts := &sdkAuth.LoginOptions{
|
||||
Metadata: map[string]string{
|
||||
"login_mode": "pat",
|
||||
},
|
||||
Prompt: promptFn,
|
||||
}
|
||||
|
||||
_, savedPath, err := manager.Login(context.Background(), "gitlab", cfg, authOpts)
|
||||
if err != nil {
|
||||
fmt.Printf("GitLab Duo PAT authentication failed: %v\n", err)
|
||||
return
|
||||
}
|
||||
if savedPath != "" {
|
||||
fmt.Printf("Authentication saved to %s\n", savedPath)
|
||||
}
|
||||
fmt.Println("GitLab Duo PAT authentication successful!")
|
||||
}
|
||||
746
internal/runtime/executor/gitlab_executor.go
Normal file
746
internal/runtime/executor/gitlab_executor.go
Normal file
@@ -0,0 +1,746 @@
|
||||
package executor
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/auth/gitlab"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/registry"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/thinking"
|
||||
cliproxyauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth"
|
||||
cliproxyexecutor "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/executor"
|
||||
sdktranslator "github.com/router-for-me/CLIProxyAPI/v6/sdk/translator"
|
||||
"github.com/tidwall/gjson"
|
||||
)
|
||||
|
||||
const (
|
||||
gitLabProviderKey = "gitlab"
|
||||
gitLabAuthMethodOAuth = "oauth"
|
||||
gitLabAuthMethodPAT = "pat"
|
||||
gitLabChatEndpoint = "/api/v4/chat/completions"
|
||||
gitLabCodeSuggestionsEndpoint = "/api/v4/code_suggestions/completions"
|
||||
)
|
||||
|
||||
type GitLabExecutor struct {
|
||||
cfg *config.Config
|
||||
}
|
||||
|
||||
type gitLabPrompt struct {
|
||||
Instruction string
|
||||
FileName string
|
||||
ContentAboveCursor string
|
||||
ChatContext []map[string]any
|
||||
CodeSuggestionContext []map[string]any
|
||||
}
|
||||
|
||||
func NewGitLabExecutor(cfg *config.Config) *GitLabExecutor {
|
||||
return &GitLabExecutor{cfg: cfg}
|
||||
}
|
||||
|
||||
func (e *GitLabExecutor) Identifier() string { return gitLabProviderKey }
|
||||
|
||||
func (e *GitLabExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (resp cliproxyexecutor.Response, err error) {
|
||||
baseModel := thinking.ParseSuffix(req.Model).ModelName
|
||||
|
||||
reporter := newUsageReporter(ctx, e.Identifier(), baseModel, auth)
|
||||
defer reporter.trackFailure(ctx, &err)
|
||||
|
||||
translated, err := e.translateToOpenAI(req, opts)
|
||||
if err != nil {
|
||||
return resp, err
|
||||
}
|
||||
prompt := buildGitLabPrompt(translated)
|
||||
if strings.TrimSpace(prompt.Instruction) == "" && strings.TrimSpace(prompt.ContentAboveCursor) == "" {
|
||||
err = statusErr{code: http.StatusBadRequest, msg: "gitlab duo executor: request has no usable text content"}
|
||||
return resp, err
|
||||
}
|
||||
|
||||
text, err := e.invoke(ctx, auth, prompt)
|
||||
if err != nil {
|
||||
return resp, err
|
||||
}
|
||||
|
||||
responseModel := gitLabResolvedModel(auth, req.Model)
|
||||
openAIResponse := buildGitLabOpenAIResponse(responseModel, text, translated)
|
||||
reporter.publish(ctx, parseOpenAIUsage(openAIResponse))
|
||||
reporter.ensurePublished(ctx)
|
||||
|
||||
var param any
|
||||
out := sdktranslator.TranslateNonStream(
|
||||
ctx,
|
||||
sdktranslator.FromString("openai"),
|
||||
opts.SourceFormat,
|
||||
req.Model,
|
||||
opts.OriginalRequest,
|
||||
translated,
|
||||
openAIResponse,
|
||||
¶m,
|
||||
)
|
||||
return cliproxyexecutor.Response{Payload: []byte(out), Headers: make(http.Header)}, nil
|
||||
}
|
||||
|
||||
func (e *GitLabExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (_ *cliproxyexecutor.StreamResult, err error) {
|
||||
baseModel := thinking.ParseSuffix(req.Model).ModelName
|
||||
|
||||
reporter := newUsageReporter(ctx, e.Identifier(), baseModel, auth)
|
||||
defer reporter.trackFailure(ctx, &err)
|
||||
|
||||
translated, err := e.translateToOpenAI(req, opts)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
prompt := buildGitLabPrompt(translated)
|
||||
if strings.TrimSpace(prompt.Instruction) == "" && strings.TrimSpace(prompt.ContentAboveCursor) == "" {
|
||||
return nil, statusErr{code: http.StatusBadRequest, msg: "gitlab duo executor: request has no usable text content"}
|
||||
}
|
||||
|
||||
text, err := e.invoke(ctx, auth, prompt)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
responseModel := gitLabResolvedModel(auth, req.Model)
|
||||
openAIResponse := buildGitLabOpenAIResponse(responseModel, text, translated)
|
||||
reporter.publish(ctx, parseOpenAIUsage(openAIResponse))
|
||||
reporter.ensurePublished(ctx)
|
||||
|
||||
out := make(chan cliproxyexecutor.StreamChunk, 8)
|
||||
go func() {
|
||||
defer close(out)
|
||||
var param any
|
||||
lines := buildGitLabOpenAIStream(responseModel, text)
|
||||
for _, line := range lines {
|
||||
chunks := sdktranslator.TranslateStream(
|
||||
ctx,
|
||||
sdktranslator.FromString("openai"),
|
||||
opts.SourceFormat,
|
||||
req.Model,
|
||||
opts.OriginalRequest,
|
||||
translated,
|
||||
[]byte(line),
|
||||
¶m,
|
||||
)
|
||||
for i := range chunks {
|
||||
out <- cliproxyexecutor.StreamChunk{Payload: []byte(chunks[i])}
|
||||
}
|
||||
}
|
||||
}()
|
||||
return &cliproxyexecutor.StreamResult{Headers: make(http.Header), Chunks: out}, nil
|
||||
}
|
||||
|
||||
func (e *GitLabExecutor) Refresh(ctx context.Context, auth *cliproxyauth.Auth) (*cliproxyauth.Auth, error) {
|
||||
if auth == nil {
|
||||
return nil, fmt.Errorf("gitlab duo executor: auth is nil")
|
||||
}
|
||||
baseURL := gitLabBaseURL(auth)
|
||||
token := gitLabPrimaryToken(auth)
|
||||
if baseURL == "" || token == "" {
|
||||
return nil, fmt.Errorf("gitlab duo executor: missing base URL or token")
|
||||
}
|
||||
|
||||
client := gitlab.NewAuthClient(e.cfg)
|
||||
method := strings.ToLower(strings.TrimSpace(gitLabMetadataString(auth.Metadata, "auth_method", "auth_kind")))
|
||||
if method == "" {
|
||||
method = gitLabAuthMethodOAuth
|
||||
}
|
||||
|
||||
if method == gitLabAuthMethodOAuth {
|
||||
if refreshed, refreshErr := e.refreshOAuthToken(ctx, client, auth, baseURL); refreshErr == nil && refreshed != nil {
|
||||
token = refreshed.AccessToken
|
||||
applyGitLabTokenMetadata(auth.Metadata, refreshed)
|
||||
}
|
||||
}
|
||||
|
||||
direct, err := client.FetchDirectAccess(ctx, baseURL, token)
|
||||
if err != nil && method == gitLabAuthMethodOAuth {
|
||||
if refreshed, refreshErr := e.refreshOAuthToken(ctx, client, auth, baseURL); refreshErr == nil && refreshed != nil {
|
||||
token = refreshed.AccessToken
|
||||
applyGitLabTokenMetadata(auth.Metadata, refreshed)
|
||||
direct, err = client.FetchDirectAccess(ctx, baseURL, token)
|
||||
}
|
||||
}
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if auth.Metadata == nil {
|
||||
auth.Metadata = make(map[string]any)
|
||||
}
|
||||
auth.Metadata["type"] = gitLabProviderKey
|
||||
auth.Metadata["auth_method"] = method
|
||||
auth.Metadata["auth_kind"] = gitLabAuthKind(method)
|
||||
auth.Metadata["base_url"] = gitlab.NormalizeBaseURL(baseURL)
|
||||
auth.Metadata["last_refresh"] = time.Now().UTC().Format(time.RFC3339)
|
||||
mergeGitLabDirectAccessMetadata(auth.Metadata, direct)
|
||||
return auth, nil
|
||||
}
|
||||
|
||||
func (e *GitLabExecutor) CountTokens(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (cliproxyexecutor.Response, error) {
|
||||
baseModel := thinking.ParseSuffix(req.Model).ModelName
|
||||
translated := sdktranslator.TranslateRequest(opts.SourceFormat, sdktranslator.FromString("openai"), baseModel, req.Payload, false)
|
||||
enc, err := tokenizerForModel(baseModel)
|
||||
if err != nil {
|
||||
return cliproxyexecutor.Response{}, fmt.Errorf("gitlab duo executor: tokenizer init failed: %w", err)
|
||||
}
|
||||
count, err := countOpenAIChatTokens(enc, translated)
|
||||
if err != nil {
|
||||
return cliproxyexecutor.Response{}, err
|
||||
}
|
||||
return cliproxyexecutor.Response{Payload: buildOpenAIUsageJSON(count), Headers: make(http.Header)}, nil
|
||||
}
|
||||
|
||||
func (e *GitLabExecutor) HttpRequest(ctx context.Context, auth *cliproxyauth.Auth, req *http.Request) (*http.Response, error) {
|
||||
if req == nil {
|
||||
return nil, fmt.Errorf("gitlab duo executor: request is nil")
|
||||
}
|
||||
if ctx == nil {
|
||||
ctx = req.Context()
|
||||
}
|
||||
httpReq := req.WithContext(ctx)
|
||||
if token := gitLabPrimaryToken(auth); token != "" {
|
||||
httpReq.Header.Set("Authorization", "Bearer "+token)
|
||||
}
|
||||
return newProxyAwareHTTPClient(ctx, e.cfg, auth, 0).Do(httpReq)
|
||||
}
|
||||
|
||||
func (e *GitLabExecutor) translateToOpenAI(req cliproxyexecutor.Request, opts cliproxyexecutor.Options) ([]byte, error) {
|
||||
baseModel := thinking.ParseSuffix(req.Model).ModelName
|
||||
return sdktranslator.TranslateRequest(opts.SourceFormat, sdktranslator.FromString("openai"), baseModel, req.Payload, opts.Stream), nil
|
||||
}
|
||||
|
||||
func (e *GitLabExecutor) invoke(ctx context.Context, auth *cliproxyauth.Auth, prompt gitLabPrompt) (string, error) {
|
||||
if text, err := e.requestChat(ctx, auth, prompt); err == nil {
|
||||
return text, nil
|
||||
} else if !shouldFallbackToCodeSuggestions(err) {
|
||||
return "", err
|
||||
}
|
||||
return e.requestCodeSuggestions(ctx, auth, prompt)
|
||||
}
|
||||
|
||||
func (e *GitLabExecutor) requestChat(ctx context.Context, auth *cliproxyauth.Auth, prompt gitLabPrompt) (string, error) {
|
||||
body := map[string]any{
|
||||
"content": prompt.Instruction,
|
||||
"with_clean_history": true,
|
||||
}
|
||||
if len(prompt.ChatContext) > 0 {
|
||||
body["additional_context"] = prompt.ChatContext
|
||||
}
|
||||
return e.doJSONTextRequest(ctx, auth, gitLabChatEndpoint, body)
|
||||
}
|
||||
|
||||
func (e *GitLabExecutor) requestCodeSuggestions(ctx context.Context, auth *cliproxyauth.Auth, prompt gitLabPrompt) (string, error) {
|
||||
contentAbove := strings.TrimSpace(prompt.ContentAboveCursor)
|
||||
if contentAbove == "" {
|
||||
contentAbove = prompt.Instruction
|
||||
}
|
||||
body := map[string]any{
|
||||
"current_file": map[string]any{
|
||||
"file_name": prompt.FileName,
|
||||
"content_above_cursor": contentAbove,
|
||||
"content_below_cursor": "",
|
||||
},
|
||||
"intent": "generation",
|
||||
"generation_type": "small_file",
|
||||
"user_instruction": prompt.Instruction,
|
||||
"stream": false,
|
||||
}
|
||||
if len(prompt.CodeSuggestionContext) > 0 {
|
||||
body["context"] = prompt.CodeSuggestionContext
|
||||
}
|
||||
return e.doJSONTextRequest(ctx, auth, gitLabCodeSuggestionsEndpoint, body)
|
||||
}
|
||||
|
||||
func (e *GitLabExecutor) doJSONTextRequest(ctx context.Context, auth *cliproxyauth.Auth, endpoint string, payload map[string]any) (string, error) {
|
||||
token := gitLabPrimaryToken(auth)
|
||||
baseURL := gitLabBaseURL(auth)
|
||||
if token == "" || baseURL == "" {
|
||||
return "", statusErr{code: http.StatusUnauthorized, msg: "gitlab duo executor: missing credentials"}
|
||||
}
|
||||
|
||||
body, err := json.Marshal(payload)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("gitlab duo executor: marshal request failed: %w", err)
|
||||
}
|
||||
|
||||
url := strings.TrimRight(baseURL, "/") + endpoint
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodPost, url, bytes.NewReader(body))
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
req.Header.Set("Authorization", "Bearer "+token)
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
req.Header.Set("Accept", "application/json")
|
||||
req.Header.Set("User-Agent", "CLIProxyAPI/GitLab-Duo")
|
||||
|
||||
var authID, authLabel, authType, authValue string
|
||||
if auth != nil {
|
||||
authID = auth.ID
|
||||
authLabel = auth.Label
|
||||
authType, authValue = auth.AccountInfo()
|
||||
}
|
||||
recordAPIRequest(ctx, e.cfg, upstreamRequestLog{
|
||||
URL: url,
|
||||
Method: http.MethodPost,
|
||||
Headers: req.Header.Clone(),
|
||||
Body: body,
|
||||
Provider: e.Identifier(),
|
||||
AuthID: authID,
|
||||
AuthLabel: authLabel,
|
||||
AuthType: authType,
|
||||
AuthValue: authValue,
|
||||
})
|
||||
|
||||
httpClient := newProxyAwareHTTPClient(ctx, e.cfg, auth, 0)
|
||||
resp, err := httpClient.Do(req)
|
||||
if err != nil {
|
||||
recordAPIResponseError(ctx, e.cfg, err)
|
||||
return "", err
|
||||
}
|
||||
defer func() { _ = resp.Body.Close() }()
|
||||
recordAPIResponseMetadata(ctx, e.cfg, resp.StatusCode, resp.Header.Clone())
|
||||
|
||||
respBody, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
recordAPIResponseError(ctx, e.cfg, err)
|
||||
return "", err
|
||||
}
|
||||
appendAPIResponseChunk(ctx, e.cfg, respBody)
|
||||
|
||||
if resp.StatusCode < 200 || resp.StatusCode >= 300 {
|
||||
return "", statusErr{code: resp.StatusCode, msg: strings.TrimSpace(string(respBody))}
|
||||
}
|
||||
|
||||
text, err := parseGitLabTextResponse(endpoint, respBody)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
return strings.TrimSpace(text), nil
|
||||
}
|
||||
|
||||
func (e *GitLabExecutor) refreshOAuthToken(ctx context.Context, client *gitlab.AuthClient, auth *cliproxyauth.Auth, baseURL string) (*gitlab.TokenResponse, error) {
|
||||
if auth == nil {
|
||||
return nil, fmt.Errorf("gitlab duo executor: auth is nil")
|
||||
}
|
||||
refreshToken := gitLabMetadataString(auth.Metadata, "refresh_token")
|
||||
if refreshToken == "" {
|
||||
return nil, fmt.Errorf("gitlab duo executor: refresh token missing")
|
||||
}
|
||||
if !gitLabOAuthTokenNeedsRefresh(auth.Metadata) && gitLabPrimaryToken(auth) != "" {
|
||||
return nil, nil
|
||||
}
|
||||
return client.RefreshTokens(
|
||||
ctx,
|
||||
baseURL,
|
||||
gitLabMetadataString(auth.Metadata, "oauth_client_id"),
|
||||
gitLabMetadataString(auth.Metadata, "oauth_client_secret"),
|
||||
refreshToken,
|
||||
)
|
||||
}
|
||||
|
||||
func buildGitLabPrompt(payload []byte) gitLabPrompt {
|
||||
root := gjson.ParseBytes(payload)
|
||||
prompt := gitLabPrompt{
|
||||
FileName: "prompt.txt",
|
||||
}
|
||||
|
||||
msgs := root.Get("messages")
|
||||
if msgs.Exists() && msgs.IsArray() {
|
||||
systemIndex := 0
|
||||
contextIndex := 0
|
||||
transcript := make([]string, 0, len(msgs.Array()))
|
||||
var lastUser string
|
||||
msgs.ForEach(func(_, msg gjson.Result) bool {
|
||||
role := strings.TrimSpace(msg.Get("role").String())
|
||||
if role == "" {
|
||||
role = "user"
|
||||
}
|
||||
content := openAIContentText(msg.Get("content"))
|
||||
if content == "" {
|
||||
return true
|
||||
}
|
||||
switch role {
|
||||
case "system":
|
||||
systemIndex++
|
||||
prompt.ChatContext = append(prompt.ChatContext, map[string]any{
|
||||
"category": "snippet",
|
||||
"id": fmt.Sprintf("system-%d", systemIndex),
|
||||
"content": content,
|
||||
})
|
||||
case "user":
|
||||
lastUser = content
|
||||
contextIndex++
|
||||
prompt.CodeSuggestionContext = append(prompt.CodeSuggestionContext, map[string]any{
|
||||
"type": "snippet",
|
||||
"name": fmt.Sprintf("user-%d", contextIndex),
|
||||
"content": content,
|
||||
})
|
||||
transcript = append(transcript, "User:\n"+content)
|
||||
default:
|
||||
contextIndex++
|
||||
prompt.ChatContext = append(prompt.ChatContext, map[string]any{
|
||||
"category": "snippet",
|
||||
"id": fmt.Sprintf("%s-%d", role, contextIndex),
|
||||
"content": content,
|
||||
})
|
||||
prompt.CodeSuggestionContext = append(prompt.CodeSuggestionContext, map[string]any{
|
||||
"type": "snippet",
|
||||
"name": fmt.Sprintf("%s-%d", role, contextIndex),
|
||||
"content": content,
|
||||
})
|
||||
transcript = append(transcript, strings.Title(role)+":\n"+content)
|
||||
}
|
||||
return true
|
||||
})
|
||||
prompt.Instruction = strings.TrimSpace(lastUser)
|
||||
prompt.ContentAboveCursor = truncateGitLabPrompt(strings.Join(transcript, "\n\n"), 12000)
|
||||
}
|
||||
|
||||
if prompt.Instruction == "" {
|
||||
for _, key := range []string{"prompt", "input", "instructions"} {
|
||||
if value := strings.TrimSpace(root.Get(key).String()); value != "" {
|
||||
prompt.Instruction = value
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
if prompt.ContentAboveCursor == "" {
|
||||
prompt.ContentAboveCursor = prompt.Instruction
|
||||
}
|
||||
prompt.Instruction = truncateGitLabPrompt(prompt.Instruction, 4000)
|
||||
prompt.ContentAboveCursor = truncateGitLabPrompt(prompt.ContentAboveCursor, 12000)
|
||||
return prompt
|
||||
}
|
||||
|
||||
func openAIContentText(content gjson.Result) string {
|
||||
segments := make([]string, 0, 8)
|
||||
collectOpenAIContent(content, &segments)
|
||||
return strings.TrimSpace(strings.Join(segments, "\n"))
|
||||
}
|
||||
|
||||
func truncateGitLabPrompt(value string, limit int) string {
|
||||
value = strings.TrimSpace(value)
|
||||
if limit <= 0 || len(value) <= limit {
|
||||
return value
|
||||
}
|
||||
return strings.TrimSpace(value[:limit])
|
||||
}
|
||||
|
||||
func parseGitLabTextResponse(endpoint string, body []byte) (string, error) {
|
||||
if endpoint == gitLabChatEndpoint {
|
||||
var text string
|
||||
if err := json.Unmarshal(body, &text); err == nil {
|
||||
return text, nil
|
||||
}
|
||||
if value := strings.TrimSpace(gjson.GetBytes(body, "response").String()); value != "" {
|
||||
return value, nil
|
||||
}
|
||||
}
|
||||
if value := strings.TrimSpace(gjson.GetBytes(body, "choices.0.text").String()); value != "" {
|
||||
return value, nil
|
||||
}
|
||||
if value := strings.TrimSpace(gjson.GetBytes(body, "response").String()); value != "" {
|
||||
return value, nil
|
||||
}
|
||||
var plain string
|
||||
if err := json.Unmarshal(body, &plain); err == nil && strings.TrimSpace(plain) != "" {
|
||||
return plain, nil
|
||||
}
|
||||
return "", fmt.Errorf("gitlab duo executor: upstream returned no text payload")
|
||||
}
|
||||
|
||||
func shouldFallbackToCodeSuggestions(err error) bool {
|
||||
if err == nil {
|
||||
return false
|
||||
}
|
||||
status, ok := err.(interface{ StatusCode() int })
|
||||
if !ok {
|
||||
return false
|
||||
}
|
||||
switch status.StatusCode() {
|
||||
case http.StatusForbidden, http.StatusNotFound, http.StatusMethodNotAllowed, http.StatusNotImplemented:
|
||||
return true
|
||||
default:
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
func buildGitLabOpenAIResponse(model, text string, translatedReq []byte) []byte {
|
||||
promptTokens, completionTokens := gitLabUsage(model, translatedReq, text)
|
||||
payload := map[string]any{
|
||||
"id": fmt.Sprintf("gitlab-%d", time.Now().UnixNano()),
|
||||
"object": "chat.completion",
|
||||
"created": time.Now().Unix(),
|
||||
"model": model,
|
||||
"choices": []map[string]any{{
|
||||
"index": 0,
|
||||
"message": map[string]any{
|
||||
"role": "assistant",
|
||||
"content": text,
|
||||
},
|
||||
"finish_reason": "stop",
|
||||
}},
|
||||
"usage": map[string]any{
|
||||
"prompt_tokens": promptTokens,
|
||||
"completion_tokens": completionTokens,
|
||||
"total_tokens": promptTokens + completionTokens,
|
||||
},
|
||||
}
|
||||
raw, _ := json.Marshal(payload)
|
||||
return raw
|
||||
}
|
||||
|
||||
func buildGitLabOpenAIStream(model, text string) []string {
|
||||
now := time.Now().Unix()
|
||||
id := fmt.Sprintf("gitlab-%d", time.Now().UnixNano())
|
||||
chunks := []map[string]any{
|
||||
{
|
||||
"id": id,
|
||||
"object": "chat.completion.chunk",
|
||||
"created": now,
|
||||
"model": model,
|
||||
"choices": []map[string]any{{
|
||||
"index": 0,
|
||||
"delta": map[string]any{"role": "assistant"},
|
||||
}},
|
||||
},
|
||||
{
|
||||
"id": id,
|
||||
"object": "chat.completion.chunk",
|
||||
"created": now,
|
||||
"model": model,
|
||||
"choices": []map[string]any{{
|
||||
"index": 0,
|
||||
"delta": map[string]any{"content": text},
|
||||
}},
|
||||
},
|
||||
{
|
||||
"id": id,
|
||||
"object": "chat.completion.chunk",
|
||||
"created": now,
|
||||
"model": model,
|
||||
"choices": []map[string]any{{
|
||||
"index": 0,
|
||||
"delta": map[string]any{},
|
||||
"finish_reason": "stop",
|
||||
}},
|
||||
},
|
||||
}
|
||||
lines := make([]string, 0, len(chunks)+1)
|
||||
for _, chunk := range chunks {
|
||||
raw, _ := json.Marshal(chunk)
|
||||
lines = append(lines, "data: "+string(raw))
|
||||
}
|
||||
lines = append(lines, "data: [DONE]")
|
||||
return lines
|
||||
}
|
||||
|
||||
func gitLabUsage(model string, translatedReq []byte, text string) (int64, int64) {
|
||||
enc, err := tokenizerForModel(model)
|
||||
if err != nil {
|
||||
return 0, 0
|
||||
}
|
||||
promptTokens, err := countOpenAIChatTokens(enc, translatedReq)
|
||||
if err != nil {
|
||||
promptTokens = 0
|
||||
}
|
||||
completionCount, err := enc.Count(strings.TrimSpace(text))
|
||||
if err != nil {
|
||||
return promptTokens, 0
|
||||
}
|
||||
return promptTokens, int64(completionCount)
|
||||
}
|
||||
|
||||
func gitLabPrimaryToken(auth *cliproxyauth.Auth) string {
|
||||
if auth == nil || auth.Metadata == nil {
|
||||
return ""
|
||||
}
|
||||
if token := gitLabMetadataString(auth.Metadata, "access_token"); token != "" {
|
||||
return token
|
||||
}
|
||||
return gitLabMetadataString(auth.Metadata, "personal_access_token")
|
||||
}
|
||||
|
||||
func gitLabBaseURL(auth *cliproxyauth.Auth) string {
|
||||
if auth == nil || auth.Metadata == nil {
|
||||
return ""
|
||||
}
|
||||
return gitlab.NormalizeBaseURL(gitLabMetadataString(auth.Metadata, "base_url"))
|
||||
}
|
||||
|
||||
func gitLabResolvedModel(auth *cliproxyauth.Auth, requested string) string {
|
||||
requested = strings.TrimSpace(thinking.ParseSuffix(requested).ModelName)
|
||||
if requested != "" && !strings.EqualFold(requested, "gitlab-duo") {
|
||||
return requested
|
||||
}
|
||||
if auth != nil && auth.Metadata != nil {
|
||||
for _, model := range gitlab.ExtractDiscoveredModels(auth.Metadata) {
|
||||
if name := strings.TrimSpace(model.ModelName); name != "" {
|
||||
return name
|
||||
}
|
||||
}
|
||||
}
|
||||
if requested != "" {
|
||||
return requested
|
||||
}
|
||||
return "gitlab-duo"
|
||||
}
|
||||
|
||||
func gitLabMetadataString(metadata map[string]any, keys ...string) string {
|
||||
for _, key := range keys {
|
||||
if metadata == nil {
|
||||
return ""
|
||||
}
|
||||
if value, ok := metadata[key].(string); ok {
|
||||
if trimmed := strings.TrimSpace(value); trimmed != "" {
|
||||
return trimmed
|
||||
}
|
||||
}
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
func gitLabOAuthTokenNeedsRefresh(metadata map[string]any) bool {
|
||||
expiry := gitLabMetadataString(metadata, "oauth_expires_at")
|
||||
if expiry == "" {
|
||||
return true
|
||||
}
|
||||
ts, err := time.Parse(time.RFC3339, expiry)
|
||||
if err != nil {
|
||||
return true
|
||||
}
|
||||
return time.Until(ts) <= 5*time.Minute
|
||||
}
|
||||
|
||||
func applyGitLabTokenMetadata(metadata map[string]any, tokenResp *gitlab.TokenResponse) {
|
||||
if metadata == nil || tokenResp == nil {
|
||||
return
|
||||
}
|
||||
if accessToken := strings.TrimSpace(tokenResp.AccessToken); accessToken != "" {
|
||||
metadata["access_token"] = accessToken
|
||||
}
|
||||
if refreshToken := strings.TrimSpace(tokenResp.RefreshToken); refreshToken != "" {
|
||||
metadata["refresh_token"] = refreshToken
|
||||
}
|
||||
if tokenType := strings.TrimSpace(tokenResp.TokenType); tokenType != "" {
|
||||
metadata["token_type"] = tokenType
|
||||
}
|
||||
if scope := strings.TrimSpace(tokenResp.Scope); scope != "" {
|
||||
metadata["scope"] = scope
|
||||
}
|
||||
if expiry := gitlab.TokenExpiry(time.Now(), tokenResp); !expiry.IsZero() {
|
||||
metadata["oauth_expires_at"] = expiry.Format(time.RFC3339)
|
||||
}
|
||||
}
|
||||
|
||||
func mergeGitLabDirectAccessMetadata(metadata map[string]any, direct *gitlab.DirectAccessResponse) {
|
||||
if metadata == nil || direct == nil {
|
||||
return
|
||||
}
|
||||
if base := strings.TrimSpace(direct.BaseURL); base != "" {
|
||||
metadata["duo_gateway_base_url"] = base
|
||||
}
|
||||
if token := strings.TrimSpace(direct.Token); token != "" {
|
||||
metadata["duo_gateway_token"] = token
|
||||
}
|
||||
if direct.ExpiresAt > 0 {
|
||||
expiry := time.Unix(direct.ExpiresAt, 0).UTC()
|
||||
metadata["duo_gateway_expires_at"] = expiry.Format(time.RFC3339)
|
||||
if ttl := expiry.Sub(time.Now().UTC()); ttl > 0 {
|
||||
interval := int(ttl.Seconds()) / 2
|
||||
switch {
|
||||
case interval < 60:
|
||||
interval = 60
|
||||
case interval > 240:
|
||||
interval = 240
|
||||
}
|
||||
metadata["refresh_interval_seconds"] = interval
|
||||
}
|
||||
}
|
||||
if len(direct.Headers) > 0 {
|
||||
headers := make(map[string]string, len(direct.Headers))
|
||||
for key, value := range direct.Headers {
|
||||
key = strings.TrimSpace(key)
|
||||
value = strings.TrimSpace(value)
|
||||
if key == "" || value == "" {
|
||||
continue
|
||||
}
|
||||
headers[key] = value
|
||||
}
|
||||
if len(headers) > 0 {
|
||||
metadata["duo_gateway_headers"] = headers
|
||||
}
|
||||
}
|
||||
if direct.ModelDetails != nil {
|
||||
modelDetails := map[string]any{}
|
||||
if provider := strings.TrimSpace(direct.ModelDetails.ModelProvider); provider != "" {
|
||||
modelDetails["model_provider"] = provider
|
||||
metadata["model_provider"] = provider
|
||||
}
|
||||
if model := strings.TrimSpace(direct.ModelDetails.ModelName); model != "" {
|
||||
modelDetails["model_name"] = model
|
||||
metadata["model_name"] = model
|
||||
}
|
||||
if len(modelDetails) > 0 {
|
||||
metadata["model_details"] = modelDetails
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func gitLabAuthKind(method string) string {
|
||||
switch strings.ToLower(strings.TrimSpace(method)) {
|
||||
case gitLabAuthMethodPAT:
|
||||
return "personal_access_token"
|
||||
default:
|
||||
return "oauth"
|
||||
}
|
||||
}
|
||||
|
||||
func GitLabModelsFromAuth(auth *cliproxyauth.Auth) []*registry.ModelInfo {
|
||||
models := make([]*registry.ModelInfo, 0, 4)
|
||||
seen := make(map[string]struct{}, 4)
|
||||
addModel := func(id, displayName, provider string) {
|
||||
id = strings.TrimSpace(id)
|
||||
if id == "" {
|
||||
return
|
||||
}
|
||||
key := strings.ToLower(id)
|
||||
if _, ok := seen[key]; ok {
|
||||
return
|
||||
}
|
||||
seen[key] = struct{}{}
|
||||
models = append(models, ®istry.ModelInfo{
|
||||
ID: id,
|
||||
Object: "model",
|
||||
Created: time.Now().Unix(),
|
||||
OwnedBy: "gitlab",
|
||||
Type: "gitlab",
|
||||
DisplayName: displayName,
|
||||
Description: provider,
|
||||
UserDefined: true,
|
||||
})
|
||||
}
|
||||
|
||||
addModel("gitlab-duo", "GitLab Duo", "gitlab")
|
||||
if auth == nil {
|
||||
return models
|
||||
}
|
||||
for _, model := range gitlab.ExtractDiscoveredModels(auth.Metadata) {
|
||||
name := strings.TrimSpace(model.ModelName)
|
||||
if name == "" {
|
||||
continue
|
||||
}
|
||||
displayName := "GitLab Duo"
|
||||
if provider := strings.TrimSpace(model.ModelProvider); provider != "" {
|
||||
displayName = fmt.Sprintf("GitLab Duo (%s)", provider)
|
||||
}
|
||||
addModel(name, displayName, strings.TrimSpace(model.ModelProvider))
|
||||
}
|
||||
return models
|
||||
}
|
||||
124
internal/runtime/executor/gitlab_executor_test.go
Normal file
124
internal/runtime/executor/gitlab_executor_test.go
Normal file
@@ -0,0 +1,124 @@
|
||||
package executor
|
||||
|
||||
import (
|
||||
"context"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
cliproxyauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth"
|
||||
cliproxyexecutor "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/executor"
|
||||
)
|
||||
|
||||
func TestGitLabExecutorRefresh_WithPATStoresGatewayMetadata(t *testing.T) {
|
||||
var server *httptest.Server
|
||||
server = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if r.URL.Path != "/api/v4/code_suggestions/direct_access" {
|
||||
t.Fatalf("unexpected path %s", r.URL.Path)
|
||||
}
|
||||
if got := r.Header.Get("Authorization"); got != "Bearer pat-123" {
|
||||
t.Fatalf("unexpected Authorization header %q", got)
|
||||
}
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
_, _ = w.Write([]byte(`{
|
||||
"base_url":"` + server.URL + `",
|
||||
"token":"gateway-token",
|
||||
"expires_at":2000000000,
|
||||
"headers":{"X-Gitlab-Realm":"saas"},
|
||||
"model_details":{"model_provider":"mistral","model_name":"codestral-2501"}
|
||||
}`))
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
exec := NewGitLabExecutor(nil)
|
||||
auth := &cliproxyauth.Auth{
|
||||
ID: "gitlab-pat.json",
|
||||
Provider: "gitlab",
|
||||
Metadata: map[string]any{
|
||||
"type": "gitlab",
|
||||
"auth_method": "pat",
|
||||
"base_url": server.URL,
|
||||
"personal_access_token": "pat-123",
|
||||
},
|
||||
}
|
||||
|
||||
updated, err := exec.Refresh(context.Background(), auth)
|
||||
if err != nil {
|
||||
t.Fatalf("Refresh returned error: %v", err)
|
||||
}
|
||||
if got := metadataString(updated.Metadata, "duo_gateway_token"); got != "gateway-token" {
|
||||
t.Fatalf("unexpected gateway token %q", got)
|
||||
}
|
||||
if got := gitLabModelName(updated); got != "codestral-2501" {
|
||||
t.Fatalf("unexpected model name %q", got)
|
||||
}
|
||||
headers := gitLabHeaders(updated)
|
||||
if headers["X-Gitlab-Realm"] != "saas" {
|
||||
t.Fatalf("unexpected gateway headers %+v", headers)
|
||||
}
|
||||
}
|
||||
|
||||
func TestGitLabExecutorExecute_UsesGatewayHeadersAndResolvedModel(t *testing.T) {
|
||||
var receivedAuth string
|
||||
var receivedRealm string
|
||||
var receivedModel string
|
||||
|
||||
gateway := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
body, _ := io.ReadAll(r.Body)
|
||||
receivedAuth = r.Header.Get("Authorization")
|
||||
receivedRealm = r.Header.Get("X-Gitlab-Realm")
|
||||
receivedModel = findJSONField(string(body), `"model":"`, `"`)
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
_, _ = w.Write([]byte(`{"id":"ok","object":"chat.completion","choices":[{"index":0,"message":{"role":"assistant","content":"ok"},"finish_reason":"stop"}]}`))
|
||||
}))
|
||||
defer gateway.Close()
|
||||
|
||||
exec := NewGitLabExecutor(nil)
|
||||
auth := &cliproxyauth.Auth{
|
||||
ID: "gitlab-oauth.json",
|
||||
Provider: "gitlab",
|
||||
Metadata: map[string]any{
|
||||
"type": "gitlab",
|
||||
"auth_method": "oauth",
|
||||
"duo_gateway_base_url": gateway.URL,
|
||||
"duo_gateway_token": "gateway-token",
|
||||
"duo_gateway_headers": map[string]any{"X-Gitlab-Realm": "saas"},
|
||||
"model_details": map[string]any{"model_name": "codestral-2501", "model_provider": "mistral"},
|
||||
},
|
||||
}
|
||||
|
||||
resp, err := exec.Execute(context.Background(), auth, cliproxyexecutor.Request{
|
||||
Model: "gitlab-duo",
|
||||
Payload: []byte(`{"model":"gitlab-duo","messages":[{"role":"user","content":"hello"}]}`),
|
||||
}, cliproxyexecutor.Options{})
|
||||
if err != nil {
|
||||
t.Fatalf("Execute returned error: %v", err)
|
||||
}
|
||||
if len(resp.Payload) == 0 {
|
||||
t.Fatal("expected non-empty payload")
|
||||
}
|
||||
if receivedAuth != "Bearer gateway-token" {
|
||||
t.Fatalf("unexpected Authorization header %q", receivedAuth)
|
||||
}
|
||||
if receivedRealm != "saas" {
|
||||
t.Fatalf("unexpected X-Gitlab-Realm header %q", receivedRealm)
|
||||
}
|
||||
if receivedModel != "codestral-2501" {
|
||||
t.Fatalf("unexpected resolved model %q", receivedModel)
|
||||
}
|
||||
}
|
||||
|
||||
func findJSONField(body, prefix, suffix string) string {
|
||||
start := strings.Index(body, prefix)
|
||||
if start < 0 {
|
||||
return ""
|
||||
}
|
||||
start += len(prefix)
|
||||
end := strings.Index(body[start:], suffix)
|
||||
if end < 0 {
|
||||
return ""
|
||||
}
|
||||
return body[start : start+end]
|
||||
}
|
||||
462
sdk/auth/gitlab.go
Normal file
462
sdk/auth/gitlab.go
Normal file
@@ -0,0 +1,462 @@
|
||||
package auth
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
gitlabauth "github.com/router-for-me/CLIProxyAPI/v6/internal/auth/gitlab"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/browser"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/misc"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/util"
|
||||
coreauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth"
|
||||
log "github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
const (
|
||||
gitLabLoginModeMetadataKey = "login_mode"
|
||||
gitLabLoginModeOAuth = "oauth"
|
||||
gitLabLoginModePAT = "pat"
|
||||
gitLabBaseURLMetadataKey = "base_url"
|
||||
gitLabOAuthClientIDMetadataKey = "oauth_client_id"
|
||||
gitLabOAuthClientSecretMetadataKey = "oauth_client_secret"
|
||||
gitLabPersonalAccessTokenMetadataKey = "personal_access_token"
|
||||
)
|
||||
|
||||
var gitLabRefreshLead = 5 * time.Minute
|
||||
|
||||
type GitLabAuthenticator struct {
|
||||
CallbackPort int
|
||||
}
|
||||
|
||||
func NewGitLabAuthenticator() *GitLabAuthenticator {
|
||||
return &GitLabAuthenticator{CallbackPort: gitlabauth.DefaultCallbackPort}
|
||||
}
|
||||
|
||||
func (a *GitLabAuthenticator) Provider() string {
|
||||
return "gitlab"
|
||||
}
|
||||
|
||||
func (a *GitLabAuthenticator) RefreshLead() *time.Duration {
|
||||
return &gitLabRefreshLead
|
||||
}
|
||||
|
||||
func (a *GitLabAuthenticator) Login(ctx context.Context, cfg *config.Config, opts *LoginOptions) (*coreauth.Auth, error) {
|
||||
if cfg == nil {
|
||||
return nil, fmt.Errorf("cliproxy auth: configuration is required")
|
||||
}
|
||||
if ctx == nil {
|
||||
ctx = context.Background()
|
||||
}
|
||||
if opts == nil {
|
||||
opts = &LoginOptions{}
|
||||
}
|
||||
|
||||
switch strings.ToLower(strings.TrimSpace(opts.Metadata[gitLabLoginModeMetadataKey])) {
|
||||
case "", gitLabLoginModeOAuth:
|
||||
return a.loginOAuth(ctx, cfg, opts)
|
||||
case gitLabLoginModePAT:
|
||||
return a.loginPAT(ctx, cfg, opts)
|
||||
default:
|
||||
return nil, fmt.Errorf("gitlab auth: unsupported login mode %q", opts.Metadata[gitLabLoginModeMetadataKey])
|
||||
}
|
||||
}
|
||||
|
||||
func (a *GitLabAuthenticator) loginOAuth(ctx context.Context, cfg *config.Config, opts *LoginOptions) (*coreauth.Auth, error) {
|
||||
client := gitlabauth.NewAuthClient(cfg)
|
||||
baseURL := a.resolveString(opts, gitLabBaseURLMetadataKey, gitlabauth.DefaultBaseURL)
|
||||
clientID, err := a.requireInput(opts, gitLabOAuthClientIDMetadataKey, "Enter GitLab OAuth application client ID: ")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
clientSecret, err := a.optionalInput(opts, gitLabOAuthClientSecretMetadataKey, "Enter GitLab OAuth application client secret (press Enter for public PKCE app): ")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
callbackPort := a.CallbackPort
|
||||
if opts.CallbackPort > 0 {
|
||||
callbackPort = opts.CallbackPort
|
||||
}
|
||||
redirectURI := gitlabauth.RedirectURL(callbackPort)
|
||||
|
||||
pkceCodes, err := gitlabauth.GeneratePKCECodes()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
state, err := misc.GenerateRandomState()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("gitlab state generation failed: %w", err)
|
||||
}
|
||||
|
||||
oauthServer := gitlabauth.NewOAuthServer(callbackPort)
|
||||
if err := oauthServer.Start(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer func() {
|
||||
stopCtx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
|
||||
defer cancel()
|
||||
if stopErr := oauthServer.Stop(stopCtx); stopErr != nil {
|
||||
log.Warnf("gitlab oauth server stop error: %v", stopErr)
|
||||
}
|
||||
}()
|
||||
|
||||
authURL, err := client.GenerateAuthURL(baseURL, clientID, redirectURI, state, pkceCodes)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if !opts.NoBrowser {
|
||||
fmt.Println("Opening browser for GitLab Duo authentication")
|
||||
if !browser.IsAvailable() {
|
||||
log.Warn("No browser available; please open the URL manually")
|
||||
util.PrintSSHTunnelInstructions(callbackPort)
|
||||
fmt.Printf("Visit the following URL to continue authentication:\n%s\n", authURL)
|
||||
} else if err = browser.OpenURL(authURL); err != nil {
|
||||
log.Warnf("Failed to open browser automatically: %v", err)
|
||||
util.PrintSSHTunnelInstructions(callbackPort)
|
||||
fmt.Printf("Visit the following URL to continue authentication:\n%s\n", authURL)
|
||||
}
|
||||
} else {
|
||||
util.PrintSSHTunnelInstructions(callbackPort)
|
||||
fmt.Printf("Visit the following URL to continue authentication:\n%s\n", authURL)
|
||||
}
|
||||
|
||||
fmt.Println("Waiting for GitLab OAuth callback...")
|
||||
|
||||
callbackCh := make(chan *gitlabauth.OAuthResult, 1)
|
||||
callbackErrCh := make(chan error, 1)
|
||||
go func() {
|
||||
result, waitErr := oauthServer.WaitForCallback(5 * time.Minute)
|
||||
if waitErr != nil {
|
||||
callbackErrCh <- waitErr
|
||||
return
|
||||
}
|
||||
callbackCh <- result
|
||||
}()
|
||||
|
||||
var result *gitlabauth.OAuthResult
|
||||
var manualPromptTimer *time.Timer
|
||||
var manualPromptC <-chan time.Time
|
||||
if opts.Prompt != nil {
|
||||
manualPromptTimer = time.NewTimer(15 * time.Second)
|
||||
manualPromptC = manualPromptTimer.C
|
||||
defer manualPromptTimer.Stop()
|
||||
}
|
||||
|
||||
waitForCallback:
|
||||
for {
|
||||
select {
|
||||
case result = <-callbackCh:
|
||||
break waitForCallback
|
||||
case err = <-callbackErrCh:
|
||||
return nil, err
|
||||
case <-manualPromptC:
|
||||
manualPromptC = nil
|
||||
if manualPromptTimer != nil {
|
||||
manualPromptTimer.Stop()
|
||||
}
|
||||
input, promptErr := opts.Prompt("Paste the GitLab callback URL (or press Enter to keep waiting): ")
|
||||
if promptErr != nil {
|
||||
return nil, promptErr
|
||||
}
|
||||
parsed, parseErr := misc.ParseOAuthCallback(input)
|
||||
if parseErr != nil {
|
||||
return nil, parseErr
|
||||
}
|
||||
if parsed == nil {
|
||||
continue
|
||||
}
|
||||
result = &gitlabauth.OAuthResult{
|
||||
Code: parsed.Code,
|
||||
State: parsed.State,
|
||||
Error: parsed.Error,
|
||||
}
|
||||
break waitForCallback
|
||||
}
|
||||
}
|
||||
|
||||
if result.Error != "" {
|
||||
return nil, fmt.Errorf("gitlab oauth returned error: %s", result.Error)
|
||||
}
|
||||
if result.State != state {
|
||||
return nil, fmt.Errorf("gitlab auth: state mismatch")
|
||||
}
|
||||
|
||||
tokenResp, err := client.ExchangeCodeForTokens(ctx, baseURL, clientID, clientSecret, redirectURI, result.Code, pkceCodes.CodeVerifier)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
accessToken := strings.TrimSpace(tokenResp.AccessToken)
|
||||
if accessToken == "" {
|
||||
return nil, fmt.Errorf("gitlab auth: missing access token")
|
||||
}
|
||||
|
||||
user, err := client.GetCurrentUser(ctx, baseURL, accessToken)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
direct, err := client.FetchDirectAccess(ctx, baseURL, accessToken)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
identifier := gitLabAccountIdentifier(user)
|
||||
fileName := fmt.Sprintf("gitlab-%s.json", sanitizeGitLabFileName(identifier))
|
||||
metadata := buildGitLabAuthMetadata(baseURL, gitLabLoginModeOAuth, tokenResp, direct)
|
||||
metadata["auth_kind"] = "oauth"
|
||||
metadata[gitLabOAuthClientIDMetadataKey] = clientID
|
||||
if strings.TrimSpace(clientSecret) != "" {
|
||||
metadata[gitLabOAuthClientSecretMetadataKey] = clientSecret
|
||||
}
|
||||
metadata["username"] = strings.TrimSpace(user.Username)
|
||||
if email := strings.TrimSpace(primaryGitLabEmail(user)); email != "" {
|
||||
metadata["email"] = email
|
||||
}
|
||||
metadata["name"] = strings.TrimSpace(user.Name)
|
||||
|
||||
fmt.Println("GitLab Duo authentication successful")
|
||||
|
||||
return &coreauth.Auth{
|
||||
ID: fileName,
|
||||
Provider: a.Provider(),
|
||||
FileName: fileName,
|
||||
Label: identifier,
|
||||
Metadata: metadata,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (a *GitLabAuthenticator) loginPAT(ctx context.Context, cfg *config.Config, opts *LoginOptions) (*coreauth.Auth, error) {
|
||||
client := gitlabauth.NewAuthClient(cfg)
|
||||
baseURL := a.resolveString(opts, gitLabBaseURLMetadataKey, gitlabauth.DefaultBaseURL)
|
||||
token, err := a.requireInput(opts, gitLabPersonalAccessTokenMetadataKey, "Enter GitLab personal access token: ")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
user, err := client.GetCurrentUser(ctx, baseURL, token)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
_, err = client.GetPersonalAccessTokenSelf(ctx, baseURL, token)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
direct, err := client.FetchDirectAccess(ctx, baseURL, token)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
identifier := gitLabAccountIdentifier(user)
|
||||
fileName := fmt.Sprintf("gitlab-%s-pat.json", sanitizeGitLabFileName(identifier))
|
||||
metadata := buildGitLabAuthMetadata(baseURL, gitLabLoginModePAT, nil, direct)
|
||||
metadata["auth_kind"] = "personal_access_token"
|
||||
metadata[gitLabPersonalAccessTokenMetadataKey] = strings.TrimSpace(token)
|
||||
metadata["token_preview"] = maskGitLabToken(token)
|
||||
metadata["username"] = strings.TrimSpace(user.Username)
|
||||
if email := strings.TrimSpace(primaryGitLabEmail(user)); email != "" {
|
||||
metadata["email"] = email
|
||||
}
|
||||
metadata["name"] = strings.TrimSpace(user.Name)
|
||||
|
||||
fmt.Println("GitLab Duo PAT authentication successful")
|
||||
|
||||
return &coreauth.Auth{
|
||||
ID: fileName,
|
||||
Provider: a.Provider(),
|
||||
FileName: fileName,
|
||||
Label: identifier + " (PAT)",
|
||||
Metadata: metadata,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func buildGitLabAuthMetadata(baseURL, mode string, tokenResp *gitlabauth.TokenResponse, direct *gitlabauth.DirectAccessResponse) map[string]any {
|
||||
metadata := map[string]any{
|
||||
"type": "gitlab",
|
||||
"auth_method": strings.TrimSpace(mode),
|
||||
gitLabBaseURLMetadataKey: gitlabauth.NormalizeBaseURL(baseURL),
|
||||
"last_refresh": time.Now().UTC().Format(time.RFC3339),
|
||||
"refresh_interval_seconds": 240,
|
||||
}
|
||||
if tokenResp != nil {
|
||||
metadata["access_token"] = strings.TrimSpace(tokenResp.AccessToken)
|
||||
if refreshToken := strings.TrimSpace(tokenResp.RefreshToken); refreshToken != "" {
|
||||
metadata["refresh_token"] = refreshToken
|
||||
}
|
||||
if tokenType := strings.TrimSpace(tokenResp.TokenType); tokenType != "" {
|
||||
metadata["token_type"] = tokenType
|
||||
}
|
||||
if scope := strings.TrimSpace(tokenResp.Scope); scope != "" {
|
||||
metadata["scope"] = scope
|
||||
}
|
||||
if expiry := gitlabauth.TokenExpiry(time.Now(), tokenResp); !expiry.IsZero() {
|
||||
metadata["oauth_expires_at"] = expiry.Format(time.RFC3339)
|
||||
}
|
||||
}
|
||||
mergeGitLabDirectAccessMetadata(metadata, direct)
|
||||
return metadata
|
||||
}
|
||||
|
||||
func mergeGitLabDirectAccessMetadata(metadata map[string]any, direct *gitlabauth.DirectAccessResponse) {
|
||||
if metadata == nil || direct == nil {
|
||||
return
|
||||
}
|
||||
if base := strings.TrimSpace(direct.BaseURL); base != "" {
|
||||
metadata["duo_gateway_base_url"] = base
|
||||
}
|
||||
if token := strings.TrimSpace(direct.Token); token != "" {
|
||||
metadata["duo_gateway_token"] = token
|
||||
}
|
||||
if direct.ExpiresAt > 0 {
|
||||
expiry := time.Unix(direct.ExpiresAt, 0).UTC()
|
||||
metadata["duo_gateway_expires_at"] = expiry.Format(time.RFC3339)
|
||||
now := time.Now().UTC()
|
||||
if ttl := expiry.Sub(now); ttl > 0 {
|
||||
interval := int(ttl.Seconds()) / 2
|
||||
switch {
|
||||
case interval < 60:
|
||||
interval = 60
|
||||
case interval > 240:
|
||||
interval = 240
|
||||
}
|
||||
metadata["refresh_interval_seconds"] = interval
|
||||
}
|
||||
}
|
||||
if len(direct.Headers) > 0 {
|
||||
headers := make(map[string]string, len(direct.Headers))
|
||||
for key, value := range direct.Headers {
|
||||
key = strings.TrimSpace(key)
|
||||
value = strings.TrimSpace(value)
|
||||
if key == "" || value == "" {
|
||||
continue
|
||||
}
|
||||
headers[key] = value
|
||||
}
|
||||
if len(headers) > 0 {
|
||||
metadata["duo_gateway_headers"] = headers
|
||||
}
|
||||
}
|
||||
if direct.ModelDetails != nil {
|
||||
modelDetails := map[string]any{}
|
||||
if provider := strings.TrimSpace(direct.ModelDetails.ModelProvider); provider != "" {
|
||||
modelDetails["model_provider"] = provider
|
||||
metadata["model_provider"] = provider
|
||||
}
|
||||
if model := strings.TrimSpace(direct.ModelDetails.ModelName); model != "" {
|
||||
modelDetails["model_name"] = model
|
||||
metadata["model_name"] = model
|
||||
}
|
||||
if len(modelDetails) > 0 {
|
||||
metadata["model_details"] = modelDetails
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (a *GitLabAuthenticator) resolveString(opts *LoginOptions, key, fallback string) string {
|
||||
if opts != nil && opts.Metadata != nil {
|
||||
if value := strings.TrimSpace(opts.Metadata[key]); value != "" {
|
||||
return value
|
||||
}
|
||||
}
|
||||
if strings.TrimSpace(fallback) != "" {
|
||||
return fallback
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
func (a *GitLabAuthenticator) requireInput(opts *LoginOptions, key, prompt string) (string, error) {
|
||||
if value := a.resolveString(opts, key, ""); value != "" {
|
||||
return value, nil
|
||||
}
|
||||
if opts != nil && opts.Prompt != nil {
|
||||
value, err := opts.Prompt(prompt)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
if trimmed := strings.TrimSpace(value); trimmed != "" {
|
||||
return trimmed, nil
|
||||
}
|
||||
}
|
||||
return "", fmt.Errorf("gitlab auth: missing required %s", key)
|
||||
}
|
||||
|
||||
func (a *GitLabAuthenticator) optionalInput(opts *LoginOptions, key, prompt string) (string, error) {
|
||||
if value := a.resolveString(opts, key, ""); value != "" {
|
||||
return value, nil
|
||||
}
|
||||
if opts != nil && opts.Prompt != nil {
|
||||
value, err := opts.Prompt(prompt)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
return strings.TrimSpace(value), nil
|
||||
}
|
||||
return "", nil
|
||||
}
|
||||
|
||||
func primaryGitLabEmail(user *gitlabauth.User) string {
|
||||
if user == nil {
|
||||
return ""
|
||||
}
|
||||
if value := strings.TrimSpace(user.Email); value != "" {
|
||||
return value
|
||||
}
|
||||
return strings.TrimSpace(user.PublicEmail)
|
||||
}
|
||||
|
||||
func gitLabAccountIdentifier(user *gitlabauth.User) string {
|
||||
if user == nil {
|
||||
return "user"
|
||||
}
|
||||
for _, value := range []string{user.Username, primaryGitLabEmail(user), user.Name} {
|
||||
if trimmed := strings.TrimSpace(value); trimmed != "" {
|
||||
return trimmed
|
||||
}
|
||||
}
|
||||
return "user"
|
||||
}
|
||||
|
||||
func sanitizeGitLabFileName(value string) string {
|
||||
value = strings.TrimSpace(strings.ToLower(value))
|
||||
if value == "" {
|
||||
return "user"
|
||||
}
|
||||
var builder strings.Builder
|
||||
lastDash := false
|
||||
for _, r := range value {
|
||||
switch {
|
||||
case r >= 'a' && r <= 'z':
|
||||
builder.WriteRune(r)
|
||||
lastDash = false
|
||||
case r >= '0' && r <= '9':
|
||||
builder.WriteRune(r)
|
||||
lastDash = false
|
||||
case r == '-' || r == '_' || r == '.':
|
||||
builder.WriteRune(r)
|
||||
lastDash = false
|
||||
default:
|
||||
if !lastDash {
|
||||
builder.WriteRune('-')
|
||||
lastDash = true
|
||||
}
|
||||
}
|
||||
}
|
||||
result := strings.Trim(builder.String(), "-")
|
||||
if result == "" {
|
||||
return "user"
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
func maskGitLabToken(token string) string {
|
||||
trimmed := strings.TrimSpace(token)
|
||||
if trimmed == "" {
|
||||
return ""
|
||||
}
|
||||
if len(trimmed) <= 8 {
|
||||
return trimmed
|
||||
}
|
||||
return trimmed[:4] + "..." + trimmed[len(trimmed)-4:]
|
||||
}
|
||||
@@ -17,6 +17,7 @@ func init() {
|
||||
registerRefreshLead("kimi", func() Authenticator { return NewKimiAuthenticator() })
|
||||
registerRefreshLead("kiro", func() Authenticator { return NewKiroAuthenticator() })
|
||||
registerRefreshLead("github-copilot", func() Authenticator { return NewGitHubCopilotAuthenticator() })
|
||||
registerRefreshLead("gitlab", func() Authenticator { return NewGitLabAuthenticator() })
|
||||
}
|
||||
|
||||
func registerRefreshLead(provider string, factory func() Authenticator) {
|
||||
|
||||
@@ -390,6 +390,27 @@ func (a *Auth) AccountInfo() (string, string) {
|
||||
|
||||
// Check metadata for email first (OAuth-style auth)
|
||||
if a.Metadata != nil {
|
||||
if method, ok := a.Metadata["auth_method"].(string); ok {
|
||||
switch strings.ToLower(strings.TrimSpace(method)) {
|
||||
case "oauth":
|
||||
for _, key := range []string{"email", "username", "name"} {
|
||||
if value, okValue := a.Metadata[key].(string); okValue {
|
||||
if trimmed := strings.TrimSpace(value); trimmed != "" {
|
||||
return "oauth", trimmed
|
||||
}
|
||||
}
|
||||
}
|
||||
case "pat", "personal_access_token":
|
||||
for _, key := range []string{"username", "email", "name", "token_preview"} {
|
||||
if value, okValue := a.Metadata[key].(string); okValue {
|
||||
if trimmed := strings.TrimSpace(value); trimmed != "" {
|
||||
return "personal_access_token", trimmed
|
||||
}
|
||||
}
|
||||
}
|
||||
return "personal_access_token", ""
|
||||
}
|
||||
}
|
||||
if v, ok := a.Metadata["email"].(string); ok {
|
||||
email := strings.TrimSpace(v)
|
||||
if email != "" {
|
||||
|
||||
@@ -119,6 +119,7 @@ func newDefaultAuthManager() *sdkAuth.Manager {
|
||||
sdkAuth.NewCodexAuthenticator(),
|
||||
sdkAuth.NewClaudeAuthenticator(),
|
||||
sdkAuth.NewQwenAuthenticator(),
|
||||
sdkAuth.NewGitLabAuthenticator(),
|
||||
)
|
||||
}
|
||||
|
||||
@@ -444,6 +445,8 @@ func (s *Service) ensureExecutorsForAuthWithMode(a *coreauth.Auth, forceReplace
|
||||
s.coreManager.RegisterExecutor(executor.NewKiloExecutor(s.cfg))
|
||||
case "github-copilot":
|
||||
s.coreManager.RegisterExecutor(executor.NewGitHubCopilotExecutor(s.cfg))
|
||||
case "gitlab":
|
||||
s.coreManager.RegisterExecutor(executor.NewGitLabExecutor(s.cfg))
|
||||
default:
|
||||
providerKey := strings.ToLower(strings.TrimSpace(a.Provider))
|
||||
if providerKey == "" {
|
||||
@@ -891,7 +894,7 @@ func (s *Service) registerModelsForAuth(a *coreauth.Auth) {
|
||||
models = applyExcludedModels(models, excluded)
|
||||
case "kimi":
|
||||
models = registry.GetKimiModels()
|
||||
models = applyExcludedModels(models, excluded)
|
||||
models = applyExcludedModels(models, excluded)
|
||||
case "github-copilot":
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 15*time.Second)
|
||||
defer cancel()
|
||||
@@ -903,6 +906,9 @@ func (s *Service) registerModelsForAuth(a *coreauth.Auth) {
|
||||
case "kilo":
|
||||
models = executor.FetchKiloModels(context.Background(), a, s.cfg)
|
||||
models = applyExcludedModels(models, excluded)
|
||||
case "gitlab":
|
||||
models = executor.GitLabModelsFromAuth(a)
|
||||
models = applyExcludedModels(models, excluded)
|
||||
default:
|
||||
// Handle OpenAI-compatibility providers by name using config
|
||||
if s.cfg != nil {
|
||||
|
||||
59
sdk/cliproxy/service_gitlab_models_test.go
Normal file
59
sdk/cliproxy/service_gitlab_models_test.go
Normal file
@@ -0,0 +1,59 @@
|
||||
package cliproxy
|
||||
|
||||
import (
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/registry"
|
||||
coreauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/sdk/config"
|
||||
)
|
||||
|
||||
func TestRegisterModelsForAuth_GitLabUsesDiscoveredModelAndAlias(t *testing.T) {
|
||||
service := &Service{cfg: &config.Config{}}
|
||||
auth := &coreauth.Auth{
|
||||
ID: "gitlab-auth",
|
||||
Provider: "gitlab",
|
||||
Status: coreauth.StatusActive,
|
||||
Metadata: map[string]any{
|
||||
"model_details": map[string]any{
|
||||
"model_provider": "mistral",
|
||||
"model_name": "codestral-2501",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
reg := registry.GetGlobalRegistry()
|
||||
reg.UnregisterClient(auth.ID)
|
||||
t.Cleanup(func() {
|
||||
reg.UnregisterClient(auth.ID)
|
||||
})
|
||||
|
||||
service.registerModelsForAuth(auth)
|
||||
|
||||
models := reg.GetModelsForClient(auth.ID)
|
||||
if len(models) == 0 {
|
||||
t.Fatal("expected GitLab models to be registered")
|
||||
}
|
||||
|
||||
seenActual := false
|
||||
seenAlias := false
|
||||
for _, model := range models {
|
||||
if model == nil {
|
||||
continue
|
||||
}
|
||||
switch strings.TrimSpace(model.ID) {
|
||||
case "codestral-2501":
|
||||
seenActual = true
|
||||
case "gitlab-duo":
|
||||
seenAlias = true
|
||||
}
|
||||
}
|
||||
|
||||
if !seenActual {
|
||||
t.Fatal("expected discovered GitLab model to be registered")
|
||||
}
|
||||
if !seenAlias {
|
||||
t.Fatal("expected stable GitLab Duo alias to be registered")
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user