From 6245ef29221fbcc0c7486789eeb7d5e1bd8fd8ac Mon Sep 17 00:00:00 2001 From: Blue Date: Tue, 23 Dec 2025 11:24:03 -0800 Subject: [PATCH 01/26] Save state --- CMakeLists.txt | 31 ++++++ src/linux/init/WSLAInit.cpp | 104 +++++++++++++++++- src/shared/inc/SocketChannel.h | 5 + src/shared/inc/lxinitshared.h | 72 +++++++++++- src/windows/wslaservice/exe/CMakeLists.txt | 7 +- src/windows/wslaservice/exe/WSLASession.cpp | 63 +++++++++-- src/windows/wslaservice/exe/WSLASession.h | 2 + .../wslaservice/exe/WSLAVirtualMachine.cpp | 37 ++++++- .../wslaservice/exe/WSLAVirtualMachine.h | 13 ++- 9 files changed, 315 insertions(+), 19 deletions(-) diff --git a/CMakeLists.txt b/CMakeLists.txt index c0a350ad3..6d6095651 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -40,6 +40,36 @@ FetchContent_Declare(nlohmannjson FetchContent_MakeAvailable(nlohmannjson) FetchContent_GetProperties(nlohmannjson SOURCE_DIR NLOHMAN_JSON_SOURCE_DIR) +FetchContent_Declare(httplib SYSTEM + GIT_REPOSITORY https://github.com/yhirose/cpp-httplib + GIT_TAG v0.28.0 + GIT_SHALLOW TRUE) + +FetchContent_MakeAvailable(httplib) +FetchContent_GetProperties(httplib SOURCE_DIR HTTP_LIB_SOURCE_DIR) + + +set(BOOST_VERSION "1.90.0") +set(BOOST_TARBALL "boost_${BOOST_VERSION}") +string(REPLACE "." "_" BOOST_TARBALL "${BOOST_TARBALL}") # 1.84.0 -> 1_84_0 + +FetchContent_Declare( + boost_headers + URL https://archives.boost.io/release/${BOOST_VERSION}/source/${BOOST_TARBALL}.tar.gz + # You can add URL_HASH to pin integrity: + # URL_HASH SHA256= +) + +# Download & unpack to boost_headers_SOURCE_DIR (no add_subdirectory!) +FetchContent_GetProperties(boost_headers) +if(NOT boost_headers_POPULATED) + FetchContent_Populate(boost_headers) +endif() + +# Path where Boost headers reside (root has 'boost/' folder) +include_directories(${boost_headers_SOURCE_DIR}) + + # Import modules list(APPEND CMAKE_MODULE_PATH "${CMAKE_CURRENT_LIST_DIR}/cmake") find_package(IDL REQUIRED) @@ -314,6 +344,7 @@ set(LINUX_COMMON_FLAGS --gcc-toolchain=${LINUXSDK_PATH} -I "${CMAKE_CURRENT_LIST_DIR}/src/shared/configfile" -I "${CMAKE_CURRENT_LIST_DIR}/src/shared/inc" -I "${NLOHMAN_JSON_SOURCE_DIR}/include" + -I "${HTTP_LIB_SOURCE_DIR}" -I "${CMAKE_BINARY_DIR}/generated" --no-standard-libraries -Werror diff --git a/src/linux/init/WSLAInit.cpp b/src/linux/init/WSLAInit.cpp index 7f2da7276..be6d79312 100644 --- a/src/linux/init/WSLAInit.cpp +++ b/src/linux/init/WSLAInit.cpp @@ -182,6 +182,108 @@ void HandleMessageImpl(wsl::shared::SocketChannel& Channel, const WSLA_OPEN& Mes result = 0; } +void HandleMessageImpl(wsl::shared::SocketChannel& Channel, const WSLA_UNIX_CONNECT& Message, const gsl::span& Buffer) +{ + int result = -1; + + auto sendResult = wil::scope_exit([&]() { Channel.SendResultMessage(result); }); + + wil::unique_fd socket; + + const auto* path = wsl::shared::string::FromSpan(Buffer, Message.PathOffset); + THROW_ERRNO_IF(EINVAL, path == nullptr); + + try + { + socket = UtilConnectUnix(path); + result = 0; + } + catch (...) + { + result = wil::ResultFromCaughtException(); + } + + if (result != 0) + { + return; + } + + sendResult.reset(); + + LOG_ERROR("Connected to unix socket {}", path); + + // Relay data between the two sockets. + + pollfd pollDescriptors[2]; + pollDescriptors[0].fd = socket.get(); + pollDescriptors[0].events = POLLIN; + pollDescriptors[1].fd = Channel.Socket(); + pollDescriptors[1].events = POLLIN; + + std::vector relayBuffer; + while (true) + { + auto result = poll(pollDescriptors, COUNT_OF(pollDescriptors), -1); + THROW_LAST_ERROR_IF(result < 0); + + if (pollDescriptors[0].revents & (POLLIN | POLLHUP | POLLERR)) + { + auto bytesRead = UtilReadBuffer(pollDescriptors[0].fd, relayBuffer); + if (bytesRead < 0) + { + LOG_ERROR("read failed {}", errno); + break; + } + else if (bytesRead == 0) + { + // Unix socket has been closed. + pollDescriptors[0].fd = -1; + break; + } + else + { + auto bytesWritten = write(Channel.Socket(), relayBuffer.data(), bytesRead); + if (bytesWritten < 0) + { + LOG_ERROR("write failed {}", errno); + break; + } + + LOG_ERROR("Relayed: {} bytes from unix socket to hvsocket", bytesWritten); + } + } + + if (pollDescriptors[1].revents & (POLLIN | POLLHUP | POLLERR)) + { + auto bytesRead = UtilReadBuffer(pollDescriptors[1].fd, relayBuffer); + if (bytesRead < 0) + { + LOG_ERROR("read failed {}", errno); + break; + } + else if (bytesRead == 0) + { + // hvsocket has been closed. + pollDescriptors[1].fd = -1; + break; + } + else + { + auto bytesWritten = write(socket.get(), relayBuffer.data(), bytesRead); + if (bytesWritten < 0) + { + LOG_ERROR("write failed {}", errno); + break; + } + + LOG_ERROR("Relayed: {} bytes from hvsocket to unix socket", bytesWritten); + } + } + } + + LOG_ERROR("Relay exited"); +} + void HandleMessageImpl(wsl::shared::SocketChannel& Channel, const WSLA_TTY_RELAY& Message, const gsl::span&) { THROW_LAST_ERROR_IF(fcntl(Message.TtyMaster, F_SETFL, O_NONBLOCK) < 0); @@ -766,7 +868,7 @@ void ProcessMessage(wsl::shared::SocketChannel& Channel, LX_MESSAGE_TYPE Type, c { try { - HandleMessage( + HandleMessage( Channel, Type, Buffer); } catch (...) diff --git a/src/shared/inc/SocketChannel.h b/src/shared/inc/SocketChannel.h index f472ed021..7ecf3440e 100644 --- a/src/shared/inc/SocketChannel.h +++ b/src/shared/inc/SocketChannel.h @@ -318,6 +318,11 @@ class SocketChannel return m_socket.get(); } + auto Release() + { + return std::move(m_socket); + } + bool Connected() const { return m_socket.get() >= 0; diff --git a/src/shared/inc/lxinitshared.h b/src/shared/inc/lxinitshared.h index b64fbf13e..fa767ea9b 100644 --- a/src/shared/inc/lxinitshared.h +++ b/src/shared/inc/lxinitshared.h @@ -397,7 +397,10 @@ typedef enum _LX_MESSAGE_TYPE LxMessageWSLADetach, LxMessageWSLATerminalChanged, LxMessageWSLAWatchProcesses, - LxMessageWSLAProcessExited + LxMessageWSLAProcessExited, + LxMessageWSLAHTTPResponse, + LxMessageWSLAHTTPRequest, + LxMessageWSLAUnixConnect, } LX_MESSAGE_TYPE, *PLX_MESSAGE_TYPE; @@ -508,7 +511,9 @@ inline auto ToString(LX_MESSAGE_TYPE messageType) X(LxMessageWSLADetach) X(LxMessageWSLATerminalChanged) X(LxMessageWSLAWatchProcesses) - X(LxMessageWSLAProcessExited) + X(LxMessageWSLAHTTPResponse) + X(LxMessageWSLAHTTPRequest) + X(LxMessageWSLAUnixConnect) default: return ""; @@ -1881,6 +1886,69 @@ struct WSLA_PROCESS_EXITED PRETTY_PRINT(FIELD(Header), FIELD(Pid), FIELD(Code), FIELD(Signaled)); }; +struct WSLA_HTTP_RESPONSE +{ + static inline auto Type = LxMessageWSLAHTTPResponse; + + int Errno; + unsigned int StatusCode; + unsigned int ContentSize; +}; + +enum class HTTPMethod +{ + GET, + POST +}; + +inline auto ToString(HTTPMethod type) +{ + if (type == HTTPMethod::GET) + { + return "GET"; + } + else if (type == HTTPMethod::POST) + { + return "POST"; + } + else + { + return "Unknown"; + } +} + +inline void PrettyPrint(std::stringstream& Out, HTTPMethod Value) +{ + Out << ToString(Value); +} + +struct WSLA_HTTP_REQUEST +{ + static inline auto Type = LxMessageWSLAHTTPRequest; + using TResponse = WSLA_HTTP_RESPONSE; + + HTTPMethod Method; + unsigned int UrlOffset; + unsigned int BodyOffset; + unsigned int ContentTypeOffset; + + PRETTY_PRINT(FIELD(Method), STRING_FIELD(UrlOffset), STRING_FIELD(BodyOffset), STRING_FIELD(ContentTypeOffset)); +}; + +struct WSLA_UNIX_CONNECT +{ + static inline auto Type = LxMessageWSLAUnixConnect; + using TResponse = RESULT_MESSAGE; + + DECLARE_MESSAGE_CTOR(WSLA_UNIX_CONNECT); + + MESSAGE_HEADER Header; + unsigned int PathOffset; + char Buffer[]; + + PRETTY_PRINT(FIELD(Header), STRING_FIELD(PathOffset)); +}; + typedef struct _LX_MINI_INIT_IMPORT_RESULT { static inline auto Type = LxMiniInitMessageImportResult; diff --git a/src/windows/wslaservice/exe/CMakeLists.txt b/src/windows/wslaservice/exe/CMakeLists.txt index 34fdde509..3fd90b5d5 100644 --- a/src/windows/wslaservice/exe/CMakeLists.txt +++ b/src/windows/wslaservice/exe/CMakeLists.txt @@ -1,6 +1,7 @@ set(SOURCES application.manifest ContainerEventTracker.cpp + DockerHTTPClient.cpp main.rc ServiceMain.cpp ServiceProcessLauncher.cpp @@ -14,6 +15,7 @@ set(SOURCES set(HEADERS ContainerEventTracker.h + DockerHTTPClient.h ServiceProcessLauncher.h WSLAContainer.h WSLAProcess.h @@ -38,7 +40,8 @@ target_link_libraries(wslaservice legacy_stdio_definitions VirtDisk.lib Winhttp.lib - Synchronization.lib) + Synchronization.lib + httplib) target_precompile_headers(wslaservice REUSE_FROM common) -set_target_properties(wslaservice PROPERTIES FOLDER windows) \ No newline at end of file +set_target_properties(wslaservice PROPERTIES FOLDER windows) diff --git a/src/windows/wslaservice/exe/WSLASession.cpp b/src/windows/wslaservice/exe/WSLASession.cpp index c2000e0a8..06ac1c732 100644 --- a/src/windows/wslaservice/exe/WSLASession.cpp +++ b/src/windows/wslaservice/exe/WSLASession.cpp @@ -18,12 +18,13 @@ Module Name: #include "WSLAContainer.h" #include "ServiceProcessLauncher.h" #include "WslCoreFilesystem.h" +#include "DockerHTTPClient.h" using namespace wsl::windows::common; using wsl::windows::service::wsla::WSLASession; using wsl::windows::service::wsla::WSLAVirtualMachine; -constexpr auto c_containerdStorage = "/var/lib/containerd"; +constexpr auto c_containerdStorage = "/var/lib/docker"; WSLASession::WSLASession(ULONG id, const WSLA_SESSION_SETTINGS& Settings, WSLAUserSessionImpl& userSessionImpl) : @@ -55,8 +56,8 @@ WSLASession::WSLASession(ULONG id, const WSLA_SESSION_SETTINGS& Settings, WSLAUs // Launch containerd // TODO: Rework the daemon logic so we can have only one thread watching all daemons. ServiceProcessLauncher launcher{ - "/usr/bin/containerd", - {"/usr/bin/containerd"}, + "/usr/bin/dockerd", + {"/usr/bin/dockerd"}, {{"PATH=/bin:/usr/local/sbin:/usr/bin:/usr/sbin:/sbin"}}, common::ProcessFlags::Stdout | common::ProcessFlags::Stderr}; m_containerdThread = std::thread(&WSLASession::MonitorContainerd, this, launcher.Launch(*m_virtualMachine.Get())); @@ -65,8 +66,18 @@ WSLASession::WSLASession(ULONG id, const WSLA_SESSION_SETTINGS& Settings, WSLAUs // TODO: Configurable timeout. THROW_WIN32_IF_MSG(ERROR_TIMEOUT, !m_containerdReadyEvent.wait(10 * 1000), "Timed out waiting for containerd to start"); + auto [_, __, channel] = m_virtualMachine->Fork(WSLA_FORK::Thread); + + DockerHTTPClient client(std::move(channel), m_virtualMachine->ExitingEvent(), m_virtualMachine->VmId(), 10 * 1000); + + client.SendRequest(boost::beast::http::verb::get, "/info", [](const gsl::span& span) { + WSL_LOG("Response", TraceLoggingValue(span.data(), "data")); + }); + // auto response = DockerRequest("/info"); + + //WSL_LOG("Info", TraceLoggingValue(response.c_str(), "DockerInfo")); // Start the event tracker. - m_eventTracker.emplace(*m_virtualMachine.Get()); + // m_eventTracker.emplace(*m_virtualMachine.Get()); errorCleanup.release(); } @@ -239,22 +250,18 @@ try return; } - constexpr auto c_containerdReadyLogLine = "containerd successfully booted"; + constexpr auto c_containerdReadyLogLine = "API listen on /var/run/docker.sock"; std::string entry = {buffer.begin(), buffer.end()}; WSL_LOG("ContainerdLog", TraceLoggingValue(entry.c_str(), "Content"), TraceLoggingValue(m_displayName.c_str(), "Name")); - auto parsed = nlohmann::json::parse(entry); + // auto parsed = nlohmann::json::parse(entry); if (!m_containerdReadyEvent.is_signaled()) { - auto it = parsed.find("msg"); - if (it != parsed.end()) + if (entry.find(c_containerdReadyLogLine) != std::string::npos) { - if (it->get().starts_with(c_containerdReadyLogLine)) - { - m_containerdReadyEvent.SetEvent(); - } + m_containerdReadyEvent.SetEvent(); } } } @@ -536,3 +543,35 @@ void WSLASession::OnContainerDeleted(const WSLAContainerImpl* Container) std::lock_guard lock{m_lock}; WI_VERIFY(std::erase_if(m_containers, [Container](const auto& e) { return e.second.get() == Container; }) == 1); } + +std::string WSLASession::DockerRequest(const std::string& Url) +{ + wil::unique_socket socket; + { + std::lock_guard lock{m_lock}; + socket = m_virtualMachine->ConnectUnixSocket("/var/run/docker.sock"); + } + + namespace http = boost::beast::http; + + boost::asio::io_context context; + + boost::asio::generic::stream_protocol hv_proto(AF_HYPERV, SOCK_STREAM); + + boost::asio::generic::stream_protocol::stream_protocol::socket stream(context); + stream.assign(hv_proto, socket.release()); + + http::request req{http::verb::get, Url, 11}; + req.set(http::field::host, "hvsocket"); // label only; AF_HYPERV doesn't do DNS + req.prepare_payload(); + + WSL_LOG("Written", TraceLoggingValue(http::write(stream, req), "Bytes")); + + boost::beast::flat_buffer buffer; // for header/body framing + http::response res; + http::read(stream, buffer, res); + + WSL_LOG("HTTPREsult", TraceLoggingValue(res.result_int(), "Status")); + + return res.body(); +} \ No newline at end of file diff --git a/src/windows/wslaservice/exe/WSLASession.h b/src/windows/wslaservice/exe/WSLASession.h index df58a9fde..429ea8a18 100644 --- a/src/windows/wslaservice/exe/WSLASession.h +++ b/src/windows/wslaservice/exe/WSLASession.h @@ -72,6 +72,8 @@ class DECLSPEC_UUID("4877FEFC-4977-4929-A958-9F36AA1892A4") WSLASession void OnContainerdLog(const gsl::span& Data); void MonitorContainerd(ServiceRunningProcess&& process); + std::string DockerRequest(const std::string& Url); + WSLA_SESSION_SETTINGS m_sessionSettings; // TODO: Revisit to see if we should have session settings as a member or not WSLAUserSessionImpl* m_userSession = nullptr; Microsoft::WRL::ComPtr m_virtualMachine; diff --git a/src/windows/wslaservice/exe/WSLAVirtualMachine.cpp b/src/windows/wslaservice/exe/WSLAVirtualMachine.cpp index 5f8c90f22..8f744da71 100644 --- a/src/windows/wslaservice/exe/WSLAVirtualMachine.cpp +++ b/src/windows/wslaservice/exe/WSLAVirtualMachine.cpp @@ -387,7 +387,28 @@ void WSLAVirtualMachine::ConfigureMounts() Mount(m_initChannel, nullptr, "/sys", "sysfs", "", 0); Mount(m_initChannel, nullptr, "/proc", "proc", "", 0); Mount(m_initChannel, nullptr, "/dev/pts", "devpts", "noatime,nosuid,noexec,gid=5,mode=620", 0); - Mount(m_initChannel, nullptr, "/sys/fs/cgroup", "cgroup2", "", 0); + Mount(m_initChannel, nullptr, "/sys/fs/cgroup", "tmpfs", "uid=0,gid=0,mode=0755", 0); + + std::vector cgroups = { + "cpuset", + "cpu", + "cpuacct", + "blkio", + "memory", + "devices", + "freezer", + "net_cls", + "perf_event", + "net_prio", + "hugetlb", + "pids", + "rdma" + }; + + for (const auto* e : cgroups) + { + Mount(m_initChannel, nullptr, std::format("/sys/fs/cgroup/{}", e).c_str(), "cgroup", e, 0); + } if (FeatureEnabled(WslaFeatureFlagsGPU)) // TODO: re-think how GPU settings should work at the session level API. { @@ -1523,4 +1544,18 @@ void WSLAVirtualMachine::ReleasePorts(const std::set& Ports) WI_VERIFY(m_allocatedPorts.erase(port) == 1); } +} + +wil::unique_socket WSLAVirtualMachine::ConnectUnixSocket(const char *Path) +{ + auto [_, __, channel] = Fork(WSLA_FORK::Thread); + + shared::MessageWriter message; + message.WriteString(message->PathOffset, Path); + + auto result = channel.Transaction(message.Span()); + + THROW_HR_IF_MSG(E_FAIL, result.Result < 0, "Failed to connect to unix socket: '%hs', %i", Path, result.Result); + + return channel.Release(); } \ No newline at end of file diff --git a/src/windows/wslaservice/exe/WSLAVirtualMachine.h b/src/windows/wslaservice/exe/WSLAVirtualMachine.h index b28a64a03..255987aea 100644 --- a/src/windows/wslaservice/exe/WSLAVirtualMachine.h +++ b/src/windows/wslaservice/exe/WSLAVirtualMachine.h @@ -97,6 +97,18 @@ class DECLSPEC_UUID("0CFC5DC1-B6A7-45FC-8034-3FA9ED73CE30") WSLAVirtualMachine void Mount(_In_ LPCSTR Source, _In_ LPCSTR Target, _In_ LPCSTR Type, _In_ LPCSTR Options, _In_ ULONG Flags); const wil::unique_event& TerminatingEvent(); + wil::unique_socket ConnectUnixSocket(_In_ const char* Path); + std::tuple Fork(enum WSLA_FORK::ForkType Type); + HANDLE ExitingEvent() const + { + return m_vmTerminatingEvent.get(); + } + + GUID VmId() const + { + return m_vmId; + } + private: static void Mount(wsl::shared::SocketChannel& Channel, LPCSTR Source, _In_ LPCSTR Target, _In_ LPCSTR Type, _In_ LPCSTR Options, _In_ ULONG Flags); @@ -111,7 +123,6 @@ class DECLSPEC_UUID("0CFC5DC1-B6A7-45FC-8034-3FA9ED73CE30") WSLAVirtualMachine void OnCrash(_In_ const HCS_EVENT* Event); bool FeatureEnabled(WSLAFeatureFlags Flag) const; - std::tuple Fork(enum WSLA_FORK::ForkType Type); std::tuple Fork( wsl::shared::SocketChannel& Channel, enum WSLA_FORK::ForkType Type, ULONG TtyRows = 0, ULONG TtyColumns = 0); int32_t ExpectClosedChannelOrError(wsl::shared::SocketChannel& Channel); From 9da8be2a609bd43c756ba3a1b4a2988fad0d95cf Mon Sep 17 00:00:00 2001 From: Blue Date: Tue, 23 Dec 2025 11:57:26 -0800 Subject: [PATCH 02/26] Save state --- .../wslaservice/exe/DockerHTTPClient.cpp | 109 ++++++++++++++++++ .../wslaservice/exe/DockerHTTPClient.h | 32 +++++ 2 files changed, 141 insertions(+) create mode 100644 src/windows/wslaservice/exe/DockerHTTPClient.cpp create mode 100644 src/windows/wslaservice/exe/DockerHTTPClient.h diff --git a/src/windows/wslaservice/exe/DockerHTTPClient.cpp b/src/windows/wslaservice/exe/DockerHTTPClient.cpp new file mode 100644 index 000000000..7dfeaba6e --- /dev/null +++ b/src/windows/wslaservice/exe/DockerHTTPClient.cpp @@ -0,0 +1,109 @@ +#include "precomp.h" + +#include "DockerHTTPClient.h" + +using wsl::windows::service::wsla::DockerHTTPClient; + +DockerHTTPClient::DockerHTTPClient(wsl::shared::SocketChannel&& Channel, HANDLE exitingEvent, GUID VmId, ULONG ConnectTimeoutMs) : + m_exitingEvent(exitingEvent), m_channel(std::move(Channel)), m_vmId(VmId), m_connectTimeoutMs(ConnectTimeoutMs) +{ +} + +wil::unique_socket DockerHTTPClient::ConnectSocket() +{ + auto lock = m_lock.lock_exclusive(); + + // Send a fork message. + WSLA_FORK message; + message.ForkType = WSLA_FORK::Thread; + const auto& response = m_channel.Transaction(message); + + THROW_HR_IF_MSG(E_FAIL, response.Pid <= 0, "fork() returned %i", response.Pid); + + // Connect the new hvsocket. + wsl::shared::SocketChannel newChannel{ + wsl::windows::common::hvsocket::Connect(m_vmId, response.Port, m_exitingEvent, m_connectTimeoutMs), "DockerClient", m_exitingEvent}; + lock.reset(); + + // Connect that socket to the docker unix socket. + shared::MessageWriter writer; + writer.WriteString(writer->PathOffset, "/var/run/docker.sock"); + + auto result = newChannel.Transaction(writer.Span()); + THROW_HR_IF_MSG(E_FAIL, result.Result < 0, "Failed to connect to unix socket: '/var/run/docker.sock', %i", result.Result); + + return newChannel.Release(); +} + +void DockerHTTPClient::SendRequest(boost::beast::http::verb Method, const std::string& Url, const OnResponseBytes& OnResponse, const std::string& Body) +{ + namespace http = boost::beast::http; + + boost::asio::io_context context; + boost::asio::generic::stream_protocol::socket stream(context); + + // Write the request + { + boost::asio::generic::stream_protocol hv_proto(AF_HYPERV, SOCK_STREAM); + stream.assign(hv_proto, ConnectSocket().release()); + + // boost::beast::basic_stream wrapped_stream; + + http::request req{http::verb::get, Url, 11}; + req.set(http::field::host, "hvsocket"); // label only; AF_HYPERV doesn't do DNS + req.prepare_payload(); + + http::write(stream, req); + } + + wil::unique_socket socket{stream.release()}; + + // Parse the response header + std::vector buffer(16 * 4096); + http::response_parser parser; + parser.eager(false); + parser.skip(false); + + size_t lineFeeds = 0; + + // Consume the socket until the header is reached + while (!parser.is_header_done()) + { + // Peek for the end of the HTTP header '\r\n' + auto bytesRead = common::socket::Receive( + socket.get(), gsl::span(reinterpret_cast(buffer.data()), buffer.size()), m_exitingEvent, MSG_PEEK); + + size_t i{}; + for (i = 0; i < bytesRead && lineFeeds < 2; i++) + { + if (buffer[i] == '\n') + { + lineFeeds++; + } + else if (buffer[i] != '\r') + { + lineFeeds = 0; + } + } + + // Consumme the buffer from the socket. + bytesRead = common::socket::Receive(socket.get(), gsl::span(reinterpret_cast(buffer.data()), i), m_exitingEvent); + WI_ASSERT(bytesRead == i); + + boost::beast::error_code error; + parser.put(boost::asio::buffer(buffer.data(), bytesRead), error); + THROW_HR_IF(E_UNEXPECTED, error && error != boost::beast::http::error::need_more); + } + + WSL_LOG("HTTPResult", TraceLoggingValue(parser.get().result_int(), "Status")); + + while (true) + { + auto bytesRead = common::socket::Receive(socket.get(), gsl::span(reinterpret_cast(buffer.data()), buffer.size()), m_exitingEvent); + if (bytesRead == 0) + { + break; + } + OnResponse(gsl::span{buffer.data(), gsl::narrow_cast(bytesRead)}); + } +} diff --git a/src/windows/wslaservice/exe/DockerHTTPClient.h b/src/windows/wslaservice/exe/DockerHTTPClient.h new file mode 100644 index 000000000..506b2b954 --- /dev/null +++ b/src/windows/wslaservice/exe/DockerHTTPClient.h @@ -0,0 +1,32 @@ +#pragma once + +#include +#include +#include +#include + +namespace wsl::windows::service::wsla { + +class DockerHTTPClient +{ + NON_COPYABLE(DockerHTTPClient); + +public: + using OnResponseBytes = std::function)>; + + DockerHTTPClient(wsl::shared::SocketChannel&& Channel, HANDLE ExitingEvent, GUID VmId, ULONG ConnectTimeoutMs); + + void SendRequest( + boost::beast::http::verb Method, const std::string& Url, const OnResponseBytes& OnResponse, const std::string& Body = ""); + +private: + + wil::unique_socket ConnectSocket(); + + ULONG m_connectTimeoutMs{}; + GUID m_vmId; + shared::SocketChannel m_channel; + HANDLE m_exitingEvent; + wil::srwlock m_lock; +}; +} // namespace wsl::windows::service::wsla \ No newline at end of file From d64175b905a0f1fe4f10e6bd1e87cb83316cb0e4 Mon Sep 17 00:00:00 2001 From: Blue Date: Tue, 23 Dec 2025 13:05:33 -0800 Subject: [PATCH 03/26] Save state --- .../wslaservice/exe/DockerHTTPClient.cpp | 20 ++++++++++++------- 1 file changed, 13 insertions(+), 7 deletions(-) diff --git a/src/windows/wslaservice/exe/DockerHTTPClient.cpp b/src/windows/wslaservice/exe/DockerHTTPClient.cpp index 7dfeaba6e..ad1b8cc40 100644 --- a/src/windows/wslaservice/exe/DockerHTTPClient.cpp +++ b/src/windows/wslaservice/exe/DockerHTTPClient.cpp @@ -97,13 +97,19 @@ void DockerHTTPClient::SendRequest(boost::beast::http::verb Method, const std::s WSL_LOG("HTTPResult", TraceLoggingValue(parser.get().result_int(), "Status")); - while (true) + boost::asio::generic::stream_protocol hv_proto(AF_HYPERV, SOCK_STREAM); + stream.assign(hv_proto, socket.release()); + + while (!parser.is_done()) { - auto bytesRead = common::socket::Receive(socket.get(), gsl::span(reinterpret_cast(buffer.data()), buffer.size()), m_exitingEvent); - if (bytesRead == 0) - { - break; - } - OnResponse(gsl::span{buffer.data(), gsl::narrow_cast(bytesRead)}); + boost::beast::flat_buffer adapter; + + parser.get().body().data = buffer.data(); + parser.get().body().size = buffer.size(); + http::read(stream, adapter, parser); + + auto bytesRead = parser.get().body().size - buffer.size(); + + OnResponse(gsl::span{buffer.data(), bytesRead}); } } From 30cbabefabf8665553b4117553d350d68ac0abde Mon Sep 17 00:00:00 2001 From: Blue Date: Tue, 23 Dec 2025 15:22:55 -0800 Subject: [PATCH 04/26] Save state --- src/windows/common/WslClient.cpp | 2 + .../wslaservice/exe/DockerHTTPClient.cpp | 65 +++++++++++-------- .../wslaservice/exe/DockerHTTPClient.h | 7 +- src/windows/wslaservice/exe/WSLASession.cpp | 28 ++++---- src/windows/wslaservice/exe/WSLASession.h | 2 + 5 files changed, 61 insertions(+), 43 deletions(-) diff --git a/src/windows/common/WslClient.cpp b/src/windows/common/WslClient.cpp index 7fbfc72d8..72204434d 100644 --- a/src/windows/common/WslClient.cpp +++ b/src/windows/common/WslClient.cpp @@ -1643,6 +1643,8 @@ int WslaShell(_In_ std::wstring_view commandLine) } else { + THROW_IF_FAILED(session->PullImage(containerImage.c_str(), nullptr, nullptr)); + std::vector fds{ WSLA_PROCESS_FD{.Fd = 0, .Type = WSLAFdTypeTerminalInput}, WSLA_PROCESS_FD{.Fd = 1, .Type = WSLAFdTypeTerminalOutput}, diff --git a/src/windows/wslaservice/exe/DockerHTTPClient.cpp b/src/windows/wslaservice/exe/DockerHTTPClient.cpp index ad1b8cc40..dba77294a 100644 --- a/src/windows/wslaservice/exe/DockerHTTPClient.cpp +++ b/src/windows/wslaservice/exe/DockerHTTPClient.cpp @@ -2,6 +2,7 @@ #include "DockerHTTPClient.h" +using boost::beast::http::verb; using wsl::windows::service::wsla::DockerHTTPClient; DockerHTTPClient::DockerHTTPClient(wsl::shared::SocketChannel&& Channel, HANDLE exitingEvent, GUID VmId, ULONG ConnectTimeoutMs) : @@ -9,6 +10,15 @@ DockerHTTPClient::DockerHTTPClient(wsl::shared::SocketChannel&& Channel, HANDLE { } +uint32_t DockerHTTPClient::PullImage(const char* Name, const char* Tag, const OnImageProgress& Callback) +{ + auto [code, _] = SendRequest(verb::post, std::format("http://localhost/images/create?fromImage=library/{}&tag={}", Name, Tag), {}, [Callback](const gsl::span& span) { + Callback(std::string{span.data(), span.size()}); + }); + + return code; +} + wil::unique_socket DockerHTTPClient::ConnectSocket() { auto lock = m_lock.lock_exclusive(); @@ -35,7 +45,7 @@ wil::unique_socket DockerHTTPClient::ConnectSocket() return newChannel.Release(); } -void DockerHTTPClient::SendRequest(boost::beast::http::verb Method, const std::string& Url, const OnResponseBytes& OnResponse, const std::string& Body) +std::pair DockerHTTPClient::SendRequest(boost::beast::http::verb Method, const std::string& Url, const std::string& Body, const OnResponseBytes& OnResponse) { namespace http = boost::beast::http; @@ -43,20 +53,16 @@ void DockerHTTPClient::SendRequest(boost::beast::http::verb Method, const std::s boost::asio::generic::stream_protocol::socket stream(context); // Write the request - { - boost::asio::generic::stream_protocol hv_proto(AF_HYPERV, SOCK_STREAM); - stream.assign(hv_proto, ConnectSocket().release()); - - // boost::beast::basic_stream wrapped_stream; - - http::request req{http::verb::get, Url, 11}; - req.set(http::field::host, "hvsocket"); // label only; AF_HYPERV doesn't do DNS - req.prepare_payload(); + boost::asio::generic::stream_protocol hv_proto(AF_HYPERV, SOCK_STREAM); + stream.assign(hv_proto, ConnectSocket().release()); - http::write(stream, req); - } + http::request req{Method, Url, 11}; + req.set(http::field::host, "localhost"); + req.set(http::field::connection, "close"); + req.set(http::field::accept, "application/json"); + req.prepare_payload(); - wil::unique_socket socket{stream.release()}; + http::write(stream, req); // Parse the response header std::vector buffer(16 * 4096); @@ -66,12 +72,12 @@ void DockerHTTPClient::SendRequest(boost::beast::http::verb Method, const std::s size_t lineFeeds = 0; - // Consume the socket until the header is reached + // Consume the socket until the header end is reached while (!parser.is_header_done()) { // Peek for the end of the HTTP header '\r\n' auto bytesRead = common::socket::Receive( - socket.get(), gsl::span(reinterpret_cast(buffer.data()), buffer.size()), m_exitingEvent, MSG_PEEK); + stream.native_handle(), gsl::span(reinterpret_cast(buffer.data()), buffer.size()), m_exitingEvent, MSG_PEEK); size_t i{}; for (i = 0; i < bytesRead && lineFeeds < 2; i++) @@ -87,7 +93,7 @@ void DockerHTTPClient::SendRequest(boost::beast::http::verb Method, const std::s } // Consumme the buffer from the socket. - bytesRead = common::socket::Receive(socket.get(), gsl::span(reinterpret_cast(buffer.data()), i), m_exitingEvent); + bytesRead = common::socket::Receive(stream.native_handle(), gsl::span(reinterpret_cast(buffer.data()), i), m_exitingEvent); WI_ASSERT(bytesRead == i); boost::beast::error_code error; @@ -95,21 +101,26 @@ void DockerHTTPClient::SendRequest(boost::beast::http::verb Method, const std::s THROW_HR_IF(E_UNEXPECTED, error && error != boost::beast::http::error::need_more); } - WSL_LOG("HTTPResult", TraceLoggingValue(parser.get().result_int(), "Status")); - - boost::asio::generic::stream_protocol hv_proto(AF_HYPERV, SOCK_STREAM); - stream.assign(hv_proto, socket.release()); + WSL_LOG("HTTPResult", TraceLoggingValue(Url.c_str(), "Url"), TraceLoggingValue(parser.get().result_int(), "Status")); - while (!parser.is_done()) + if (OnResponse) { - boost::beast::flat_buffer adapter; + while (!parser.is_done()) + { + boost::beast::flat_buffer adapter; - parser.get().body().data = buffer.data(); - parser.get().body().size = buffer.size(); - http::read(stream, adapter, parser); + parser.get().body().data = buffer.data(); + parser.get().body().size = buffer.size(); + http::read(stream, adapter, parser); - auto bytesRead = parser.get().body().size - buffer.size(); - OnResponse(gsl::span{buffer.data(), bytesRead}); + WSL_LOG("Sizes", TraceLoggingValue(parser.get().body().size)); + + auto bytesRead = buffer.size() - parser.get().body().size; + + OnResponse(gsl::span{buffer.data(), bytesRead}); + } } + + return {parser.get().result_int(), wil::unique_socket{stream.release()}}; } diff --git a/src/windows/wslaservice/exe/DockerHTTPClient.h b/src/windows/wslaservice/exe/DockerHTTPClient.h index 506b2b954..269ae9fd6 100644 --- a/src/windows/wslaservice/exe/DockerHTTPClient.h +++ b/src/windows/wslaservice/exe/DockerHTTPClient.h @@ -13,14 +13,15 @@ class DockerHTTPClient public: using OnResponseBytes = std::function)>; + using OnImageProgress = std::function; DockerHTTPClient(wsl::shared::SocketChannel&& Channel, HANDLE ExitingEvent, GUID VmId, ULONG ConnectTimeoutMs); - void SendRequest( - boost::beast::http::verb Method, const std::string& Url, const OnResponseBytes& OnResponse, const std::string& Body = ""); -private: + uint32_t PullImage(const char* Name, const char* Tag, const OnImageProgress& Callback); + std::pair SendRequest(boost::beast::http::verb Method, const std::string& Url, const std::string& Body = "", const OnResponseBytes& OnResponse = {}); +private: wil::unique_socket ConnectSocket(); ULONG m_connectTimeoutMs{}; diff --git a/src/windows/wslaservice/exe/WSLASession.cpp b/src/windows/wslaservice/exe/WSLASession.cpp index 06ac1c732..94abcb1cd 100644 --- a/src/windows/wslaservice/exe/WSLASession.cpp +++ b/src/windows/wslaservice/exe/WSLASession.cpp @@ -18,7 +18,6 @@ Module Name: #include "WSLAContainer.h" #include "ServiceProcessLauncher.h" #include "WslCoreFilesystem.h" -#include "DockerHTTPClient.h" using namespace wsl::windows::common; using wsl::windows::service::wsla::WSLASession; @@ -68,16 +67,11 @@ WSLASession::WSLASession(ULONG id, const WSLA_SESSION_SETTINGS& Settings, WSLAUs auto [_, __, channel] = m_virtualMachine->Fork(WSLA_FORK::Thread); - DockerHTTPClient client(std::move(channel), m_virtualMachine->ExitingEvent(), m_virtualMachine->VmId(), 10 * 1000); + m_dockerClient.emplace(std::move(channel), m_virtualMachine->ExitingEvent(), m_virtualMachine->VmId(), 10 * 1000); - client.SendRequest(boost::beast::http::verb::get, "/info", [](const gsl::span& span) { - WSL_LOG("Response", TraceLoggingValue(span.data(), "data")); - }); - // auto response = DockerRequest("/info"); - - //WSL_LOG("Info", TraceLoggingValue(response.c_str(), "DockerInfo")); - // Start the event tracker. - // m_eventTracker.emplace(*m_virtualMachine.Get()); + // WSL_LOG("Info", TraceLoggingValue(response.c_str(), "DockerInfo")); + // Start the event tracker. + // m_eventTracker.emplace(*m_virtualMachine.Get()); errorCleanup.release(); } @@ -310,10 +304,18 @@ try std::lock_guard lock{m_lock}; - ServiceProcessLauncher launcher{nerdctlPath, {nerdctlPath, "pull", ImageUri}}; - auto result = launcher.Launch(*m_virtualMachine.Get()).WaitAndCaptureOutput(); + std::string image{ImageUri}; + size_t separator = image.find(':'); + THROW_HR_IF_MSG(E_INVALIDARG, separator == std::string::npos || separator >= image.size() - 1, "Invalid image: %hs", ImageUri); + + auto callback = [&](const std::string &content) + { + WSL_LOG("ImagePullProgress", TraceLoggingValue(ImageUri, "Image"), TraceLoggingValue(content.c_str(), "Content")); + }; + + auto code = m_dockerClient->PullImage(image.substr(0, separator).c_str(), image.substr(separator + 1).c_str(), callback); - RETURN_HR_IF_MSG(E_FAIL, result.Code != 0, "Pull image failed: %hs", launcher.FormatResult(result).c_str()); + THROW_HR_IF_MSG(E_FAIL, code != 200, "Failed to pull image: %hs", ImageUri); return S_OK; } diff --git a/src/windows/wslaservice/exe/WSLASession.h b/src/windows/wslaservice/exe/WSLASession.h index 429ea8a18..447d6a840 100644 --- a/src/windows/wslaservice/exe/WSLASession.h +++ b/src/windows/wslaservice/exe/WSLASession.h @@ -18,6 +18,7 @@ Module Name: #include "WSLAVirtualMachine.h" #include "WSLAContainer.h" #include "ContainerEventTracker.h" +#include "DockerHTTPClient.h" namespace wsl::windows::service::wsla { @@ -75,6 +76,7 @@ class DECLSPEC_UUID("4877FEFC-4977-4929-A958-9F36AA1892A4") WSLASession std::string DockerRequest(const std::string& Url); WSLA_SESSION_SETTINGS m_sessionSettings; // TODO: Revisit to see if we should have session settings as a member or not + std::optional m_dockerClient; WSLAUserSessionImpl* m_userSession = nullptr; Microsoft::WRL::ComPtr m_virtualMachine; std::optional m_eventTracker; From 65200871f7775e019554d42d2e1b422582a0e929 Mon Sep 17 00:00:00 2001 From: Blue Date: Tue, 23 Dec 2025 18:29:32 -0800 Subject: [PATCH 05/26] Interactive shell POC --- src/linux/init/WSLAInit.cpp | 11 +-- src/shared/inc/lxinitshared.h | 2 +- src/windows/common/WslClient.cpp | 25 ++++- src/windows/wslaservice/exe/CMakeLists.txt | 1 + .../wslaservice/exe/DockerHTTPClient.cpp | 67 +++++++++++-- .../wslaservice/exe/DockerHTTPClient.h | 60 +++++++++++- src/windows/wslaservice/exe/WSLAContainer.cpp | 93 ++++++------------- src/windows/wslaservice/exe/WSLAContainer.h | 14 ++- src/windows/wslaservice/exe/WSLASession.cpp | 9 +- .../wslaservice/exe/WSLAVirtualMachine.cpp | 7 +- .../wslaservice/exe/WSLAVirtualMachine.h | 1 - src/windows/wslaservice/exe/docker_schema.h | 30 ++++++ src/windows/wslaservice/inc/wslaservice.idl | 1 + 13 files changed, 224 insertions(+), 97 deletions(-) create mode 100644 src/windows/wslaservice/exe/docker_schema.h diff --git a/src/linux/init/WSLAInit.cpp b/src/linux/init/WSLAInit.cpp index be6d79312..dcb3d8bfd 100644 --- a/src/linux/init/WSLAInit.cpp +++ b/src/linux/init/WSLAInit.cpp @@ -184,6 +184,10 @@ void HandleMessageImpl(wsl::shared::SocketChannel& Channel, const WSLA_OPEN& Mes void HandleMessageImpl(wsl::shared::SocketChannel& Channel, const WSLA_UNIX_CONNECT& Message, const gsl::span& Buffer) { + + // Make sure to close the channel since no more messages can be processed after this. + auto closeChannel = wil::scope_exit([&]() { Channel.Close(); }); + int result = -1; auto sendResult = wil::scope_exit([&]() { Channel.SendResultMessage(result); }); @@ -210,10 +214,7 @@ void HandleMessageImpl(wsl::shared::SocketChannel& Channel, const WSLA_UNIX_CONN sendResult.reset(); - LOG_ERROR("Connected to unix socket {}", path); - // Relay data between the two sockets. - pollfd pollDescriptors[2]; pollDescriptors[0].fd = socket.get(); pollDescriptors[0].events = POLLIN; @@ -248,8 +249,6 @@ void HandleMessageImpl(wsl::shared::SocketChannel& Channel, const WSLA_UNIX_CONN LOG_ERROR("write failed {}", errno); break; } - - LOG_ERROR("Relayed: {} bytes from unix socket to hvsocket", bytesWritten); } } @@ -275,8 +274,6 @@ void HandleMessageImpl(wsl::shared::SocketChannel& Channel, const WSLA_UNIX_CONN LOG_ERROR("write failed {}", errno); break; } - - LOG_ERROR("Relayed: {} bytes from hvsocket to unix socket", bytesWritten); } } } diff --git a/src/shared/inc/lxinitshared.h b/src/shared/inc/lxinitshared.h index fa767ea9b..0cc144377 100644 --- a/src/shared/inc/lxinitshared.h +++ b/src/shared/inc/lxinitshared.h @@ -1946,7 +1946,7 @@ struct WSLA_UNIX_CONNECT unsigned int PathOffset; char Buffer[]; - PRETTY_PRINT(FIELD(Header), STRING_FIELD(PathOffset)); + PRETTY_PRINT(FIELD(Header), STRING_FIELD(PathOffset)); }; typedef struct _LX_MINI_INIT_IMPORT_RESULT diff --git a/src/windows/common/WslClient.cpp b/src/windows/common/WslClient.cpp index 72204434d..529b001a2 100644 --- a/src/windows/common/WslClient.cpp +++ b/src/windows/common/WslClient.cpp @@ -1663,8 +1663,8 @@ int WslaShell(_In_ std::wstring_view commandLine) THROW_IF_FAILED(session->CreateContainer(&containerOptions, &container.value())); wil::com_ptr initProcess; - THROW_IF_FAILED((*container)->GetInitProcess(&initProcess)); - process.emplace(std::move(initProcess), std::move(fds)); + // THROW_IF_FAILED((*container)->GetInitProcess(&initProcess)); + // process.emplace(std::move(initProcess), std::move(fds)); } // Save original console modes so they can be restored on exit. @@ -1692,9 +1692,28 @@ int WslaShell(_In_ std::wstring_view commandLine) THROW_LAST_ERROR_IF(!::SetConsoleOutputCP(CP_UTF8)); + auto exitEvent = wil::unique_event(wil::EventOptions::ManualReset); + + if (!containerImage.empty()) + { + wil::unique_handle ttyHandle; + THROW_IF_FAILED(container->get()->GetTtyHandle(reinterpret_cast(&ttyHandle))); + + std::thread inputThread( + [&]() { wsl::windows::common::relay::StandardInputRelay(Stdin, ttyHandle.get(), []() {}, exitEvent.get()); }); + + auto joinThread = wil::scope_exit_log(WI_DIAGNOSTICS_INFO, [&]() { + exitEvent.SetEvent(); + inputThread.join(); + }); + + // Relay the contents of the pipe to stdout. + wsl::windows::common::relay::InterruptableRelay(ttyHandle.get(), Stdout); + return 0; + } + { // Create a thread to relay stdin to the pipe. - auto exitEvent = wil::unique_event(wil::EventOptions::ManualReset); wsl::shared::SocketChannel controlChannel{ wil::unique_socket{(SOCKET)process->GetStdHandle(2).release()}, "TerminalControl", exitEvent.get()}; diff --git a/src/windows/wslaservice/exe/CMakeLists.txt b/src/windows/wslaservice/exe/CMakeLists.txt index 3fd90b5d5..270b8e91c 100644 --- a/src/windows/wslaservice/exe/CMakeLists.txt +++ b/src/windows/wslaservice/exe/CMakeLists.txt @@ -15,6 +15,7 @@ set(SOURCES set(HEADERS ContainerEventTracker.h + docker_schema.h DockerHTTPClient.h ServiceProcessLauncher.h WSLAContainer.h diff --git a/src/windows/wslaservice/exe/DockerHTTPClient.cpp b/src/windows/wslaservice/exe/DockerHTTPClient.cpp index dba77294a..0bbffe220 100644 --- a/src/windows/wslaservice/exe/DockerHTTPClient.cpp +++ b/src/windows/wslaservice/exe/DockerHTTPClient.cpp @@ -4,6 +4,7 @@ using boost::beast::http::verb; using wsl::windows::service::wsla::DockerHTTPClient; +using namespace wsl::windows::service::wsla::docker_schema; DockerHTTPClient::DockerHTTPClient(wsl::shared::SocketChannel&& Channel, HANDLE exitingEvent, GUID VmId, ULONG ConnectTimeoutMs) : m_exitingEvent(exitingEvent), m_channel(std::move(Channel)), m_vmId(VmId), m_connectTimeoutMs(ConnectTimeoutMs) @@ -12,13 +13,42 @@ DockerHTTPClient::DockerHTTPClient(wsl::shared::SocketChannel&& Channel, HANDLE uint32_t DockerHTTPClient::PullImage(const char* Name, const char* Tag, const OnImageProgress& Callback) { - auto [code, _] = SendRequest(verb::post, std::format("http://localhost/images/create?fromImage=library/{}&tag={}", Name, Tag), {}, [Callback](const gsl::span& span) { - Callback(std::string{span.data(), span.size()}); - }); + auto [code, _] = SendRequest( + verb::post, + std::format("http://localhost/images/create?fromImage=library/{}&tag={}", Name, Tag), + {}, + [Callback](const gsl::span& span) { Callback(std::string{span.data(), span.size()}); }); return code; } +DockerHTTPClient::RequestResult DockerHTTPClient::CreateContainer(const docker_schema::CreateContainer& Request) +{ + // TODO: Url escaping. + return SendRequest(verb::post, "http://localhost/containers/create", Request); +} + +DockerHTTPClient::RequestResult DockerHTTPClient::StartContainer(const std::string& Id) +{ + RequestResult result; + std::tie(result.StatusCode, result.ResponseString) = Transaction(verb::post, std::format("http://localhost/containers/{}/start", Id)); + + return result; +} + +wil::unique_socket DockerHTTPClient::AttachContainer(const std::string& Id) +{ + std::map headers{ + {boost::beast::http::field::upgrade, "tcp"}, {boost::beast::http::field::connection, "upgrade"}}; + + auto [status, socket] = SendRequest( + verb::post, std::format("http://localhost/containers/{}/attach?stream=1&stdin=1&stdout=1&stderr=1&logs=true", Id), {}, {}, headers); + + THROW_HR_IF_MSG(E_FAIL, status != 101, "Failed to attach to container %hs: %i", Id.c_str(), status); + + return std::move(socket); +} + wil::unique_socket DockerHTTPClient::ConnectSocket() { auto lock = m_lock.lock_exclusive(); @@ -45,7 +75,22 @@ wil::unique_socket DockerHTTPClient::ConnectSocket() return newChannel.Release(); } -std::pair DockerHTTPClient::SendRequest(boost::beast::http::verb Method, const std::string& Url, const std::string& Body, const OnResponseBytes& OnResponse) +std::pair DockerHTTPClient::Transaction(verb Method, const std::string& Url, const std::string& Body) +{ + std::string responseBody; + auto OnResponse = [&responseBody](const gsl::span& span) { responseBody.append(span.data(), span.size()); }; + + auto [status, _] = SendRequest(Method, Url, Body, OnResponse); + + return {status, std::move(responseBody)}; +} + +std::pair DockerHTTPClient::SendRequest( + verb Method, + const std::string& Url, + const std::string& Body, + const OnResponseBytes& OnResponse, + const std::map& Headers) { namespace http = boost::beast::http; @@ -60,6 +105,17 @@ std::pair DockerHTTPClient::SendRequest(boost::bea req.set(http::field::host, "localhost"); req.set(http::field::connection, "close"); req.set(http::field::accept, "application/json"); + if (!Body.empty()) + { + req.set(http::field::content_type, "application/json"); + req.body() = Body; + } + + for (const auto [field, value] : Headers) + { + req.set(field, value); + } + req.prepare_payload(); http::write(stream, req); @@ -113,9 +169,6 @@ std::pair DockerHTTPClient::SendRequest(boost::bea parser.get().body().size = buffer.size(); http::read(stream, adapter, parser); - - WSL_LOG("Sizes", TraceLoggingValue(parser.get().body().size)); - auto bytesRead = buffer.size() - parser.get().body().size; OnResponse(gsl::span{buffer.data(), bytesRead}); diff --git a/src/windows/wslaservice/exe/DockerHTTPClient.h b/src/windows/wslaservice/exe/DockerHTTPClient.h index 269ae9fd6..0e0d20773 100644 --- a/src/windows/wslaservice/exe/DockerHTTPClient.h +++ b/src/windows/wslaservice/exe/DockerHTTPClient.h @@ -4,6 +4,7 @@ #include #include #include +#include "docker_schema.h" namespace wsl::windows::service::wsla { @@ -15,15 +16,72 @@ class DockerHTTPClient using OnResponseBytes = std::function)>; using OnImageProgress = std::function; + template + struct RequestResult + { + uint32_t StatusCode; + std::optional ResponseObject; + std::string ResponseString; + std::string RequestString; + + std::string Format() + { + return std::format("{} -> {}({})", RequestString, StatusCode, ResponseString); + } + }; + + template <> + struct RequestResult + { + uint32_t StatusCode; + std::string ResponseString; + std::string RequestString; + + std::string Format() + { + return std::format("{} -> {}({})", RequestString, StatusCode, ResponseString); + } + }; + DockerHTTPClient(wsl::shared::SocketChannel&& Channel, HANDLE ExitingEvent, GUID VmId, ULONG ConnectTimeoutMs); + RequestResult CreateContainer(const docker_schema::CreateContainer& Request); + RequestResult StartContainer(const std::string& Id); + + wil::unique_socket AttachContainer(const std::string& Id); uint32_t PullImage(const char* Name, const char* Tag, const OnImageProgress& Callback); - std::pair SendRequest(boost::beast::http::verb Method, const std::string& Url, const std::string& Body = "", const OnResponseBytes& OnResponse = {}); + std::pair SendRequest( + boost::beast::http::verb Method, + const std::string& Url, + const std::string& Body = "", + const OnResponseBytes& OnResponse = {}, + const std::map& Headers = {}); + + std::pair Transaction( + boost::beast::http::verb Method, const std::string& Url, const std::string& Body = ""); private: wil::unique_socket ConnectSocket(); + template + auto SendRequest(boost::beast::http::verb Method, const std::string& Url, const TRequest& Request) + { + RequestResult result; + result.RequestString = wsl::shared::ToJson(Request); + std::tie(result.StatusCode, result.ResponseString) = Transaction(Method, Url, result.RequestString); + + if constexpr (!std::is_same_v) + { + if (result.StatusCode >= 200 && result.StatusCode < 300) + { + result.ResponseObject = wsl::shared::FromJson(result.ResponseString.c_str()); + } + } + + return result; + } + ULONG m_connectTimeoutMs{}; GUID m_vmId; shared::SocketChannel m_channel; diff --git a/src/windows/wslaservice/exe/WSLAContainer.cpp b/src/windows/wslaservice/exe/WSLAContainer.cpp index e42808302..23ef3253c 100644 --- a/src/windows/wslaservice/exe/WSLAContainer.cpp +++ b/src/windows/wslaservice/exe/WSLAContainer.cpp @@ -138,21 +138,23 @@ WSLAContainerImpl::WSLAContainerImpl( WSLAVirtualMachine* parentVM, const WSLA_CONTAINER_OPTIONS& Options, std::string&& Id, - ContainerEventTracker& tracker, std::vector&& volumes, std::vector&& ports, - std::function&& onDeleted) : + std::function&& onDeleted, + DockerHTTPClient& DockerClient) : m_parentVM(parentVM), m_name(Options.Name), m_image(Options.Image), m_id(std::move(Id)), m_mountedVolumes(std::move(volumes)), m_mappedPorts(std::move(ports)), - m_comWrapper(wil::MakeOrThrow(this, std::move(onDeleted))) + m_comWrapper(wil::MakeOrThrow(this, std::move(onDeleted))), + m_dockerClient(DockerClient) { m_state = WslaContainerStateCreated; - m_trackingReference = tracker.RegisterContainerStateUpdates(m_id, std::bind(&WSLAContainerImpl::OnEvent, this, std::placeholders::_1)); + // Attach to the tty now. This is required not to 'drop' tty content before GetTtyHandle is called(). + m_TtyHandle = m_dockerClient.AttachContainer(m_id); } WSLAContainerImpl::~WSLAContainerImpl() @@ -179,8 +181,6 @@ WSLAContainerImpl::~WSLAContainerImpl() CATCH_LOG(); } - m_trackingReference.Reset(); - // Release port mappings. std::set allocatedGuestPorts; for (const auto& e : m_mappedPorts) @@ -221,36 +221,17 @@ void WSLAContainerImpl::Start(const WSLA_CONTAINER_OPTIONS& Options) m_name.c_str(), m_state); - ServiceProcessLauncher launcher(nerdctlPath, {nerdctlPath, "start", "-a", m_id}, defaultNerdctlEnv, common::ProcessFlags::None); - for (auto i = 0; i < Options.InitProcessOptions.FdsCount; i++) - { - launcher.AddFd(Options.InitProcessOptions.Fds[i]); - } - - m_containerProcess = launcher.Launch(*m_parentVM); + auto result = m_dockerClient.StartContainer(m_id); + THROW_HR_IF_MSG(E_FAIL, result.StatusCode != 204, "Failed to start container: %hs, %hs", m_id.c_str(), result.Format().c_str()); - auto cleanup = wil::scope_exit([&]() { m_containerProcess.reset(); }); - - // Wait for either the container to get into a 'started' state, or the nerdctl process to exit. - common::relay::MultiHandleWait wait; - wait.AddHandle(std::make_unique(m_containerProcess->GetExitEvent(), [&]() { wait.Cancel(); })); - wait.AddHandle(std::make_unique(m_startedEvent.get(), [&]() { wait.Cancel(); })); - wait.Run({}); - - if (!m_startedEvent.is_signaled()) - { - auto status = GetNerdctlStatus(); + m_state = WslaContainerStateRunning; +} - THROW_HR_IF_MSG( - E_FAIL, - status != "exited", - "Failed to start container %hs, nerdctl status: %hs", - m_name.c_str(), - status.value_or("").c_str()); - } +void WSLAContainerImpl::GetTtyHandle(ULONG* Handle) +{ + std::lock_guard lock(m_lock); - cleanup.release(); - m_state = WslaContainerStateRunning; + *Handle = HandleToUlong(common::wslutil::DuplicateHandleToCallingProcess((HANDLE)m_TtyHandle.get())); } void WSLAContainerImpl::OnEvent(ContainerEvent event) @@ -471,49 +452,26 @@ void wsl::windows::service::wsla::WSLAContainerImpl::UnmountVolumes(const std::v std::unique_ptr WSLAContainerImpl::Create( const WSLA_CONTAINER_OPTIONS& containerOptions, WSLAVirtualMachine& parentVM, - ContainerEventTracker& eventTracker, - std::function&& OnDeleted) + std::function&& OnDeleted, + DockerHTTPClient& DockerClient) { + // TODO: Think about when 'StdinOnce' should be set. auto [hasStdin, hasTty] = ParseFdStatus(containerOptions.InitProcessOptions); - // Don't support stdin for now since it will hang. - // TODO: Remove once stdin is fixed in nerdctl. - THROW_HR_IF(HRESULT_FROM_WIN32(ERROR_NOT_SUPPORTED), hasStdin && !hasTty); - - std::vector inputOptions; - if (hasStdin) - { - inputOptions.push_back("-i"); - } - - if (hasTty) - { - inputOptions.push_back("-t"); - } + auto result = DockerClient.CreateContainer( + {.Image = containerOptions.Image, .Tty = hasTty, .OpenStdin = true, .StdinOnce = true, .AttachStdin = false, .AttachStdout = false, .AttachStderr = false}); - AddEnvironmentVariables(inputOptions, containerOptions.InitProcessOptions); + THROW_HR_IF_MSG(E_FAIL, !result.ResponseObject.has_value(), "Failed to create container: %hs", result.Format().c_str()); // TODO: Rethink command line generation logic. - auto [mappedPorts, errorCleanup] = ProcessPortMappings(containerOptions, parentVM, inputOptions); + std::vector dummy; + auto [mappedPorts, errorCleanup] = ProcessPortMappings(containerOptions, parentVM, dummy); auto volumes = MountVolumes(containerOptions, parentVM); - auto args = PrepareNerdctlCreateCommand(containerOptions, std::move(inputOptions), volumes); - - ServiceProcessLauncher launcher(nerdctlPath, args, defaultNerdctlEnv); - auto result = launcher.Launch(parentVM).WaitAndCaptureOutput(); - - // TODO: Have better error codes. - THROW_HR_IF_MSG(E_FAIL, result.Code != 0, "Failed to create container: %hs", launcher.FormatResult(result).c_str()); - - auto id = result.Output[1]; - while (!id.empty() && (id.back() == '\n')) - { - id.pop_back(); - } // N.B. mappedPorts is explicitly copied because it's referenced in errorCleanup, so it can't be moved. auto container = std::make_unique( - &parentVM, containerOptions, std::move(id), eventTracker, std::move(volumes), std::vector(*mappedPorts), std::move(OnDeleted)); + &parentVM, containerOptions, std::move(result.ResponseObject->Id), std::move(volumes), std::vector(*mappedPorts), std::move(OnDeleted), DockerClient); errorCleanup.release(); @@ -653,6 +611,11 @@ HRESULT WSLAContainer::Stop(int Signal, ULONG TimeoutMs) return CallImpl(&WSLAContainerImpl::Stop, Signal, TimeoutMs); } +HRESULT WSLAContainer::GetTtyHandle(ULONG* Handle) +{ + return CallImpl(&WSLAContainerImpl::GetTtyHandle, Handle); +} + HRESULT WSLAContainer::Delete() try { diff --git a/src/windows/wslaservice/exe/WSLAContainer.h b/src/windows/wslaservice/exe/WSLAContainer.h index a4106c244..fca511e82 100644 --- a/src/windows/wslaservice/exe/WSLAContainer.h +++ b/src/windows/wslaservice/exe/WSLAContainer.h @@ -18,6 +18,7 @@ Module Name: #include "wslaservice.h" #include "WSLAVirtualMachine.h" #include "ContainerEventTracker.h" +#include "DockerHTTPClient.h" namespace wsl::windows::service::wsla { @@ -50,10 +51,10 @@ class WSLAContainerImpl WSLAVirtualMachine* parentVM, const WSLA_CONTAINER_OPTIONS& Options, std::string&& Id, - ContainerEventTracker& tracker, std::vector&& volumes, std::vector&& ports, - std::function&& OnDeleted); + std::function&& OnDeleted, + DockerHTTPClient& DockerClient); ~WSLAContainerImpl(); void Start(const WSLA_CONTAINER_OPTIONS& Options); @@ -63,6 +64,7 @@ class WSLAContainerImpl void GetState(_Out_ WSLA_CONTAINER_STATE* State); void GetInitProcess(_Out_ IWSLAProcess** process); void Exec(_In_ const WSLA_PROCESS_OPTIONS* Options, _Out_ IWSLAProcess** Process, _Out_ int* Errno); + void GetTtyHandle(_Out_ ULONG* Handle); IWSLAContainer& ComWrapper(); @@ -72,8 +74,8 @@ class WSLAContainerImpl static std::unique_ptr Create( const WSLA_CONTAINER_OPTIONS& Options, WSLAVirtualMachine& parentVM, - ContainerEventTracker& tracker, - std::function&& OnDeleted); + std::function&& OnDeleted, + DockerHTTPClient& DockerClient); private: void OnEvent(ContainerEvent event); @@ -87,12 +89,13 @@ class WSLAContainerImpl std::string m_name; std::string m_image; std::string m_id; + DockerHTTPClient& m_dockerClient; WSLA_CONTAINER_STATE m_state = WslaContainerStateInvalid; WSLAVirtualMachine* m_parentVM = nullptr; - ContainerEventTracker::ContainerTrackingReference m_trackingReference; std::vector m_mappedPorts; std::vector m_mountedVolumes; Microsoft::WRL::ComPtr m_comWrapper; + wil::unique_socket m_TtyHandle; static std::vector PrepareNerdctlCreateCommand( const WSLA_CONTAINER_OPTIONS& options, std::vector&& inputOptions, std::vector& volumes); @@ -115,6 +118,7 @@ class DECLSPEC_UUID("B1F1C4E3-C225-4CAE-AD8A-34C004DE1AE4") WSLAContainer IFACEMETHOD(GetState)(_Out_ WSLA_CONTAINER_STATE* State) override; IFACEMETHOD(GetInitProcess)(_Out_ IWSLAProcess** process) override; IFACEMETHOD(Exec)(_In_ const WSLA_PROCESS_OPTIONS* Options, _Out_ IWSLAProcess** Process, _Out_ int* Errno) override; + IFACEMETHOD(GetTtyHandle)(_Out_ ULONG* Handle) override; void Disconnect() noexcept; diff --git a/src/windows/wslaservice/exe/WSLASession.cpp b/src/windows/wslaservice/exe/WSLASession.cpp index 94abcb1cd..fe7ccac30 100644 --- a/src/windows/wslaservice/exe/WSLASession.cpp +++ b/src/windows/wslaservice/exe/WSLASession.cpp @@ -308,8 +308,7 @@ try size_t separator = image.find(':'); THROW_HR_IF_MSG(E_INVALIDARG, separator == std::string::npos || separator >= image.size() - 1, "Invalid image: %hs", ImageUri); - auto callback = [&](const std::string &content) - { + auto callback = [&](const std::string& content) { WSL_LOG("ImagePullProgress", TraceLoggingValue(ImageUri, "Image"), TraceLoggingValue(content.c_str(), "Content")); }; @@ -415,7 +414,11 @@ try auto [container, inserted] = m_containers.emplace( containerOptions->Name, WSLAContainerImpl::Create( - *containerOptions, *m_virtualMachine.Get(), *m_eventTracker, std::bind(&WSLASession::OnContainerDeleted, this, std::placeholders::_1))); + *containerOptions, + *m_virtualMachine.Get(), + std::bind(&WSLASession::OnContainerDeleted, this, std::placeholders::_1), + m_dockerClient.value())); + WI_ASSERT(inserted); container->second->Start(*containerOptions); diff --git a/src/windows/wslaservice/exe/WSLAVirtualMachine.cpp b/src/windows/wslaservice/exe/WSLAVirtualMachine.cpp index 8f744da71..a458325ad 100644 --- a/src/windows/wslaservice/exe/WSLAVirtualMachine.cpp +++ b/src/windows/wslaservice/exe/WSLAVirtualMachine.cpp @@ -402,8 +402,7 @@ void WSLAVirtualMachine::ConfigureMounts() "net_prio", "hugetlb", "pids", - "rdma" - }; + "rdma"}; for (const auto* e : cgroups) { @@ -1546,13 +1545,13 @@ void WSLAVirtualMachine::ReleasePorts(const std::set& Ports) } } -wil::unique_socket WSLAVirtualMachine::ConnectUnixSocket(const char *Path) +wil::unique_socket WSLAVirtualMachine::ConnectUnixSocket(const char* Path) { auto [_, __, channel] = Fork(WSLA_FORK::Thread); shared::MessageWriter message; message.WriteString(message->PathOffset, Path); - + auto result = channel.Transaction(message.Span()); THROW_HR_IF_MSG(E_FAIL, result.Result < 0, "Failed to connect to unix socket: '%hs', %i", Path, result.Result); diff --git a/src/windows/wslaservice/exe/WSLAVirtualMachine.h b/src/windows/wslaservice/exe/WSLAVirtualMachine.h index 255987aea..6de6ab569 100644 --- a/src/windows/wslaservice/exe/WSLAVirtualMachine.h +++ b/src/windows/wslaservice/exe/WSLAVirtualMachine.h @@ -109,7 +109,6 @@ class DECLSPEC_UUID("0CFC5DC1-B6A7-45FC-8034-3FA9ED73CE30") WSLAVirtualMachine return m_vmId; } - private: static void Mount(wsl::shared::SocketChannel& Channel, LPCSTR Source, _In_ LPCSTR Target, _In_ LPCSTR Type, _In_ LPCSTR Options, _In_ ULONG Flags); void MountGpuLibraries(_In_ LPCSTR LibrariesMountPoint, _In_ LPCSTR DriversMountpoint); diff --git a/src/windows/wslaservice/exe/docker_schema.h b/src/windows/wslaservice/exe/docker_schema.h new file mode 100644 index 000000000..64dd980fc --- /dev/null +++ b/src/windows/wslaservice/exe/docker_schema.h @@ -0,0 +1,30 @@ +#pragma once + +#include "JsonUtils.h" + +namespace wsl::windows::service::wsla::docker_schema { + +struct CreatedContainer +{ + std::string Id; + std::vector Warnings; + + NLOHMANN_DEFINE_TYPE_INTRUSIVE(CreatedContainer, Id, Warnings); +}; + +struct CreateContainer +{ + using TResponse = CreatedContainer; + std::string Image; + bool Tty{}; + bool OpenStdin{}; + bool StdinOnce{}; + bool AttachStdin{}; + bool AttachStdout{}; + bool AttachStderr{}; + std::vector Cmd; + + NLOHMANN_DEFINE_TYPE_INTRUSIVE_ONLY_SERIALIZE(CreateContainer, Image, Cmd, Tty, OpenStdin, StdinOnce); +}; + +} // namespace wsl::windows::service::wsla::docker_schema \ No newline at end of file diff --git a/src/windows/wslaservice/inc/wslaservice.idl b/src/windows/wslaservice/inc/wslaservice.idl index e89a8fc1b..387fd6624 100644 --- a/src/windows/wslaservice/inc/wslaservice.idl +++ b/src/windows/wslaservice/inc/wslaservice.idl @@ -280,6 +280,7 @@ interface IWSLAContainer : IUnknown HRESULT GetState([out] enum WSLA_CONTAINER_STATE* State); HRESULT GetInitProcess([out] IWSLAProcess** Process); HRESULT Exec([in] const struct WSLA_PROCESS_OPTIONS* Options, [out] IWSLAProcess** Process, [out] int* Errno); + HRESULT GetTtyHandle([out] ULONG * Handle); // Anonymous host port allocation (P1). //HRESULT AllocateHostPort([in] LPCSTR Name, [in] USHORT ContainerPort, [out] USHORT* AllocatedHostPort); From 5b9accd13bf2e73ad2c9ba5b0e15f6e2df3a132b Mon Sep 17 00:00:00 2001 From: Blue Date: Fri, 26 Dec 2025 14:37:50 -0800 Subject: [PATCH 06/26] Save state --- CMakeLists.txt | 15 +---- msipackage/package.wix.in | 6 ++ src/shared/inc/lxinitshared.h | 54 +----------------- src/windows/common/WslClient.cpp | 9 +-- src/windows/common/wslutil.cpp | 3 +- src/windows/wslaservice/exe/CMakeLists.txt | 5 +- .../wslaservice/exe/DockerHTTPClient.cpp | 9 ++- .../wslaservice/exe/DockerHTTPClient.h | 55 +++++++++++++++---- src/windows/wslaservice/exe/WSLAContainer.cpp | 31 ++++++----- src/windows/wslaservice/exe/WSLAContainer.h | 7 ++- src/windows/wslaservice/exe/WSLAProcess.cpp | 5 ++ src/windows/wslaservice/exe/WSLAProcess.h | 1 + src/windows/wslaservice/exe/WSLASession.cpp | 2 - src/windows/wslaservice/exe/docker_schema.h | 17 ++++++ src/windows/wslaservice/inc/wslaservice.idl | 2 + 15 files changed, 116 insertions(+), 105 deletions(-) diff --git a/CMakeLists.txt b/CMakeLists.txt index 6d6095651..f557fdc47 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -40,24 +40,14 @@ FetchContent_Declare(nlohmannjson FetchContent_MakeAvailable(nlohmannjson) FetchContent_GetProperties(nlohmannjson SOURCE_DIR NLOHMAN_JSON_SOURCE_DIR) -FetchContent_Declare(httplib SYSTEM - GIT_REPOSITORY https://github.com/yhirose/cpp-httplib - GIT_TAG v0.28.0 - GIT_SHALLOW TRUE) - -FetchContent_MakeAvailable(httplib) -FetchContent_GetProperties(httplib SOURCE_DIR HTTP_LIB_SOURCE_DIR) - - set(BOOST_VERSION "1.90.0") set(BOOST_TARBALL "boost_${BOOST_VERSION}") -string(REPLACE "." "_" BOOST_TARBALL "${BOOST_TARBALL}") # 1.84.0 -> 1_84_0 +string(REPLACE "." "_" BOOST_TARBALL "${BOOST_TARBALL}") FetchContent_Declare( boost_headers URL https://archives.boost.io/release/${BOOST_VERSION}/source/${BOOST_TARBALL}.tar.gz - # You can add URL_HASH to pin integrity: - # URL_HASH SHA256= + # URL_HASH SHA256=TODO ) # Download & unpack to boost_headers_SOURCE_DIR (no add_subdirectory!) @@ -344,7 +334,6 @@ set(LINUX_COMMON_FLAGS --gcc-toolchain=${LINUXSDK_PATH} -I "${CMAKE_CURRENT_LIST_DIR}/src/shared/configfile" -I "${CMAKE_CURRENT_LIST_DIR}/src/shared/inc" -I "${NLOHMAN_JSON_SOURCE_DIR}/include" - -I "${HTTP_LIB_SOURCE_DIR}" -I "${CMAKE_BINARY_DIR}/generated" --no-standard-libraries -Werror diff --git a/msipackage/package.wix.in b/msipackage/package.wix.in index 0ef367e0d..370d482a7 100644 --- a/msipackage/package.wix.in +++ b/msipackage/package.wix.in @@ -262,6 +262,12 @@ + + + + + + diff --git a/src/shared/inc/lxinitshared.h b/src/shared/inc/lxinitshared.h index 0cc144377..f7b7639fb 100644 --- a/src/shared/inc/lxinitshared.h +++ b/src/shared/inc/lxinitshared.h @@ -398,8 +398,6 @@ typedef enum _LX_MESSAGE_TYPE LxMessageWSLATerminalChanged, LxMessageWSLAWatchProcesses, LxMessageWSLAProcessExited, - LxMessageWSLAHTTPResponse, - LxMessageWSLAHTTPRequest, LxMessageWSLAUnixConnect, } LX_MESSAGE_TYPE, *PLX_MESSAGE_TYPE; @@ -511,8 +509,7 @@ inline auto ToString(LX_MESSAGE_TYPE messageType) X(LxMessageWSLADetach) X(LxMessageWSLATerminalChanged) X(LxMessageWSLAWatchProcesses) - X(LxMessageWSLAHTTPResponse) - X(LxMessageWSLAHTTPRequest) + X(LxMessageWSLAProcessExited) X(LxMessageWSLAUnixConnect) default: @@ -1886,55 +1883,6 @@ struct WSLA_PROCESS_EXITED PRETTY_PRINT(FIELD(Header), FIELD(Pid), FIELD(Code), FIELD(Signaled)); }; -struct WSLA_HTTP_RESPONSE -{ - static inline auto Type = LxMessageWSLAHTTPResponse; - - int Errno; - unsigned int StatusCode; - unsigned int ContentSize; -}; - -enum class HTTPMethod -{ - GET, - POST -}; - -inline auto ToString(HTTPMethod type) -{ - if (type == HTTPMethod::GET) - { - return "GET"; - } - else if (type == HTTPMethod::POST) - { - return "POST"; - } - else - { - return "Unknown"; - } -} - -inline void PrettyPrint(std::stringstream& Out, HTTPMethod Value) -{ - Out << ToString(Value); -} - -struct WSLA_HTTP_REQUEST -{ - static inline auto Type = LxMessageWSLAHTTPRequest; - using TResponse = WSLA_HTTP_RESPONSE; - - HTTPMethod Method; - unsigned int UrlOffset; - unsigned int BodyOffset; - unsigned int ContentTypeOffset; - - PRETTY_PRINT(FIELD(Method), STRING_FIELD(UrlOffset), STRING_FIELD(BodyOffset), STRING_FIELD(ContentTypeOffset)); -}; - struct WSLA_UNIX_CONNECT { static inline auto Type = LxMessageWSLAUnixConnect; diff --git a/src/windows/common/WslClient.cpp b/src/windows/common/WslClient.cpp index 529b001a2..b71b42da9 100644 --- a/src/windows/common/WslClient.cpp +++ b/src/windows/common/WslClient.cpp @@ -1661,10 +1661,12 @@ int WslaShell(_In_ std::wstring_view commandLine) container.emplace(); THROW_IF_FAILED(session->CreateContainer(&containerOptions, &container.value())); + THROW_IF_FAILED((*container)->Start()); wil::com_ptr initProcess; - // THROW_IF_FAILED((*container)->GetInitProcess(&initProcess)); - // process.emplace(std::move(initProcess), std::move(fds)); + THROW_IF_FAILED((*container)->GetInitProcess(&initProcess)); + process.emplace(std::move(initProcess), std::move(fds)); + } // Save original console modes so they can be restored on exit. @@ -1696,8 +1698,7 @@ int WslaShell(_In_ std::wstring_view commandLine) if (!containerImage.empty()) { - wil::unique_handle ttyHandle; - THROW_IF_FAILED(container->get()->GetTtyHandle(reinterpret_cast(&ttyHandle))); + auto ttyHandle = process->GetStdHandle(0); std::thread inputThread( [&]() { wsl::windows::common::relay::StandardInputRelay(Stdin, ttyHandle.get(), []() {}, exitEvent.get()); }); diff --git a/src/windows/common/wslutil.cpp b/src/windows/common/wslutil.cpp index 8c4ad8ac3..81258d431 100644 --- a/src/windows/common/wslutil.cpp +++ b/src/windows/common/wslutil.cpp @@ -144,7 +144,8 @@ static const std::map g_commonErrors{ X_WIN32(WSAECONNREFUSED), X_WIN32(ERROR_BAD_PATHNAME), X(WININET_E_TIMEOUT), - X_WIN32(ERROR_INVALID_SID)}; + X_WIN32(ERROR_INVALID_SID), + X_WIN32(ERROR_INVALID_STATE)}; #undef X diff --git a/src/windows/wslaservice/exe/CMakeLists.txt b/src/windows/wslaservice/exe/CMakeLists.txt index 270b8e91c..2833d58da 100644 --- a/src/windows/wslaservice/exe/CMakeLists.txt +++ b/src/windows/wslaservice/exe/CMakeLists.txt @@ -1,5 +1,6 @@ set(SOURCES application.manifest + WSLAContainerProcess.cpp ContainerEventTracker.cpp DockerHTTPClient.cpp main.rc @@ -15,6 +16,7 @@ set(SOURCES set(HEADERS ContainerEventTracker.h + WSLAContainerProcess.h docker_schema.h DockerHTTPClient.h ServiceProcessLauncher.h @@ -41,8 +43,7 @@ target_link_libraries(wslaservice legacy_stdio_definitions VirtDisk.lib Winhttp.lib - Synchronization.lib - httplib) + Synchronization.lib) target_precompile_headers(wslaservice REUSE_FROM common) set_target_properties(wslaservice PROPERTIES FOLDER windows) diff --git a/src/windows/wslaservice/exe/DockerHTTPClient.cpp b/src/windows/wslaservice/exe/DockerHTTPClient.cpp index 0bbffe220..d177006f2 100644 --- a/src/windows/wslaservice/exe/DockerHTTPClient.cpp +++ b/src/windows/wslaservice/exe/DockerHTTPClient.cpp @@ -22,10 +22,15 @@ uint32_t DockerHTTPClient::PullImage(const char* Name, const char* Tag, const On return code; } -DockerHTTPClient::RequestResult DockerHTTPClient::CreateContainer(const docker_schema::CreateContainer& Request) +CreatedContainer DockerHTTPClient::CreateContainer(const docker_schema::CreateContainer& Request) { // TODO: Url escaping. - return SendRequest(verb::post, "http://localhost/containers/create", Request); + return Transaction(verb::post, "http://localhost/containers/create", Request); +} + +void DockerHTTPClient::ResizeContainerTty(const std::string& Id, ULONG Rows, ULONG Columns) +{ + Transaction(verb::post, std::format("http://localhost/containers/{}/resize?w={}&h={}", Id, Columns, Rows)); } DockerHTTPClient::RequestResult DockerHTTPClient::StartContainer(const std::string& Id) diff --git a/src/windows/wslaservice/exe/DockerHTTPClient.h b/src/windows/wslaservice/exe/DockerHTTPClient.h index 0e0d20773..fd9f6b034 100644 --- a/src/windows/wslaservice/exe/DockerHTTPClient.h +++ b/src/windows/wslaservice/exe/DockerHTTPClient.h @@ -8,6 +8,31 @@ namespace wsl::windows::service::wsla { +class DockerHTTPExceptions : std::runtime_error +{ +public: + DockerHTTPExceptions(uint16_t StatusCode, const std::string& Url, const std::string& RequestContent, const std::string& ResponseContent) : + std::runtime_error(std::format("HTTP request failed: {} -> {} (Request: {}, Response: {})", Url, StatusCode, RequestContent, ResponseContent)), + m_statusCode(StatusCode), + m_url(Url), + m_request(RequestContent), + m_response(ResponseContent) + { + } + + template + T DockerMessage() + { + return wsl::shared::FromJson(m_response.c_str()); + } + +private: + uint16_t m_statusCode{}; + std::string m_url; + std::string m_request; + std::string m_response; +}; + class DockerHTTPClient { NON_COPYABLE(DockerHTTPClient); @@ -45,11 +70,13 @@ class DockerHTTPClient DockerHTTPClient(wsl::shared::SocketChannel&& Channel, HANDLE ExitingEvent, GUID VmId, ULONG ConnectTimeoutMs); - RequestResult CreateContainer(const docker_schema::CreateContainer& Request); + docker_schema::CreatedContainer CreateContainer(const docker_schema::CreateContainer& Request); RequestResult StartContainer(const std::string& Id); wil::unique_socket AttachContainer(const std::string& Id); + void ResizeContainerTty(const std::string& Id, ULONG Rows, ULONG Columns); + uint32_t PullImage(const char* Name, const char* Tag, const OnImageProgress& Callback); std::pair SendRequest( boost::beast::http::verb Method, @@ -64,22 +91,26 @@ class DockerHTTPClient private: wil::unique_socket ConnectSocket(); - template - auto SendRequest(boost::beast::http::verb Method, const std::string& Url, const TRequest& Request) + template + auto Transaction(boost::beast::http::verb Method, const std::string& Url, const TRequest& RequestObject = {}) { - RequestResult result; - result.RequestString = wsl::shared::ToJson(Request); - std::tie(result.StatusCode, result.ResponseString) = Transaction(Method, Url, result.RequestString); + std::string requestString; + if constexpr (!std::is_same_v) + { + requestString = wsl::shared::ToJson(RequestObject); + } + + auto [statusCode, responseString] = Transaction(Method, Url, requestString); - if constexpr (!std::is_same_v) + if (statusCode < 200 || statusCode >= 300) { - if (result.StatusCode >= 200 && result.StatusCode < 300) - { - result.ResponseObject = wsl::shared::FromJson(result.ResponseString.c_str()); - } + throw DockerHTTPExceptions(statusCode, Url, requestString, responseString); } - return result; + if constexpr (!std::is_same_v) + { + return wsl::shared::FromJson(responseString.c_str()); + } } ULONG m_connectTimeoutMs{}; diff --git a/src/windows/wslaservice/exe/WSLAContainer.cpp b/src/windows/wslaservice/exe/WSLAContainer.cpp index 23ef3253c..4596a3b9c 100644 --- a/src/windows/wslaservice/exe/WSLAContainer.cpp +++ b/src/windows/wslaservice/exe/WSLAContainer.cpp @@ -150,11 +150,8 @@ WSLAContainerImpl::WSLAContainerImpl( m_mappedPorts(std::move(ports)), m_comWrapper(wil::MakeOrThrow(this, std::move(onDeleted))), m_dockerClient(DockerClient) -{ +{ m_state = WslaContainerStateCreated; - - // Attach to the tty now. This is required not to 'drop' tty content before GetTtyHandle is called(). - m_TtyHandle = m_dockerClient.AttachContainer(m_id); } WSLAContainerImpl::~WSLAContainerImpl() @@ -210,7 +207,7 @@ IWSLAContainer& WSLAContainerImpl::ComWrapper() return *m_comWrapper.Get(); } -void WSLAContainerImpl::Start(const WSLA_CONTAINER_OPTIONS& Options) +void WSLAContainerImpl::Start() { std::lock_guard lock(m_lock); @@ -221,17 +218,22 @@ void WSLAContainerImpl::Start(const WSLA_CONTAINER_OPTIONS& Options) m_name.c_str(), m_state); + // Attach to the container's init process so no IO is lost. + m_initProcess.emplace(std::string{m_id}, wil::unique_handle{(HANDLE)m_dockerClient.AttachContainer(m_id).release()}, true, m_dockerClient); + auto cleanup = wil::scope_exit_log(WI_DIAGNOSTICS_INFO, [this]() mutable { m_initProcess.reset(); }); + auto result = m_dockerClient.StartContainer(m_id); THROW_HR_IF_MSG(E_FAIL, result.StatusCode != 204, "Failed to start container: %hs, %hs", m_id.c_str(), result.Format().c_str()); m_state = WslaContainerStateRunning; + cleanup.release(); } void WSLAContainerImpl::GetTtyHandle(ULONG* Handle) { std::lock_guard lock(m_lock); - *Handle = HandleToUlong(common::wslutil::DuplicateHandleToCallingProcess((HANDLE)m_TtyHandle.get())); + //*Handle = HandleToUlong(common::wslutil::DuplicateHandleToCallingProcess((HANDLE)m_TtyHandle.get())); } void WSLAContainerImpl::OnEvent(ContainerEvent event) @@ -311,10 +313,10 @@ WSLA_CONTAINER_STATE WSLAContainerImpl::State() noexcept std::lock_guard lock(m_lock); // If the container is running, refresh the init process state before returning. - if (m_state == WslaContainerStateRunning && m_containerProcess->State() != WSLAProcessStateRunning) + // if (m_state == WslaContainerStateRunning && m_containerProcess->State() != WSLAProcessStateRunning) { m_state = WslaContainerStateExited; - m_containerProcess.reset(); + // m_containerProcess.reset(); } return m_state; @@ -329,8 +331,8 @@ void WSLAContainerImpl::GetInitProcess(IWSLAProcess** Process) { std::lock_guard lock(m_lock); - THROW_HR_IF(HRESULT_FROM_WIN32(ERROR_INVALID_STATE), !m_containerProcess.has_value()); - THROW_IF_FAILED(m_containerProcess->Get().QueryInterface(__uuidof(IWSLAProcess), (void**)Process)); + THROW_HR_IF(HRESULT_FROM_WIN32(ERROR_INVALID_STATE), !m_initProcess.has_value()); + THROW_IF_FAILED(m_initProcess->QueryInterface(__uuidof(IWSLAProcess), (void**)Process)); } void WSLAContainerImpl::Exec(const WSLA_PROCESS_OPTIONS* Options, IWSLAProcess** Process, int* Errno) @@ -461,8 +463,6 @@ std::unique_ptr WSLAContainerImpl::Create( auto result = DockerClient.CreateContainer( {.Image = containerOptions.Image, .Tty = hasTty, .OpenStdin = true, .StdinOnce = true, .AttachStdin = false, .AttachStdout = false, .AttachStderr = false}); - THROW_HR_IF_MSG(E_FAIL, !result.ResponseObject.has_value(), "Failed to create container: %hs", result.Format().c_str()); - // TODO: Rethink command line generation logic. std::vector dummy; auto [mappedPorts, errorCleanup] = ProcessPortMappings(containerOptions, parentVM, dummy); @@ -471,7 +471,7 @@ std::unique_ptr WSLAContainerImpl::Create( // N.B. mappedPorts is explicitly copied because it's referenced in errorCleanup, so it can't be moved. auto container = std::make_unique( - &parentVM, containerOptions, std::move(result.ResponseObject->Id), std::move(volumes), std::vector(*mappedPorts), std::move(OnDeleted), DockerClient); + &parentVM, containerOptions, std::move(result.Id), std::move(volumes), std::vector(*mappedPorts), std::move(OnDeleted), DockerClient); errorCleanup.release(); @@ -616,6 +616,11 @@ HRESULT WSLAContainer::GetTtyHandle(ULONG* Handle) return CallImpl(&WSLAContainerImpl::GetTtyHandle, Handle); } +HRESULT WSLAContainer::Start() +{ + return CallImpl(&WSLAContainerImpl::Start); +} + HRESULT WSLAContainer::Delete() try { diff --git a/src/windows/wslaservice/exe/WSLAContainer.h b/src/windows/wslaservice/exe/WSLAContainer.h index fca511e82..6141d8f72 100644 --- a/src/windows/wslaservice/exe/WSLAContainer.h +++ b/src/windows/wslaservice/exe/WSLAContainer.h @@ -19,6 +19,7 @@ Module Name: #include "WSLAVirtualMachine.h" #include "ContainerEventTracker.h" #include "DockerHTTPClient.h" +#include "WSLAContainerProcess.h" namespace wsl::windows::service::wsla { @@ -57,7 +58,7 @@ class WSLAContainerImpl DockerHTTPClient& DockerClient); ~WSLAContainerImpl(); - void Start(const WSLA_CONTAINER_OPTIONS& Options); + void Start(); void Stop(_In_ int Signal, _In_ ULONG TimeoutMs); void Delete(); @@ -85,7 +86,6 @@ class WSLAContainerImpl std::recursive_mutex m_lock; wil::unique_event m_startedEvent{wil::EventOptions::ManualReset}; - std::optional m_containerProcess; std::string m_name; std::string m_image; std::string m_id; @@ -95,7 +95,7 @@ class WSLAContainerImpl std::vector m_mappedPorts; std::vector m_mountedVolumes; Microsoft::WRL::ComPtr m_comWrapper; - wil::unique_socket m_TtyHandle; + std::optional m_initProcess; static std::vector PrepareNerdctlCreateCommand( const WSLA_CONTAINER_OPTIONS& options, std::vector&& inputOptions, std::vector& volumes); @@ -119,6 +119,7 @@ class DECLSPEC_UUID("B1F1C4E3-C225-4CAE-AD8A-34C004DE1AE4") WSLAContainer IFACEMETHOD(GetInitProcess)(_Out_ IWSLAProcess** process) override; IFACEMETHOD(Exec)(_In_ const WSLA_PROCESS_OPTIONS* Options, _Out_ IWSLAProcess** Process, _Out_ int* Errno) override; IFACEMETHOD(GetTtyHandle)(_Out_ ULONG* Handle) override; + IFACEMETHOD(Start)() override; void Disconnect() noexcept; diff --git a/src/windows/wslaservice/exe/WSLAProcess.cpp b/src/windows/wslaservice/exe/WSLAProcess.cpp index 95269a36f..48f6ad0a8 100644 --- a/src/windows/wslaservice/exe/WSLAProcess.cpp +++ b/src/windows/wslaservice/exe/WSLAProcess.cpp @@ -131,6 +131,11 @@ try } CATCH_RETURN(); +HRESULT WSLAProcess::ResizeTty(ULONG Rows, ULONG Columns) +{ + return E_NOTIMPL; +} + void WSLAProcess::OnTerminated(bool Signalled, int Code) { WI_ASSERT(m_virtualMachine != nullptr); diff --git a/src/windows/wslaservice/exe/WSLAProcess.h b/src/windows/wslaservice/exe/WSLAProcess.h index c8639ac07..cd06dc187 100644 --- a/src/windows/wslaservice/exe/WSLAProcess.h +++ b/src/windows/wslaservice/exe/WSLAProcess.h @@ -33,6 +33,7 @@ class DECLSPEC_UUID("AFBEA6D6-D8A4-4F81-8FED-F947EB74B33B") WSLAProcess IFACEMETHOD(GetStdHandle)(_In_ ULONG Index, _Out_ ULONG* Handle) override; IFACEMETHOD(GetPid)(_Out_ int* Pid) override; IFACEMETHOD(GetState)(_Out_ WSLA_PROCESS_STATE* State, _Out_ int* Code) override; + IFACEMETHOD(ResizeTty)(_In_ ULONG Rows, _In_ ULONG Columns) override; void OnTerminated(bool Signalled, int Code); void OnVmTerminated(); diff --git a/src/windows/wslaservice/exe/WSLASession.cpp b/src/windows/wslaservice/exe/WSLASession.cpp index fe7ccac30..4214bd93d 100644 --- a/src/windows/wslaservice/exe/WSLASession.cpp +++ b/src/windows/wslaservice/exe/WSLASession.cpp @@ -421,8 +421,6 @@ try WI_ASSERT(inserted); - container->second->Start(*containerOptions); - THROW_IF_FAILED(container->second->ComWrapper().QueryInterface(__uuidof(IWSLAContainer), (void**)Container)); return S_OK; diff --git a/src/windows/wslaservice/exe/docker_schema.h b/src/windows/wslaservice/exe/docker_schema.h index 64dd980fc..22358fd35 100644 --- a/src/windows/wslaservice/exe/docker_schema.h +++ b/src/windows/wslaservice/exe/docker_schema.h @@ -12,9 +12,26 @@ struct CreatedContainer NLOHMANN_DEFINE_TYPE_INTRUSIVE(CreatedContainer, Id, Warnings); }; +struct ErrorResponse +{ + std::string Message; + + NLOHMANN_DEFINE_TYPE_INTRUSIVE(ErrorResponse, Message); +}; + +struct EmtpyResponse +{ +}; + +struct EmtpyRequest +{ + using TResponse = EmtpyResponse; +}; + struct CreateContainer { using TResponse = CreatedContainer; + std::string Image; bool Tty{}; bool OpenStdin{}; diff --git a/src/windows/wslaservice/inc/wslaservice.idl b/src/windows/wslaservice/inc/wslaservice.idl index 387fd6624..756698b4f 100644 --- a/src/windows/wslaservice/inc/wslaservice.idl +++ b/src/windows/wslaservice/inc/wslaservice.idl @@ -211,6 +211,7 @@ interface IWSLAProcess : IUnknown HRESULT GetStdHandle([in] ULONG Index, [out] ULONG* Handle); HRESULT GetPid([out] int* Pid); HRESULT GetState([out] enum WSLA_PROCESS_STATE* State, [out] int* Code); + HRESULT ResizeTty([in] ULONG Rows, [in] ULONG Columns); // Note: the SDK can offer a convenience Wait() method, but that doesn't need to be part of the service API. } @@ -276,6 +277,7 @@ struct WSLA_SESSION_SETTINGS { interface IWSLAContainer : IUnknown { HRESULT Stop([in] int Signal, [in] ULONG TimeoutMs); + HRESULT Start(); HRESULT Delete(); // TODO: Look into lifetime logic. HRESULT GetState([out] enum WSLA_CONTAINER_STATE* State); HRESULT GetInitProcess([out] IWSLAProcess** Process); From 74650e55d8fd9427cd71ec1f745f50212c65a9c7 Mon Sep 17 00:00:00 2001 From: Blue Date: Fri, 26 Dec 2025 14:55:04 -0800 Subject: [PATCH 07/26] New process class --- src/windows/common/WslClient.cpp | 36 +++++++++++++------------------- 1 file changed, 15 insertions(+), 21 deletions(-) diff --git a/src/windows/common/WslClient.cpp b/src/windows/common/WslClient.cpp index b71b42da9..a14883f00 100644 --- a/src/windows/common/WslClient.cpp +++ b/src/windows/common/WslClient.cpp @@ -1696,31 +1696,26 @@ int WslaShell(_In_ std::wstring_view commandLine) auto exitEvent = wil::unique_event(wil::EventOptions::ManualReset); + std::vector handleStorage; + HANDLE ttyInput = nullptr; + HANDLE ttyOutput = nullptr; if (!containerImage.empty()) { - auto ttyHandle = process->GetStdHandle(0); - - std::thread inputThread( - [&]() { wsl::windows::common::relay::StandardInputRelay(Stdin, ttyHandle.get(), []() {}, exitEvent.get()); }); - - auto joinThread = wil::scope_exit_log(WI_DIAGNOSTICS_INFO, [&]() { - exitEvent.SetEvent(); - inputThread.join(); - }); - - // Relay the contents of the pipe to stdout. - wsl::windows::common::relay::InterruptableRelay(ttyHandle.get(), Stdout); - return 0; + auto& it = handleStorage.emplace_back(process->GetStdHandle(0)); + ttyInput = it.get(); + ttyOutput = it.get(); + } + else + { + ttyInput = handleStorage.emplace_back(process->GetStdHandle(0)).get(); + ttyOutput = handleStorage.emplace_back(process->GetStdHandle(1)).get(); } { // Create a thread to relay stdin to the pipe. - wsl::shared::SocketChannel controlChannel{ - wil::unique_socket{(SOCKET)process->GetStdHandle(2).release()}, "TerminalControl", exitEvent.get()}; - std::thread inputThread([&]() { - auto updateTerminal = [&controlChannel, &Stdout]() { + auto updateTerminal = [&Stdout, &process]() { CONSOLE_SCREEN_BUFFER_INFOEX info{}; info.cbSize = sizeof(info); @@ -1729,11 +1724,10 @@ int WslaShell(_In_ std::wstring_view commandLine) WSLA_TERMINAL_CHANGED message{}; message.Columns = info.srWindow.Right - info.srWindow.Left + 1; message.Rows = info.srWindow.Bottom - info.srWindow.Top + 1; - - controlChannel.SendMessage(message); + LOG_IF_FAILED(process->Get().ResizeTty(message.Rows, message.Columns)); }; - wsl::windows::common::relay::StandardInputRelay(Stdin, process->GetStdHandle(0).get(), updateTerminal, exitEvent.get()); + wsl::windows::common::relay::StandardInputRelay(Stdin, ttyInput, updateTerminal, exitEvent.get()); }); auto joinThread = wil::scope_exit_log(WI_DIAGNOSTICS_INFO, [&]() { @@ -1742,7 +1736,7 @@ int WslaShell(_In_ std::wstring_view commandLine) }); // Relay the contents of the pipe to stdout. - wsl::windows::common::relay::InterruptableRelay(process->GetStdHandle(1).get(), Stdout); + wsl::windows::common::relay::InterruptableRelay(ttyOutput, Stdout); } process->GetExitEvent().wait(); From 364ad34b5b7e25e90603d65adeb77bb706dc337d Mon Sep 17 00:00:00 2001 From: Blue Date: Fri, 26 Dec 2025 16:43:53 -0800 Subject: [PATCH 08/26] Wire event tracker --- src/linux/init/WSLAInit.cpp | 2 - src/windows/common/relay.cpp | 76 ++++++++++++++++++ src/windows/common/relay.hpp | 20 ++++- .../wslaservice/exe/ContainerEventTracker.cpp | 77 ++++++++----------- .../wslaservice/exe/ContainerEventTracker.h | 6 +- .../wslaservice/exe/DockerHTTPClient.cpp | 39 ++++++++-- .../wslaservice/exe/DockerHTTPClient.h | 33 ++++---- src/windows/wslaservice/exe/WSLAContainer.cpp | 48 ++++-------- src/windows/wslaservice/exe/WSLAContainer.h | 4 +- .../wslaservice/exe/WSLAContainerProcess.cpp | 75 ++++++++++++++++++ .../wslaservice/exe/WSLAContainerProcess.h | 31 ++++++++ src/windows/wslaservice/exe/WSLASession.cpp | 4 +- src/windows/wslaservice/exe/docker_schema.h | 6 +- 13 files changed, 306 insertions(+), 115 deletions(-) create mode 100644 src/windows/wslaservice/exe/WSLAContainerProcess.cpp create mode 100644 src/windows/wslaservice/exe/WSLAContainerProcess.h diff --git a/src/linux/init/WSLAInit.cpp b/src/linux/init/WSLAInit.cpp index dcb3d8bfd..11e8e7546 100644 --- a/src/linux/init/WSLAInit.cpp +++ b/src/linux/init/WSLAInit.cpp @@ -277,8 +277,6 @@ void HandleMessageImpl(wsl::shared::SocketChannel& Channel, const WSLA_UNIX_CONN } } } - - LOG_ERROR("Relay exited"); } void HandleMessageImpl(wsl::shared::SocketChannel& Channel, const WSLA_TTY_RELAY& Message, const gsl::span&) diff --git a/src/windows/common/relay.cpp b/src/windows/common/relay.cpp index 270044f63..be79057ff 100644 --- a/src/windows/common/relay.cpp +++ b/src/windows/common/relay.cpp @@ -17,6 +17,7 @@ Module Name: #pragma hdrstop using wsl::windows::common::relay::EventHandle; +using wsl::windows::common::relay::HTTPChunkBasedReadHandle; using wsl::windows::common::relay::IOHandleStatus; using wsl::windows::common::relay::LineBasedReadHandle; using wsl::windows::common::relay::MultiHandleWait; @@ -1163,6 +1164,81 @@ void LineBasedReadHandle::OnRead(const gsl::span& Buffer) PendingBuffer.insert(PendingBuffer.end(), begin, end); } +HTTPChunkBasedReadHandle::HTTPChunkBasedReadHandle(wil::unique_handle&& MovedHandle, std::function& Line)>&& OnChunk) : + LineBasedReadHandle(std::move(MovedHandle), [this](const gsl::span& Buffer) { OnRead(Buffer); }), OnChunk(OnChunk) +{ +} + +HTTPChunkBasedReadHandle::~HTTPChunkBasedReadHandle() +{ + LOG_HR_IF(E_UNEXPECTED, !PendingBuffer.empty() || PendingChunkSize != 0); +} + +void HTTPChunkBasedReadHandle::OnRead(const gsl::span& Input) +{ + // See: https://httpwg.org/specs/rfc9112.html#field.transfer-encoding + + auto buffer = Input; + + auto advance = [&](size_t count) { + WI_ASSERT(buffer.size() >= count); + buffer = buffer.subspan(count); + }; + + while (!buffer.empty()) + { + if (PendingChunkSize == 0 || !ReadingChunk) + { + if (buffer.front() == '\r' || buffer.front() == '\n') + { + // Consume CRLF's between chunks. + advance(1); + continue; + } + } + + if (PendingChunkSize == 0) + { + auto lf = std::ranges::find(buffer, '\r'); + + THROW_HR_IF_MSG( + E_INVALIDARG, lf == buffer.end(), "Unexpected HTTP chunk trailer: %hs", std::string(buffer.data(), buffer.size()).c_str()); + + auto chunkSizeStr = std::string{buffer.begin(), lf}; + + try + { + PendingChunkSize = std::stoul(chunkSizeStr.c_str(), nullptr, 16); + } + catch (...) + { + THROW_HR_MSG(E_INVALIDARG, "Failed to parse chunk size: %hs", chunkSizeStr.c_str()); + } + + advance(chunkSizeStr.size()); + ReadingChunk = false; + } + else + { + // Consume the chunk. + ReadingChunk = true; + + auto consumedBytes = std::min(PendingChunkSize, buffer.size()); + PendingBuffer.insert(PendingBuffer.end(), buffer.data(), buffer.data() + consumedBytes); + advance(consumedBytes); + + WI_ASSERT(PendingChunkSize >= PendingChunkSize); + PendingChunkSize -= consumedBytes; + + if (PendingChunkSize == 0) + { + OnChunk(PendingBuffer); + PendingBuffer.clear(); + } + } + } +} + WriteHandle::WriteHandle(wil::unique_handle&& MovedHandle, const std::vector& Buffer) : Handle(std::move(MovedHandle)), Buffer(Buffer) { diff --git a/src/windows/common/relay.hpp b/src/windows/common/relay.hpp index b1b75f91a..610015328 100644 --- a/src/windows/common/relay.hpp +++ b/src/windows/common/relay.hpp @@ -220,7 +220,7 @@ class LineBasedReadHandle : public ReadHandle NON_COPYABLE(LineBasedReadHandle); NON_MOVABLE(LineBasedReadHandle); - LineBasedReadHandle(wil::unique_handle&& MovedHandle, std::function& Buffer)>&& OneLine); + LineBasedReadHandle(wil::unique_handle&& MovedHandle, std::function& Buffer)>&& OnLine); ~LineBasedReadHandle(); private: @@ -230,6 +230,24 @@ class LineBasedReadHandle : public ReadHandle std::string PendingBuffer; }; +class HTTPChunkBasedReadHandle : public LineBasedReadHandle +{ +public: + NON_COPYABLE(HTTPChunkBasedReadHandle); + NON_MOVABLE(HTTPChunkBasedReadHandle); + + HTTPChunkBasedReadHandle(wil::unique_handle&& MovedHandle, std::function& Buffer)>&& OnChunk); + ~HTTPChunkBasedReadHandle(); + +private: + void OnRead(const gsl::span& Line); + + std::function& Buffer)> OnChunk; + std::string PendingBuffer; + uint64_t PendingChunkSize = 0; + bool ReadingChunk = false; +}; + class WriteHandle : public OverlappedIOHandle { public: diff --git a/src/windows/wslaservice/exe/ContainerEventTracker.cpp b/src/windows/wslaservice/exe/ContainerEventTracker.cpp index e0a58d7dd..4667c766b 100644 --- a/src/windows/wslaservice/exe/ContainerEventTracker.cpp +++ b/src/windows/wslaservice/exe/ContainerEventTracker.cpp @@ -17,6 +17,7 @@ Module Name: #include using wsl::windows::service::wsla::ContainerEventTracker; +using wsl::windows::service::wsla::DockerHTTPClient; using wsl::windows::service::wsla::WSLAVirtualMachine; ContainerEventTracker::ContainerTrackingReference::ContainerTrackingReference(ContainerEventTracker* tracker, size_t id) : @@ -51,15 +52,12 @@ ContainerEventTracker::ContainerTrackingReference::~ContainerTrackingReference() Reset(); } -ContainerEventTracker::ContainerEventTracker(WSLAVirtualMachine& virtualMachine) +ContainerEventTracker::ContainerEventTracker(DockerHTTPClient& dockerClient) { - ServiceProcessLauncher launcher{nerdctlPath, {nerdctlPath, "events", "--format", "{{json .}}"}, {}, common::ProcessFlags::Stdout}; + auto socket = dockerClient.MonitorEvents(); + m_thread = std::thread([socket = std::move(socket), this]() mutable { Run(std::move(socket)); } - // Redirect stderr to /dev/null to avoid pipe deadlocks. - launcher.AddFd({.Fd = 2, .Type = WSLAFdTypeLinuxFileOutput, .Path = "/dev/null"}); - - auto process = launcher.Launch(virtualMachine); - m_thread = std::thread(std::bind(&ContainerEventTracker::Run, this, std::move(process))); + ); } void ContainerEventTracker::Stop() @@ -85,36 +83,29 @@ void ContainerEventTracker::OnEvent(const std::string& event) // TODO: log session ID WSL_LOG("NerdCtlEvent", TraceLoggingValue(event.c_str(), "Data")); - static std::map events{ - {"/tasks/create", ContainerEvent::Create}, - {"/tasks/start", ContainerEvent::Start}, - {"/tasks/stop", ContainerEvent::Stop}, - {"/tasks/exit", ContainerEvent::Exit}, - {"/tasks/destroy", ContainerEvent::Destroy}}; + static std::map events{{"start", ContainerEvent::Start}, {"die", ContainerEvent::Stop}}; auto parsed = nlohmann::json::parse(event); - auto type = parsed.find("Topic"); - auto details = parsed.find("Event"); + auto type = parsed.find("Type"); + auto action = parsed.find("Action"); + auto actor = parsed.find("Actor"); - THROW_HR_IF_MSG(E_INVALIDARG, type == parsed.end() || details == parsed.end(), "Failed to parse json: %hs", event.c_str()); + THROW_HR_IF_MSG( + E_INVALIDARG, type == parsed.end() || action == parsed.end() || actor == parsed.end(), "Failed to parse json: %hs", event.c_str()); - auto it = events.find(type->get()); + auto it = events.find(action->get()); if (it == events.end()) { return; // Event is not tracked, dropped. } - // N.B. The 'Event' field is a json string. - auto innerEventJson = details->get(); - auto innerEvent = nlohmann::json::parse(innerEventJson); - - auto containerIdIt = innerEvent.find("container_id"); - THROW_HR_IF_MSG(E_INVALIDARG, containerIdIt == innerEvent.end(), "Failed to parse json: %hs", innerEventJson.c_str()); + auto id = actor->find("ID"); + THROW_HR_IF_MSG(E_INVALIDARG, id == actor->end(), "Failed to parse json: %hs", event.c_str()); + auto containerId = id->get(); std::lock_guard lock{m_lock}; - std::string containerId = containerIdIt->get(); for (const auto& e : m_callbacks) { if (e.ContainerId == containerId) @@ -124,34 +115,32 @@ void ContainerEventTracker::OnEvent(const std::string& event) } } -void ContainerEventTracker::Run(ServiceRunningProcess& process) +void ContainerEventTracker::Run(wil::unique_socket&& socket) +try { - try - { - wsl::windows::common::relay::MultiHandleWait io; + wsl::windows::common::relay::MultiHandleWait io; - auto oneLineWritten = [&](const gsl::span& buffer) { - // nerdctl events' output is line based. Call OnEvent() for each completed line. + auto oneLineWritten = [&](const gsl::span& buffer) { + // nerdctl events' output is line based. Call OnEvent() for each completed line. - if (!buffer.empty()) // nerdctl inserts empty lines between events, skip those. - { - OnEvent(std::string{buffer.begin(), buffer.end()}); - } - }; + if (!buffer.empty()) // nerdctl inserts empty lines between events, skip those. + { + OnEvent(std::string{buffer.begin(), buffer.end()}); + } + }; - auto onStop = [&]() { io.Cancel(); }; + auto onStop = [&]() { io.Cancel(); }; - io.AddHandle(std::make_unique(process.GetStdHandle(1), std::move(oneLineWritten))); - io.AddHandle(std::make_unique(m_stopEvent.get(), std::move(onStop))); + io.AddHandle(std::make_unique(wil::unique_handle{(HANDLE)socket.release()}, std::move(oneLineWritten))); + io.AddHandle(std::make_unique(m_stopEvent.get(), std::move(onStop))); - if (io.Run({})) - { - // TODO: Report error to session. - WSL_LOG("Unexpected nerdctl exit"); - } + if (io.Run({})) + { + // TODO: Report error to session. + WSL_LOG("Unexpected nerdctl exit"); } - CATCH_LOG(); } +CATCH_LOG(); ContainerEventTracker::ContainerTrackingReference ContainerEventTracker::RegisterContainerStateUpdates( const std::string& ContainerId, ContainerStateChangeCallback&& Callback) diff --git a/src/windows/wslaservice/exe/ContainerEventTracker.h b/src/windows/wslaservice/exe/ContainerEventTracker.h index 20ec0a777..34e4db304 100644 --- a/src/windows/wslaservice/exe/ContainerEventTracker.h +++ b/src/windows/wslaservice/exe/ContainerEventTracker.h @@ -13,7 +13,7 @@ Module Name: --*/ #pragma once -#include "ServiceProcessLauncher.h" +#include "DockerHTTPClient.h" namespace wsl::windows::service::wsla { @@ -55,7 +55,7 @@ class ContainerEventTracker using ContainerStateChangeCallback = std::function; - ContainerEventTracker(WSLAVirtualMachine& virtualMachine); + ContainerEventTracker(DockerHTTPClient& dockerClient); ~ContainerEventTracker(); void Stop(); @@ -65,7 +65,7 @@ class ContainerEventTracker private: void OnEvent(const std::string& event); - void Run(ServiceRunningProcess& process); + void Run(wil::unique_socket&& Socket); struct Callback { diff --git a/src/windows/wslaservice/exe/DockerHTTPClient.cpp b/src/windows/wslaservice/exe/DockerHTTPClient.cpp index d177006f2..7954b0937 100644 --- a/src/windows/wslaservice/exe/DockerHTTPClient.cpp +++ b/src/windows/wslaservice/exe/DockerHTTPClient.cpp @@ -33,12 +33,19 @@ void DockerHTTPClient::ResizeContainerTty(const std::string& Id, ULONG Rows, ULO Transaction(verb::post, std::format("http://localhost/containers/{}/resize?w={}&h={}", Id, Columns, Rows)); } -DockerHTTPClient::RequestResult DockerHTTPClient::StartContainer(const std::string& Id) +void DockerHTTPClient::StartContainer(const std::string& Id) { - RequestResult result; - std::tie(result.StatusCode, result.ResponseString) = Transaction(verb::post, std::format("http://localhost/containers/{}/start", Id)); + Transaction(verb::post, std::format("http://localhost/containers/{}/start", Id)); +} + +void DockerHTTPClient::StopContainer(const std::string& Id, int Signal, ULONG TimeoutSeconds) +{ + Transaction(verb::post, std::format("http://localhost/containers/{}/stop?signal={}&t={}", Id, Signal, TimeoutSeconds)); +} - return result; +void DockerHTTPClient::DeleteContainer(const std::string& Id) +{ + Transaction(verb::delete_, std::format("http://localhost/containers/{}", Id)); } wil::unique_socket DockerHTTPClient::AttachContainer(const std::string& Id) @@ -46,10 +53,26 @@ wil::unique_socket DockerHTTPClient::AttachContainer(const std::string& Id) std::map headers{ {boost::beast::http::field::upgrade, "tcp"}, {boost::beast::http::field::connection, "upgrade"}}; - auto [status, socket] = SendRequest( - verb::post, std::format("http://localhost/containers/{}/attach?stream=1&stdin=1&stdout=1&stderr=1&logs=true", Id), {}, {}, headers); + auto url = std::format("http://localhost/containers/{}/attach?stream=1&stdin=1&stdout=1&stderr=1&logs=true", Id); + auto [status, socket] = SendRequest(verb::post, url, {}, {}, headers); + + if (status != 101) + { + throw DockerHTTPException(status, url, "", ""); + } - THROW_HR_IF_MSG(E_FAIL, status != 101, "Failed to attach to container %hs: %i", Id.c_str(), status); + return std::move(socket); +} + +wil::unique_socket DockerHTTPClient::MonitorEvents() +{ + auto url = "http://localhost/events"; + auto [status, socket] = SendRequest(verb::get, url, {}, {}); + + if (status != 200) + { + throw DockerHTTPException(status, url, "", ""); + } return std::move(socket); } @@ -80,7 +103,7 @@ wil::unique_socket DockerHTTPClient::ConnectSocket() return newChannel.Release(); } -std::pair DockerHTTPClient::Transaction(verb Method, const std::string& Url, const std::string& Body) +std::pair DockerHTTPClient::SendRequest(verb Method, const std::string& Url, const std::string& Body) { std::string responseBody; auto OnResponse = [&responseBody](const gsl::span& span) { responseBody.append(span.data(), span.size()); }; diff --git a/src/windows/wslaservice/exe/DockerHTTPClient.h b/src/windows/wslaservice/exe/DockerHTTPClient.h index fd9f6b034..31c4c2c53 100644 --- a/src/windows/wslaservice/exe/DockerHTTPClient.h +++ b/src/windows/wslaservice/exe/DockerHTTPClient.h @@ -8,10 +8,10 @@ namespace wsl::windows::service::wsla { -class DockerHTTPExceptions : std::runtime_error +class DockerHTTPException : std::runtime_error { public: - DockerHTTPExceptions(uint16_t StatusCode, const std::string& Url, const std::string& RequestContent, const std::string& ResponseContent) : + DockerHTTPException(uint16_t StatusCode, const std::string& Url, const std::string& RequestContent, const std::string& ResponseContent) : std::runtime_error(std::format("HTTP request failed: {} -> {} (Request: {}, Response: {})", Url, StatusCode, RequestContent, ResponseContent)), m_statusCode(StatusCode), m_url(Url), @@ -71,40 +71,43 @@ class DockerHTTPClient DockerHTTPClient(wsl::shared::SocketChannel&& Channel, HANDLE ExitingEvent, GUID VmId, ULONG ConnectTimeoutMs); docker_schema::CreatedContainer CreateContainer(const docker_schema::CreateContainer& Request); - RequestResult StartContainer(const std::string& Id); + void StartContainer(const std::string& Id); + void StopContainer(const std::string& Id, int Signal, ULONG TimeoutSeconds); + void DeleteContainer(const std::string& Id); wil::unique_socket AttachContainer(const std::string& Id); + wil::unique_socket MonitorEvents(); void ResizeContainerTty(const std::string& Id, ULONG Rows, ULONG Columns); uint32_t PullImage(const char* Name, const char* Tag, const OnImageProgress& Callback); + +private: + wil::unique_socket ConnectSocket(); + std::pair SendRequest( + boost::beast::http::verb Method, const std::string& Url, const std::string& Body = ""); + std::pair SendRequest( boost::beast::http::verb Method, const std::string& Url, - const std::string& Body = "", - const OnResponseBytes& OnResponse = {}, + const std::string& Body, + const OnResponseBytes& OnResponse, const std::map& Headers = {}); - std::pair Transaction( - boost::beast::http::verb Method, const std::string& Url, const std::string& Body = ""); - -private: - wil::unique_socket ConnectSocket(); - - template + template auto Transaction(boost::beast::http::verb Method, const std::string& Url, const TRequest& RequestObject = {}) { std::string requestString; - if constexpr (!std::is_same_v) + if constexpr (!std::is_same_v) { requestString = wsl::shared::ToJson(RequestObject); } - auto [statusCode, responseString] = Transaction(Method, Url, requestString); + auto [statusCode, responseString] = SendRequest(Method, Url, requestString); if (statusCode < 200 || statusCode >= 300) { - throw DockerHTTPExceptions(statusCode, Url, requestString, responseString); + throw DockerHTTPException(statusCode, Url, requestString, responseString); } if constexpr (!std::is_same_v) diff --git a/src/windows/wslaservice/exe/WSLAContainer.cpp b/src/windows/wslaservice/exe/WSLAContainer.cpp index 4596a3b9c..34f52c720 100644 --- a/src/windows/wslaservice/exe/WSLAContainer.cpp +++ b/src/windows/wslaservice/exe/WSLAContainer.cpp @@ -141,6 +141,7 @@ WSLAContainerImpl::WSLAContainerImpl( std::vector&& volumes, std::vector&& ports, std::function&& onDeleted, + ContainerEventTracker& EventTracker, DockerHTTPClient& DockerClient) : m_parentVM(parentVM), m_name(Options.Name), @@ -149,8 +150,9 @@ WSLAContainerImpl::WSLAContainerImpl( m_mountedVolumes(std::move(volumes)), m_mappedPorts(std::move(ports)), m_comWrapper(wil::MakeOrThrow(this, std::move(onDeleted))), - m_dockerClient(DockerClient) -{ + m_dockerClient(DockerClient), + m_containerEvents(EventTracker.RegisterContainerStateUpdates(m_id, std::bind(&WSLAContainerImpl::OnEvent, this, std::placeholders::_1))) +{ m_state = WslaContainerStateCreated; } @@ -162,6 +164,8 @@ WSLAContainerImpl::~WSLAContainerImpl() TraceLoggingValue(m_id.c_str(), "Id"), TraceLoggingValue((int)m_state, "State")); + m_containerEvents.Reset(); + // Disconnect from the COM instance. After this returns, no COM calls can be made to this instance. m_comWrapper->Disconnect(); @@ -222,8 +226,7 @@ void WSLAContainerImpl::Start() m_initProcess.emplace(std::string{m_id}, wil::unique_handle{(HANDLE)m_dockerClient.AttachContainer(m_id).release()}, true, m_dockerClient); auto cleanup = wil::scope_exit_log(WI_DIAGNOSTICS_INFO, [this]() mutable { m_initProcess.reset(); }); - auto result = m_dockerClient.StartContainer(m_id); - THROW_HR_IF_MSG(E_FAIL, result.StatusCode != 204, "Failed to start container: %hs, %hs", m_id.c_str(), result.Format().c_str()); + m_dockerClient.StartContainer(m_id); m_state = WslaContainerStateRunning; cleanup.release(); @@ -238,9 +241,10 @@ void WSLAContainerImpl::GetTtyHandle(ULONG* Handle) void WSLAContainerImpl::OnEvent(ContainerEvent event) { - if (event == ContainerEvent::Start) + if (event == ContainerEvent::Stop) { - m_startedEvent.SetEvent(); + std::lock_guard lock(m_lock); + m_state = WslaContainerStateExited; } WSL_LOG( @@ -259,30 +263,7 @@ void WSLAContainerImpl::Stop(int Signal, ULONG TimeoutMs) return; } - /* 'nerdctl stop ...' - * returns success and on stdout if the container is running or already stopped - * returns error "No such container: " on stderr if the container is in 'Created' state or does not exist - * - * For our case, we treat stopping an already-exited container as a no-op and return success. - * Stopping a deleted or created container returns ERROR_INVALID_STATE. - * TODO: Discuss and return stdout/stderr or corresponding HRESULT from nerdctl stop for better diagnostics. - */ - - // Validate that the container is in the running state. - THROW_HR_IF_MSG( - HRESULT_FROM_WIN32(ERROR_INVALID_STATE), - m_state != WslaContainerStateRunning, - "Container '%hs' is not in a stoppable state: %i", - m_name.c_str(), - m_state); - ServiceProcessLauncher launcher( - nerdctlPath, {nerdctlPath, "stop", m_name, "--time", std::to_string(static_cast(std::round(TimeoutMs / 1000)))}, defaultNerdctlEnv); - - // TODO: Figure out how we want to handle custom signals. - // nerdctl stop has a --time and a --signal option that can be used - // By default, it uses SIGTERM and a default timeout of 10 seconds. - auto result = launcher.Launch(*m_parentVM).WaitAndCaptureOutput(); - THROW_HR_IF_MSG(E_FAIL, result.Code != 0, "%hs", launcher.FormatResult(result).c_str()); + m_dockerClient.StopContainer(m_id, Signal, static_cast(std::round(TimeoutMs / 1000))); m_state = WslaContainerStateExited; } @@ -299,9 +280,7 @@ void WSLAContainerImpl::Delete() m_name.c_str(), m_state); - ServiceProcessLauncher launcher(nerdctlPath, {nerdctlPath, "rm", "-f", m_name}, defaultNerdctlEnv); - auto result = launcher.Launch(*m_parentVM).WaitAndCaptureOutput(deleteTimeout); - THROW_HR_IF_MSG(E_FAIL, result.Code != 0, "%hs", launcher.FormatResult(result).c_str()); + m_dockerClient.DeleteContainer(m_id); UnmountVolumes(m_mountedVolumes, *m_parentVM); @@ -455,6 +434,7 @@ std::unique_ptr WSLAContainerImpl::Create( const WSLA_CONTAINER_OPTIONS& containerOptions, WSLAVirtualMachine& parentVM, std::function&& OnDeleted, + ContainerEventTracker& EventTracker, DockerHTTPClient& DockerClient) { // TODO: Think about when 'StdinOnce' should be set. @@ -471,7 +451,7 @@ std::unique_ptr WSLAContainerImpl::Create( // N.B. mappedPorts is explicitly copied because it's referenced in errorCleanup, so it can't be moved. auto container = std::make_unique( - &parentVM, containerOptions, std::move(result.Id), std::move(volumes), std::vector(*mappedPorts), std::move(OnDeleted), DockerClient); + &parentVM, containerOptions, std::move(result.Id), std::move(volumes), std::vector(*mappedPorts), std::move(OnDeleted), EventTracker, DockerClient); errorCleanup.release(); diff --git a/src/windows/wslaservice/exe/WSLAContainer.h b/src/windows/wslaservice/exe/WSLAContainer.h index 6141d8f72..02d521b41 100644 --- a/src/windows/wslaservice/exe/WSLAContainer.h +++ b/src/windows/wslaservice/exe/WSLAContainer.h @@ -55,6 +55,7 @@ class WSLAContainerImpl std::vector&& volumes, std::vector&& ports, std::function&& OnDeleted, + ContainerEventTracker& EventTracker, DockerHTTPClient& DockerClient); ~WSLAContainerImpl(); @@ -76,6 +77,7 @@ class WSLAContainerImpl const WSLA_CONTAINER_OPTIONS& Options, WSLAVirtualMachine& parentVM, std::function&& OnDeleted, + ContainerEventTracker& EventTracker, DockerHTTPClient& DockerClient); private: @@ -85,7 +87,6 @@ class WSLAContainerImpl std::optional GetNerdctlStatus(); std::recursive_mutex m_lock; - wil::unique_event m_startedEvent{wil::EventOptions::ManualReset}; std::string m_name; std::string m_image; std::string m_id; @@ -96,6 +97,7 @@ class WSLAContainerImpl std::vector m_mountedVolumes; Microsoft::WRL::ComPtr m_comWrapper; std::optional m_initProcess; + ContainerEventTracker::ContainerTrackingReference m_containerEvents; static std::vector PrepareNerdctlCreateCommand( const WSLA_CONTAINER_OPTIONS& options, std::vector&& inputOptions, std::vector& volumes); diff --git a/src/windows/wslaservice/exe/WSLAContainerProcess.cpp b/src/windows/wslaservice/exe/WSLAContainerProcess.cpp new file mode 100644 index 000000000..fd6cb4215 --- /dev/null +++ b/src/windows/wslaservice/exe/WSLAContainerProcess.cpp @@ -0,0 +1,75 @@ +#include "precomp.h" +#include "WSLAContainerProcess.h" + +using wsl::windows::service::wsla::WSLAContainerProcess; + +WSLAContainerProcess::WSLAContainerProcess(std::string&& Id, wil::unique_handle&& IoStream, bool Tty, DockerHTTPClient& client) : + m_id(std::move(Id)), m_ioStream(std::move(IoStream)), m_dockerClient(client), m_tty(Tty) +{ +} + +HRESULT WSLAContainerProcess::Signal(_In_ int Signal) +{ + return E_NOTIMPL; +} + +HRESULT WSLAContainerProcess::GetExitEvent(_Out_ ULONG* Event) +{ + return E_NOTIMPL; +} + +HRESULT WSLAContainerProcess::GetStdHandle(_In_ ULONG Index, _Out_ ULONG* Handle) +try +{ + std::lock_guard lock{m_mutex}; + + auto& socket = GetStdHandle(Index); + RETURN_HR_IF(HRESULT_FROM_WIN32(ERROR_INVALID_STATE), !socket.is_valid()); + + *Handle = HandleToUlong(common::wslutil::DuplicateHandleToCallingProcess(socket.get())); + WSL_LOG( + "GetStdHandle", + TraceLoggingValue(Index, "fd"), + TraceLoggingValue(socket.get(), "handle"), + TraceLoggingValue(*Handle, "remoteHandle")); + + socket.reset(); + return S_OK; + + return E_NOTIMPL; +} +CATCH_RETURN(); + +HRESULT WSLAContainerProcess::GetPid(_Out_ int* Pid) +{ + return E_NOTIMPL; +} + +HRESULT WSLAContainerProcess::GetState(_Out_ WSLA_PROCESS_STATE* State, _Out_ int* Code) +{ + return E_NOTIMPL; +} + +HRESULT WSLAContainerProcess::ResizeTty(_In_ ULONG Rows, _In_ ULONG Columns) +try +{ + std::lock_guard lock{m_mutex}; + RETURN_HR_IF(E_INVALIDARG, !m_tty); + + m_dockerClient.ResizeContainerTty(m_id, Rows, Columns); + + return S_OK; +} +CATCH_RETURN(); + +wil::unique_handle& WSLAContainerProcess::GetStdHandle(int Index) +{ + std::lock_guard lock{m_mutex}; + + if (Index == 0 && m_tty) + { + return m_ioStream; + } + + return m_ioStream; // TODO: fix +} \ No newline at end of file diff --git a/src/windows/wslaservice/exe/WSLAContainerProcess.h b/src/windows/wslaservice/exe/WSLAContainerProcess.h new file mode 100644 index 000000000..389df6bd8 --- /dev/null +++ b/src/windows/wslaservice/exe/WSLAContainerProcess.h @@ -0,0 +1,31 @@ +#pragma once + +#include "DockerHTTPClient.h" +#include "wslaservice.h" + +namespace wsl::windows::service::wsla { + +class DECLSPEC_UUID("3A5DB29D-6D1D-4619-B89D-578EB34C8E52") WSLAContainerProcess + : public Microsoft::WRL::RuntimeClass, IWSLAProcess, IFastRundown> +{ +public: + WSLAContainerProcess(std::string&& Id, wil::unique_handle&& IoStream, bool Tty, DockerHTTPClient& client); + + IFACEMETHOD(Signal)(_In_ int Signal) override; + IFACEMETHOD(GetExitEvent)(_Out_ ULONG* Event) override; + IFACEMETHOD(GetStdHandle)(_In_ ULONG Index, _Out_ ULONG* Handle) override; + IFACEMETHOD(GetPid)(_Out_ int* Pid) override; + IFACEMETHOD(GetState)(_Out_ WSLA_PROCESS_STATE* State, _Out_ int* Code) override; + IFACEMETHOD(ResizeTty)(_In_ ULONG Rows, _In_ ULONG Columns) override; + +private: + wil::unique_handle& GetStdHandle(int Index); + + wil::unique_handle m_ioStream; + DockerHTTPClient& m_dockerClient; + bool m_tty = false; + std::string m_id; + std::recursive_mutex m_mutex; +}; + +} // namespace wsl::windows::service::wsla \ No newline at end of file diff --git a/src/windows/wslaservice/exe/WSLASession.cpp b/src/windows/wslaservice/exe/WSLASession.cpp index 4214bd93d..b4541eab0 100644 --- a/src/windows/wslaservice/exe/WSLASession.cpp +++ b/src/windows/wslaservice/exe/WSLASession.cpp @@ -69,9 +69,8 @@ WSLASession::WSLASession(ULONG id, const WSLA_SESSION_SETTINGS& Settings, WSLAUs m_dockerClient.emplace(std::move(channel), m_virtualMachine->ExitingEvent(), m_virtualMachine->VmId(), 10 * 1000); - // WSL_LOG("Info", TraceLoggingValue(response.c_str(), "DockerInfo")); // Start the event tracker. - // m_eventTracker.emplace(*m_virtualMachine.Get()); + m_eventTracker.emplace(m_dockerClient.value()); errorCleanup.release(); } @@ -417,6 +416,7 @@ try *containerOptions, *m_virtualMachine.Get(), std::bind(&WSLASession::OnContainerDeleted, this, std::placeholders::_1), + m_eventTracker.value(), m_dockerClient.value())); WI_ASSERT(inserted); diff --git a/src/windows/wslaservice/exe/docker_schema.h b/src/windows/wslaservice/exe/docker_schema.h index 22358fd35..859ed055a 100644 --- a/src/windows/wslaservice/exe/docker_schema.h +++ b/src/windows/wslaservice/exe/docker_schema.h @@ -19,13 +19,9 @@ struct ErrorResponse NLOHMANN_DEFINE_TYPE_INTRUSIVE(ErrorResponse, Message); }; -struct EmtpyResponse -{ -}; - struct EmtpyRequest { - using TResponse = EmtpyResponse; + using TResponse = void; }; struct CreateContainer From 99fd48daec311293e42fc12adfbce79178873785 Mon Sep 17 00:00:00 2001 From: Blue Date: Mon, 29 Dec 2025 14:17:24 -0800 Subject: [PATCH 09/26] Save state --- src/windows/common/WslClient.cpp | 21 +- src/windows/common/relay.cpp | 271 ++++++++++++++++-- src/windows/common/relay.hpp | 84 +++++- .../wslaservice/exe/ContainerEventTracker.cpp | 2 +- src/windows/wslaservice/exe/WSLAContainer.cpp | 20 +- .../wslaservice/exe/WSLAContainerProcess.cpp | 106 ++++++- .../wslaservice/exe/WSLAContainerProcess.h | 10 + src/windows/wslaservice/exe/docker_schema.h | 2 + 8 files changed, 473 insertions(+), 43 deletions(-) diff --git a/src/windows/common/WslClient.cpp b/src/windows/common/WslClient.cpp index a14883f00..97c56f961 100644 --- a/src/windows/common/WslClient.cpp +++ b/src/windows/common/WslClient.cpp @@ -1542,6 +1542,7 @@ int WslaShell(_In_ std::wstring_view commandLine) std::string containerImage; bool help = false; + bool noTty = false; std::wstring debugShell; std::wstring storagePath; @@ -1561,6 +1562,7 @@ int WslaShell(_In_ std::wstring_view commandLine) parser.AddArgument(Integer(reinterpret_cast(sessionSettings.NetworkingMode)), L"--networking-mode"); parser.AddArgument(Utf8String(containerImage), L"--image"); parser.AddArgument(debugShell, L"--debug-shell"); + parser.AddArgument(noTty, L"--no-tty"); parser.AddArgument(help, L"--help"); parser.Parse(); @@ -1645,11 +1647,19 @@ int WslaShell(_In_ std::wstring_view commandLine) { THROW_IF_FAILED(session->PullImage(containerImage.c_str(), nullptr, nullptr)); - std::vector fds{ - WSLA_PROCESS_FD{.Fd = 0, .Type = WSLAFdTypeTerminalInput}, - WSLA_PROCESS_FD{.Fd = 1, .Type = WSLAFdTypeTerminalOutput}, - WSLA_PROCESS_FD{.Fd = 2, .Type = WSLAFdTypeTerminalControl}, - }; + std::vector fds; + + if (noTty) + { + fds.emplace_back(WSLA_PROCESS_FD{.Fd = 0, .Type = WSLAFdTypeDefault}); + fds.emplace_back(WSLA_PROCESS_FD{.Fd = 1, .Type = WSLAFdTypeDefault}); + fds.emplace_back(WSLA_PROCESS_FD{.Fd = 2, .Type = WSLAFdTypeDefault}); + } + else + { + fds.emplace_back(WSLA_PROCESS_FD{.Fd = 0, .Type = WSLAFdTypeTerminalInput}); + fds.emplace_back(WSLA_PROCESS_FD{.Fd = 1, .Type = WSLAFdTypeTerminalOutput}); + } WSLA_CONTAINER_OPTIONS containerOptions{}; containerOptions.Image = containerImage.c_str(); @@ -1666,7 +1676,6 @@ int WslaShell(_In_ std::wstring_view commandLine) wil::com_ptr initProcess; THROW_IF_FAILED((*container)->GetInitProcess(&initProcess)); process.emplace(std::move(initProcess), std::move(fds)); - } // Save original console modes so they can be restored on exit. diff --git a/src/windows/common/relay.cpp b/src/windows/common/relay.cpp index be79057ff..9a81765a5 100644 --- a/src/windows/common/relay.cpp +++ b/src/windows/common/relay.cpp @@ -16,13 +16,16 @@ Module Name: #include "relay.hpp" #pragma hdrstop +using wsl::windows::common::relay::DockerIORelayHandle; using wsl::windows::common::relay::EventHandle; +using wsl::windows::common::relay::HandleWrapper; using wsl::windows::common::relay::HTTPChunkBasedReadHandle; using wsl::windows::common::relay::IOHandleStatus; using wsl::windows::common::relay::LineBasedReadHandle; using wsl::windows::common::relay::MultiHandleWait; using wsl::windows::common::relay::OverlappedIOHandle; using wsl::windows::common::relay::ReadHandle; +using wsl::windows::common::relay::RelayHandle; using wsl::windows::common::relay::ScopedMultiRelay; using wsl::windows::common::relay::ScopedRelay; using wsl::windows::common::relay::WriteHandle; @@ -1038,7 +1041,7 @@ HANDLE EventHandle::GetHandle() const return Handle; } -ReadHandle::ReadHandle(wil::unique_handle&& MovedHandle, std::function& Buffer)>&& OnRead) : +ReadHandle::ReadHandle(HandleWrapper&& MovedHandle, std::function& Buffer)>&& OnRead) : Handle(std::move(MovedHandle)), OnRead(OnRead) { Overlapped.hEvent = Event.get(); @@ -1049,9 +1052,9 @@ ReadHandle::~ReadHandle() if (State == IOHandleStatus::Pending) { DWORD bytesRead{}; - if (CancelIoEx(Handle.get(), &Overlapped)) + if (CancelIoEx(Handle.Get(), &Overlapped)) { - LOG_LAST_ERROR_IF(!GetOverlappedResult(Handle.get(), &Overlapped, &bytesRead, true) && GetLastError() != ERROR_CONNECTION_ABORTED); + LOG_LAST_ERROR_IF(!GetOverlappedResult(Handle.Get(), &Overlapped, &bytesRead, true) && GetLastError() != ERROR_CONNECTION_ABORTED); } else { @@ -1069,7 +1072,7 @@ void ReadHandle::Schedule() // Schedule the read. DWORD bytesRead{}; - if (ReadFile(Handle.get(), Buffer.data(), static_cast(Buffer.size()), &bytesRead, &Overlapped)) + if (ReadFile(Handle.Get(), Buffer.data(), static_cast(Buffer.size()), &bytesRead, &Overlapped)) { // Signal the read. OnRead(gsl::make_span(Buffer.data(), static_cast(bytesRead))); @@ -1092,7 +1095,7 @@ void ReadHandle::Schedule() return; } - THROW_LAST_ERROR_IF_MSG(error != ERROR_IO_PENDING, "Handle: 0x%p", (void*)Handle.get()); + THROW_LAST_ERROR_IF_MSG(error != ERROR_IO_PENDING, "Handle: 0x%p", (void*)Handle.Get()); // The read is pending, update to 'Pending' State = IOHandleStatus::Pending; @@ -1108,7 +1111,7 @@ void ReadHandle::Collect() // Complete the read. DWORD bytesRead{}; - if (!GetOverlappedResult(Handle.get(), &Overlapped, &bytesRead, false)) + if (!GetOverlappedResult(Handle.Get(), &Overlapped, &bytesRead, false)) { auto error = GetLastError(); THROW_WIN32_IF(error, error != ERROR_HANDLE_EOF && error != ERROR_BROKEN_PIPE); @@ -1132,8 +1135,8 @@ HANDLE ReadHandle::GetHandle() const return Event.get(); } -LineBasedReadHandle::LineBasedReadHandle(wil::unique_handle&& MovedHandle, std::function& Line)>&& OnLine) : - ReadHandle(std::move(MovedHandle), [this](const gsl::span& Buffer) { OnRead(Buffer); }), OnLine(OnLine) +LineBasedReadHandle::LineBasedReadHandle(HandleWrapper&& Handle, std::function& Line)>&& OnLine) : + ReadHandle(std::move(Handle), [this](const gsl::span& Buffer) { OnRead(Buffer); }), OnLine(OnLine) { } @@ -1164,7 +1167,7 @@ void LineBasedReadHandle::OnRead(const gsl::span& Buffer) PendingBuffer.insert(PendingBuffer.end(), begin, end); } -HTTPChunkBasedReadHandle::HTTPChunkBasedReadHandle(wil::unique_handle&& MovedHandle, std::function& Line)>&& OnChunk) : +HTTPChunkBasedReadHandle::HTTPChunkBasedReadHandle(HandleWrapper&& MovedHandle, std::function& Line)>&& OnChunk) : LineBasedReadHandle(std::move(MovedHandle), [this](const gsl::span& Buffer) { OnRead(Buffer); }), OnChunk(OnChunk) { } @@ -1239,7 +1242,7 @@ void HTTPChunkBasedReadHandle::OnRead(const gsl::span& Input) } } -WriteHandle::WriteHandle(wil::unique_handle&& MovedHandle, const std::vector& Buffer) : +WriteHandle::WriteHandle(HandleWrapper&& MovedHandle, const std::vector& Buffer) : Handle(std::move(MovedHandle)), Buffer(Buffer) { Overlapped.hEvent = Event.get(); @@ -1250,9 +1253,9 @@ WriteHandle::~WriteHandle() if (State == IOHandleStatus::Pending) { DWORD bytesRead{}; - if (CancelIoEx(Handle.get(), &Overlapped)) + if (CancelIoEx(Handle.Get(), &Overlapped)) { - LOG_LAST_ERROR_IF(!GetOverlappedResult(Handle.get(), &Overlapped, &bytesRead, true) && GetLastError() != ERROR_CONNECTION_ABORTED); + LOG_LAST_ERROR_IF(!GetOverlappedResult(Handle.Get(), &Overlapped, &bytesRead, true) && GetLastError() != ERROR_CONNECTION_ABORTED); } else { @@ -1270,10 +1273,10 @@ void WriteHandle::Schedule() // Schedule the write. DWORD bytesWritten{}; - if (WriteFile(Handle.get(), Buffer.data() + Offset, static_cast(Buffer.size() - Offset), &bytesWritten, &Overlapped)) + if (WriteFile(Handle.Get(), Buffer.data(), static_cast(Buffer.size()), &bytesWritten, &Overlapped)) { - Offset += bytesWritten; - if (Offset >= Buffer.size()) + Buffer.erase(Buffer.begin(), Buffer.begin() + bytesWritten); + if (Buffer.empty()) { State = IOHandleStatus::Completed; } @@ -1281,7 +1284,7 @@ void WriteHandle::Schedule() else { auto error = GetLastError(); - THROW_LAST_ERROR_IF_MSG(error != ERROR_IO_PENDING, "Handle: 0x%p", (void*)Handle.get()); + THROW_LAST_ERROR_IF_MSG(error != ERROR_IO_PENDING, "Handle: 0x%p", (void*)Handle.Get()); // The write is pending, update to 'Pending' State = IOHandleStatus::Pending; @@ -1297,16 +1300,246 @@ void WriteHandle::Collect() // Complete the write. DWORD bytesWritten{}; - THROW_IF_WIN32_BOOL_FALSE(GetOverlappedResult(Handle.get(), &Overlapped, &bytesWritten, false)); + THROW_IF_WIN32_BOOL_FALSE(GetOverlappedResult(Handle.Get(), &Overlapped, &bytesWritten, false)); - Offset += bytesWritten; - if (Offset >= Buffer.size()) + Buffer.erase(Buffer.begin(), Buffer.begin() + bytesWritten); + if (Buffer.empty()) { State = IOHandleStatus::Completed; } } +void WriteHandle::Push(const gsl::span& Content) +{ + // Don't write if a WriteFile() is pending, since that could cause the buffer to reallocate. + WI_ASSERT(State == IOHandleStatus::Standby || State == IOHandleStatus::Completed); + WI_ASSERT(!Content.empty()); + + Buffer.insert(Buffer.end(), Content.begin(), Content.end()); + + State = IOHandleStatus::Standby; +} + HANDLE WriteHandle::GetHandle() const { return Event.get(); +} + +RelayHandle::RelayHandle(HandleWrapper&& ReadHandle, HandleWrapper&& WriteHandle) : + Read(std::move(ReadHandle), [this](const gsl::span& Buffer) { return OnRead(Buffer); }), Write(std::move(WriteHandle)) +{ +} + +void RelayHandle::Schedule() +{ + WI_ASSERT(State == IOHandleStatus::Standby); + + // If the Buffer is empty, then we're reading. + if (PendingBuffer.empty()) + { + // If the output buffer is empty and the reading end is completed, then we're done. + if (Read.GetState() == IOHandleStatus::Completed) + { + State = IOHandleStatus::Completed; + return; + } + + Read.Schedule(); + + // If the read is pending, update to 'Pending' + if (Read.GetState() == IOHandleStatus::Pending) + { + State = IOHandleStatus::Pending; + } + } + else + { + Write.Push(PendingBuffer); + PendingBuffer.clear(); + + Write.Schedule(); + + if (Write.GetState() == IOHandleStatus::Pending) + { + // The write is pending, update to 'Pending' + State = IOHandleStatus::Pending; + } + } +} + +void RelayHandle::OnRead(const gsl::span& Content) +{ + WI_ASSERT(PendingBuffer.empty()); + + PendingBuffer.insert(PendingBuffer.end(), Content.begin(), Content.end()); +} + +void RelayHandle::Collect() +{ + WI_ASSERT(State == IOHandleStatus::Pending); + + // Transition back to standby + State = IOHandleStatus::Standby; + + if (Read.GetState() == IOHandleStatus::Pending) + { + Read.Collect(); + } + else + { + WI_ASSERT(Write.GetState() == IOHandleStatus::Pending); + Write.Collect(); + } +} + +HANDLE RelayHandle::GetHandle() const +{ + if (Read.GetState() == IOHandleStatus::Pending) + { + return Read.GetHandle(); + } + else + { + WI_ASSERT(Write.GetState() == IOHandleStatus::Pending); + return Write.GetHandle(); + } +} + +DockerIORelayHandle::DockerIORelayHandle(HandleWrapper&& ReadHandle, HandleWrapper&& Stdout, HandleWrapper&& Stderr) : + Read(std::move(ReadHandle), [this](const gsl::span& Buffer) { return OnRead(Buffer); }), + WriteStdout(std::move(Stdout)), + WriteStderr(std::move(Stderr)) +{ +} + +void DockerIORelayHandle::Schedule() +{ + WI_ASSERT(State == IOHandleStatus::Standby); + + // If we have an active handle and a buffer, try to flush that first. + if (ActiveHandle != nullptr) + { + // Push the data to the selected handle. + DWORD bytesToWrite = std::min(static_cast(RemainingBytes), static_cast(PendingBuffer.size())); + + ActiveHandle->Push(gsl::make_span(PendingBuffer.data(), bytesToWrite)); + + // Consume the written bytes. + RemainingBytes -= bytesToWrite; + PendingBuffer.erase(PendingBuffer.begin(), PendingBuffer.begin() + bytesToWrite); + + // Schedule the write. + ActiveHandle->Schedule(); + + // If the write is pending, update to 'Pending' + if (ActiveHandle->GetState() == IOHandleStatus::Pending) + { + State = IOHandleStatus::Pending; + } + else if (ActiveHandle->GetState() == IOHandleStatus::Completed) + { + // Switch back to reading if we've written all bytes for this chunk. + ActiveHandle = nullptr; + } + } + else + { + // Schedule a read from the input. + Read.Schedule(); + if (Read.GetState() == IOHandleStatus::Pending) + { + State = IOHandleStatus::Pending; + } + } +} + +void DockerIORelayHandle::Collect() +{ + WI_ASSERT(State == IOHandleStatus::Pending); + + if (ActiveHandle != nullptr) + { + // Complete the write. + ActiveHandle->Collect(); + + // If the write is completed, switch back to reading. + if (ActiveHandle->GetState() == IOHandleStatus::Completed) + { + ActiveHandle = nullptr; + } + + // Transition back to standby if there's still data to read. + // Otherwise switch to Completed since everything is done. + if (Read.GetState() == IOHandleStatus::Completed) + { + State = IOHandleStatus::Completed; + } + else + { + State = IOHandleStatus::Standby; + } + } + else + { + // Complete the read. + Read.Collect(); + + // Transition back to standby. + State = IOHandleStatus::Standby; + } +} + +HANDLE DockerIORelayHandle::GetHandle() const +{ + if (ActiveHandle != nullptr) + { + return ActiveHandle->GetHandle(); + } + else + { + return Read.GetHandle(); + } +} + +void DockerIORelayHandle::OnRead(const gsl::span& Buffer) +{ + +#pragma pack(push, 1) + struct MultiplexedHeader + { + uint8_t Fd; + char Zeroes[3]; + uint32_t Length; + }; +#pragma pack(pop) + + static_assert(sizeof(MultiplexedHeader) == 8); + + PendingBuffer.insert(PendingBuffer.end(), Buffer.begin(), Buffer.end()); + + if (ActiveHandle == nullptr) + { + // If no handle is active, expect a header. + if (PendingBuffer.size() < sizeof(MultiplexedHeader)) + { + // Not enough data for a header yet. + return; + } + + const auto* header = reinterpret_cast(PendingBuffer.data()); + RemainingBytes = ntohl(header->Length); + + if (header->Fd == 1) + { + ActiveHandle = &WriteStdout; + } + else if (header->Fd == 2) + { + ActiveHandle = &WriteStderr; + } + else + { + THROW_HR_MSG(E_UNEXPECTED, "Unexpected Docker IO multiplexed header fd: %u", header->Fd); + } + } } \ No newline at end of file diff --git a/src/windows/common/relay.hpp b/src/windows/common/relay.hpp index 610015328..7ea4ffba8 100644 --- a/src/windows/common/relay.hpp +++ b/src/windows/common/relay.hpp @@ -159,6 +159,32 @@ enum class IOHandleStatus Completed }; +struct HandleWrapper +{ + HandleWrapper(wil::unique_handle&& handle) : OwnedHandle(std::move(handle)), Handle(OwnedHandle.get()) + { + } + + HandleWrapper(HANDLE handle) : Handle(handle) + { + } + + HANDLE Get() const + { + return Handle; + } + + void Reset() + { + OwnedHandle.reset(); + Handle = nullptr; + } + +private: + wil::unique_handle OwnedHandle; + HANDLE Handle{}; +}; + class OverlappedIOHandle { public: @@ -200,14 +226,14 @@ class ReadHandle : public OverlappedIOHandle NON_COPYABLE(ReadHandle); NON_MOVABLE(ReadHandle); - ReadHandle(wil::unique_handle&& MovedHandle, std::function& Buffer)>&& OnRead); + ReadHandle(HandleWrapper&& MovedHandle, std::function& Buffer)>&& OnRead); virtual ~ReadHandle(); void Schedule() override; void Collect() override; HANDLE GetHandle() const override; private: - wil::unique_handle Handle; + HandleWrapper Handle; std::function& Buffer)> OnRead; wil::unique_event Event{wil::EventOptions::ManualReset}; OVERLAPPED Overlapped{}; @@ -220,7 +246,7 @@ class LineBasedReadHandle : public ReadHandle NON_COPYABLE(LineBasedReadHandle); NON_MOVABLE(LineBasedReadHandle); - LineBasedReadHandle(wil::unique_handle&& MovedHandle, std::function& Buffer)>&& OnLine); + LineBasedReadHandle(HandleWrapper&& Handle, std::function& Buffer)>&& OnLine); ~LineBasedReadHandle(); private: @@ -236,7 +262,7 @@ class HTTPChunkBasedReadHandle : public LineBasedReadHandle NON_COPYABLE(HTTPChunkBasedReadHandle); NON_MOVABLE(HTTPChunkBasedReadHandle); - HTTPChunkBasedReadHandle(wil::unique_handle&& MovedHandle, std::function& Buffer)>&& OnChunk); + HTTPChunkBasedReadHandle(HandleWrapper&& Handler, std::function& Buffer)>&& OnChunk); ~HTTPChunkBasedReadHandle(); private: @@ -254,18 +280,60 @@ class WriteHandle : public OverlappedIOHandle NON_COPYABLE(WriteHandle); NON_MOVABLE(WriteHandle); - WriteHandle(wil::unique_handle&& MovedHandle, const std::vector& Buffer); + WriteHandle(HandleWrapper&& Handle, const std::vector& Buffer = {}); ~WriteHandle(); void Schedule() override; void Collect() override; HANDLE GetHandle() const override; + void Push(const gsl::span& Buffer); private: - wil::unique_handle Handle; + HandleWrapper Handle; wil::unique_event Event{wil::EventOptions::ManualReset}; OVERLAPPED Overlapped{}; - const std::vector& Buffer; - DWORD Offset = 0; + std::vector Buffer; +}; + +class RelayHandle : public OverlappedIOHandle +{ +public: + NON_COPYABLE(RelayHandle); + NON_MOVABLE(RelayHandle); + + RelayHandle(HandleWrapper&& Input, HandleWrapper&& Output); + + void Schedule() override; + void Collect() override; + HANDLE GetHandle() const override; + +private: + void OnRead(const gsl::span& Buffer); + + ReadHandle Read; + WriteHandle Write; + std::vector PendingBuffer; +}; + +class DockerIORelayHandle : public OverlappedIOHandle +{ +public: + NON_COPYABLE(DockerIORelayHandle); + NON_MOVABLE(DockerIORelayHandle); + + DockerIORelayHandle(HandleWrapper&& Input, HandleWrapper&& Stdout, HandleWrapper&& Stderr); + void Schedule() override; + void Collect() override; + HANDLE GetHandle() const override; + +private: + void OnRead(const gsl::span& Buffer); + + ReadHandle Read; + WriteHandle WriteStdout; + WriteHandle WriteStderr; + std::vector PendingBuffer; + WriteHandle* ActiveHandle = nullptr; + size_t RemainingBytes = 0; }; class MultiHandleWait diff --git a/src/windows/wslaservice/exe/ContainerEventTracker.cpp b/src/windows/wslaservice/exe/ContainerEventTracker.cpp index 4667c766b..f1d963e5b 100644 --- a/src/windows/wslaservice/exe/ContainerEventTracker.cpp +++ b/src/windows/wslaservice/exe/ContainerEventTracker.cpp @@ -81,7 +81,7 @@ ContainerEventTracker::~ContainerEventTracker() void ContainerEventTracker::OnEvent(const std::string& event) { // TODO: log session ID - WSL_LOG("NerdCtlEvent", TraceLoggingValue(event.c_str(), "Data")); + WSL_LOG("DockerEvent", TraceLoggingValue(event.c_str(), "Data")); static std::map events{{"start", ContainerEvent::Start}, {"die", ContainerEvent::Stop}}; diff --git a/src/windows/wslaservice/exe/WSLAContainer.cpp b/src/windows/wslaservice/exe/WSLAContainer.cpp index 34f52c720..5fa2633cc 100644 --- a/src/windows/wslaservice/exe/WSLAContainer.cpp +++ b/src/windows/wslaservice/exe/WSLAContainer.cpp @@ -245,6 +245,9 @@ void WSLAContainerImpl::OnEvent(ContainerEvent event) { std::lock_guard lock(m_lock); m_state = WslaContainerStateExited; + + // TODO: propagate exit code. + m_initProcess->OnExited(0); } WSL_LOG( @@ -440,8 +443,21 @@ std::unique_ptr WSLAContainerImpl::Create( // TODO: Think about when 'StdinOnce' should be set. auto [hasStdin, hasTty] = ParseFdStatus(containerOptions.InitProcessOptions); - auto result = DockerClient.CreateContainer( - {.Image = containerOptions.Image, .Tty = hasTty, .OpenStdin = true, .StdinOnce = true, .AttachStdin = false, .AttachStdout = false, .AttachStderr = false}); + docker_schema::CreateContainer request; + request.Image = containerOptions.Image; + + if (hasTty) + { + request.Tty = true; + } + + if (hasStdin) + { + request.OpenStdin = true; + request.StdinOnce = true; + } + + auto result = DockerClient.CreateContainer(request); // TODO: Rethink command line generation logic. std::vector dummy; diff --git a/src/windows/wslaservice/exe/WSLAContainerProcess.cpp b/src/windows/wslaservice/exe/WSLAContainerProcess.cpp index fd6cb4215..f7a085fa2 100644 --- a/src/windows/wslaservice/exe/WSLAContainerProcess.cpp +++ b/src/windows/wslaservice/exe/WSLAContainerProcess.cpp @@ -7,6 +7,17 @@ WSLAContainerProcess::WSLAContainerProcess(std::string&& Id, wil::unique_handle& m_id(std::move(Id)), m_ioStream(std::move(IoStream)), m_dockerClient(client), m_tty(Tty) { } +WSLAContainerProcess::~WSLAContainerProcess() +{ + + // TODO: consider moving this to a different class. + if (m_relayThread.has_value()) + { + m_exitRelayEvent.SetEvent(); + + m_relayThread->join(); + } +} HRESULT WSLAContainerProcess::Signal(_In_ int Signal) { @@ -15,7 +26,9 @@ HRESULT WSLAContainerProcess::Signal(_In_ int Signal) HRESULT WSLAContainerProcess::GetExitEvent(_Out_ ULONG* Event) { - return E_NOTIMPL; + *Event = HandleToUlong(common::wslutil::DuplicateHandleToCallingProcess(m_exitEvent.get())); + + return S_OK; } HRESULT WSLAContainerProcess::GetStdHandle(_In_ ULONG Index, _Out_ ULONG* Handle) @@ -23,17 +36,17 @@ try { std::lock_guard lock{m_mutex}; - auto& socket = GetStdHandle(Index); - RETURN_HR_IF(HRESULT_FROM_WIN32(ERROR_INVALID_STATE), !socket.is_valid()); + auto& handle = GetStdHandle(Index); + RETURN_HR_IF(HRESULT_FROM_WIN32(ERROR_INVALID_STATE), !handle.is_valid()); - *Handle = HandleToUlong(common::wslutil::DuplicateHandleToCallingProcess(socket.get())); + *Handle = HandleToUlong(common::wslutil::DuplicateHandleToCallingProcess(handle.get())); WSL_LOG( "GetStdHandle", TraceLoggingValue(Index, "fd"), - TraceLoggingValue(socket.get(), "handle"), + TraceLoggingValue(handle.get(), "handle"), TraceLoggingValue(*Handle, "remoteHandle")); - socket.reset(); + handle.reset(); return S_OK; return E_NOTIMPL; @@ -47,7 +60,18 @@ HRESULT WSLAContainerProcess::GetPid(_Out_ int* Pid) HRESULT WSLAContainerProcess::GetState(_Out_ WSLA_PROCESS_STATE* State, _Out_ int* Code) { - return E_NOTIMPL; + if (m_exitEvent.is_signaled()) + { + // TODO: Handle signals. + *State = WSLA_PROCESS_STATE::WslaProcessStateExited; + *Code = m_exitedCode; + } + else + { + *State = WSLA_PROCESS_STATE::WslaProcessStateRunning; + } + + return S_OK; } HRESULT WSLAContainerProcess::ResizeTty(_In_ ULONG Rows, _In_ ULONG Columns) @@ -66,10 +90,78 @@ wil::unique_handle& WSLAContainerProcess::GetStdHandle(int Index) { std::lock_guard lock{m_mutex}; + if (m_tty) + { + THROW_HR_IF_MSG(E_INVALIDARG, Index != 0, "Invalid fd index for tty process: %i", Index); + + return m_ioStream; + } + else + { + THROW_HR_IF_MSG(E_INVALIDARG, Index > m_relayedHandles->size(), "Invalid fd index for non-tty process: %i", Index); + + return m_relayedHandles->at(Index); + } + if (Index == 0 && m_tty) { return m_ioStream; } return m_ioStream; // TODO: fix +} + +void WSLAContainerProcess::RunIORelay(HANDLE exitEvent, wil::unique_handle&& stdinPipe, wil::unique_handle&& stdoutPipe, wil::unique_handle&& stderrPipe) +try +{ + common::relay::MultiHandleWait io; + + io.AddHandle(std::make_unique(exitEvent, [&]() { io.Cancel(); })); + io.AddHandle(std::make_unique(std::move(stdinPipe), m_ioStream.get())); + io.AddHandle(std::make_unique(m_ioStream.get(), std::move(stdoutPipe), std::move(stderrPipe))); + + io.Run({}); +} +CATCH_LOG(); + +void WSLAContainerProcess::StartIORelay() +{ + std::lock_guard lock{m_mutex}; + + WI_ASSERT(!m_relayThread.has_value()); + WI_ASSERT(!m_exitRelayEvent); + WI_ASSERT(!m_relayedHandles.has_value()); + + m_exitRelayEvent.create(wil::EventOptions::ManualReset); + m_relayedHandles.emplace(); + + auto createPipe = []() { + std::pair pipe; + THROW_IF_WIN32_BOOL_FALSE(CreatePipe(&pipe.first, &pipe.second, nullptr, 0)); + return pipe; + }; + + auto stdinPipe = createPipe(); + auto stdoutPipe = createPipe(); + auto stderrPipe = createPipe(); + + m_relayedHandles->emplace_back(std::move(stdinPipe.second)); + m_relayedHandles->emplace_back(std::move(stdoutPipe.first)); + m_relayedHandles->emplace_back(std::move(stdoutPipe.first)); + + m_relayThread.emplace([this, + event = m_exitRelayEvent.get(), + stdinPipe = std::move(stdinPipe.first), + stdoutPipe = std::move(stdoutPipe.second), + stderrPipe = std::move(stderrPipe.second)]() mutable { + RunIORelay(event, std::move(stdinPipe), std::move(stdoutPipe), std::move(stderrPipe)); + }); +} + +void WSLAContainerProcess::OnExited(int Code) +{ + WI_ASSERT(!m_exitEvent.is_signaled()); + + m_exitedCode = Code; + m_exitEvent.SetEvent(); } \ No newline at end of file diff --git a/src/windows/wslaservice/exe/WSLAContainerProcess.h b/src/windows/wslaservice/exe/WSLAContainerProcess.h index 389df6bd8..9b22ab745 100644 --- a/src/windows/wslaservice/exe/WSLAContainerProcess.h +++ b/src/windows/wslaservice/exe/WSLAContainerProcess.h @@ -10,6 +10,7 @@ class DECLSPEC_UUID("3A5DB29D-6D1D-4619-B89D-578EB34C8E52") WSLAContainerProcess { public: WSLAContainerProcess(std::string&& Id, wil::unique_handle&& IoStream, bool Tty, DockerHTTPClient& client); + ~WSLAContainerProcess(); IFACEMETHOD(Signal)(_In_ int Signal) override; IFACEMETHOD(GetExitEvent)(_Out_ ULONG* Event) override; @@ -18,13 +19,22 @@ class DECLSPEC_UUID("3A5DB29D-6D1D-4619-B89D-578EB34C8E52") WSLAContainerProcess IFACEMETHOD(GetState)(_Out_ WSLA_PROCESS_STATE* State, _Out_ int* Code) override; IFACEMETHOD(ResizeTty)(_In_ ULONG Rows, _In_ ULONG Columns) override; + void OnExited(int Code); + private: wil::unique_handle& GetStdHandle(int Index); + void StartIORelay(); + void RunIORelay(HANDLE exitEvent, wil::unique_handle&& stdinPipe, wil::unique_handle&& stdoutPipe, wil::unique_handle&& stderrPipe); wil::unique_handle m_ioStream; DockerHTTPClient& m_dockerClient; bool m_tty = false; + wil::unique_event m_exitEvent{wil::EventOptions::ManualReset}; + int m_exitedCode = -1; std::string m_id; + std::optional m_relayThread; + wil::unique_event m_exitRelayEvent; + std::optional> m_relayedHandles; std::recursive_mutex m_mutex; }; diff --git a/src/windows/wslaservice/exe/docker_schema.h b/src/windows/wslaservice/exe/docker_schema.h index 859ed055a..947defb43 100644 --- a/src/windows/wslaservice/exe/docker_schema.h +++ b/src/windows/wslaservice/exe/docker_schema.h @@ -40,4 +40,6 @@ struct CreateContainer NLOHMANN_DEFINE_TYPE_INTRUSIVE_ONLY_SERIALIZE(CreateContainer, Image, Cmd, Tty, OpenStdin, StdinOnce); }; + + } // namespace wsl::windows::service::wsla::docker_schema \ No newline at end of file From 982523656563df4865404a5fe080d2d0f30a0725 Mon Sep 17 00:00:00 2001 From: Blue Date: Mon, 29 Dec 2025 16:02:52 -0800 Subject: [PATCH 10/26] Implement non interactive IO relay --- nuget.config | 2 +- packages.config | 2 +- src/linux/init/WSLAInit.cpp | 7 +- src/windows/common/WslClient.cpp | 126 ++++++++++-------- src/windows/common/hvsocket.cpp | 3 + src/windows/common/relay.cpp | 10 ++ src/windows/common/relay.hpp | 22 ++- src/windows/wslaservice/exe/WSLAContainer.cpp | 17 +-- src/windows/wslaservice/exe/WSLAContainer.h | 3 +- .../wslaservice/exe/WSLAContainerProcess.cpp | 30 +++-- .../wslaservice/exe/WSLAContainerProcess.h | 2 +- src/windows/wslaservice/exe/docker_schema.h | 2 - src/windows/wslaservice/inc/wslaservice.idl | 1 - 13 files changed, 140 insertions(+), 87 deletions(-) diff --git a/nuget.config b/nuget.config index cf8d07af2..c0a0caabd 100644 --- a/nuget.config +++ b/nuget.config @@ -18,7 +18,7 @@ - +