mirror of
https://github.com/Gouryella/drip.git
synced 2026-02-23 12:53:43 +00:00
- Implement client bandwidth limitation parameter --bandwidth, supporting 1M, 1MB, 1G and other formats - Added parseBandwidth function to parse bandwidth values and verify them - Added bandwidth limit option in HTTP, HTTPS, TCP commands - Pass bandwidth configuration to the server through protocol - Add relevant test cases to verify the bandwidth analysis function feat(server): implements server-side bandwidth limitation function - Add bandwidth limitation logic in connection processing, using token bucket algorithm - Implement an effective rate limiting strategy that minimizes the bandwidth of the client and server - Added QoS limiter and restricted connection wrapper - Integrated bandwidth throttling in HTTP and WebSocket proxies - Added global bandwidth limit and burst multiplier settings in server configuration docs: Updated documentation to describe bandwidth limiting functionality - Add 2025-02-14 version update instructions in README and README_CN - Add bandwidth limit function description and usage examples - Provide client and server configuration examples and parameter descriptions
260 lines
5.7 KiB
Go
260 lines
5.7 KiB
Go
package cli
|
|
|
|
import (
|
|
"fmt"
|
|
"os"
|
|
"os/signal"
|
|
"strings"
|
|
"sync"
|
|
"syscall"
|
|
|
|
"drip/internal/client/tcp"
|
|
"drip/internal/shared/protocol"
|
|
"drip/internal/shared/ui"
|
|
"drip/internal/shared/utils"
|
|
"drip/pkg/config"
|
|
|
|
"github.com/spf13/cobra"
|
|
)
|
|
|
|
var (
|
|
startAll bool
|
|
)
|
|
|
|
var startCmd = &cobra.Command{
|
|
Use: "start [tunnel-names...]",
|
|
Short: "Start predefined tunnels from config",
|
|
Long: `Start one or more predefined tunnels from your configuration file.
|
|
|
|
Examples:
|
|
drip start web Start the tunnel named "web"
|
|
drip start web api Start multiple tunnels
|
|
drip start --all Start all configured tunnels
|
|
|
|
Configuration file example (~/.drip/config.yaml):
|
|
server: tunnel.example.com:443
|
|
token: your-token
|
|
tls: true
|
|
tunnels:
|
|
- name: web
|
|
type: http
|
|
port: 3000
|
|
subdomain: myapp
|
|
|
|
- name: api
|
|
type: http
|
|
port: 8080
|
|
subdomain: api
|
|
transport: wss
|
|
|
|
- name: db
|
|
type: tcp
|
|
port: 5432
|
|
subdomain: postgres
|
|
allow_ips:
|
|
- 192.168.0.0/16
|
|
- 10.0.0.0/8`,
|
|
RunE: runStart,
|
|
SilenceUsage: true,
|
|
SilenceErrors: true,
|
|
}
|
|
|
|
func init() {
|
|
startCmd.Flags().BoolVar(&startAll, "all", false, "Start all configured tunnels")
|
|
rootCmd.AddCommand(startCmd)
|
|
}
|
|
|
|
func runStart(_ *cobra.Command, args []string) error {
|
|
cfg, err := config.LoadClientConfig("")
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
if len(cfg.Tunnels) == 0 {
|
|
return fmt.Errorf("no tunnels configured in %s", config.DefaultClientConfigPath())
|
|
}
|
|
|
|
var tunnelsToStart []*config.TunnelConfig
|
|
|
|
if startAll {
|
|
tunnelsToStart = cfg.Tunnels
|
|
} else if len(args) == 0 {
|
|
// No args and no --all flag, show available tunnels
|
|
fmt.Println(ui.Title("Available Tunnels"))
|
|
fmt.Println()
|
|
for _, t := range cfg.Tunnels {
|
|
fmt.Printf(" %s\n", formatTunnelInfo(t))
|
|
}
|
|
fmt.Println()
|
|
fmt.Println("Usage:")
|
|
fmt.Println(" drip start <tunnel-name> Start a specific tunnel")
|
|
fmt.Println(" drip start --all Start all tunnels")
|
|
return nil
|
|
} else {
|
|
// Start specific tunnels by name
|
|
for _, name := range args {
|
|
t := cfg.GetTunnel(name)
|
|
if t == nil {
|
|
availableNames := cfg.GetTunnelNames()
|
|
return fmt.Errorf("tunnel '%s' not found. Available tunnels: %s", name, strings.Join(availableNames, ", "))
|
|
}
|
|
tunnelsToStart = append(tunnelsToStart, t)
|
|
}
|
|
}
|
|
|
|
if len(tunnelsToStart) == 0 {
|
|
return fmt.Errorf("no tunnels to start")
|
|
}
|
|
|
|
// Start tunnels
|
|
if len(tunnelsToStart) == 1 {
|
|
return startSingleTunnel(cfg, tunnelsToStart[0])
|
|
}
|
|
|
|
return startMultipleTunnels(cfg, tunnelsToStart)
|
|
}
|
|
|
|
func formatTunnelInfo(t *config.TunnelConfig) string {
|
|
addr := t.Address
|
|
if addr == "" {
|
|
addr = "127.0.0.1"
|
|
}
|
|
info := fmt.Sprintf("%-12s %s %s:%d", t.Name, t.Type, addr, t.Port)
|
|
if t.Subdomain != "" {
|
|
info += fmt.Sprintf(" (subdomain: %s)", t.Subdomain)
|
|
}
|
|
return info
|
|
}
|
|
|
|
func startSingleTunnel(cfg *config.ClientConfig, t *config.TunnelConfig) error {
|
|
connConfig := buildConnectorConfig(cfg, t)
|
|
|
|
fmt.Printf("Starting tunnel '%s' (%s %s:%d)\n", t.Name, t.Type, getAddress(t), t.Port)
|
|
|
|
return runTunnelWithUI(connConfig, nil)
|
|
}
|
|
|
|
func startMultipleTunnels(cfg *config.ClientConfig, tunnels []*config.TunnelConfig) error {
|
|
if err := utils.InitLogger(verbose); err != nil {
|
|
return fmt.Errorf("failed to initialize logger: %w", err)
|
|
}
|
|
defer utils.Sync()
|
|
|
|
logger := utils.GetLogger()
|
|
|
|
fmt.Println(ui.Title("Starting Tunnels"))
|
|
fmt.Println()
|
|
|
|
var wg sync.WaitGroup
|
|
errChan := make(chan error, len(tunnels))
|
|
stopChan := make(chan struct{})
|
|
|
|
// Handle interrupt signal
|
|
sigChan := make(chan os.Signal, 1)
|
|
signal.Notify(sigChan, syscall.SIGINT, syscall.SIGTERM)
|
|
|
|
go func() {
|
|
<-sigChan
|
|
fmt.Println("\nShutting down tunnels...")
|
|
close(stopChan)
|
|
}()
|
|
|
|
for _, t := range tunnels {
|
|
wg.Add(1)
|
|
go func(tunnel *config.TunnelConfig) {
|
|
defer wg.Done()
|
|
|
|
connConfig := buildConnectorConfig(cfg, tunnel)
|
|
fmt.Printf(" Starting %s (%s %s:%d)...\n", tunnel.Name, tunnel.Type, getAddress(tunnel), tunnel.Port)
|
|
|
|
client := tcp.NewTunnelClient(connConfig, logger)
|
|
|
|
// Connect
|
|
if err := client.Connect(); err != nil {
|
|
errChan <- fmt.Errorf("%s: %w", tunnel.Name, err)
|
|
return
|
|
}
|
|
|
|
fmt.Printf(" ✓ %s: %s\n", tunnel.Name, client.GetURL())
|
|
|
|
// Run until stopped
|
|
select {
|
|
case <-stopChan:
|
|
client.Close()
|
|
}
|
|
}(t)
|
|
}
|
|
|
|
// Wait for interrupt or error
|
|
go func() {
|
|
wg.Wait()
|
|
close(errChan)
|
|
}()
|
|
|
|
// Collect errors
|
|
var errors []error
|
|
for err := range errChan {
|
|
errors = append(errors, err)
|
|
fmt.Printf(" ✗ %v\n", err)
|
|
}
|
|
|
|
// Wait for signal if no errors
|
|
if len(errors) == 0 {
|
|
<-stopChan
|
|
}
|
|
|
|
wg.Wait()
|
|
|
|
if len(errors) > 0 {
|
|
return fmt.Errorf("%d tunnel(s) failed to start", len(errors))
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
func buildConnectorConfig(cfg *config.ClientConfig, t *config.TunnelConfig) *tcp.ConnectorConfig {
|
|
tunnelType := protocol.TunnelTypeHTTP
|
|
switch t.Type {
|
|
case "https":
|
|
tunnelType = protocol.TunnelTypeHTTPS
|
|
case "tcp":
|
|
tunnelType = protocol.TunnelTypeTCP
|
|
}
|
|
|
|
transport := tcp.TransportAuto
|
|
switch strings.ToLower(t.Transport) {
|
|
case "tcp", "tls":
|
|
transport = tcp.TransportTCP
|
|
case "wss":
|
|
transport = tcp.TransportWebSocket
|
|
}
|
|
|
|
return &tcp.ConnectorConfig{
|
|
ServerAddr: cfg.Server,
|
|
Token: cfg.Token,
|
|
TunnelType: tunnelType,
|
|
LocalHost: getAddress(t),
|
|
LocalPort: t.Port,
|
|
Subdomain: t.Subdomain,
|
|
Insecure: insecure,
|
|
AllowIPs: t.AllowIPs,
|
|
DenyIPs: t.DenyIPs,
|
|
AuthPass: t.Auth,
|
|
AuthBearer: t.AuthBearer,
|
|
Transport: transport,
|
|
Bandwidth: parseBandwidthOrZero(t.Bandwidth),
|
|
}
|
|
}
|
|
|
|
func getAddress(t *config.TunnelConfig) string {
|
|
if t.Address != "" {
|
|
return t.Address
|
|
}
|
|
return "127.0.0.1"
|
|
}
|
|
|
|
func parseBandwidthOrZero(s string) int64 {
|
|
bw, _ := parseBandwidth(s)
|
|
return bw
|
|
}
|