Files
drip/internal/client/cli/start.go
Gouryella 89f67ab145 feat(client): Add bandwidth limit function support
- Implement client bandwidth limitation parameter --bandwidth, supporting 1M, 1MB, 1G and other formats
- Added parseBandwidth function to parse bandwidth values and verify them
- Added bandwidth limit option in HTTP, HTTPS, TCP commands
- Pass bandwidth configuration to the server through protocol
- Add relevant test cases to verify the bandwidth analysis function

feat(server): implements server-side bandwidth limitation function

- Add bandwidth limitation logic in connection processing, using token bucket algorithm
- Implement an effective rate limiting strategy that minimizes the bandwidth of the client and server
- Added QoS limiter and restricted connection wrapper
- Integrated bandwidth throttling in HTTP and WebSocket proxies
- Added global bandwidth limit and burst multiplier settings in server configuration

docs: Updated documentation to describe bandwidth limiting functionality

- Add 2025-02-14 version update instructions in README and README_CN
- Add bandwidth limit function description and usage examples
- Provide client and server configuration examples and parameter descriptions
2026-02-15 02:39:50 +08:00

260 lines
5.7 KiB
Go

package cli
import (
"fmt"
"os"
"os/signal"
"strings"
"sync"
"syscall"
"drip/internal/client/tcp"
"drip/internal/shared/protocol"
"drip/internal/shared/ui"
"drip/internal/shared/utils"
"drip/pkg/config"
"github.com/spf13/cobra"
)
var (
startAll bool
)
var startCmd = &cobra.Command{
Use: "start [tunnel-names...]",
Short: "Start predefined tunnels from config",
Long: `Start one or more predefined tunnels from your configuration file.
Examples:
drip start web Start the tunnel named "web"
drip start web api Start multiple tunnels
drip start --all Start all configured tunnels
Configuration file example (~/.drip/config.yaml):
server: tunnel.example.com:443
token: your-token
tls: true
tunnels:
- name: web
type: http
port: 3000
subdomain: myapp
- name: api
type: http
port: 8080
subdomain: api
transport: wss
- name: db
type: tcp
port: 5432
subdomain: postgres
allow_ips:
- 192.168.0.0/16
- 10.0.0.0/8`,
RunE: runStart,
SilenceUsage: true,
SilenceErrors: true,
}
func init() {
startCmd.Flags().BoolVar(&startAll, "all", false, "Start all configured tunnels")
rootCmd.AddCommand(startCmd)
}
func runStart(_ *cobra.Command, args []string) error {
cfg, err := config.LoadClientConfig("")
if err != nil {
return err
}
if len(cfg.Tunnels) == 0 {
return fmt.Errorf("no tunnels configured in %s", config.DefaultClientConfigPath())
}
var tunnelsToStart []*config.TunnelConfig
if startAll {
tunnelsToStart = cfg.Tunnels
} else if len(args) == 0 {
// No args and no --all flag, show available tunnels
fmt.Println(ui.Title("Available Tunnels"))
fmt.Println()
for _, t := range cfg.Tunnels {
fmt.Printf(" %s\n", formatTunnelInfo(t))
}
fmt.Println()
fmt.Println("Usage:")
fmt.Println(" drip start <tunnel-name> Start a specific tunnel")
fmt.Println(" drip start --all Start all tunnels")
return nil
} else {
// Start specific tunnels by name
for _, name := range args {
t := cfg.GetTunnel(name)
if t == nil {
availableNames := cfg.GetTunnelNames()
return fmt.Errorf("tunnel '%s' not found. Available tunnels: %s", name, strings.Join(availableNames, ", "))
}
tunnelsToStart = append(tunnelsToStart, t)
}
}
if len(tunnelsToStart) == 0 {
return fmt.Errorf("no tunnels to start")
}
// Start tunnels
if len(tunnelsToStart) == 1 {
return startSingleTunnel(cfg, tunnelsToStart[0])
}
return startMultipleTunnels(cfg, tunnelsToStart)
}
func formatTunnelInfo(t *config.TunnelConfig) string {
addr := t.Address
if addr == "" {
addr = "127.0.0.1"
}
info := fmt.Sprintf("%-12s %s %s:%d", t.Name, t.Type, addr, t.Port)
if t.Subdomain != "" {
info += fmt.Sprintf(" (subdomain: %s)", t.Subdomain)
}
return info
}
func startSingleTunnel(cfg *config.ClientConfig, t *config.TunnelConfig) error {
connConfig := buildConnectorConfig(cfg, t)
fmt.Printf("Starting tunnel '%s' (%s %s:%d)\n", t.Name, t.Type, getAddress(t), t.Port)
return runTunnelWithUI(connConfig, nil)
}
func startMultipleTunnels(cfg *config.ClientConfig, tunnels []*config.TunnelConfig) error {
if err := utils.InitLogger(verbose); err != nil {
return fmt.Errorf("failed to initialize logger: %w", err)
}
defer utils.Sync()
logger := utils.GetLogger()
fmt.Println(ui.Title("Starting Tunnels"))
fmt.Println()
var wg sync.WaitGroup
errChan := make(chan error, len(tunnels))
stopChan := make(chan struct{})
// Handle interrupt signal
sigChan := make(chan os.Signal, 1)
signal.Notify(sigChan, syscall.SIGINT, syscall.SIGTERM)
go func() {
<-sigChan
fmt.Println("\nShutting down tunnels...")
close(stopChan)
}()
for _, t := range tunnels {
wg.Add(1)
go func(tunnel *config.TunnelConfig) {
defer wg.Done()
connConfig := buildConnectorConfig(cfg, tunnel)
fmt.Printf(" Starting %s (%s %s:%d)...\n", tunnel.Name, tunnel.Type, getAddress(tunnel), tunnel.Port)
client := tcp.NewTunnelClient(connConfig, logger)
// Connect
if err := client.Connect(); err != nil {
errChan <- fmt.Errorf("%s: %w", tunnel.Name, err)
return
}
fmt.Printf(" ✓ %s: %s\n", tunnel.Name, client.GetURL())
// Run until stopped
select {
case <-stopChan:
client.Close()
}
}(t)
}
// Wait for interrupt or error
go func() {
wg.Wait()
close(errChan)
}()
// Collect errors
var errors []error
for err := range errChan {
errors = append(errors, err)
fmt.Printf(" ✗ %v\n", err)
}
// Wait for signal if no errors
if len(errors) == 0 {
<-stopChan
}
wg.Wait()
if len(errors) > 0 {
return fmt.Errorf("%d tunnel(s) failed to start", len(errors))
}
return nil
}
func buildConnectorConfig(cfg *config.ClientConfig, t *config.TunnelConfig) *tcp.ConnectorConfig {
tunnelType := protocol.TunnelTypeHTTP
switch t.Type {
case "https":
tunnelType = protocol.TunnelTypeHTTPS
case "tcp":
tunnelType = protocol.TunnelTypeTCP
}
transport := tcp.TransportAuto
switch strings.ToLower(t.Transport) {
case "tcp", "tls":
transport = tcp.TransportTCP
case "wss":
transport = tcp.TransportWebSocket
}
return &tcp.ConnectorConfig{
ServerAddr: cfg.Server,
Token: cfg.Token,
TunnelType: tunnelType,
LocalHost: getAddress(t),
LocalPort: t.Port,
Subdomain: t.Subdomain,
Insecure: insecure,
AllowIPs: t.AllowIPs,
DenyIPs: t.DenyIPs,
AuthPass: t.Auth,
AuthBearer: t.AuthBearer,
Transport: transport,
Bandwidth: parseBandwidthOrZero(t.Bandwidth),
}
}
func getAddress(t *config.TunnelConfig) string {
if t.Address != "" {
return t.Address
}
return "127.0.0.1"
}
func parseBandwidthOrZero(s string) int64 {
bw, _ := parseBandwidth(s)
return bw
}