mirror of
https://github.com/Gouryella/drip.git
synced 2026-03-01 15:52:32 +00:00
feat(client): Support predefined tunnel configuration and management commands
Added predefined tunnel functionality, allowing users to define multiple tunnels in the configuration file and start them by name, including the following improvements: - Added --all flag to start all configured tunnels - Added parameterless start command to list available tunnels - Support configuration of multiple tunnel types (http, https, tcp) - Support advanced configurations such as subdomains, transport protocols, and IP allowlists refactor(deployments): Refactor Docker deployment configuration Removed old Dockerfile and Compose configurations, added new deployment files: - Removed .env.example and old Docker build files - Added Caddy reverse proxy configuration file - Added two deployment modes: standard and Caddy reverse proxy - Added detailed server configuration example files docs: Update documentation to include tunnel configuration and deployment guide Updated Chinese and English README documents: - Added usage instructions and configuration examples for predefined tunnels - Expanded server deployment section to include direct TLS and reverse proxy modes - Added server configuration reference table with detailed configuration item descriptions - Added specific configuration methods for Caddy and Nginx reverse proxies
This commit is contained in:
@@ -15,7 +15,7 @@ import (
|
||||
var configCmd = &cobra.Command{
|
||||
Use: "config",
|
||||
Short: "Manage configuration",
|
||||
Long: "Manage Drip client configuration (server, token, etc.)",
|
||||
Long: "Manage Drip client configuration (server, token, tunnels)",
|
||||
}
|
||||
|
||||
var configInitCmd = &cobra.Command{
|
||||
@@ -135,6 +135,32 @@ func runConfigShow(_ *cobra.Command, _ []string) error {
|
||||
|
||||
fmt.Println(ui.RenderConfigShow(cfg.Server, displayToken, !configFull, cfg.TLS, config.DefaultClientConfigPath()))
|
||||
|
||||
// Show tunnels if configured
|
||||
if len(cfg.Tunnels) > 0 {
|
||||
fmt.Println()
|
||||
fmt.Println(ui.Title("Configured Tunnels"))
|
||||
for _, t := range cfg.Tunnels {
|
||||
addr := t.Address
|
||||
if addr == "" {
|
||||
addr = "127.0.0.1"
|
||||
}
|
||||
fmt.Printf(" %-12s %-6s %s:%d", t.Name, t.Type, addr, t.Port)
|
||||
if t.Subdomain != "" {
|
||||
fmt.Printf(" subdomain=%s", t.Subdomain)
|
||||
}
|
||||
if t.Transport != "" {
|
||||
fmt.Printf(" transport=%s", t.Transport)
|
||||
}
|
||||
if len(t.AllowIPs) > 0 {
|
||||
fmt.Printf(" allow=%s", strings.Join(t.AllowIPs, ","))
|
||||
}
|
||||
if len(t.DenyIPs) > 0 {
|
||||
fmt.Printf(" deny=%s", strings.Join(t.DenyIPs, ","))
|
||||
}
|
||||
fmt.Println()
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -221,6 +247,24 @@ func runConfigValidate(_ *cobra.Command, _ []string) error {
|
||||
|
||||
fmt.Println(ui.RenderConfigValidation(serverValid, serverMsg, tokenSet, tokenMsg, tlsEnabled))
|
||||
|
||||
// Validate tunnels
|
||||
if len(cfg.Tunnels) > 0 {
|
||||
fmt.Println()
|
||||
fmt.Println(ui.Title("Tunnel Validation"))
|
||||
allValid := true
|
||||
for _, t := range cfg.Tunnels {
|
||||
if err := t.Validate(); err != nil {
|
||||
fmt.Printf(" ✗ %s: %v\n", t.Name, err)
|
||||
allValid = false
|
||||
} else {
|
||||
fmt.Printf(" ✓ %s: valid\n", t.Name)
|
||||
}
|
||||
}
|
||||
if !allValid {
|
||||
return fmt.Errorf("some tunnels have invalid configuration")
|
||||
}
|
||||
}
|
||||
|
||||
if !serverValid {
|
||||
return fmt.Errorf("invalid configuration: %s", serverMsg)
|
||||
}
|
||||
|
||||
@@ -98,7 +98,6 @@ func runHTTP(_ *cobra.Command, args []string) error {
|
||||
return runTunnelWithUI(connConfig, daemon)
|
||||
}
|
||||
|
||||
// parseTransport converts a string to TransportType
|
||||
func parseTransport(s string) tcp.TransportType {
|
||||
switch strings.ToLower(s) {
|
||||
case "wss":
|
||||
|
||||
@@ -36,6 +36,7 @@ var (
|
||||
serverPprofPort int
|
||||
serverTransports string
|
||||
serverTunnelTypes string
|
||||
serverConfigFile string
|
||||
)
|
||||
|
||||
var serverCmd = &cobra.Command{
|
||||
@@ -48,6 +49,9 @@ var serverCmd = &cobra.Command{
|
||||
func init() {
|
||||
rootCmd.AddCommand(serverCmd)
|
||||
|
||||
// Config file flag
|
||||
serverCmd.Flags().StringVarP(&serverConfigFile, "config", "c", "", "Path to config file (default: /etc/drip/config.yaml or ~/.drip/server.yaml)")
|
||||
|
||||
// Command line flags with environment variable defaults
|
||||
serverCmd.Flags().IntVarP(&serverPort, "port", "p", getEnvInt("DRIP_PORT", 8443), "Server port (env: DRIP_PORT)")
|
||||
serverCmd.Flags().IntVar(&serverPublicPort, "public-port", getEnvInt("DRIP_PUBLIC_PORT", 0), "Public port to display in URLs (env: DRIP_PUBLIC_PORT)")
|
||||
@@ -71,32 +75,180 @@ func init() {
|
||||
serverCmd.Flags().StringVar(&serverTunnelTypes, "tunnel-types", getEnvString("DRIP_TUNNEL_TYPES", "http,https,tcp"), "Allowed tunnel types: http,https,tcp (env: DRIP_TUNNEL_TYPES)")
|
||||
}
|
||||
|
||||
func runServer(_ *cobra.Command, _ []string) error {
|
||||
func runServer(cmd *cobra.Command, _ []string) error {
|
||||
// Apply server-mode GC tuning (high throughput, more memory)
|
||||
tuning.ApplyMode(tuning.ModeServer)
|
||||
|
||||
if serverTLSCert == "" {
|
||||
return fmt.Errorf("TLS certificate path is required (use --tls-cert flag or DRIP_TLS_CERT environment variable)")
|
||||
// Load config file if specified or if default exists
|
||||
var cfg *config.ServerConfig
|
||||
configPath := serverConfigFile
|
||||
if configPath == "" && config.ServerConfigExists("") {
|
||||
configPath = config.DefaultServerConfigPath()
|
||||
}
|
||||
if serverTLSKey == "" {
|
||||
return fmt.Errorf("TLS private key path is required (use --tls-key flag or DRIP_TLS_KEY environment variable)")
|
||||
if configPath != "" {
|
||||
var err error
|
||||
cfg, err = config.LoadServerConfig(configPath)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to load config file: %w", err)
|
||||
}
|
||||
}
|
||||
if cfg == nil {
|
||||
cfg = &config.ServerConfig{}
|
||||
}
|
||||
|
||||
if err := utils.InitServerLogger(serverDebug); err != nil {
|
||||
// Configuration priority: flag > env > config file > default
|
||||
// Note: flag variables already contain env defaults from init()
|
||||
// We need to check if flag was explicitly set, or if env var exists
|
||||
|
||||
// Port: flag > env > config > default(8443)
|
||||
if cmd.Flags().Changed("port") {
|
||||
cfg.Port = serverPort
|
||||
} else if os.Getenv("DRIP_PORT") != "" {
|
||||
cfg.Port = serverPort // serverPort already has env value
|
||||
} else if cfg.Port == 0 {
|
||||
cfg.Port = serverPort // use default
|
||||
}
|
||||
|
||||
// PublicPort: flag > env > config > default(0)
|
||||
// Note: 0 is a valid value meaning "same as port"
|
||||
if cmd.Flags().Changed("public-port") {
|
||||
cfg.PublicPort = serverPublicPort
|
||||
} else if os.Getenv("DRIP_PUBLIC_PORT") != "" {
|
||||
cfg.PublicPort = serverPublicPort
|
||||
}
|
||||
// else keep config file value (including 0)
|
||||
|
||||
// Domain: flag > env > config > default
|
||||
if cmd.Flags().Changed("domain") {
|
||||
cfg.Domain = serverDomain
|
||||
} else if os.Getenv("DRIP_DOMAIN") != "" {
|
||||
cfg.Domain = serverDomain
|
||||
} else if cfg.Domain == "" {
|
||||
cfg.Domain = serverDomain
|
||||
}
|
||||
|
||||
// TunnelDomain: flag > env > config > default("")
|
||||
if cmd.Flags().Changed("tunnel-domain") {
|
||||
cfg.TunnelDomain = serverTunnelDomain
|
||||
} else if os.Getenv("DRIP_TUNNEL_DOMAIN") != "" {
|
||||
cfg.TunnelDomain = serverTunnelDomain
|
||||
}
|
||||
// else keep config file value
|
||||
|
||||
// AuthToken: flag > env > config > default("")
|
||||
if cmd.Flags().Changed("token") {
|
||||
cfg.AuthToken = serverAuthToken
|
||||
} else if os.Getenv("DRIP_TOKEN") != "" {
|
||||
cfg.AuthToken = serverAuthToken
|
||||
}
|
||||
// else keep config file value
|
||||
|
||||
// MetricsToken: flag > env > config > default("")
|
||||
if cmd.Flags().Changed("metrics-token") {
|
||||
cfg.MetricsToken = serverMetricsToken
|
||||
} else if os.Getenv("DRIP_METRICS_TOKEN") != "" {
|
||||
cfg.MetricsToken = serverMetricsToken
|
||||
}
|
||||
// else keep config file value
|
||||
|
||||
// Debug: flag > config > default(false)
|
||||
// Note: debug has no env var
|
||||
if cmd.Flags().Changed("debug") {
|
||||
cfg.Debug = serverDebug
|
||||
}
|
||||
// else keep config file value
|
||||
|
||||
// TCPPortMin: flag > env > config > default
|
||||
if cmd.Flags().Changed("tcp-port-min") {
|
||||
cfg.TCPPortMin = serverTCPPortMin
|
||||
} else if os.Getenv("DRIP_TCP_PORT_MIN") != "" {
|
||||
cfg.TCPPortMin = serverTCPPortMin
|
||||
} else if cfg.TCPPortMin == 0 {
|
||||
cfg.TCPPortMin = serverTCPPortMin
|
||||
}
|
||||
|
||||
// TCPPortMax: flag > env > config > default
|
||||
if cmd.Flags().Changed("tcp-port-max") {
|
||||
cfg.TCPPortMax = serverTCPPortMax
|
||||
} else if os.Getenv("DRIP_TCP_PORT_MAX") != "" {
|
||||
cfg.TCPPortMax = serverTCPPortMax
|
||||
} else if cfg.TCPPortMax == 0 {
|
||||
cfg.TCPPortMax = serverTCPPortMax
|
||||
}
|
||||
|
||||
// TLSCertFile: flag > env > config > default("")
|
||||
if cmd.Flags().Changed("tls-cert") {
|
||||
cfg.TLSCertFile = serverTLSCert
|
||||
} else if os.Getenv("DRIP_TLS_CERT") != "" {
|
||||
cfg.TLSCertFile = serverTLSCert
|
||||
}
|
||||
// else keep config file value
|
||||
|
||||
// TLSKeyFile: flag > env > config > default("")
|
||||
if cmd.Flags().Changed("tls-key") {
|
||||
cfg.TLSKeyFile = serverTLSKey
|
||||
} else if os.Getenv("DRIP_TLS_KEY") != "" {
|
||||
cfg.TLSKeyFile = serverTLSKey
|
||||
}
|
||||
// else keep config file value
|
||||
|
||||
// PprofPort: flag > env > config > default(0)
|
||||
// Note: 0 is valid meaning "disabled"
|
||||
if cmd.Flags().Changed("pprof") {
|
||||
cfg.PprofPort = serverPprofPort
|
||||
} else if os.Getenv("DRIP_PPROF_PORT") != "" {
|
||||
cfg.PprofPort = serverPprofPort
|
||||
}
|
||||
// else keep config file value
|
||||
|
||||
// AllowedTransports: flag > env > config > default
|
||||
if cmd.Flags().Changed("transports") {
|
||||
cfg.AllowedTransports = parseCommaSeparated(serverTransports)
|
||||
} else if os.Getenv("DRIP_TRANSPORTS") != "" {
|
||||
cfg.AllowedTransports = parseCommaSeparated(serverTransports)
|
||||
} else if len(cfg.AllowedTransports) == 0 {
|
||||
cfg.AllowedTransports = parseCommaSeparated(serverTransports)
|
||||
}
|
||||
|
||||
// AllowedTunnelTypes: flag > env > config > default
|
||||
if cmd.Flags().Changed("tunnel-types") {
|
||||
cfg.AllowedTunnelTypes = parseCommaSeparated(serverTunnelTypes)
|
||||
} else if os.Getenv("DRIP_TUNNEL_TYPES") != "" {
|
||||
cfg.AllowedTunnelTypes = parseCommaSeparated(serverTunnelTypes)
|
||||
} else if len(cfg.AllowedTunnelTypes) == 0 {
|
||||
cfg.AllowedTunnelTypes = parseCommaSeparated(serverTunnelTypes)
|
||||
}
|
||||
|
||||
// TLS is always enabled for server
|
||||
cfg.TLSEnabled = true
|
||||
|
||||
// Validate required fields
|
||||
if cfg.TLSCertFile == "" {
|
||||
return fmt.Errorf("TLS certificate path is required (use --tls-cert flag, DRIP_TLS_CERT environment variable, or config file)")
|
||||
}
|
||||
if cfg.TLSKeyFile == "" {
|
||||
return fmt.Errorf("TLS private key path is required (use --tls-key flag, DRIP_TLS_KEY environment variable, or config file)")
|
||||
}
|
||||
|
||||
if err := utils.InitServerLogger(cfg.Debug); err != nil {
|
||||
return fmt.Errorf("failed to initialize logger: %w", err)
|
||||
}
|
||||
defer utils.Sync()
|
||||
|
||||
logger := utils.GetLogger()
|
||||
|
||||
if configPath != "" {
|
||||
logger.Info("Loaded configuration from file", zap.String("path", configPath))
|
||||
}
|
||||
|
||||
logger.Info("Starting Drip Server",
|
||||
zap.String("version", Version),
|
||||
zap.String("commit", GitCommit),
|
||||
)
|
||||
|
||||
if serverPprofPort > 0 {
|
||||
if cfg.PprofPort > 0 {
|
||||
go func() {
|
||||
pprofAddr := fmt.Sprintf("localhost:%d", serverPprofPort)
|
||||
pprofAddr := fmt.Sprintf("localhost:%d", cfg.PprofPort)
|
||||
logger.Info("Starting pprof server", zap.String("address", pprofAddr))
|
||||
if err := http.ListenAndServe(pprofAddr, nil); err != nil {
|
||||
logger.Error("pprof server failed", zap.Error(err))
|
||||
@@ -104,63 +256,46 @@ func runServer(_ *cobra.Command, _ []string) error {
|
||||
}()
|
||||
}
|
||||
|
||||
displayPort := serverPublicPort
|
||||
if displayPort == 0 {
|
||||
displayPort = serverPort
|
||||
// Set public port for display if not specified
|
||||
if cfg.PublicPort == 0 {
|
||||
cfg.PublicPort = cfg.Port
|
||||
}
|
||||
|
||||
// Use tunnel domain if set, otherwise fall back to domain
|
||||
tunnelDomain := serverTunnelDomain
|
||||
if tunnelDomain == "" {
|
||||
tunnelDomain = serverDomain
|
||||
// Use tunnel domain if not set, fall back to domain
|
||||
if cfg.TunnelDomain == "" {
|
||||
cfg.TunnelDomain = cfg.Domain
|
||||
}
|
||||
|
||||
serverConfig := &config.ServerConfig{
|
||||
Port: serverPort,
|
||||
PublicPort: displayPort,
|
||||
Domain: serverDomain,
|
||||
TunnelDomain: tunnelDomain,
|
||||
TCPPortMin: serverTCPPortMin,
|
||||
TCPPortMax: serverTCPPortMax,
|
||||
TLSEnabled: true,
|
||||
TLSCertFile: serverTLSCert,
|
||||
TLSKeyFile: serverTLSKey,
|
||||
AuthToken: serverAuthToken,
|
||||
Debug: serverDebug,
|
||||
AllowedTransports: parseCommaSeparated(serverTransports),
|
||||
AllowedTunnelTypes: parseCommaSeparated(serverTunnelTypes),
|
||||
}
|
||||
|
||||
if err := serverConfig.Validate(); err != nil {
|
||||
if err := cfg.Validate(); err != nil {
|
||||
logger.Fatal("Invalid server configuration", zap.Error(err))
|
||||
}
|
||||
|
||||
tlsConfig, err := serverConfig.LoadTLSConfig()
|
||||
tlsConfig, err := cfg.LoadTLSConfig()
|
||||
if err != nil {
|
||||
logger.Fatal("Failed to load TLS configuration", zap.Error(err))
|
||||
}
|
||||
|
||||
logger.Info("TLS 1.3 configuration loaded",
|
||||
zap.String("cert", serverTLSCert),
|
||||
zap.String("key", serverTLSKey),
|
||||
zap.String("cert", cfg.TLSCertFile),
|
||||
zap.String("key", cfg.TLSKeyFile),
|
||||
)
|
||||
|
||||
tunnelManager := tunnel.NewManager(logger)
|
||||
|
||||
portAllocator, err := tcp.NewPortAllocator(serverTCPPortMin, serverTCPPortMax)
|
||||
portAllocator, err := tcp.NewPortAllocator(cfg.TCPPortMin, cfg.TCPPortMax)
|
||||
if err != nil {
|
||||
logger.Fatal("Invalid TCP port range", zap.Error(err))
|
||||
}
|
||||
|
||||
listenAddr := fmt.Sprintf("0.0.0.0:%d", serverPort)
|
||||
listenAddr := fmt.Sprintf("0.0.0.0:%d", cfg.Port)
|
||||
|
||||
httpHandler := proxy.NewHandler(tunnelManager, logger, tunnelDomain, serverAuthToken, serverMetricsToken)
|
||||
httpHandler.SetAllowedTransports(serverConfig.AllowedTransports)
|
||||
httpHandler.SetAllowedTunnelTypes(serverConfig.AllowedTunnelTypes)
|
||||
httpHandler := proxy.NewHandler(tunnelManager, logger, cfg.TunnelDomain, cfg.AuthToken, cfg.MetricsToken)
|
||||
httpHandler.SetAllowedTransports(cfg.AllowedTransports)
|
||||
httpHandler.SetAllowedTunnelTypes(cfg.AllowedTunnelTypes)
|
||||
|
||||
listener := tcp.NewListener(listenAddr, tlsConfig, serverAuthToken, tunnelManager, logger, portAllocator, serverDomain, tunnelDomain, displayPort, httpHandler)
|
||||
listener.SetAllowedTransports(serverConfig.AllowedTransports)
|
||||
listener.SetAllowedTunnelTypes(serverConfig.AllowedTunnelTypes)
|
||||
listener := tcp.NewListener(listenAddr, tlsConfig, cfg.AuthToken, tunnelManager, logger, portAllocator, cfg.Domain, cfg.TunnelDomain, cfg.PublicPort, httpHandler)
|
||||
listener.SetAllowedTransports(cfg.AllowedTransports)
|
||||
listener.SetAllowedTunnelTypes(cfg.AllowedTunnelTypes)
|
||||
|
||||
if err := listener.Start(); err != nil {
|
||||
logger.Fatal("Failed to start TCP listener", zap.Error(err))
|
||||
@@ -168,11 +303,11 @@ func runServer(_ *cobra.Command, _ []string) error {
|
||||
|
||||
logger.Info("Drip Server started",
|
||||
zap.String("address", listenAddr),
|
||||
zap.String("domain", serverDomain),
|
||||
zap.String("tunnel_domain", tunnelDomain),
|
||||
zap.String("domain", cfg.Domain),
|
||||
zap.String("tunnel_domain", cfg.TunnelDomain),
|
||||
zap.String("protocol", "TCP over TLS 1.3"),
|
||||
zap.Strings("transports", serverConfig.AllowedTransports),
|
||||
zap.Strings("tunnel_types", serverConfig.AllowedTunnelTypes),
|
||||
zap.Strings("transports", cfg.AllowedTransports),
|
||||
zap.Strings("tunnel_types", cfg.AllowedTunnelTypes),
|
||||
)
|
||||
|
||||
quit := make(chan os.Signal, 1)
|
||||
|
||||
@@ -114,10 +114,10 @@ func runServerConfigShow(_ *cobra.Command, _ []string) error {
|
||||
}
|
||||
|
||||
// Configuration sources
|
||||
fmt.Println("📋 Configuration Sources:")
|
||||
fmt.Println("Configuration Sources:")
|
||||
fmt.Println(" Command-line flags (highest priority)")
|
||||
fmt.Println(" Environment variables (DRIP_*)")
|
||||
fmt.Println(" Command-line flags")
|
||||
fmt.Println(" Config file: /etc/drip/server.env")
|
||||
fmt.Println(" Config file: /etc/drip/config.yaml or ~/.drip/server.yaml")
|
||||
fmt.Println()
|
||||
|
||||
// Endpoints
|
||||
|
||||
250
internal/client/cli/start.go
Normal file
250
internal/client/cli/start.go
Normal file
@@ -0,0 +1,250 @@
|
||||
package cli
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"os"
|
||||
"os/signal"
|
||||
"strings"
|
||||
"sync"
|
||||
"syscall"
|
||||
|
||||
"drip/internal/client/tcp"
|
||||
"drip/internal/shared/protocol"
|
||||
"drip/internal/shared/ui"
|
||||
"drip/internal/shared/utils"
|
||||
"drip/pkg/config"
|
||||
|
||||
"github.com/spf13/cobra"
|
||||
)
|
||||
|
||||
var (
|
||||
startAll bool
|
||||
)
|
||||
|
||||
var startCmd = &cobra.Command{
|
||||
Use: "start [tunnel-names...]",
|
||||
Short: "Start predefined tunnels from config",
|
||||
Long: `Start one or more predefined tunnels from your configuration file.
|
||||
|
||||
Examples:
|
||||
drip start web Start the tunnel named "web"
|
||||
drip start web api Start multiple tunnels
|
||||
drip start --all Start all configured tunnels
|
||||
|
||||
Configuration file example (~/.drip/config.yaml):
|
||||
server: tunnel.example.com:443
|
||||
token: your-token
|
||||
tls: true
|
||||
tunnels:
|
||||
- name: web
|
||||
type: http
|
||||
port: 3000
|
||||
subdomain: myapp
|
||||
|
||||
- name: api
|
||||
type: http
|
||||
port: 8080
|
||||
subdomain: api
|
||||
transport: wss
|
||||
|
||||
- name: db
|
||||
type: tcp
|
||||
port: 5432
|
||||
subdomain: postgres
|
||||
allow_ips:
|
||||
- 192.168.0.0/16
|
||||
- 10.0.0.0/8`,
|
||||
RunE: runStart,
|
||||
}
|
||||
|
||||
func init() {
|
||||
startCmd.Flags().BoolVar(&startAll, "all", false, "Start all configured tunnels")
|
||||
rootCmd.AddCommand(startCmd)
|
||||
}
|
||||
|
||||
func runStart(_ *cobra.Command, args []string) error {
|
||||
cfg, err := config.LoadClientConfig("")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if len(cfg.Tunnels) == 0 {
|
||||
return fmt.Errorf("no tunnels configured in %s", config.DefaultClientConfigPath())
|
||||
}
|
||||
|
||||
var tunnelsToStart []*config.TunnelConfig
|
||||
|
||||
if startAll {
|
||||
tunnelsToStart = cfg.Tunnels
|
||||
} else if len(args) == 0 {
|
||||
// No args and no --all flag, show available tunnels
|
||||
fmt.Println(ui.Title("Available Tunnels"))
|
||||
fmt.Println()
|
||||
for _, t := range cfg.Tunnels {
|
||||
fmt.Printf(" %s\n", formatTunnelInfo(t))
|
||||
}
|
||||
fmt.Println()
|
||||
fmt.Println("Usage:")
|
||||
fmt.Println(" drip start <tunnel-name> Start a specific tunnel")
|
||||
fmt.Println(" drip start --all Start all tunnels")
|
||||
return nil
|
||||
} else {
|
||||
// Start specific tunnels by name
|
||||
for _, name := range args {
|
||||
t := cfg.GetTunnel(name)
|
||||
if t == nil {
|
||||
availableNames := cfg.GetTunnelNames()
|
||||
return fmt.Errorf("tunnel '%s' not found. Available tunnels: %s", name, strings.Join(availableNames, ", "))
|
||||
}
|
||||
tunnelsToStart = append(tunnelsToStart, t)
|
||||
}
|
||||
}
|
||||
|
||||
if len(tunnelsToStart) == 0 {
|
||||
return fmt.Errorf("no tunnels to start")
|
||||
}
|
||||
|
||||
// Start tunnels
|
||||
if len(tunnelsToStart) == 1 {
|
||||
return startSingleTunnel(cfg, tunnelsToStart[0])
|
||||
}
|
||||
|
||||
return startMultipleTunnels(cfg, tunnelsToStart)
|
||||
}
|
||||
|
||||
func formatTunnelInfo(t *config.TunnelConfig) string {
|
||||
addr := t.Address
|
||||
if addr == "" {
|
||||
addr = "127.0.0.1"
|
||||
}
|
||||
info := fmt.Sprintf("%-12s %s %s:%d", t.Name, t.Type, addr, t.Port)
|
||||
if t.Subdomain != "" {
|
||||
info += fmt.Sprintf(" (subdomain: %s)", t.Subdomain)
|
||||
}
|
||||
return info
|
||||
}
|
||||
|
||||
func startSingleTunnel(cfg *config.ClientConfig, t *config.TunnelConfig) error {
|
||||
connConfig := buildConnectorConfig(cfg, t)
|
||||
|
||||
fmt.Printf("Starting tunnel '%s' (%s %s:%d)\n", t.Name, t.Type, getAddress(t), t.Port)
|
||||
|
||||
return runTunnelWithUI(connConfig, nil)
|
||||
}
|
||||
|
||||
func startMultipleTunnels(cfg *config.ClientConfig, tunnels []*config.TunnelConfig) error {
|
||||
if err := utils.InitLogger(verbose); err != nil {
|
||||
return fmt.Errorf("failed to initialize logger: %w", err)
|
||||
}
|
||||
defer utils.Sync()
|
||||
|
||||
logger := utils.GetLogger()
|
||||
|
||||
fmt.Println(ui.Title("Starting Tunnels"))
|
||||
fmt.Println()
|
||||
|
||||
var wg sync.WaitGroup
|
||||
errChan := make(chan error, len(tunnels))
|
||||
stopChan := make(chan struct{})
|
||||
|
||||
// Handle interrupt signal
|
||||
sigChan := make(chan os.Signal, 1)
|
||||
signal.Notify(sigChan, syscall.SIGINT, syscall.SIGTERM)
|
||||
|
||||
go func() {
|
||||
<-sigChan
|
||||
fmt.Println("\nShutting down tunnels...")
|
||||
close(stopChan)
|
||||
}()
|
||||
|
||||
for _, t := range tunnels {
|
||||
wg.Add(1)
|
||||
go func(tunnel *config.TunnelConfig) {
|
||||
defer wg.Done()
|
||||
|
||||
connConfig := buildConnectorConfig(cfg, tunnel)
|
||||
fmt.Printf(" Starting %s (%s %s:%d)...\n", tunnel.Name, tunnel.Type, getAddress(tunnel), tunnel.Port)
|
||||
|
||||
client := tcp.NewTunnelClient(connConfig, logger)
|
||||
|
||||
// Connect
|
||||
if err := client.Connect(); err != nil {
|
||||
errChan <- fmt.Errorf("%s: %w", tunnel.Name, err)
|
||||
return
|
||||
}
|
||||
|
||||
fmt.Printf(" ✓ %s: %s\n", tunnel.Name, client.GetURL())
|
||||
|
||||
// Run until stopped
|
||||
select {
|
||||
case <-stopChan:
|
||||
client.Close()
|
||||
}
|
||||
}(t)
|
||||
}
|
||||
|
||||
// Wait for interrupt or error
|
||||
go func() {
|
||||
wg.Wait()
|
||||
close(errChan)
|
||||
}()
|
||||
|
||||
// Collect errors
|
||||
var errors []error
|
||||
for err := range errChan {
|
||||
errors = append(errors, err)
|
||||
fmt.Printf(" ✗ %v\n", err)
|
||||
}
|
||||
|
||||
// Wait for signal if no errors
|
||||
if len(errors) == 0 {
|
||||
<-stopChan
|
||||
}
|
||||
|
||||
wg.Wait()
|
||||
|
||||
if len(errors) > 0 {
|
||||
return fmt.Errorf("%d tunnel(s) failed to start", len(errors))
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func buildConnectorConfig(cfg *config.ClientConfig, t *config.TunnelConfig) *tcp.ConnectorConfig {
|
||||
tunnelType := protocol.TunnelTypeHTTP
|
||||
switch t.Type {
|
||||
case "https":
|
||||
tunnelType = protocol.TunnelTypeHTTPS
|
||||
case "tcp":
|
||||
tunnelType = protocol.TunnelTypeTCP
|
||||
}
|
||||
|
||||
transport := tcp.TransportAuto
|
||||
switch strings.ToLower(t.Transport) {
|
||||
case "tcp", "tls":
|
||||
transport = tcp.TransportTCP
|
||||
case "wss":
|
||||
transport = tcp.TransportWebSocket
|
||||
}
|
||||
|
||||
return &tcp.ConnectorConfig{
|
||||
ServerAddr: cfg.Server,
|
||||
Token: cfg.Token,
|
||||
TunnelType: tunnelType,
|
||||
LocalHost: getAddress(t),
|
||||
LocalPort: t.Port,
|
||||
Subdomain: t.Subdomain,
|
||||
Insecure: insecure,
|
||||
AllowIPs: t.AllowIPs,
|
||||
DenyIPs: t.DenyIPs,
|
||||
AuthPass: t.Auth,
|
||||
Transport: transport,
|
||||
}
|
||||
}
|
||||
|
||||
func getAddress(t *config.TunnelConfig) string {
|
||||
if t.Address != "" {
|
||||
return t.Address
|
||||
}
|
||||
return "127.0.0.1"
|
||||
}
|
||||
Reference in New Issue
Block a user