mirror of
https://github.com/router-for-me/CLIProxyAPIPlus.git
synced 2026-03-22 09:10:30 +00:00
493 lines
14 KiB
Go
493 lines
14 KiB
Go
package gitlab
|
|
|
|
import (
|
|
"context"
|
|
"crypto/rand"
|
|
"crypto/sha256"
|
|
"encoding/base64"
|
|
"encoding/json"
|
|
"fmt"
|
|
"io"
|
|
"net"
|
|
"net/http"
|
|
"net/url"
|
|
"strconv"
|
|
"strings"
|
|
"sync"
|
|
"time"
|
|
|
|
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
|
|
"github.com/router-for-me/CLIProxyAPI/v6/internal/util"
|
|
log "github.com/sirupsen/logrus"
|
|
)
|
|
|
|
const (
|
|
DefaultBaseURL = "https://gitlab.com"
|
|
DefaultCallbackPort = 17171
|
|
defaultOAuthScope = "api read_user"
|
|
)
|
|
|
|
type PKCECodes struct {
|
|
CodeVerifier string
|
|
CodeChallenge string
|
|
}
|
|
|
|
type OAuthResult struct {
|
|
Code string
|
|
State string
|
|
Error string
|
|
}
|
|
|
|
type OAuthServer struct {
|
|
server *http.Server
|
|
port int
|
|
resultChan chan *OAuthResult
|
|
errorChan chan error
|
|
mu sync.Mutex
|
|
running bool
|
|
}
|
|
|
|
type TokenResponse struct {
|
|
AccessToken string `json:"access_token"`
|
|
TokenType string `json:"token_type"`
|
|
RefreshToken string `json:"refresh_token"`
|
|
Scope string `json:"scope"`
|
|
CreatedAt int64 `json:"created_at"`
|
|
ExpiresIn int `json:"expires_in"`
|
|
}
|
|
|
|
type User struct {
|
|
ID int64 `json:"id"`
|
|
Username string `json:"username"`
|
|
Name string `json:"name"`
|
|
Email string `json:"email"`
|
|
PublicEmail string `json:"public_email"`
|
|
}
|
|
|
|
type PersonalAccessTokenSelf struct {
|
|
ID int64 `json:"id"`
|
|
Name string `json:"name"`
|
|
Scopes []string `json:"scopes"`
|
|
UserID int64 `json:"user_id"`
|
|
}
|
|
|
|
type ModelDetails struct {
|
|
ModelProvider string `json:"model_provider"`
|
|
ModelName string `json:"model_name"`
|
|
}
|
|
|
|
type DirectAccessResponse struct {
|
|
BaseURL string `json:"base_url"`
|
|
Token string `json:"token"`
|
|
ExpiresAt int64 `json:"expires_at"`
|
|
Headers map[string]string `json:"headers"`
|
|
ModelDetails *ModelDetails `json:"model_details,omitempty"`
|
|
}
|
|
|
|
type DiscoveredModel struct {
|
|
ModelProvider string
|
|
ModelName string
|
|
}
|
|
|
|
type AuthClient struct {
|
|
httpClient *http.Client
|
|
}
|
|
|
|
func NewAuthClient(cfg *config.Config) *AuthClient {
|
|
client := &http.Client{}
|
|
if cfg != nil {
|
|
client = util.SetProxy(&cfg.SDKConfig, client)
|
|
}
|
|
return &AuthClient{httpClient: client}
|
|
}
|
|
|
|
func NormalizeBaseURL(raw string) string {
|
|
value := strings.TrimSpace(raw)
|
|
if value == "" {
|
|
return DefaultBaseURL
|
|
}
|
|
if !strings.Contains(value, "://") {
|
|
value = "https://" + value
|
|
}
|
|
value = strings.TrimRight(value, "/")
|
|
return value
|
|
}
|
|
|
|
func TokenExpiry(now time.Time, token *TokenResponse) time.Time {
|
|
if token == nil {
|
|
return time.Time{}
|
|
}
|
|
if token.CreatedAt > 0 && token.ExpiresIn > 0 {
|
|
return time.Unix(token.CreatedAt+int64(token.ExpiresIn), 0).UTC()
|
|
}
|
|
if token.ExpiresIn > 0 {
|
|
return now.UTC().Add(time.Duration(token.ExpiresIn) * time.Second)
|
|
}
|
|
return time.Time{}
|
|
}
|
|
|
|
func GeneratePKCECodes() (*PKCECodes, error) {
|
|
verifierBytes := make([]byte, 32)
|
|
if _, err := rand.Read(verifierBytes); err != nil {
|
|
return nil, fmt.Errorf("gitlab pkce generation failed: %w", err)
|
|
}
|
|
verifier := base64.RawURLEncoding.EncodeToString(verifierBytes)
|
|
sum := sha256.Sum256([]byte(verifier))
|
|
challenge := base64.RawURLEncoding.EncodeToString(sum[:])
|
|
return &PKCECodes{
|
|
CodeVerifier: verifier,
|
|
CodeChallenge: challenge,
|
|
}, nil
|
|
}
|
|
|
|
func NewOAuthServer(port int) *OAuthServer {
|
|
return &OAuthServer{
|
|
port: port,
|
|
resultChan: make(chan *OAuthResult, 1),
|
|
errorChan: make(chan error, 1),
|
|
}
|
|
}
|
|
|
|
func (s *OAuthServer) Start() error {
|
|
s.mu.Lock()
|
|
defer s.mu.Unlock()
|
|
|
|
if s.running {
|
|
return fmt.Errorf("gitlab oauth server already running")
|
|
}
|
|
if !s.isPortAvailable() {
|
|
return fmt.Errorf("port %d is already in use", s.port)
|
|
}
|
|
|
|
mux := http.NewServeMux()
|
|
mux.HandleFunc("/auth/callback", s.handleCallback)
|
|
|
|
s.server = &http.Server{
|
|
Addr: fmt.Sprintf(":%d", s.port),
|
|
Handler: mux,
|
|
ReadTimeout: 10 * time.Second,
|
|
WriteTimeout: 10 * time.Second,
|
|
}
|
|
s.running = true
|
|
|
|
go func() {
|
|
if err := s.server.ListenAndServe(); err != nil && err != http.ErrServerClosed {
|
|
s.errorChan <- err
|
|
}
|
|
}()
|
|
|
|
time.Sleep(100 * time.Millisecond)
|
|
return nil
|
|
}
|
|
|
|
func (s *OAuthServer) Stop(ctx context.Context) error {
|
|
s.mu.Lock()
|
|
defer s.mu.Unlock()
|
|
if !s.running || s.server == nil {
|
|
return nil
|
|
}
|
|
defer func() {
|
|
s.running = false
|
|
s.server = nil
|
|
}()
|
|
return s.server.Shutdown(ctx)
|
|
}
|
|
|
|
func (s *OAuthServer) WaitForCallback(timeout time.Duration) (*OAuthResult, error) {
|
|
select {
|
|
case result := <-s.resultChan:
|
|
return result, nil
|
|
case err := <-s.errorChan:
|
|
return nil, err
|
|
case <-time.After(timeout):
|
|
return nil, fmt.Errorf("timeout waiting for OAuth callback")
|
|
}
|
|
}
|
|
|
|
func (s *OAuthServer) handleCallback(w http.ResponseWriter, r *http.Request) {
|
|
if r.Method != http.MethodGet {
|
|
http.Error(w, "method not allowed", http.StatusMethodNotAllowed)
|
|
return
|
|
}
|
|
query := r.URL.Query()
|
|
if errParam := strings.TrimSpace(query.Get("error")); errParam != "" {
|
|
s.sendResult(&OAuthResult{Error: errParam})
|
|
http.Error(w, errParam, http.StatusBadRequest)
|
|
return
|
|
}
|
|
code := strings.TrimSpace(query.Get("code"))
|
|
state := strings.TrimSpace(query.Get("state"))
|
|
if code == "" || state == "" {
|
|
s.sendResult(&OAuthResult{Error: "missing_code_or_state"})
|
|
http.Error(w, "missing code or state", http.StatusBadRequest)
|
|
return
|
|
}
|
|
s.sendResult(&OAuthResult{Code: code, State: state})
|
|
_, _ = w.Write([]byte("GitLab authentication received. You can close this tab."))
|
|
}
|
|
|
|
func (s *OAuthServer) sendResult(result *OAuthResult) {
|
|
select {
|
|
case s.resultChan <- result:
|
|
default:
|
|
log.Debug("gitlab oauth result channel full, dropping callback result")
|
|
}
|
|
}
|
|
|
|
func (s *OAuthServer) isPortAvailable() bool {
|
|
listener, err := net.Listen("tcp", fmt.Sprintf(":%d", s.port))
|
|
if err != nil {
|
|
return false
|
|
}
|
|
_ = listener.Close()
|
|
return true
|
|
}
|
|
|
|
func RedirectURL(port int) string {
|
|
return fmt.Sprintf("http://localhost:%d/auth/callback", port)
|
|
}
|
|
|
|
func (c *AuthClient) GenerateAuthURL(baseURL, clientID, redirectURI, state string, pkce *PKCECodes) (string, error) {
|
|
if pkce == nil {
|
|
return "", fmt.Errorf("gitlab auth URL generation failed: PKCE codes are required")
|
|
}
|
|
if strings.TrimSpace(clientID) == "" {
|
|
return "", fmt.Errorf("gitlab auth URL generation failed: client ID is required")
|
|
}
|
|
baseURL = NormalizeBaseURL(baseURL)
|
|
params := url.Values{
|
|
"client_id": {strings.TrimSpace(clientID)},
|
|
"response_type": {"code"},
|
|
"redirect_uri": {strings.TrimSpace(redirectURI)},
|
|
"scope": {defaultOAuthScope},
|
|
"state": {strings.TrimSpace(state)},
|
|
"code_challenge": {pkce.CodeChallenge},
|
|
"code_challenge_method": {"S256"},
|
|
}
|
|
return fmt.Sprintf("%s/oauth/authorize?%s", baseURL, params.Encode()), nil
|
|
}
|
|
|
|
func (c *AuthClient) ExchangeCodeForTokens(ctx context.Context, baseURL, clientID, clientSecret, redirectURI, code, codeVerifier string) (*TokenResponse, error) {
|
|
form := url.Values{
|
|
"grant_type": {"authorization_code"},
|
|
"client_id": {strings.TrimSpace(clientID)},
|
|
"code": {strings.TrimSpace(code)},
|
|
"redirect_uri": {strings.TrimSpace(redirectURI)},
|
|
"code_verifier": {strings.TrimSpace(codeVerifier)},
|
|
}
|
|
if secret := strings.TrimSpace(clientSecret); secret != "" {
|
|
form.Set("client_secret", secret)
|
|
}
|
|
return c.postToken(ctx, NormalizeBaseURL(baseURL)+"/oauth/token", form)
|
|
}
|
|
|
|
func (c *AuthClient) RefreshTokens(ctx context.Context, baseURL, clientID, clientSecret, refreshToken string) (*TokenResponse, error) {
|
|
form := url.Values{
|
|
"grant_type": {"refresh_token"},
|
|
"refresh_token": {strings.TrimSpace(refreshToken)},
|
|
}
|
|
if clientID = strings.TrimSpace(clientID); clientID != "" {
|
|
form.Set("client_id", clientID)
|
|
}
|
|
if secret := strings.TrimSpace(clientSecret); secret != "" {
|
|
form.Set("client_secret", secret)
|
|
}
|
|
return c.postToken(ctx, NormalizeBaseURL(baseURL)+"/oauth/token", form)
|
|
}
|
|
|
|
func (c *AuthClient) postToken(ctx context.Context, tokenURL string, form url.Values) (*TokenResponse, error) {
|
|
req, err := http.NewRequestWithContext(ctx, http.MethodPost, tokenURL, strings.NewReader(form.Encode()))
|
|
if err != nil {
|
|
return nil, fmt.Errorf("gitlab token request failed: %w", err)
|
|
}
|
|
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
|
|
req.Header.Set("Accept", "application/json")
|
|
|
|
resp, err := c.httpClient.Do(req)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("gitlab token request failed: %w", err)
|
|
}
|
|
defer func() { _ = resp.Body.Close() }()
|
|
|
|
body, err := io.ReadAll(resp.Body)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("gitlab token response read failed: %w", err)
|
|
}
|
|
if resp.StatusCode < 200 || resp.StatusCode >= 300 {
|
|
return nil, fmt.Errorf("gitlab token request failed with status %d: %s", resp.StatusCode, strings.TrimSpace(string(body)))
|
|
}
|
|
var token TokenResponse
|
|
if err := json.Unmarshal(body, &token); err != nil {
|
|
return nil, fmt.Errorf("gitlab token response decode failed: %w", err)
|
|
}
|
|
return &token, nil
|
|
}
|
|
|
|
func (c *AuthClient) GetCurrentUser(ctx context.Context, baseURL, token string) (*User, error) {
|
|
req, err := http.NewRequestWithContext(ctx, http.MethodGet, NormalizeBaseURL(baseURL)+"/api/v4/user", nil)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("gitlab user request failed: %w", err)
|
|
}
|
|
req.Header.Set("Authorization", "Bearer "+strings.TrimSpace(token))
|
|
req.Header.Set("Accept", "application/json")
|
|
|
|
resp, err := c.httpClient.Do(req)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("gitlab user request failed: %w", err)
|
|
}
|
|
defer func() { _ = resp.Body.Close() }()
|
|
|
|
body, err := io.ReadAll(resp.Body)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("gitlab user response read failed: %w", err)
|
|
}
|
|
if resp.StatusCode < 200 || resp.StatusCode >= 300 {
|
|
return nil, fmt.Errorf("gitlab user request failed with status %d: %s", resp.StatusCode, strings.TrimSpace(string(body)))
|
|
}
|
|
|
|
var user User
|
|
if err := json.Unmarshal(body, &user); err != nil {
|
|
return nil, fmt.Errorf("gitlab user response decode failed: %w", err)
|
|
}
|
|
return &user, nil
|
|
}
|
|
|
|
func (c *AuthClient) GetPersonalAccessTokenSelf(ctx context.Context, baseURL, token string) (*PersonalAccessTokenSelf, error) {
|
|
req, err := http.NewRequestWithContext(ctx, http.MethodGet, NormalizeBaseURL(baseURL)+"/api/v4/personal_access_tokens/self", nil)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("gitlab PAT self request failed: %w", err)
|
|
}
|
|
req.Header.Set("Authorization", "Bearer "+strings.TrimSpace(token))
|
|
req.Header.Set("Accept", "application/json")
|
|
|
|
resp, err := c.httpClient.Do(req)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("gitlab PAT self request failed: %w", err)
|
|
}
|
|
defer func() { _ = resp.Body.Close() }()
|
|
|
|
body, err := io.ReadAll(resp.Body)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("gitlab PAT self response read failed: %w", err)
|
|
}
|
|
if resp.StatusCode < 200 || resp.StatusCode >= 300 {
|
|
return nil, fmt.Errorf("gitlab PAT self request failed with status %d: %s", resp.StatusCode, strings.TrimSpace(string(body)))
|
|
}
|
|
|
|
var pat PersonalAccessTokenSelf
|
|
if err := json.Unmarshal(body, &pat); err != nil {
|
|
return nil, fmt.Errorf("gitlab PAT self response decode failed: %w", err)
|
|
}
|
|
return &pat, nil
|
|
}
|
|
|
|
func (c *AuthClient) FetchDirectAccess(ctx context.Context, baseURL, token string) (*DirectAccessResponse, error) {
|
|
req, err := http.NewRequestWithContext(ctx, http.MethodPost, NormalizeBaseURL(baseURL)+"/api/v4/code_suggestions/direct_access", nil)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("gitlab direct access request failed: %w", err)
|
|
}
|
|
req.Header.Set("Authorization", "Bearer "+strings.TrimSpace(token))
|
|
req.Header.Set("Accept", "application/json")
|
|
|
|
resp, err := c.httpClient.Do(req)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("gitlab direct access request failed: %w", err)
|
|
}
|
|
defer func() { _ = resp.Body.Close() }()
|
|
|
|
body, err := io.ReadAll(resp.Body)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("gitlab direct access response read failed: %w", err)
|
|
}
|
|
if resp.StatusCode < 200 || resp.StatusCode >= 300 {
|
|
return nil, fmt.Errorf("gitlab direct access request failed with status %d: %s", resp.StatusCode, strings.TrimSpace(string(body)))
|
|
}
|
|
|
|
var direct DirectAccessResponse
|
|
if err := json.Unmarshal(body, &direct); err != nil {
|
|
return nil, fmt.Errorf("gitlab direct access response decode failed: %w", err)
|
|
}
|
|
if direct.Headers == nil {
|
|
direct.Headers = make(map[string]string)
|
|
}
|
|
return &direct, nil
|
|
}
|
|
|
|
func ExtractDiscoveredModels(metadata map[string]any) []DiscoveredModel {
|
|
if len(metadata) == 0 {
|
|
return nil
|
|
}
|
|
|
|
models := make([]DiscoveredModel, 0, 4)
|
|
seen := make(map[string]struct{})
|
|
appendModel := func(provider, name string) {
|
|
provider = strings.TrimSpace(provider)
|
|
name = strings.TrimSpace(name)
|
|
if name == "" {
|
|
return
|
|
}
|
|
key := strings.ToLower(name)
|
|
if _, ok := seen[key]; ok {
|
|
return
|
|
}
|
|
seen[key] = struct{}{}
|
|
models = append(models, DiscoveredModel{
|
|
ModelProvider: provider,
|
|
ModelName: name,
|
|
})
|
|
}
|
|
|
|
if raw, ok := metadata["model_details"]; ok {
|
|
appendDiscoveredModels(raw, appendModel)
|
|
}
|
|
appendModel(stringValue(metadata["model_provider"]), stringValue(metadata["model_name"]))
|
|
|
|
for _, key := range []string{"models", "supported_models", "discovered_models"} {
|
|
if raw, ok := metadata[key]; ok {
|
|
appendDiscoveredModels(raw, appendModel)
|
|
}
|
|
}
|
|
|
|
return models
|
|
}
|
|
|
|
func appendDiscoveredModels(raw any, appendModel func(provider, name string)) {
|
|
switch typed := raw.(type) {
|
|
case map[string]any:
|
|
appendModel(stringValue(typed["model_provider"]), stringValue(typed["model_name"]))
|
|
appendModel(stringValue(typed["provider"]), stringValue(typed["name"]))
|
|
if nested, ok := typed["models"]; ok {
|
|
appendDiscoveredModels(nested, appendModel)
|
|
}
|
|
case []any:
|
|
for _, item := range typed {
|
|
appendDiscoveredModels(item, appendModel)
|
|
}
|
|
case []string:
|
|
for _, item := range typed {
|
|
appendModel("", item)
|
|
}
|
|
case string:
|
|
appendModel("", typed)
|
|
}
|
|
}
|
|
|
|
func stringValue(raw any) string {
|
|
switch typed := raw.(type) {
|
|
case string:
|
|
return strings.TrimSpace(typed)
|
|
case fmt.Stringer:
|
|
return strings.TrimSpace(typed.String())
|
|
case json.Number:
|
|
return typed.String()
|
|
case int:
|
|
return strconv.Itoa(typed)
|
|
case int64:
|
|
return strconv.FormatInt(typed, 10)
|
|
case float64:
|
|
return strconv.FormatInt(int64(typed), 10)
|
|
default:
|
|
return ""
|
|
}
|
|
}
|