feat(tunnel): switch to yamux stream proxying and connection pooling

- Introduce pooled tunnel sessions (TunnelID/DataConnect) on client/server
- Proxy HTTP/HTTPS via raw HTTP over yamux streams; pipe TCP streams directly
- Move UI/stats into internal/shared; refactor CLI tunnel helpers; drop msgpack/hpack legacy
This commit is contained in:
Gouryella
2025-12-13 18:03:44 +08:00
parent 3c93789266
commit 0c19c3300c
55 changed files with 3380 additions and 4849 deletions

View File

@@ -1,280 +0,0 @@
package hpack
import (
"bytes"
"errors"
"fmt"
"net/http"
"sync"
)
// Decoder decompresses HPACK-encoded headers
// Each connection MUST have its own decoder instance to maintain correct state
type Decoder struct {
mu sync.Mutex
dynamicTable *DynamicTable
staticTable *StaticTable
maxTableSize uint32
}
// NewDecoder creates a new HPACK decoder
func NewDecoder(maxTableSize uint32) *Decoder {
if maxTableSize == 0 {
maxTableSize = DefaultDynamicTableSize
}
return &Decoder{
dynamicTable: NewDynamicTable(maxTableSize),
staticTable: GetStaticTable(),
maxTableSize: maxTableSize,
}
}
// Decode decodes HPACK-encoded headers
func (d *Decoder) Decode(data []byte) (http.Header, error) {
d.mu.Lock()
defer d.mu.Unlock()
if len(data) == 0 {
return http.Header{}, nil
}
headers := make(http.Header)
buf := bytes.NewReader(data)
for buf.Len() > 0 {
b, err := buf.ReadByte()
if err != nil {
return nil, fmt.Errorf("read header byte: %w", err)
}
// Unread the byte so we can process it properly
if err := buf.UnreadByte(); err != nil {
return nil, err
}
var name, value string
if b&0x80 != 0 {
// Indexed header field (10xxxxxx)
name, value, err = d.decodeIndexedHeader(buf)
} else if b&0x40 != 0 {
// Literal with incremental indexing (01xxxxxx)
name, value, err = d.decodeLiteralWithIndexing(buf)
} else {
// Literal without indexing (0000xxxx)
name, value, err = d.decodeLiteralWithoutIndexing(buf)
}
if err != nil {
return nil, err
}
headers.Add(name, value)
}
return headers, nil
}
// decodeIndexedHeader decodes an indexed header field
func (d *Decoder) decodeIndexedHeader(buf *bytes.Reader) (string, string, error) {
index, err := d.readInteger(buf, 7)
if err != nil {
return "", "", fmt.Errorf("read index: %w", err)
}
if index == 0 {
return "", "", errors.New("invalid index: 0")
}
staticSize := uint32(d.staticTable.Size())
if index <= staticSize {
// Static table
return d.staticTable.Get(index - 1)
}
// Dynamic table (indices start after static table)
dynamicIndex := index - staticSize - 1
return d.dynamicTable.Get(dynamicIndex)
}
// decodeLiteralWithIndexing decodes a literal header with incremental indexing
func (d *Decoder) decodeLiteralWithIndexing(buf *bytes.Reader) (string, string, error) {
nameIndex, err := d.readInteger(buf, 6)
if err != nil {
return "", "", err
}
var name string
if nameIndex == 0 {
// Name is literal
name, err = d.readString(buf)
if err != nil {
return "", "", fmt.Errorf("read name: %w", err)
}
} else {
// Name is indexed
staticSize := uint32(d.staticTable.Size())
if nameIndex <= staticSize {
name, _, err = d.staticTable.Get(nameIndex - 1)
} else {
dynamicIndex := nameIndex - staticSize - 1
name, _, err = d.dynamicTable.Get(dynamicIndex)
}
if err != nil {
return "", "", fmt.Errorf("get indexed name: %w", err)
}
}
// Value is always literal
value, err := d.readString(buf)
if err != nil {
return "", "", fmt.Errorf("read value: %w", err)
}
// Add to dynamic table
d.dynamicTable.Add(name, value)
return name, value, nil
}
// decodeLiteralWithoutIndexing decodes a literal header without indexing
func (d *Decoder) decodeLiteralWithoutIndexing(buf *bytes.Reader) (string, string, error) {
nameIndex, err := d.readInteger(buf, 4)
if err != nil {
return "", "", err
}
var name string
if nameIndex == 0 {
// Name is literal
name, err = d.readString(buf)
if err != nil {
return "", "", fmt.Errorf("read name: %w", err)
}
} else {
// Name is indexed
staticSize := uint32(d.staticTable.Size())
if nameIndex <= staticSize {
name, _, err = d.staticTable.Get(nameIndex - 1)
} else {
dynamicIndex := nameIndex - staticSize - 1
name, _, err = d.dynamicTable.Get(dynamicIndex)
}
if err != nil {
return "", "", fmt.Errorf("get indexed name: %w", err)
}
}
// Value is always literal
value, err := d.readString(buf)
if err != nil {
return "", "", fmt.Errorf("read value: %w", err)
}
// Do NOT add to dynamic table
return name, value, nil
}
// readInteger reads an HPACK integer
func (d *Decoder) readInteger(buf *bytes.Reader, prefixBits int) (uint32, error) {
if prefixBits < 1 || prefixBits > 8 {
return 0, fmt.Errorf("invalid prefix bits: %d", prefixBits)
}
b, err := buf.ReadByte()
if err != nil {
return 0, err
}
maxPrefix := uint32((1 << prefixBits) - 1)
mask := byte(maxPrefix)
value := uint32(b & mask)
if value < maxPrefix {
return value, nil
}
// Multi-byte integer
m := uint32(0)
for {
b, err := buf.ReadByte()
if err != nil {
return 0, err
}
value += uint32(b&0x7f) << m
m += 7
if b&0x80 == 0 {
break
}
if m > 28 {
return 0, errors.New("integer overflow")
}
}
return value, nil
}
// readString reads an HPACK string
func (d *Decoder) readString(buf *bytes.Reader) (string, error) {
b, err := buf.ReadByte()
if err != nil {
return "", err
}
if err := buf.UnreadByte(); err != nil {
return "", err
}
huffmanEncoded := (b & 0x80) != 0
length, err := d.readInteger(buf, 7)
if err != nil {
return "", fmt.Errorf("read string length: %w", err)
}
if length == 0 {
return "", nil
}
if length > uint32(buf.Len()) {
return "", fmt.Errorf("string length %d exceeds buffer size %d", length, buf.Len())
}
data := make([]byte, length)
n, err := buf.Read(data)
if err != nil {
return "", err
}
if n != int(length) {
return "", fmt.Errorf("expected %d bytes, read %d", length, n)
}
if huffmanEncoded {
// TODO: Implement Huffman decoding if needed
return "", errors.New("huffman decoding not implemented")
}
return string(data), nil
}
// SetMaxTableSize updates the dynamic table size
func (d *Decoder) SetMaxTableSize(size uint32) {
d.mu.Lock()
defer d.mu.Unlock()
d.maxTableSize = size
d.dynamicTable.SetMaxSize(size)
}
// Reset clears the dynamic table
func (d *Decoder) Reset() {
d.mu.Lock()
defer d.mu.Unlock()
d.dynamicTable = NewDynamicTable(d.maxTableSize)
}

View File

@@ -1,124 +0,0 @@
package hpack
import (
"fmt"
)
// DynamicTable implements the HPACK dynamic table (RFC 7541 Section 2.3.2)
// The dynamic table is a FIFO queue where new entries are added at the beginning
// and old entries are evicted when the table size exceeds the maximum
type DynamicTable struct {
entries []HeaderField
size uint32 // Current size in bytes
maxSize uint32 // Maximum size in bytes
}
// HeaderField represents a header name-value pair
type HeaderField struct {
Name string
Value string
}
// Size returns the size of this header field in bytes
// RFC 7541: size = len(name) + len(value) + 32
func (h *HeaderField) Size() uint32 {
return uint32(len(h.Name) + len(h.Value) + 32)
}
// NewDynamicTable creates a new dynamic table with the specified maximum size
func NewDynamicTable(maxSize uint32) *DynamicTable {
return &DynamicTable{
entries: make([]HeaderField, 0, 32),
size: 0,
maxSize: maxSize,
}
}
// Add adds a header field to the dynamic table
// New entries are added at the beginning (index 0)
func (dt *DynamicTable) Add(name, value string) {
field := HeaderField{Name: name, Value: value}
fieldSize := field.Size()
// If the field is larger than maxSize, don't add it
if fieldSize > dt.maxSize {
dt.evictAll()
return
}
// Evict entries if necessary to make room
for dt.size+fieldSize > dt.maxSize && len(dt.entries) > 0 {
dt.evictOldest()
}
// Add new entry at the beginning
dt.entries = append([]HeaderField{field}, dt.entries...)
dt.size += fieldSize
}
// Get retrieves a header field by index (0-based)
// Index 0 is the most recently added entry
func (dt *DynamicTable) Get(index uint32) (string, string, error) {
if index >= uint32(len(dt.entries)) {
return "", "", fmt.Errorf("index %d out of range (table size: %d)", index, len(dt.entries))
}
field := dt.entries[index]
return field.Name, field.Value, nil
}
// FindExact searches for an exact match (name and value)
// Returns the index (0-based) and true if found
func (dt *DynamicTable) FindExact(name, value string) (uint32, bool) {
for i, field := range dt.entries {
if field.Name == name && field.Value == value {
return uint32(i), true
}
}
return 0, false
}
// FindName searches for a name match
// Returns the index (0-based) and true if found
func (dt *DynamicTable) FindName(name string) (uint32, bool) {
for i, field := range dt.entries {
if field.Name == name {
return uint32(i), true
}
}
return 0, false
}
// SetMaxSize updates the maximum table size
// If the new size is smaller, entries are evicted
func (dt *DynamicTable) SetMaxSize(maxSize uint32) {
dt.maxSize = maxSize
// Evict entries if current size exceeds new max
for dt.size > dt.maxSize && len(dt.entries) > 0 {
dt.evictOldest()
}
}
// CurrentSize returns the current size of the table in bytes
func (dt *DynamicTable) CurrentSize() uint32 {
return dt.size
}
// evictOldest removes the oldest entry (last in the slice)
func (dt *DynamicTable) evictOldest() {
if len(dt.entries) == 0 {
return
}
lastIndex := len(dt.entries) - 1
evicted := dt.entries[lastIndex]
dt.entries = dt.entries[:lastIndex]
dt.size -= evicted.Size()
}
// evictAll removes all entries
func (dt *DynamicTable) evictAll() {
dt.entries = dt.entries[:0]
dt.size = 0
}

View File

@@ -1,200 +0,0 @@
package hpack
import (
"bytes"
"errors"
"fmt"
"net/http"
"strings"
"sync"
)
const (
// DefaultDynamicTableSize is the default size of the dynamic table (4KB)
DefaultDynamicTableSize = 4096
// IndexedHeaderField represents a fully indexed header field
indexedHeaderField = 0x80 // 10xxxxxx
// LiteralHeaderFieldWithIndexing represents a literal with incremental indexing
literalHeaderFieldWithIndexing = 0x40 // 01xxxxxx
)
// Encoder compresses HTTP headers using HPACK
// Each connection MUST have its own encoder instance to avoid state corruption
type Encoder struct {
mu sync.Mutex
dynamicTable *DynamicTable
staticTable *StaticTable
maxTableSize uint32
}
// NewEncoder creates a new HPACK encoder with the specified dynamic table size
// This encoder is NOT thread-safe and should be used by a single connection
func NewEncoder(maxTableSize uint32) *Encoder {
if maxTableSize == 0 {
maxTableSize = DefaultDynamicTableSize
}
return &Encoder{
dynamicTable: NewDynamicTable(maxTableSize),
staticTable: GetStaticTable(),
maxTableSize: maxTableSize,
}
}
// Encode encodes HTTP headers into HPACK binary format
// This method is safe to call concurrently within the same encoder instance
func (e *Encoder) Encode(headers http.Header) ([]byte, error) {
e.mu.Lock()
defer e.mu.Unlock()
if headers == nil {
return nil, errors.New("headers cannot be nil")
}
buf := &bytes.Buffer{}
for name, values := range headers {
for _, value := range values {
if err := e.encodeHeaderField(buf, name, value); err != nil {
return nil, fmt.Errorf("encode header %s: %w", name, err)
}
}
}
return buf.Bytes(), nil
}
// encodeHeaderField encodes a single header field
func (e *Encoder) encodeHeaderField(buf *bytes.Buffer, name, value string) error {
// HTTP/2 requires header names to be lowercase (RFC 7540 Section 8.1.2)
// Convert to lowercase for table lookups and storage
nameLower := strings.ToLower(name)
// Try to find in static table first
if index, found := e.staticTable.FindExact(nameLower, value); found {
return e.writeIndexedHeader(buf, index+1)
}
// Check if name exists in static table (for literal with name reference)
if index, found := e.staticTable.FindName(nameLower); found {
return e.writeLiteralWithIndexing(buf, index+1, value, true)
}
// Try dynamic table
if index, found := e.dynamicTable.FindExact(nameLower, value); found {
// Dynamic table indices start after static table
dynamicIndex := uint32(e.staticTable.Size()) + index + 1
return e.writeIndexedHeader(buf, dynamicIndex)
}
if index, found := e.dynamicTable.FindName(nameLower); found {
dynamicIndex := uint32(e.staticTable.Size()) + index + 1
return e.writeLiteralWithIndexing(buf, dynamicIndex, value, true)
}
// Not found anywhere - literal with indexing and new name
// Write literal flag
buf.WriteByte(literalHeaderFieldWithIndexing)
// Write name as literal string (must come before value)
// Use lowercase name for consistency
if err := e.writeString(buf, nameLower, false); err != nil {
return err
}
// Write value as literal string
if err := e.writeString(buf, value, false); err != nil {
return err
}
// Add to dynamic table with lowercase name
e.dynamicTable.Add(nameLower, value)
return nil
}
// writeIndexedHeader writes an indexed header field (10xxxxxx)
func (e *Encoder) writeIndexedHeader(buf *bytes.Buffer, index uint32) error {
return e.writeInteger(buf, index, 7, indexedHeaderField)
}
// writeLiteralWithIndexing writes a literal header with incremental indexing (01xxxxxx)
func (e *Encoder) writeLiteralWithIndexing(buf *bytes.Buffer, nameIndex uint32, value string, hasIndex bool) error {
if hasIndex {
// Write name as index
if err := e.writeInteger(buf, nameIndex, 6, literalHeaderFieldWithIndexing); err != nil {
return err
}
} else {
// Write literal flag
buf.WriteByte(literalHeaderFieldWithIndexing)
}
// Write value as literal string
return e.writeString(buf, value, false)
}
// writeInteger writes an integer using HPACK integer representation
func (e *Encoder) writeInteger(buf *bytes.Buffer, value uint32, prefixBits int, prefix byte) error {
if prefixBits < 1 || prefixBits > 8 {
return fmt.Errorf("invalid prefix bits: %d", prefixBits)
}
maxPrefix := uint32((1 << prefixBits) - 1)
if value < maxPrefix {
buf.WriteByte(prefix | byte(value))
return nil
}
// Value >= maxPrefix, need multiple bytes
buf.WriteByte(prefix | byte(maxPrefix))
value -= maxPrefix
for value >= 128 {
buf.WriteByte(byte(value%128) | 0x80)
value /= 128
}
buf.WriteByte(byte(value))
return nil
}
// writeString writes a string using HPACK string representation
func (e *Encoder) writeString(buf *bytes.Buffer, str string, huffmanEncode bool) error {
// For simplicity, we don't use Huffman encoding in this implementation
// Huffman flag is bit 7, followed by length in remaining 7 bits
length := uint32(len(str))
if huffmanEncode {
// TODO: Implement Huffman encoding if needed
return errors.New("huffman encoding not implemented")
}
// Write length with H=0 (no Huffman)
if err := e.writeInteger(buf, length, 7, 0x00); err != nil {
return err
}
// Write string bytes
buf.WriteString(str)
return nil
}
// SetMaxTableSize updates the dynamic table size
func (e *Encoder) SetMaxTableSize(size uint32) {
e.mu.Lock()
defer e.mu.Unlock()
e.maxTableSize = size
e.dynamicTable.SetMaxSize(size)
}
// Reset clears the dynamic table
func (e *Encoder) Reset() {
e.mu.Lock()
defer e.mu.Unlock()
e.dynamicTable = NewDynamicTable(e.maxTableSize)
}

View File

@@ -1,150 +0,0 @@
package hpack
import (
"fmt"
"sync"
)
// StaticTable implements the HPACK static table (RFC 7541 Appendix A)
// The static table is predefined and never changes
type StaticTable struct {
entries []HeaderField
nameMap map[string][]uint32 // Maps name to list of indices
}
var (
staticTableInstance *StaticTable
staticTableOnce sync.Once
)
// GetStaticTable returns the singleton static table instance
func GetStaticTable() *StaticTable {
staticTableOnce.Do(func() {
staticTableInstance = newStaticTable()
})
return staticTableInstance
}
// newStaticTable creates and initializes the static table
func newStaticTable() *StaticTable {
// RFC 7541 Appendix A - Static Table Definition
// We include the most common headers for HTTP
entries := []HeaderField{
{Name: ":authority", Value: ""},
{Name: ":method", Value: "GET"},
{Name: ":method", Value: "POST"},
{Name: ":path", Value: "/"},
{Name: ":path", Value: "/index.html"},
{Name: ":scheme", Value: "http"},
{Name: ":scheme", Value: "https"},
{Name: ":status", Value: "200"},
{Name: ":status", Value: "204"},
{Name: ":status", Value: "206"},
{Name: ":status", Value: "304"},
{Name: ":status", Value: "400"},
{Name: ":status", Value: "404"},
{Name: ":status", Value: "500"},
{Name: "accept-charset", Value: ""},
{Name: "accept-encoding", Value: "gzip, deflate"},
{Name: "accept-language", Value: ""},
{Name: "accept-ranges", Value: ""},
{Name: "accept", Value: ""},
{Name: "access-control-allow-origin", Value: ""},
{Name: "age", Value: ""},
{Name: "allow", Value: ""},
{Name: "authorization", Value: ""},
{Name: "cache-control", Value: ""},
{Name: "content-disposition", Value: ""},
{Name: "content-encoding", Value: ""},
{Name: "content-language", Value: ""},
{Name: "content-length", Value: ""},
{Name: "content-location", Value: ""},
{Name: "content-range", Value: ""},
{Name: "content-type", Value: ""},
{Name: "cookie", Value: ""},
{Name: "date", Value: ""},
{Name: "etag", Value: ""},
{Name: "expect", Value: ""},
{Name: "expires", Value: ""},
{Name: "from", Value: ""},
{Name: "host", Value: ""},
{Name: "if-match", Value: ""},
{Name: "if-modified-since", Value: ""},
{Name: "if-none-match", Value: ""},
{Name: "if-range", Value: ""},
{Name: "if-unmodified-since", Value: ""},
{Name: "last-modified", Value: ""},
{Name: "link", Value: ""},
{Name: "location", Value: ""},
{Name: "max-forwards", Value: ""},
{Name: "proxy-authenticate", Value: ""},
{Name: "proxy-authorization", Value: ""},
{Name: "range", Value: ""},
{Name: "referer", Value: ""},
{Name: "refresh", Value: ""},
{Name: "retry-after", Value: ""},
{Name: "server", Value: ""},
{Name: "set-cookie", Value: ""},
{Name: "strict-transport-security", Value: ""},
{Name: "transfer-encoding", Value: ""},
{Name: "user-agent", Value: ""},
{Name: "vary", Value: ""},
{Name: "via", Value: ""},
{Name: "www-authenticate", Value: ""},
}
// Build name index map
nameMap := make(map[string][]uint32)
for i, entry := range entries {
nameMap[entry.Name] = append(nameMap[entry.Name], uint32(i))
}
return &StaticTable{
entries: entries,
nameMap: nameMap,
}
}
// Get retrieves a header field by index (0-based)
func (st *StaticTable) Get(index uint32) (string, string, error) {
if index >= uint32(len(st.entries)) {
return "", "", fmt.Errorf("index %d out of range (static table size: %d)", index, len(st.entries))
}
field := st.entries[index]
return field.Name, field.Value, nil
}
// FindExact searches for an exact match (name and value)
// Returns the index (0-based) and true if found
func (st *StaticTable) FindExact(name, value string) (uint32, bool) {
indices, exists := st.nameMap[name]
if !exists {
return 0, false
}
for _, index := range indices {
field := st.entries[index]
if field.Value == value {
return index, true
}
}
return 0, false
}
// FindName searches for a name match
// Returns the first matching index (0-based) and true if found
func (st *StaticTable) FindName(name string) (uint32, bool) {
indices, exists := st.nameMap[name]
if !exists || len(indices) == 0 {
return 0, false
}
return indices[0], true
}
// Size returns the number of entries in the static table
func (st *StaticTable) Size() int {
return len(st.entries)
}

View File

@@ -9,6 +9,10 @@ const (
// DefaultWSPort is the default WebSocket port
DefaultWSPort = 8080
// YamuxAcceptBacklog controls how many incoming streams can be queued
// before yamux starts blocking stream opens under load.
YamuxAcceptBacklog = 4096
// HeartbeatInterval is how often clients send heartbeat messages
HeartbeatInterval = 2 * time.Second

View File

@@ -0,0 +1,71 @@
package httputil
import (
"fmt"
"io"
"net/http"
"strings"
)
// CopyHeaders copies all headers from src to dst.
func CopyHeaders(dst, src http.Header) {
for k, vv := range src {
for _, v := range vv {
dst.Add(k, v)
}
}
}
// CleanHopByHopHeaders removes hop-by-hop headers that should not be forwarded.
func CleanHopByHopHeaders(headers http.Header) {
if headers == nil {
return
}
if connectionHeaders := headers.Get("Connection"); connectionHeaders != "" {
for _, token := range strings.Split(connectionHeaders, ",") {
if t := strings.TrimSpace(token); t != "" {
headers.Del(http.CanonicalHeaderKey(t))
}
}
}
for _, key := range []string{
"Connection",
"Keep-Alive",
"Proxy-Authenticate",
"Proxy-Authorization",
"Te",
"Trailer",
"Transfer-Encoding",
"Proxy-Connection",
} {
headers.Del(key)
}
}
// WriteProxyError writes an HTTP error response to the writer.
func WriteProxyError(w io.Writer, code int, msg string) {
body := msg
resp := &http.Response{
StatusCode: code,
Status: fmt.Sprintf("%d %s", code, http.StatusText(code)),
Proto: "HTTP/1.1",
ProtoMajor: 1,
ProtoMinor: 1,
Header: make(http.Header),
Body: io.NopCloser(strings.NewReader(body)),
ContentLength: int64(len(body)),
Close: true,
}
resp.Header.Set("Content-Type", "text/plain; charset=utf-8")
resp.Header.Set("Content-Length", fmt.Sprintf("%d", len(body)))
_ = resp.Write(w)
_ = resp.Body.Close()
}
// IsWebSocketUpgrade checks if the request is a WebSocket upgrade request.
func IsWebSocketUpgrade(req *http.Request) bool {
return strings.EqualFold(req.Header.Get("Upgrade"), "websocket") &&
strings.Contains(strings.ToLower(req.Header.Get("Connection")), "upgrade")
}

View File

@@ -0,0 +1,35 @@
package netutil
import "net"
// CountingConn wraps a net.Conn to count bytes read/written.
type CountingConn struct {
net.Conn
OnRead func(int64)
OnWrite func(int64)
}
// NewCountingConn creates a new CountingConn.
func NewCountingConn(conn net.Conn, onRead, onWrite func(int64)) *CountingConn {
return &CountingConn{
Conn: conn,
OnRead: onRead,
OnWrite: onWrite,
}
}
func (c *CountingConn) Read(p []byte) (int, error) {
n, err := c.Conn.Read(p)
if n > 0 && c.OnRead != nil {
c.OnRead(int64(n))
}
return n, err
}
func (c *CountingConn) Write(p []byte) (int, error) {
n, err := c.Conn.Write(p)
if n > 0 && c.OnWrite != nil {
c.OnWrite(int64(n))
}
return n, err
}

View File

@@ -0,0 +1,164 @@
package netutil
import (
"context"
"io"
"sync"
"time"
"drip/internal/shared/pool"
)
const tcpWaitTimeout = 10 * time.Second
type closeReader interface {
CloseRead() error
}
type closeWriter interface {
CloseWrite() error
}
type readDeadliner interface {
SetReadDeadline(t time.Time) error
}
// Pipe copies bytes bidirectionally between a and b (gost-like),
// and applies TCP half-close when supported.
func Pipe(ctx context.Context, a, b io.ReadWriteCloser) error {
return PipeWithCallbacksAndBufferSize(ctx, a, b, pool.SizeMedium, nil, nil)
}
// PipeWithCallbacks is Pipe with optional byte counters for each direction:
// onAToB is called with bytes copied from a -> b, onBToA for b -> a.
func PipeWithCallbacks(ctx context.Context, a, b io.ReadWriteCloser, onAToB func(n int64), onBToA func(n int64)) error {
return PipeWithCallbacksAndBufferSize(ctx, a, b, pool.SizeMedium, onAToB, onBToA)
}
// PipeWithBufferSize is Pipe with a custom buffer size.
func PipeWithBufferSize(ctx context.Context, a, b io.ReadWriteCloser, bufSize int) error {
return PipeWithCallbacksAndBufferSize(ctx, a, b, bufSize, nil, nil)
}
// PipeWithCallbacksAndBufferSize is PipeWithCallbacks with a custom buffer size.
func PipeWithCallbacksAndBufferSize(ctx context.Context, a, b io.ReadWriteCloser, bufSize int, onAToB func(n int64), onBToA func(n int64)) error {
if bufSize <= 0 {
bufSize = pool.SizeMedium
}
if bufSize > pool.SizeLarge {
bufSize = pool.SizeLarge
}
var wg sync.WaitGroup
wg.Add(2)
stopCh := make(chan struct{})
var closeOnce sync.Once
closeAll := func() {
closeOnce.Do(func() {
close(stopCh)
_ = a.Close()
_ = b.Close()
})
}
errCh := make(chan error, 2)
if ctx != nil {
go func() {
select {
case <-ctx.Done():
closeAll()
case <-stopCh:
}
}()
}
go func() {
defer wg.Done()
err := pipeBuffer(b, a, bufSize, onAToB, stopCh)
if err != nil {
errCh <- err
}
closeAll()
}()
go func() {
defer wg.Done()
err := pipeBuffer(a, b, bufSize, onBToA, stopCh)
if err != nil {
errCh <- err
}
closeAll()
}()
wg.Wait()
select {
case err := <-errCh:
return err
default:
return nil
}
}
func pipeBuffer(dst io.ReadWriteCloser, src io.ReadWriteCloser, bufSize int, onCopied func(n int64), stopCh <-chan struct{}) error {
bufPtr := pool.GetBuffer(bufSize)
defer pool.PutBuffer(bufPtr)
buf := (*bufPtr)[:bufSize]
_, err := copyBuffer(dst, src, buf, onCopied, stopCh)
if cr, ok := src.(closeReader); ok {
_ = cr.CloseRead()
}
if cw, ok := dst.(closeWriter); ok {
if e := cw.CloseWrite(); e != nil {
_ = dst.Close()
}
if rd, ok := dst.(readDeadliner); ok {
_ = rd.SetReadDeadline(time.Now().Add(tcpWaitTimeout))
}
} else {
_ = dst.Close()
if rd, ok := dst.(readDeadliner); ok {
_ = rd.SetReadDeadline(time.Now().Add(tcpWaitTimeout))
}
}
return err
}
func copyBuffer(dst io.Writer, src io.Reader, buf []byte, onCopied func(n int64), stopCh <-chan struct{}) (written int64, err error) {
for {
select {
case <-stopCh:
return written, io.EOF
default:
}
nr, er := src.Read(buf)
if nr > 0 {
nw, ew := dst.Write(buf[:nr])
if nw > 0 {
written += int64(nw)
if onCopied != nil {
onCopied(int64(nw))
}
}
if ew != nil {
return written, ew
}
if nr != nw {
return written, io.ErrShortWrite
}
}
if er != nil {
if er == io.EOF {
return written, nil
}
return written, er
}
}
}

View File

@@ -1,73 +0,0 @@
package pool
import (
"sync"
)
// AdaptiveBufferPool manages reusable buffers of different sizes
// This eliminates the massive memory allocation overhead seen in profiling
type AdaptiveBufferPool struct {
// Large buffers for streaming threshold (1MB)
largePool *sync.Pool
// Medium buffers for temporary reads (32KB)
mediumPool *sync.Pool
}
const (
// LargeBufferSize is 1MB for streaming threshold
LargeBufferSize = 1 * 1024 * 1024
// MediumBufferSize is 32KB for temporary reads
MediumBufferSize = 32 * 1024
)
// NewAdaptiveBufferPool creates a new adaptive buffer pool
func NewAdaptiveBufferPool() *AdaptiveBufferPool {
return &AdaptiveBufferPool{
largePool: &sync.Pool{
New: func() interface{} {
buf := make([]byte, LargeBufferSize)
return &buf
},
},
mediumPool: &sync.Pool{
New: func() interface{} {
buf := make([]byte, MediumBufferSize)
return &buf
},
},
}
}
// GetLarge returns a large buffer (1MB) from the pool
// The returned buffer should be returned via PutLarge when done
func (p *AdaptiveBufferPool) GetLarge() *[]byte {
return p.largePool.Get().(*[]byte)
}
// PutLarge returns a large buffer to the pool for reuse
func (p *AdaptiveBufferPool) PutLarge(buf *[]byte) {
if buf == nil {
return
}
// Reset to full capacity to allow reuse
*buf = (*buf)[:cap(*buf)]
p.largePool.Put(buf)
}
// GetMedium returns a medium buffer (32KB) from the pool
// The returned buffer should be returned via PutMedium when done
func (p *AdaptiveBufferPool) GetMedium() *[]byte {
return p.mediumPool.Get().(*[]byte)
}
// PutMedium returns a medium buffer to the pool for reuse
func (p *AdaptiveBufferPool) PutMedium(buf *[]byte) {
if buf == nil {
return
}
// Reset to full capacity to allow reuse
*buf = (*buf)[:cap(*buf)]
p.mediumPool.Put(buf)
}

View File

@@ -1,86 +0,0 @@
package pool
import (
"net/http"
"sync"
)
// HeaderPool manages a pool of http.Header objects for reuse.
type HeaderPool struct {
pool sync.Pool
}
// NewHeaderPool creates a new header pool
func NewHeaderPool() *HeaderPool {
return &HeaderPool{
pool: sync.Pool{
New: func() interface{} {
return make(http.Header, 12)
},
},
}
}
// Get retrieves a header from the pool.
func (p *HeaderPool) Get() http.Header {
h := p.pool.Get().(http.Header)
for k := range h {
delete(h, k)
}
return h
}
// Put returns a header to the pool.
func (p *HeaderPool) Put(h http.Header) {
if h == nil {
return
}
p.pool.Put(h)
}
// Clone creates a copy of src into dst, reusing dst's underlying storage
// This is more efficient than creating a new header from scratch
func (p *HeaderPool) Clone(dst, src http.Header) {
// Clear dst first
for k := range dst {
delete(dst, k)
}
// Copy all headers from src to dst
for k, vv := range src {
// Allocate new slice with exact capacity to avoid over-allocation
dst[k] = make([]string, len(vv))
copy(dst[k], vv)
}
}
// CloneWithExtra clones src into dst and adds/overwrites extra headers
// This is optimized for the common pattern of cloning + adding Host header
func (p *HeaderPool) CloneWithExtra(dst, src http.Header, extraKey, extraValue string) {
// Clear dst first
for k := range dst {
delete(dst, k)
}
// Copy all headers from src to dst
for k, vv := range src {
dst[k] = make([]string, len(vv))
copy(dst[k], vv)
}
// Set extra header (overwrite if exists)
dst.Set(extraKey, extraValue)
}
// globalHeaderPool is a package-level pool for convenience
var globalHeaderPool = NewHeaderPool()
// GetHeader retrieves a header from the global pool
func GetHeader() http.Header {
return globalHeaderPool.Get()
}
// PutHeader returns a header to the global pool
func PutHeader(h http.Header) {
globalHeaderPool.Put(h)
}

View File

@@ -2,81 +2,23 @@ package protocol
import (
"sync/atomic"
"time"
"drip/internal/shared/pool"
)
// AdaptivePoolManager dynamically adjusts buffer pool usage based on load
// AdaptivePoolManager tracks active connections for load monitoring
type AdaptivePoolManager struct {
activeConnections atomic.Int64
currentThreshold atomic.Int64
highLoadConnectionThreshold int64
midLoadConnectionThreshold int64
midLoadThreshold int64
highLoadThreshold int64
activeConnections atomic.Int64
}
var globalAdaptiveManager = NewAdaptivePoolManager()
func NewAdaptivePoolManager() *AdaptivePoolManager {
m := &AdaptivePoolManager{
highLoadConnectionThreshold: 300,
midLoadConnectionThreshold: 150,
midLoadThreshold: int64(pool.SizeLarge),
highLoadThreshold: int64(pool.SizeMedium),
}
m.currentThreshold.Store(m.midLoadThreshold)
go m.monitor()
return m
}
func (m *AdaptivePoolManager) monitor() {
ticker := time.NewTicker(1 * time.Second)
defer ticker.Stop()
for range ticker.C {
connections := m.activeConnections.Load()
if connections >= m.highLoadConnectionThreshold {
m.currentThreshold.Store(m.highLoadThreshold)
} else if connections < m.midLoadConnectionThreshold {
m.currentThreshold.Store(m.midLoadThreshold)
}
// Hysteresis zone (150-300): maintain current threshold
}
}
func (m *AdaptivePoolManager) GetThreshold() int {
return int(m.currentThreshold.Load())
}
func (m *AdaptivePoolManager) RegisterConnection() {
m.activeConnections.Add(1)
}
func (m *AdaptivePoolManager) UnregisterConnection() {
m.activeConnections.Add(-1)
}
func (m *AdaptivePoolManager) GetActiveConnections() int64 {
return m.activeConnections.Load()
}
func GetAdaptiveThreshold() int {
return globalAdaptiveManager.GetThreshold()
}
var globalAdaptiveManager = &AdaptivePoolManager{}
func RegisterConnection() {
globalAdaptiveManager.RegisterConnection()
globalAdaptiveManager.activeConnections.Add(1)
}
func UnregisterConnection() {
globalAdaptiveManager.UnregisterConnection()
globalAdaptiveManager.activeConnections.Add(-1)
}
func GetActiveConnections() int64 {
return globalAdaptiveManager.GetActiveConnections()
return globalAdaptiveManager.activeConnections.Load()
}

View File

@@ -1,162 +0,0 @@
package protocol
import (
"encoding/binary"
"errors"
)
// DataHeader represents a binary-encoded data header for data plane
// All data transmission uses pure binary encoding for performance
type DataHeader struct {
Type DataType
IsLast bool
StreamID string
RequestID string
}
// DataType represents the type of data frame
type DataType uint8
const (
DataTypeData DataType = 0x00 // 000
DataTypeResponse DataType = 0x01 // 001
DataTypeClose DataType = 0x02 // 010
DataTypeHTTPRequest DataType = 0x03 // 011
DataTypeHTTPResponse DataType = 0x04 // 100
DataTypeHTTPHead DataType = 0x05 // 101 - streaming headers (shared)
DataTypeHTTPBodyChunk DataType = 0x06 // 110 - streaming body chunks (shared)
// Reuse the same type codes for request streaming to stay within 3 bits.
DataTypeHTTPRequestHead DataType = DataTypeHTTPHead
DataTypeHTTPRequestBodyChunk DataType = DataTypeHTTPBodyChunk
)
// String returns the string representation of DataType
func (t DataType) String() string {
switch t {
case DataTypeData:
return "data"
case DataTypeResponse:
return "response"
case DataTypeClose:
return "close"
case DataTypeHTTPRequest:
return "http_request"
case DataTypeHTTPResponse:
return "http_response"
case DataTypeHTTPHead:
return "http_head"
case DataTypeHTTPBodyChunk:
return "http_body_chunk"
default:
return "unknown"
}
}
// FromString converts a string to DataType
func DataTypeFromString(s string) DataType {
switch s {
case "data":
return DataTypeData
case "response":
return DataTypeResponse
case "close":
return DataTypeClose
case "http_request":
return DataTypeHTTPRequest
case "http_response":
return DataTypeHTTPResponse
case "http_head":
return DataTypeHTTPHead
case "http_body_chunk":
return DataTypeHTTPBodyChunk
default:
return DataTypeData
}
}
// Binary format:
// +--------+--------+--------+--------+--------+
// | Flags | StreamID Length | RequestID Len |
// | 1 byte | 2 bytes | 2 bytes |
// +--------+--------+--------+--------+--------+
// | StreamID (variable) |
// +--------+--------+--------+--------+--------+
// | RequestID (variable) |
// +--------+--------+--------+--------+--------+
//
// Flags (8 bits):
// - Bit 0-2: Type (3 bits)
// - Bit 3: IsLast (1 bit)
// - Bit 4-7: Reserved (4 bits)
const (
binaryHeaderMinSize = 5 // 1 byte flags + 2 bytes streamID len + 2 bytes requestID len
)
// MarshalBinary encodes the header to binary format
func (h *DataHeader) MarshalBinary() []byte {
streamIDLen := len(h.StreamID)
requestIDLen := len(h.RequestID)
totalLen := binaryHeaderMinSize + streamIDLen + requestIDLen
buf := make([]byte, totalLen)
// Encode flags
flags := uint8(h.Type) & 0x07 // Type uses bits 0-2
if h.IsLast {
flags |= 0x08 // IsLast uses bit 3
}
buf[0] = flags
// Encode lengths (big-endian)
binary.BigEndian.PutUint16(buf[1:3], uint16(streamIDLen))
binary.BigEndian.PutUint16(buf[3:5], uint16(requestIDLen))
// Encode StreamID
offset := binaryHeaderMinSize
copy(buf[offset:], h.StreamID)
offset += streamIDLen
// Encode RequestID
copy(buf[offset:], h.RequestID)
return buf
}
// UnmarshalBinary decodes the header from binary format
func (h *DataHeader) UnmarshalBinary(data []byte) error {
if len(data) < binaryHeaderMinSize {
return errors.New("invalid binary header: too short")
}
// Decode flags
flags := data[0]
h.Type = DataType(flags & 0x07) // Bits 0-2
h.IsLast = (flags & 0x08) != 0 // Bit 3
// Decode lengths
streamIDLen := int(binary.BigEndian.Uint16(data[1:3]))
requestIDLen := int(binary.BigEndian.Uint16(data[3:5]))
// Validate total length
expectedLen := binaryHeaderMinSize + streamIDLen + requestIDLen
if len(data) < expectedLen {
return errors.New("invalid binary header: length mismatch")
}
// Decode StreamID
offset := binaryHeaderMinSize
h.StreamID = string(data[offset : offset+streamIDLen])
offset += streamIDLen
// Decode RequestID
h.RequestID = string(data[offset : offset+requestIDLen])
return nil
}
// Size returns the size of the binary-encoded header
func (h *DataHeader) Size() int {
return binaryHeaderMinSize + len(h.StreamID) + len(h.RequestID)
}

View File

@@ -1,34 +0,0 @@
package protocol
import (
json "github.com/goccy/go-json"
)
type FlowControlAction string
const (
FlowControlPause FlowControlAction = "pause"
FlowControlResume FlowControlAction = "resume"
)
type FlowControlMessage struct {
StreamID string `json:"stream_id"`
Action FlowControlAction `json:"action"`
}
func NewFlowControlFrame(streamID string, action FlowControlAction) *Frame {
msg := FlowControlMessage{
StreamID: streamID,
Action: action,
}
payload, _ := json.Marshal(&msg)
return NewFrame(FrameTypeFlowControl, payload)
}
func DecodeFlowControlMessage(payload []byte) (*FlowControlMessage, error) {
var msg FlowControlMessage
if err := json.Unmarshal(payload, &msg); err != nil {
return nil, err
}
return &msg, nil
}

View File

@@ -18,14 +18,14 @@ const (
type FrameType byte
const (
FrameTypeRegister FrameType = 0x01
FrameTypeRegisterAck FrameType = 0x02
FrameTypeHeartbeat FrameType = 0x03
FrameTypeHeartbeatAck FrameType = 0x04
FrameTypeData FrameType = 0x05
FrameTypeClose FrameType = 0x06
FrameTypeError FrameType = 0x07
FrameTypeFlowControl FrameType = 0x08
FrameTypeRegister FrameType = 0x01
FrameTypeRegisterAck FrameType = 0x02
FrameTypeHeartbeat FrameType = 0x03
FrameTypeHeartbeatAck FrameType = 0x04
FrameTypeClose FrameType = 0x05
FrameTypeError FrameType = 0x06
FrameTypeDataConnect FrameType = 0x07
FrameTypeDataConnectAck FrameType = 0x08
)
// String returns the string representation of frame type
@@ -39,14 +39,14 @@ func (t FrameType) String() string {
return "Heartbeat"
case FrameTypeHeartbeatAck:
return "HeartbeatAck"
case FrameTypeData:
return "Data"
case FrameTypeClose:
return "Close"
case FrameTypeError:
return "Error"
case FrameTypeFlowControl:
return "FlowControl"
case FrameTypeDataConnect:
return "DataConnect"
case FrameTypeDataConnectAck:
return "DataConnectAck"
default:
return fmt.Sprintf("Unknown(%d)", t)
}
@@ -56,6 +56,9 @@ type Frame struct {
Type FrameType
Payload []byte
poolBuffer *[]byte
// queuedBytes is set by FrameWriter when the frame is enqueued.
// It allows the writer to decrement backlog counters exactly once.
queuedBytes int64
}
func WriteFrame(w io.Writer, frame *Frame) error {
@@ -130,6 +133,8 @@ func (f *Frame) Release() {
f.poolBuffer = nil
f.Payload = nil
}
// Reset queued marker to avoid carrying over stale state if the frame is reused.
f.queuedBytes = 0
}
// NewFrame creates a new frame

View File

@@ -1,119 +0,0 @@
package protocol
import (
"errors"
json "github.com/goccy/go-json"
"github.com/vmihailenco/msgpack/v5"
)
// EncodeHTTPRequest encodes HTTPRequest using msgpack encoding (optimized)
func EncodeHTTPRequest(req *HTTPRequest) ([]byte, error) {
return msgpack.Marshal(req)
}
// DecodeHTTPRequest decodes HTTPRequest with automatic version detection
// Detects based on first byte: '{' = JSON, else = msgpack
func DecodeHTTPRequest(data []byte) (*HTTPRequest, error) {
if len(data) == 0 {
return nil, errors.New("empty data")
}
var req HTTPRequest
// Auto-detect: JSON starts with '{', msgpack starts with 0x80-0x8f (fixmap)
if data[0] == '{' {
// v1: JSON
if err := json.Unmarshal(data, &req); err != nil {
return nil, err
}
} else {
// v2: msgpack
if err := msgpack.Unmarshal(data, &req); err != nil {
return nil, err
}
}
return &req, nil
}
// EncodeHTTPRequestHead encodes HTTP request headers for streaming
func EncodeHTTPRequestHead(head *HTTPRequestHead) ([]byte, error) {
return msgpack.Marshal(head)
}
// DecodeHTTPRequestHead decodes HTTP request headers for streaming
func DecodeHTTPRequestHead(data []byte) (*HTTPRequestHead, error) {
if len(data) == 0 {
return nil, errors.New("empty data")
}
var head HTTPRequestHead
if data[0] == '{' {
if err := json.Unmarshal(data, &head); err != nil {
return nil, err
}
} else {
if err := msgpack.Unmarshal(data, &head); err != nil {
return nil, err
}
}
return &head, nil
}
// EncodeHTTPResponse encodes HTTPResponse using msgpack encoding (optimized)
func EncodeHTTPResponse(resp *HTTPResponse) ([]byte, error) {
return msgpack.Marshal(resp)
}
// DecodeHTTPResponse decodes HTTPResponse with automatic version detection
// Detects based on first byte: '{' = JSON, else = msgpack
func DecodeHTTPResponse(data []byte) (*HTTPResponse, error) {
if len(data) == 0 {
return nil, errors.New("empty data")
}
var resp HTTPResponse
// Auto-detect: JSON starts with '{', msgpack starts with 0x80-0x8f (fixmap)
if data[0] == '{' {
// v1: JSON
if err := json.Unmarshal(data, &resp); err != nil {
return nil, err
}
} else {
// v2: msgpack
if err := msgpack.Unmarshal(data, &resp); err != nil {
return nil, err
}
}
return &resp, nil
}
// EncodeHTTPResponseHead encodes HTTP response headers for streaming
func EncodeHTTPResponseHead(head *HTTPResponseHead) ([]byte, error) {
return msgpack.Marshal(head)
}
// DecodeHTTPResponseHead decodes HTTP response headers for streaming
func DecodeHTTPResponseHead(data []byte) (*HTTPResponseHead, error) {
if len(data) == 0 {
return nil, errors.New("empty data")
}
var head HTTPResponseHead
if data[0] == '{' {
if err := json.Unmarshal(data, &head); err != nil {
return nil, err
}
} else {
if err := msgpack.Unmarshal(data, &head); err != nil {
return nil, err
}
}
return &head, nil
}

View File

@@ -1,71 +0,0 @@
package protocol
// MessageType defines the type of tunnel message
type MessageType string
const (
// TypeRegister is sent when a client connects and gets a subdomain assigned
TypeRegister MessageType = "register"
// TypeRequest is sent from server to client when an HTTP request arrives
TypeRequest MessageType = "request"
// TypeResponse is sent from client to server with the HTTP response
TypeResponse MessageType = "response"
// TypeHeartbeat is sent periodically to keep the connection alive
TypeHeartbeat MessageType = "heartbeat"
// TypeError is sent when an error occurs
TypeError MessageType = "error"
)
// Message represents a tunnel protocol message
type Message struct {
Type MessageType `json:"type"`
ID string `json:"id,omitempty"`
Subdomain string `json:"subdomain,omitempty"`
Data map[string]interface{} `json:"data,omitempty"`
Error string `json:"error,omitempty"`
}
// HTTPRequest represents an HTTP request to be forwarded
type HTTPRequest struct {
Method string `json:"method"`
URL string `json:"url"`
Headers map[string][]string `json:"headers"`
Body []byte `json:"body,omitempty"`
}
// HTTPRequestHead represents HTTP request headers for streaming (no body)
type HTTPRequestHead struct {
Method string `json:"method"`
URL string `json:"url"`
Headers map[string][]string `json:"headers"`
ContentLength int64 `json:"content_length"` // -1 for unknown/chunked
}
// HTTPResponse represents an HTTP response from the local service
type HTTPResponse struct {
StatusCode int `json:"status_code"`
Status string `json:"status"`
Headers map[string][]string `json:"headers"`
Body []byte `json:"body,omitempty"`
}
// HTTPResponseHead represents HTTP response headers for streaming (no body)
type HTTPResponseHead struct {
StatusCode int `json:"status_code"`
Status string `json:"status"`
Headers map[string][]string `json:"headers"`
ContentLength int64 `json:"content_length"` // -1 for unknown/chunked
}
// RegisterData contains information sent when a tunnel is registered
type RegisterData struct {
Subdomain string `json:"subdomain"`
URL string `json:"url"`
Message string `json:"message"`
}
// ErrorData contains error information
type ErrorData struct {
Code string `json:"code"`
Message string `json:"message"`
}

View File

@@ -2,12 +2,23 @@ package protocol
import json "github.com/goccy/go-json"
// PoolCapabilities advertises client connection pool capabilities
type PoolCapabilities struct {
MaxDataConns int `json:"max_data_conns"` // Maximum data connections client supports
Version int `json:"version"` // Protocol version for pool features
}
// RegisterRequest is sent by client to register a tunnel
type RegisterRequest struct {
Token string `json:"token"` // Authentication token
CustomSubdomain string `json:"custom_subdomain"` // Optional custom subdomain
TunnelType TunnelType `json:"tunnel_type"` // http, tcp, udp
LocalPort int `json:"local_port"` // Local port to forward to
// Connection pool fields (optional, for multi-connection support)
ConnectionType string `json:"connection_type,omitempty"` // "primary" or empty for legacy
TunnelID string `json:"tunnel_id,omitempty"` // For data connections to join
PoolCapabilities *PoolCapabilities `json:"pool_capabilities,omitempty"` // Client pool capabilities
}
// RegisterResponse is sent by server after successful registration
@@ -16,6 +27,25 @@ type RegisterResponse struct {
Port int `json:"port,omitempty"` // Assigned TCP port (for TCP tunnels)
URL string `json:"url"` // Full tunnel URL
Message string `json:"message"` // Success message
// Connection pool fields (optional, for multi-connection support)
TunnelID string `json:"tunnel_id,omitempty"` // Unique tunnel identifier
SupportsDataConn bool `json:"supports_data_conn,omitempty"` // Server supports multi-connection
RecommendedConns int `json:"recommended_conns,omitempty"` // Suggested data connection count
}
// DataConnectRequest is sent by data connections to join a tunnel
type DataConnectRequest struct {
TunnelID string `json:"tunnel_id"` // Tunnel to join
Token string `json:"token"` // Same auth token as primary
ConnectionID string `json:"connection_id"` // Unique connection identifier
}
// DataConnectResponse acknowledges data connection
type DataConnectResponse struct {
Accepted bool `json:"accepted"` // Whether connection was accepted
ConnectionID string `json:"connection_id"` // Echoed connection ID
Message string `json:"message,omitempty"` // Optional message
}
// ErrorMessage represents an error
@@ -24,9 +54,6 @@ type ErrorMessage struct {
Message string `json:"message"` // Error message
}
// Note: DataHeader is now defined in binary_header.go as a pure binary structure
// TCPData has been removed - use DataHeader + raw bytes directly
// Marshal helpers for control plane messages (JSON encoding)
func MarshalJSON(v interface{}) ([]byte, error) {
return json.Marshal(v)

View File

@@ -1,96 +0,0 @@
package protocol
import (
"encoding/binary"
"errors"
"drip/internal/shared/pool"
)
// encodeDataPayload encodes a data header and payload into a frame payload.
func encodeDataPayload(header DataHeader, data []byte) ([]byte, error) {
streamIDLen := len(header.StreamID)
requestIDLen := len(header.RequestID)
totalLen := binaryHeaderMinSize + streamIDLen + requestIDLen + len(data)
payload := make([]byte, totalLen)
flags := uint8(header.Type) & 0x07
if header.IsLast {
flags |= 0x08
}
payload[0] = flags
binary.BigEndian.PutUint16(payload[1:3], uint16(streamIDLen))
binary.BigEndian.PutUint16(payload[3:5], uint16(requestIDLen))
offset := binaryHeaderMinSize
copy(payload[offset:], header.StreamID)
offset += streamIDLen
copy(payload[offset:], header.RequestID)
offset += requestIDLen
copy(payload[offset:], data)
return payload, nil
}
// EncodeDataPayloadPooled encodes with adaptive allocation based on load.
// Returns payload slice and pool buffer pointer (may be nil).
func EncodeDataPayloadPooled(header DataHeader, data []byte) (payload []byte, poolBuffer *[]byte, err error) {
streamIDLen := len(header.StreamID)
requestIDLen := len(header.RequestID)
totalLen := binaryHeaderMinSize + streamIDLen + requestIDLen + len(data)
dynamicThreshold := GetAdaptiveThreshold()
if totalLen < dynamicThreshold {
regularPayload, err := encodeDataPayload(header, data)
return regularPayload, nil, err
}
if totalLen > pool.SizeLarge {
regularPayload, err := encodeDataPayload(header, data)
return regularPayload, nil, err
}
poolBuffer = pool.GetBuffer(totalLen)
payload = (*poolBuffer)[:totalLen]
flags := uint8(header.Type) & 0x07
if header.IsLast {
flags |= 0x08
}
payload[0] = flags
binary.BigEndian.PutUint16(payload[1:3], uint16(streamIDLen))
binary.BigEndian.PutUint16(payload[3:5], uint16(requestIDLen))
offset := binaryHeaderMinSize
copy(payload[offset:], header.StreamID)
offset += streamIDLen
copy(payload[offset:], header.RequestID)
offset += requestIDLen
copy(payload[offset:], data)
return payload, poolBuffer, nil
}
// DecodeDataPayload decodes a frame payload into header and data.
func DecodeDataPayload(payload []byte) (DataHeader, []byte, error) {
if len(payload) < binaryHeaderMinSize {
return DataHeader{}, nil, errors.New("invalid payload: too short")
}
var header DataHeader
if err := header.UnmarshalBinary(payload); err != nil {
return DataHeader{}, nil, err
}
headerSize := header.Size()
if len(payload) < headerSize {
return DataHeader{}, nil, errors.New("invalid payload: data missing")
}
data := payload[headerSize:]
return header, data, nil
}

View File

@@ -4,20 +4,11 @@ import (
"sync"
)
// SafeFrame wraps Frame with automatic resource cleanup
type SafeFrame struct {
*Frame
once sync.Once
}
// NewSafeFrame creates a SafeFrame that implements io.Closer
func NewSafeFrame(frameType FrameType, payload []byte) *SafeFrame {
return &SafeFrame{
Frame: NewFrame(frameType, payload),
}
}
// Close implements io.Closer, ensures Release is called exactly once
func (sf *SafeFrame) Close() error {
sf.once.Do(func() {
if sf.Frame != nil {
@@ -27,14 +18,6 @@ func (sf *SafeFrame) Close() error {
return nil
}
// WithFrame wraps an existing Frame with automatic cleanup
func WithFrame(frame *Frame) *SafeFrame {
return &SafeFrame{Frame: frame}
}
// MustClose is a helper that calls Close and panics on error (for defer cleanup)
func (sf *SafeFrame) MustClose() {
if err := sf.Close(); err != nil {
panic(err)
}
}

View File

@@ -4,16 +4,18 @@ import (
"errors"
"io"
"sync"
"sync/atomic"
"time"
)
type FrameWriter struct {
conn io.Writer
queue chan *Frame
batch []*Frame
mu sync.Mutex
done chan struct{}
closed bool
conn io.Writer
queue chan *Frame
controlQueue chan *Frame
batch []*Frame
mu sync.Mutex
done chan struct{}
closed bool
maxBatch int
maxBatchWait time.Duration
@@ -24,13 +26,20 @@ type FrameWriter struct {
heartbeatControl chan struct{}
// Error handling
writeErr error
errOnce sync.Once
onWriteError func(error) // Callback for write errors
writeErr error
errOnce sync.Once
onWriteError func(error) // Callback for write errors
// Adaptive flushing
adaptiveFlush bool // Enable adaptive flush based on queue depth
lowConcurrencyThreshold int // Queue depth threshold for immediate flush
adaptiveFlush bool // Enable adaptive flush based on queue depth
lowConcurrencyThreshold int // Queue depth threshold for immediate flush
// Hooks
preWriteHook func(*Frame) // Called right before a frame is written to conn
// Backlog tracking
queuedFrames atomic.Int64
queuedBytes atomic.Int64
}
func NewFrameWriter(conn io.Writer) *FrameWriter {
@@ -41,8 +50,14 @@ func NewFrameWriter(conn io.Writer) *FrameWriter {
func NewFrameWriterWithConfig(conn io.Writer, maxBatch int, maxBatchWait time.Duration, queueSize int) *FrameWriter {
w := &FrameWriter{
conn: conn,
queue: make(chan *Frame, queueSize),
conn: conn,
queue: make(chan *Frame, queueSize),
controlQueue: make(chan *Frame, func() int {
if queueSize < 256 {
return queueSize
}
return 256
}()), // control path needs small, fast buffer
batch: make([]*Frame, 0, maxBatch),
maxBatch: maxBatch,
maxBatchWait: maxBatchWait,
@@ -74,6 +89,22 @@ func (w *FrameWriter) writeLoop() {
}()
for {
// Always drain control queue first to prioritize control/heartbeat frames.
select {
case frame, ok := <-w.controlQueue:
if !ok {
w.mu.Lock()
w.flushBatchLocked()
w.mu.Unlock()
return
}
w.mu.Lock()
w.flushFrameLocked(frame)
w.mu.Unlock()
continue
default:
}
select {
case frame, ok := <-w.queue:
if !ok {
@@ -105,8 +136,7 @@ func (w *FrameWriter) writeLoop() {
w.mu.Lock()
if w.heartbeatCallback != nil {
if frame := w.heartbeatCallback(); frame != nil {
w.batch = append(w.batch, frame)
w.flushBatchLocked()
w.flushFrameLocked(frame)
}
}
w.mu.Unlock()
@@ -139,22 +169,47 @@ func (w *FrameWriter) flushBatchLocked() {
}
for _, frame := range w.batch {
if err := WriteFrame(w.conn, frame); err != nil {
w.errOnce.Do(func() {
w.writeErr = err
if w.onWriteError != nil {
go w.onWriteError(err)
}
w.closed = true
})
}
frame.Release()
w.flushFrameLocked(frame)
}
w.batch = w.batch[:0]
}
// flushFrameLocked writes a single frame immediately. Caller must hold w.mu.
func (w *FrameWriter) flushFrameLocked(frame *Frame) {
if frame == nil {
return
}
if w.preWriteHook != nil {
w.preWriteHook(frame)
}
if err := WriteFrame(w.conn, frame); err != nil {
w.errOnce.Do(func() {
w.writeErr = err
if w.onWriteError != nil {
go w.onWriteError(err)
}
w.closed = true
})
}
w.unmarkQueued(frame)
frame.Release()
}
func (w *FrameWriter) WriteFrame(frame *Frame) error {
return w.WriteFrameWithCancel(frame, nil)
}
// WriteFrameWithCancel writes a frame with an optional cancellation channel
// If cancel is closed, the write will be aborted immediately
func (w *FrameWriter) WriteFrameWithCancel(frame *Frame, cancel <-chan struct{}) error {
if frame == nil {
return nil
}
w.mu.Lock()
if w.closed {
w.mu.Unlock()
@@ -165,10 +220,19 @@ func (w *FrameWriter) WriteFrame(frame *Frame) error {
}
w.mu.Unlock()
size := int64(len(frame.Payload) + FrameHeaderSize)
w.queuedFrames.Add(1)
w.queuedBytes.Add(size)
atomic.StoreInt64(&frame.queuedBytes, size)
// Try non-blocking first for best performance
select {
case w.queue <- frame:
return nil
case <-w.done:
w.queuedFrames.Add(-1)
w.queuedBytes.Add(-size)
atomic.StoreInt64(&frame.queuedBytes, 0)
w.mu.Lock()
err := w.writeErr
w.mu.Unlock()
@@ -176,6 +240,54 @@ func (w *FrameWriter) WriteFrame(frame *Frame) error {
return err
}
return errors.New("writer closed")
default:
}
// Queue full - block with cancellation support
if cancel != nil {
select {
case w.queue <- frame:
return nil
case <-w.done:
w.queuedFrames.Add(-1)
w.queuedBytes.Add(-size)
atomic.StoreInt64(&frame.queuedBytes, 0)
w.mu.Lock()
err := w.writeErr
w.mu.Unlock()
if err != nil {
return err
}
return errors.New("writer closed")
case <-cancel:
w.queuedFrames.Add(-1)
w.queuedBytes.Add(-size)
atomic.StoreInt64(&frame.queuedBytes, 0)
return errors.New("write cancelled")
}
}
// No cancel channel - block with timeout
select {
case w.queue <- frame:
return nil
case <-w.done:
w.queuedFrames.Add(-1)
w.queuedBytes.Add(-size)
atomic.StoreInt64(&frame.queuedBytes, 0)
w.mu.Lock()
err := w.writeErr
w.mu.Unlock()
if err != nil {
return err
}
return errors.New("writer closed")
case <-time.After(30 * time.Second):
w.queuedFrames.Add(-1)
w.queuedBytes.Add(-size)
atomic.StoreInt64(&frame.queuedBytes, 0)
return errors.New("write queue full timeout")
}
}
@@ -189,8 +301,14 @@ func (w *FrameWriter) Close() error {
w.mu.Unlock()
close(w.queue)
close(w.controlQueue)
for frame := range w.queue {
w.unmarkQueued(frame)
frame.Release()
}
for frame := range w.controlQueue {
w.unmarkQueued(frame)
frame.Release()
}
@@ -264,3 +382,97 @@ func (w *FrameWriter) DisableAdaptiveFlush() {
w.adaptiveFlush = false
w.mu.Unlock()
}
// WriteControl enqueues a control/prioritized frame to be written ahead of data frames.
func (w *FrameWriter) WriteControl(frame *Frame) error {
if frame == nil {
return nil
}
w.mu.Lock()
if w.closed {
w.mu.Unlock()
if w.writeErr != nil {
return w.writeErr
}
return errors.New("writer closed")
}
w.mu.Unlock()
size := int64(len(frame.Payload) + FrameHeaderSize)
w.queuedFrames.Add(1)
w.queuedBytes.Add(size)
atomic.StoreInt64(&frame.queuedBytes, size)
// Try non-blocking first
select {
case w.controlQueue <- frame:
return nil
case <-w.done:
w.queuedFrames.Add(-1)
w.queuedBytes.Add(-size)
atomic.StoreInt64(&frame.queuedBytes, 0)
w.mu.Lock()
err := w.writeErr
w.mu.Unlock()
if err != nil {
return err
}
return errors.New("writer closed")
default:
}
// Queue full - wait with timeout
select {
case w.controlQueue <- frame:
return nil
case <-w.done:
w.queuedFrames.Add(-1)
w.queuedBytes.Add(-size)
atomic.StoreInt64(&frame.queuedBytes, 0)
w.mu.Lock()
err := w.writeErr
w.mu.Unlock()
if err != nil {
return err
}
return errors.New("writer closed")
case <-time.After(50 * time.Millisecond):
// Control frames should have priority, shorter timeout
w.queuedFrames.Add(-1)
w.queuedBytes.Add(-size)
atomic.StoreInt64(&frame.queuedBytes, 0)
return errors.New("control queue full timeout")
}
}
// SetPreWriteHook registers a callback invoked just before a frame is written to the underlying writer.
func (w *FrameWriter) SetPreWriteHook(hook func(*Frame)) {
w.mu.Lock()
w.preWriteHook = hook
w.mu.Unlock()
}
// QueuedFrames returns the number of frames currently queued (data + control).
func (w *FrameWriter) QueuedFrames() int64 {
return w.queuedFrames.Load()
}
// QueuedBytes returns the approximate number of bytes currently queued.
func (w *FrameWriter) QueuedBytes() int64 {
return w.queuedBytes.Load()
}
// unmarkQueued decrements backlog counters for a frame once it is written or discarded.
func (w *FrameWriter) unmarkQueued(frame *Frame) {
if frame == nil {
return
}
size := atomic.SwapInt64(&frame.queuedBytes, 0)
if size <= 0 {
return
}
w.queuedFrames.Add(-1)
w.queuedBytes.Add(-size)
}

View File

@@ -0,0 +1,77 @@
package stats
// FormatBytes formats bytes to human readable string
func FormatBytes(bytes int64) string {
const (
KB = 1024
MB = KB * 1024
GB = MB * 1024
)
switch {
case bytes >= GB:
return formatFloat(float64(bytes)/float64(GB)) + " GB"
case bytes >= MB:
return formatFloat(float64(bytes)/float64(MB)) + " MB"
case bytes >= KB:
return formatFloat(float64(bytes)/float64(KB)) + " KB"
default:
return formatInt(bytes) + " B"
}
}
// FormatSpeed formats speed (bytes per second) to human readable string
func FormatSpeed(bytesPerSec int64) string {
if bytesPerSec == 0 {
return "0 B/s"
}
return FormatBytes(bytesPerSec) + "/s"
}
func formatFloat(f float64) string {
if f >= 100 {
return formatInt(int64(f))
} else if f >= 10 {
return formatOneDecimal(f)
}
return formatTwoDecimal(f)
}
func formatInt(i int64) string {
return intToStr(i)
}
func formatOneDecimal(f float64) string {
i := int64(f * 10)
whole := i / 10
frac := i % 10
return intToStr(whole) + "." + intToStr(frac)
}
func formatTwoDecimal(f float64) string {
i := int64(f * 100)
whole := i / 100
frac := i % 100
if frac < 10 {
return intToStr(whole) + ".0" + intToStr(frac)
}
return intToStr(whole) + "." + intToStr(frac)
}
func intToStr(i int64) string {
if i == 0 {
return "0"
}
if i < 0 {
return "-" + intToStr(-i)
}
var buf [20]byte
pos := len(buf)
for i > 0 {
pos--
buf[pos] = byte('0' + i%10)
i /= 10
}
return string(buf[pos:])
}

View File

@@ -0,0 +1,184 @@
package stats
import (
"sync"
"sync/atomic"
"time"
)
// TrafficStats tracks traffic statistics for a tunnel connection
type TrafficStats struct {
// Total bytes
totalBytesIn int64
totalBytesOut int64
// Request counts
totalRequests int64
activeConnections int64
// For speed calculation
lastBytesIn int64
lastBytesOut int64
lastTime time.Time
speedMu sync.Mutex
// Current speed (bytes per second)
speedIn int64
speedOut int64
// Start time
startTime time.Time
}
// NewTrafficStats creates a new traffic stats tracker
func NewTrafficStats() *TrafficStats {
now := time.Now()
return &TrafficStats{
startTime: now,
lastTime: now,
}
}
// AddBytesIn adds incoming bytes to the counter
func (s *TrafficStats) AddBytesIn(n int64) {
atomic.AddInt64(&s.totalBytesIn, n)
}
// AddBytesOut adds outgoing bytes to the counter
func (s *TrafficStats) AddBytesOut(n int64) {
atomic.AddInt64(&s.totalBytesOut, n)
}
// AddRequest increments the request counter
func (s *TrafficStats) AddRequest() {
atomic.AddInt64(&s.totalRequests, 1)
}
func (s *TrafficStats) IncActiveConnections() {
atomic.AddInt64(&s.activeConnections, 1)
}
func (s *TrafficStats) DecActiveConnections() {
v := atomic.AddInt64(&s.activeConnections, -1)
if v < 0 {
atomic.StoreInt64(&s.activeConnections, 0)
}
}
// GetTotalBytesIn returns total incoming bytes
func (s *TrafficStats) GetTotalBytesIn() int64 {
return atomic.LoadInt64(&s.totalBytesIn)
}
// GetTotalBytesOut returns total outgoing bytes
func (s *TrafficStats) GetTotalBytesOut() int64 {
return atomic.LoadInt64(&s.totalBytesOut)
}
// GetTotalRequests returns total request count
func (s *TrafficStats) GetTotalRequests() int64 {
return atomic.LoadInt64(&s.totalRequests)
}
func (s *TrafficStats) GetActiveConnections() int64 {
return atomic.LoadInt64(&s.activeConnections)
}
// GetTotalBytes returns total bytes (in + out)
func (s *TrafficStats) GetTotalBytes() int64 {
return s.GetTotalBytesIn() + s.GetTotalBytesOut()
}
// UpdateSpeed calculates current transfer speed
// Should be called periodically (e.g., every second)
func (s *TrafficStats) UpdateSpeed() {
s.speedMu.Lock()
defer s.speedMu.Unlock()
now := time.Now()
elapsed := now.Sub(s.lastTime).Seconds()
// Require minimum interval of 100ms to avoid division issues
if elapsed < 0.1 {
return
}
currentIn := atomic.LoadInt64(&s.totalBytesIn)
currentOut := atomic.LoadInt64(&s.totalBytesOut)
deltaIn := currentIn - s.lastBytesIn
deltaOut := currentOut - s.lastBytesOut
// Calculate instantaneous speed
if deltaIn > 0 {
s.speedIn = int64(float64(deltaIn) / elapsed)
} else {
// No new bytes - set speed to 0 immediately
s.speedIn = 0
}
if deltaOut > 0 {
s.speedOut = int64(float64(deltaOut) / elapsed)
} else {
// No new bytes - set speed to 0 immediately
s.speedOut = 0
}
s.lastBytesIn = currentIn
s.lastBytesOut = currentOut
s.lastTime = now
}
// GetSpeedIn returns current incoming speed in bytes per second
func (s *TrafficStats) GetSpeedIn() int64 {
s.speedMu.Lock()
defer s.speedMu.Unlock()
return s.speedIn
}
// GetSpeedOut returns current outgoing speed in bytes per second
func (s *TrafficStats) GetSpeedOut() int64 {
s.speedMu.Lock()
defer s.speedMu.Unlock()
return s.speedOut
}
// GetUptime returns how long the connection has been active
func (s *TrafficStats) GetUptime() time.Duration {
return time.Since(s.startTime)
}
// Snapshot returns a snapshot of all stats
type Snapshot struct {
TotalBytesIn int64
TotalBytesOut int64
TotalBytes int64
TotalRequests int64
ActiveConnections int64
SpeedIn int64 // bytes per second
SpeedOut int64 // bytes per second
Uptime time.Duration
}
// GetSnapshot returns a snapshot of current stats
func (s *TrafficStats) GetSnapshot() Snapshot {
s.speedMu.Lock()
speedIn := s.speedIn
speedOut := s.speedOut
s.speedMu.Unlock()
totalIn := atomic.LoadInt64(&s.totalBytesIn)
totalOut := atomic.LoadInt64(&s.totalBytesOut)
active := atomic.LoadInt64(&s.activeConnections)
return Snapshot{
TotalBytesIn: totalIn,
TotalBytesOut: totalOut,
TotalBytes: totalIn + totalOut,
TotalRequests: atomic.LoadInt64(&s.totalRequests),
ActiveConnections: active,
SpeedIn: speedIn,
SpeedOut: speedOut,
Uptime: time.Since(s.startTime),
}
}

View File

@@ -0,0 +1,117 @@
package ui
import (
"fmt"
)
// RenderConfigInit renders config initialization UI
func RenderConfigInit() string {
title := "Drip Configuration Setup"
box := boxStyle.Width(50)
return "\n" + box.Render(titleStyle.Render(title)) + "\n"
}
// RenderConfigShow renders the config display
func RenderConfigShow(server, token string, tokenHidden bool, tlsEnabled bool, configPath string) string {
lines := []string{
KeyValue("Server", server),
}
if token != "" {
if tokenHidden {
if len(token) > 10 {
displayToken := token[:3] + "***" + token[len(token)-3:]
lines = append(lines, KeyValue("Token", Muted(displayToken+" (hidden)")))
} else {
lines = append(lines, KeyValue("Token", Muted(token[:3]+"*** (hidden)")))
}
} else {
lines = append(lines, KeyValue("Token", token))
}
} else {
lines = append(lines, KeyValue("Token", Muted("(not set)")))
}
tlsStatus := "enabled"
if !tlsEnabled {
tlsStatus = "disabled"
}
lines = append(lines, KeyValue("TLS", tlsStatus))
lines = append(lines, KeyValue("Config", Muted(configPath)))
return Info("Current Configuration", lines...)
}
// RenderConfigSaved renders config saved message
func RenderConfigSaved(configPath string) string {
return SuccessBox(
"Configuration Saved",
Muted("Config saved to: ")+configPath,
"",
Muted("You can now use 'drip' without --server and --token flags"),
)
}
// RenderConfigUpdated renders config updated message
func RenderConfigUpdated(updates []string) string {
lines := make([]string, len(updates)+1)
for i, update := range updates {
lines[i] = Success(update)
}
lines[len(updates)] = ""
lines = append(lines, Muted("Configuration has been updated"))
return SuccessBox("Configuration Updated", lines...)
}
// RenderConfigDeleted renders config deleted message
func RenderConfigDeleted() string {
return SuccessBox("Configuration Deleted", Muted("Configuration file has been removed"))
}
// RenderConfigValidation renders config validation results
func RenderConfigValidation(serverValid bool, serverMsg string, tokenSet bool, tokenMsg string, tlsEnabled bool) string {
lines := []string{}
if serverValid {
lines = append(lines, Success(serverMsg))
} else {
lines = append(lines, Error(serverMsg))
}
if tokenSet {
lines = append(lines, Success(tokenMsg))
} else {
lines = append(lines, Warning(tokenMsg))
}
if tlsEnabled {
lines = append(lines, Success("TLS is enabled"))
} else {
lines = append(lines, Warning("TLS is disabled (not recommended for production)"))
}
lines = append(lines, "")
lines = append(lines, Muted("Configuration validation complete"))
if serverValid && tokenSet && tlsEnabled {
return SuccessBox("Configuration Valid", lines...)
}
return WarningBox("Configuration Validation", lines...)
}
// RenderDaemonStarted renders daemon started message
func RenderDaemonStarted(tunnelType string, port int, pid int, logPath string) string {
lines := []string{
KeyValue("Type", Highlight(tunnelType)),
KeyValue("Port", fmt.Sprintf("%d", port)),
KeyValue("PID", fmt.Sprintf("%d", pid)),
"",
Muted("Commands:"),
Cyan(" drip list") + Muted(" Check tunnel status"),
Cyan(fmt.Sprintf(" drip attach %s %d", tunnelType, port)) + Muted(" View logs"),
Cyan(fmt.Sprintf(" drip stop %s %d", tunnelType, port)) + Muted(" Stop tunnel"),
"",
Muted("Logs: ") + mutedStyle.Render(logPath),
}
return SuccessBox("Tunnel Started in Background", lines...)
}

View File

@@ -0,0 +1,184 @@
package ui
import (
"github.com/charmbracelet/lipgloss"
)
var (
// Colors inspired by Vercel CLI
successColor = lipgloss.Color("#0070F3")
warningColor = lipgloss.Color("#F5A623")
errorColor = lipgloss.Color("#E00")
mutedColor = lipgloss.Color("#888")
highlightColor = lipgloss.Color("#0070F3")
cyanColor = lipgloss.Color("#50E3C2")
// Box styles - Vercel-like clean box
boxStyle = lipgloss.NewStyle().
Border(lipgloss.RoundedBorder()).
Padding(1, 2).
MarginTop(1).
MarginBottom(1)
successBoxStyle = boxStyle.BorderForeground(successColor)
warningBoxStyle = boxStyle.BorderForeground(warningColor)
errorBoxStyle = boxStyle.BorderForeground(errorColor)
// Text styles
titleStyle = lipgloss.NewStyle().
Bold(true)
subtitleStyle = lipgloss.NewStyle().
Foreground(mutedColor)
successStyle = lipgloss.NewStyle().
Foreground(successColor).
Bold(true)
errorStyle = lipgloss.NewStyle().
Foreground(errorColor).
Bold(true)
warningStyle = lipgloss.NewStyle().
Foreground(warningColor).
Bold(true)
mutedStyle = lipgloss.NewStyle().
Foreground(mutedColor)
highlightStyle = lipgloss.NewStyle().
Foreground(highlightColor).
Bold(true)
cyanStyle = lipgloss.NewStyle().
Foreground(cyanColor)
urlStyle = lipgloss.NewStyle().
Foreground(highlightColor).
Underline(true).
Bold(true)
labelStyle = lipgloss.NewStyle().
Foreground(mutedColor).
Width(12)
valueStyle = lipgloss.NewStyle().
Bold(true)
// Table styles (padding handled manually for consistent Windows output)
tableHeaderStyle = lipgloss.NewStyle().
Foreground(mutedColor).
Bold(true)
)
// Success returns a styled success message
func Success(text string) string {
return successStyle.Render("✓ " + text)
}
// Error returns a styled error message
func Error(text string) string {
return errorStyle.Render("✗ " + text)
}
// Warning returns a styled warning message
func Warning(text string) string {
return warningStyle.Render("⚠ " + text)
}
// Muted returns a styled muted text
func Muted(text string) string {
return mutedStyle.Render(text)
}
// Highlight returns a styled highlighted text
func Highlight(text string) string {
return highlightStyle.Render(text)
}
// Cyan returns a styled cyan text
func Cyan(text string) string {
return cyanStyle.Render(text)
}
// URL returns a styled URL
func URL(text string) string {
return urlStyle.Render(text)
}
// Title returns a styled title
func Title(text string) string {
return titleStyle.Render(text)
}
// Subtitle returns a styled subtitle
func Subtitle(text string) string {
return subtitleStyle.Render(text)
}
// KeyValue returns a styled key-value pair
func KeyValue(key, value string) string {
return labelStyle.Render(key+":") + " " + valueStyle.Render(value)
}
// Info renders an info box (Vercel-style)
func Info(title string, lines ...string) string {
content := titleStyle.Render(title)
if len(lines) > 0 {
content += "\n\n"
for i, line := range lines {
if i > 0 {
content += "\n"
}
content += line
}
}
return boxStyle.Render(content)
}
// SuccessBox renders a success box
func SuccessBox(title string, lines ...string) string {
content := successStyle.Render("✓ " + title)
if len(lines) > 0 {
content += "\n\n"
for i, line := range lines {
if i > 0 {
content += "\n"
}
content += line
}
}
return successBoxStyle.Render(content)
}
// WarningBox renders a warning box
func WarningBox(title string, lines ...string) string {
content := warningStyle.Render("⚠ " + title)
if len(lines) > 0 {
content += "\n\n"
for i, line := range lines {
if i > 0 {
content += "\n"
}
content += line
}
}
return warningBoxStyle.Render(content)
}
// ErrorBox renders an error box
func ErrorBox(title string, lines ...string) string {
content := errorStyle.Render("✗ " + title)
if len(lines) > 0 {
content += "\n\n"
for i, line := range lines {
if i > 0 {
content += "\n"
}
content += line
}
}
return errorBoxStyle.Render(content)
}

145
internal/shared/ui/table.go Normal file
View File

@@ -0,0 +1,145 @@
package ui
import (
"fmt"
"runtime"
"strings"
"github.com/charmbracelet/lipgloss"
)
// Table represents a simple table for CLI output
type Table struct {
headers []string
rows [][]string
title string
}
// NewTable creates a new table
func NewTable(headers []string) *Table {
return &Table{
headers: headers,
rows: [][]string{},
}
}
// WithTitle sets the table title
func (t *Table) WithTitle(title string) *Table {
t.title = title
return t
}
// AddRow adds a row to the table
func (t *Table) AddRow(row []string) *Table {
t.rows = append(t.rows, row)
return t
}
// Render renders the table (Vercel-style)
func (t *Table) Render() string {
if len(t.rows) == 0 {
return ""
}
// Calculate column widths
colWidths := make([]int, len(t.headers))
for i, header := range t.headers {
colWidths[i] = lipgloss.Width(header)
}
for _, row := range t.rows {
for i, cell := range row {
if i < len(colWidths) {
width := lipgloss.Width(cell)
if width > colWidths[i] {
colWidths[i] = width
}
}
}
}
var output strings.Builder
// Title
if t.title != "" {
output.WriteString("\n")
output.WriteString(titleStyle.Render(t.title))
output.WriteString("\n\n")
}
// Header
headerParts := make([]string, len(t.headers))
for i, header := range t.headers {
styled := tableHeaderStyle.Render(header)
headerParts[i] = padRight(styled, colWidths[i])
}
output.WriteString(strings.Join(headerParts, " "))
output.WriteString("\n")
// Separator line
separatorChar := "─"
if runtime.GOOS == "windows" {
separatorChar = "-"
}
separatorParts := make([]string, len(t.headers))
for i := range t.headers {
separatorParts[i] = mutedStyle.Render(strings.Repeat(separatorChar, colWidths[i]))
}
output.WriteString(strings.Join(separatorParts, " "))
output.WriteString("\n")
// Rows
for _, row := range t.rows {
rowParts := make([]string, len(t.headers))
for i, cell := range row {
if i < len(colWidths) {
rowParts[i] = padRight(cell, colWidths[i])
}
}
output.WriteString(strings.Join(rowParts, " "))
output.WriteString("\n")
}
output.WriteString("\n")
return output.String()
}
// padRight pads
func padRight(text string, targetWidth int) string {
visibleWidth := lipgloss.Width(text)
if visibleWidth >= targetWidth {
return text
}
padding := strings.Repeat(" ", targetWidth-visibleWidth)
return text + padding
}
// Print prints the table
func (t *Table) Print() {
fmt.Print(t.Render())
}
// RenderList renders a simple list with bullet points
func RenderList(items []string) string {
bullet := "•"
if runtime.GOOS == "windows" {
bullet = "*"
}
var output strings.Builder
for _, item := range items {
output.WriteString(mutedStyle.Render(" " + bullet + " "))
output.WriteString(item)
output.WriteString("\n")
}
return output.String()
}
// RenderNumberedList renders a numbered list
func RenderNumberedList(items []string) string {
var output strings.Builder
for i, item := range items {
output.WriteString(mutedStyle.Render(fmt.Sprintf(" %d. ", i+1)))
output.WriteString(item)
output.WriteString("\n")
}
return output.String()
}

View File

@@ -0,0 +1,251 @@
package ui
import (
"fmt"
"strings"
"time"
"github.com/charmbracelet/lipgloss"
)
const (
tunnelCardWidth = 76
statsColumnWidth = 32
)
var (
latencyFastColor = lipgloss.Color("#22c55e") // green
latencyYellowColor = lipgloss.Color("#eab308") // yellow
latencyOrangeColor = lipgloss.Color("#f97316") // orange
latencyRedColor = lipgloss.Color("#ef4444") // red
)
// TunnelStatus represents the status of a tunnel
type TunnelStatus struct {
Type string // "http", "https", "tcp"
URL string // Public URL
LocalAddr string // Local address
Latency time.Duration // Current latency
BytesIn int64 // Bytes received
BytesOut int64 // Bytes sent
SpeedIn float64 // Download speed
SpeedOut float64 // Upload speed
TotalRequest int64 // Total requests
}
// RenderTunnelConnected renders the tunnel connection card
func RenderTunnelConnected(status *TunnelStatus) string {
icon, typeStr, accent := tunnelVisuals(status.Type)
card := lipgloss.NewStyle().
Border(lipgloss.RoundedBorder()).
BorderForeground(accent).
Padding(1, 2).
Width(tunnelCardWidth)
typeBadge := lipgloss.NewStyle().
Background(accent).
Foreground(lipgloss.Color("#f8fafc")).
Bold(true).
Padding(0, 1).
Render(strings.ToUpper(typeStr) + " TUNNEL")
headline := lipgloss.JoinHorizontal(
lipgloss.Left,
lipgloss.NewStyle().Foreground(accent).Render(icon),
lipgloss.NewStyle().Bold(true).MarginLeft(1).Render("Tunnel Connected"),
lipgloss.NewStyle().MarginLeft(2).Render(typeBadge),
)
urlLine := lipgloss.JoinHorizontal(
lipgloss.Left,
urlStyle.Foreground(accent).Render(status.URL),
lipgloss.NewStyle().MarginLeft(1).Foreground(mutedColor).Render("(forwarded address)"),
)
forwardLine := lipgloss.NewStyle().
MarginLeft(2).
Render(Muted("⇢ ") + valueStyle.Render(status.LocalAddr))
hint := lipgloss.NewStyle().
Foreground(latencyOrangeColor).
Render("Ctrl+C to stop • reconnects automatically")
content := lipgloss.JoinVertical(
lipgloss.Left,
headline,
"",
urlLine,
forwardLine,
"",
hint,
)
return "\n" + card.Render(content) + "\n"
}
// RenderTunnelStats renders real-time tunnel statistics in a card
func RenderTunnelStats(status *TunnelStatus) string {
latencyStr := formatLatency(status.Latency)
trafficStr := fmt.Sprintf("↓ %s ↑ %s", formatBytes(status.BytesIn), formatBytes(status.BytesOut))
speedStr := fmt.Sprintf("↓ %s ↑ %s", formatSpeed(status.SpeedIn), formatSpeed(status.SpeedOut))
requestsStr := fmt.Sprintf("%d", status.TotalRequest)
_, _, accent := tunnelVisuals(status.Type)
requestLabel := "Requests"
if status.Type == "tcp" {
requestLabel = "Connections"
}
header := lipgloss.JoinHorizontal(
lipgloss.Left,
lipgloss.NewStyle().Foreground(accent).Render("◉"),
lipgloss.NewStyle().Bold(true).MarginLeft(1).Render("Live Metrics"),
)
row1 := lipgloss.JoinHorizontal(
lipgloss.Top,
statColumn("Latency", latencyStr, statsColumnWidth),
statColumn(requestLabel, highlightStyle.Render(requestsStr), statsColumnWidth),
)
row2 := lipgloss.JoinHorizontal(
lipgloss.Top,
statColumn("Traffic", Cyan(trafficStr), statsColumnWidth),
statColumn("Speed", warningStyle.Render(speedStr), statsColumnWidth),
)
card := lipgloss.NewStyle().
Border(lipgloss.RoundedBorder()).
BorderForeground(accent).
Padding(1, 2).
Width(tunnelCardWidth)
body := lipgloss.JoinVertical(
lipgloss.Left,
header,
"",
row1,
row2,
)
return "\n" + card.Render(body) + "\n"
}
// RenderConnecting renders the connecting message
func RenderConnecting(serverAddr string, attempt int, maxAttempts int) string {
if attempt == 0 {
return Highlight("◌") + " Connecting to " + Muted(serverAddr) + "..."
}
return Warning(fmt.Sprintf("◌ Reconnecting to %s (attempt %d/%d)...", serverAddr, attempt, maxAttempts))
}
// RenderConnectionFailed renders connection failure message
func RenderConnectionFailed(err error) string {
return Error(fmt.Sprintf("Connection failed: %v", err))
}
// RenderShuttingDown renders shutdown message
func RenderShuttingDown() string {
return Warning("⏹ Shutting down...")
}
// RenderConnectionLost renders connection lost message
func RenderConnectionLost() string {
return Error("⚠ Connection lost!")
}
// RenderRetrying renders retry message
func RenderRetrying(interval time.Duration) string {
return Muted(fmt.Sprintf(" Retrying in %v...", interval))
}
// formatLatency formats latency with color
func formatLatency(d time.Duration) string {
if d == 0 {
return mutedStyle.Render("measuring...")
}
ms := d.Milliseconds()
var style lipgloss.Style
switch {
case ms < 50:
style = lipgloss.NewStyle().Foreground(latencyFastColor)
case ms < 150:
style = lipgloss.NewStyle().Foreground(latencyYellowColor)
case ms < 300:
style = lipgloss.NewStyle().Foreground(latencyOrangeColor)
default:
style = lipgloss.NewStyle().Foreground(latencyRedColor)
}
if ms == 0 {
us := d.Microseconds()
return style.Render(fmt.Sprintf("%dµs", us))
}
return style.Render(fmt.Sprintf("%dms", ms))
}
// formatBytes formats bytes to human readable format
func formatBytes(bytes int64) string {
const unit = 1024
if bytes < unit {
return fmt.Sprintf("%d B", bytes)
}
div, exp := int64(unit), 0
for n := bytes / unit; n >= unit; n /= unit {
div *= unit
exp++
}
return fmt.Sprintf("%.1f %cB", float64(bytes)/float64(div), "KMGTPE"[exp])
}
// formatSpeed formats speed to human readable format
func formatSpeed(bytesPerSec float64) string {
const unit = 1024.0
if bytesPerSec < unit {
return fmt.Sprintf("%.0f B/s", bytesPerSec)
}
div, exp := unit, 0
for n := bytesPerSec / unit; n >= unit; n /= unit {
div *= unit
exp++
}
return fmt.Sprintf("%.1f %cB/s", bytesPerSec/div, "KMGTPE"[exp])
}
func statColumn(label, value string, width int) string {
labelView := lipgloss.NewStyle().
Foreground(mutedColor).
Render(strings.ToUpper(label))
block := lipgloss.JoinHorizontal(
lipgloss.Left,
labelView,
lipgloss.NewStyle().MarginLeft(1).Render(value),
)
if width <= 0 {
return block
}
return lipgloss.NewStyle().
Width(width).
Render(block)
}
func tunnelVisuals(tunnelType string) (string, string, lipgloss.Color) {
switch tunnelType {
case "http":
return "🚀", "HTTP", lipgloss.Color("#0070F3")
case "https":
return "🔒", "HTTPS", lipgloss.Color("#2D8CFF")
case "tcp":
return "🔌", "TCP", lipgloss.Color("#50E3C2")
default:
return "🌐", strings.ToUpper(tunnelType), lipgloss.Color("#0070F3")
}
}