diff --git a/msipackage/package.wix.in b/msipackage/package.wix.in index 3bdfd341d..24e7ab7d9 100644 --- a/msipackage/package.wix.in +++ b/msipackage/package.wix.in @@ -263,13 +263,11 @@ - - @@ -301,7 +299,6 @@ - diff --git a/src/shared/configfile/configfile.h b/src/shared/configfile/configfile.h index 6b4393a76..4749515d1 100644 --- a/src/shared/configfile/configfile.h +++ b/src/shared/configfile/configfile.h @@ -20,7 +20,7 @@ Parses .gitconfig-style properties files. #ifdef WIN32 #include -#include "wslservice.h" +#include "wslaservice.h" #include "ExecutionContext.h" #endif diff --git a/src/windows/common/svccommio.hpp b/src/windows/common/svccommio.hpp index 95de9b263..244921d3c 100644 --- a/src/windows/common/svccommio.hpp +++ b/src/windows/common/svccommio.hpp @@ -16,7 +16,7 @@ Module Name: #include #include -#include "wslservice.h" +#include "wslaservice.h" typedef struct _LXSS_STD_HANDLES_INFO { diff --git a/src/windows/wslaclient/DllMain.cpp b/src/windows/wslaclient/DllMain.cpp index ac193f79d..730b540b8 100644 --- a/src/windows/wslaclient/DllMain.cpp +++ b/src/windows/wslaclient/DllMain.cpp @@ -58,7 +58,7 @@ try THROW_IF_FAILED(CoCreateInstance(__uuidof(WSLAUserSession), nullptr, CLSCTX_LOCAL_SERVER, IID_PPV_ARGS(&session))); wsl::windows::common::security::ConfigureForCOMImpersonation(session.get()); - wil::com_ptr virtualMachineInstance; + /* wil::com_ptr virtualMachineInstance; VIRTUAL_MACHINE_SETTINGS settings{}; settings.DisplayName = UserSettings->DisplayName; @@ -86,7 +86,7 @@ try // Callback instance is now owned by the service. } - *reinterpret_cast(VirtualMachine) = virtualMachineInstance.detach(); + *reinterpret_cast(VirtualMachine) = virtualMachineInstance.detach();*/ return S_OK; } CATCH_RETURN(); @@ -426,4 +426,4 @@ try { return reinterpret_cast(VirtualMachine)->MountGpuLibraries(LibrariesMountPoint, DriversMountpoint, static_cast(Flags)); } -CATCH_RETURN(); \ No newline at end of file +CATCH_RETURN(); diff --git a/src/windows/wslaservice/exe/CMakeLists.txt b/src/windows/wslaservice/exe/CMakeLists.txt index 46260d14d..7a701b95f 100644 --- a/src/windows/wslaservice/exe/CMakeLists.txt +++ b/src/windows/wslaservice/exe/CMakeLists.txt @@ -33,4 +33,4 @@ target_link_libraries(wslaservice Synchronization.lib) target_precompile_headers(wslaservice REUSE_FROM common) -set_target_properties(wslaservice PROPERTIES FOLDER windows) \ No newline at end of file +set_target_properties(wslaservice PROPERTIES FOLDER windows) diff --git a/src/windows/wslaservice/exe/ContainerManager.cpp b/src/windows/wslaservice/exe/ContainerManager.cpp new file mode 100644 index 000000000..9c75a2df8 --- /dev/null +++ b/src/windows/wslaservice/exe/ContainerManager.cpp @@ -0,0 +1,266 @@ +#include "C:/Users/trivedipooja/source/repos/WSL/src/windows/common/CMakeFiles/common.dir/Debug/cmake_pch.hxx" +#include "ContainerManager.h" +#include +#include +#include + +namespace wsl::windows::service::wsla { + +ContainerManager::ContainerManager(WSLAVirtualMachine* pVM) : m_pVM(pVM) +{ +} + +ContainerManager::~ContainerManager() +{ +} + +ContainerResult ContainerManager::StartNewContainer(const ContainerOptions& options) +{ + // TODO: Fix! + std::string containerNameStr = "something";//Utils::GenerateUuidString(); + + ContainerResult containerResult; + int containerId = m_containerId++; + + { + std::lock_guard containersLock{m_containersLock}; + if (CheckPortConflicts(options.PortMappings)) + { + THROW_WIN32_MSG(ERROR_ADDRESS_ALREADY_ASSOCIATED, "Port is already in use."); + } + + m_containers.emplace(containerId, ContainerInfo{containerNameStr, ContainerState::Creating}); + } + + // Building nerdctl run command with args + std::string command = prepareNerdctlRunCommand(options.Image, containerNameStr, options); + + try + { + auto runningProcess = StartProcess(command, options.InitProcessOptions); + + containerResult.Result = S_OK; + containerResult.ContainerId = containerId; + containerResult.MainProcess.StdIn = runningProcess.StdIn; + containerResult.MainProcess.StdOut = runningProcess.StdOut; + containerResult.MainProcess.StdErr = runningProcess.StdErr; + containerResult.MainProcess.Pid = runningProcess.Pid; + + if (SUCCEEDED(containerResult.Result)) + { + std::lock_guard containersLock{m_containersLock}; + m_containers[containerId].State = ContainerState::Created; + + // Adding port mapping + for (const auto& portMapping : options.PortMappings) + { + THROW_IF_FAILED(m_pVM->MapPort(portMapping.AddressFamily, portMapping.WindowsPort, portMapping.LinuxPort, false)); + m_containers[containerId].PortMappings.push_back(portMapping); + } + } + return containerResult; + } + catch (const wil::ResultException& re) + { + // TODO: LOG_CAUGHT_EXCEPTION() when logging is enabled + + std::ignore = StopContainer(containerId, true); + + containerResult.Result = re.GetErrorCode(); + return containerResult; + } +} + +ContainerResult ContainerManager::StartContainer(const int containerId) +{ + ContainerResult containerResult; + containerResult.Result = E_NOTIMPL; // TODO: Implement + return containerResult; +} + +HRESULT ContainerManager::StopContainer(const int containerId, bool remove) +{ + return S_OK; + /* + std::lock_guard containersLock{m_containersLock}; + auto it = m_containers.find(containerId); + THROW_HR_IF_MSG(E_INVALIDARG, it == m_containers.end(), "Container with id %d not found.", containerId); + if (it->second.State != ContainerState::Started) + { + RETURN_HR(S_OK); + } + // Building nerdctl stop command with args + std::vector args{"stop", "--time=10", it->second.Name}; + auto runningProcess = StartProcess(args); + auto processResult = WaitForProcess(runningProcess, 20000); + RETURN_HR_IF_MSG(processResult.Result, FAILED(processResult.Result), "Failed to stop container %d", containerId); + if (remove) + { + // Building nerdctl rm command with args + std::vector rmArgs{"rm", it->second.Name}; + auto rmProcess = StartProcess(rmArgs); + auto rmProcessResult = WaitForProcess(rmProcess, 20000); + RETURN_HR_IF_MSG(rmProcessResult.Result, FAILED(rmProcessResult.Result), "Failed to remove container %d", containerId); + } + // Unmapping ports + for (const auto& portMapping : it->second.PortMappings) + { + THROW_IF_FAILED(m_pVM->MapPort(portMapping.AddressFamily, portMapping.WindowsPort, portMapping.LinuxPort, true)); + } + m_containers.erase(it); + RETURN_HR(S_OK); + */ +} + +ContainerResult ContainerManager::RestartContainer(const int containerId) +{ + ContainerResult containerResult; + containerResult.Result = E_NOTIMPL; // TODO: Implement + return containerResult; +} + +////// private + +bool ContainerManager::CheckPortConflicts(const std::vector& portMappings) +{ + for (const auto& newMapping : portMappings) + { + for (const auto& containerEntry : m_containers) + { + for (const auto& existingMapping : containerEntry.second.PortMappings) + { + if (newMapping.AddressFamily == existingMapping.AddressFamily && + (newMapping.LinuxPort == existingMapping.LinuxPort || newMapping.WindowsPort == existingMapping.WindowsPort)) + { + return true; + } + } + } + } + return false; +} + +bool ContainerManager::IsContainerRunning(const int containerId) +{ + return false; + /* + std::lock_guard containersLock{m_containersLock}; + auto it = m_containers.find(containerId); + THROW_HR_IF_MSG(E_INVALIDARG, it == m_containers.end(), "Container with id %d not found.", containerId); + return it->second.State == ContainerState::Started; + */ +} + +std::string ContainerManager::prepareNerdctlRunCommand(std::string_view image, std::string_view containerName, const ContainerOptions& options) +{ + ContainerManager::NerdctlCommandBuilder builder; + + builder.addArgument("run").addArgument(ContainerManager::defaultNerdctlRunArgs).addArgument("--name").addArgument(containerName); + if (options.ShmSizeMb > 0) + { + builder.addArgument("--shm-size=" + std::to_string(options.ShmSizeMb) + 'm'); + } + if (options.GPUOptions.Enable) + { + builder.addArgument({"--gpus", options.GPUOptions.GPUDevices}); + } + + // TODO: Add envs! + /* for (const auto& env : options.MainProcessOptions.Envs) + { + args.insert(args.end(), {"-e", env}); + } */ + + // Adding local mount paths + for (const auto& volume : options.Volumes) + { + THROW_WIN32_IF_MSG( + ERROR_NOT_SUPPORTED, volume.MountPoint.find(':') != std::string::npos, "Char ':' not supported for MountPoint."); + + std::string mountContainerPath; + mountContainerPath = std::string(volume.HostPath) + ":" + std::string(volume.MountPoint); + if (volume.IsReadOnly) + { + mountContainerPath += ":ro"; + } + + builder.addArgument({"-v", mountContainerPath}); + } + + // TODO: .addEnv("SOMETHING1", "SOEMTHING2") // Add a custom environment variable + builder.addArgument(image); + + // Add main process args + for (const auto& processArgs : options.InitProcessOptions.CommandLine) + { + builder.addArgument(processArgs); + } + + return builder.build(); +} + +ContainerProcess ContainerManager::StartProcess(const std::string& command, const ContainerProcessOptions& processOptions) +{ + WSLA_CREATE_PROCESS_OPTIONS options{}; + + auto Count = [](const auto* Ptr) -> ULONG { + if (Ptr == nullptr) + { + return 0; + } + + ULONG Result = 0; + + while (*Ptr != nullptr) + { + Result++; + Ptr++; + } + + return Result; + }; + + std::vector clArray; + for (const auto& cl : processOptions.CommandLine) + { + clArray.push_back(cl.c_str()); + } + clArray.push_back(nullptr); + + std::vector envArray; + for (const auto& env : processOptions.CommandLine) + { + envArray.push_back(env.c_str()); + } + clArray.push_back(nullptr); + options.Executable = processOptions.Executable.c_str(); + options.CommandLine = clArray.data(); + options.CommandLineCount = Count(options.CommandLine); + options.Environment = envArray.data(); + options.EnvironmentCount = Count(options.Environment); + options.CurrentDirectory = processOptions.CurrentDirectory.c_str(); + + std::vector inputFd(3); + inputFd[0].Fd = 0; + inputFd[0].Type = WslFdType::WslFdTypeDefault; + inputFd[1].Fd = 1; + inputFd[1].Type = WslFdType::WslFdTypeDefault; + inputFd[2].Fd = 2; + inputFd[2].Type = WslFdType::WslFdTypeDefault; + + std::vector env{"PATH=/sbin:/usr/sbin:/bin:/usr/bin", nullptr}; + + std::vector fds(3); + if (fds.empty()) + { + fds.resize(1); // COM doesn't like null pointers. + } + + WSLA_CREATE_PROCESS_RESULT result{}; + THROW_IF_FAILED(reinterpret_cast(m_pVM) + ->CreateLinuxProcess(&options, 3, inputFd.data(), fds.data(), &result)); + + return { UlongToHandle(fds[0]), UlongToHandle(fds[1]), UlongToHandle(fds[2]), result.Pid }; +} +} // namespace wsl::windows::service::wsla + diff --git a/src/windows/wslaservice/exe/ContainerManager.h b/src/windows/wslaservice/exe/ContainerManager.h new file mode 100644 index 000000000..4565ecbb2 --- /dev/null +++ b/src/windows/wslaservice/exe/ContainerManager.h @@ -0,0 +1,213 @@ +#pragma once + +#include +#include +#include +#include +#include +#include +#include + +#include "WSLAVirtualMachine.h" + +namespace wsl::windows::service::wsla { + + struct ContainerVolume + { + bool IsReadOnly; + std::string HostPath; + std::string MountPoint; + }; + + struct PortMapping + { + uint16_t WindowsPort; + uint16_t LinuxPort; + int AddressFamily; + }; + + struct GPUOptions + { + bool Enable; + std::string GPUDevices; + }; + + enum ContainerProcessFlags + { + None = 0, + InteractiveShell, + }; + + struct ContainerProcessOptions + { + std::string Executable; + std::vector CommandLine; + std::vector Environment; + std::string CurrentDirectory; + HANDLE TerminalControlChannel; // Used to create interactive shells, (to handle terminal window resizes) + uint32_t Rows; // Only applicable when creating an interactive shell + uint32_t Columns; + uint32_t Flags = None | InteractiveShell; + }; + + // For use when starting or waiting on a new nerdctl process + struct ContainerProcess + { + HANDLE StdIn; + HANDLE StdOut; + HANDLE StdErr; + int Pid = -1; + }; + + struct ContainerOptions + { + std::string Image; + std::string Name; + ContainerProcessOptions InitProcessOptions; + std::vector Volumes; + std::vector PortMappings; + GPUOptions GPUOptions; + uint64_t ShmSizeMb; + }; + + struct ContainerResult + { + HRESULT Result; + int32_t ContainerId = -1; + ContainerProcess MainProcess; + }; + + enum ContainerState + { + Default, + Creating, + Created, + Stopping, + Exited, + Failed, + + // TODO: Future consideration, add more container states and keep + // states in sync with actual runtime, like exited, etc. + }; + + // For use in the container map m_containers + struct ContainerInfo + { + std::string Name; + ContainerState State = ContainerState::Default; + std::vector PortMappings; + }; + + struct ContainerProcessResult + { + HRESULT Result; + int ExitCode = -1; + std::string StdOut; + std::string StdErr; + }; + + class ContainerManager + { + public: + ContainerManager(WSLAVirtualMachine* vm); + ContainerManager(const ContainerManager&) = delete; + ContainerManager& operator=(const ContainerManager&) = delete; + ContainerManager(ContainerManager&&) = delete; + ContainerManager& operator=(ContainerManager&&) = delete; + + ~ContainerManager(); + + ContainerResult StartNewContainer(const ContainerOptions& options); + ContainerResult StartContainer(const int containerId); + HRESULT StopContainer(const int containerId, bool remove); + ContainerResult RestartContainer(const int containerId); + + // Constants for required default arguments for "nerdctl run..." + static constexpr std::initializer_list defaultNerdctlRunArgs = { + "--pull=never", + "--host=net", // TODO: default for now, change later + "--ulimit nofile=65536:65536"}; + + + // Nested helper class for building nerdctl commands + class NerdctlCommandBuilder + { + public: + private: + const std::string baseCommand = "/usr/bin/nerdctl"; + std::vector args; + // TODO: std::vector envVariables; // Store owned strings for ENV vars + std::vector defaultGlobalArgs; // Any nerdctl-wide global args we want added to every nerdctl command + + public: + NerdctlCommandBuilder() + { + // Initialize with default arguments + if (defaultGlobalArgs.size()) + { + args.insert(args.end(), defaultGlobalArgs.begin(), defaultGlobalArgs.end()); + } + } + + NerdctlCommandBuilder& addArgument(std::string_view arg) + { + args.push_back(arg); + return *this; + } + + NerdctlCommandBuilder& addArgument(std::initializer_list arguments) + { + // Efficiently insert all elements from the initializer list into the arguments vector + args.insert(args.end(), arguments); + return *this; + } + + /* TODO : Adds environment variables(e.g., -e KEY = VALUE) + NerdctlCommandBuilder& addEnv(const std::string& key, const std::string& value) + { + // Formats as "-e KEY=VALUE" + std::string envArg = "-e " + key + "=" + value; + envVariables.push_back(std::move(envArg)); + return *this; + } */ + + // Finalizes and constructs the full command string + std::string build() + { + std::stringstream ss; + ss << baseCommand; + + // 1. Add Default and Custom Arguments + for (const auto& arg : args) + { + ss << " " << arg; + } + + /* TODO: + 2. Add Environment Variables(which are dynamically created strings) + for (const auto& env : envVariables) + { + ss << " " << env; + } */ + + return ss.str(); + } + }; + + private: + WSLAVirtualMachine* m_pVM; + std::map m_containers; + std::recursive_mutex m_containersLock; + std::atomic_int m_containerId = 1; + + bool IsContainerRunning(const int containerId); + + // Start a new nerdctl process + ContainerProcess StartProcess(const std::string& command, const ContainerProcessOptions& processOptions); + ContainerProcessResult WaitForProcess(const ContainerProcess& process, uint64_t waitTimeOutMs = 60000); + + bool CheckPortConflicts(const std::vector& portMappings); + + std::string prepareNerdctlRunCommand(std::string_view image, std::string_view containerName, const ContainerOptions& options); + }; +} // namespace wsl::windows::service::wsla diff --git a/src/windows/wslaservice/exe/WSLASession.h b/src/windows/wslaservice/exe/WSLASession.h index 71939ce85..5428c3895 100644 --- a/src/windows/wslaservice/exe/WSLASession.h +++ b/src/windows/wslaservice/exe/WSLASession.h @@ -21,6 +21,7 @@ namespace wsl::windows::service::wsla { class DECLSPEC_UUID("4877FEFC-4977-4929-A958-9F36AA1892A4") WSLASession : public Microsoft::WRL::RuntimeClass, IWSLASession, IFastRundown> + { public: WSLASession(const WSLA_SESSION_SETTINGS& Settings, WSLAUserSessionImpl& userSessionImpl, const VIRTUAL_MACHINE_SETTINGS& VmSettings); @@ -34,4 +35,4 @@ class DECLSPEC_UUID("4877FEFC-4977-4929-A958-9F36AA1892A4") WSLASession std::wstring m_displayName; }; -} // namespace wsl::windows::service::wsla \ No newline at end of file +} // namespace wsl::windows::service::wsla diff --git a/src/windows/wslaservice/exe/WSLAUserSession.cpp b/src/windows/wslaservice/exe/WSLAUserSession.cpp index d1c840c75..384dfb50b 100644 --- a/src/windows/wslaservice/exe/WSLAUserSession.cpp +++ b/src/windows/wslaservice/exe/WSLAUserSession.cpp @@ -113,3 +113,6 @@ try return session->CreateSession(Settings, VmSettings, WslaSession); } CATCH_RETURN(); + + + diff --git a/src/windows/wslaservice/exe/WSLAVirtualMachine.cpp b/src/windows/wslaservice/exe/WSLAVirtualMachine.cpp index bce7bad78..6b13d263e 100644 --- a/src/windows/wslaservice/exe/WSLAVirtualMachine.cpp +++ b/src/windows/wslaservice/exe/WSLAVirtualMachine.cpp @@ -1107,4 +1107,4 @@ try RETURN_IF_FAILED(Mount("none", LibrariesMountPoint, "overlay", options.c_str(), Flags)); return S_OK; } -CATCH_RETURN(); \ No newline at end of file +CATCH_RETURN(); diff --git a/src/windows/wslaservice/inc/wslaservice.idl b/src/windows/wslaservice/inc/wslaservice.idl index 003ea3d1b..164bbaa89 100644 --- a/src/windows/wslaservice/inc/wslaservice.idl +++ b/src/windows/wslaservice/inc/wslaservice.idl @@ -101,7 +101,6 @@ struct _VIRTUAL_MACHINE_SETTINGS { BOOL EnableGPU; } VIRTUAL_MACHINE_SETTINGS; - typedef struct _WSLA_SESSION_SETTINGS { LPCWSTR DisplayName; @@ -130,4 +129,4 @@ interface IWSLAUserSession : IUnknown HRESULT GetVersion([out] WSL_VERSION* Error); HRESULT CreateVirtualMachine([in] const VIRTUAL_MACHINE_SETTINGS* Settings, [out]IWSLAVirtualMachine** VirtualMachine); HRESULT CreateSession([in] const WSLA_SESSION_SETTINGS* Settings, [in] const VIRTUAL_MACHINE_SETTINGS* VmSettings, [out] IWSLASession** Session); -} \ No newline at end of file +} diff --git a/test/windows/WSLATests.cpp b/test/windows/WSLATests.cpp index 6a37f432e..a9c69daa5 100644 --- a/test/windows/WSLATests.cpp +++ b/test/windows/WSLATests.cpp @@ -1113,7 +1113,7 @@ class WSLATests VERIFY_SUCCEEDED(CoCreateInstance(__uuidof(WSLAUserSession), nullptr, CLSCTX_LOCAL_SERVER, IID_PPV_ARGS(&userSession))); wsl::windows::common::security::ConfigureForCOMImpersonation(userSession.get()); - WSLA_SESSION_SETTINGS settings{L"my-display-name"}; + WSLA_SESSION_SETTINGS settings{L"create-session-smoke-test"}; wil::com_ptr session; VIRTUAL_MACHINE_SETTINGS vmSettings{}; @@ -1131,4 +1131,4 @@ class WSLATests VERIFY_ARE_EQUAL(returnedDisplayName.get(), std::wstring(L"my-display-name")); } -}; \ No newline at end of file +};