From cfb1b3fe31c37db79a67434ab620ddb0eca41faf Mon Sep 17 00:00:00 2001 From: Pascal Fischer <32096965+pascal-fischer@users.noreply.github.com> Date: Tue, 5 May 2026 18:40:42 +0200 Subject: [PATCH] [proxy] consolidate mapping update (#6072) --- management/internals/shared/grpc/proxy.go | 118 ++++++--- .../shared/grpc/proxy_snapshot_test.go | 174 ++++++++++++++ .../internals/shared/grpc/proxy_test.go | 3 + proxy/management_integration_test.go | 50 ++-- proxy/server.go | 45 +++- proxy/snapshot_reconcile_test.go | 227 ++++++++++++++++++ 6 files changed, 559 insertions(+), 58 deletions(-) create mode 100644 management/internals/shared/grpc/proxy_snapshot_test.go create mode 100644 proxy/snapshot_reconcile_test.go diff --git a/management/internals/shared/grpc/proxy.go b/management/internals/shared/grpc/proxy.go index d811a0f69..6763a3ba3 100644 --- a/management/internals/shared/grpc/proxy.go +++ b/management/internals/shared/grpc/proxy.go @@ -11,6 +11,8 @@ import ( "fmt" "net/http" "net/url" + "os" + "strconv" "strings" "sync" "time" @@ -82,11 +84,40 @@ type ProxyServiceServer struct { // Store for PKCE verifiers pkceVerifierStore *PKCEVerifierStore + // tokenTTL is the lifetime of one-time tokens generated for proxy + // authentication. Defaults to defaultProxyTokenTTL when zero. + tokenTTL time.Duration + + // snapshotBatchSize is the number of mappings per gRPC message during + // initial snapshot delivery. Configurable via NB_PROXY_SNAPSHOT_BATCH_SIZE. + snapshotBatchSize int + cancel context.CancelFunc } const pkceVerifierTTL = 10 * time.Minute +const defaultProxyTokenTTL = 5 * time.Minute + +const defaultSnapshotBatchSize = 500 + +func snapshotBatchSizeFromEnv() int { + if v := os.Getenv("NB_PROXY_SNAPSHOT_BATCH_SIZE"); v != "" { + if n, err := strconv.Atoi(v); err == nil && n > 0 { + return n + } + } + return defaultSnapshotBatchSize +} + +// proxyTokenTTL returns the configured token TTL or the default when unset. +func (s *ProxyServiceServer) proxyTokenTTL() time.Duration { + if s.tokenTTL > 0 { + return s.tokenTTL + } + return defaultProxyTokenTTL +} + // proxyConnection represents a connected proxy type proxyConnection struct { proxyID string @@ -110,6 +141,7 @@ func NewProxyServiceServer(accessLogMgr accesslogs.Manager, tokenStore *OneTimeT peersManager: peersManager, usersManager: usersManager, proxyManager: proxyMgr, + snapshotBatchSize: snapshotBatchSizeFromEnv(), cancel: cancel, } go s.cleanupStaleProxies(ctx) @@ -192,11 +224,6 @@ func (s *ProxyServiceServer) GetMappingUpdate(req *proto.GetMappingUpdateRequest cancel: cancel, } - s.connectedProxies.Store(proxyID, conn) - if err := s.proxyController.RegisterProxyToCluster(ctx, conn.address, proxyID); err != nil { - log.WithContext(ctx).Warnf("Failed to register proxy %s in cluster: %v", proxyID, err) - } - // Register proxy in database with capabilities var caps *proxy.Capabilities if c := req.GetCapabilities(); c != nil { @@ -209,13 +236,31 @@ func (s *ProxyServiceServer) GetMappingUpdate(req *proto.GetMappingUpdateRequest proxyRecord, err := s.proxyManager.Connect(ctx, proxyID, sessionID, proxyAddress, peerInfo, caps) if err != nil { log.WithContext(ctx).Warnf("failed to register proxy %s in database: %v", proxyID, err) - s.connectedProxies.CompareAndDelete(proxyID, conn) - if unregErr := s.proxyController.UnregisterProxyFromCluster(ctx, conn.address, proxyID); unregErr != nil { - log.WithContext(ctx).Debugf("cleanup after Connect failure for proxy %s: %v", proxyID, unregErr) - } + cancel() return status.Errorf(codes.Internal, "register proxy in database: %v", err) } + s.connectedProxies.Store(proxyID, conn) + if err := s.proxyController.RegisterProxyToCluster(ctx, conn.address, proxyID); err != nil { + log.WithContext(ctx).Warnf("Failed to register proxy %s in cluster: %v", proxyID, err) + } + + if err := s.sendSnapshot(ctx, conn); err != nil { + if s.connectedProxies.CompareAndDelete(proxyID, conn) { + if unregErr := s.proxyController.UnregisterProxyFromCluster(context.Background(), conn.address, proxyID); unregErr != nil { + log.WithContext(ctx).Debugf("cleanup after snapshot failure for proxy %s: %v", proxyID, unregErr) + } + } + cancel() + if disconnErr := s.proxyManager.Disconnect(context.Background(), proxyID, sessionID); disconnErr != nil { + log.WithContext(ctx).Debugf("cleanup after snapshot failure for proxy %s: %v", proxyID, disconnErr) + } + return fmt.Errorf("send snapshot to proxy %s: %w", proxyID, err) + } + + errChan := make(chan error, 2) + go s.sender(conn, errChan) + log.WithFields(log.Fields{ "proxy_id": proxyID, "session_id": sessionID, @@ -241,13 +286,6 @@ func (s *ProxyServiceServer) GetMappingUpdate(req *proto.GetMappingUpdateRequest log.Infof("Proxy %s session %s disconnected", proxyID, sessionID) }() - if err := s.sendSnapshot(ctx, conn); err != nil { - return fmt.Errorf("send snapshot to proxy %s: %w", proxyID, err) - } - - errChan := make(chan error, 2) - go s.sender(conn, errChan) - go s.heartbeat(connCtx, proxyRecord) select { @@ -290,22 +328,27 @@ func (s *ProxyServiceServer) sendSnapshot(ctx context.Context, conn *proxyConnec return err } + // Send mappings in batches to reduce per-message gRPC overhead while + // staying well within the default 4 MB message size limit. + for i := 0; i < len(mappings); i += s.snapshotBatchSize { + end := i + s.snapshotBatchSize + if end > len(mappings) { + end = len(mappings) + } + if err := conn.stream.Send(&proto.GetMappingUpdateResponse{ + Mapping: mappings[i:end], + InitialSyncComplete: end == len(mappings), + }); err != nil { + return fmt.Errorf("send snapshot batch: %w", err) + } + } + if len(mappings) == 0 { if err := conn.stream.Send(&proto.GetMappingUpdateResponse{ InitialSyncComplete: true, }); err != nil { return fmt.Errorf("send snapshot completion: %w", err) } - return nil - } - - for i, m := range mappings { - if err := conn.stream.Send(&proto.GetMappingUpdateResponse{ - Mapping: []*proto.ProxyMapping{m}, - InitialSyncComplete: i == len(mappings)-1, - }); err != nil { - return fmt.Errorf("send proxy mapping: %w", err) - } } return nil @@ -323,13 +366,9 @@ func (s *ProxyServiceServer) snapshotServiceMappings(ctx context.Context, conn * continue } - token, err := s.tokenStore.GenerateToken(service.AccountID, service.ID, 5*time.Minute) + token, err := s.tokenStore.GenerateToken(service.AccountID, service.ID, s.proxyTokenTTL()) if err != nil { - log.WithFields(log.Fields{ - "service": service.Name, - "account": service.AccountID, - }).WithError(err).Error("failed to generate auth token for snapshot") - continue + return nil, fmt.Errorf("generate auth token for service %s: %w", service.ID, err) } m := service.ToProtoMapping(rpservice.Create, token, s.GetOIDCValidationConfig()) @@ -409,13 +448,16 @@ func (s *ProxyServiceServer) SendServiceUpdate(update *proto.GetMappingUpdateRes conn := value.(*proxyConnection) resp := s.perProxyMessage(update, conn.proxyID) if resp == nil { + log.Warnf("Token generation failed for proxy %s, disconnecting to force resync", conn.proxyID) + conn.cancel() return true } select { case conn.sendChan <- resp: log.Debugf("Sent service update to proxy server %s", conn.proxyID) default: - log.Warnf("Failed to send service update to proxy server %s (channel full)", conn.proxyID) + log.Warnf("Send channel full for proxy %s, disconnecting to force resync", conn.proxyID) + conn.cancel() } return true }) @@ -495,13 +537,16 @@ func (s *ProxyServiceServer) SendServiceUpdateToCluster(ctx context.Context, upd } msg := s.perProxyMessage(updateResponse, proxyID) if msg == nil { + log.WithContext(ctx).Warnf("Token generation failed for proxy %s in cluster %s, disconnecting to force resync", proxyID, clusterAddr) + conn.cancel() continue } select { case conn.sendChan <- msg: log.WithContext(ctx).Debugf("Sent service update with id %s to proxy %s in cluster %s", update.Id, proxyID, clusterAddr) default: - log.WithContext(ctx).Warnf("Failed to send service update to proxy %s in cluster %s (channel full)", proxyID, clusterAddr) + log.WithContext(ctx).Warnf("Send channel full for proxy %s in cluster %s, disconnecting to force resync", proxyID, clusterAddr) + conn.cancel() } } } @@ -527,7 +572,8 @@ func proxyAcceptsMapping(conn *proxyConnection, mapping *proto.ProxyMapping) boo // perProxyMessage returns a copy of update with a fresh one-time token for // create/update operations. For delete operations the original mapping is // used unchanged because proxies do not need to authenticate for removal. -// Returns nil if token generation fails (the proxy should be skipped). +// Returns nil if token generation fails; the caller must disconnect the +// proxy so it can resync via a fresh snapshot on reconnect. func (s *ProxyServiceServer) perProxyMessage(update *proto.GetMappingUpdateResponse, proxyID string) *proto.GetMappingUpdateResponse { resp := make([]*proto.ProxyMapping, 0, len(update.Mapping)) for _, mapping := range update.Mapping { @@ -536,7 +582,7 @@ func (s *ProxyServiceServer) perProxyMessage(update *proto.GetMappingUpdateRespo continue } - token, err := s.tokenStore.GenerateToken(mapping.AccountId, mapping.Id, 5*time.Minute) + token, err := s.tokenStore.GenerateToken(mapping.AccountId, mapping.Id, s.proxyTokenTTL()) if err != nil { log.Warnf("Failed to generate token for proxy %s: %v", proxyID, err) return nil diff --git a/management/internals/shared/grpc/proxy_snapshot_test.go b/management/internals/shared/grpc/proxy_snapshot_test.go new file mode 100644 index 000000000..e0c7425c5 --- /dev/null +++ b/management/internals/shared/grpc/proxy_snapshot_test.go @@ -0,0 +1,174 @@ +package grpc + +import ( + "context" + "fmt" + "testing" + + "github.com/golang/mock/gomock" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "google.golang.org/grpc" + "google.golang.org/grpc/metadata" + + rpservice "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/service" + "github.com/netbirdio/netbird/shared/management/proto" +) + +// recordingStream captures all messages sent via Send so tests can inspect +// batching behaviour without a real gRPC transport. +type recordingStream struct { + grpc.ServerStream + messages []*proto.GetMappingUpdateResponse +} + +func (s *recordingStream) Send(m *proto.GetMappingUpdateResponse) error { + s.messages = append(s.messages, m) + return nil +} + +func (s *recordingStream) Context() context.Context { return context.Background() } +func (s *recordingStream) SetHeader(metadata.MD) error { return nil } +func (s *recordingStream) SendHeader(metadata.MD) error { return nil } +func (s *recordingStream) SetTrailer(metadata.MD) {} +func (s *recordingStream) SendMsg(any) error { return nil } +func (s *recordingStream) RecvMsg(any) error { return nil } + +// makeServices creates n enabled services assigned to the given cluster. +func makeServices(n int, cluster string) []*rpservice.Service { + services := make([]*rpservice.Service, n) + for i := range n { + services[i] = &rpservice.Service{ + ID: fmt.Sprintf("svc-%d", i), + AccountID: "acct-1", + Name: fmt.Sprintf("svc-%d", i), + Domain: fmt.Sprintf("svc-%d.example.com", i), + ProxyCluster: cluster, + Enabled: true, + Targets: []*rpservice.Target{ + {TargetType: rpservice.TargetTypeHost, TargetId: "host-1"}, + }, + } + } + return services +} + +func newSnapshotTestServer(t *testing.T, batchSize int) *ProxyServiceServer { + t.Helper() + s := &ProxyServiceServer{ + tokenStore: NewOneTimeTokenStore(context.Background(), testCacheStore(t)), + snapshotBatchSize: batchSize, + } + s.SetProxyController(newTestProxyController()) + return s +} + +func TestSendSnapshot_BatchesMappings(t *testing.T) { + const cluster = "cluster.example.com" + const batchSize = 3 + const totalServices = 7 // 3 + 3 + 1 + + ctrl := gomock.NewController(t) + mgr := rpservice.NewMockManager(ctrl) + mgr.EXPECT().GetGlobalServices(gomock.Any()).Return(makeServices(totalServices, cluster), nil) + + s := newSnapshotTestServer(t, batchSize) + s.serviceManager = mgr + + stream := &recordingStream{} + conn := &proxyConnection{ + proxyID: "proxy-a", + address: cluster, + stream: stream, + } + + err := s.sendSnapshot(context.Background(), conn) + require.NoError(t, err) + + // Expect ceil(7/3) = 3 messages + require.Len(t, stream.messages, 3, "should send ceil(totalServices/batchSize) messages") + + assert.Len(t, stream.messages[0].Mapping, 3) + assert.False(t, stream.messages[0].InitialSyncComplete, "first batch should not be sync-complete") + + assert.Len(t, stream.messages[1].Mapping, 3) + assert.False(t, stream.messages[1].InitialSyncComplete, "middle batch should not be sync-complete") + + assert.Len(t, stream.messages[2].Mapping, 1) + assert.True(t, stream.messages[2].InitialSyncComplete, "last batch must be sync-complete") + + // Verify all service IDs are present exactly once + seen := make(map[string]bool) + for _, msg := range stream.messages { + for _, m := range msg.Mapping { + assert.False(t, seen[m.Id], "duplicate service ID %s", m.Id) + seen[m.Id] = true + } + } + assert.Len(t, seen, totalServices) +} + +func TestSendSnapshot_ExactBatchMultiple(t *testing.T) { + const cluster = "cluster.example.com" + const batchSize = 3 + const totalServices = 6 // exactly 2 batches + + ctrl := gomock.NewController(t) + mgr := rpservice.NewMockManager(ctrl) + mgr.EXPECT().GetGlobalServices(gomock.Any()).Return(makeServices(totalServices, cluster), nil) + + s := newSnapshotTestServer(t, batchSize) + s.serviceManager = mgr + + stream := &recordingStream{} + conn := &proxyConnection{proxyID: "proxy-a", address: cluster, stream: stream} + + require.NoError(t, s.sendSnapshot(context.Background(), conn)) + require.Len(t, stream.messages, 2) + + assert.Len(t, stream.messages[0].Mapping, 3) + assert.False(t, stream.messages[0].InitialSyncComplete) + + assert.Len(t, stream.messages[1].Mapping, 3) + assert.True(t, stream.messages[1].InitialSyncComplete) +} + +func TestSendSnapshot_SingleBatch(t *testing.T) { + const cluster = "cluster.example.com" + const batchSize = 100 + const totalServices = 5 + + ctrl := gomock.NewController(t) + mgr := rpservice.NewMockManager(ctrl) + mgr.EXPECT().GetGlobalServices(gomock.Any()).Return(makeServices(totalServices, cluster), nil) + + s := newSnapshotTestServer(t, batchSize) + s.serviceManager = mgr + + stream := &recordingStream{} + conn := &proxyConnection{proxyID: "proxy-a", address: cluster, stream: stream} + + require.NoError(t, s.sendSnapshot(context.Background(), conn)) + require.Len(t, stream.messages, 1, "all mappings should fit in one batch") + assert.Len(t, stream.messages[0].Mapping, totalServices) + assert.True(t, stream.messages[0].InitialSyncComplete) +} + +func TestSendSnapshot_EmptySnapshot(t *testing.T) { + const cluster = "cluster.example.com" + + ctrl := gomock.NewController(t) + mgr := rpservice.NewMockManager(ctrl) + mgr.EXPECT().GetGlobalServices(gomock.Any()).Return(nil, nil) + + s := newSnapshotTestServer(t, 500) + s.serviceManager = mgr + + stream := &recordingStream{} + conn := &proxyConnection{proxyID: "proxy-a", address: cluster, stream: stream} + + require.NoError(t, s.sendSnapshot(context.Background(), conn)) + require.Len(t, stream.messages, 1, "empty snapshot must still send sync-complete") + assert.Empty(t, stream.messages[0].Mapping) + assert.True(t, stream.messages[0].InitialSyncComplete) +} diff --git a/management/internals/shared/grpc/proxy_test.go b/management/internals/shared/grpc/proxy_test.go index de4e96d93..5a7a457df 100644 --- a/management/internals/shared/grpc/proxy_test.go +++ b/management/internals/shared/grpc/proxy_test.go @@ -85,11 +85,14 @@ func registerFakeProxy(s *ProxyServiceServer, proxyID, clusterAddr string) chan // registerFakeProxyWithCaps adds a fake proxy connection with explicit capabilities. func registerFakeProxyWithCaps(s *ProxyServiceServer, proxyID, clusterAddr string, caps *proto.ProxyCapabilities) chan *proto.GetMappingUpdateResponse { ch := make(chan *proto.GetMappingUpdateResponse, 10) + ctx, cancel := context.WithCancel(context.Background()) conn := &proxyConnection{ proxyID: proxyID, address: clusterAddr, capabilities: caps, sendChan: ch, + ctx: ctx, + cancel: cancel, } s.connectedProxies.Store(proxyID, conn) diff --git a/proxy/management_integration_test.go b/proxy/management_integration_test.go index e9eae3210..99bbdad0c 100644 --- a/proxy/management_integration_test.go +++ b/proxy/management_integration_test.go @@ -364,14 +364,16 @@ func TestIntegration_ProxyConnection_HappyPath(t *testing.T) { }) require.NoError(t, err) - // Receive all mappings from the snapshot - server sends each mapping individually mappingsByID := make(map[string]*proto.ProxyMapping) - for i := 0; i < 2; i++ { + for { msg, err := stream.Recv() require.NoError(t, err) for _, m := range msg.GetMapping() { mappingsByID[m.GetId()] = m } + if msg.GetInitialSyncComplete() { + break + } } // Should receive 2 mappings total @@ -411,12 +413,14 @@ func TestIntegration_ProxyConnection_SendsClusterAddress(t *testing.T) { }) require.NoError(t, err) - // Receive all mappings - server sends each mapping individually mappings := make([]*proto.ProxyMapping, 0) - for i := 0; i < 2; i++ { + for { msg, err := stream.Recv() require.NoError(t, err) mappings = append(mappings, msg.GetMapping()...) + if msg.GetInitialSyncComplete() { + break + } } // Should receive the 2 mappings matching the cluster @@ -440,13 +444,15 @@ func TestIntegration_ProxyConnection_Reconnect_ReceivesSameConfig(t *testing.T) clusterAddress := "test.proxy.io" proxyID := "test-proxy-reconnect" - // Helper to receive all mappings from a stream - receiveMappings := func(stream proto.ProxyService_GetMappingUpdateClient, count int) []*proto.ProxyMapping { + receiveMappings := func(stream proto.ProxyService_GetMappingUpdateClient) []*proto.ProxyMapping { var mappings []*proto.ProxyMapping - for i := 0; i < count; i++ { + for { msg, err := stream.Recv() require.NoError(t, err) mappings = append(mappings, msg.GetMapping()...) + if msg.GetInitialSyncComplete() { + break + } } return mappings } @@ -460,7 +466,7 @@ func TestIntegration_ProxyConnection_Reconnect_ReceivesSameConfig(t *testing.T) }) require.NoError(t, err) - firstMappings := receiveMappings(stream1, 2) + firstMappings := receiveMappings(stream1) cancel1() time.Sleep(100 * time.Millisecond) @@ -476,7 +482,7 @@ func TestIntegration_ProxyConnection_Reconnect_ReceivesSameConfig(t *testing.T) }) require.NoError(t, err) - secondMappings := receiveMappings(stream2, 2) + secondMappings := receiveMappings(stream2) // Should receive the same mappings assert.Equal(t, len(firstMappings), len(secondMappings), @@ -542,12 +548,14 @@ func TestIntegration_ProxyConnection_ReconnectDoesNotDuplicateState(t *testing.T } } - // Helper to receive and apply all mappings receiveAndApply := func(stream proto.ProxyService_GetMappingUpdateClient) { - for i := 0; i < 2; i++ { + for { msg, err := stream.Recv() require.NoError(t, err) applyMappings(msg.GetMapping()) + if msg.GetInitialSyncComplete() { + break + } } } @@ -636,12 +644,14 @@ func TestIntegration_ProxyConnection_MultipleProxiesReceiveUpdates(t *testing.T) }) require.NoError(t, err) - // Receive all mappings - server sends each mapping individually count := 0 - for i := 0; i < 2; i++ { + for { msg, err := stream.Recv() require.NoError(t, err) count += len(msg.GetMapping()) + if msg.GetInitialSyncComplete() { + break + } } mu.Lock() @@ -681,9 +691,12 @@ func TestIntegration_ProxyConnection_FastReconnectDoesNotLoseState(t *testing.T) }) require.NoError(t, err) - for i := 0; i < 2; i++ { - _, err := stream1.Recv() + for { + msg, err := stream1.Recv() require.NoError(t, err) + if msg.GetInitialSyncComplete() { + break + } } require.Contains(t, setup.proxyService.GetConnectedProxies(), proxyID, @@ -699,9 +712,12 @@ func TestIntegration_ProxyConnection_FastReconnectDoesNotLoseState(t *testing.T) }) require.NoError(t, err) - for i := 0; i < 2; i++ { - _, err := stream2.Recv() + for { + msg, err := stream2.Recv() require.NoError(t, err) + if msg.GetInitialSyncComplete() { + break + } } cancel1() diff --git a/proxy/server.go b/proxy/server.go index fbd0d058e..6980e1df1 100644 --- a/proxy/server.go +++ b/proxy/server.go @@ -943,6 +943,8 @@ func (s *Server) newManagementMappingWorker(ctx context.Context, client proto.Pr operation := func() error { s.Logger.Debug("connecting to management mapping stream") + initialSyncDone = false + if s.healthChecker != nil { s.healthChecker.SetManagementConnected(false) } @@ -1000,6 +1002,11 @@ func (s *Server) handleMappingStream(ctx context.Context, mappingClient proto.Pr return ctx.Err() } + var snapshotIDs map[types.ServiceID]struct{} + if !*initialSyncDone { + snapshotIDs = make(map[types.ServiceID]struct{}) + } + for { // Check for context completion to gracefully shutdown. select { @@ -1020,17 +1027,45 @@ func (s *Server) handleMappingStream(ctx context.Context, mappingClient proto.Pr s.processMappings(ctx, msg.GetMapping()) s.Logger.Debug("Processing mapping update completed") - if !*initialSyncDone && msg.GetInitialSyncComplete() { - if s.healthChecker != nil { - s.healthChecker.SetInitialSyncComplete() + if !*initialSyncDone { + for _, m := range msg.GetMapping() { + snapshotIDs[types.ServiceID(m.GetId())] = struct{}{} + } + if msg.GetInitialSyncComplete() { + s.reconcileSnapshot(ctx, snapshotIDs) + snapshotIDs = nil + if s.healthChecker != nil { + s.healthChecker.SetInitialSyncComplete() + } + *initialSyncDone = true + s.Logger.Info("Initial mapping sync complete") } - *initialSyncDone = true - s.Logger.Info("Initial mapping sync complete") } } } } +// reconcileSnapshot removes local mappings that are absent from the snapshot. +// This ensures services deleted while the proxy was disconnected get cleaned up. +func (s *Server) reconcileSnapshot(ctx context.Context, snapshotIDs map[types.ServiceID]struct{}) { + s.portMu.RLock() + var stale []*proto.ProxyMapping + for svcID, mapping := range s.lastMappings { + if _, ok := snapshotIDs[svcID]; !ok { + stale = append(stale, mapping) + } + } + s.portMu.RUnlock() + + for _, mapping := range stale { + s.Logger.WithFields(log.Fields{ + "service_id": mapping.GetId(), + "domain": mapping.GetDomain(), + }).Info("Removing stale mapping absent from snapshot") + s.removeMapping(ctx, mapping) + } +} + func (s *Server) processMappings(ctx context.Context, mappings []*proto.ProxyMapping) { for _, mapping := range mappings { s.Logger.WithFields(log.Fields{ diff --git a/proxy/snapshot_reconcile_test.go b/proxy/snapshot_reconcile_test.go new file mode 100644 index 000000000..042d8df77 --- /dev/null +++ b/proxy/snapshot_reconcile_test.go @@ -0,0 +1,227 @@ +package proxy + +import ( + "context" + "io" + "testing" + + log "github.com/sirupsen/logrus" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "github.com/netbirdio/netbird/proxy/internal/health" + "github.com/netbirdio/netbird/proxy/internal/types" + "github.com/netbirdio/netbird/shared/management/proto" +) + +// collectStaleIDs mirrors the stale-detection logic in reconcileSnapshot +// so we can verify it without triggering removeMapping (which requires full +// server wiring). This keeps the test focused on the detection algorithm. +func collectStaleIDs(lastMappings map[types.ServiceID]*proto.ProxyMapping, snapshotIDs map[types.ServiceID]struct{}) []types.ServiceID { + var stale []types.ServiceID + for svcID := range lastMappings { + if _, ok := snapshotIDs[svcID]; !ok { + stale = append(stale, svcID) + } + } + return stale +} + +// TestStaleDetection_PartialOverlap verifies that only services absent from +// the snapshot are flagged as stale. +func TestStaleDetection_PartialOverlap(t *testing.T) { + local := map[types.ServiceID]*proto.ProxyMapping{ + "svc-1": {Id: "svc-1"}, + "svc-2": {Id: "svc-2"}, + "svc-stale-a": {Id: "svc-stale-a"}, + "svc-stale-b": {Id: "svc-stale-b"}, + } + snapshot := map[types.ServiceID]struct{}{ + "svc-1": {}, + "svc-2": {}, + "svc-3": {}, // new service, not in local + } + + stale := collectStaleIDs(local, snapshot) + assert.Len(t, stale, 2) + staleSet := make(map[types.ServiceID]struct{}) + for _, id := range stale { + staleSet[id] = struct{}{} + } + assert.Contains(t, staleSet, types.ServiceID("svc-stale-a")) + assert.Contains(t, staleSet, types.ServiceID("svc-stale-b")) +} + +// TestStaleDetection_AllStale verifies an empty snapshot flags everything. +func TestStaleDetection_AllStale(t *testing.T) { + local := map[types.ServiceID]*proto.ProxyMapping{ + "svc-1": {Id: "svc-1"}, + "svc-2": {Id: "svc-2"}, + } + stale := collectStaleIDs(local, map[types.ServiceID]struct{}{}) + assert.Len(t, stale, 2) +} + +// TestStaleDetection_NoneStale verifies full overlap produces no stale entries. +func TestStaleDetection_NoneStale(t *testing.T) { + local := map[types.ServiceID]*proto.ProxyMapping{ + "svc-1": {Id: "svc-1"}, + "svc-2": {Id: "svc-2"}, + } + snapshot := map[types.ServiceID]struct{}{ + "svc-1": {}, + "svc-2": {}, + } + stale := collectStaleIDs(local, snapshot) + assert.Empty(t, stale) +} + +// TestStaleDetection_EmptyLocal verifies no stale entries when local is empty. +func TestStaleDetection_EmptyLocal(t *testing.T) { + stale := collectStaleIDs( + map[types.ServiceID]*proto.ProxyMapping{}, + map[types.ServiceID]struct{}{"svc-1": {}}, + ) + assert.Empty(t, stale) +} + +// TestReconcileSnapshot_NoStale verifies reconciliation is a no-op when all +// local mappings are present in the snapshot (removeMapping is never called). +func TestReconcileSnapshot_NoStale(t *testing.T) { + s := &Server{ + Logger: log.StandardLogger(), + lastMappings: make(map[types.ServiceID]*proto.ProxyMapping), + } + s.lastMappings["svc-1"] = &proto.ProxyMapping{Id: "svc-1"} + s.lastMappings["svc-2"] = &proto.ProxyMapping{Id: "svc-2"} + + snapshotIDs := map[types.ServiceID]struct{}{ + "svc-1": {}, + "svc-2": {}, + } + // This should not panic — no stale entries means removeMapping is never called. + s.reconcileSnapshot(context.Background(), snapshotIDs) + + assert.Len(t, s.lastMappings, 2, "no mappings should be removed when all are in snapshot") +} + +// TestReconcileSnapshot_EmptyLocal verifies reconciliation is a no-op with +// no local mappings. +func TestReconcileSnapshot_EmptyLocal(t *testing.T) { + s := &Server{ + Logger: log.StandardLogger(), + lastMappings: make(map[types.ServiceID]*proto.ProxyMapping), + } + s.reconcileSnapshot(context.Background(), map[types.ServiceID]struct{}{"svc-1": {}}) + assert.Empty(t, s.lastMappings) +} + +// --- handleMappingStream tests for batched snapshot ID accumulation --- + +// TestHandleMappingStream_BatchedSnapshotSyncComplete verifies that sync is +// marked done only after the final InitialSyncComplete message, even when +// the snapshot arrives in multiple batches. +func TestHandleMappingStream_BatchedSnapshotSyncComplete(t *testing.T) { + checker := health.NewChecker(nil, nil) + s := &Server{ + Logger: log.StandardLogger(), + healthChecker: checker, + routerReady: closedChan(), + lastMappings: make(map[types.ServiceID]*proto.ProxyMapping), + } + + stream := &mockMappingStream{ + messages: []*proto.GetMappingUpdateResponse{ + {}, // batch 1: no sync-complete + {}, // batch 2: no sync-complete + {InitialSyncComplete: true}, // batch 3: sync done + }, + } + + syncDone := false + err := s.handleMappingStream(context.Background(), stream, &syncDone) + assert.NoError(t, err) + assert.True(t, syncDone, "sync should be marked done after final batch") +} + +// TestHandleMappingStream_PostSyncDoesNotReconcile verifies that messages +// arriving after InitialSyncComplete do not trigger a second reconciliation. +func TestHandleMappingStream_PostSyncDoesNotReconcile(t *testing.T) { + s := &Server{ + Logger: log.StandardLogger(), + routerReady: closedChan(), + lastMappings: make(map[types.ServiceID]*proto.ProxyMapping), + } + + // Simulate state left over from a previous sync. + s.lastMappings["svc-1"] = &proto.ProxyMapping{Id: "svc-1", AccountId: "acct-1"} + s.lastMappings["svc-2"] = &proto.ProxyMapping{Id: "svc-2", AccountId: "acct-1"} + + stream := &mockMappingStream{ + messages: []*proto.GetMappingUpdateResponse{ + {}, // post-sync empty message — must not reconcile + }, + } + + syncDone := true // sync already completed in a previous stream + err := s.handleMappingStream(context.Background(), stream, &syncDone) + require.NoError(t, err) + + assert.Len(t, s.lastMappings, 2, + "post-sync messages must not trigger reconciliation — all entries should survive") +} + +// TestHandleMappingStream_ImmediateEOF_NoReconciliation verifies that if the +// stream closes before sync completes, no reconciliation occurs. +func TestHandleMappingStream_ImmediateEOF_NoReconciliation(t *testing.T) { + s := &Server{ + Logger: log.StandardLogger(), + routerReady: closedChan(), + lastMappings: make(map[types.ServiceID]*proto.ProxyMapping), + } + + s.lastMappings["svc-stale"] = &proto.ProxyMapping{Id: "svc-stale", AccountId: "acct-1"} + + stream := &mockMappingStream{} // no messages → immediate EOF + + syncDone := false + err := s.handleMappingStream(context.Background(), stream, &syncDone) + assert.NoError(t, err) + assert.False(t, syncDone, "sync should not be marked done on immediate EOF") + + _, hasStale := s.lastMappings["svc-stale"] + assert.True(t, hasStale, "stale mapping should remain when sync never completed") +} + +// mockErrRecvStream returns an error on the second Recv to verify +// handleMappingStream returns without completing sync. +type mockErrRecvStream struct { + mockMappingStream + calls int +} + +func (m *mockErrRecvStream) Recv() (*proto.GetMappingUpdateResponse, error) { + m.calls++ + if m.calls == 1 { + return &proto.GetMappingUpdateResponse{}, nil + } + return nil, io.ErrUnexpectedEOF +} + +func TestHandleMappingStream_ErrorMidSync_NoReconciliation(t *testing.T) { + s := &Server{ + Logger: log.StandardLogger(), + routerReady: closedChan(), + lastMappings: make(map[types.ServiceID]*proto.ProxyMapping), + } + + s.lastMappings["svc-stale"] = &proto.ProxyMapping{Id: "svc-stale", AccountId: "acct-1"} + + syncDone := false + err := s.handleMappingStream(context.Background(), &mockErrRecvStream{}, &syncDone) + assert.Error(t, err) + assert.False(t, syncDone) + + _, hasStale := s.lastMappings["svc-stale"] + assert.True(t, hasStale, "stale mapping should remain when sync was interrupted by error") +}