Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions src/windows/common/WslClient.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1552,6 +1552,8 @@ int WslaShell(_In_ std::wstring_view commandLine)
parser.AddArgument(Utf8String(shell), L"--shell");
parser.AddArgument(
SetFlag<int, WslaFeatureFlagsDnsTunneling>(reinterpret_cast<int&>(sessionSettings.FeatureFlags)), L"--dns-tunneling");
parser.AddArgument(
SetFlag<int, WslaFeatureFlagsPmemVhds>(reinterpret_cast<int&>(sessionSettings.FeatureFlags)), L"--pmem-vhds");
parser.AddArgument(
SetFlag<int, WslaFeatureFlagsVirtioFs>(reinterpret_cast<int&>(sessionSettings.FeatureFlags)), L"--virtiofs");
parser.AddArgument(Integer(sessionSettings.MemoryMb), L"--memory");
Expand Down
49 changes: 49 additions & 0 deletions src/windows/common/hcs_schema.h
Original file line number Diff line number Diff line change
Expand Up @@ -458,6 +458,53 @@ inline void to_json(nlohmann::json& j, const DebugOptions& d)
OMIT_IF_EMPTY(j, d, ShutdownOrResetSavedStateFileName);
}

enum class VirtualPMemImageFormat
{
Vhdx,
Vhd1
};

NLOHMANN_JSON_SERIALIZE_ENUM(
VirtualPMemImageFormat,
{
{VirtualPMemImageFormat::Vhdx, "Vhdx"},
{VirtualPMemImageFormat::Vhd1, "Vhd1"},
})

struct VirtualPMemDevice
{
std::wstring HostPath;
bool ReadOnly;
VirtualPMemImageFormat ImageFormat;
// uint64_t SizeBytes;
// std::map<uint64_t, VirtualPMemMapping> Mappings;

NLOHMANN_DEFINE_TYPE_INTRUSIVE_ONLY_SERIALIZE(VirtualPMemDevice, HostPath, ReadOnly, ImageFormat);
};

enum class VirtualPMemBackingType
{
Virtual,
Physical
};

NLOHMANN_JSON_SERIALIZE_ENUM(
VirtualPMemBackingType,
{
{VirtualPMemBackingType::Virtual, "Virtual"},
{VirtualPMemBackingType::Physical, "Physical"},
})

struct VirtualPMemController
{
std::map<std::string, VirtualPMemDevice> Devices;
uint8_t MaximumCount;
uint64_t MaximumSizeBytes;
VirtualPMemBackingType Backing;

NLOHMANN_DEFINE_TYPE_INTRUSIVE_ONLY_SERIALIZE(VirtualPMemController, Devices, MaximumCount, MaximumSizeBytes, Backing);
};

struct Devices
{
std::optional<VirtioSerial> VirtioSerial;
Expand All @@ -466,6 +513,7 @@ struct Devices
EmptyObject Battery;
HvSocket HvSocket;
std::map<std::string, Scsi> Scsi;
std::optional<VirtualPMemController> VirtualPMem;
};

inline void to_json(nlohmann::json& j, const Devices& devices)
Expand All @@ -478,6 +526,7 @@ inline void to_json(nlohmann::json& j, const Devices& devices)
{"Scsi", devices.Scsi}};

OMIT_IF_EMPTY(j, devices, VirtioSerial);
OMIT_IF_EMPTY(j, devices, VirtualPMem);
}

struct VirtualMachine
Expand Down
140 changes: 96 additions & 44 deletions src/windows/wslaservice/exe/WSLAVirtualMachine.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -282,17 +282,73 @@ void WSLAVirtualMachine::Start()
vmSettings.Chipset.Uefi = std::move(uefiSettings);
}

// Initialize other devices.
vmSettings.Devices.Scsi["0"] = hcs::Scsi{};
hcs::HvSocket hvSocketConfig{};
#ifdef WSL_KERNEL_MODULES_PATH

auto kernelModulesPath = std::filesystem::path(TEXT(WSL_KERNEL_MODULES_PATH));

#else

auto kernelModulesPath = basePath / L"tools" / L"modules.vhd";

#endif

// Initialize the boot VHDs.
std::variant<ULONG, std::string> rootVhd;
std::variant<ULONG, std::string> modulesVhd;
hcs::Scsi scsiController{};
if (!FeatureEnabled(WslaFeatureFlagsPmemVhds))
{
ULONG nextLun = 0;
auto attachScsiDisk = [&](PCWSTR path) {
auto lun = nextLun;
nextLun += 1;
hcs::Attachment disk{};
disk.Type = hcs::AttachmentType::VirtualDisk;
disk.Path = path;
disk.ReadOnly = true;
disk.SupportCompressedVolumes = true;
disk.AlwaysAllowSparseFiles = true;
disk.SupportEncryptedFiles = true;
scsiController.Attachments[std::to_string(lun)] = std::move(disk);
AttachedDisk attachedDisk{path};
m_attachedDisks.emplace(lun, std::move(attachedDisk));
return lun;
};

rootVhd = attachScsiDisk(m_settings.RootVhd.c_str());
modulesVhd = attachScsiDisk(kernelModulesPath.c_str());
}
else
{
hcs::VirtualPMemController pmemController;
pmemController.MaximumCount = 0;
pmemController.MaximumSizeBytes = 0;
pmemController.Backing = hcs::VirtualPMemBackingType::Virtual;
auto attachPmemDisk = [&](PCWSTR path) {
ULONG deviceId = pmemController.MaximumCount;
pmemController.MaximumCount += 1;
hcs::VirtualPMemDevice vhd;
vhd.HostPath = path;
vhd.ReadOnly = true;
vhd.ImageFormat = hcs::VirtualPMemImageFormat::Vhd1;
pmemController.Devices[std::to_string(deviceId)] = std::move(vhd);
return std::format("/dev/pmem{}", deviceId);
};

rootVhd = attachPmemDisk(m_settings.RootVhd.c_str());
modulesVhd = attachPmemDisk(kernelModulesPath.c_str());
vmSettings.Devices.VirtualPMem = std::move(pmemController);
}

// Initialize the SCSI controller.
vmSettings.Devices.Scsi["0"] = std::move(scsiController);

// Construct a security descriptor that allows system and the current user.
wil::unique_hlocal_string userSidString;
THROW_LAST_ERROR_IF(!ConvertSidToStringSidW(m_userSid, &userSidString));

std::wstring securityDescriptor{L"D:P(A;;FA;;;SY)(A;;FA;;;"};
securityDescriptor += userSidString.get();
securityDescriptor += L")";
std::wstring securityDescriptor = std::format(L"D:P(A;;FA;;;SY)(A;;FA;;;{})", userSidString.get());
hcs::HvSocket hvSocketConfig{};
hvSocketConfig.HvSocketConfig.DefaultBindSecurityDescriptor = securityDescriptor;
hvSocketConfig.HvSocketConfig.DefaultConnectSecurityDescriptor = securityDescriptor;
vmSettings.Devices.HvSocket = std::move(hvSocketConfig);
Expand Down Expand Up @@ -352,20 +408,29 @@ void WSLAVirtualMachine::Start()

ConfigureNetworking();

// Mount the kernel modules VHD.

#ifdef WSL_KERNEL_MODULES_PATH

auto kernelModulesPath = std::filesystem::path(TEXT(WSL_KERNEL_MODULES_PATH));

#else
// Configure mounts.
auto getDevicePath = [&](std::variant<ULONG, std::string>& vhd) -> const std::string& {
// If the variant holds the SCSI LUN, query the guest for the device path.
if (std::holds_alternative<ULONG>(vhd))
{
const auto lun = std::get<ULONG>(vhd);
auto it = m_attachedDisks.find(lun);
WI_ASSERT(it != m_attachedDisks.end() && it->second.Device.empty());

auto kernelModulesPath = basePath / L"tools" / L"modules.vhd";
it->second.Device = GetVhdDevicePath(lun);
vhd = it->second.Device;
}

#endif
return std::get<std::string>(vhd);
};

auto [_, device] = AttachDisk(kernelModulesPath.c_str(), true);
Mount(m_initChannel, device.c_str(), "", "ext4", "ro", WSLA_MOUNT::KernelModules);
Mount(m_initChannel, getDevicePath(rootVhd).c_str(), "/mnt", m_settings.RootVhdType.c_str(), "ro", WSLAMountFlagsChroot | WSLAMountFlagsWriteableOverlayFs);
Mount(m_initChannel, nullptr, "/dev", "devtmpfs", "", 0);
Mount(m_initChannel, nullptr, "/sys", "sysfs", "", 0);
Mount(m_initChannel, nullptr, "/proc", "proc", "", 0);
Mount(m_initChannel, nullptr, "/dev/pts", "devpts", "noatime,nosuid,noexec,gid=5,mode=620", 0);
Mount(m_initChannel, nullptr, "/sys/fs/cgroup", "cgroup2", "", 0);
Mount(m_initChannel, getDevicePath(modulesVhd).c_str(), "", "ext4", "ro", WSLA_MOUNT::KernelModules);

// Configure GPU if requested.
if (FeatureEnabled(WslaFeatureFlagsGPU))
Expand All @@ -382,24 +447,6 @@ void WSLAVirtualMachine::Start()
}

wsl::windows::common::hcs::ModifyComputeSystem(m_computeSystem.get(), wsl::shared::ToJsonW(gpuRequest).c_str());
}

ConfigureMounts();
}

void WSLAVirtualMachine::ConfigureMounts()
{
auto [_, device] = AttachDisk(m_settings.RootVhd.c_str(), true);

Mount(m_initChannel, device.c_str(), "/mnt", m_settings.RootVhdType.c_str(), "ro", WSLAMountFlagsChroot | WSLAMountFlagsWriteableOverlayFs);
Mount(m_initChannel, nullptr, "/dev", "devtmpfs", "", 0);
Mount(m_initChannel, nullptr, "/sys", "sysfs", "", 0);
Mount(m_initChannel, nullptr, "/proc", "proc", "", 0);
Mount(m_initChannel, nullptr, "/dev/pts", "devpts", "noatime,nosuid,noexec,gid=5,mode=620", 0);
Mount(m_initChannel, nullptr, "/sys/fs/cgroup", "cgroup2", "", 0);

if (FeatureEnabled(WslaFeatureFlagsGPU)) // TODO: re-think how GPU settings should work at the session level API.
{
MountGpuLibraries("/usr/lib/wsl/lib", "/usr/lib/wsl/drivers");
}
}
Expand Down Expand Up @@ -680,17 +727,10 @@ std::pair<ULONG, std::string> WSLAVirtualMachine::AttachDisk(_In_ PCWSTR Path, _

vhdAdded = true;

WSLA_GET_DISK message{};
message.Header.MessageSize = sizeof(message);
message.Header.MessageType = WSLA_GET_DISK::Type;
message.ScsiLun = Lun;
const auto& response = m_initChannel.Transaction(message);

THROW_HR_IF_MSG(E_FAIL, response.Result != 0, "Failed to attach disk, init returned: %lu", response.Result);

const auto devicePath = GetVhdDevicePath(Lun);
cleanup.release();

disk.Device = response.Buffer;
disk.Device = std::move(devicePath);
Device = disk.Device;
m_attachedDisks.emplace(Lun, std::move(disk));
});
Expand Down Expand Up @@ -801,6 +841,18 @@ WSLAVirtualMachine::ConnectedSocket WSLAVirtualMachine::ConnectSocket(wsl::share
return socket;
}

std::string WSLAVirtualMachine::GetVhdDevicePath(ULONG Lun)
{
WSLA_GET_DISK message{};
message.Header.MessageSize = sizeof(message);
message.Header.MessageType = WSLA_GET_DISK::Type;
message.ScsiLun = Lun;
const auto& response = m_initChannel.Transaction(message);
THROW_HR_IF_MSG(E_FAIL, response.Result != 0, "Failed to get disk path, init returned: %lu", response.Result);

return response.Buffer;
}

void WSLAVirtualMachine::OpenLinuxFile(wsl::shared::SocketChannel& Channel, const char* Path, uint32_t Flags, int32_t Fd)
{
static_assert(WSLAFdTypeLinuxFileInput == WslaOpenFlagsRead);
Expand Down
2 changes: 1 addition & 1 deletion src/windows/wslaservice/exe/WSLAVirtualMachine.h
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,6 @@ class DECLSPEC_UUID("0CFC5DC1-B6A7-45FC-8034-3FA9ED73CE30") WSLAVirtualMachine
const WSLA_PROCESS_FD* Fds, ULONG FdCount, const WSLA_PROCESS_FD** TtyInput, const WSLA_PROCESS_FD** TtyOutput, const WSLA_PROCESS_FD** TtyControl);

void ConfigureNetworking();
void ConfigureMounts();
void OnExit(_In_ const HCS_EVENT* Event);
void OnCrash(_In_ const HCS_EVENT* Event);
bool FeatureEnabled(WSLAFeatureFlags Flag) const;
Expand All @@ -117,6 +116,7 @@ class DECLSPEC_UUID("0CFC5DC1-B6A7-45FC-8034-3FA9ED73CE30") WSLAVirtualMachine
int32_t ExpectClosedChannelOrError(wsl::shared::SocketChannel& Channel);

ConnectedSocket ConnectSocket(wsl::shared::SocketChannel& Channel, int32_t Fd);
std::string GetVhdDevicePath(ULONG Lun);
static void OpenLinuxFile(wsl::shared::SocketChannel& Channel, const char* Path, uint32_t Flags, int32_t Fd);
void LaunchPortRelay();
void RemoveShare(_In_ const MountedFolderInfo& MountInfo);
Expand Down
1 change: 1 addition & 0 deletions src/windows/wslaservice/inc/wslaservice.idl
Original file line number Diff line number Diff line change
Expand Up @@ -248,6 +248,7 @@ typedef enum _WSLAFeatureFlags
WslaFeatureFlagsEarlyBootDmesg = 2,
WslaFeatureFlagsGPU = 4,
WslaFeatureFlagsVirtioFs = 8,
WslaFeatureFlagsPmemVhds = 16,
} WSLAFeatureFlags;

struct WSLA_SESSION_SETTINGS {
Expand Down
37 changes: 37 additions & 0 deletions test/windows/WSLATests.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -971,6 +971,43 @@ class WSLATests
ExpectCommandResult(session.get(), {"/bin/sh", "-c", "lsmod | grep ^xsk_diag"}, 0);
}

TEST_METHOD(PmemVhds)
{
WSL2_TEST_ONLY();

// Test with SCSI boot VHDs.
{
auto settings = GetDefaultSessionSettings();
WI_ClearFlag(settings.FeatureFlags, WslaFeatureFlagsPmemVhds);

auto session = CreateSession(settings);

// Validate that SCSI devices are present and PMEM devices are not.
ExpectCommandResult(session.get(), {"/bin/sh", "-c", "test -b /dev/sda"}, 0);
ExpectCommandResult(session.get(), {"/bin/sh", "-c", "test -b /dev/sdb"}, 0);
ExpectCommandResult(session.get(), {"/bin/sh", "-c", "test -b /dev/pmem0"}, 1);
ExpectCommandResult(session.get(), {"/bin/sh", "-c", "test -b /dev/pmem1"}, 1);

// Verify that the SCSI device is readable.
ExpectCommandResult(session.get(), {"/bin/sh", "-c", "dd if=/dev/sda of=/dev/null bs=512 count=1 2>&1"}, 0);
}

// Test with PMEM boot VHDs enabled.
{
auto settings = GetDefaultSessionSettings();
WI_SetFlag(settings.FeatureFlags, WslaFeatureFlagsPmemVhds);

auto session = CreateSession(settings);

// Validate that PMEM devices are present.
ExpectCommandResult(session.get(), {"/bin/sh", "-c", "test -b /dev/pmem0"}, 0);
ExpectCommandResult(session.get(), {"/bin/sh", "-c", "test -b /dev/pmem1"}, 0);

// Verify that the PMEM devices can be read from.
ExpectCommandResult(session.get(), {"/bin/sh", "-c", "dd if=/dev/pmem0 of=/dev/null bs=512 count=1 2>&1"}, 0);
}
}

TEST_METHOD(CreateRootNamespaceProcess)
{
WSL2_TEST_ONLY();
Expand Down