mirror of
https://github.com/Gouryella/drip.git
synced 2026-02-26 14:21:17 +00:00
feat(tunnel): switch to yamux stream proxying and connection pooling
- Introduce pooled tunnel sessions (TunnelID/DataConnect) on client/server - Proxy HTTP/HTTPS via raw HTTP over yamux streams; pipe TCP streams directly - Move UI/stats into internal/shared; refactor CLI tunnel helpers; drop msgpack/hpack legacy
This commit is contained in:
@@ -1,280 +0,0 @@
|
||||
package hpack
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"sync"
|
||||
)
|
||||
|
||||
// Decoder decompresses HPACK-encoded headers
|
||||
// Each connection MUST have its own decoder instance to maintain correct state
|
||||
type Decoder struct {
|
||||
mu sync.Mutex
|
||||
dynamicTable *DynamicTable
|
||||
staticTable *StaticTable
|
||||
maxTableSize uint32
|
||||
}
|
||||
|
||||
// NewDecoder creates a new HPACK decoder
|
||||
func NewDecoder(maxTableSize uint32) *Decoder {
|
||||
if maxTableSize == 0 {
|
||||
maxTableSize = DefaultDynamicTableSize
|
||||
}
|
||||
|
||||
return &Decoder{
|
||||
dynamicTable: NewDynamicTable(maxTableSize),
|
||||
staticTable: GetStaticTable(),
|
||||
maxTableSize: maxTableSize,
|
||||
}
|
||||
}
|
||||
|
||||
// Decode decodes HPACK-encoded headers
|
||||
func (d *Decoder) Decode(data []byte) (http.Header, error) {
|
||||
d.mu.Lock()
|
||||
defer d.mu.Unlock()
|
||||
|
||||
if len(data) == 0 {
|
||||
return http.Header{}, nil
|
||||
}
|
||||
|
||||
headers := make(http.Header)
|
||||
buf := bytes.NewReader(data)
|
||||
|
||||
for buf.Len() > 0 {
|
||||
b, err := buf.ReadByte()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("read header byte: %w", err)
|
||||
}
|
||||
|
||||
// Unread the byte so we can process it properly
|
||||
if err := buf.UnreadByte(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var name, value string
|
||||
|
||||
if b&0x80 != 0 {
|
||||
// Indexed header field (10xxxxxx)
|
||||
name, value, err = d.decodeIndexedHeader(buf)
|
||||
} else if b&0x40 != 0 {
|
||||
// Literal with incremental indexing (01xxxxxx)
|
||||
name, value, err = d.decodeLiteralWithIndexing(buf)
|
||||
} else {
|
||||
// Literal without indexing (0000xxxx)
|
||||
name, value, err = d.decodeLiteralWithoutIndexing(buf)
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
headers.Add(name, value)
|
||||
}
|
||||
|
||||
return headers, nil
|
||||
}
|
||||
|
||||
// decodeIndexedHeader decodes an indexed header field
|
||||
func (d *Decoder) decodeIndexedHeader(buf *bytes.Reader) (string, string, error) {
|
||||
index, err := d.readInteger(buf, 7)
|
||||
if err != nil {
|
||||
return "", "", fmt.Errorf("read index: %w", err)
|
||||
}
|
||||
|
||||
if index == 0 {
|
||||
return "", "", errors.New("invalid index: 0")
|
||||
}
|
||||
|
||||
staticSize := uint32(d.staticTable.Size())
|
||||
|
||||
if index <= staticSize {
|
||||
// Static table
|
||||
return d.staticTable.Get(index - 1)
|
||||
}
|
||||
|
||||
// Dynamic table (indices start after static table)
|
||||
dynamicIndex := index - staticSize - 1
|
||||
return d.dynamicTable.Get(dynamicIndex)
|
||||
}
|
||||
|
||||
// decodeLiteralWithIndexing decodes a literal header with incremental indexing
|
||||
func (d *Decoder) decodeLiteralWithIndexing(buf *bytes.Reader) (string, string, error) {
|
||||
nameIndex, err := d.readInteger(buf, 6)
|
||||
if err != nil {
|
||||
return "", "", err
|
||||
}
|
||||
|
||||
var name string
|
||||
if nameIndex == 0 {
|
||||
// Name is literal
|
||||
name, err = d.readString(buf)
|
||||
if err != nil {
|
||||
return "", "", fmt.Errorf("read name: %w", err)
|
||||
}
|
||||
} else {
|
||||
// Name is indexed
|
||||
staticSize := uint32(d.staticTable.Size())
|
||||
if nameIndex <= staticSize {
|
||||
name, _, err = d.staticTable.Get(nameIndex - 1)
|
||||
} else {
|
||||
dynamicIndex := nameIndex - staticSize - 1
|
||||
name, _, err = d.dynamicTable.Get(dynamicIndex)
|
||||
}
|
||||
if err != nil {
|
||||
return "", "", fmt.Errorf("get indexed name: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
// Value is always literal
|
||||
value, err := d.readString(buf)
|
||||
if err != nil {
|
||||
return "", "", fmt.Errorf("read value: %w", err)
|
||||
}
|
||||
|
||||
// Add to dynamic table
|
||||
d.dynamicTable.Add(name, value)
|
||||
|
||||
return name, value, nil
|
||||
}
|
||||
|
||||
// decodeLiteralWithoutIndexing decodes a literal header without indexing
|
||||
func (d *Decoder) decodeLiteralWithoutIndexing(buf *bytes.Reader) (string, string, error) {
|
||||
nameIndex, err := d.readInteger(buf, 4)
|
||||
if err != nil {
|
||||
return "", "", err
|
||||
}
|
||||
|
||||
var name string
|
||||
if nameIndex == 0 {
|
||||
// Name is literal
|
||||
name, err = d.readString(buf)
|
||||
if err != nil {
|
||||
return "", "", fmt.Errorf("read name: %w", err)
|
||||
}
|
||||
} else {
|
||||
// Name is indexed
|
||||
staticSize := uint32(d.staticTable.Size())
|
||||
if nameIndex <= staticSize {
|
||||
name, _, err = d.staticTable.Get(nameIndex - 1)
|
||||
} else {
|
||||
dynamicIndex := nameIndex - staticSize - 1
|
||||
name, _, err = d.dynamicTable.Get(dynamicIndex)
|
||||
}
|
||||
if err != nil {
|
||||
return "", "", fmt.Errorf("get indexed name: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
// Value is always literal
|
||||
value, err := d.readString(buf)
|
||||
if err != nil {
|
||||
return "", "", fmt.Errorf("read value: %w", err)
|
||||
}
|
||||
|
||||
// Do NOT add to dynamic table
|
||||
|
||||
return name, value, nil
|
||||
}
|
||||
|
||||
// readInteger reads an HPACK integer
|
||||
func (d *Decoder) readInteger(buf *bytes.Reader, prefixBits int) (uint32, error) {
|
||||
if prefixBits < 1 || prefixBits > 8 {
|
||||
return 0, fmt.Errorf("invalid prefix bits: %d", prefixBits)
|
||||
}
|
||||
|
||||
b, err := buf.ReadByte()
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
maxPrefix := uint32((1 << prefixBits) - 1)
|
||||
mask := byte(maxPrefix)
|
||||
|
||||
value := uint32(b & mask)
|
||||
if value < maxPrefix {
|
||||
return value, nil
|
||||
}
|
||||
|
||||
// Multi-byte integer
|
||||
m := uint32(0)
|
||||
for {
|
||||
b, err := buf.ReadByte()
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
value += uint32(b&0x7f) << m
|
||||
m += 7
|
||||
|
||||
if b&0x80 == 0 {
|
||||
break
|
||||
}
|
||||
|
||||
if m > 28 {
|
||||
return 0, errors.New("integer overflow")
|
||||
}
|
||||
}
|
||||
|
||||
return value, nil
|
||||
}
|
||||
|
||||
// readString reads an HPACK string
|
||||
func (d *Decoder) readString(buf *bytes.Reader) (string, error) {
|
||||
b, err := buf.ReadByte()
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
if err := buf.UnreadByte(); err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
huffmanEncoded := (b & 0x80) != 0
|
||||
|
||||
length, err := d.readInteger(buf, 7)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("read string length: %w", err)
|
||||
}
|
||||
|
||||
if length == 0 {
|
||||
return "", nil
|
||||
}
|
||||
|
||||
if length > uint32(buf.Len()) {
|
||||
return "", fmt.Errorf("string length %d exceeds buffer size %d", length, buf.Len())
|
||||
}
|
||||
|
||||
data := make([]byte, length)
|
||||
n, err := buf.Read(data)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
if n != int(length) {
|
||||
return "", fmt.Errorf("expected %d bytes, read %d", length, n)
|
||||
}
|
||||
|
||||
if huffmanEncoded {
|
||||
// TODO: Implement Huffman decoding if needed
|
||||
return "", errors.New("huffman decoding not implemented")
|
||||
}
|
||||
|
||||
return string(data), nil
|
||||
}
|
||||
|
||||
// SetMaxTableSize updates the dynamic table size
|
||||
func (d *Decoder) SetMaxTableSize(size uint32) {
|
||||
d.mu.Lock()
|
||||
defer d.mu.Unlock()
|
||||
|
||||
d.maxTableSize = size
|
||||
d.dynamicTable.SetMaxSize(size)
|
||||
}
|
||||
|
||||
// Reset clears the dynamic table
|
||||
func (d *Decoder) Reset() {
|
||||
d.mu.Lock()
|
||||
defer d.mu.Unlock()
|
||||
d.dynamicTable = NewDynamicTable(d.maxTableSize)
|
||||
}
|
||||
@@ -1,124 +0,0 @@
|
||||
package hpack
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
)
|
||||
|
||||
// DynamicTable implements the HPACK dynamic table (RFC 7541 Section 2.3.2)
|
||||
// The dynamic table is a FIFO queue where new entries are added at the beginning
|
||||
// and old entries are evicted when the table size exceeds the maximum
|
||||
type DynamicTable struct {
|
||||
entries []HeaderField
|
||||
size uint32 // Current size in bytes
|
||||
maxSize uint32 // Maximum size in bytes
|
||||
}
|
||||
|
||||
// HeaderField represents a header name-value pair
|
||||
type HeaderField struct {
|
||||
Name string
|
||||
Value string
|
||||
}
|
||||
|
||||
// Size returns the size of this header field in bytes
|
||||
// RFC 7541: size = len(name) + len(value) + 32
|
||||
func (h *HeaderField) Size() uint32 {
|
||||
return uint32(len(h.Name) + len(h.Value) + 32)
|
||||
}
|
||||
|
||||
// NewDynamicTable creates a new dynamic table with the specified maximum size
|
||||
func NewDynamicTable(maxSize uint32) *DynamicTable {
|
||||
return &DynamicTable{
|
||||
entries: make([]HeaderField, 0, 32),
|
||||
size: 0,
|
||||
maxSize: maxSize,
|
||||
}
|
||||
}
|
||||
|
||||
// Add adds a header field to the dynamic table
|
||||
// New entries are added at the beginning (index 0)
|
||||
func (dt *DynamicTable) Add(name, value string) {
|
||||
field := HeaderField{Name: name, Value: value}
|
||||
fieldSize := field.Size()
|
||||
|
||||
// If the field is larger than maxSize, don't add it
|
||||
if fieldSize > dt.maxSize {
|
||||
dt.evictAll()
|
||||
return
|
||||
}
|
||||
|
||||
// Evict entries if necessary to make room
|
||||
for dt.size+fieldSize > dt.maxSize && len(dt.entries) > 0 {
|
||||
dt.evictOldest()
|
||||
}
|
||||
|
||||
// Add new entry at the beginning
|
||||
dt.entries = append([]HeaderField{field}, dt.entries...)
|
||||
dt.size += fieldSize
|
||||
}
|
||||
|
||||
// Get retrieves a header field by index (0-based)
|
||||
// Index 0 is the most recently added entry
|
||||
func (dt *DynamicTable) Get(index uint32) (string, string, error) {
|
||||
if index >= uint32(len(dt.entries)) {
|
||||
return "", "", fmt.Errorf("index %d out of range (table size: %d)", index, len(dt.entries))
|
||||
}
|
||||
|
||||
field := dt.entries[index]
|
||||
return field.Name, field.Value, nil
|
||||
}
|
||||
|
||||
// FindExact searches for an exact match (name and value)
|
||||
// Returns the index (0-based) and true if found
|
||||
func (dt *DynamicTable) FindExact(name, value string) (uint32, bool) {
|
||||
for i, field := range dt.entries {
|
||||
if field.Name == name && field.Value == value {
|
||||
return uint32(i), true
|
||||
}
|
||||
}
|
||||
return 0, false
|
||||
}
|
||||
|
||||
// FindName searches for a name match
|
||||
// Returns the index (0-based) and true if found
|
||||
func (dt *DynamicTable) FindName(name string) (uint32, bool) {
|
||||
for i, field := range dt.entries {
|
||||
if field.Name == name {
|
||||
return uint32(i), true
|
||||
}
|
||||
}
|
||||
return 0, false
|
||||
}
|
||||
|
||||
// SetMaxSize updates the maximum table size
|
||||
// If the new size is smaller, entries are evicted
|
||||
func (dt *DynamicTable) SetMaxSize(maxSize uint32) {
|
||||
dt.maxSize = maxSize
|
||||
|
||||
// Evict entries if current size exceeds new max
|
||||
for dt.size > dt.maxSize && len(dt.entries) > 0 {
|
||||
dt.evictOldest()
|
||||
}
|
||||
}
|
||||
|
||||
// CurrentSize returns the current size of the table in bytes
|
||||
func (dt *DynamicTable) CurrentSize() uint32 {
|
||||
return dt.size
|
||||
}
|
||||
|
||||
// evictOldest removes the oldest entry (last in the slice)
|
||||
func (dt *DynamicTable) evictOldest() {
|
||||
if len(dt.entries) == 0 {
|
||||
return
|
||||
}
|
||||
|
||||
lastIndex := len(dt.entries) - 1
|
||||
evicted := dt.entries[lastIndex]
|
||||
dt.entries = dt.entries[:lastIndex]
|
||||
dt.size -= evicted.Size()
|
||||
}
|
||||
|
||||
// evictAll removes all entries
|
||||
func (dt *DynamicTable) evictAll() {
|
||||
dt.entries = dt.entries[:0]
|
||||
dt.size = 0
|
||||
}
|
||||
@@ -1,200 +0,0 @@
|
||||
package hpack
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"strings"
|
||||
"sync"
|
||||
)
|
||||
|
||||
const (
|
||||
// DefaultDynamicTableSize is the default size of the dynamic table (4KB)
|
||||
DefaultDynamicTableSize = 4096
|
||||
|
||||
// IndexedHeaderField represents a fully indexed header field
|
||||
indexedHeaderField = 0x80 // 10xxxxxx
|
||||
|
||||
// LiteralHeaderFieldWithIndexing represents a literal with incremental indexing
|
||||
literalHeaderFieldWithIndexing = 0x40 // 01xxxxxx
|
||||
)
|
||||
|
||||
// Encoder compresses HTTP headers using HPACK
|
||||
// Each connection MUST have its own encoder instance to avoid state corruption
|
||||
type Encoder struct {
|
||||
mu sync.Mutex
|
||||
dynamicTable *DynamicTable
|
||||
staticTable *StaticTable
|
||||
maxTableSize uint32
|
||||
}
|
||||
|
||||
// NewEncoder creates a new HPACK encoder with the specified dynamic table size
|
||||
// This encoder is NOT thread-safe and should be used by a single connection
|
||||
func NewEncoder(maxTableSize uint32) *Encoder {
|
||||
if maxTableSize == 0 {
|
||||
maxTableSize = DefaultDynamicTableSize
|
||||
}
|
||||
|
||||
return &Encoder{
|
||||
dynamicTable: NewDynamicTable(maxTableSize),
|
||||
staticTable: GetStaticTable(),
|
||||
maxTableSize: maxTableSize,
|
||||
}
|
||||
}
|
||||
|
||||
// Encode encodes HTTP headers into HPACK binary format
|
||||
// This method is safe to call concurrently within the same encoder instance
|
||||
func (e *Encoder) Encode(headers http.Header) ([]byte, error) {
|
||||
e.mu.Lock()
|
||||
defer e.mu.Unlock()
|
||||
|
||||
if headers == nil {
|
||||
return nil, errors.New("headers cannot be nil")
|
||||
}
|
||||
|
||||
buf := &bytes.Buffer{}
|
||||
|
||||
for name, values := range headers {
|
||||
for _, value := range values {
|
||||
if err := e.encodeHeaderField(buf, name, value); err != nil {
|
||||
return nil, fmt.Errorf("encode header %s: %w", name, err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return buf.Bytes(), nil
|
||||
}
|
||||
|
||||
// encodeHeaderField encodes a single header field
|
||||
func (e *Encoder) encodeHeaderField(buf *bytes.Buffer, name, value string) error {
|
||||
// HTTP/2 requires header names to be lowercase (RFC 7540 Section 8.1.2)
|
||||
// Convert to lowercase for table lookups and storage
|
||||
nameLower := strings.ToLower(name)
|
||||
|
||||
// Try to find in static table first
|
||||
if index, found := e.staticTable.FindExact(nameLower, value); found {
|
||||
return e.writeIndexedHeader(buf, index+1)
|
||||
}
|
||||
|
||||
// Check if name exists in static table (for literal with name reference)
|
||||
if index, found := e.staticTable.FindName(nameLower); found {
|
||||
return e.writeLiteralWithIndexing(buf, index+1, value, true)
|
||||
}
|
||||
|
||||
// Try dynamic table
|
||||
if index, found := e.dynamicTable.FindExact(nameLower, value); found {
|
||||
// Dynamic table indices start after static table
|
||||
dynamicIndex := uint32(e.staticTable.Size()) + index + 1
|
||||
return e.writeIndexedHeader(buf, dynamicIndex)
|
||||
}
|
||||
|
||||
if index, found := e.dynamicTable.FindName(nameLower); found {
|
||||
dynamicIndex := uint32(e.staticTable.Size()) + index + 1
|
||||
return e.writeLiteralWithIndexing(buf, dynamicIndex, value, true)
|
||||
}
|
||||
|
||||
// Not found anywhere - literal with indexing and new name
|
||||
// Write literal flag
|
||||
buf.WriteByte(literalHeaderFieldWithIndexing)
|
||||
|
||||
// Write name as literal string (must come before value)
|
||||
// Use lowercase name for consistency
|
||||
if err := e.writeString(buf, nameLower, false); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Write value as literal string
|
||||
if err := e.writeString(buf, value, false); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Add to dynamic table with lowercase name
|
||||
e.dynamicTable.Add(nameLower, value)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// writeIndexedHeader writes an indexed header field (10xxxxxx)
|
||||
func (e *Encoder) writeIndexedHeader(buf *bytes.Buffer, index uint32) error {
|
||||
return e.writeInteger(buf, index, 7, indexedHeaderField)
|
||||
}
|
||||
|
||||
// writeLiteralWithIndexing writes a literal header with incremental indexing (01xxxxxx)
|
||||
func (e *Encoder) writeLiteralWithIndexing(buf *bytes.Buffer, nameIndex uint32, value string, hasIndex bool) error {
|
||||
if hasIndex {
|
||||
// Write name as index
|
||||
if err := e.writeInteger(buf, nameIndex, 6, literalHeaderFieldWithIndexing); err != nil {
|
||||
return err
|
||||
}
|
||||
} else {
|
||||
// Write literal flag
|
||||
buf.WriteByte(literalHeaderFieldWithIndexing)
|
||||
}
|
||||
|
||||
// Write value as literal string
|
||||
return e.writeString(buf, value, false)
|
||||
}
|
||||
|
||||
// writeInteger writes an integer using HPACK integer representation
|
||||
func (e *Encoder) writeInteger(buf *bytes.Buffer, value uint32, prefixBits int, prefix byte) error {
|
||||
if prefixBits < 1 || prefixBits > 8 {
|
||||
return fmt.Errorf("invalid prefix bits: %d", prefixBits)
|
||||
}
|
||||
|
||||
maxPrefix := uint32((1 << prefixBits) - 1)
|
||||
|
||||
if value < maxPrefix {
|
||||
buf.WriteByte(prefix | byte(value))
|
||||
return nil
|
||||
}
|
||||
|
||||
// Value >= maxPrefix, need multiple bytes
|
||||
buf.WriteByte(prefix | byte(maxPrefix))
|
||||
value -= maxPrefix
|
||||
|
||||
for value >= 128 {
|
||||
buf.WriteByte(byte(value%128) | 0x80)
|
||||
value /= 128
|
||||
}
|
||||
buf.WriteByte(byte(value))
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// writeString writes a string using HPACK string representation
|
||||
func (e *Encoder) writeString(buf *bytes.Buffer, str string, huffmanEncode bool) error {
|
||||
// For simplicity, we don't use Huffman encoding in this implementation
|
||||
// Huffman flag is bit 7, followed by length in remaining 7 bits
|
||||
|
||||
length := uint32(len(str))
|
||||
if huffmanEncode {
|
||||
// TODO: Implement Huffman encoding if needed
|
||||
return errors.New("huffman encoding not implemented")
|
||||
}
|
||||
|
||||
// Write length with H=0 (no Huffman)
|
||||
if err := e.writeInteger(buf, length, 7, 0x00); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Write string bytes
|
||||
buf.WriteString(str)
|
||||
return nil
|
||||
}
|
||||
|
||||
// SetMaxTableSize updates the dynamic table size
|
||||
func (e *Encoder) SetMaxTableSize(size uint32) {
|
||||
e.mu.Lock()
|
||||
defer e.mu.Unlock()
|
||||
|
||||
e.maxTableSize = size
|
||||
e.dynamicTable.SetMaxSize(size)
|
||||
}
|
||||
|
||||
// Reset clears the dynamic table
|
||||
func (e *Encoder) Reset() {
|
||||
e.mu.Lock()
|
||||
defer e.mu.Unlock()
|
||||
e.dynamicTable = NewDynamicTable(e.maxTableSize)
|
||||
}
|
||||
@@ -1,150 +0,0 @@
|
||||
package hpack
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"sync"
|
||||
)
|
||||
|
||||
// StaticTable implements the HPACK static table (RFC 7541 Appendix A)
|
||||
// The static table is predefined and never changes
|
||||
type StaticTable struct {
|
||||
entries []HeaderField
|
||||
nameMap map[string][]uint32 // Maps name to list of indices
|
||||
}
|
||||
|
||||
var (
|
||||
staticTableInstance *StaticTable
|
||||
staticTableOnce sync.Once
|
||||
)
|
||||
|
||||
// GetStaticTable returns the singleton static table instance
|
||||
func GetStaticTable() *StaticTable {
|
||||
staticTableOnce.Do(func() {
|
||||
staticTableInstance = newStaticTable()
|
||||
})
|
||||
return staticTableInstance
|
||||
}
|
||||
|
||||
// newStaticTable creates and initializes the static table
|
||||
func newStaticTable() *StaticTable {
|
||||
// RFC 7541 Appendix A - Static Table Definition
|
||||
// We include the most common headers for HTTP
|
||||
entries := []HeaderField{
|
||||
{Name: ":authority", Value: ""},
|
||||
{Name: ":method", Value: "GET"},
|
||||
{Name: ":method", Value: "POST"},
|
||||
{Name: ":path", Value: "/"},
|
||||
{Name: ":path", Value: "/index.html"},
|
||||
{Name: ":scheme", Value: "http"},
|
||||
{Name: ":scheme", Value: "https"},
|
||||
{Name: ":status", Value: "200"},
|
||||
{Name: ":status", Value: "204"},
|
||||
{Name: ":status", Value: "206"},
|
||||
{Name: ":status", Value: "304"},
|
||||
{Name: ":status", Value: "400"},
|
||||
{Name: ":status", Value: "404"},
|
||||
{Name: ":status", Value: "500"},
|
||||
{Name: "accept-charset", Value: ""},
|
||||
{Name: "accept-encoding", Value: "gzip, deflate"},
|
||||
{Name: "accept-language", Value: ""},
|
||||
{Name: "accept-ranges", Value: ""},
|
||||
{Name: "accept", Value: ""},
|
||||
{Name: "access-control-allow-origin", Value: ""},
|
||||
{Name: "age", Value: ""},
|
||||
{Name: "allow", Value: ""},
|
||||
{Name: "authorization", Value: ""},
|
||||
{Name: "cache-control", Value: ""},
|
||||
{Name: "content-disposition", Value: ""},
|
||||
{Name: "content-encoding", Value: ""},
|
||||
{Name: "content-language", Value: ""},
|
||||
{Name: "content-length", Value: ""},
|
||||
{Name: "content-location", Value: ""},
|
||||
{Name: "content-range", Value: ""},
|
||||
{Name: "content-type", Value: ""},
|
||||
{Name: "cookie", Value: ""},
|
||||
{Name: "date", Value: ""},
|
||||
{Name: "etag", Value: ""},
|
||||
{Name: "expect", Value: ""},
|
||||
{Name: "expires", Value: ""},
|
||||
{Name: "from", Value: ""},
|
||||
{Name: "host", Value: ""},
|
||||
{Name: "if-match", Value: ""},
|
||||
{Name: "if-modified-since", Value: ""},
|
||||
{Name: "if-none-match", Value: ""},
|
||||
{Name: "if-range", Value: ""},
|
||||
{Name: "if-unmodified-since", Value: ""},
|
||||
{Name: "last-modified", Value: ""},
|
||||
{Name: "link", Value: ""},
|
||||
{Name: "location", Value: ""},
|
||||
{Name: "max-forwards", Value: ""},
|
||||
{Name: "proxy-authenticate", Value: ""},
|
||||
{Name: "proxy-authorization", Value: ""},
|
||||
{Name: "range", Value: ""},
|
||||
{Name: "referer", Value: ""},
|
||||
{Name: "refresh", Value: ""},
|
||||
{Name: "retry-after", Value: ""},
|
||||
{Name: "server", Value: ""},
|
||||
{Name: "set-cookie", Value: ""},
|
||||
{Name: "strict-transport-security", Value: ""},
|
||||
{Name: "transfer-encoding", Value: ""},
|
||||
{Name: "user-agent", Value: ""},
|
||||
{Name: "vary", Value: ""},
|
||||
{Name: "via", Value: ""},
|
||||
{Name: "www-authenticate", Value: ""},
|
||||
}
|
||||
|
||||
// Build name index map
|
||||
nameMap := make(map[string][]uint32)
|
||||
for i, entry := range entries {
|
||||
nameMap[entry.Name] = append(nameMap[entry.Name], uint32(i))
|
||||
}
|
||||
|
||||
return &StaticTable{
|
||||
entries: entries,
|
||||
nameMap: nameMap,
|
||||
}
|
||||
}
|
||||
|
||||
// Get retrieves a header field by index (0-based)
|
||||
func (st *StaticTable) Get(index uint32) (string, string, error) {
|
||||
if index >= uint32(len(st.entries)) {
|
||||
return "", "", fmt.Errorf("index %d out of range (static table size: %d)", index, len(st.entries))
|
||||
}
|
||||
|
||||
field := st.entries[index]
|
||||
return field.Name, field.Value, nil
|
||||
}
|
||||
|
||||
// FindExact searches for an exact match (name and value)
|
||||
// Returns the index (0-based) and true if found
|
||||
func (st *StaticTable) FindExact(name, value string) (uint32, bool) {
|
||||
indices, exists := st.nameMap[name]
|
||||
if !exists {
|
||||
return 0, false
|
||||
}
|
||||
|
||||
for _, index := range indices {
|
||||
field := st.entries[index]
|
||||
if field.Value == value {
|
||||
return index, true
|
||||
}
|
||||
}
|
||||
|
||||
return 0, false
|
||||
}
|
||||
|
||||
// FindName searches for a name match
|
||||
// Returns the first matching index (0-based) and true if found
|
||||
func (st *StaticTable) FindName(name string) (uint32, bool) {
|
||||
indices, exists := st.nameMap[name]
|
||||
if !exists || len(indices) == 0 {
|
||||
return 0, false
|
||||
}
|
||||
|
||||
return indices[0], true
|
||||
}
|
||||
|
||||
// Size returns the number of entries in the static table
|
||||
func (st *StaticTable) Size() int {
|
||||
return len(st.entries)
|
||||
}
|
||||
@@ -9,6 +9,10 @@ const (
|
||||
// DefaultWSPort is the default WebSocket port
|
||||
DefaultWSPort = 8080
|
||||
|
||||
// YamuxAcceptBacklog controls how many incoming streams can be queued
|
||||
// before yamux starts blocking stream opens under load.
|
||||
YamuxAcceptBacklog = 4096
|
||||
|
||||
// HeartbeatInterval is how often clients send heartbeat messages
|
||||
HeartbeatInterval = 2 * time.Second
|
||||
|
||||
|
||||
71
internal/shared/httputil/helpers.go
Normal file
71
internal/shared/httputil/helpers.go
Normal file
@@ -0,0 +1,71 @@
|
||||
package httputil
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"strings"
|
||||
)
|
||||
|
||||
// CopyHeaders copies all headers from src to dst.
|
||||
func CopyHeaders(dst, src http.Header) {
|
||||
for k, vv := range src {
|
||||
for _, v := range vv {
|
||||
dst.Add(k, v)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// CleanHopByHopHeaders removes hop-by-hop headers that should not be forwarded.
|
||||
func CleanHopByHopHeaders(headers http.Header) {
|
||||
if headers == nil {
|
||||
return
|
||||
}
|
||||
|
||||
if connectionHeaders := headers.Get("Connection"); connectionHeaders != "" {
|
||||
for _, token := range strings.Split(connectionHeaders, ",") {
|
||||
if t := strings.TrimSpace(token); t != "" {
|
||||
headers.Del(http.CanonicalHeaderKey(t))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
for _, key := range []string{
|
||||
"Connection",
|
||||
"Keep-Alive",
|
||||
"Proxy-Authenticate",
|
||||
"Proxy-Authorization",
|
||||
"Te",
|
||||
"Trailer",
|
||||
"Transfer-Encoding",
|
||||
"Proxy-Connection",
|
||||
} {
|
||||
headers.Del(key)
|
||||
}
|
||||
}
|
||||
|
||||
// WriteProxyError writes an HTTP error response to the writer.
|
||||
func WriteProxyError(w io.Writer, code int, msg string) {
|
||||
body := msg
|
||||
resp := &http.Response{
|
||||
StatusCode: code,
|
||||
Status: fmt.Sprintf("%d %s", code, http.StatusText(code)),
|
||||
Proto: "HTTP/1.1",
|
||||
ProtoMajor: 1,
|
||||
ProtoMinor: 1,
|
||||
Header: make(http.Header),
|
||||
Body: io.NopCloser(strings.NewReader(body)),
|
||||
ContentLength: int64(len(body)),
|
||||
Close: true,
|
||||
}
|
||||
resp.Header.Set("Content-Type", "text/plain; charset=utf-8")
|
||||
resp.Header.Set("Content-Length", fmt.Sprintf("%d", len(body)))
|
||||
_ = resp.Write(w)
|
||||
_ = resp.Body.Close()
|
||||
}
|
||||
|
||||
// IsWebSocketUpgrade checks if the request is a WebSocket upgrade request.
|
||||
func IsWebSocketUpgrade(req *http.Request) bool {
|
||||
return strings.EqualFold(req.Header.Get("Upgrade"), "websocket") &&
|
||||
strings.Contains(strings.ToLower(req.Header.Get("Connection")), "upgrade")
|
||||
}
|
||||
35
internal/shared/netutil/counting_conn.go
Normal file
35
internal/shared/netutil/counting_conn.go
Normal file
@@ -0,0 +1,35 @@
|
||||
package netutil
|
||||
|
||||
import "net"
|
||||
|
||||
// CountingConn wraps a net.Conn to count bytes read/written.
|
||||
type CountingConn struct {
|
||||
net.Conn
|
||||
OnRead func(int64)
|
||||
OnWrite func(int64)
|
||||
}
|
||||
|
||||
// NewCountingConn creates a new CountingConn.
|
||||
func NewCountingConn(conn net.Conn, onRead, onWrite func(int64)) *CountingConn {
|
||||
return &CountingConn{
|
||||
Conn: conn,
|
||||
OnRead: onRead,
|
||||
OnWrite: onWrite,
|
||||
}
|
||||
}
|
||||
|
||||
func (c *CountingConn) Read(p []byte) (int, error) {
|
||||
n, err := c.Conn.Read(p)
|
||||
if n > 0 && c.OnRead != nil {
|
||||
c.OnRead(int64(n))
|
||||
}
|
||||
return n, err
|
||||
}
|
||||
|
||||
func (c *CountingConn) Write(p []byte) (int, error) {
|
||||
n, err := c.Conn.Write(p)
|
||||
if n > 0 && c.OnWrite != nil {
|
||||
c.OnWrite(int64(n))
|
||||
}
|
||||
return n, err
|
||||
}
|
||||
164
internal/shared/netutil/pipe.go
Normal file
164
internal/shared/netutil/pipe.go
Normal file
@@ -0,0 +1,164 @@
|
||||
package netutil
|
||||
|
||||
import (
|
||||
"context"
|
||||
"io"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"drip/internal/shared/pool"
|
||||
)
|
||||
|
||||
const tcpWaitTimeout = 10 * time.Second
|
||||
|
||||
type closeReader interface {
|
||||
CloseRead() error
|
||||
}
|
||||
|
||||
type closeWriter interface {
|
||||
CloseWrite() error
|
||||
}
|
||||
|
||||
type readDeadliner interface {
|
||||
SetReadDeadline(t time.Time) error
|
||||
}
|
||||
|
||||
// Pipe copies bytes bidirectionally between a and b (gost-like),
|
||||
// and applies TCP half-close when supported.
|
||||
func Pipe(ctx context.Context, a, b io.ReadWriteCloser) error {
|
||||
return PipeWithCallbacksAndBufferSize(ctx, a, b, pool.SizeMedium, nil, nil)
|
||||
}
|
||||
|
||||
// PipeWithCallbacks is Pipe with optional byte counters for each direction:
|
||||
// onAToB is called with bytes copied from a -> b, onBToA for b -> a.
|
||||
func PipeWithCallbacks(ctx context.Context, a, b io.ReadWriteCloser, onAToB func(n int64), onBToA func(n int64)) error {
|
||||
return PipeWithCallbacksAndBufferSize(ctx, a, b, pool.SizeMedium, onAToB, onBToA)
|
||||
}
|
||||
|
||||
// PipeWithBufferSize is Pipe with a custom buffer size.
|
||||
func PipeWithBufferSize(ctx context.Context, a, b io.ReadWriteCloser, bufSize int) error {
|
||||
return PipeWithCallbacksAndBufferSize(ctx, a, b, bufSize, nil, nil)
|
||||
}
|
||||
|
||||
// PipeWithCallbacksAndBufferSize is PipeWithCallbacks with a custom buffer size.
|
||||
func PipeWithCallbacksAndBufferSize(ctx context.Context, a, b io.ReadWriteCloser, bufSize int, onAToB func(n int64), onBToA func(n int64)) error {
|
||||
if bufSize <= 0 {
|
||||
bufSize = pool.SizeMedium
|
||||
}
|
||||
if bufSize > pool.SizeLarge {
|
||||
bufSize = pool.SizeLarge
|
||||
}
|
||||
|
||||
var wg sync.WaitGroup
|
||||
wg.Add(2)
|
||||
|
||||
stopCh := make(chan struct{})
|
||||
var closeOnce sync.Once
|
||||
closeAll := func() {
|
||||
closeOnce.Do(func() {
|
||||
close(stopCh)
|
||||
_ = a.Close()
|
||||
_ = b.Close()
|
||||
})
|
||||
}
|
||||
|
||||
errCh := make(chan error, 2)
|
||||
|
||||
if ctx != nil {
|
||||
go func() {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
closeAll()
|
||||
case <-stopCh:
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
err := pipeBuffer(b, a, bufSize, onAToB, stopCh)
|
||||
if err != nil {
|
||||
errCh <- err
|
||||
}
|
||||
closeAll()
|
||||
}()
|
||||
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
err := pipeBuffer(a, b, bufSize, onBToA, stopCh)
|
||||
if err != nil {
|
||||
errCh <- err
|
||||
}
|
||||
closeAll()
|
||||
}()
|
||||
|
||||
wg.Wait()
|
||||
|
||||
select {
|
||||
case err := <-errCh:
|
||||
return err
|
||||
default:
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
func pipeBuffer(dst io.ReadWriteCloser, src io.ReadWriteCloser, bufSize int, onCopied func(n int64), stopCh <-chan struct{}) error {
|
||||
bufPtr := pool.GetBuffer(bufSize)
|
||||
defer pool.PutBuffer(bufPtr)
|
||||
|
||||
buf := (*bufPtr)[:bufSize]
|
||||
_, err := copyBuffer(dst, src, buf, onCopied, stopCh)
|
||||
|
||||
if cr, ok := src.(closeReader); ok {
|
||||
_ = cr.CloseRead()
|
||||
}
|
||||
|
||||
if cw, ok := dst.(closeWriter); ok {
|
||||
if e := cw.CloseWrite(); e != nil {
|
||||
_ = dst.Close()
|
||||
}
|
||||
if rd, ok := dst.(readDeadliner); ok {
|
||||
_ = rd.SetReadDeadline(time.Now().Add(tcpWaitTimeout))
|
||||
}
|
||||
} else {
|
||||
_ = dst.Close()
|
||||
if rd, ok := dst.(readDeadliner); ok {
|
||||
_ = rd.SetReadDeadline(time.Now().Add(tcpWaitTimeout))
|
||||
}
|
||||
}
|
||||
|
||||
return err
|
||||
}
|
||||
|
||||
func copyBuffer(dst io.Writer, src io.Reader, buf []byte, onCopied func(n int64), stopCh <-chan struct{}) (written int64, err error) {
|
||||
for {
|
||||
select {
|
||||
case <-stopCh:
|
||||
return written, io.EOF
|
||||
default:
|
||||
}
|
||||
|
||||
nr, er := src.Read(buf)
|
||||
if nr > 0 {
|
||||
nw, ew := dst.Write(buf[:nr])
|
||||
if nw > 0 {
|
||||
written += int64(nw)
|
||||
if onCopied != nil {
|
||||
onCopied(int64(nw))
|
||||
}
|
||||
}
|
||||
if ew != nil {
|
||||
return written, ew
|
||||
}
|
||||
if nr != nw {
|
||||
return written, io.ErrShortWrite
|
||||
}
|
||||
}
|
||||
if er != nil {
|
||||
if er == io.EOF {
|
||||
return written, nil
|
||||
}
|
||||
return written, er
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -1,73 +0,0 @@
|
||||
package pool
|
||||
|
||||
import (
|
||||
"sync"
|
||||
)
|
||||
|
||||
// AdaptiveBufferPool manages reusable buffers of different sizes
|
||||
// This eliminates the massive memory allocation overhead seen in profiling
|
||||
type AdaptiveBufferPool struct {
|
||||
// Large buffers for streaming threshold (1MB)
|
||||
largePool *sync.Pool
|
||||
|
||||
// Medium buffers for temporary reads (32KB)
|
||||
mediumPool *sync.Pool
|
||||
}
|
||||
|
||||
const (
|
||||
// LargeBufferSize is 1MB for streaming threshold
|
||||
LargeBufferSize = 1 * 1024 * 1024
|
||||
|
||||
// MediumBufferSize is 32KB for temporary reads
|
||||
MediumBufferSize = 32 * 1024
|
||||
)
|
||||
|
||||
// NewAdaptiveBufferPool creates a new adaptive buffer pool
|
||||
func NewAdaptiveBufferPool() *AdaptiveBufferPool {
|
||||
return &AdaptiveBufferPool{
|
||||
largePool: &sync.Pool{
|
||||
New: func() interface{} {
|
||||
buf := make([]byte, LargeBufferSize)
|
||||
return &buf
|
||||
},
|
||||
},
|
||||
mediumPool: &sync.Pool{
|
||||
New: func() interface{} {
|
||||
buf := make([]byte, MediumBufferSize)
|
||||
return &buf
|
||||
},
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
// GetLarge returns a large buffer (1MB) from the pool
|
||||
// The returned buffer should be returned via PutLarge when done
|
||||
func (p *AdaptiveBufferPool) GetLarge() *[]byte {
|
||||
return p.largePool.Get().(*[]byte)
|
||||
}
|
||||
|
||||
// PutLarge returns a large buffer to the pool for reuse
|
||||
func (p *AdaptiveBufferPool) PutLarge(buf *[]byte) {
|
||||
if buf == nil {
|
||||
return
|
||||
}
|
||||
// Reset to full capacity to allow reuse
|
||||
*buf = (*buf)[:cap(*buf)]
|
||||
p.largePool.Put(buf)
|
||||
}
|
||||
|
||||
// GetMedium returns a medium buffer (32KB) from the pool
|
||||
// The returned buffer should be returned via PutMedium when done
|
||||
func (p *AdaptiveBufferPool) GetMedium() *[]byte {
|
||||
return p.mediumPool.Get().(*[]byte)
|
||||
}
|
||||
|
||||
// PutMedium returns a medium buffer to the pool for reuse
|
||||
func (p *AdaptiveBufferPool) PutMedium(buf *[]byte) {
|
||||
if buf == nil {
|
||||
return
|
||||
}
|
||||
// Reset to full capacity to allow reuse
|
||||
*buf = (*buf)[:cap(*buf)]
|
||||
p.mediumPool.Put(buf)
|
||||
}
|
||||
@@ -1,86 +0,0 @@
|
||||
package pool
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"sync"
|
||||
)
|
||||
|
||||
// HeaderPool manages a pool of http.Header objects for reuse.
|
||||
type HeaderPool struct {
|
||||
pool sync.Pool
|
||||
}
|
||||
|
||||
// NewHeaderPool creates a new header pool
|
||||
func NewHeaderPool() *HeaderPool {
|
||||
return &HeaderPool{
|
||||
pool: sync.Pool{
|
||||
New: func() interface{} {
|
||||
return make(http.Header, 12)
|
||||
},
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
// Get retrieves a header from the pool.
|
||||
func (p *HeaderPool) Get() http.Header {
|
||||
h := p.pool.Get().(http.Header)
|
||||
for k := range h {
|
||||
delete(h, k)
|
||||
}
|
||||
return h
|
||||
}
|
||||
|
||||
// Put returns a header to the pool.
|
||||
func (p *HeaderPool) Put(h http.Header) {
|
||||
if h == nil {
|
||||
return
|
||||
}
|
||||
p.pool.Put(h)
|
||||
}
|
||||
|
||||
// Clone creates a copy of src into dst, reusing dst's underlying storage
|
||||
// This is more efficient than creating a new header from scratch
|
||||
func (p *HeaderPool) Clone(dst, src http.Header) {
|
||||
// Clear dst first
|
||||
for k := range dst {
|
||||
delete(dst, k)
|
||||
}
|
||||
|
||||
// Copy all headers from src to dst
|
||||
for k, vv := range src {
|
||||
// Allocate new slice with exact capacity to avoid over-allocation
|
||||
dst[k] = make([]string, len(vv))
|
||||
copy(dst[k], vv)
|
||||
}
|
||||
}
|
||||
|
||||
// CloneWithExtra clones src into dst and adds/overwrites extra headers
|
||||
// This is optimized for the common pattern of cloning + adding Host header
|
||||
func (p *HeaderPool) CloneWithExtra(dst, src http.Header, extraKey, extraValue string) {
|
||||
// Clear dst first
|
||||
for k := range dst {
|
||||
delete(dst, k)
|
||||
}
|
||||
|
||||
// Copy all headers from src to dst
|
||||
for k, vv := range src {
|
||||
dst[k] = make([]string, len(vv))
|
||||
copy(dst[k], vv)
|
||||
}
|
||||
|
||||
// Set extra header (overwrite if exists)
|
||||
dst.Set(extraKey, extraValue)
|
||||
}
|
||||
|
||||
// globalHeaderPool is a package-level pool for convenience
|
||||
var globalHeaderPool = NewHeaderPool()
|
||||
|
||||
// GetHeader retrieves a header from the global pool
|
||||
func GetHeader() http.Header {
|
||||
return globalHeaderPool.Get()
|
||||
}
|
||||
|
||||
// PutHeader returns a header to the global pool
|
||||
func PutHeader(h http.Header) {
|
||||
globalHeaderPool.Put(h)
|
||||
}
|
||||
@@ -2,81 +2,23 @@ package protocol
|
||||
|
||||
import (
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"drip/internal/shared/pool"
|
||||
)
|
||||
|
||||
// AdaptivePoolManager dynamically adjusts buffer pool usage based on load
|
||||
// AdaptivePoolManager tracks active connections for load monitoring
|
||||
type AdaptivePoolManager struct {
|
||||
activeConnections atomic.Int64
|
||||
currentThreshold atomic.Int64
|
||||
highLoadConnectionThreshold int64
|
||||
midLoadConnectionThreshold int64
|
||||
midLoadThreshold int64
|
||||
highLoadThreshold int64
|
||||
activeConnections atomic.Int64
|
||||
}
|
||||
|
||||
var globalAdaptiveManager = NewAdaptivePoolManager()
|
||||
|
||||
func NewAdaptivePoolManager() *AdaptivePoolManager {
|
||||
m := &AdaptivePoolManager{
|
||||
highLoadConnectionThreshold: 300,
|
||||
midLoadConnectionThreshold: 150,
|
||||
midLoadThreshold: int64(pool.SizeLarge),
|
||||
highLoadThreshold: int64(pool.SizeMedium),
|
||||
}
|
||||
|
||||
m.currentThreshold.Store(m.midLoadThreshold)
|
||||
go m.monitor()
|
||||
|
||||
return m
|
||||
}
|
||||
|
||||
func (m *AdaptivePoolManager) monitor() {
|
||||
ticker := time.NewTicker(1 * time.Second)
|
||||
defer ticker.Stop()
|
||||
|
||||
for range ticker.C {
|
||||
connections := m.activeConnections.Load()
|
||||
|
||||
if connections >= m.highLoadConnectionThreshold {
|
||||
m.currentThreshold.Store(m.highLoadThreshold)
|
||||
} else if connections < m.midLoadConnectionThreshold {
|
||||
m.currentThreshold.Store(m.midLoadThreshold)
|
||||
}
|
||||
// Hysteresis zone (150-300): maintain current threshold
|
||||
}
|
||||
}
|
||||
|
||||
func (m *AdaptivePoolManager) GetThreshold() int {
|
||||
return int(m.currentThreshold.Load())
|
||||
}
|
||||
|
||||
func (m *AdaptivePoolManager) RegisterConnection() {
|
||||
m.activeConnections.Add(1)
|
||||
}
|
||||
|
||||
func (m *AdaptivePoolManager) UnregisterConnection() {
|
||||
m.activeConnections.Add(-1)
|
||||
}
|
||||
|
||||
func (m *AdaptivePoolManager) GetActiveConnections() int64 {
|
||||
return m.activeConnections.Load()
|
||||
}
|
||||
|
||||
func GetAdaptiveThreshold() int {
|
||||
return globalAdaptiveManager.GetThreshold()
|
||||
}
|
||||
var globalAdaptiveManager = &AdaptivePoolManager{}
|
||||
|
||||
func RegisterConnection() {
|
||||
globalAdaptiveManager.RegisterConnection()
|
||||
globalAdaptiveManager.activeConnections.Add(1)
|
||||
}
|
||||
|
||||
func UnregisterConnection() {
|
||||
globalAdaptiveManager.UnregisterConnection()
|
||||
globalAdaptiveManager.activeConnections.Add(-1)
|
||||
}
|
||||
|
||||
func GetActiveConnections() int64 {
|
||||
return globalAdaptiveManager.GetActiveConnections()
|
||||
return globalAdaptiveManager.activeConnections.Load()
|
||||
}
|
||||
|
||||
@@ -1,162 +0,0 @@
|
||||
package protocol
|
||||
|
||||
import (
|
||||
"encoding/binary"
|
||||
"errors"
|
||||
)
|
||||
|
||||
// DataHeader represents a binary-encoded data header for data plane
|
||||
// All data transmission uses pure binary encoding for performance
|
||||
type DataHeader struct {
|
||||
Type DataType
|
||||
IsLast bool
|
||||
StreamID string
|
||||
RequestID string
|
||||
}
|
||||
|
||||
// DataType represents the type of data frame
|
||||
type DataType uint8
|
||||
|
||||
const (
|
||||
DataTypeData DataType = 0x00 // 000
|
||||
DataTypeResponse DataType = 0x01 // 001
|
||||
DataTypeClose DataType = 0x02 // 010
|
||||
DataTypeHTTPRequest DataType = 0x03 // 011
|
||||
DataTypeHTTPResponse DataType = 0x04 // 100
|
||||
DataTypeHTTPHead DataType = 0x05 // 101 - streaming headers (shared)
|
||||
DataTypeHTTPBodyChunk DataType = 0x06 // 110 - streaming body chunks (shared)
|
||||
|
||||
// Reuse the same type codes for request streaming to stay within 3 bits.
|
||||
DataTypeHTTPRequestHead DataType = DataTypeHTTPHead
|
||||
DataTypeHTTPRequestBodyChunk DataType = DataTypeHTTPBodyChunk
|
||||
)
|
||||
|
||||
// String returns the string representation of DataType
|
||||
func (t DataType) String() string {
|
||||
switch t {
|
||||
case DataTypeData:
|
||||
return "data"
|
||||
case DataTypeResponse:
|
||||
return "response"
|
||||
case DataTypeClose:
|
||||
return "close"
|
||||
case DataTypeHTTPRequest:
|
||||
return "http_request"
|
||||
case DataTypeHTTPResponse:
|
||||
return "http_response"
|
||||
case DataTypeHTTPHead:
|
||||
return "http_head"
|
||||
case DataTypeHTTPBodyChunk:
|
||||
return "http_body_chunk"
|
||||
default:
|
||||
return "unknown"
|
||||
}
|
||||
}
|
||||
|
||||
// FromString converts a string to DataType
|
||||
func DataTypeFromString(s string) DataType {
|
||||
switch s {
|
||||
case "data":
|
||||
return DataTypeData
|
||||
case "response":
|
||||
return DataTypeResponse
|
||||
case "close":
|
||||
return DataTypeClose
|
||||
case "http_request":
|
||||
return DataTypeHTTPRequest
|
||||
case "http_response":
|
||||
return DataTypeHTTPResponse
|
||||
case "http_head":
|
||||
return DataTypeHTTPHead
|
||||
case "http_body_chunk":
|
||||
return DataTypeHTTPBodyChunk
|
||||
default:
|
||||
return DataTypeData
|
||||
}
|
||||
}
|
||||
|
||||
// Binary format:
|
||||
// +--------+--------+--------+--------+--------+
|
||||
// | Flags | StreamID Length | RequestID Len |
|
||||
// | 1 byte | 2 bytes | 2 bytes |
|
||||
// +--------+--------+--------+--------+--------+
|
||||
// | StreamID (variable) |
|
||||
// +--------+--------+--------+--------+--------+
|
||||
// | RequestID (variable) |
|
||||
// +--------+--------+--------+--------+--------+
|
||||
//
|
||||
// Flags (8 bits):
|
||||
// - Bit 0-2: Type (3 bits)
|
||||
// - Bit 3: IsLast (1 bit)
|
||||
// - Bit 4-7: Reserved (4 bits)
|
||||
|
||||
const (
|
||||
binaryHeaderMinSize = 5 // 1 byte flags + 2 bytes streamID len + 2 bytes requestID len
|
||||
)
|
||||
|
||||
// MarshalBinary encodes the header to binary format
|
||||
func (h *DataHeader) MarshalBinary() []byte {
|
||||
streamIDLen := len(h.StreamID)
|
||||
requestIDLen := len(h.RequestID)
|
||||
|
||||
totalLen := binaryHeaderMinSize + streamIDLen + requestIDLen
|
||||
buf := make([]byte, totalLen)
|
||||
|
||||
// Encode flags
|
||||
flags := uint8(h.Type) & 0x07 // Type uses bits 0-2
|
||||
if h.IsLast {
|
||||
flags |= 0x08 // IsLast uses bit 3
|
||||
}
|
||||
buf[0] = flags
|
||||
|
||||
// Encode lengths (big-endian)
|
||||
binary.BigEndian.PutUint16(buf[1:3], uint16(streamIDLen))
|
||||
binary.BigEndian.PutUint16(buf[3:5], uint16(requestIDLen))
|
||||
|
||||
// Encode StreamID
|
||||
offset := binaryHeaderMinSize
|
||||
copy(buf[offset:], h.StreamID)
|
||||
offset += streamIDLen
|
||||
|
||||
// Encode RequestID
|
||||
copy(buf[offset:], h.RequestID)
|
||||
|
||||
return buf
|
||||
}
|
||||
|
||||
// UnmarshalBinary decodes the header from binary format
|
||||
func (h *DataHeader) UnmarshalBinary(data []byte) error {
|
||||
if len(data) < binaryHeaderMinSize {
|
||||
return errors.New("invalid binary header: too short")
|
||||
}
|
||||
|
||||
// Decode flags
|
||||
flags := data[0]
|
||||
h.Type = DataType(flags & 0x07) // Bits 0-2
|
||||
h.IsLast = (flags & 0x08) != 0 // Bit 3
|
||||
|
||||
// Decode lengths
|
||||
streamIDLen := int(binary.BigEndian.Uint16(data[1:3]))
|
||||
requestIDLen := int(binary.BigEndian.Uint16(data[3:5]))
|
||||
|
||||
// Validate total length
|
||||
expectedLen := binaryHeaderMinSize + streamIDLen + requestIDLen
|
||||
if len(data) < expectedLen {
|
||||
return errors.New("invalid binary header: length mismatch")
|
||||
}
|
||||
|
||||
// Decode StreamID
|
||||
offset := binaryHeaderMinSize
|
||||
h.StreamID = string(data[offset : offset+streamIDLen])
|
||||
offset += streamIDLen
|
||||
|
||||
// Decode RequestID
|
||||
h.RequestID = string(data[offset : offset+requestIDLen])
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Size returns the size of the binary-encoded header
|
||||
func (h *DataHeader) Size() int {
|
||||
return binaryHeaderMinSize + len(h.StreamID) + len(h.RequestID)
|
||||
}
|
||||
@@ -1,34 +0,0 @@
|
||||
package protocol
|
||||
|
||||
import (
|
||||
json "github.com/goccy/go-json"
|
||||
)
|
||||
|
||||
type FlowControlAction string
|
||||
|
||||
const (
|
||||
FlowControlPause FlowControlAction = "pause"
|
||||
FlowControlResume FlowControlAction = "resume"
|
||||
)
|
||||
|
||||
type FlowControlMessage struct {
|
||||
StreamID string `json:"stream_id"`
|
||||
Action FlowControlAction `json:"action"`
|
||||
}
|
||||
|
||||
func NewFlowControlFrame(streamID string, action FlowControlAction) *Frame {
|
||||
msg := FlowControlMessage{
|
||||
StreamID: streamID,
|
||||
Action: action,
|
||||
}
|
||||
payload, _ := json.Marshal(&msg)
|
||||
return NewFrame(FrameTypeFlowControl, payload)
|
||||
}
|
||||
|
||||
func DecodeFlowControlMessage(payload []byte) (*FlowControlMessage, error) {
|
||||
var msg FlowControlMessage
|
||||
if err := json.Unmarshal(payload, &msg); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &msg, nil
|
||||
}
|
||||
@@ -18,14 +18,14 @@ const (
|
||||
type FrameType byte
|
||||
|
||||
const (
|
||||
FrameTypeRegister FrameType = 0x01
|
||||
FrameTypeRegisterAck FrameType = 0x02
|
||||
FrameTypeHeartbeat FrameType = 0x03
|
||||
FrameTypeHeartbeatAck FrameType = 0x04
|
||||
FrameTypeData FrameType = 0x05
|
||||
FrameTypeClose FrameType = 0x06
|
||||
FrameTypeError FrameType = 0x07
|
||||
FrameTypeFlowControl FrameType = 0x08
|
||||
FrameTypeRegister FrameType = 0x01
|
||||
FrameTypeRegisterAck FrameType = 0x02
|
||||
FrameTypeHeartbeat FrameType = 0x03
|
||||
FrameTypeHeartbeatAck FrameType = 0x04
|
||||
FrameTypeClose FrameType = 0x05
|
||||
FrameTypeError FrameType = 0x06
|
||||
FrameTypeDataConnect FrameType = 0x07
|
||||
FrameTypeDataConnectAck FrameType = 0x08
|
||||
)
|
||||
|
||||
// String returns the string representation of frame type
|
||||
@@ -39,14 +39,14 @@ func (t FrameType) String() string {
|
||||
return "Heartbeat"
|
||||
case FrameTypeHeartbeatAck:
|
||||
return "HeartbeatAck"
|
||||
case FrameTypeData:
|
||||
return "Data"
|
||||
case FrameTypeClose:
|
||||
return "Close"
|
||||
case FrameTypeError:
|
||||
return "Error"
|
||||
case FrameTypeFlowControl:
|
||||
return "FlowControl"
|
||||
case FrameTypeDataConnect:
|
||||
return "DataConnect"
|
||||
case FrameTypeDataConnectAck:
|
||||
return "DataConnectAck"
|
||||
default:
|
||||
return fmt.Sprintf("Unknown(%d)", t)
|
||||
}
|
||||
@@ -56,6 +56,9 @@ type Frame struct {
|
||||
Type FrameType
|
||||
Payload []byte
|
||||
poolBuffer *[]byte
|
||||
// queuedBytes is set by FrameWriter when the frame is enqueued.
|
||||
// It allows the writer to decrement backlog counters exactly once.
|
||||
queuedBytes int64
|
||||
}
|
||||
|
||||
func WriteFrame(w io.Writer, frame *Frame) error {
|
||||
@@ -130,6 +133,8 @@ func (f *Frame) Release() {
|
||||
f.poolBuffer = nil
|
||||
f.Payload = nil
|
||||
}
|
||||
// Reset queued marker to avoid carrying over stale state if the frame is reused.
|
||||
f.queuedBytes = 0
|
||||
}
|
||||
|
||||
// NewFrame creates a new frame
|
||||
|
||||
@@ -1,119 +0,0 @@
|
||||
package protocol
|
||||
|
||||
import (
|
||||
"errors"
|
||||
|
||||
json "github.com/goccy/go-json"
|
||||
|
||||
"github.com/vmihailenco/msgpack/v5"
|
||||
)
|
||||
|
||||
// EncodeHTTPRequest encodes HTTPRequest using msgpack encoding (optimized)
|
||||
func EncodeHTTPRequest(req *HTTPRequest) ([]byte, error) {
|
||||
return msgpack.Marshal(req)
|
||||
}
|
||||
|
||||
// DecodeHTTPRequest decodes HTTPRequest with automatic version detection
|
||||
// Detects based on first byte: '{' = JSON, else = msgpack
|
||||
func DecodeHTTPRequest(data []byte) (*HTTPRequest, error) {
|
||||
if len(data) == 0 {
|
||||
return nil, errors.New("empty data")
|
||||
}
|
||||
|
||||
var req HTTPRequest
|
||||
|
||||
// Auto-detect: JSON starts with '{', msgpack starts with 0x80-0x8f (fixmap)
|
||||
if data[0] == '{' {
|
||||
// v1: JSON
|
||||
if err := json.Unmarshal(data, &req); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
} else {
|
||||
// v2: msgpack
|
||||
if err := msgpack.Unmarshal(data, &req); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
return &req, nil
|
||||
}
|
||||
|
||||
// EncodeHTTPRequestHead encodes HTTP request headers for streaming
|
||||
func EncodeHTTPRequestHead(head *HTTPRequestHead) ([]byte, error) {
|
||||
return msgpack.Marshal(head)
|
||||
}
|
||||
|
||||
// DecodeHTTPRequestHead decodes HTTP request headers for streaming
|
||||
func DecodeHTTPRequestHead(data []byte) (*HTTPRequestHead, error) {
|
||||
if len(data) == 0 {
|
||||
return nil, errors.New("empty data")
|
||||
}
|
||||
|
||||
var head HTTPRequestHead
|
||||
if data[0] == '{' {
|
||||
if err := json.Unmarshal(data, &head); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
} else {
|
||||
if err := msgpack.Unmarshal(data, &head); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
return &head, nil
|
||||
}
|
||||
|
||||
// EncodeHTTPResponse encodes HTTPResponse using msgpack encoding (optimized)
|
||||
func EncodeHTTPResponse(resp *HTTPResponse) ([]byte, error) {
|
||||
return msgpack.Marshal(resp)
|
||||
}
|
||||
|
||||
// DecodeHTTPResponse decodes HTTPResponse with automatic version detection
|
||||
// Detects based on first byte: '{' = JSON, else = msgpack
|
||||
func DecodeHTTPResponse(data []byte) (*HTTPResponse, error) {
|
||||
if len(data) == 0 {
|
||||
return nil, errors.New("empty data")
|
||||
}
|
||||
|
||||
var resp HTTPResponse
|
||||
|
||||
// Auto-detect: JSON starts with '{', msgpack starts with 0x80-0x8f (fixmap)
|
||||
if data[0] == '{' {
|
||||
// v1: JSON
|
||||
if err := json.Unmarshal(data, &resp); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
} else {
|
||||
// v2: msgpack
|
||||
if err := msgpack.Unmarshal(data, &resp); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
return &resp, nil
|
||||
}
|
||||
|
||||
// EncodeHTTPResponseHead encodes HTTP response headers for streaming
|
||||
func EncodeHTTPResponseHead(head *HTTPResponseHead) ([]byte, error) {
|
||||
return msgpack.Marshal(head)
|
||||
}
|
||||
|
||||
// DecodeHTTPResponseHead decodes HTTP response headers for streaming
|
||||
func DecodeHTTPResponseHead(data []byte) (*HTTPResponseHead, error) {
|
||||
if len(data) == 0 {
|
||||
return nil, errors.New("empty data")
|
||||
}
|
||||
|
||||
var head HTTPResponseHead
|
||||
if data[0] == '{' {
|
||||
if err := json.Unmarshal(data, &head); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
} else {
|
||||
if err := msgpack.Unmarshal(data, &head); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
return &head, nil
|
||||
}
|
||||
@@ -1,71 +0,0 @@
|
||||
package protocol
|
||||
|
||||
// MessageType defines the type of tunnel message
|
||||
type MessageType string
|
||||
|
||||
const (
|
||||
// TypeRegister is sent when a client connects and gets a subdomain assigned
|
||||
TypeRegister MessageType = "register"
|
||||
// TypeRequest is sent from server to client when an HTTP request arrives
|
||||
TypeRequest MessageType = "request"
|
||||
// TypeResponse is sent from client to server with the HTTP response
|
||||
TypeResponse MessageType = "response"
|
||||
// TypeHeartbeat is sent periodically to keep the connection alive
|
||||
TypeHeartbeat MessageType = "heartbeat"
|
||||
// TypeError is sent when an error occurs
|
||||
TypeError MessageType = "error"
|
||||
)
|
||||
|
||||
// Message represents a tunnel protocol message
|
||||
type Message struct {
|
||||
Type MessageType `json:"type"`
|
||||
ID string `json:"id,omitempty"`
|
||||
Subdomain string `json:"subdomain,omitempty"`
|
||||
Data map[string]interface{} `json:"data,omitempty"`
|
||||
Error string `json:"error,omitempty"`
|
||||
}
|
||||
|
||||
// HTTPRequest represents an HTTP request to be forwarded
|
||||
type HTTPRequest struct {
|
||||
Method string `json:"method"`
|
||||
URL string `json:"url"`
|
||||
Headers map[string][]string `json:"headers"`
|
||||
Body []byte `json:"body,omitempty"`
|
||||
}
|
||||
|
||||
// HTTPRequestHead represents HTTP request headers for streaming (no body)
|
||||
type HTTPRequestHead struct {
|
||||
Method string `json:"method"`
|
||||
URL string `json:"url"`
|
||||
Headers map[string][]string `json:"headers"`
|
||||
ContentLength int64 `json:"content_length"` // -1 for unknown/chunked
|
||||
}
|
||||
|
||||
// HTTPResponse represents an HTTP response from the local service
|
||||
type HTTPResponse struct {
|
||||
StatusCode int `json:"status_code"`
|
||||
Status string `json:"status"`
|
||||
Headers map[string][]string `json:"headers"`
|
||||
Body []byte `json:"body,omitempty"`
|
||||
}
|
||||
|
||||
// HTTPResponseHead represents HTTP response headers for streaming (no body)
|
||||
type HTTPResponseHead struct {
|
||||
StatusCode int `json:"status_code"`
|
||||
Status string `json:"status"`
|
||||
Headers map[string][]string `json:"headers"`
|
||||
ContentLength int64 `json:"content_length"` // -1 for unknown/chunked
|
||||
}
|
||||
|
||||
// RegisterData contains information sent when a tunnel is registered
|
||||
type RegisterData struct {
|
||||
Subdomain string `json:"subdomain"`
|
||||
URL string `json:"url"`
|
||||
Message string `json:"message"`
|
||||
}
|
||||
|
||||
// ErrorData contains error information
|
||||
type ErrorData struct {
|
||||
Code string `json:"code"`
|
||||
Message string `json:"message"`
|
||||
}
|
||||
@@ -2,12 +2,23 @@ package protocol
|
||||
|
||||
import json "github.com/goccy/go-json"
|
||||
|
||||
// PoolCapabilities advertises client connection pool capabilities
|
||||
type PoolCapabilities struct {
|
||||
MaxDataConns int `json:"max_data_conns"` // Maximum data connections client supports
|
||||
Version int `json:"version"` // Protocol version for pool features
|
||||
}
|
||||
|
||||
// RegisterRequest is sent by client to register a tunnel
|
||||
type RegisterRequest struct {
|
||||
Token string `json:"token"` // Authentication token
|
||||
CustomSubdomain string `json:"custom_subdomain"` // Optional custom subdomain
|
||||
TunnelType TunnelType `json:"tunnel_type"` // http, tcp, udp
|
||||
LocalPort int `json:"local_port"` // Local port to forward to
|
||||
|
||||
// Connection pool fields (optional, for multi-connection support)
|
||||
ConnectionType string `json:"connection_type,omitempty"` // "primary" or empty for legacy
|
||||
TunnelID string `json:"tunnel_id,omitempty"` // For data connections to join
|
||||
PoolCapabilities *PoolCapabilities `json:"pool_capabilities,omitempty"` // Client pool capabilities
|
||||
}
|
||||
|
||||
// RegisterResponse is sent by server after successful registration
|
||||
@@ -16,6 +27,25 @@ type RegisterResponse struct {
|
||||
Port int `json:"port,omitempty"` // Assigned TCP port (for TCP tunnels)
|
||||
URL string `json:"url"` // Full tunnel URL
|
||||
Message string `json:"message"` // Success message
|
||||
|
||||
// Connection pool fields (optional, for multi-connection support)
|
||||
TunnelID string `json:"tunnel_id,omitempty"` // Unique tunnel identifier
|
||||
SupportsDataConn bool `json:"supports_data_conn,omitempty"` // Server supports multi-connection
|
||||
RecommendedConns int `json:"recommended_conns,omitempty"` // Suggested data connection count
|
||||
}
|
||||
|
||||
// DataConnectRequest is sent by data connections to join a tunnel
|
||||
type DataConnectRequest struct {
|
||||
TunnelID string `json:"tunnel_id"` // Tunnel to join
|
||||
Token string `json:"token"` // Same auth token as primary
|
||||
ConnectionID string `json:"connection_id"` // Unique connection identifier
|
||||
}
|
||||
|
||||
// DataConnectResponse acknowledges data connection
|
||||
type DataConnectResponse struct {
|
||||
Accepted bool `json:"accepted"` // Whether connection was accepted
|
||||
ConnectionID string `json:"connection_id"` // Echoed connection ID
|
||||
Message string `json:"message,omitempty"` // Optional message
|
||||
}
|
||||
|
||||
// ErrorMessage represents an error
|
||||
@@ -24,9 +54,6 @@ type ErrorMessage struct {
|
||||
Message string `json:"message"` // Error message
|
||||
}
|
||||
|
||||
// Note: DataHeader is now defined in binary_header.go as a pure binary structure
|
||||
// TCPData has been removed - use DataHeader + raw bytes directly
|
||||
|
||||
// Marshal helpers for control plane messages (JSON encoding)
|
||||
func MarshalJSON(v interface{}) ([]byte, error) {
|
||||
return json.Marshal(v)
|
||||
|
||||
@@ -1,96 +0,0 @@
|
||||
package protocol
|
||||
|
||||
import (
|
||||
"encoding/binary"
|
||||
"errors"
|
||||
|
||||
"drip/internal/shared/pool"
|
||||
)
|
||||
|
||||
// encodeDataPayload encodes a data header and payload into a frame payload.
|
||||
func encodeDataPayload(header DataHeader, data []byte) ([]byte, error) {
|
||||
streamIDLen := len(header.StreamID)
|
||||
requestIDLen := len(header.RequestID)
|
||||
|
||||
totalLen := binaryHeaderMinSize + streamIDLen + requestIDLen + len(data)
|
||||
payload := make([]byte, totalLen)
|
||||
|
||||
flags := uint8(header.Type) & 0x07
|
||||
if header.IsLast {
|
||||
flags |= 0x08
|
||||
}
|
||||
payload[0] = flags
|
||||
|
||||
binary.BigEndian.PutUint16(payload[1:3], uint16(streamIDLen))
|
||||
binary.BigEndian.PutUint16(payload[3:5], uint16(requestIDLen))
|
||||
|
||||
offset := binaryHeaderMinSize
|
||||
copy(payload[offset:], header.StreamID)
|
||||
offset += streamIDLen
|
||||
copy(payload[offset:], header.RequestID)
|
||||
offset += requestIDLen
|
||||
copy(payload[offset:], data)
|
||||
|
||||
return payload, nil
|
||||
}
|
||||
|
||||
// EncodeDataPayloadPooled encodes with adaptive allocation based on load.
|
||||
// Returns payload slice and pool buffer pointer (may be nil).
|
||||
func EncodeDataPayloadPooled(header DataHeader, data []byte) (payload []byte, poolBuffer *[]byte, err error) {
|
||||
streamIDLen := len(header.StreamID)
|
||||
requestIDLen := len(header.RequestID)
|
||||
totalLen := binaryHeaderMinSize + streamIDLen + requestIDLen + len(data)
|
||||
|
||||
dynamicThreshold := GetAdaptiveThreshold()
|
||||
|
||||
if totalLen < dynamicThreshold {
|
||||
regularPayload, err := encodeDataPayload(header, data)
|
||||
return regularPayload, nil, err
|
||||
}
|
||||
|
||||
if totalLen > pool.SizeLarge {
|
||||
regularPayload, err := encodeDataPayload(header, data)
|
||||
return regularPayload, nil, err
|
||||
}
|
||||
|
||||
poolBuffer = pool.GetBuffer(totalLen)
|
||||
payload = (*poolBuffer)[:totalLen]
|
||||
|
||||
flags := uint8(header.Type) & 0x07
|
||||
if header.IsLast {
|
||||
flags |= 0x08
|
||||
}
|
||||
payload[0] = flags
|
||||
|
||||
binary.BigEndian.PutUint16(payload[1:3], uint16(streamIDLen))
|
||||
binary.BigEndian.PutUint16(payload[3:5], uint16(requestIDLen))
|
||||
|
||||
offset := binaryHeaderMinSize
|
||||
copy(payload[offset:], header.StreamID)
|
||||
offset += streamIDLen
|
||||
copy(payload[offset:], header.RequestID)
|
||||
offset += requestIDLen
|
||||
copy(payload[offset:], data)
|
||||
|
||||
return payload, poolBuffer, nil
|
||||
}
|
||||
|
||||
// DecodeDataPayload decodes a frame payload into header and data.
|
||||
func DecodeDataPayload(payload []byte) (DataHeader, []byte, error) {
|
||||
if len(payload) < binaryHeaderMinSize {
|
||||
return DataHeader{}, nil, errors.New("invalid payload: too short")
|
||||
}
|
||||
|
||||
var header DataHeader
|
||||
if err := header.UnmarshalBinary(payload); err != nil {
|
||||
return DataHeader{}, nil, err
|
||||
}
|
||||
|
||||
headerSize := header.Size()
|
||||
if len(payload) < headerSize {
|
||||
return DataHeader{}, nil, errors.New("invalid payload: data missing")
|
||||
}
|
||||
|
||||
data := payload[headerSize:]
|
||||
return header, data, nil
|
||||
}
|
||||
@@ -4,20 +4,11 @@ import (
|
||||
"sync"
|
||||
)
|
||||
|
||||
// SafeFrame wraps Frame with automatic resource cleanup
|
||||
type SafeFrame struct {
|
||||
*Frame
|
||||
once sync.Once
|
||||
}
|
||||
|
||||
// NewSafeFrame creates a SafeFrame that implements io.Closer
|
||||
func NewSafeFrame(frameType FrameType, payload []byte) *SafeFrame {
|
||||
return &SafeFrame{
|
||||
Frame: NewFrame(frameType, payload),
|
||||
}
|
||||
}
|
||||
|
||||
// Close implements io.Closer, ensures Release is called exactly once
|
||||
func (sf *SafeFrame) Close() error {
|
||||
sf.once.Do(func() {
|
||||
if sf.Frame != nil {
|
||||
@@ -27,14 +18,6 @@ func (sf *SafeFrame) Close() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// WithFrame wraps an existing Frame with automatic cleanup
|
||||
func WithFrame(frame *Frame) *SafeFrame {
|
||||
return &SafeFrame{Frame: frame}
|
||||
}
|
||||
|
||||
// MustClose is a helper that calls Close and panics on error (for defer cleanup)
|
||||
func (sf *SafeFrame) MustClose() {
|
||||
if err := sf.Close(); err != nil {
|
||||
panic(err)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -4,16 +4,18 @@ import (
|
||||
"errors"
|
||||
"io"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
)
|
||||
|
||||
type FrameWriter struct {
|
||||
conn io.Writer
|
||||
queue chan *Frame
|
||||
batch []*Frame
|
||||
mu sync.Mutex
|
||||
done chan struct{}
|
||||
closed bool
|
||||
conn io.Writer
|
||||
queue chan *Frame
|
||||
controlQueue chan *Frame
|
||||
batch []*Frame
|
||||
mu sync.Mutex
|
||||
done chan struct{}
|
||||
closed bool
|
||||
|
||||
maxBatch int
|
||||
maxBatchWait time.Duration
|
||||
@@ -24,13 +26,20 @@ type FrameWriter struct {
|
||||
heartbeatControl chan struct{}
|
||||
|
||||
// Error handling
|
||||
writeErr error
|
||||
errOnce sync.Once
|
||||
onWriteError func(error) // Callback for write errors
|
||||
writeErr error
|
||||
errOnce sync.Once
|
||||
onWriteError func(error) // Callback for write errors
|
||||
|
||||
// Adaptive flushing
|
||||
adaptiveFlush bool // Enable adaptive flush based on queue depth
|
||||
lowConcurrencyThreshold int // Queue depth threshold for immediate flush
|
||||
adaptiveFlush bool // Enable adaptive flush based on queue depth
|
||||
lowConcurrencyThreshold int // Queue depth threshold for immediate flush
|
||||
|
||||
// Hooks
|
||||
preWriteHook func(*Frame) // Called right before a frame is written to conn
|
||||
|
||||
// Backlog tracking
|
||||
queuedFrames atomic.Int64
|
||||
queuedBytes atomic.Int64
|
||||
}
|
||||
|
||||
func NewFrameWriter(conn io.Writer) *FrameWriter {
|
||||
@@ -41,8 +50,14 @@ func NewFrameWriter(conn io.Writer) *FrameWriter {
|
||||
|
||||
func NewFrameWriterWithConfig(conn io.Writer, maxBatch int, maxBatchWait time.Duration, queueSize int) *FrameWriter {
|
||||
w := &FrameWriter{
|
||||
conn: conn,
|
||||
queue: make(chan *Frame, queueSize),
|
||||
conn: conn,
|
||||
queue: make(chan *Frame, queueSize),
|
||||
controlQueue: make(chan *Frame, func() int {
|
||||
if queueSize < 256 {
|
||||
return queueSize
|
||||
}
|
||||
return 256
|
||||
}()), // control path needs small, fast buffer
|
||||
batch: make([]*Frame, 0, maxBatch),
|
||||
maxBatch: maxBatch,
|
||||
maxBatchWait: maxBatchWait,
|
||||
@@ -74,6 +89,22 @@ func (w *FrameWriter) writeLoop() {
|
||||
}()
|
||||
|
||||
for {
|
||||
// Always drain control queue first to prioritize control/heartbeat frames.
|
||||
select {
|
||||
case frame, ok := <-w.controlQueue:
|
||||
if !ok {
|
||||
w.mu.Lock()
|
||||
w.flushBatchLocked()
|
||||
w.mu.Unlock()
|
||||
return
|
||||
}
|
||||
w.mu.Lock()
|
||||
w.flushFrameLocked(frame)
|
||||
w.mu.Unlock()
|
||||
continue
|
||||
default:
|
||||
}
|
||||
|
||||
select {
|
||||
case frame, ok := <-w.queue:
|
||||
if !ok {
|
||||
@@ -105,8 +136,7 @@ func (w *FrameWriter) writeLoop() {
|
||||
w.mu.Lock()
|
||||
if w.heartbeatCallback != nil {
|
||||
if frame := w.heartbeatCallback(); frame != nil {
|
||||
w.batch = append(w.batch, frame)
|
||||
w.flushBatchLocked()
|
||||
w.flushFrameLocked(frame)
|
||||
}
|
||||
}
|
||||
w.mu.Unlock()
|
||||
@@ -139,22 +169,47 @@ func (w *FrameWriter) flushBatchLocked() {
|
||||
}
|
||||
|
||||
for _, frame := range w.batch {
|
||||
if err := WriteFrame(w.conn, frame); err != nil {
|
||||
w.errOnce.Do(func() {
|
||||
w.writeErr = err
|
||||
if w.onWriteError != nil {
|
||||
go w.onWriteError(err)
|
||||
}
|
||||
w.closed = true
|
||||
})
|
||||
}
|
||||
frame.Release()
|
||||
w.flushFrameLocked(frame)
|
||||
}
|
||||
|
||||
w.batch = w.batch[:0]
|
||||
}
|
||||
|
||||
// flushFrameLocked writes a single frame immediately. Caller must hold w.mu.
|
||||
func (w *FrameWriter) flushFrameLocked(frame *Frame) {
|
||||
if frame == nil {
|
||||
return
|
||||
}
|
||||
|
||||
if w.preWriteHook != nil {
|
||||
w.preWriteHook(frame)
|
||||
}
|
||||
|
||||
if err := WriteFrame(w.conn, frame); err != nil {
|
||||
w.errOnce.Do(func() {
|
||||
w.writeErr = err
|
||||
if w.onWriteError != nil {
|
||||
go w.onWriteError(err)
|
||||
}
|
||||
w.closed = true
|
||||
})
|
||||
}
|
||||
|
||||
w.unmarkQueued(frame)
|
||||
frame.Release()
|
||||
}
|
||||
|
||||
func (w *FrameWriter) WriteFrame(frame *Frame) error {
|
||||
return w.WriteFrameWithCancel(frame, nil)
|
||||
}
|
||||
|
||||
// WriteFrameWithCancel writes a frame with an optional cancellation channel
|
||||
// If cancel is closed, the write will be aborted immediately
|
||||
func (w *FrameWriter) WriteFrameWithCancel(frame *Frame, cancel <-chan struct{}) error {
|
||||
if frame == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
w.mu.Lock()
|
||||
if w.closed {
|
||||
w.mu.Unlock()
|
||||
@@ -165,10 +220,19 @@ func (w *FrameWriter) WriteFrame(frame *Frame) error {
|
||||
}
|
||||
w.mu.Unlock()
|
||||
|
||||
size := int64(len(frame.Payload) + FrameHeaderSize)
|
||||
w.queuedFrames.Add(1)
|
||||
w.queuedBytes.Add(size)
|
||||
atomic.StoreInt64(&frame.queuedBytes, size)
|
||||
|
||||
// Try non-blocking first for best performance
|
||||
select {
|
||||
case w.queue <- frame:
|
||||
return nil
|
||||
case <-w.done:
|
||||
w.queuedFrames.Add(-1)
|
||||
w.queuedBytes.Add(-size)
|
||||
atomic.StoreInt64(&frame.queuedBytes, 0)
|
||||
w.mu.Lock()
|
||||
err := w.writeErr
|
||||
w.mu.Unlock()
|
||||
@@ -176,6 +240,54 @@ func (w *FrameWriter) WriteFrame(frame *Frame) error {
|
||||
return err
|
||||
}
|
||||
return errors.New("writer closed")
|
||||
default:
|
||||
}
|
||||
|
||||
// Queue full - block with cancellation support
|
||||
if cancel != nil {
|
||||
select {
|
||||
case w.queue <- frame:
|
||||
return nil
|
||||
case <-w.done:
|
||||
w.queuedFrames.Add(-1)
|
||||
w.queuedBytes.Add(-size)
|
||||
atomic.StoreInt64(&frame.queuedBytes, 0)
|
||||
w.mu.Lock()
|
||||
err := w.writeErr
|
||||
w.mu.Unlock()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return errors.New("writer closed")
|
||||
case <-cancel:
|
||||
w.queuedFrames.Add(-1)
|
||||
w.queuedBytes.Add(-size)
|
||||
atomic.StoreInt64(&frame.queuedBytes, 0)
|
||||
return errors.New("write cancelled")
|
||||
}
|
||||
}
|
||||
|
||||
// No cancel channel - block with timeout
|
||||
select {
|
||||
case w.queue <- frame:
|
||||
return nil
|
||||
case <-w.done:
|
||||
w.queuedFrames.Add(-1)
|
||||
w.queuedBytes.Add(-size)
|
||||
atomic.StoreInt64(&frame.queuedBytes, 0)
|
||||
|
||||
w.mu.Lock()
|
||||
err := w.writeErr
|
||||
w.mu.Unlock()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return errors.New("writer closed")
|
||||
case <-time.After(30 * time.Second):
|
||||
w.queuedFrames.Add(-1)
|
||||
w.queuedBytes.Add(-size)
|
||||
atomic.StoreInt64(&frame.queuedBytes, 0)
|
||||
return errors.New("write queue full timeout")
|
||||
}
|
||||
}
|
||||
|
||||
@@ -189,8 +301,14 @@ func (w *FrameWriter) Close() error {
|
||||
w.mu.Unlock()
|
||||
|
||||
close(w.queue)
|
||||
close(w.controlQueue)
|
||||
|
||||
for frame := range w.queue {
|
||||
w.unmarkQueued(frame)
|
||||
frame.Release()
|
||||
}
|
||||
for frame := range w.controlQueue {
|
||||
w.unmarkQueued(frame)
|
||||
frame.Release()
|
||||
}
|
||||
|
||||
@@ -264,3 +382,97 @@ func (w *FrameWriter) DisableAdaptiveFlush() {
|
||||
w.adaptiveFlush = false
|
||||
w.mu.Unlock()
|
||||
}
|
||||
|
||||
// WriteControl enqueues a control/prioritized frame to be written ahead of data frames.
|
||||
func (w *FrameWriter) WriteControl(frame *Frame) error {
|
||||
if frame == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
w.mu.Lock()
|
||||
if w.closed {
|
||||
w.mu.Unlock()
|
||||
if w.writeErr != nil {
|
||||
return w.writeErr
|
||||
}
|
||||
return errors.New("writer closed")
|
||||
}
|
||||
w.mu.Unlock()
|
||||
|
||||
size := int64(len(frame.Payload) + FrameHeaderSize)
|
||||
w.queuedFrames.Add(1)
|
||||
w.queuedBytes.Add(size)
|
||||
atomic.StoreInt64(&frame.queuedBytes, size)
|
||||
|
||||
// Try non-blocking first
|
||||
select {
|
||||
case w.controlQueue <- frame:
|
||||
return nil
|
||||
case <-w.done:
|
||||
w.queuedFrames.Add(-1)
|
||||
w.queuedBytes.Add(-size)
|
||||
atomic.StoreInt64(&frame.queuedBytes, 0)
|
||||
w.mu.Lock()
|
||||
err := w.writeErr
|
||||
w.mu.Unlock()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return errors.New("writer closed")
|
||||
default:
|
||||
}
|
||||
|
||||
// Queue full - wait with timeout
|
||||
select {
|
||||
case w.controlQueue <- frame:
|
||||
return nil
|
||||
case <-w.done:
|
||||
w.queuedFrames.Add(-1)
|
||||
w.queuedBytes.Add(-size)
|
||||
atomic.StoreInt64(&frame.queuedBytes, 0)
|
||||
|
||||
w.mu.Lock()
|
||||
err := w.writeErr
|
||||
w.mu.Unlock()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return errors.New("writer closed")
|
||||
case <-time.After(50 * time.Millisecond):
|
||||
// Control frames should have priority, shorter timeout
|
||||
w.queuedFrames.Add(-1)
|
||||
w.queuedBytes.Add(-size)
|
||||
atomic.StoreInt64(&frame.queuedBytes, 0)
|
||||
return errors.New("control queue full timeout")
|
||||
}
|
||||
}
|
||||
|
||||
// SetPreWriteHook registers a callback invoked just before a frame is written to the underlying writer.
|
||||
func (w *FrameWriter) SetPreWriteHook(hook func(*Frame)) {
|
||||
w.mu.Lock()
|
||||
w.preWriteHook = hook
|
||||
w.mu.Unlock()
|
||||
}
|
||||
|
||||
// QueuedFrames returns the number of frames currently queued (data + control).
|
||||
func (w *FrameWriter) QueuedFrames() int64 {
|
||||
return w.queuedFrames.Load()
|
||||
}
|
||||
|
||||
// QueuedBytes returns the approximate number of bytes currently queued.
|
||||
func (w *FrameWriter) QueuedBytes() int64 {
|
||||
return w.queuedBytes.Load()
|
||||
}
|
||||
|
||||
// unmarkQueued decrements backlog counters for a frame once it is written or discarded.
|
||||
func (w *FrameWriter) unmarkQueued(frame *Frame) {
|
||||
if frame == nil {
|
||||
return
|
||||
}
|
||||
size := atomic.SwapInt64(&frame.queuedBytes, 0)
|
||||
if size <= 0 {
|
||||
return
|
||||
}
|
||||
w.queuedFrames.Add(-1)
|
||||
w.queuedBytes.Add(-size)
|
||||
}
|
||||
|
||||
77
internal/shared/stats/format.go
Normal file
77
internal/shared/stats/format.go
Normal file
@@ -0,0 +1,77 @@
|
||||
package stats
|
||||
|
||||
// FormatBytes formats bytes to human readable string
|
||||
func FormatBytes(bytes int64) string {
|
||||
const (
|
||||
KB = 1024
|
||||
MB = KB * 1024
|
||||
GB = MB * 1024
|
||||
)
|
||||
|
||||
switch {
|
||||
case bytes >= GB:
|
||||
return formatFloat(float64(bytes)/float64(GB)) + " GB"
|
||||
case bytes >= MB:
|
||||
return formatFloat(float64(bytes)/float64(MB)) + " MB"
|
||||
case bytes >= KB:
|
||||
return formatFloat(float64(bytes)/float64(KB)) + " KB"
|
||||
default:
|
||||
return formatInt(bytes) + " B"
|
||||
}
|
||||
}
|
||||
|
||||
// FormatSpeed formats speed (bytes per second) to human readable string
|
||||
func FormatSpeed(bytesPerSec int64) string {
|
||||
if bytesPerSec == 0 {
|
||||
return "0 B/s"
|
||||
}
|
||||
return FormatBytes(bytesPerSec) + "/s"
|
||||
}
|
||||
|
||||
func formatFloat(f float64) string {
|
||||
if f >= 100 {
|
||||
return formatInt(int64(f))
|
||||
} else if f >= 10 {
|
||||
return formatOneDecimal(f)
|
||||
}
|
||||
return formatTwoDecimal(f)
|
||||
}
|
||||
|
||||
func formatInt(i int64) string {
|
||||
return intToStr(i)
|
||||
}
|
||||
|
||||
func formatOneDecimal(f float64) string {
|
||||
i := int64(f * 10)
|
||||
whole := i / 10
|
||||
frac := i % 10
|
||||
return intToStr(whole) + "." + intToStr(frac)
|
||||
}
|
||||
|
||||
func formatTwoDecimal(f float64) string {
|
||||
i := int64(f * 100)
|
||||
whole := i / 100
|
||||
frac := i % 100
|
||||
if frac < 10 {
|
||||
return intToStr(whole) + ".0" + intToStr(frac)
|
||||
}
|
||||
return intToStr(whole) + "." + intToStr(frac)
|
||||
}
|
||||
|
||||
func intToStr(i int64) string {
|
||||
if i == 0 {
|
||||
return "0"
|
||||
}
|
||||
if i < 0 {
|
||||
return "-" + intToStr(-i)
|
||||
}
|
||||
|
||||
var buf [20]byte
|
||||
pos := len(buf)
|
||||
for i > 0 {
|
||||
pos--
|
||||
buf[pos] = byte('0' + i%10)
|
||||
i /= 10
|
||||
}
|
||||
return string(buf[pos:])
|
||||
}
|
||||
184
internal/shared/stats/stats.go
Normal file
184
internal/shared/stats/stats.go
Normal file
@@ -0,0 +1,184 @@
|
||||
package stats
|
||||
|
||||
import (
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
)
|
||||
|
||||
// TrafficStats tracks traffic statistics for a tunnel connection
|
||||
type TrafficStats struct {
|
||||
// Total bytes
|
||||
totalBytesIn int64
|
||||
totalBytesOut int64
|
||||
|
||||
// Request counts
|
||||
totalRequests int64
|
||||
activeConnections int64
|
||||
|
||||
// For speed calculation
|
||||
lastBytesIn int64
|
||||
lastBytesOut int64
|
||||
lastTime time.Time
|
||||
speedMu sync.Mutex
|
||||
|
||||
// Current speed (bytes per second)
|
||||
speedIn int64
|
||||
speedOut int64
|
||||
|
||||
// Start time
|
||||
startTime time.Time
|
||||
}
|
||||
|
||||
// NewTrafficStats creates a new traffic stats tracker
|
||||
func NewTrafficStats() *TrafficStats {
|
||||
now := time.Now()
|
||||
return &TrafficStats{
|
||||
startTime: now,
|
||||
lastTime: now,
|
||||
}
|
||||
}
|
||||
|
||||
// AddBytesIn adds incoming bytes to the counter
|
||||
func (s *TrafficStats) AddBytesIn(n int64) {
|
||||
atomic.AddInt64(&s.totalBytesIn, n)
|
||||
}
|
||||
|
||||
// AddBytesOut adds outgoing bytes to the counter
|
||||
func (s *TrafficStats) AddBytesOut(n int64) {
|
||||
atomic.AddInt64(&s.totalBytesOut, n)
|
||||
}
|
||||
|
||||
// AddRequest increments the request counter
|
||||
func (s *TrafficStats) AddRequest() {
|
||||
atomic.AddInt64(&s.totalRequests, 1)
|
||||
}
|
||||
|
||||
func (s *TrafficStats) IncActiveConnections() {
|
||||
atomic.AddInt64(&s.activeConnections, 1)
|
||||
}
|
||||
|
||||
func (s *TrafficStats) DecActiveConnections() {
|
||||
v := atomic.AddInt64(&s.activeConnections, -1)
|
||||
if v < 0 {
|
||||
atomic.StoreInt64(&s.activeConnections, 0)
|
||||
}
|
||||
}
|
||||
|
||||
// GetTotalBytesIn returns total incoming bytes
|
||||
func (s *TrafficStats) GetTotalBytesIn() int64 {
|
||||
return atomic.LoadInt64(&s.totalBytesIn)
|
||||
}
|
||||
|
||||
// GetTotalBytesOut returns total outgoing bytes
|
||||
func (s *TrafficStats) GetTotalBytesOut() int64 {
|
||||
return atomic.LoadInt64(&s.totalBytesOut)
|
||||
}
|
||||
|
||||
// GetTotalRequests returns total request count
|
||||
func (s *TrafficStats) GetTotalRequests() int64 {
|
||||
return atomic.LoadInt64(&s.totalRequests)
|
||||
}
|
||||
|
||||
func (s *TrafficStats) GetActiveConnections() int64 {
|
||||
return atomic.LoadInt64(&s.activeConnections)
|
||||
}
|
||||
|
||||
// GetTotalBytes returns total bytes (in + out)
|
||||
func (s *TrafficStats) GetTotalBytes() int64 {
|
||||
return s.GetTotalBytesIn() + s.GetTotalBytesOut()
|
||||
}
|
||||
|
||||
// UpdateSpeed calculates current transfer speed
|
||||
// Should be called periodically (e.g., every second)
|
||||
func (s *TrafficStats) UpdateSpeed() {
|
||||
s.speedMu.Lock()
|
||||
defer s.speedMu.Unlock()
|
||||
|
||||
now := time.Now()
|
||||
elapsed := now.Sub(s.lastTime).Seconds()
|
||||
|
||||
// Require minimum interval of 100ms to avoid division issues
|
||||
if elapsed < 0.1 {
|
||||
return
|
||||
}
|
||||
|
||||
currentIn := atomic.LoadInt64(&s.totalBytesIn)
|
||||
currentOut := atomic.LoadInt64(&s.totalBytesOut)
|
||||
|
||||
deltaIn := currentIn - s.lastBytesIn
|
||||
deltaOut := currentOut - s.lastBytesOut
|
||||
|
||||
// Calculate instantaneous speed
|
||||
if deltaIn > 0 {
|
||||
s.speedIn = int64(float64(deltaIn) / elapsed)
|
||||
} else {
|
||||
// No new bytes - set speed to 0 immediately
|
||||
s.speedIn = 0
|
||||
}
|
||||
|
||||
if deltaOut > 0 {
|
||||
s.speedOut = int64(float64(deltaOut) / elapsed)
|
||||
} else {
|
||||
// No new bytes - set speed to 0 immediately
|
||||
s.speedOut = 0
|
||||
}
|
||||
|
||||
s.lastBytesIn = currentIn
|
||||
s.lastBytesOut = currentOut
|
||||
s.lastTime = now
|
||||
}
|
||||
|
||||
// GetSpeedIn returns current incoming speed in bytes per second
|
||||
func (s *TrafficStats) GetSpeedIn() int64 {
|
||||
s.speedMu.Lock()
|
||||
defer s.speedMu.Unlock()
|
||||
return s.speedIn
|
||||
}
|
||||
|
||||
// GetSpeedOut returns current outgoing speed in bytes per second
|
||||
func (s *TrafficStats) GetSpeedOut() int64 {
|
||||
s.speedMu.Lock()
|
||||
defer s.speedMu.Unlock()
|
||||
return s.speedOut
|
||||
}
|
||||
|
||||
// GetUptime returns how long the connection has been active
|
||||
func (s *TrafficStats) GetUptime() time.Duration {
|
||||
return time.Since(s.startTime)
|
||||
}
|
||||
|
||||
// Snapshot returns a snapshot of all stats
|
||||
type Snapshot struct {
|
||||
TotalBytesIn int64
|
||||
TotalBytesOut int64
|
||||
TotalBytes int64
|
||||
TotalRequests int64
|
||||
ActiveConnections int64
|
||||
SpeedIn int64 // bytes per second
|
||||
SpeedOut int64 // bytes per second
|
||||
Uptime time.Duration
|
||||
}
|
||||
|
||||
// GetSnapshot returns a snapshot of current stats
|
||||
func (s *TrafficStats) GetSnapshot() Snapshot {
|
||||
s.speedMu.Lock()
|
||||
speedIn := s.speedIn
|
||||
speedOut := s.speedOut
|
||||
s.speedMu.Unlock()
|
||||
|
||||
totalIn := atomic.LoadInt64(&s.totalBytesIn)
|
||||
totalOut := atomic.LoadInt64(&s.totalBytesOut)
|
||||
active := atomic.LoadInt64(&s.activeConnections)
|
||||
|
||||
return Snapshot{
|
||||
TotalBytesIn: totalIn,
|
||||
TotalBytesOut: totalOut,
|
||||
TotalBytes: totalIn + totalOut,
|
||||
TotalRequests: atomic.LoadInt64(&s.totalRequests),
|
||||
ActiveConnections: active,
|
||||
SpeedIn: speedIn,
|
||||
SpeedOut: speedOut,
|
||||
Uptime: time.Since(s.startTime),
|
||||
}
|
||||
}
|
||||
117
internal/shared/ui/config.go
Normal file
117
internal/shared/ui/config.go
Normal file
@@ -0,0 +1,117 @@
|
||||
package ui
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
)
|
||||
|
||||
// RenderConfigInit renders config initialization UI
|
||||
func RenderConfigInit() string {
|
||||
title := "Drip Configuration Setup"
|
||||
box := boxStyle.Width(50)
|
||||
return "\n" + box.Render(titleStyle.Render(title)) + "\n"
|
||||
}
|
||||
|
||||
// RenderConfigShow renders the config display
|
||||
func RenderConfigShow(server, token string, tokenHidden bool, tlsEnabled bool, configPath string) string {
|
||||
lines := []string{
|
||||
KeyValue("Server", server),
|
||||
}
|
||||
|
||||
if token != "" {
|
||||
if tokenHidden {
|
||||
if len(token) > 10 {
|
||||
displayToken := token[:3] + "***" + token[len(token)-3:]
|
||||
lines = append(lines, KeyValue("Token", Muted(displayToken+" (hidden)")))
|
||||
} else {
|
||||
lines = append(lines, KeyValue("Token", Muted(token[:3]+"*** (hidden)")))
|
||||
}
|
||||
} else {
|
||||
lines = append(lines, KeyValue("Token", token))
|
||||
}
|
||||
} else {
|
||||
lines = append(lines, KeyValue("Token", Muted("(not set)")))
|
||||
}
|
||||
|
||||
tlsStatus := "enabled"
|
||||
if !tlsEnabled {
|
||||
tlsStatus = "disabled"
|
||||
}
|
||||
lines = append(lines, KeyValue("TLS", tlsStatus))
|
||||
lines = append(lines, KeyValue("Config", Muted(configPath)))
|
||||
|
||||
return Info("Current Configuration", lines...)
|
||||
}
|
||||
|
||||
// RenderConfigSaved renders config saved message
|
||||
func RenderConfigSaved(configPath string) string {
|
||||
return SuccessBox(
|
||||
"Configuration Saved",
|
||||
Muted("Config saved to: ")+configPath,
|
||||
"",
|
||||
Muted("You can now use 'drip' without --server and --token flags"),
|
||||
)
|
||||
}
|
||||
|
||||
// RenderConfigUpdated renders config updated message
|
||||
func RenderConfigUpdated(updates []string) string {
|
||||
lines := make([]string, len(updates)+1)
|
||||
for i, update := range updates {
|
||||
lines[i] = Success(update)
|
||||
}
|
||||
lines[len(updates)] = ""
|
||||
lines = append(lines, Muted("Configuration has been updated"))
|
||||
return SuccessBox("Configuration Updated", lines...)
|
||||
}
|
||||
|
||||
// RenderConfigDeleted renders config deleted message
|
||||
func RenderConfigDeleted() string {
|
||||
return SuccessBox("Configuration Deleted", Muted("Configuration file has been removed"))
|
||||
}
|
||||
|
||||
// RenderConfigValidation renders config validation results
|
||||
func RenderConfigValidation(serverValid bool, serverMsg string, tokenSet bool, tokenMsg string, tlsEnabled bool) string {
|
||||
lines := []string{}
|
||||
|
||||
if serverValid {
|
||||
lines = append(lines, Success(serverMsg))
|
||||
} else {
|
||||
lines = append(lines, Error(serverMsg))
|
||||
}
|
||||
|
||||
if tokenSet {
|
||||
lines = append(lines, Success(tokenMsg))
|
||||
} else {
|
||||
lines = append(lines, Warning(tokenMsg))
|
||||
}
|
||||
|
||||
if tlsEnabled {
|
||||
lines = append(lines, Success("TLS is enabled"))
|
||||
} else {
|
||||
lines = append(lines, Warning("TLS is disabled (not recommended for production)"))
|
||||
}
|
||||
|
||||
lines = append(lines, "")
|
||||
lines = append(lines, Muted("Configuration validation complete"))
|
||||
|
||||
if serverValid && tokenSet && tlsEnabled {
|
||||
return SuccessBox("Configuration Valid", lines...)
|
||||
}
|
||||
return WarningBox("Configuration Validation", lines...)
|
||||
}
|
||||
|
||||
// RenderDaemonStarted renders daemon started message
|
||||
func RenderDaemonStarted(tunnelType string, port int, pid int, logPath string) string {
|
||||
lines := []string{
|
||||
KeyValue("Type", Highlight(tunnelType)),
|
||||
KeyValue("Port", fmt.Sprintf("%d", port)),
|
||||
KeyValue("PID", fmt.Sprintf("%d", pid)),
|
||||
"",
|
||||
Muted("Commands:"),
|
||||
Cyan(" drip list") + Muted(" Check tunnel status"),
|
||||
Cyan(fmt.Sprintf(" drip attach %s %d", tunnelType, port)) + Muted(" View logs"),
|
||||
Cyan(fmt.Sprintf(" drip stop %s %d", tunnelType, port)) + Muted(" Stop tunnel"),
|
||||
"",
|
||||
Muted("Logs: ") + mutedStyle.Render(logPath),
|
||||
}
|
||||
return SuccessBox("Tunnel Started in Background", lines...)
|
||||
}
|
||||
184
internal/shared/ui/styles.go
Normal file
184
internal/shared/ui/styles.go
Normal file
@@ -0,0 +1,184 @@
|
||||
package ui
|
||||
|
||||
import (
|
||||
"github.com/charmbracelet/lipgloss"
|
||||
)
|
||||
|
||||
var (
|
||||
// Colors inspired by Vercel CLI
|
||||
successColor = lipgloss.Color("#0070F3")
|
||||
warningColor = lipgloss.Color("#F5A623")
|
||||
errorColor = lipgloss.Color("#E00")
|
||||
mutedColor = lipgloss.Color("#888")
|
||||
highlightColor = lipgloss.Color("#0070F3")
|
||||
cyanColor = lipgloss.Color("#50E3C2")
|
||||
|
||||
// Box styles - Vercel-like clean box
|
||||
boxStyle = lipgloss.NewStyle().
|
||||
Border(lipgloss.RoundedBorder()).
|
||||
Padding(1, 2).
|
||||
MarginTop(1).
|
||||
MarginBottom(1)
|
||||
|
||||
successBoxStyle = boxStyle.BorderForeground(successColor)
|
||||
|
||||
warningBoxStyle = boxStyle.BorderForeground(warningColor)
|
||||
|
||||
errorBoxStyle = boxStyle.BorderForeground(errorColor)
|
||||
|
||||
// Text styles
|
||||
titleStyle = lipgloss.NewStyle().
|
||||
Bold(true)
|
||||
|
||||
subtitleStyle = lipgloss.NewStyle().
|
||||
Foreground(mutedColor)
|
||||
|
||||
successStyle = lipgloss.NewStyle().
|
||||
Foreground(successColor).
|
||||
Bold(true)
|
||||
|
||||
errorStyle = lipgloss.NewStyle().
|
||||
Foreground(errorColor).
|
||||
Bold(true)
|
||||
|
||||
warningStyle = lipgloss.NewStyle().
|
||||
Foreground(warningColor).
|
||||
Bold(true)
|
||||
|
||||
mutedStyle = lipgloss.NewStyle().
|
||||
Foreground(mutedColor)
|
||||
|
||||
highlightStyle = lipgloss.NewStyle().
|
||||
Foreground(highlightColor).
|
||||
Bold(true)
|
||||
|
||||
cyanStyle = lipgloss.NewStyle().
|
||||
Foreground(cyanColor)
|
||||
|
||||
urlStyle = lipgloss.NewStyle().
|
||||
Foreground(highlightColor).
|
||||
Underline(true).
|
||||
Bold(true)
|
||||
|
||||
labelStyle = lipgloss.NewStyle().
|
||||
Foreground(mutedColor).
|
||||
Width(12)
|
||||
|
||||
valueStyle = lipgloss.NewStyle().
|
||||
Bold(true)
|
||||
|
||||
// Table styles (padding handled manually for consistent Windows output)
|
||||
tableHeaderStyle = lipgloss.NewStyle().
|
||||
Foreground(mutedColor).
|
||||
Bold(true)
|
||||
)
|
||||
|
||||
// Success returns a styled success message
|
||||
func Success(text string) string {
|
||||
return successStyle.Render("✓ " + text)
|
||||
}
|
||||
|
||||
// Error returns a styled error message
|
||||
func Error(text string) string {
|
||||
return errorStyle.Render("✗ " + text)
|
||||
}
|
||||
|
||||
// Warning returns a styled warning message
|
||||
func Warning(text string) string {
|
||||
return warningStyle.Render("⚠ " + text)
|
||||
}
|
||||
|
||||
// Muted returns a styled muted text
|
||||
func Muted(text string) string {
|
||||
return mutedStyle.Render(text)
|
||||
}
|
||||
|
||||
// Highlight returns a styled highlighted text
|
||||
func Highlight(text string) string {
|
||||
return highlightStyle.Render(text)
|
||||
}
|
||||
|
||||
// Cyan returns a styled cyan text
|
||||
func Cyan(text string) string {
|
||||
return cyanStyle.Render(text)
|
||||
}
|
||||
|
||||
// URL returns a styled URL
|
||||
func URL(text string) string {
|
||||
return urlStyle.Render(text)
|
||||
}
|
||||
|
||||
// Title returns a styled title
|
||||
func Title(text string) string {
|
||||
return titleStyle.Render(text)
|
||||
}
|
||||
|
||||
// Subtitle returns a styled subtitle
|
||||
func Subtitle(text string) string {
|
||||
return subtitleStyle.Render(text)
|
||||
}
|
||||
|
||||
// KeyValue returns a styled key-value pair
|
||||
func KeyValue(key, value string) string {
|
||||
return labelStyle.Render(key+":") + " " + valueStyle.Render(value)
|
||||
}
|
||||
|
||||
// Info renders an info box (Vercel-style)
|
||||
func Info(title string, lines ...string) string {
|
||||
content := titleStyle.Render(title)
|
||||
if len(lines) > 0 {
|
||||
content += "\n\n"
|
||||
for i, line := range lines {
|
||||
if i > 0 {
|
||||
content += "\n"
|
||||
}
|
||||
content += line
|
||||
}
|
||||
}
|
||||
return boxStyle.Render(content)
|
||||
}
|
||||
|
||||
// SuccessBox renders a success box
|
||||
func SuccessBox(title string, lines ...string) string {
|
||||
content := successStyle.Render("✓ " + title)
|
||||
if len(lines) > 0 {
|
||||
content += "\n\n"
|
||||
for i, line := range lines {
|
||||
if i > 0 {
|
||||
content += "\n"
|
||||
}
|
||||
content += line
|
||||
}
|
||||
}
|
||||
return successBoxStyle.Render(content)
|
||||
}
|
||||
|
||||
// WarningBox renders a warning box
|
||||
func WarningBox(title string, lines ...string) string {
|
||||
content := warningStyle.Render("⚠ " + title)
|
||||
if len(lines) > 0 {
|
||||
content += "\n\n"
|
||||
for i, line := range lines {
|
||||
if i > 0 {
|
||||
content += "\n"
|
||||
}
|
||||
content += line
|
||||
}
|
||||
}
|
||||
return warningBoxStyle.Render(content)
|
||||
}
|
||||
|
||||
// ErrorBox renders an error box
|
||||
func ErrorBox(title string, lines ...string) string {
|
||||
content := errorStyle.Render("✗ " + title)
|
||||
if len(lines) > 0 {
|
||||
content += "\n\n"
|
||||
for i, line := range lines {
|
||||
if i > 0 {
|
||||
content += "\n"
|
||||
}
|
||||
content += line
|
||||
}
|
||||
}
|
||||
return errorBoxStyle.Render(content)
|
||||
}
|
||||
145
internal/shared/ui/table.go
Normal file
145
internal/shared/ui/table.go
Normal file
@@ -0,0 +1,145 @@
|
||||
package ui
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"runtime"
|
||||
"strings"
|
||||
|
||||
"github.com/charmbracelet/lipgloss"
|
||||
)
|
||||
|
||||
// Table represents a simple table for CLI output
|
||||
type Table struct {
|
||||
headers []string
|
||||
rows [][]string
|
||||
title string
|
||||
}
|
||||
|
||||
// NewTable creates a new table
|
||||
func NewTable(headers []string) *Table {
|
||||
return &Table{
|
||||
headers: headers,
|
||||
rows: [][]string{},
|
||||
}
|
||||
}
|
||||
|
||||
// WithTitle sets the table title
|
||||
func (t *Table) WithTitle(title string) *Table {
|
||||
t.title = title
|
||||
return t
|
||||
}
|
||||
|
||||
// AddRow adds a row to the table
|
||||
func (t *Table) AddRow(row []string) *Table {
|
||||
t.rows = append(t.rows, row)
|
||||
return t
|
||||
}
|
||||
|
||||
// Render renders the table (Vercel-style)
|
||||
func (t *Table) Render() string {
|
||||
if len(t.rows) == 0 {
|
||||
return ""
|
||||
}
|
||||
|
||||
// Calculate column widths
|
||||
colWidths := make([]int, len(t.headers))
|
||||
for i, header := range t.headers {
|
||||
colWidths[i] = lipgloss.Width(header)
|
||||
}
|
||||
for _, row := range t.rows {
|
||||
for i, cell := range row {
|
||||
if i < len(colWidths) {
|
||||
width := lipgloss.Width(cell)
|
||||
if width > colWidths[i] {
|
||||
colWidths[i] = width
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
var output strings.Builder
|
||||
|
||||
// Title
|
||||
if t.title != "" {
|
||||
output.WriteString("\n")
|
||||
output.WriteString(titleStyle.Render(t.title))
|
||||
output.WriteString("\n\n")
|
||||
}
|
||||
|
||||
// Header
|
||||
headerParts := make([]string, len(t.headers))
|
||||
for i, header := range t.headers {
|
||||
styled := tableHeaderStyle.Render(header)
|
||||
headerParts[i] = padRight(styled, colWidths[i])
|
||||
}
|
||||
output.WriteString(strings.Join(headerParts, " "))
|
||||
output.WriteString("\n")
|
||||
|
||||
// Separator line
|
||||
separatorChar := "─"
|
||||
if runtime.GOOS == "windows" {
|
||||
separatorChar = "-"
|
||||
}
|
||||
separatorParts := make([]string, len(t.headers))
|
||||
for i := range t.headers {
|
||||
separatorParts[i] = mutedStyle.Render(strings.Repeat(separatorChar, colWidths[i]))
|
||||
}
|
||||
output.WriteString(strings.Join(separatorParts, " "))
|
||||
output.WriteString("\n")
|
||||
|
||||
// Rows
|
||||
for _, row := range t.rows {
|
||||
rowParts := make([]string, len(t.headers))
|
||||
for i, cell := range row {
|
||||
if i < len(colWidths) {
|
||||
rowParts[i] = padRight(cell, colWidths[i])
|
||||
}
|
||||
}
|
||||
output.WriteString(strings.Join(rowParts, " "))
|
||||
output.WriteString("\n")
|
||||
}
|
||||
|
||||
output.WriteString("\n")
|
||||
return output.String()
|
||||
}
|
||||
|
||||
// padRight pads
|
||||
func padRight(text string, targetWidth int) string {
|
||||
visibleWidth := lipgloss.Width(text)
|
||||
if visibleWidth >= targetWidth {
|
||||
return text
|
||||
}
|
||||
padding := strings.Repeat(" ", targetWidth-visibleWidth)
|
||||
return text + padding
|
||||
}
|
||||
|
||||
// Print prints the table
|
||||
func (t *Table) Print() {
|
||||
fmt.Print(t.Render())
|
||||
}
|
||||
|
||||
// RenderList renders a simple list with bullet points
|
||||
func RenderList(items []string) string {
|
||||
bullet := "•"
|
||||
if runtime.GOOS == "windows" {
|
||||
bullet = "*"
|
||||
}
|
||||
var output strings.Builder
|
||||
for _, item := range items {
|
||||
output.WriteString(mutedStyle.Render(" " + bullet + " "))
|
||||
output.WriteString(item)
|
||||
output.WriteString("\n")
|
||||
}
|
||||
return output.String()
|
||||
}
|
||||
|
||||
// RenderNumberedList renders a numbered list
|
||||
func RenderNumberedList(items []string) string {
|
||||
var output strings.Builder
|
||||
for i, item := range items {
|
||||
output.WriteString(mutedStyle.Render(fmt.Sprintf(" %d. ", i+1)))
|
||||
output.WriteString(item)
|
||||
output.WriteString("\n")
|
||||
}
|
||||
return output.String()
|
||||
}
|
||||
251
internal/shared/ui/tunnel.go
Normal file
251
internal/shared/ui/tunnel.go
Normal file
@@ -0,0 +1,251 @@
|
||||
package ui
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/charmbracelet/lipgloss"
|
||||
)
|
||||
|
||||
const (
|
||||
tunnelCardWidth = 76
|
||||
statsColumnWidth = 32
|
||||
)
|
||||
|
||||
var (
|
||||
latencyFastColor = lipgloss.Color("#22c55e") // green
|
||||
latencyYellowColor = lipgloss.Color("#eab308") // yellow
|
||||
latencyOrangeColor = lipgloss.Color("#f97316") // orange
|
||||
latencyRedColor = lipgloss.Color("#ef4444") // red
|
||||
)
|
||||
|
||||
// TunnelStatus represents the status of a tunnel
|
||||
type TunnelStatus struct {
|
||||
Type string // "http", "https", "tcp"
|
||||
URL string // Public URL
|
||||
LocalAddr string // Local address
|
||||
Latency time.Duration // Current latency
|
||||
BytesIn int64 // Bytes received
|
||||
BytesOut int64 // Bytes sent
|
||||
SpeedIn float64 // Download speed
|
||||
SpeedOut float64 // Upload speed
|
||||
TotalRequest int64 // Total requests
|
||||
}
|
||||
|
||||
// RenderTunnelConnected renders the tunnel connection card
|
||||
func RenderTunnelConnected(status *TunnelStatus) string {
|
||||
icon, typeStr, accent := tunnelVisuals(status.Type)
|
||||
|
||||
card := lipgloss.NewStyle().
|
||||
Border(lipgloss.RoundedBorder()).
|
||||
BorderForeground(accent).
|
||||
Padding(1, 2).
|
||||
Width(tunnelCardWidth)
|
||||
|
||||
typeBadge := lipgloss.NewStyle().
|
||||
Background(accent).
|
||||
Foreground(lipgloss.Color("#f8fafc")).
|
||||
Bold(true).
|
||||
Padding(0, 1).
|
||||
Render(strings.ToUpper(typeStr) + " TUNNEL")
|
||||
|
||||
headline := lipgloss.JoinHorizontal(
|
||||
lipgloss.Left,
|
||||
lipgloss.NewStyle().Foreground(accent).Render(icon),
|
||||
lipgloss.NewStyle().Bold(true).MarginLeft(1).Render("Tunnel Connected"),
|
||||
lipgloss.NewStyle().MarginLeft(2).Render(typeBadge),
|
||||
)
|
||||
|
||||
urlLine := lipgloss.JoinHorizontal(
|
||||
lipgloss.Left,
|
||||
urlStyle.Foreground(accent).Render(status.URL),
|
||||
lipgloss.NewStyle().MarginLeft(1).Foreground(mutedColor).Render("(forwarded address)"),
|
||||
)
|
||||
|
||||
forwardLine := lipgloss.NewStyle().
|
||||
MarginLeft(2).
|
||||
Render(Muted("⇢ ") + valueStyle.Render(status.LocalAddr))
|
||||
|
||||
hint := lipgloss.NewStyle().
|
||||
Foreground(latencyOrangeColor).
|
||||
Render("Ctrl+C to stop • reconnects automatically")
|
||||
|
||||
content := lipgloss.JoinVertical(
|
||||
lipgloss.Left,
|
||||
headline,
|
||||
"",
|
||||
urlLine,
|
||||
forwardLine,
|
||||
"",
|
||||
hint,
|
||||
)
|
||||
|
||||
return "\n" + card.Render(content) + "\n"
|
||||
}
|
||||
|
||||
// RenderTunnelStats renders real-time tunnel statistics in a card
|
||||
func RenderTunnelStats(status *TunnelStatus) string {
|
||||
latencyStr := formatLatency(status.Latency)
|
||||
trafficStr := fmt.Sprintf("↓ %s ↑ %s", formatBytes(status.BytesIn), formatBytes(status.BytesOut))
|
||||
speedStr := fmt.Sprintf("↓ %s ↑ %s", formatSpeed(status.SpeedIn), formatSpeed(status.SpeedOut))
|
||||
requestsStr := fmt.Sprintf("%d", status.TotalRequest)
|
||||
|
||||
_, _, accent := tunnelVisuals(status.Type)
|
||||
|
||||
requestLabel := "Requests"
|
||||
if status.Type == "tcp" {
|
||||
requestLabel = "Connections"
|
||||
}
|
||||
|
||||
header := lipgloss.JoinHorizontal(
|
||||
lipgloss.Left,
|
||||
lipgloss.NewStyle().Foreground(accent).Render("◉"),
|
||||
lipgloss.NewStyle().Bold(true).MarginLeft(1).Render("Live Metrics"),
|
||||
)
|
||||
|
||||
row1 := lipgloss.JoinHorizontal(
|
||||
lipgloss.Top,
|
||||
statColumn("Latency", latencyStr, statsColumnWidth),
|
||||
statColumn(requestLabel, highlightStyle.Render(requestsStr), statsColumnWidth),
|
||||
)
|
||||
|
||||
row2 := lipgloss.JoinHorizontal(
|
||||
lipgloss.Top,
|
||||
statColumn("Traffic", Cyan(trafficStr), statsColumnWidth),
|
||||
statColumn("Speed", warningStyle.Render(speedStr), statsColumnWidth),
|
||||
)
|
||||
|
||||
card := lipgloss.NewStyle().
|
||||
Border(lipgloss.RoundedBorder()).
|
||||
BorderForeground(accent).
|
||||
Padding(1, 2).
|
||||
Width(tunnelCardWidth)
|
||||
|
||||
body := lipgloss.JoinVertical(
|
||||
lipgloss.Left,
|
||||
header,
|
||||
"",
|
||||
row1,
|
||||
row2,
|
||||
)
|
||||
|
||||
return "\n" + card.Render(body) + "\n"
|
||||
}
|
||||
|
||||
// RenderConnecting renders the connecting message
|
||||
func RenderConnecting(serverAddr string, attempt int, maxAttempts int) string {
|
||||
if attempt == 0 {
|
||||
return Highlight("◌") + " Connecting to " + Muted(serverAddr) + "..."
|
||||
}
|
||||
return Warning(fmt.Sprintf("◌ Reconnecting to %s (attempt %d/%d)...", serverAddr, attempt, maxAttempts))
|
||||
}
|
||||
|
||||
// RenderConnectionFailed renders connection failure message
|
||||
func RenderConnectionFailed(err error) string {
|
||||
return Error(fmt.Sprintf("Connection failed: %v", err))
|
||||
}
|
||||
|
||||
// RenderShuttingDown renders shutdown message
|
||||
func RenderShuttingDown() string {
|
||||
return Warning("⏹ Shutting down...")
|
||||
}
|
||||
|
||||
// RenderConnectionLost renders connection lost message
|
||||
func RenderConnectionLost() string {
|
||||
return Error("⚠ Connection lost!")
|
||||
}
|
||||
|
||||
// RenderRetrying renders retry message
|
||||
func RenderRetrying(interval time.Duration) string {
|
||||
return Muted(fmt.Sprintf(" Retrying in %v...", interval))
|
||||
}
|
||||
|
||||
// formatLatency formats latency with color
|
||||
func formatLatency(d time.Duration) string {
|
||||
if d == 0 {
|
||||
return mutedStyle.Render("measuring...")
|
||||
}
|
||||
|
||||
ms := d.Milliseconds()
|
||||
var style lipgloss.Style
|
||||
|
||||
switch {
|
||||
case ms < 50:
|
||||
style = lipgloss.NewStyle().Foreground(latencyFastColor)
|
||||
case ms < 150:
|
||||
style = lipgloss.NewStyle().Foreground(latencyYellowColor)
|
||||
case ms < 300:
|
||||
style = lipgloss.NewStyle().Foreground(latencyOrangeColor)
|
||||
default:
|
||||
style = lipgloss.NewStyle().Foreground(latencyRedColor)
|
||||
}
|
||||
|
||||
if ms == 0 {
|
||||
us := d.Microseconds()
|
||||
return style.Render(fmt.Sprintf("%dµs", us))
|
||||
}
|
||||
|
||||
return style.Render(fmt.Sprintf("%dms", ms))
|
||||
}
|
||||
|
||||
// formatBytes formats bytes to human readable format
|
||||
func formatBytes(bytes int64) string {
|
||||
const unit = 1024
|
||||
if bytes < unit {
|
||||
return fmt.Sprintf("%d B", bytes)
|
||||
}
|
||||
div, exp := int64(unit), 0
|
||||
for n := bytes / unit; n >= unit; n /= unit {
|
||||
div *= unit
|
||||
exp++
|
||||
}
|
||||
return fmt.Sprintf("%.1f %cB", float64(bytes)/float64(div), "KMGTPE"[exp])
|
||||
}
|
||||
|
||||
// formatSpeed formats speed to human readable format
|
||||
func formatSpeed(bytesPerSec float64) string {
|
||||
const unit = 1024.0
|
||||
if bytesPerSec < unit {
|
||||
return fmt.Sprintf("%.0f B/s", bytesPerSec)
|
||||
}
|
||||
div, exp := unit, 0
|
||||
for n := bytesPerSec / unit; n >= unit; n /= unit {
|
||||
div *= unit
|
||||
exp++
|
||||
}
|
||||
return fmt.Sprintf("%.1f %cB/s", bytesPerSec/div, "KMGTPE"[exp])
|
||||
}
|
||||
|
||||
func statColumn(label, value string, width int) string {
|
||||
labelView := lipgloss.NewStyle().
|
||||
Foreground(mutedColor).
|
||||
Render(strings.ToUpper(label))
|
||||
|
||||
block := lipgloss.JoinHorizontal(
|
||||
lipgloss.Left,
|
||||
labelView,
|
||||
lipgloss.NewStyle().MarginLeft(1).Render(value),
|
||||
)
|
||||
|
||||
if width <= 0 {
|
||||
return block
|
||||
}
|
||||
|
||||
return lipgloss.NewStyle().
|
||||
Width(width).
|
||||
Render(block)
|
||||
}
|
||||
|
||||
func tunnelVisuals(tunnelType string) (string, string, lipgloss.Color) {
|
||||
switch tunnelType {
|
||||
case "http":
|
||||
return "🚀", "HTTP", lipgloss.Color("#0070F3")
|
||||
case "https":
|
||||
return "🔒", "HTTPS", lipgloss.Color("#2D8CFF")
|
||||
case "tcp":
|
||||
return "🔌", "TCP", lipgloss.Color("#50E3C2")
|
||||
default:
|
||||
return "🌐", strings.ToUpper(tunnelType), lipgloss.Color("#0070F3")
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user