From 54b045d9ca511b15dd6ea43e6edff717161b2ca4 Mon Sep 17 00:00:00 2001 From: Viktor Liu Date: Fri, 26 Apr 2024 16:37:27 +0200 Subject: [PATCH] Replaces powershell with the route command and cache route lookups on windows (#1880) --- client/internal/routemanager/client.go | 17 +++- client/internal/routemanager/routemanager.go | 7 +- client/internal/routemanager/systemops.go | 42 +++------ .../routemanager/systemops_android.go | 4 +- .../internal/routemanager/systemops_darwin.go | 12 +-- .../routemanager/systemops_darwin_test.go | 2 +- client/internal/routemanager/systemops_ios.go | 4 +- .../internal/routemanager/systemops_linux.go | 63 ++++++++----- .../routemanager/systemops_nonlinux.go | 5 +- .../internal/routemanager/systemops_test.go | 48 +++++++--- .../routemanager/systemops_windows.go | 92 +++++++++---------- 11 files changed, 161 insertions(+), 135 deletions(-) diff --git a/client/internal/routemanager/client.go b/client/internal/routemanager/client.go index d41ed422b..3569d13ae 100644 --- a/client/internal/routemanager/client.go +++ b/client/internal/routemanager/client.go @@ -3,6 +3,7 @@ package routemanager import ( "context" "fmt" + "net" "net/netip" "time" @@ -215,7 +216,7 @@ func (c *clientNetwork) removeRouteFromWireguardPeer(peerKey string) error { func (c *clientNetwork) removeRouteFromPeerAndSystem() error { if c.chosenRoute != nil { - if err := removeVPNRoute(c.network, c.wgInterface.Name()); err != nil { + if err := removeVPNRoute(c.network, c.getAsInterface()); err != nil { return fmt.Errorf("remove route %s from system, err: %v", c.network, err) } @@ -256,7 +257,7 @@ func (c *clientNetwork) recalculateRouteAndUpdatePeerAndSystem() error { } } else { // otherwise add the route to the system - if err := addVPNRoute(c.network, c.wgInterface.Name()); err != nil { + if err := addVPNRoute(c.network, c.getAsInterface()); err != nil { return fmt.Errorf("route %s couldn't be added for peer %s, err: %v", c.network.String(), c.wgInterface.Address().IP.String(), err) } @@ -344,3 +345,15 @@ func (c *clientNetwork) peersStateAndUpdateWatcher() { } } } + +func (c *clientNetwork) getAsInterface() *net.Interface { + intf, err := net.InterfaceByName(c.wgInterface.Name()) + if err != nil { + log.Warnf("Couldn't get interface by name %s: %v", c.wgInterface.Name(), err) + intf = &net.Interface{ + Name: c.wgInterface.Name(), + } + } + + return intf +} diff --git a/client/internal/routemanager/routemanager.go b/client/internal/routemanager/routemanager.go index 8f9ff9f4b..7715aa819 100644 --- a/client/internal/routemanager/routemanager.go +++ b/client/internal/routemanager/routemanager.go @@ -5,6 +5,7 @@ package routemanager import ( "errors" "fmt" + "net" "net/netip" "sync" @@ -17,7 +18,7 @@ import ( type ref struct { count int nexthop netip.Addr - intf string + intf *net.Interface } type RouteManager struct { @@ -30,8 +31,8 @@ type RouteManager struct { mutex sync.Mutex } -type AddRouteFunc func(prefix netip.Prefix) (nexthop netip.Addr, intf string, err error) -type RemoveRouteFunc func(prefix netip.Prefix, nexthop netip.Addr, intf string) error +type AddRouteFunc func(prefix netip.Prefix) (nexthop netip.Addr, intf *net.Interface, err error) +type RemoveRouteFunc func(prefix netip.Prefix, nexthop netip.Addr, intf *net.Interface) error func NewRouteManager(addRoute AddRouteFunc, removeRoute RemoveRouteFunc) *RouteManager { // TODO: read initial routing table into refCountMap diff --git a/client/internal/routemanager/systemops.go b/client/internal/routemanager/systemops.go index 1ee54b746..1f37a8a3c 100644 --- a/client/internal/routemanager/systemops.go +++ b/client/internal/routemanager/systemops.go @@ -60,17 +60,13 @@ func addRouteForCurrentDefaultGateway(prefix netip.Prefix) error { return nil } - var exitIntf string gatewayHop, intf, err := getNextHop(defaultGateway) if err != nil && !errors.Is(err, ErrRouteNotFound) { return fmt.Errorf("unable to get the next hop for the default gateway address. error: %s", err) } - if intf != nil { - exitIntf = intf.Name - } log.Debugf("Adding a new route for gateway %s with next hop %s", gatewayPrefix, gatewayHop) - return addToRouteTable(gatewayPrefix, gatewayHop, exitIntf) + return addToRouteTable(gatewayPrefix, gatewayHop, intf) } func getNextHop(ip netip.Addr) (netip.Addr, *net.Interface, error) { @@ -84,7 +80,7 @@ func getNextHop(ip netip.Addr) (netip.Addr, *net.Interface, error) { return netip.Addr{}, nil, ErrRouteNotFound } - log.Debugf("Route for %s: interface %v, nexthop %v, preferred source %v", ip, intf, gateway, preferredSrc) + log.Debugf("Route for %s: interface %v nexthop %v, preferred source %v", ip, intf, gateway, preferredSrc) if gateway == nil { if preferredSrc == nil { return netip.Addr{}, nil, ErrRouteNotFound @@ -153,12 +149,7 @@ func isSubRange(prefix netip.Prefix) (bool, error) { // addRouteToNonVPNIntf adds a new route to the routing table for the given prefix and returns the next hop and interface. // If the next hop or interface is pointing to the VPN interface, it will return the initial values. -func addRouteToNonVPNIntf( - prefix netip.Prefix, - vpnIntf *iface.WGIface, - initialNextHop netip.Addr, - initialIntf *net.Interface, -) (netip.Addr, string, error) { +func addRouteToNonVPNIntf(prefix netip.Prefix, vpnIntf *iface.WGIface, initialNextHop netip.Addr, initialIntf *net.Interface) (netip.Addr, *net.Interface, error) { addr := prefix.Addr() switch { case addr.IsLoopback(), @@ -168,39 +159,34 @@ func addRouteToNonVPNIntf( addr.IsUnspecified(), addr.IsMulticast(): - return netip.Addr{}, "", ErrRouteNotAllowed + return netip.Addr{}, nil, ErrRouteNotAllowed } // Determine the exit interface and next hop for the prefix, so we can add a specific route nexthop, intf, err := getNextHop(addr) if err != nil { - return netip.Addr{}, "", fmt.Errorf("get next hop: %w", err) + return netip.Addr{}, nil, fmt.Errorf("get next hop: %w", err) } log.Debugf("Found next hop %s for prefix %s with interface %v", nexthop, prefix, intf) exitNextHop := nexthop - var exitIntf string - if intf != nil { - exitIntf = intf.Name - } + exitIntf := intf vpnAddr, ok := netip.AddrFromSlice(vpnIntf.Address().IP) if !ok { - return netip.Addr{}, "", fmt.Errorf("failed to convert vpn address to netip.Addr") + return netip.Addr{}, nil, fmt.Errorf("failed to convert vpn address to netip.Addr") } // if next hop is the VPN address or the interface is the VPN interface, we should use the initial values - if exitNextHop == vpnAddr || exitIntf == vpnIntf.Name() { + if exitNextHop == vpnAddr || exitIntf != nil && exitIntf.Name == vpnIntf.Name() { log.Debugf("Route for prefix %s is pointing to the VPN interface", prefix) exitNextHop = initialNextHop - if initialIntf != nil { - exitIntf = initialIntf.Name - } + exitIntf = initialIntf } log.Debugf("Adding a new route for prefix %s with next hop %s", prefix, exitNextHop) if err := addToRouteTable(prefix, exitNextHop, exitIntf); err != nil { - return netip.Addr{}, "", fmt.Errorf("add route to table: %w", err) + return netip.Addr{}, nil, fmt.Errorf("add route to table: %w", err) } return exitNextHop, exitIntf, nil @@ -208,7 +194,7 @@ func addRouteToNonVPNIntf( // genericAddVPNRoute adds a new route to the vpn interface, it splits the default prefix // in two /1 prefixes to avoid replacing the existing default route -func genericAddVPNRoute(prefix netip.Prefix, intf string) error { +func genericAddVPNRoute(prefix netip.Prefix, intf *net.Interface) error { if prefix == defaultv4 { if err := addToRouteTable(splitDefaultv4_1, netip.Addr{}, intf); err != nil { return err @@ -250,7 +236,7 @@ func genericAddVPNRoute(prefix netip.Prefix, intf string) error { } // addNonExistingRoute adds a new route to the vpn interface if it doesn't exist in the current routing table -func addNonExistingRoute(prefix netip.Prefix, intf string) error { +func addNonExistingRoute(prefix netip.Prefix, intf *net.Interface) error { ok, err := existsInRouteTable(prefix) if err != nil { return fmt.Errorf("exists in route table: %w", err) @@ -277,7 +263,7 @@ func addNonExistingRoute(prefix netip.Prefix, intf string) error { // genericRemoveVPNRoute removes the route from the vpn interface. If a default prefix is given, // it will remove the split /1 prefixes -func genericRemoveVPNRoute(prefix netip.Prefix, intf string) error { +func genericRemoveVPNRoute(prefix netip.Prefix, intf *net.Interface) error { if prefix == defaultv4 { var result *multierror.Error if err := removeFromRouteTable(splitDefaultv4_1, netip.Addr{}, intf); err != nil { @@ -343,7 +329,7 @@ func setupRoutingWithRouteManager(routeManager **RouteManager, initAddresses []n } *routeManager = NewRouteManager( - func(prefix netip.Prefix) (netip.Addr, string, error) { + func(prefix netip.Prefix) (netip.Addr, *net.Interface, error) { addr := prefix.Addr() nexthop, intf := initialNextHopV4, initialIntfV4 if addr.Is6() { diff --git a/client/internal/routemanager/systemops_android.go b/client/internal/routemanager/systemops_android.go index 34d2d270f..4d23d3910 100644 --- a/client/internal/routemanager/systemops_android.go +++ b/client/internal/routemanager/systemops_android.go @@ -24,10 +24,10 @@ func enableIPForwarding() error { return nil } -func addVPNRoute(netip.Prefix, string) error { +func addVPNRoute(netip.Prefix, *net.Interface) error { return nil } -func removeVPNRoute(netip.Prefix, string) error { +func removeVPNRoute(netip.Prefix, *net.Interface) error { return nil } diff --git a/client/internal/routemanager/systemops_darwin.go b/client/internal/routemanager/systemops_darwin.go index f7ce72a4e..017dc6c28 100644 --- a/client/internal/routemanager/systemops_darwin.go +++ b/client/internal/routemanager/systemops_darwin.go @@ -27,15 +27,15 @@ func cleanupRouting() error { return cleanupRoutingWithRouteManager(routeManager) } -func addToRouteTable(prefix netip.Prefix, nexthop netip.Addr, intf string) error { +func addToRouteTable(prefix netip.Prefix, nexthop netip.Addr, intf *net.Interface) error { return routeCmd("add", prefix, nexthop, intf) } -func removeFromRouteTable(prefix netip.Prefix, nexthop netip.Addr, intf string) error { +func removeFromRouteTable(prefix netip.Prefix, nexthop netip.Addr, intf *net.Interface) error { return routeCmd("delete", prefix, nexthop, intf) } -func routeCmd(action string, prefix netip.Prefix, nexthop netip.Addr, intf string) error { +func routeCmd(action string, prefix netip.Prefix, nexthop netip.Addr, intf *net.Interface) error { inet := "-inet" network := prefix.String() if prefix.IsSingleIP() { @@ -46,15 +46,15 @@ func routeCmd(action string, prefix netip.Prefix, nexthop netip.Addr, intf strin // Special case for IPv6 split default route, pointing to the wg interface fails // TODO: Remove once we have IPv6 support on the interface if prefix.Bits() == 1 { - intf = "lo0" + intf = &net.Interface{Name: "lo0"} } } args := []string{"-n", action, inet, network} if nexthop.IsValid() { args = append(args, nexthop.Unmap().String()) - } else if intf != "" { - args = append(args, "-interface", intf) + } else if intf != nil { + args = append(args, "-interface", intf.Name) } if err := retryRouteCmd(args); err != nil { diff --git a/client/internal/routemanager/systemops_darwin_test.go b/client/internal/routemanager/systemops_darwin_test.go index cc9bb9db5..c23a7cde3 100644 --- a/client/internal/routemanager/systemops_darwin_test.go +++ b/client/internal/routemanager/systemops_darwin_test.go @@ -33,7 +33,7 @@ func init() { func TestConcurrentRoutes(t *testing.T) { baseIP := netip.MustParseAddr("192.0.2.0") - intf := "lo0" + intf := &net.Interface{Name: "lo0"} var wg sync.WaitGroup for i := 0; i < 1024; i++ { diff --git a/client/internal/routemanager/systemops_ios.go b/client/internal/routemanager/systemops_ios.go index 34d2d270f..4d23d3910 100644 --- a/client/internal/routemanager/systemops_ios.go +++ b/client/internal/routemanager/systemops_ios.go @@ -24,10 +24,10 @@ func enableIPForwarding() error { return nil } -func addVPNRoute(netip.Prefix, string) error { +func addVPNRoute(netip.Prefix, *net.Interface) error { return nil } -func removeVPNRoute(netip.Prefix, string) error { +func removeVPNRoute(netip.Prefix, *net.Interface) error { return nil } diff --git a/client/internal/routemanager/systemops_linux.go b/client/internal/routemanager/systemops_linux.go index 7c77c9fbb..ce0c07ce6 100644 --- a/client/internal/routemanager/systemops_linux.go +++ b/client/internal/routemanager/systemops_linux.go @@ -46,9 +46,6 @@ var routeManager = &RouteManager{} // originalSysctl stores the original sysctl values before they are modified var originalSysctl map[string]int -// determines whether to use the legacy routing setup -var isLegacy = os.Getenv("NB_USE_LEGACY_ROUTING") == "true" || nbnet.CustomRoutingDisabled() - // sysctlFailed is used as an indicator to emit a warning when default routes are configured var sysctlFailed bool @@ -62,6 +59,20 @@ type ruleParams struct { description string } +// isLegacy determines whether to use the legacy routing setup +func isLegacy() bool { + return os.Getenv("NB_USE_LEGACY_ROUTING") == "true" || nbnet.CustomRoutingDisabled() +} + +// setIsLegacy sets the legacy routing setup +func setIsLegacy(b bool) { + if b { + os.Setenv("NB_USE_LEGACY_ROUTING", "true") + } else { + os.Unsetenv("NB_USE_LEGACY_ROUTING") + } +} + func getSetupRules() []ruleParams { return []ruleParams{ {100, -1, syscall.RT_TABLE_MAIN, netlink.FAMILY_V4, false, 0, "rule with suppress prefixlen v4"}, @@ -82,7 +93,7 @@ func getSetupRules() []ruleParams { // This table is where a default route or other specific routes received from the management server are configured, // enabling VPN connectivity. func setupRouting(initAddresses []net.IP, wgIface *iface.WGIface) (_ peer.BeforeAddPeerHookFunc, _ peer.AfterRemovePeerHookFunc, err error) { - if isLegacy { + if isLegacy() { log.Infof("Using legacy routing setup") return setupRoutingWithRouteManager(&routeManager, initAddresses, wgIface) } @@ -111,7 +122,7 @@ func setupRouting(initAddresses []net.IP, wgIface *iface.WGIface) (_ peer.Before if err := addRule(rule); err != nil { if errors.Is(err, syscall.EOPNOTSUPP) { log.Warnf("Rule operations are not supported, falling back to the legacy routing setup") - isLegacy = true + setIsLegacy(true) return setupRoutingWithRouteManager(&routeManager, initAddresses, wgIface) } return nil, nil, fmt.Errorf("%s: %w", rule.description, err) @@ -125,7 +136,7 @@ func setupRouting(initAddresses []net.IP, wgIface *iface.WGIface) (_ peer.Before // It systematically removes the three rules and any associated routing table entries to ensure a clean state. // The function uses error aggregation to report any errors encountered during the cleanup process. func cleanupRouting() error { - if isLegacy { + if isLegacy() { return cleanupRoutingWithRouteManager(routeManager) } @@ -154,16 +165,16 @@ func cleanupRouting() error { return result.ErrorOrNil() } -func addToRouteTable(prefix netip.Prefix, nexthop netip.Addr, intf string) error { +func addToRouteTable(prefix netip.Prefix, nexthop netip.Addr, intf *net.Interface) error { return addRoute(prefix, nexthop, intf, syscall.RT_TABLE_MAIN) } -func removeFromRouteTable(prefix netip.Prefix, nexthop netip.Addr, intf string) error { +func removeFromRouteTable(prefix netip.Prefix, nexthop netip.Addr, intf *net.Interface) error { return removeRoute(prefix, nexthop, intf, syscall.RT_TABLE_MAIN) } -func addVPNRoute(prefix netip.Prefix, intf string) error { - if isLegacy { +func addVPNRoute(prefix netip.Prefix, intf *net.Interface) error { + if isLegacy() { return genericAddVPNRoute(prefix, intf) } @@ -185,8 +196,8 @@ func addVPNRoute(prefix netip.Prefix, intf string) error { return nil } -func removeVPNRoute(prefix netip.Prefix, intf string) error { - if isLegacy { +func removeVPNRoute(prefix netip.Prefix, intf *net.Interface) error { + if isLegacy() { return genericRemoveVPNRoute(prefix, intf) } @@ -244,7 +255,7 @@ func getRoutes(tableID, family int) ([]netip.Prefix, error) { } // addRoute adds a route to a specific routing table identified by tableID. -func addRoute(prefix netip.Prefix, addr netip.Addr, intf string, tableID int) error { +func addRoute(prefix netip.Prefix, addr netip.Addr, intf *net.Interface, tableID int) error { route := &netlink.Route{ Scope: netlink.SCOPE_UNIVERSE, Table: tableID, @@ -316,7 +327,7 @@ func removeUnreachableRoute(prefix netip.Prefix, tableID int) error { } // removeRoute removes a route from a specific routing table identified by tableID. -func removeRoute(prefix netip.Prefix, addr netip.Addr, intf string, tableID int) error { +func removeRoute(prefix netip.Prefix, addr netip.Addr, intf *net.Interface, tableID int) error { _, ipNet, err := net.ParseCIDR(prefix.String()) if err != nil { return fmt.Errorf("parse prefix %s: %w", prefix, err) @@ -470,20 +481,22 @@ func removeRule(params ruleParams) error { } // addNextHop adds the gateway and device to the route. -func addNextHop(addr netip.Addr, intf string, route *netlink.Route) error { - if addr.IsValid() { - route.Gw = addr.AsSlice() - if intf == "" { - intf = addr.Zone() - } +func addNextHop(addr netip.Addr, intf *net.Interface, route *netlink.Route) error { + if intf != nil { + route.LinkIndex = intf.Index } - if intf != "" { - link, err := netlink.LinkByName(intf) - if err != nil { - return fmt.Errorf("set interface %s: %w", intf, err) + if addr.IsValid() { + route.Gw = addr.AsSlice() + + // if zone is set, it means the gateway is a link-local address, so we set the link index + if addr.Zone() != "" && intf == nil { + link, err := netlink.LinkByName(addr.Zone()) + if err != nil { + return fmt.Errorf("get link by name for zone %s: %w", addr.Zone(), err) + } + route.LinkIndex = link.Attrs().Index } - route.LinkIndex = link.Attrs().Index } return nil diff --git a/client/internal/routemanager/systemops_nonlinux.go b/client/internal/routemanager/systemops_nonlinux.go index 38026107e..91879790a 100644 --- a/client/internal/routemanager/systemops_nonlinux.go +++ b/client/internal/routemanager/systemops_nonlinux.go @@ -3,6 +3,7 @@ package routemanager import ( + "net" "net/netip" "runtime" @@ -14,10 +15,10 @@ func enableIPForwarding() error { return nil } -func addVPNRoute(prefix netip.Prefix, intf string) error { +func addVPNRoute(prefix netip.Prefix, intf *net.Interface) error { return genericAddVPNRoute(prefix, intf) } -func removeVPNRoute(prefix netip.Prefix, intf string) error { +func removeVPNRoute(prefix netip.Prefix, intf *net.Interface) error { return genericRemoveVPNRoute(prefix, intf) } diff --git a/client/internal/routemanager/systemops_test.go b/client/internal/routemanager/systemops_test.go index 9f906c06f..8a92ac579 100644 --- a/client/internal/routemanager/systemops_test.go +++ b/client/internal/routemanager/systemops_test.go @@ -50,6 +50,8 @@ func TestAddRemoveRoutes(t *testing.T) { for n, testCase := range testCases { t.Run(testCase.name, func(t *testing.T) { + t.Setenv("NB_DISABLE_ROUTE_CACHE", "true") + peerPrivateKey, _ := wgtypes.GeneratePrivateKey() newNet, err := stdnet.NewNet() if err != nil { @@ -67,7 +69,11 @@ func TestAddRemoveRoutes(t *testing.T) { assert.NoError(t, cleanupRouting()) }) - err = genericAddVPNRoute(testCase.prefix, wgInterface.Name()) + index, err := net.InterfaceByName(wgInterface.Name()) + require.NoError(t, err, "InterfaceByName should not return err") + intf := &net.Interface{Index: index.Index, Name: wgInterface.Name()} + + err = addVPNRoute(testCase.prefix, intf) require.NoError(t, err, "genericAddVPNRoute should not return err") if testCase.shouldRouteToWireguard { @@ -78,7 +84,7 @@ func TestAddRemoveRoutes(t *testing.T) { exists, err := existsInRouteTable(testCase.prefix) require.NoError(t, err, "existsInRouteTable should not return err") if exists && testCase.shouldRouteToWireguard { - err = genericRemoveVPNRoute(testCase.prefix, wgInterface.Name()) + err = removeVPNRoute(testCase.prefix, intf) require.NoError(t, err, "genericRemoveVPNRoute should not return err") prefixGateway, _, err := getNextHop(testCase.prefix.Addr()) @@ -182,12 +188,16 @@ func TestAddExistAndRemoveRoute(t *testing.T) { } for n, testCase := range testCases { + var buf bytes.Buffer log.SetOutput(&buf) defer func() { log.SetOutput(os.Stderr) }() t.Run(testCase.name, func(t *testing.T) { + t.Setenv("NB_USE_LEGACY_ROUTING", "true") + t.Setenv("NB_DISABLE_ROUTE_CACHE", "true") + peerPrivateKey, _ := wgtypes.GeneratePrivateKey() newNet, err := stdnet.NewNet() if err != nil { @@ -200,14 +210,18 @@ func TestAddExistAndRemoveRoute(t *testing.T) { err = wgInterface.Create() require.NoError(t, err, "should create testing wireguard interface") + index, err := net.InterfaceByName(wgInterface.Name()) + require.NoError(t, err, "InterfaceByName should not return err") + intf := &net.Interface{Index: index.Index, Name: wgInterface.Name()} + // Prepare the environment if testCase.preExistingPrefix.IsValid() { - err := genericAddVPNRoute(testCase.preExistingPrefix, wgInterface.Name()) + err := addVPNRoute(testCase.preExistingPrefix, intf) require.NoError(t, err, "should not return err when adding pre-existing route") } // Add the route - err = genericAddVPNRoute(testCase.prefix, wgInterface.Name()) + err = addVPNRoute(testCase.prefix, intf) require.NoError(t, err, "should not return err when adding route") if testCase.shouldAddRoute { @@ -217,7 +231,7 @@ func TestAddExistAndRemoveRoute(t *testing.T) { require.True(t, ok, "route should exist") // remove route again if added - err = genericRemoveVPNRoute(testCase.prefix, wgInterface.Name()) + err = removeVPNRoute(testCase.prefix, intf) require.NoError(t, err, "should not return err") } @@ -345,43 +359,47 @@ func setupTestEnv(t *testing.T) { assert.NoError(t, cleanupRouting()) }) + index, err := net.InterfaceByName(wgIface.Name()) + require.NoError(t, err, "InterfaceByName should not return err") + intf := &net.Interface{Index: index.Index, Name: wgIface.Name()} + // default route exists in main table and vpn table - err = addVPNRoute(netip.MustParsePrefix("0.0.0.0/0"), wgIface.Name()) + err = addVPNRoute(netip.MustParsePrefix("0.0.0.0/0"), intf) require.NoError(t, err, "addVPNRoute should not return err") t.Cleanup(func() { - err = removeVPNRoute(netip.MustParsePrefix("0.0.0.0/0"), wgIface.Name()) + err = removeVPNRoute(netip.MustParsePrefix("0.0.0.0/0"), intf) assert.NoError(t, err, "removeVPNRoute should not return err") }) // 10.0.0.0/8 route exists in main table and vpn table - err = addVPNRoute(netip.MustParsePrefix("10.0.0.0/8"), wgIface.Name()) + err = addVPNRoute(netip.MustParsePrefix("10.0.0.0/8"), intf) require.NoError(t, err, "addVPNRoute should not return err") t.Cleanup(func() { - err = removeVPNRoute(netip.MustParsePrefix("10.0.0.0/8"), wgIface.Name()) + err = removeVPNRoute(netip.MustParsePrefix("10.0.0.0/8"), intf) assert.NoError(t, err, "removeVPNRoute should not return err") }) // 10.10.0.0/24 more specific route exists in vpn table - err = addVPNRoute(netip.MustParsePrefix("10.10.0.0/24"), wgIface.Name()) + err = addVPNRoute(netip.MustParsePrefix("10.10.0.0/24"), intf) require.NoError(t, err, "addVPNRoute should not return err") t.Cleanup(func() { - err = removeVPNRoute(netip.MustParsePrefix("10.10.0.0/24"), wgIface.Name()) + err = removeVPNRoute(netip.MustParsePrefix("10.10.0.0/24"), intf) assert.NoError(t, err, "removeVPNRoute should not return err") }) // 127.0.10.0/24 more specific route exists in vpn table - err = addVPNRoute(netip.MustParsePrefix("127.0.10.0/24"), wgIface.Name()) + err = addVPNRoute(netip.MustParsePrefix("127.0.10.0/24"), intf) require.NoError(t, err, "addVPNRoute should not return err") t.Cleanup(func() { - err = removeVPNRoute(netip.MustParsePrefix("127.0.10.0/24"), wgIface.Name()) + err = removeVPNRoute(netip.MustParsePrefix("127.0.10.0/24"), intf) assert.NoError(t, err, "removeVPNRoute should not return err") }) // unique route in vpn table - err = addVPNRoute(netip.MustParsePrefix("172.16.0.0/12"), wgIface.Name()) + err = addVPNRoute(netip.MustParsePrefix("172.16.0.0/12"), intf) require.NoError(t, err, "addVPNRoute should not return err") t.Cleanup(func() { - err = removeVPNRoute(netip.MustParsePrefix("172.16.0.0/12"), wgIface.Name()) + err = removeVPNRoute(netip.MustParsePrefix("172.16.0.0/12"), intf) assert.NoError(t, err, "removeVPNRoute should not return err") }) } diff --git a/client/internal/routemanager/systemops_windows.go b/client/internal/routemanager/systemops_windows.go index ba211082f..f9e75e2ed 100644 --- a/client/internal/routemanager/systemops_windows.go +++ b/client/internal/routemanager/systemops_windows.go @@ -6,8 +6,12 @@ import ( "fmt" "net" "net/netip" + "os" "os/exec" + "strconv" "strings" + "sync" + "time" log "github.com/sirupsen/logrus" "github.com/yusufpapurcu/wmi" @@ -21,6 +25,10 @@ type Win32_IP4RouteTable struct { Mask string } +var prefixList []netip.Prefix +var lastUpdate time.Time +var mux = sync.Mutex{} + var routeManager *RouteManager func setupRouting(initAddresses []net.IP, wgIface *iface.WGIface) (peer.BeforeAddPeerHookFunc, peer.AfterRemovePeerHookFunc, error) { @@ -32,15 +40,23 @@ func cleanupRouting() error { } func getRoutesFromTable() ([]netip.Prefix, error) { - var routes []Win32_IP4RouteTable + mux.Lock() + defer mux.Unlock() + query := "SELECT Destination, Mask FROM Win32_IP4RouteTable" + // If many routes are added at the same time this might block for a long time (seconds to minutes), so we cache the result + if !isCacheDisabled() && time.Since(lastUpdate) < 2*time.Second { + return prefixList, nil + } + + var routes []Win32_IP4RouteTable err := wmi.Query(query, &routes) if err != nil { return nil, fmt.Errorf("get routes: %w", err) } - var prefixList []netip.Prefix + prefixList = nil for _, route := range routes { addr, err := netip.ParseAddr(route.Destination) if err != nil { @@ -60,54 +76,29 @@ func getRoutesFromTable() ([]netip.Prefix, error) { prefixList = append(prefixList, routePrefix) } } + + lastUpdate = time.Now() return prefixList, nil } -func addRoutePowershell(prefix netip.Prefix, nexthop netip.Addr, intf, intfIdx string) error { - destinationPrefix := prefix.String() - psCmd := "New-NetRoute" - - addressFamily := "IPv4" - if prefix.Addr().Is6() { - addressFamily = "IPv6" - } - - script := fmt.Sprintf( - `%s -AddressFamily "%s" -DestinationPrefix "%s" -Confirm:$False -ErrorAction Stop -PolicyStore ActiveStore`, - psCmd, addressFamily, destinationPrefix, - ) - - if intfIdx != "" { - script = fmt.Sprintf( - `%s -InterfaceIndex %s`, script, intfIdx, - ) - } else { - script = fmt.Sprintf( - `%s -InterfaceAlias "%s"`, script, intf, - ) - } +func addRouteCmd(prefix netip.Prefix, nexthop netip.Addr, intf *net.Interface) error { + args := []string{"add", prefix.String()} if nexthop.IsValid() { - script = fmt.Sprintf( - `%s -NextHop "%s"`, script, nexthop, - ) + args = append(args, nexthop.Unmap().String()) + } else { + addr := "0.0.0.0" + if prefix.Addr().Is6() { + addr = "::" + } + args = append(args, addr) } - out, err := exec.Command("powershell", "-Command", script).CombinedOutput() - log.Tracef("PowerShell %s: %s", script, string(out)) - - if err != nil { - return fmt.Errorf("PowerShell add route: %w", err) + if intf != nil { + args = append(args, "if", strconv.Itoa(intf.Index)) } - return nil -} - -func addRouteCmd(prefix netip.Prefix, nexthop netip.Addr, _ string) error { - args := []string{"add", prefix.String(), nexthop.Unmap().String()} - out, err := exec.Command("route", args...).CombinedOutput() - log.Tracef("route %s: %s", strings.Join(args, " "), out) if err != nil { return fmt.Errorf("route add: %w", err) @@ -116,21 +107,20 @@ func addRouteCmd(prefix netip.Prefix, nexthop netip.Addr, _ string) error { return nil } -func addToRouteTable(prefix netip.Prefix, nexthop netip.Addr, intf string) error { - var intfIdx string - if nexthop.Zone() != "" { - intfIdx = nexthop.Zone() +func addToRouteTable(prefix netip.Prefix, nexthop netip.Addr, intf *net.Interface) error { + if nexthop.Zone() != "" && intf == nil { + zone, err := strconv.Atoi(nexthop.Zone()) + if err != nil { + return fmt.Errorf("invalid zone: %w", err) + } + intf = &net.Interface{Index: zone} nexthop.WithZone("") } - // Powershell doesn't support adding routes without an interface but allows to add interface by name - if intf != "" || intfIdx != "" { - return addRoutePowershell(prefix, nexthop, intf, intfIdx) - } return addRouteCmd(prefix, nexthop, intf) } -func removeFromRouteTable(prefix netip.Prefix, nexthop netip.Addr, _ string) error { +func removeFromRouteTable(prefix netip.Prefix, nexthop netip.Addr, _ *net.Interface) error { args := []string{"delete", prefix.String()} if nexthop.IsValid() { nexthop.WithZone("") @@ -145,3 +135,7 @@ func removeFromRouteTable(prefix netip.Prefix, nexthop netip.Addr, _ string) err } return nil } + +func isCacheDisabled() bool { + return os.Getenv("NB_DISABLE_ROUTE_CACHE") == "true" +}