diff --git a/src/linux/init/GnsEngine.cpp b/src/linux/init/GnsEngine.cpp index 37ee355a8..7f4189541 100644 --- a/src/linux/init/GnsEngine.cpp +++ b/src/linux/init/GnsEngine.cpp @@ -268,6 +268,13 @@ void GnsEngine::ProcessRouteChange(Interface& interface, const wsl::shared::hns: auto interfaceRoute = Route{addrFamily, {{addrFamily, route.SitePrefixLength, nextHopValue}}, interface.Index(), defaultRoute, to, route.Metric}; + // Extract preferred source if provided + if (!route.PreferredSource.empty()) + { + const auto preferredSourceValue = wsl::shared::string::WideToMultiByte(route.PreferredSource); + interfaceRoute.preferredSource = Address{addrFamily, 0, preferredSourceValue}; + } + auto routeString = utils::Stringify(interfaceRoute); if (action == ModifyRequestType::Add) diff --git a/src/linux/netlinkutil/Route.h b/src/linux/netlinkutil/Route.h index 7a098c32f..fb41e24d7 100644 --- a/src/linux/netlinkutil/Route.h +++ b/src/linux/netlinkutil/Route.h @@ -13,6 +13,7 @@ struct Route std::optional
to; int metric = 0; bool isLoopbackRoute = false; + std::optional
preferredSource; Route(int family, const std::optional
& via, int dev, bool defaultRoute, const std::optional
& to, int metric); diff --git a/src/linux/netlinkutil/RoutingTable.cpp b/src/linux/netlinkutil/RoutingTable.cpp index 9ec761c0c..242af6df9 100644 --- a/src/linux/netlinkutil/RoutingTable.cpp +++ b/src/linux/netlinkutil/RoutingTable.cpp @@ -268,22 +268,29 @@ void RoutingTable::ModifyLinkLocalRouteImpl(const Route& route, int operation, i { utils::AddressAttribute to; utils::IntegerAttribute metric; + utils::AddressAttribute preferredSource; } __attribute__((packed)); GNS_LOG_INFO( - "SendMessage Route (to {}, via {}), operation ({}), netLinkflags ({})", + "SendMessage Route (to {}, via {}, preferredSource {}), operation ({}), netLinkflags ({})", route.to.has_value() ? route.to.value().Addr().c_str() : "[empty]", route.via.has_value() ? route.via.value().Addr().c_str() : "[empty]", + route.preferredSource.has_value() ? route.preferredSource.value().Addr().c_str() : "[empty]", RouteOperationToString(operation), NetLinkFormatFlagsToString(flags).c_str()); SendMessage(route, operation, flags, [&](Message& message) { GNS_LOG_INFO( - "InitializeAddressAttribute RTA_DST ({}) RTA_GATEWAY ([not set]), RTA_PRIORITY ({})", + "InitializeAddressAttribute RTA_DST ({}) RTA_GATEWAY ([not set]), RTA_PRIORITY ({}), RTA_PREFSRC ({})", route.to.has_value() ? route.to.value().Addr().c_str() : "[empty]", - route.metric); + route.metric, + route.preferredSource.has_value() ? route.preferredSource.value().Addr().c_str() : "[not set]"); utils::InitializeAddressAttribute(message.to, route.to.value(), RTA_DST); utils::InitializeIntegerAttribute(message.metric, route.metric, RTA_PRIORITY); + if (route.preferredSource.has_value()) + { + utils::InitializeAddressAttribute(message.preferredSource, route.preferredSource.value(), RTA_PREFSRC); + } }); } @@ -300,24 +307,31 @@ void RoutingTable::ModifyOfflinkRouteImpl(const Route& route, int operation, int utils::AddressAttribute to; utils::AddressAttribute via; utils::IntegerAttribute metric; + utils::AddressAttribute preferredSource; } __attribute__((packed)); GNS_LOG_INFO( - "SendMessage Route (to {}, via {}), operation ({}), netLinkflags ({})", + "SendMessage Route (to {}, via {}, preferredSource {}), operation ({}), netLinkflags ({})", route.to.has_value() ? route.to.value().Addr().c_str() : "[empty]", route.via.has_value() ? route.via.value().Addr().c_str() : "[empty]", + route.preferredSource.has_value() ? route.preferredSource.value().Addr().c_str() : "[empty]", RouteOperationToString(operation), NetLinkFormatFlagsToString(flags).c_str()); SendMessage(route, operation, flags, [&](Message& message) { GNS_LOG_INFO( - "InitializeAddressAttribute RTA_DST ({}) RTA_GATEWAY ({}), RTA_PRIORITY ({})", + "InitializeAddressAttribute RTA_DST ({}) RTA_GATEWAY ({}), RTA_PRIORITY ({}), RTA_PREFSRC ({})", route.to.has_value() ? route.to.value().Addr().c_str() : "[empty]", route.via.has_value() ? route.via.value().Addr().c_str() : "[empty]", - route.metric); + route.metric, + route.preferredSource.has_value() ? route.preferredSource.value().Addr().c_str() : "[not set]"); utils::InitializeAddressAttribute(message.to, route.to.value(), RTA_DST); utils::InitializeAddressAttribute(message.via, route.via.value(), RTA_GATEWAY); utils::InitializeIntegerAttribute(message.metric, route.metric, RTA_PRIORITY); + if (route.preferredSource.has_value()) + { + utils::InitializeAddressAttribute(message.preferredSource, route.preferredSource.value(), RTA_PREFSRC); + } }); } diff --git a/src/shared/inc/hns_schema.h b/src/shared/inc/hns_schema.h index 2d9f98aa4..572ac5a07 100644 --- a/src/shared/inc/hns_schema.h +++ b/src/shared/inc/hns_schema.h @@ -160,8 +160,9 @@ struct Route uint8_t SitePrefixLength{}; uint32_t Metric{}; uint16_t Family{}; + std::wstring PreferredSource; - NLOHMANN_DEFINE_TYPE_INTRUSIVE_WITH_DEFAULT(Route, NextHop, DestinationPrefix, SitePrefixLength, Metric, Family); + NLOHMANN_DEFINE_TYPE_INTRUSIVE_WITH_DEFAULT(Route, NextHop, DestinationPrefix, SitePrefixLength, Metric, Family, PreferredSource); }; enum class ModifyRequestType diff --git a/src/windows/service/exe/WslCoreNetworkEndpointSettings.h b/src/windows/service/exe/WslCoreNetworkEndpointSettings.h index 5e165c9ed..a682427c7 100644 --- a/src/windows/service/exe/WslCoreNetworkEndpointSettings.h +++ b/src/windows/service/exe/WslCoreNetworkEndpointSettings.h @@ -163,6 +163,8 @@ struct EndpointRoute unsigned char SitePrefixLength = 0; unsigned int Metric = 0; bool IsAutoGeneratedPrefixRoute = false; + SOCKADDR_INET PreferredSource{}; + std::wstring PreferredSourceString{}; EndpointRoute() = default; ~EndpointRoute() noexcept = default; @@ -219,7 +221,8 @@ struct EndpointRoute { return Family == rhs.Family && DestinationPrefix.PrefixLength == rhs.DestinationPrefix.PrefixLength && DestinationPrefix.Prefix == rhs.DestinationPrefix.Prefix && NextHop == rhs.NextHop && - SitePrefixLength == rhs.SitePrefixLength && Metric == rhs.Metric; + SitePrefixLength == rhs.SitePrefixLength && Metric == rhs.Metric && + PreferredSource == rhs.PreferredSource; } bool operator!=(const EndpointRoute& other) const diff --git a/src/windows/service/exe/WslCoreTcpIpStateTracking.h b/src/windows/service/exe/WslCoreTcpIpStateTracking.h index 06ad59e50..60fde178e 100644 --- a/src/windows/service/exe/WslCoreTcpIpStateTracking.h +++ b/src/windows/service/exe/WslCoreTcpIpStateTracking.h @@ -146,6 +146,7 @@ struct TrackedRoute route.SitePrefixLength = Route.SitePrefixLength; route.NextHop = Route.NextHopString; route.Metric = Route.Metric; + route.PreferredSource = Route.PreferredSourceString; return route; } diff --git a/src/windows/service/exe/WslMirroredNetworking.cpp b/src/windows/service/exe/WslMirroredNetworking.cpp index f7bd0e174..6764b60dd 100644 --- a/src/windows/service/exe/WslMirroredNetworking.cpp +++ b/src/windows/service/exe/WslMirroredNetworking.cpp @@ -208,6 +208,94 @@ void wsl::core::networking::WslMirroredNetworkManager::ProcessIpAddressChange() } } +// Helper function to find the source address on the same subnet as the next-hop. +// Uses longest prefix match to select the most specific matching subnet when multiple addresses match. +static std::optional FindSourceAddressForNextHop( + const SOCKADDR_INET& nextHop, + const std::set& addresses) +{ + std::optional bestMatch; + unsigned char longestPrefix = 0; + + for (const auto& addr : addresses) + { + // Skip if address family doesn't match + if (addr.Address.si_family != nextHop.si_family) + { + continue; + } + + // Skip if this prefix is shorter than our best match + if (addr.PrefixLength < longestPrefix) + { + continue; + } + + bool matches = false; + + if (nextHop.si_family == AF_INET) + { + // Handle IPv4 prefix matching with mask generation + uint32_t mask; + if (addr.PrefixLength == 0) + { + // Match any address (0.0.0.0/0) + mask = 0; + } + else if (addr.PrefixLength >= 32) + { + // Exact match required + mask = 0xFFFFFFFF; + } + else + { + // Generate mask for prefix length (e.g., /24 -> 0xFFFFFF00) + mask = 0xFFFFFFFF << (32 - addr.PrefixLength); + } + + // Compare network portions (convert to host byte order for masking) + uint32_t addrNetwork = ntohl(addr.Address.Ipv4.sin_addr.S_un.S_addr) & mask; + uint32_t nextHopNetwork = ntohl(nextHop.Ipv4.sin_addr.S_un.S_addr) & mask; + + matches = (addrNetwork == nextHopNetwork); + } + else if (nextHop.si_family == AF_INET6) + { + // Optimized IPv6 prefix matching using byte-by-byte comparison with memcmp and partial byte masking + matches = true; + int remainingBits = addr.PrefixLength; + const uint8_t* addrBytes = addr.Address.Ipv6.sin6_addr.u.Byte; + const uint8_t* nextHopBytes = nextHop.Ipv6.sin6_addr.u.Byte; + + // Process full bytes first (faster than bit-by-bit) + int fullBytes = remainingBits / 8; + if (fullBytes > 0 && memcmp(addrBytes, nextHopBytes, fullBytes) != 0) + { + matches = false; + } + else if (remainingBits % 8 != 0) + { + // Handle partial byte at the end + int partialBits = remainingBits % 8; + uint8_t mask = 0xFF << (8 - partialBits); + if ((addrBytes[fullBytes] & mask) != (nextHopBytes[fullBytes] & mask)) + { + matches = false; + } + } + } + + if (matches) + { + // Found a better match (longer prefix) + bestMatch = addr.Address; + longestPrefix = addr.PrefixLength; + } + } + + return bestMatch; +} + _Requires_lock_held_(m_networkLock) void wsl::core::networking::WslMirroredNetworkManager::ProcessRouteChange() { @@ -375,6 +463,15 @@ void wsl::core::networking::WslMirroredNetworkManager::ProcessRouteChange() newRoute.NextHop.si_family = route.NextHop.si_family; newRoute.NextHopString = windows::common::string::SockAddrInetToWstring(newRoute.NextHop); + // Find and set the preferred source address from the same subnet as the next-hop + // This ensures correct source address selection for multi-IP interfaces + auto sourceAddr = FindSourceAddressForNextHop(route.NextHop, endpoint.Network->IpAddresses); + if (sourceAddr.has_value()) + { + newRoute.PreferredSource = sourceAddr.value(); + newRoute.PreferredSourceString = windows::common::string::SockAddrInetToWstring(newRoute.PreferredSource); + } + // force a copy so the route strings are re-calculated in the new EndpointRoute object newRoutes.emplace_back(std::move(newRoute)); }