diff --git a/src/shared/inc/defs.h b/src/shared/inc/defs.h index cdeda12bd..b46853339 100644 --- a/src/shared/inc/defs.h +++ b/src/shared/inc/defs.h @@ -30,6 +30,10 @@ Module Name: Type(Type&&) = delete; \ Type& operator=(Type&&) = delete; +#define DEFAULT_MOVABLE(Type) \ + Type(Type&&) = default; \ + Type& operator=(Type&&) = default; + namespace wsl::shared { inline constexpr std::uint32_t VersionMajor = WSL_PACKAGE_VERSION_MAJOR; diff --git a/src/windows/common/CMakeLists.txt b/src/windows/common/CMakeLists.txt index 4424603e9..ba15ee5d1 100644 --- a/src/windows/common/CMakeLists.txt +++ b/src/windows/common/CMakeLists.txt @@ -34,6 +34,7 @@ set(SOURCES SubProcess.cpp svccomm.cpp svccommio.cpp + WSLAContainerLauncher.cpp WSLAProcessLauncher.cpp WslClient.cpp WslCoreConfig.cpp @@ -106,6 +107,7 @@ set(HEADERS SubProcess.h svccomm.hpp svccommio.hpp + WSLAContainerLauncher.h WSLAProcessLauncher.h WslClient.h WslCoreConfig.h diff --git a/src/windows/common/WSLAContainerLauncher.cpp b/src/windows/common/WSLAContainerLauncher.cpp new file mode 100644 index 000000000..6f8ce2753 --- /dev/null +++ b/src/windows/common/WSLAContainerLauncher.cpp @@ -0,0 +1,61 @@ +#include "WSLAContainerLauncher.h" + +using wsl::windows::common::ClientRunningWSLAProcess; +using wsl::windows::common::RunningWSLAContainer; +using wsl::windows::common::WSLAContainerLauncher; + +RunningWSLAContainer::RunningWSLAContainer(wil::com_ptr&& Container, std::vector&& fds) : + m_container(std::move(Container)), m_fds(std::move(fds)) +{ +} + +IWSLAContainer& RunningWSLAContainer::Get() +{ + return *m_container; +} + +WSLA_CONTAINER_STATE RunningWSLAContainer::State() +{ + WSLA_CONTAINER_STATE state{}; + THROW_IF_FAILED(m_container->GetState(&state)); + return state; +} + +ClientRunningWSLAProcess RunningWSLAContainer::GetInitProcess() +{ + wil::com_ptr process; + THROW_IF_FAILED(m_container->GetInitProcess(&process)); + + return ClientRunningWSLAProcess{std::move(process), std::move(m_fds)}; +} + +WSLAContainerLauncher::WSLAContainerLauncher( + const std::string& Image, + const std::string& Name, + const std::string& EntryPoint, + const std::vector& Arguments, + const std::vector& Environment, + ProcessFlags Flags) : + WSLAProcessLauncher(EntryPoint, Arguments, Environment, Flags), m_image(Image), m_name(Name) +{ +} + +RunningWSLAContainer WSLAContainerLauncher::Launch(IWSLASession& Session) +{ + WSLA_CONTAINER_OPTIONS options{}; + options.Image = m_image.c_str(); + options.Name = m_name.c_str(); + auto [processOptions, commandLinePtrs, environmentPtrs] = CreateProcessOptions(); + options.InitProcessOptions = processOptions; + + if (m_executable.empty()) + { + options.InitProcessOptions.Executable = nullptr; + } + + // TODO: Support volumes, ports, flags, shm size, etc. + wil::com_ptr container; + THROW_IF_FAILED(Session.CreateContainer(&options, &container)); + + return RunningWSLAContainer{std::move(container), std::move(m_fds)}; +} \ No newline at end of file diff --git a/src/windows/common/WSLAContainerLauncher.h b/src/windows/common/WSLAContainerLauncher.h new file mode 100644 index 000000000..11301afbe --- /dev/null +++ b/src/windows/common/WSLAContainerLauncher.h @@ -0,0 +1,44 @@ +#include "WSLAprocessLauncher.h" + +namespace wsl::windows::common { + +class RunningWSLAContainer +{ +public: + NON_COPYABLE(RunningWSLAContainer); + DEFAULT_MOVABLE(RunningWSLAContainer); + RunningWSLAContainer(wil::com_ptr&& Container, std::vector&& fds); + IWSLAContainer& Get(); + + WSLA_CONTAINER_STATE State(); + ClientRunningWSLAProcess GetInitProcess(); + +private: + wil::com_ptr m_container; + std::vector m_fds; +}; + +class WSLAContainerLauncher : public WSLAProcessLauncher +{ +public: + NON_COPYABLE(WSLAContainerLauncher); + NON_MOVABLE(WSLAContainerLauncher); + + WSLAContainerLauncher( + const std::string& Image, + const std::string& Name, + const std::string& EntryPoint = "", + const std::vector& Arguments = {}, + const std::vector& Environment = {}, + ProcessFlags Flags = ProcessFlags::Stdout | ProcessFlags::Stderr); + + void AddVolume(const std::string& HostPath, const std::string& ContainerPath, bool ReadOnly); + void AddPort(uint16_t WindowsPort, uint16_t ContainerPort, int Family); + + RunningWSLAContainer Launch(IWSLASession& Session); + +private: + std::string m_image; + std::string m_name; +}; +} // namespace wsl::windows::common \ No newline at end of file diff --git a/src/windows/common/WSLAProcessLauncher.cpp b/src/windows/common/WSLAProcessLauncher.cpp index 6e2b4cd2d..18ec9aa3b 100644 --- a/src/windows/common/WSLAProcessLauncher.cpp +++ b/src/windows/common/WSLAProcessLauncher.cpp @@ -150,7 +150,7 @@ ClientRunningWSLAProcess WSLAProcessLauncher::Launch(IWSLASession& Session) THROW_HR_MSG(hresult, "Failed to launch process: %hs (commandline: %hs). Errno = %i", m_executable.c_str(), commandLine.c_str(), error); } - return process.value(); + return std::move(process.value()); } std::tuple> WSLAProcessLauncher::LaunchNoThrow(IWSLASession& Session) diff --git a/src/windows/common/WSLAProcessLauncher.h b/src/windows/common/WSLAProcessLauncher.h index 97fd2ac7e..1bdcade7f 100644 --- a/src/windows/common/WSLAProcessLauncher.h +++ b/src/windows/common/WSLAProcessLauncher.h @@ -43,6 +43,9 @@ class RunningWSLAProcess }; RunningWSLAProcess(std::vector&& fds); + NON_COPYABLE(RunningWSLAProcess); + DEFAULT_MOVABLE(RunningWSLAProcess); + ProcessResult WaitAndCaptureOutput(DWORD TimeoutMs = INFINITE, std::vector>&& ExtraHandles = {}); virtual wil::unique_handle GetStdHandle(int Index) = 0; virtual wil::unique_event GetExitEvent() = 0; @@ -57,6 +60,9 @@ class RunningWSLAProcess class ClientRunningWSLAProcess : public RunningWSLAProcess { public: + NON_COPYABLE(ClientRunningWSLAProcess); + DEFAULT_MOVABLE(ClientRunningWSLAProcess); + ClientRunningWSLAProcess(wil::com_ptr&& process, std::vector&& fds); wil::unique_handle GetStdHandle(int Index) override; wil::unique_event GetExitEvent() override; @@ -68,7 +74,6 @@ class ClientRunningWSLAProcess : public RunningWSLAProcess private: wil::com_ptr m_process; }; - class WSLAProcessLauncher { public: diff --git a/src/windows/common/WslClient.cpp b/src/windows/common/WslClient.cpp index 6b40f927f..dedeb2de5 100644 --- a/src/windows/common/WslClient.cpp +++ b/src/windows/common/WslClient.cpp @@ -1549,7 +1549,9 @@ int WslaShell(_In_ std::wstring_view commandLine) settings.BootTimeoutMs = 30000; settings.NetworkingMode = WSLANetworkingModeNAT; std::wstring containerRootVhd; + std::string containerImage; bool help = false; + std::wstring debugShell; ArgumentParser parser(std::wstring{commandLine}, WSL_BINARY_NAME); parser.AddArgument(vhd, L"--vhd"); @@ -1559,6 +1561,8 @@ int WslaShell(_In_ std::wstring_view commandLine) parser.AddArgument(Integer(settings.CpuCount), L"--cpu"); parser.AddArgument(Utf8String(fsType), L"--fstype"); parser.AddArgument(containerRootVhd, L"--container-vhd"); + parser.AddArgument(Utf8String(containerImage), L"--image"); + parser.AddArgument(debugShell, L"--debug-shell"); parser.AddArgument(help, L"--help"); parser.Parse(); @@ -1594,24 +1598,59 @@ int WslaShell(_In_ std::wstring_view commandLine) wil::com_ptr session; settings.RootVhd = vhd.c_str(); settings.RootVhdType = fsType.c_str(); - THROW_IF_FAILED(userSession->CreateSession(&sessionSettings, &settings, &session)); - THROW_IF_FAILED(session->GetVirtualMachine(&virtualMachine)); - wsl::windows::common::security::ConfigureForCOMImpersonation(userSession.get()); + if (!debugShell.empty()) + { + THROW_IF_FAILED(userSession->OpenSessionByName(debugShell.c_str(), &session)); + } + else + { + THROW_IF_FAILED(userSession->CreateSession(&sessionSettings, &settings, &session)); + THROW_IF_FAILED(session->GetVirtualMachine(&virtualMachine)); - if (!containerRootVhd.empty()) + wsl::windows::common::security::ConfigureForCOMImpersonation(userSession.get()); + + if (!containerRootVhd.empty()) + { + wsl::windows::common::WSLAProcessLauncher initProcessLauncher{shell, {shell, "/etc/lsw-init.sh"}}; + auto initProcess = initProcessLauncher.Launch(*session); + THROW_HR_IF(E_FAIL, initProcess.WaitAndCaptureOutput().Code != 0); + } + } + + std::optional> container; + std::optional process; + + if (containerImage.empty()) { - wsl::windows::common::WSLAProcessLauncher initProcessLauncher{shell, {shell, "/etc/lsw-init.sh"}}; - auto initProcess = initProcessLauncher.Launch(*session); - THROW_HR_IF(E_FAIL, initProcess.WaitAndCaptureOutput().Code != 0); + wsl::windows::common::WSLAProcessLauncher launcher{shell, {shell}, {"TERM=xterm-256color"}, ProcessFlags::None}; + launcher.AddFd(WSLA_PROCESS_FD{.Fd = 0, .Type = WSLAFdTypeTerminalInput}); + launcher.AddFd(WSLA_PROCESS_FD{.Fd = 1, .Type = WSLAFdTypeTerminalOutput}); + launcher.AddFd(WSLA_PROCESS_FD{.Fd = 2, .Type = WSLAFdTypeTerminalControl}); + + process = launcher.Launch(*session); } + else + { + std::vector fds{ + WSLA_PROCESS_FD{.Fd = 0, .Type = WSLAFdTypeTerminalInput}, + WSLA_PROCESS_FD{.Fd = 1, .Type = WSLAFdTypeTerminalOutput}, + WSLA_PROCESS_FD{.Fd = 2, .Type = WSLAFdTypeTerminalControl}, + }; - wsl::windows::common::WSLAProcessLauncher launcher{shell, {shell}, {"TERM=xterm-256color"}, ProcessFlags::None}; - launcher.AddFd(WSLA_PROCESS_FD{.Fd = 0, .Type = WSLAFdTypeTerminalInput}); - launcher.AddFd(WSLA_PROCESS_FD{.Fd = 1, .Type = WSLAFdTypeTerminalOutput}); - launcher.AddFd(WSLA_PROCESS_FD{.Fd = 2, .Type = WSLAFdTypeTerminalControl}); + WSLA_CONTAINER_OPTIONS containerOptions{}; + containerOptions.Image = containerImage.c_str(); + containerOptions.Name = "test-container"; + containerOptions.InitProcessOptions.Fds = fds.data(); + containerOptions.InitProcessOptions.FdsCount = static_cast(fds.size()); - auto process = launcher.Launch(*session); + container.emplace(); + THROW_IF_FAILED(session->CreateContainer(&containerOptions, &container.value())); + + wil::com_ptr initProcess; + THROW_IF_FAILED((*container)->GetInitProcess(&initProcess)); + process.emplace(std::move(initProcess), std::move(fds)); + } // Configure console for interactive usage. @@ -1641,7 +1680,7 @@ int WslaShell(_In_ std::wstring_view commandLine) auto exitEvent = wil::unique_event(wil::EventOptions::ManualReset); wsl::shared::SocketChannel controlChannel{ - wil::unique_socket{(SOCKET)process.GetStdHandle(2).release()}, "TerminalControl", exitEvent.get()}; + wil::unique_socket{(SOCKET)process->GetStdHandle(2).release()}, "TerminalControl", exitEvent.get()}; std::thread inputThread([&]() { auto updateTerminal = [&controlChannel, &Stdout]() { @@ -1657,7 +1696,7 @@ int WslaShell(_In_ std::wstring_view commandLine) controlChannel.SendMessage(message); }; - wsl::windows::common::relay::StandardInputRelay(Stdin, process.GetStdHandle(0).get(), updateTerminal, exitEvent.get()); + wsl::windows::common::relay::StandardInputRelay(Stdin, process->GetStdHandle(0).get(), updateTerminal, exitEvent.get()); }); auto joinThread = wil::scope_exit_log(WI_DIAGNOSTICS_INFO, [&]() { @@ -1666,12 +1705,12 @@ int WslaShell(_In_ std::wstring_view commandLine) }); // Relay the contents of the pipe to stdout. - wsl::windows::common::relay::InterruptableRelay(process.GetStdHandle(1).get(), Stdout); + wsl::windows::common::relay::InterruptableRelay(process->GetStdHandle(1).get(), Stdout); } - process.GetExitEvent().wait(); + process->GetExitEvent().wait(); - auto [code, signalled] = process.GetExitState(); + auto [code, signalled] = process->GetExitState(); wprintf(L"%hs exited with: %i%hs", shell.c_str(), code, signalled ? " (signalled)" : ""); return code; diff --git a/src/windows/common/WslCoreFilesystem.cpp b/src/windows/common/WslCoreFilesystem.cpp index 6fd79b31e..92cbe0165 100644 --- a/src/windows/common/WslCoreFilesystem.cpp +++ b/src/windows/common/WslCoreFilesystem.cpp @@ -29,8 +29,10 @@ wil::unique_hfile wsl::core::filesystem::CreateFile( void wsl::core::filesystem::CreateVhd(_In_ LPCWSTR target, _In_ ULONGLONG maximumSize, _In_ PSID userSid, _In_ BOOL sparse, _In_ BOOL fixed) { - WI_ASSERT(wsl::windows::common::string::IsPathComponentEqual( - std::filesystem::path{target}.extension().native(), windows::common::wslutil::c_vhdxFileExtension)); + THROW_HR_IF( + E_INVALIDARG, + !wsl::windows::common::string::IsPathComponentEqual( + std::filesystem::path{target}.extension().native(), windows::common::wslutil::c_vhdxFileExtension)); // Disable creation of sparse VHDs while data corruption is being debugged. if (sparse) diff --git a/src/windows/common/wslutil.cpp b/src/windows/common/wslutil.cpp index fd0f10c06..43da515e4 100644 --- a/src/windows/common/wslutil.cpp +++ b/src/windows/common/wslutil.cpp @@ -143,7 +143,8 @@ static const std::map g_commonErrors{ X_WIN32(ERROR_OPERATION_ABORTED), X_WIN32(WSAECONNREFUSED), X_WIN32(ERROR_BAD_PATHNAME), - X(WININET_E_TIMEOUT)}; + X(WININET_E_TIMEOUT), + X_WIN32(ERROR_INVALID_SID)}; #undef X diff --git a/src/windows/wslaservice/exe/ServiceProcessLauncher.h b/src/windows/wslaservice/exe/ServiceProcessLauncher.h index e9b61e040..88b4b0693 100644 --- a/src/windows/wslaservice/exe/ServiceProcessLauncher.h +++ b/src/windows/wslaservice/exe/ServiceProcessLauncher.h @@ -24,7 +24,7 @@ class ServiceRunningProcess : public common::RunningWSLAProcess { public: NON_COPYABLE(ServiceRunningProcess); - NON_MOVABLE(ServiceRunningProcess); + DEFAULT_MOVABLE(ServiceRunningProcess); ServiceRunningProcess(const Microsoft::WRL::ComPtr& process, std::vector&& fds); wil::unique_handle GetStdHandle(int Index) override; diff --git a/src/windows/wslaservice/exe/WSLAContainer.cpp b/src/windows/wslaservice/exe/WSLAContainer.cpp index 4e7a68b59..384ea4cdd 100644 --- a/src/windows/wslaservice/exe/WSLAContainer.cpp +++ b/src/windows/wslaservice/exe/WSLAContainer.cpp @@ -18,6 +18,14 @@ Module Name: using wsl::windows::service::wsla::WSLAContainer; +const std::string nerdctlPath = "/usr/bin/nerdctl"; + +// Constants for required default arguments for "nerdctl run..." +static std::vector defaultNerdctlRunArgs{//"--pull=never", // TODO: Uncomment once PullImage() is implemented. + "--net=host", // TODO: default for now, change later + "--ulimit", + "nofile=65536:65536"}; + HRESULT WSLAContainer::Start() { return E_NOTIMPL; @@ -38,10 +46,12 @@ HRESULT WSLAContainer::GetState(WSLA_CONTAINER_STATE* State) return E_NOTIMPL; } -HRESULT WSLAContainer::GetInitProcess(IWSLAProcess** process) +HRESULT WSLAContainer::GetInitProcess(IWSLAProcess** Process) +try { - return E_NOTIMPL; + return m_containerProcess.Get().QueryInterface(__uuidof(IWSLAProcess), (void**)Process); } +CATCH_RETURN(); HRESULT WSLAContainer::Exec(const WSLA_PROCESS_OPTIONS* Options, IWSLAProcess** Process, int* Errno) try @@ -53,3 +63,102 @@ try return S_OK; } CATCH_RETURN(); + +Microsoft::WRL::ComPtr WSLAContainer::Create(const WSLA_CONTAINER_OPTIONS& containerOptions, WSLAVirtualMachine& parentVM) +{ + + bool hasStdin = false; + bool hasTty = false; + for (size_t i = 0; i < containerOptions.InitProcessOptions.FdsCount; i++) + { + if (containerOptions.InitProcessOptions.Fds[i].Fd == 0) + { + hasStdin = true; + } + + if (containerOptions.InitProcessOptions.Fds[i].Type == WSLAFdTypeTerminalInput || + containerOptions.InitProcessOptions.Fds[i].Type == WSLAFdTypeTerminalOutput) + { + hasTty = true; + } + } + + std::vector inputOptions; + if (hasStdin) + { + inputOptions.push_back("-i"); + } + + if (hasTty) + { + inputOptions.push_back("-t"); + } + + auto args = PrepareNerdctlRunCommand(containerOptions, std::move(inputOptions)); + + ServiceProcessLauncher launcher(nerdctlPath, args, {}, common::ProcessFlags::None); + for (size_t i = 0; i < containerOptions.InitProcessOptions.FdsCount; i++) + { + launcher.AddFd(containerOptions.InitProcessOptions.Fds[i]); + } + + return wil::MakeOrThrow(&parentVM, launcher.Launch(parentVM)); +} + +std::vector WSLAContainer::PrepareNerdctlRunCommand(const WSLA_CONTAINER_OPTIONS& options, std::vector&& inputOptions) +{ + std::vector args{nerdctlPath}; + args.push_back("run"); + args.push_back("--name"); + args.push_back(options.Name); + if (options.ShmSize > 0) + { + args.push_back("--shm-size=" + std::to_string(options.ShmSize) + 'm'); + } + if (options.Flags & WSLA_CONTAINER_FLAG_ENABLE_GPU) + { + args.push_back("--gpus"); + // TODO: Parse GPU device list from WSLA_CONTAINER_OPTIONS. For now, just enable all GPUs. + args.push_back("all"); + } + + args.insert(args.end(), defaultNerdctlRunArgs.begin(), defaultNerdctlRunArgs.end()); + args.insert(args.end(), inputOptions.begin(), inputOptions.end()); + + for (ULONG i = 0; i < options.InitProcessOptions.EnvironmentCount; i++) + { + THROW_HR_IF_MSG( + E_INVALIDARG, + options.InitProcessOptions.Environment[i][0] == L'-', + "Invlaid environment string: %hs", + options.InitProcessOptions.Environment[i]); + + args.insert(args.end(), {"-e", options.InitProcessOptions.Environment[i]}); + } + + if (options.InitProcessOptions.Executable != nullptr) + { + args.push_back("--entrypoint"); + args.push_back(options.InitProcessOptions.Executable); + } + + // TODO: + // - Implement volume mounts + // - Implement port mapping + + args.push_back(options.Image); + + if (options.InitProcessOptions.CommandLineCount > 0) + { + args.push_back("--"); + + for (ULONG i = 0; i < options.InitProcessOptions.CommandLineCount; i++) + { + args.push_back(options.InitProcessOptions.CommandLine[i]); + } + } + + // TODO: Implement --entrypoint override if specified in WSLA_CONTAINER_OPTIONS. + + return args; +} \ No newline at end of file diff --git a/src/windows/wslaservice/exe/WSLAContainer.h b/src/windows/wslaservice/exe/WSLAContainer.h index f9bed4c84..d84eaac8f 100644 --- a/src/windows/wslaservice/exe/WSLAContainer.h +++ b/src/windows/wslaservice/exe/WSLAContainer.h @@ -14,7 +14,9 @@ Module Name: #pragma once +#include "ServiceProcessLauncher.h" #include "wslaservice.h" +#include "WSLAVirtualMachine.h" namespace wsl::windows::service::wsla { @@ -23,6 +25,10 @@ class DECLSPEC_UUID("B1F1C4E3-C225-4CAE-AD8A-34C004DE1AE4") WSLAContainer { public: WSLAContainer() = default; // TODO + WSLAContainer(WSLAVirtualMachine* parentVM, ServiceRunningProcess&& containerProcess) : + m_parentVM(parentVM), m_containerProcess(std::move(containerProcess)) + { + } WSLAContainer(const WSLAContainer&) = delete; WSLAContainer& operator=(const WSLAContainer&) = delete; @@ -33,6 +39,12 @@ class DECLSPEC_UUID("B1F1C4E3-C225-4CAE-AD8A-34C004DE1AE4") WSLAContainer IFACEMETHOD(GetInitProcess)(_Out_ IWSLAProcess** process) override; IFACEMETHOD(Exec)(_In_ const WSLA_PROCESS_OPTIONS* Options, _Out_ IWSLAProcess** Process, _Out_ int* Errno) override; + static Microsoft::WRL::ComPtr Create(const WSLA_CONTAINER_OPTIONS& Options, WSLAVirtualMachine& parentVM); + private: + ServiceRunningProcess m_containerProcess; + WSLAVirtualMachine* m_parentVM = nullptr; + + static std::vector PrepareNerdctlRunCommand(const WSLA_CONTAINER_OPTIONS& options, std::vector&& inputOptions); }; } // namespace wsl::windows::service::wsla \ No newline at end of file diff --git a/src/windows/wslaservice/exe/WSLAProcess.cpp b/src/windows/wslaservice/exe/WSLAProcess.cpp index 8eba6e7bf..95269a36f 100644 --- a/src/windows/wslaservice/exe/WSLAProcess.cpp +++ b/src/windows/wslaservice/exe/WSLAProcess.cpp @@ -44,6 +44,8 @@ void WSLAProcess::OnVmTerminated() { m_state = WslaProcessStateSignalled; m_exitedCode = 9; // SIGKILL + + m_exitEvent.SetEvent(); } } @@ -74,6 +76,11 @@ try RETURN_HR_IF(HRESULT_FROM_WIN32(ERROR_INVALID_STATE), !socket.is_valid()); *Handle = HandleToUlong(common::wslutil::DuplicateHandleToCallingProcess(socket.get())); + WSL_LOG( + "GetStdHandle", + TraceLoggingValue(Index, "fd"), + TraceLoggingValue(socket.get(), "handle"), + TraceLoggingValue(*Handle, "remoteHandle")); socket.reset(); return S_OK; diff --git a/src/windows/wslaservice/exe/WSLASession.cpp b/src/windows/wslaservice/exe/WSLASession.cpp index 525eb3f4d..bbfceff8a 100644 --- a/src/windows/wslaservice/exe/WSLASession.cpp +++ b/src/windows/wslaservice/exe/WSLASession.cpp @@ -60,6 +60,11 @@ HRESULT WSLASession::GetDisplayName(LPWSTR* DisplayName) return S_OK; } +const std::wstring& WSLASession::DisplayName() const +{ + return m_displayName; +} + HRESULT WSLASession::PullImage(LPCWSTR Image, const WSLA_REGISTRY_AUTHENTICATION_INFORMATION* RegistryInformation, IProgressCallback* ProgressCallback) { return E_NOTIMPL; @@ -80,14 +85,18 @@ HRESULT WSLASession::DeleteImage(LPCWSTR Image) return E_NOTIMPL; } -HRESULT WSLASession::CreateContainer(const WSLA_CONTAINER_OPTIONS* Options, IWSLAContainer** Container) +HRESULT WSLASession::CreateContainer(const WSLA_CONTAINER_OPTIONS* containerOptions, IWSLAContainer** Container) try { - // Basic instanciation for testing. - // TODO: Implement. + RETURN_HR_IF_NULL(E_POINTER, containerOptions); + + std::lock_guard lock{m_lock}; + THROW_HR_IF(HRESULT_FROM_WIN32(ERROR_INVALID_STATE), !m_virtualMachine); - auto container = wil::MakeOrThrow(); - container.CopyTo(__uuidof(IWSLAContainer), (void**)Container); + // TODO: Log entrance into the function. + m_containerId++; + auto container = WSLAContainer::Create(*containerOptions, *m_virtualMachine.Get()); + THROW_IF_FAILED(container.CopyTo(__uuidof(IWSLAContainer), (void**)Container)); return S_OK; } diff --git a/src/windows/wslaservice/exe/WSLASession.h b/src/windows/wslaservice/exe/WSLASession.h index 843b88d50..362a70e78 100644 --- a/src/windows/wslaservice/exe/WSLASession.h +++ b/src/windows/wslaservice/exe/WSLASession.h @@ -27,6 +27,7 @@ class DECLSPEC_UUID("4877FEFC-4977-4929-A958-9F36AA1892A4") WSLASession ~WSLASession(); IFACEMETHOD(GetDisplayName)(LPWSTR* DisplayName) override; + const std::wstring& DisplayName() const; // Image management. IFACEMETHOD(PullImage)(_In_ LPCWSTR Image, _In_ const WSLA_REGISTRY_AUTHENTICATION_INFORMATION* RegistryInformation, _In_ IProgressCallback* ProgressCallback) override; @@ -56,6 +57,9 @@ class DECLSPEC_UUID("4877FEFC-4977-4929-A958-9F36AA1892A4") WSLASession Microsoft::WRL::ComPtr m_virtualMachine; std::wstring m_displayName; std::mutex m_lock; + + std::atomic_int m_containerId = 1; + // TODO: Add container tracking here. Could reuse m_lock for that. }; } // namespace wsl::windows::service::wsla \ No newline at end of file diff --git a/src/windows/wslaservice/exe/WSLAUserSession.cpp b/src/windows/wslaservice/exe/WSLAUserSession.cpp index f7d59418d..20eddccb3 100644 --- a/src/windows/wslaservice/exe/WSLAUserSession.cpp +++ b/src/windows/wslaservice/exe/WSLAUserSession.cpp @@ -47,8 +47,7 @@ PSID WSLAUserSessionImpl::GetUserSid() const return m_tokenInfo->User.Sid; } -HRESULT wsl::windows::service::wsla::WSLAUserSessionImpl::CreateSession( - const WSLA_SESSION_SETTINGS* Settings, const VIRTUAL_MACHINE_SETTINGS* VmSettings, IWSLASession** WslaSession) +HRESULT WSLAUserSessionImpl::CreateSession(const WSLA_SESSION_SETTINGS* Settings, const VIRTUAL_MACHINE_SETTINGS* VmSettings, IWSLASession** WslaSession) { auto session = wil::MakeOrThrow(*Settings, *this, *VmSettings); @@ -63,6 +62,24 @@ HRESULT wsl::windows::service::wsla::WSLAUserSessionImpl::CreateSession( return S_OK; } +HRESULT WSLAUserSessionImpl::OpenSessionByName(LPCWSTR DisplayName, IWSLASession** Session) +{ + std::lock_guard lock(m_wslaSessionsLock); + + // TODO: ACL check + // TODO: Check for duplicate on session creation. + for (auto& e : m_sessions) + { + if (e->DisplayName() == DisplayName) + { + THROW_IF_FAILED(e->QueryInterface(__uuidof(IWSLASession), (void**)Session)); + return S_OK; + } + } + + return HRESULT_FROM_WIN32(ERROR_NOT_FOUND); +} + wsl::windows::service::wsla::WSLAUserSession::WSLAUserSession(std::weak_ptr&& Session) : m_session(std::move(Session)) { @@ -92,7 +109,18 @@ HRESULT wsl::windows::service::wsla::WSLAUserSession::ListSessions(WSLA_SESSION_ { return E_NOTIMPL; } + HRESULT wsl::windows::service::wsla::WSLAUserSession::OpenSession(ULONG Id, IWSLASession** Session) { return E_NOTIMPL; } + +HRESULT wsl::windows::service::wsla::WSLAUserSession::OpenSessionByName(LPCWSTR DisplayName, IWSLASession** Session) +try +{ + auto session = m_session.lock(); + RETURN_HR_IF(RPC_E_DISCONNECTED, !session); + + return session->OpenSessionByName(DisplayName, Session); +} +CATCH_RETURN(); \ No newline at end of file diff --git a/src/windows/wslaservice/exe/WSLAUserSession.h b/src/windows/wslaservice/exe/WSLAUserSession.h index a7b6635ce..d0fcca5d8 100644 --- a/src/windows/wslaservice/exe/WSLAUserSession.h +++ b/src/windows/wslaservice/exe/WSLAUserSession.h @@ -30,6 +30,7 @@ class WSLAUserSessionImpl PSID GetUserSid() const; HRESULT CreateSession(const WSLA_SESSION_SETTINGS* Settings, const VIRTUAL_MACHINE_SETTINGS* VmSettings, IWSLASession** WslaSession); + HRESULT OpenSessionByName(_In_ LPCWSTR DisplayName, _Out_ IWSLASession** Session); void OnSessionTerminated(WSLASession* Session); @@ -55,6 +56,7 @@ class DECLSPEC_UUID("a9b7a1b9-0671-405c-95f1-e0612cb4ce8f") WSLAUserSession IFACEMETHOD(CreateSession)(const WSLA_SESSION_SETTINGS* WslaSessionSettings, const VIRTUAL_MACHINE_SETTINGS* VmSettings, IWSLASession** WslaSession) override; IFACEMETHOD(ListSessions)(_Out_ WSLA_SESSION_INFORMATION** Sessions, _Out_ ULONG* SessionsCount) override; IFACEMETHOD(OpenSession)(_In_ ULONG Id, _Out_ IWSLASession** Session) override; + IFACEMETHOD(OpenSessionByName)(_In_ LPCWSTR DisplayName, _Out_ IWSLASession** Session) override; private: std::weak_ptr m_session; diff --git a/src/windows/wslaservice/exe/WSLAUserSessionFactory.cpp b/src/windows/wslaservice/exe/WSLAUserSessionFactory.cpp index e13effec7..0a28ede8e 100644 --- a/src/windows/wslaservice/exe/WSLAUserSessionFactory.cpp +++ b/src/windows/wslaservice/exe/WSLAUserSessionFactory.cpp @@ -50,11 +50,15 @@ HRESULT WSLAUserSessionFactory::CreateInstance(_In_ IUnknown* pUnkOuter, _In_ RE THROW_HR_IF(CO_E_SERVER_STOPPING, !g_sessions.has_value()); auto session = std::find_if(g_sessions->begin(), g_sessions->end(), [&tokenInfo](auto it) { - return EqualSid(it->GetUserSid(), &tokenInfo->User.Sid); + return EqualSid(it->GetUserSid(), tokenInfo->User.Sid); }); if (session == g_sessions->end()) { + wil::unique_hlocal_string sid; + THROW_IF_WIN32_BOOL_FALSE(ConvertSidToStringSid(tokenInfo->User.Sid, &sid)); + WSL_LOG("WSLAUserSession created", TraceLoggingValue(sid.get(), "sid")); + session = g_sessions->insert(g_sessions->end(), std::make_shared(userToken.get(), std::move(tokenInfo))); } diff --git a/src/windows/wslaservice/inc/wslaservice.idl b/src/windows/wslaservice/inc/wslaservice.idl index de9235cf0..a3736c6d9 100644 --- a/src/windows/wslaservice/inc/wslaservice.idl +++ b/src/windows/wslaservice/inc/wslaservice.idl @@ -104,7 +104,7 @@ struct WSLA_IMAGE_INFORMATION struct WSLA_PROCESS_OPTIONS { - LPCSTR Executable; + [unique] LPCSTR Executable; [unique] LPCSTR CurrentDirectory; [size_is(CommandLineCount)] LPCSTR* CommandLine; ULONG CommandLineCount; @@ -117,7 +117,8 @@ struct WSLA_PROCESS_OPTIONS struct WSLA_VOLUME { LPCSTR HostPath; - LPCSTR ContainerHostPath; + LPCSTR ContainerPath; + BOOL ReadOnly; }; struct WSLA_PORT_MAPPING @@ -126,11 +127,17 @@ struct WSLA_PORT_MAPPING USHORT ContainerPort; }; +enum WSLA_CONTAINER_FLAGS +{ + WSLA_CONTAINER_FLAG_ENABLE_GPU = 1 +} ; + + struct WSLA_CONTAINER_OPTIONS { LPCSTR Image; LPCSTR Name; - struct WSLA_PROCESS_OPTIONS* InitProcessOptions; + struct WSLA_PROCESS_OPTIONS InitProcessOptions; [unique, size_is(VolumesCount)] struct WSLA_VOLUME* Volumes; ULONG VolumesCount; [unique, size_is(PortsCount)] struct WSLA_PORT_MAPPING* Ports; @@ -312,6 +319,7 @@ interface IWSLAUserSession : IUnknown HRESULT CreateSession([in] const struct WSLA_SESSION_SETTINGS* Settings, [in] const VIRTUAL_MACHINE_SETTINGS* VmSettings, [out]IWSLASession** Session); HRESULT ListSessions([out, size_is(, *SessionsCount)] struct WSLA_SESSION_INFORMATION** Sessions, [out] ULONG* SessionsCount); HRESULT OpenSession([in] ULONG Id, [out]IWSLASession** Session); + HRESULT OpenSessionByName([in] LPCWSTR DisplayName, [out]IWSLASession** Session); // TODO: Do we need 'TerminateSession()' ? } \ No newline at end of file diff --git a/test/windows/WSLATests.cpp b/test/windows/WSLATests.cpp index f7cb14635..59125bf35 100644 --- a/test/windows/WSLATests.cpp +++ b/test/windows/WSLATests.cpp @@ -17,11 +17,14 @@ Module Name: #include "WSLAApi.h" #include "wslaservice.h" #include "WSLAProcessLauncher.h" +#include "WSLAContainerLauncher.h" #include "WslCoreFilesystem.h" using namespace wsl::windows::common::registry; using wsl::windows::common::ProcessFlags; +using wsl::windows::common::RunningWSLAContainer; using wsl::windows::common::RunningWSLAProcess; +using wsl::windows::common::WSLAContainerLauncher; using wsl::windows::common::WSLAProcessLauncher; using wsl::windows::common::relay::OverlappedIOHandle; using wsl::windows::common::relay::WriteHandle; @@ -53,7 +56,10 @@ class WSLATests wil::com_ptr CreateSession(VIRTUAL_MACHINE_SETTINGS& vmSettings, const WSLA_SESSION_SETTINGS& sessionSettings = {L"wsla-test"}) { - vmSettings.RootVhdType = "ext4"; + if (vmSettings.RootVhdType == nullptr) + { + vmSettings.RootVhdType = "ext4"; + } wil::com_ptr userSession; VERIFY_SUCCEEDED(CoCreateInstance(__uuidof(WSLAUserSession), nullptr, CLSCTX_LOCAL_SERVER, IID_PPV_ARGS(&userSession))); @@ -132,6 +138,38 @@ class WSLATests return result; } + void ValidateProcessOutput(RunningWSLAProcess& process, const std::map& expectedOutput, int expectedResult = 0) + { + auto result = process.WaitAndCaptureOutput(); + + if (result.Code != expectedResult) + { + LogError( + "Comman didn't return expected code (%i). ExitCode: %i, Stdout: '%hs', Stderr: '%hs'", + expectedResult, + result.Code, + result.Output[1].c_str(), + result.Output[2].c_str()); + + return; + } + + for (const auto& [fd, expected] : expectedOutput) + { + auto it = result.Output.find(fd); + if (it == result.Output.end()) + { + LogError("Expected output on fd %i, but none found.", fd); + return; + } + + if (it->second != expected) + { + LogError("Unexpected output on fd %i. Expected: '%hs', Actual: '%hs'", fd, expected.c_str(), it->second.c_str()); + } + } + } + TEST_METHOD(CustomDmesgOutput) { WSL2_TEST_ONLY(); @@ -411,7 +449,7 @@ class WSLATests auto [hresult, _, process] = launcher.LaunchNoThrow(*session); VERIFY_ARE_EQUAL(hresult, expectedError); - return process; + return std::move(process); }; { @@ -1098,4 +1136,82 @@ class WSLATests VERIFY_ARE_EQUAL(session->FormatVirtualDisk(L"DoesNotExist.vhdx"), E_INVALIDARG); VERIFY_ARE_EQUAL(session->FormatVirtualDisk(L"C:\\DoesNotExist.vhdx"), HRESULT_FROM_WIN32(ERROR_FILE_NOT_FOUND)); } + + TEST_METHOD(CreateContainer) + { + WSL2_TEST_ONLY(); + + auto storageVhd = std::filesystem::current_path() / "storage.vhdx"; + + // Create a 1G temporary VHD. + if (!std::filesystem::exists(storageVhd)) + { + wsl::core::filesystem::CreateVhd(storageVhd.native().c_str(), 1024 * 1024 * 1024, nullptr, true, false); + } + + auto cleanup = wil::scope_exit_log(WI_DIAGNOSTICS_INFO, [&]() { LOG_IF_WIN32_BOOL_FALSE(DeleteFileW(storageVhd.c_str())); }); + + VIRTUAL_MACHINE_SETTINGS settings{}; + settings.CpuCount = 4; + settings.DisplayName = L"WSLA"; + settings.MemoryMb = 2048; + settings.BootTimeoutMs = 30 * 1000; + settings.RootVhd = TEXT(WSLA_TEST_DISTRO_PATH); + settings.RootVhdType = "squashfs"; + settings.NetworkingMode = WSLANetworkingModeNAT; + settings.ContainerRootVhd = storageVhd.c_str(); + settings.FormatContainerRootVhd = true; + + auto session = CreateSession(settings); + + // TODO: Remove once the proper rootfs VHD is available. + ExpectCommandResult(session.get(), {"/etc/lsw-init.sh"}, 0); + + // Test a simple container start. + { + WSLAContainerLauncher launcher("debian:latest", "test-simple", "echo", {"OK"}); + auto container = launcher.Launch(*session); + auto process = container.GetInitProcess(); + + ValidateProcessOutput(process, {{1, "OK\n"}}); + } + + // Validate that env is correctly wired. + { + WSLAContainerLauncher launcher("debian:latest", "test-env", "/bin/bash", {"-c", "echo $testenv"}, {{"testenv=testvalue"}}); + auto container = launcher.Launch(*session); + auto process = container.GetInitProcess(); + + ValidateProcessOutput(process, {{1, "testvalue\n"}}); + } + + // Validate that starting containers work with the default entrypoint. + + // TODO: This is hanging. nerdctl run seems to hang with -i is passed outside of a TTY context. + /* + { + WSLAContainerLauncher launcher( + "debian:latest", "test-default-entrypoint", "/bin/cat", {}, {}, ProcessFlags::Stdin | ProcessFlags::Stdout | + ProcessFlags::Stderr); auto container = launcher.Launch(*session); auto process = container.GetInitProcess(); + + std::string shellInput = "echo $SHELL\n exit"; + std::unique_ptr writeStdin( + new WriteHandle(process.GetStdHandle(0), {shellInput.begin(), shellInput.end()})); + std::vector> extraHandles; + extraHandles.emplace_back(std::move(writeStdin)); + + auto result = process.WaitAndCaptureOutput(INFINITE, std::move(extraHandles)); + + VERIFY_ARE_EQUAL(result.Output[1], "foo"); + }*/ + + // Validate that stdin is empty if ProcessFlags::Stdin is not passed. + { + WSLAContainerLauncher launcher("debian:latest", "test-stdin", "/bin/cat"); + auto container = launcher.Launch(*session); + auto process = container.GetInitProcess(); + + ValidateProcessOutput(process, {{1, ""}}); + } + } }; \ No newline at end of file