mirror of
https://github.com/Gouryella/drip.git
synced 2026-03-01 15:52:32 +00:00
feat(cli): add proxy authentication support
Add the --auth parameter to enable proxy authentication for HTTP and HTTPS tunnels, supporting password verification and session management. - Add --auth flag in CLI to set proxy authentication password - Implement server-side authentication handling and login page - Support Cookie-based session management and validation - Add protocol message definitions related to authentication
This commit is contained in:
@@ -3,6 +3,9 @@ package proxy
|
||||
import (
|
||||
"bufio"
|
||||
"context"
|
||||
"crypto/rand"
|
||||
"crypto/sha256"
|
||||
"encoding/hex"
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
@@ -32,6 +35,57 @@ var bufioReaderPool = sync.Pool{
|
||||
}
|
||||
|
||||
const openStreamTimeout = 3 * time.Second
|
||||
const authCookieName = "drip_auth"
|
||||
const authSessionDuration = 24 * time.Hour
|
||||
|
||||
type authSession struct {
|
||||
subdomain string
|
||||
expiresAt time.Time
|
||||
}
|
||||
|
||||
type authSessionStore struct {
|
||||
mu sync.RWMutex
|
||||
sessions map[string]*authSession
|
||||
}
|
||||
|
||||
var sessionStore = &authSessionStore{
|
||||
sessions: make(map[string]*authSession),
|
||||
}
|
||||
|
||||
func (s *authSessionStore) create(subdomain string) string {
|
||||
token := generateSessionToken()
|
||||
s.mu.Lock()
|
||||
s.sessions[token] = &authSession{
|
||||
subdomain: subdomain,
|
||||
expiresAt: time.Now().Add(authSessionDuration),
|
||||
}
|
||||
s.mu.Unlock()
|
||||
return token
|
||||
}
|
||||
|
||||
func (s *authSessionStore) validate(token, subdomain string) bool {
|
||||
s.mu.RLock()
|
||||
session, ok := s.sessions[token]
|
||||
s.mu.RUnlock()
|
||||
|
||||
if !ok {
|
||||
return false
|
||||
}
|
||||
if time.Now().After(session.expiresAt) {
|
||||
s.mu.Lock()
|
||||
delete(s.sessions, token)
|
||||
s.mu.Unlock()
|
||||
return false
|
||||
}
|
||||
return session.subdomain == subdomain
|
||||
}
|
||||
|
||||
func generateSessionToken() string {
|
||||
b := make([]byte, 32)
|
||||
rand.Read(b)
|
||||
hash := sha256.Sum256(b)
|
||||
return hex.EncodeToString(hash[:])
|
||||
}
|
||||
|
||||
type Handler struct {
|
||||
manager *tunnel.Manager
|
||||
@@ -45,13 +99,13 @@ var privateNetworks []*net.IPNet
|
||||
|
||||
func init() {
|
||||
privateCIDRs := []string{
|
||||
"127.0.0.0/8", // IPv4 loopback
|
||||
"10.0.0.0/8", // RFC 1918 Class A
|
||||
"172.16.0.0/12", // RFC 1918 Class B
|
||||
"192.168.0.0/16", // RFC 1918 Class C
|
||||
"::1/128", // IPv6 loopback
|
||||
"fc00::/7", // IPv6 unique local
|
||||
"fe80::/10", // IPv6 link-local
|
||||
"127.0.0.0/8",
|
||||
"10.0.0.0/8",
|
||||
"172.16.0.0/12",
|
||||
"192.168.0.0/16",
|
||||
"::1/128",
|
||||
"fc00::/7",
|
||||
"fe80::/10",
|
||||
}
|
||||
for _, cidr := range privateCIDRs {
|
||||
_, ipNet, _ := net.ParseCIDR(cidr)
|
||||
@@ -107,6 +161,18 @@ func (h *Handler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
||||
}
|
||||
}
|
||||
|
||||
// Check proxy authentication
|
||||
if tconn.HasProxyAuth() {
|
||||
if r.URL.Path == "/_drip/login" {
|
||||
h.handleProxyLogin(w, r, tconn, subdomain)
|
||||
return
|
||||
}
|
||||
if !h.isProxyAuthenticated(r, subdomain) {
|
||||
h.serveLoginPage(w, r, subdomain, "")
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
tType := tconn.GetTunnelType()
|
||||
if tType != "" && tType != protocol.TunnelTypeHTTP && tType != protocol.TunnelTypeHTTPS {
|
||||
http.Error(w, "Tunnel does not accept HTTP traffic", http.StatusBadGateway)
|
||||
@@ -638,3 +704,145 @@ type bufferedReadWriteCloser struct {
|
||||
func (b *bufferedReadWriteCloser) Read(p []byte) (int, error) {
|
||||
return b.Reader.Read(p)
|
||||
}
|
||||
|
||||
func (h *Handler) isProxyAuthenticated(r *http.Request, subdomain string) bool {
|
||||
cookie, err := r.Cookie(authCookieName + "_" + subdomain)
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
return sessionStore.validate(cookie.Value, subdomain)
|
||||
}
|
||||
|
||||
func (h *Handler) handleProxyLogin(w http.ResponseWriter, r *http.Request, tconn *tunnel.Connection, subdomain string) {
|
||||
if r.Method != http.MethodPost {
|
||||
h.serveLoginPage(w, r, subdomain, "")
|
||||
return
|
||||
}
|
||||
|
||||
if err := r.ParseForm(); err != nil {
|
||||
h.serveLoginPage(w, r, subdomain, "Invalid form data")
|
||||
return
|
||||
}
|
||||
|
||||
password := r.FormValue("password")
|
||||
|
||||
if !tconn.ValidateProxyAuth(password) {
|
||||
h.serveLoginPage(w, r, subdomain, "Invalid password")
|
||||
return
|
||||
}
|
||||
|
||||
token := sessionStore.create(subdomain)
|
||||
http.SetCookie(w, &http.Cookie{
|
||||
Name: authCookieName + "_" + subdomain,
|
||||
Value: token,
|
||||
Path: "/",
|
||||
MaxAge: int(authSessionDuration.Seconds()),
|
||||
HttpOnly: true,
|
||||
Secure: true,
|
||||
SameSite: http.SameSiteLaxMode,
|
||||
})
|
||||
|
||||
redirectURL := r.FormValue("redirect")
|
||||
if redirectURL == "" || redirectURL == "/_drip/login" {
|
||||
redirectURL = "/"
|
||||
}
|
||||
http.Redirect(w, r, redirectURL, http.StatusSeeOther)
|
||||
}
|
||||
|
||||
func (h *Handler) serveLoginPage(w http.ResponseWriter, r *http.Request, subdomain string, errorMsg string) {
|
||||
redirectURL := r.URL.Path
|
||||
if r.URL.RawQuery != "" {
|
||||
redirectURL += "?" + r.URL.RawQuery
|
||||
}
|
||||
if redirectURL == "/_drip/login" {
|
||||
redirectURL = "/"
|
||||
}
|
||||
|
||||
errorHTML := ""
|
||||
if errorMsg != "" {
|
||||
errorHTML = fmt.Sprintf(`<p class="error">%s</p>`, errorMsg)
|
||||
}
|
||||
|
||||
html := fmt.Sprintf(`<!DOCTYPE html>
|
||||
<html lang="en">
|
||||
<head>
|
||||
<meta charset="UTF-8" />
|
||||
<meta name="viewport" content="width=device-width, initial-scale=1.0" />
|
||||
<title>%s - Drip</title>
|
||||
<style>
|
||||
* { margin: 0; padding: 0; box-sizing: border-box; }
|
||||
body {
|
||||
font-family: -apple-system, BlinkMacSystemFont, 'Segoe UI', Roboto, sans-serif;
|
||||
background: #fff;
|
||||
color: #24292f;
|
||||
line-height: 1.6;
|
||||
}
|
||||
.container { max-width: 720px; margin: 0 auto; padding: 48px 24px; }
|
||||
header { margin-bottom: 48px; }
|
||||
h1 { font-size: 28px; font-weight: 600; margin-bottom: 8px; }
|
||||
h1 span { margin-right: 8px; }
|
||||
.desc { color: #57606a; font-size: 16px; }
|
||||
p { margin-bottom: 24px; }
|
||||
.error { color: #cf222e; margin-bottom: 16px; }
|
||||
.input-wrap {
|
||||
position: relative;
|
||||
background: #f6f8fa;
|
||||
border: 1px solid #d0d7de;
|
||||
border-radius: 6px;
|
||||
margin-bottom: 12px;
|
||||
display: flex;
|
||||
}
|
||||
.input-wrap input {
|
||||
flex: 1;
|
||||
margin: 0;
|
||||
padding: 12px 16px;
|
||||
font-family: ui-monospace, SFMono-Regular, 'SF Mono', Menlo, Consolas, monospace;
|
||||
font-size: 14px;
|
||||
background: transparent;
|
||||
border: none;
|
||||
outline: none;
|
||||
}
|
||||
.input-wrap button {
|
||||
background: #24292f;
|
||||
color: #fff;
|
||||
border: none;
|
||||
padding: 8px 16px;
|
||||
margin: 4px;
|
||||
border-radius: 4px;
|
||||
font-size: 14px;
|
||||
cursor: pointer;
|
||||
}
|
||||
.input-wrap button:hover { background: #32383f; }
|
||||
footer { margin-top: 48px; padding-top: 24px; border-top: 1px solid #d0d7de; }
|
||||
footer a { color: #57606a; text-decoration: none; font-size: 14px; }
|
||||
footer a:hover { color: #0969da; }
|
||||
</style>
|
||||
</head>
|
||||
<body>
|
||||
<div class="container">
|
||||
<header>
|
||||
<h1><span>🔒</span>%s</h1>
|
||||
<p class="desc">This tunnel is password protected</p>
|
||||
</header>
|
||||
|
||||
%s
|
||||
<form method="POST" action="/_drip/login">
|
||||
<input type="hidden" name="redirect" value="%s" />
|
||||
<div class="input-wrap">
|
||||
<input type="password" name="password" placeholder="Enter password" required autofocus />
|
||||
<button type="submit">Continue</button>
|
||||
</div>
|
||||
</form>
|
||||
|
||||
<footer>
|
||||
<a href="https://github.com/Gouryella/drip" target="_blank">GitHub</a>
|
||||
</footer>
|
||||
</div>
|
||||
</body>
|
||||
</html>`, subdomain, subdomain, errorHTML, redirectURL)
|
||||
|
||||
w.Header().Set("Content-Type", "text/html; charset=utf-8")
|
||||
w.Header().Set("Cache-Control", "no-store, no-cache, must-revalidate")
|
||||
w.WriteHeader(http.StatusUnauthorized)
|
||||
w.Write([]byte(html))
|
||||
}
|
||||
|
||||
@@ -190,6 +190,13 @@ func (c *Connection) Handle() error {
|
||||
)
|
||||
}
|
||||
|
||||
if req.ProxyAuth != nil && req.ProxyAuth.Enabled {
|
||||
c.tunnelConn.SetProxyAuth(req.ProxyAuth)
|
||||
c.logger.Info("Proxy authentication configured",
|
||||
zap.String("subdomain", subdomain),
|
||||
)
|
||||
}
|
||||
|
||||
c.logger.Info("Tunnel registered",
|
||||
zap.String("subdomain", subdomain),
|
||||
zap.String("tunnel_type", string(req.TunnelType)),
|
||||
|
||||
@@ -13,7 +13,6 @@ import (
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
|
||||
// Connection represents a tunnel connection from a client
|
||||
type Connection struct {
|
||||
Subdomain string
|
||||
Conn *websocket.Conn
|
||||
@@ -22,19 +21,19 @@ type Connection struct {
|
||||
LastActive time.Time
|
||||
mu sync.RWMutex
|
||||
logger *zap.Logger
|
||||
closed atomic.Bool // Use atomic for lock-free hot path checks
|
||||
closed atomic.Bool
|
||||
tunnelType protocol.TunnelType
|
||||
openStream func() (net.Conn, error)
|
||||
remoteIP string // Client IP for rate limiting tracking
|
||||
remoteIP string
|
||||
|
||||
bytesIn atomic.Int64
|
||||
bytesOut atomic.Int64
|
||||
activeConnections atomic.Int64
|
||||
|
||||
ipAccessChecker *netutil.IPAccessChecker
|
||||
proxyAuth *protocol.ProxyAuth
|
||||
}
|
||||
|
||||
// NewConnection creates a new tunnel connection
|
||||
func NewConnection(subdomain string, conn *websocket.Conn, logger *zap.Logger) *Connection {
|
||||
return &Connection{
|
||||
Subdomain: subdomain,
|
||||
@@ -46,9 +45,7 @@ func NewConnection(subdomain string, conn *websocket.Conn, logger *zap.Logger) *
|
||||
}
|
||||
}
|
||||
|
||||
// Send sends data through the WebSocket connection
|
||||
func (c *Connection) Send(data []byte) error {
|
||||
// Lock-free check using atomic - avoids RLock contention on hot path
|
||||
if c.closed.Load() {
|
||||
return ErrConnectionClosed
|
||||
}
|
||||
@@ -61,25 +58,21 @@ func (c *Connection) Send(data []byte) error {
|
||||
}
|
||||
}
|
||||
|
||||
// UpdateActivity updates the last activity timestamp
|
||||
func (c *Connection) UpdateActivity() {
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
c.LastActive = time.Now()
|
||||
}
|
||||
|
||||
// IsAlive checks if the connection is still alive based on last activity
|
||||
func (c *Connection) IsAlive(timeout time.Duration) bool {
|
||||
c.mu.RLock()
|
||||
defer c.mu.RUnlock()
|
||||
return time.Since(c.LastActive) < timeout
|
||||
}
|
||||
|
||||
// Close closes the connection and all associated channels
|
||||
func (c *Connection) Close() {
|
||||
// Use atomic swap to ensure only one goroutine closes
|
||||
if c.closed.Swap(true) {
|
||||
return // Already closed
|
||||
return
|
||||
}
|
||||
|
||||
c.mu.Lock()
|
||||
@@ -89,46 +82,37 @@ func (c *Connection) Close() {
|
||||
close(c.SendCh)
|
||||
|
||||
if c.Conn != nil {
|
||||
// Send close message
|
||||
c.Conn.WriteMessage(websocket.CloseMessage,
|
||||
websocket.FormatCloseMessage(websocket.CloseNormalClosure, ""))
|
||||
c.Conn.Close()
|
||||
}
|
||||
|
||||
c.logger.Info("Connection closed",
|
||||
zap.String("subdomain", c.Subdomain),
|
||||
)
|
||||
c.logger.Info("Connection closed", zap.String("subdomain", c.Subdomain))
|
||||
}
|
||||
|
||||
// IsClosed returns whether the connection is closed
|
||||
func (c *Connection) IsClosed() bool {
|
||||
return c.closed.Load() // Lock-free check
|
||||
return c.closed.Load()
|
||||
}
|
||||
|
||||
// SetTunnelType sets the tunnel type.
|
||||
func (c *Connection) SetTunnelType(tType protocol.TunnelType) {
|
||||
c.mu.Lock()
|
||||
c.tunnelType = tType
|
||||
c.mu.Unlock()
|
||||
}
|
||||
|
||||
// GetTunnelType returns the tunnel type.
|
||||
func (c *Connection) GetTunnelType() protocol.TunnelType {
|
||||
c.mu.RLock()
|
||||
defer c.mu.RUnlock()
|
||||
return c.tunnelType
|
||||
}
|
||||
|
||||
// SetOpenStream registers a stream opener for this tunnel.
|
||||
func (c *Connection) SetOpenStream(open func() (net.Conn, error)) {
|
||||
c.mu.Lock()
|
||||
c.openStream = open
|
||||
c.mu.Unlock()
|
||||
}
|
||||
|
||||
// OpenStream opens a new mux stream to the tunnel client.
|
||||
func (c *Connection) OpenStream() (net.Conn, error) {
|
||||
// Lock-free closed check
|
||||
if c.closed.Load() {
|
||||
return nil, ErrConnectionClosed
|
||||
}
|
||||
@@ -161,13 +145,8 @@ func (c *Connection) AddBytesOut(n int64) {
|
||||
metrics.TunnelBytesSent.WithLabelValues(c.Subdomain, c.Subdomain, c.GetTunnelType().String()).Add(float64(n))
|
||||
}
|
||||
|
||||
func (c *Connection) GetBytesIn() int64 {
|
||||
return c.bytesIn.Load()
|
||||
}
|
||||
|
||||
func (c *Connection) GetBytesOut() int64 {
|
||||
return c.bytesOut.Load()
|
||||
}
|
||||
func (c *Connection) GetBytesIn() int64 { return c.bytesIn.Load() }
|
||||
func (c *Connection) GetBytesOut() int64 { return c.bytesOut.Load() }
|
||||
|
||||
func (c *Connection) IncActiveConnections() {
|
||||
c.activeConnections.Add(1)
|
||||
@@ -181,37 +160,60 @@ func (c *Connection) DecActiveConnections() {
|
||||
metrics.TunnelActiveConnections.WithLabelValues(c.Subdomain, c.Subdomain, c.GetTunnelType().String()).Dec()
|
||||
}
|
||||
|
||||
func (c *Connection) GetActiveConnections() int64 {
|
||||
return c.activeConnections.Load()
|
||||
}
|
||||
func (c *Connection) GetActiveConnections() int64 { return c.activeConnections.Load() }
|
||||
|
||||
// SetIPAccessControl sets the IP access control rules for this tunnel.
|
||||
func (c *Connection) SetIPAccessControl(allowCIDRs, denyIPs []string) {
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
c.ipAccessChecker = netutil.NewIPAccessChecker(allowCIDRs, denyIPs)
|
||||
}
|
||||
|
||||
// IsIPAllowed checks if the given IP address is allowed to access this tunnel.
|
||||
func (c *Connection) IsIPAllowed(ip string) bool {
|
||||
c.mu.RLock()
|
||||
checker := c.ipAccessChecker
|
||||
c.mu.RUnlock()
|
||||
|
||||
if checker == nil {
|
||||
return true // No access control configured
|
||||
return true
|
||||
}
|
||||
return checker.IsAllowed(ip)
|
||||
}
|
||||
|
||||
// HasIPAccessControl returns true if IP access control is configured.
|
||||
func (c *Connection) HasIPAccessControl() bool {
|
||||
c.mu.RLock()
|
||||
defer c.mu.RUnlock()
|
||||
return c.ipAccessChecker != nil && c.ipAccessChecker.HasRules()
|
||||
}
|
||||
|
||||
// StartWritePump starts the write pump for sending messages
|
||||
func (c *Connection) SetProxyAuth(auth *protocol.ProxyAuth) {
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
c.proxyAuth = auth
|
||||
}
|
||||
|
||||
func (c *Connection) GetProxyAuth() *protocol.ProxyAuth {
|
||||
c.mu.RLock()
|
||||
defer c.mu.RUnlock()
|
||||
return c.proxyAuth
|
||||
}
|
||||
|
||||
func (c *Connection) HasProxyAuth() bool {
|
||||
c.mu.RLock()
|
||||
defer c.mu.RUnlock()
|
||||
return c.proxyAuth != nil && c.proxyAuth.Enabled
|
||||
}
|
||||
|
||||
func (c *Connection) ValidateProxyAuth(password string) bool {
|
||||
c.mu.RLock()
|
||||
auth := c.proxyAuth
|
||||
c.mu.RUnlock()
|
||||
|
||||
if auth == nil || !auth.Enabled {
|
||||
return true
|
||||
}
|
||||
return auth.Password == password
|
||||
}
|
||||
|
||||
func (c *Connection) StartWritePump() {
|
||||
if c.Conn == nil {
|
||||
go func() {
|
||||
@@ -241,15 +243,11 @@ func (c *Connection) StartWritePump() {
|
||||
|
||||
c.Conn.SetWriteDeadline(time.Now().Add(10 * time.Second))
|
||||
if err := c.Conn.WriteMessage(websocket.TextMessage, message); err != nil {
|
||||
c.logger.Error("Write error",
|
||||
zap.String("subdomain", c.Subdomain),
|
||||
zap.Error(err),
|
||||
)
|
||||
c.logger.Error("Write error", zap.String("subdomain", c.Subdomain), zap.Error(err))
|
||||
return
|
||||
}
|
||||
|
||||
case <-ticker.C:
|
||||
// Send ping to keep connection alive
|
||||
c.Conn.SetWriteDeadline(time.Now().Add(10 * time.Second))
|
||||
if err := c.Conn.WriteMessage(websocket.PingMessage, nil); err != nil {
|
||||
return
|
||||
|
||||
Reference in New Issue
Block a user