diff --git a/src/windows/common/Distribution.cpp b/src/windows/common/Distribution.cpp index 1ca3ee759..88c92649c 100644 --- a/src/windows/common/Distribution.cpp +++ b/src/windows/common/Distribution.cpp @@ -210,34 +210,157 @@ std::optional LookupDistributionInManifest(const DistributionList return *it; } +// Helper function to merge distributions from multiple manifests +void MergeDistributionLists(DistributionList& target, const DistributionList& source) +{ + // Merge legacy distributions + if (source.Distributions.has_value()) + { + if (!target.Distributions.has_value()) + { + target.Distributions = std::vector{}; + } + + for (const auto& dist : *source.Distributions) + { + // Check if distribution already exists (avoid duplicates) + auto it = std::find_if(target.Distributions->begin(), target.Distributions->end(), + [&](const Distribution& d) { + return d.Name == dist.Name + && d.Version == dist.Version + && d.Architecture == dist.Architecture; + }); + + if (it == target.Distributions->end()) + { + target.Distributions->push_back(dist); + } + } + } + + // Merge modern distributions + if (source.ModernDistributions.has_value()) + { + if (!target.ModernDistributions.has_value()) + { + target.ModernDistributions = std::map>{}; + } + + for (const auto& [distroName, versions] : *source.ModernDistributions) + { + auto& targetVersions = (*target.ModernDistributions)[distroName]; + + for (const auto& version : versions) + { + // Check if version already exists + auto it = std::find_if(targetVersions.begin(), targetVersions.end(), + [&](const ModernDistributionVersion& v) { + return v.Name == version.Name + && v.Version == version.Version + && v.Architecture == version.Architecture; + }); + + if (it == targetVersions.end()) + { + targetVersions.push_back(version); + } + } + } + } + + // Update default if source has one and target doesn't + if (source.Default.has_value() && !target.Default.has_value()) + { + target.Default = source.Default; + } +} + } // namespace AvailableDistributions wsl::windows::common::distribution::GetAvailable() { AvailableDistributions distributions{}; + // Determine the base manifest URL + // Priority: HKCU > HKLM > Default std::wstring url = c_defaultDistroListUrl; - std::optional appendUrl; + std::vector appendUrls; + try { - const auto registryKey = registry::OpenLxssMachineKey(); - url = registry::ReadString(registryKey.get(), nullptr, c_distroUrlRegistryValue, c_defaultDistroListUrl); + // First check HKEY_LOCAL_MACHINE + const auto machineKey = registry::OpenLxssMachineKey(); + url = registry::ReadString(machineKey.get(), nullptr, c_distroUrlRegistryValue, c_defaultDistroListUrl); + + // Read HKLM append URLs (supports REG_MULTI_SZ) + auto hklmAppendUrls = registry::ReadWideStringSet(machineKey.get(), nullptr, c_distroUrlAppendRegistryValue, {}); + appendUrls.insert(appendUrls.end(), hklmAppendUrls.begin(), hklmAppendUrls.end()); + if (url != c_defaultDistroListUrl) { - WSL_LOG("Found custom URL for distribution list", TraceLoggingValue(url.c_str(), "url")); + WSL_LOG("Found custom URL for distribution list in HKLM", TraceLoggingValue(url.c_str(), "url")); } + + if (!appendUrls.empty()) + { + WSL_LOG("Found append URLs in HKLM", TraceLoggingValue(static_cast(appendUrls.size()), "count")); + } + } + CATCH_LOG() - appendUrl = registry::ReadOptionalString(registryKey.get(), nullptr, c_distroUrlAppendRegistryValue); + try + { + // Then check HKEY_CURRENT_USER (takes precedence) + const auto userKey = registry::OpenLxssUserKey(); + + // Check if user has overridden the base URL + auto userUrl = registry::ReadOptionalString(userKey.get(), nullptr, c_distroUrlRegistryValue); + if (userUrl.has_value()) + { + url = userUrl.value(); + WSL_LOG("Found custom URL for distribution list in HKCU (overriding)", TraceLoggingValue(url.c_str(), "url")); + } + + // Read HKCU append URLs (supports REG_MULTI_SZ) - these are added to HKLM append URLs + auto hkcuAppendUrls = registry::ReadWideStringSet(userKey.get(), nullptr, c_distroUrlAppendRegistryValue, {}); + appendUrls.insert(appendUrls.end(), hkcuAppendUrls.begin(), hkcuAppendUrls.end()); + + if (!hkcuAppendUrls.empty()) + { + WSL_LOG("Found append URLs in HKCU", TraceLoggingValue(static_cast(hkcuAppendUrls.size()), "count")); + } } CATCH_LOG() + // Load the base manifest distributions.Manifest = ReadFromManifest(url); - if (appendUrl.has_value()) + // Load and merge all append manifests + if (!appendUrls.empty()) { - WSL_LOG("Found append URL for distribution list", TraceLoggingValue(appendUrl->c_str(), "url")); - - distributions.OverrideManifest = ReadFromManifest(appendUrl.value()); + for (const auto& appendUrl : appendUrls) + { + try + { + WSL_LOG("Loading append manifest", TraceLoggingValue(appendUrl.c_str(), "url")); + auto appendManifest = ReadFromManifest(appendUrl); + + // Merge into override manifest if it exists, otherwise create it + if (!distributions.OverrideManifest.has_value()) + { + distributions.OverrideManifest = appendManifest; + } + else + { + MergeDistributionLists(*distributions.OverrideManifest, appendManifest); + } + } + catch (...) + { + // Log the error but continue with other sources + LOG_CAUGHT_EXCEPTION_MSG("Failed to load append manifest from %ls", appendUrl.c_str()); + } + } } return distributions; diff --git a/src/windows/common/Distribution.h b/src/windows/common/Distribution.h index 4bbdfa3e6..eb57cf84b 100644 --- a/src/windows/common/Distribution.h +++ b/src/windows/common/Distribution.h @@ -22,12 +22,24 @@ Module Name: namespace wsl::windows::common::distribution { +// Represents a file to be injected into the distribution during installation +struct InjectedFile +{ + std::wstring Source; // "url" or "inline" + std::optional Url; // URL to download from (if Source == "url") + std::optional Sha256; // SHA256 hash for verification (if Source == "url") + std::optional Contents; // Inline file contents (if Source == "inline") + + NLOHMANN_DEFINE_TYPE_INTRUSIVE_WITH_DEFAULT(InjectedFile, Source, Url, Sha256, Contents); +}; + struct DistributionArchive { std::wstring Url; std::wstring Sha256; + std::optional> Files; // Map of file paths to inject - NLOHMANN_DEFINE_TYPE_INTRUSIVE_WITH_DEFAULT(DistributionArchive, Url, Sha256); + NLOHMANN_DEFINE_TYPE_INTRUSIVE_WITH_DEFAULT(DistributionArchive, Url, Sha256, Files); }; struct ModernDistributionVersion diff --git a/src/windows/common/WslInstall.cpp b/src/windows/common/WslInstall.cpp index 1aab74850..93653e396 100644 --- a/src/windows/common/WslInstall.cpp +++ b/src/windows/common/WslInstall.cpp @@ -324,5 +324,108 @@ std::pair WslInstall::InstallModernDistribution( fixedVhd ? LXSS_IMPORT_DISTRO_FLAGS_FIXED_VHD : 0, vhdSize); + // Inject files if specified in the distribution metadata + if (downloadInfo->Files.has_value() && !downloadInfo->Files->empty()) + { + try + { + PrintMessage(L"Injecting configuration files...", stdout); + + for (const auto& [targetPath, fileSpec] : *downloadInfo->Files) + { + const auto targetPathWide = wsl::shared::string::MultiByteToWide(targetPath); + + if (wsl::windows::common::string::IsEqual(fileSpec.Source, L"inline", true)) + { + // Inline content - write directly using base64 encoding to avoid shell escaping issues + if (!fileSpec.Contents.has_value()) + { + LOG_HR_MSG(E_INVALIDARG, "Inline file source specified but no contents provided for %s", targetPath.c_str()); + continue; + } + + // Convert content to base64 to safely pass through shell + const auto contentBase64 = wsl::shared::string::Base64EncodeFromWide(fileSpec.Contents->c_str()); + + // Create parent directory and decode base64 content into file + const auto command = std::format( + L"/bin/sh -c \"mkdir -p $(dirname '{}') && echo '{}' | base64 -d > '{}'\"", + targetPathWide, + wsl::shared::string::MultiByteToWide(contentBase64), + targetPathWide); + + LPCWSTR argv[] = {L"/bin/sh", L"-c", command.c_str()}; + const auto exitCode = service.LaunchProcess(&id, L"/bin/sh", 3, argv, LXSS_LAUNCH_FLAGS_NONE, nullptr, nullptr, 30000); + + if (exitCode != 0) + { + LOG_HR_MSG(E_FAIL, "Failed to inject inline file %s, exit code: %d", targetPath.c_str(), exitCode); + } + } + else if (wsl::windows::common::string::IsEqual(fileSpec.Source, L"url", true)) + { + // URL-based file - download and inject + if (!fileSpec.Url.has_value() || !fileSpec.Sha256.has_value()) + { + LOG_HR_MSG(E_INVALIDARG, "URL file source specified but no URL or SHA256 provided for %s", targetPath.c_str()); + continue; + } + + // Download file to temp location with UUID to prevent collisions + GUID uniqueId{}; + THROW_IF_FAILED(CoCreateGuid(&uniqueId)); + const auto tempFileName = std::format(L"injected_file_{}.tmp", + wsl::shared::string::GuidToString(uniqueId, wsl::shared::string::GuidToStringFlags::None)); + const auto tempFilePath = DownloadFile(*fileSpec.Url, tempFileName); + + // Verify hash + wil::unique_handle tempFile{CreateFile(tempFilePath.c_str(), GENERIC_READ, FILE_SHARE_READ, nullptr, OPEN_EXISTING, 0, nullptr)}; + if (tempFile) + { + try + { + EnforceFileHash(tempFile.get(), *fileSpec.Sha256); + tempFile.reset(); + + // Read file content and inject (using base64 for safety) + const auto fileContent = wsl::shared::string::ReadFile(tempFilePath.c_str()); + const auto contentBase64 = wsl::shared::string::Base64Encode(fileContent); + + const auto command = std::format( + L"/bin/sh -c \"mkdir -p $(dirname '{}') && echo '{}' | base64 -d > '{}'\"", + targetPathWide, + wsl::shared::string::MultiByteToWide(contentBase64), + targetPathWide); + + LPCWSTR argv[] = {L"/bin/sh", L"-c", command.c_str()}; + const auto exitCode = service.LaunchProcess(&id, L"/bin/sh", 3, argv, LXSS_LAUNCH_FLAGS_NONE, nullptr, nullptr, 30000); + + if (exitCode != 0) + { + LOG_HR_MSG(E_FAIL, "Failed to inject URL-based file %s, exit code: %d", targetPath.c_str(), exitCode); + } + } + catch (...) + { + LOG_CAUGHT_EXCEPTION_MSG("Failed to inject file from URL for %s", targetPath.c_str()); + } + + // Clean up temp file + DeleteFileW(tempFilePath.c_str()); + } + } + else + { + LOG_HR_MSG(E_INVALIDARG, "Unknown file source type: %ls for %s", fileSpec.Source.c_str(), targetPath.c_str()); + } + } + } + catch (...) + { + // Log but don't fail installation if file injection fails + LOG_CAUGHT_EXCEPTION_MSG("File injection failed, but installation will continue"); + } + } + return {installedName.get(), id}; } \ No newline at end of file diff --git a/src/windows/common/registry.cpp b/src/windows/common/registry.cpp index 92fa1cf10..d30463405 100644 --- a/src/windows/common/registry.cpp +++ b/src/windows/common/registry.cpp @@ -426,6 +426,50 @@ std::vector wsl::windows::common::registry::ReadStringSet( return Values; } +std::vector wsl::windows::common::registry::ReadWideStringSet( + _In_ HKEY Key, _In_opt_ LPCWSTR KeyName, _In_opt_ LPCWSTR ValueName, const std::vector& Default) +{ + // + // Detect if the key exists and determine how large of a buffer is needed. + // If the key does not exist, return the default value. + // + + LONG Result; + DWORD Size = 0; + Result = RegGetValueW(Key, KeyName, ValueName, RRF_RT_REG_MULTI_SZ, nullptr, nullptr, &Size); + if ((Result == ERROR_PATH_NOT_FOUND) || (Result == ERROR_FILE_NOT_FOUND) || (Size == 0)) + { + return Default; + } + + ReportErrorIfFailed(Result, Key, KeyName, ValueName); + + // + // Allocate a buffer to hold the value and two NULL terminators. + // + + std::vector Buffer((Size / sizeof(WCHAR)) + 2); + + // + // Read the value. + // + + Result = RegGetValueW(Key, KeyName, ValueName, RRF_RT_REG_MULTI_SZ, nullptr, Buffer.data(), &Size); + ReportErrorIfFailed(Result, Key, KeyName, ValueName); + + // + // Convert the reg value into a vector of wide strings. + // + + std::vector Values{}; + for (auto Current = Buffer.data(); UNICODE_NULL != *Current; Current += wcslen(Current) + 1) + { + Values.push_back(Current); + } + + return Values; +} + void wsl::windows::common::registry::WriteDword(_In_ HKEY Key, _In_ LPCWSTR SubKey, _In_ LPCWSTR ValueName, _In_ DWORD Value) { const auto Result = RegSetKeyValueW(Key, SubKey, ValueName, REG_DWORD, &Value, sizeof(Value)); diff --git a/src/windows/common/registry.hpp b/src/windows/common/registry.hpp index 01a1aca61..d18f6f1b6 100644 --- a/src/windows/common/registry.hpp +++ b/src/windows/common/registry.hpp @@ -63,6 +63,7 @@ std::wstring ReadString(_In_ HKEY Key, _In_opt_ LPCWSTR KeyName, _In_opt_ LPCWST std::optional ReadOptionalString(_In_ HKEY Key, _In_opt_ LPCWSTR KeyName, _In_opt_ LPCWSTR ValueName); std::vector ReadStringSet(_In_ HKEY Key, _In_opt_ LPCWSTR KeyName, _In_opt_ LPCWSTR ValueName, _In_ const std::vector& Default); +std::vector ReadWideStringSet(_In_ HKEY Key, _In_opt_ LPCWSTR KeyName, _In_opt_ LPCWSTR ValueName, _In_ const std::vector& Default); void WriteDword(_In_ HKEY Key, _In_ LPCWSTR SubKey, _In_ LPCWSTR KeyName, _In_ DWORD Value);