diff --git a/src/shared/inc/CommandLine.h b/src/shared/inc/CommandLine.h index 9966c5e65..3ef367ed4 100644 --- a/src/shared/inc/CommandLine.h +++ b/src/shared/inc/CommandLine.h @@ -307,7 +307,8 @@ class ArgumentParser public: #ifdef WIN32 - ArgumentParser(const std::wstring& CommandLine, LPCWSTR Name, int StartIndex = 1) : m_startIndex(StartIndex), m_name(Name) + ArgumentParser(const std::wstring& CommandLine, LPCWSTR Name, int StartIndex = 1, bool stopUnknownArgs = false) : + m_startIndex(StartIndex), m_name(Name), m_stopUnknownArgs(stopUnknownArgs) { m_argv.reset(CommandLineToArgvW(std::wstring(CommandLine).c_str(), &m_argc)); THROW_LAST_ERROR_IF(!m_argv); @@ -315,7 +316,8 @@ class ArgumentParser #else - ArgumentParser(int argc, const char* const* argv) : m_argc(argc), m_argv(argv), m_startIndex(1) + ArgumentParser(int argc, const char* const* argv, bool stopUnknownArgs = false) : + m_argc(argc), m_argv(argv), m_startIndex(1), m_stopUnknownArgs(stopUnknownArgs) { } @@ -400,7 +402,7 @@ class ArgumentParser const TChar* value = nullptr; if (e.Positional) { - value = m_argv[i]; // Positional arguments directly receive arvg[i] + value = m_argv[i]; // Positional arguments directly receive argv[i] } else if (i + 1 < m_argc) { @@ -429,6 +431,11 @@ class ArgumentParser if (!foundMatch) { + if (m_stopUnknownArgs) + { + break; + } + THROW_USER_ERROR(wsl::shared::Localization::MessageInvalidCommandLine(m_argv[i], m_name ? m_name : m_argv[0])); } @@ -535,6 +542,7 @@ class ArgumentParser int m_startIndex{}; const TChar* m_name{}; + bool m_stopUnknownArgs{false}; }; } // namespace wsl::shared diff --git a/src/windows/common/ExecutionContext.h b/src/windows/common/ExecutionContext.h index 0fe473d9f..9c12ff644 100644 --- a/src/windows/common/ExecutionContext.h +++ b/src/windows/common/ExecutionContext.h @@ -66,6 +66,7 @@ enum Context : ULONGLONG UpdatePackage = 0x10000000000, QueryLatestGitHubRelease = 0x20000000000, VerifyChecksum = 0x40000000000, + WslaDiag = 0x80000000000, }; DEFINE_ENUM_FLAG_OPERATORS(Context) diff --git a/src/windows/common/wslutil.cpp b/src/windows/common/wslutil.cpp index 8c4ad8ac3..544c6fce5 100644 --- a/src/windows/common/wslutil.cpp +++ b/src/windows/common/wslutil.cpp @@ -190,7 +190,8 @@ static const std::map g_contextStrings{ X(HNS), X(ReadDistroConfig), X(MoveDistro), - X(VerifyChecksum)}; + X(VerifyChecksum), + X(WslaDiag)}; #undef X diff --git a/src/windows/wsladiag/wsladiag.cpp b/src/windows/wsladiag/wsladiag.cpp index eef2f3150..0cec3a59c 100644 --- a/src/windows/wsladiag/wsladiag.cpp +++ b/src/windows/wsladiag/wsladiag.cpp @@ -8,7 +8,7 @@ Module Name: Abstract: - Entry point for the wsladiag tool, performs WSL runtime initialization and parses --list/--help. + Entry point for the wsladiag tool, performs WSL runtime initialization and parses list/shell/help. --*/ @@ -18,11 +18,14 @@ Module Name: #include "wslaservice.h" #include "WslSecurity.h" #include "WSLAProcessLauncher.h" +#include "ExecutionContext.h" #include #include using namespace wsl::shared; namespace wslutil = wsl::windows::common::wslutil; +using wsl::windows::common::Context; +using wsl::windows::common::ExecutionContext; using wsl::windows::common::WSLAProcessLauncher; // Adding a helper to factor error handling between all the arguments. @@ -42,9 +45,23 @@ static int ReportError(const std::wstring& context, HRESULT hr) return 1; } -// Handler for `wsladiag shell ` (TTY-backed interactive shell). -static int RunShellCommand(const std::wstring& sessionName, bool verbose) +// Handler for `wsladiag shell [--verbose]` command - launches TTY-backed interactive shell. +static int RunShellCommand(std::wstring_view commandLine) { + std::wstring sessionName; + bool verbose = false; + + ArgumentParser parser(std::wstring{commandLine}, L"wsladiag", 2); // Skip "wsladiag.exe shell" to parse shell-specific args + parser.AddPositionalArgument(sessionName, 0); + parser.AddArgument(verbose, L"--verbose", L'v'); + + parser.Parse(); + + if (sessionName.empty()) + { + THROW_HR(E_INVALIDARG); + } + const auto log = [&](std::wstring_view msg) { if (verbose) { @@ -85,9 +102,11 @@ static int RunShellCommand(const std::wstring& sessionName, bool verbose) wsl::windows::common::WSLAProcessLauncher launcher{ shell, {shell, "--login"}, {"TERM=xterm-256color"}, wsl::windows::common::ProcessFlags::None}; - launcher.AddFd(WSLA_PROCESS_FD{.Fd = 0, .Type = WSLAFdTypeTerminalInput}); - launcher.AddFd(WSLA_PROCESS_FD{.Fd = 1, .Type = WSLAFdTypeTerminalOutput}); - launcher.AddFd(WSLA_PROCESS_FD{.Fd = 2, .Type = WSLAFdTypeTerminalControl}); + launcher.AddFd(WSLA_PROCESS_FD{.Fd = 0, .Type = WSLAFdTypeTerminalInput, .Path = nullptr}); + launcher.AddFd(WSLA_PROCESS_FD{.Fd = 1, .Type = WSLAFdTypeTerminalOutput, .Path = nullptr}); + launcher.AddFd(WSLA_PROCESS_FD{.Fd = 2, .Type = WSLAFdTypeTerminalControl, .Path = nullptr}); + + log(std::format(L"[diag] tty rows={} cols={}", rows, cols)); launcher.SetTtySize(rows, cols); log(L"[diag] launching shell process..."); @@ -96,7 +115,6 @@ static int RunShellCommand(const std::wstring& sessionName, bool verbose) auto ttyIn = process.GetStdHandle(0); auto ttyOut = process.GetStdHandle(1); - auto ttyControl = process.GetStdHandle(2); // Console handles. wil::unique_hfile conin{ @@ -119,11 +137,11 @@ static int RunShellCommand(const std::wstring& sessionName, bool verbose) const UINT originalOutCP = GetConsoleOutputCP(); const UINT originalInCP = GetConsoleCP(); - auto restoreConsole = wil::scope_exit([&] { - SetConsoleMode(consoleIn, originalInMode); - SetConsoleMode(consoleOut, originalOutMode); - SetConsoleOutputCP(originalOutCP); - SetConsoleCP(originalInCP); + auto restoreConsole = wil::scope_exit_log(WI_DIAGNOSTICS_INFO, [&] { + LOG_IF_WIN32_BOOL_FALSE(SetConsoleMode(consoleIn, originalInMode)); + LOG_IF_WIN32_BOOL_FALSE(SetConsoleMode(consoleOut, originalOutMode)); + LOG_IF_WIN32_BOOL_FALSE(SetConsoleOutputCP(originalOutCP)); + LOG_IF_WIN32_BOOL_FALSE(SetConsoleCP(originalInCP)); }); // Console mode for interactive terminal. @@ -139,8 +157,9 @@ static int RunShellCommand(const std::wstring& sessionName, bool verbose) THROW_LAST_ERROR_IF(!SetConsoleOutputCP(CP_UTF8)); THROW_LAST_ERROR_IF(!SetConsoleCP(CP_UTF8)); - // Keep terminal control socket alive. auto exitEvent = wil::unique_event(wil::EventOptions::ManualReset); + + auto ttyControl = process.GetStdHandle(2); // TerminalControl wsl::shared::SocketChannel controlChannel{wil::unique_socket{(SOCKET)ttyControl.release()}, "TerminalControl", exitEvent.get()}; auto updateTerminalSize = [&]() { @@ -155,7 +174,8 @@ static int RunShellCommand(const std::wstring& sessionName, bool verbose) controlChannel.SendMessage(message); }; - // Relay console -> tty input. + // Start input relay thread to forward console input to TTY + // Runs in parallel with output relay (main thread) std::thread inputThread([&] { try { @@ -179,15 +199,16 @@ static int RunShellCommand(const std::wstring& sessionName, bool verbose) wsl::windows::common::relay::InterruptableRelay(ttyOut.get(), consoleOut, exitEvent.get()); process.GetExitEvent().wait(); - auto [code, signalled] = process.GetExitState(); + + auto [exitCode, signalled] = process.GetExitState(); std::wstring shellWide(shell.begin(), shell.end()); - wslutil::PrintMessage(std::format(L"{} exited with: {}{}", shellWide, code, signalled ? L" (signalled)" : L""), stdout); + wslutil::PrintMessage(std::format(L"{} exited with: {}{}", shellWide, exitCode, signalled ? L" (signalled)" : L""), stdout); return 0; } -static int RunListCommand(bool /*verbose*/) +static int RunListCommand() { wil::com_ptr userSession; THROW_IF_FAILED(CoCreateInstance(__uuidof(WSLAUserSession), nullptr, CLSCTX_LOCAL_SERVER, IID_PPV_ARGS(&userSession))); @@ -240,6 +261,17 @@ static int RunListCommand(bool /*verbose*/) return 0; } +static void PrintUsage() +{ + wslutil::PrintMessage( + L"wsladiag - WSLA diagnostics tool\n" + L"Usage:\n" + L" wsladiag list\n" + L" wsladiag shell [--verbose]\n" + L" wsladiag --help\n", + stderr); +} + int wsladiag_main(std::wstring_view commandLine) { wslutil::ConfigureCrt(); @@ -257,78 +289,67 @@ int wsladiag_main(std::wstring_view commandLine) THROW_IF_WIN32_ERROR(WSAStartup(MAKEWORD(2, 2), &data)); auto wsaCleanup = wil::scope_exit_log(WI_DIAGNOSTICS_INFO, []() { WSACleanup(); }); - ArgumentParser parser(std::wstring{commandLine}, L"wsladiag"); + ArgumentParser parser(std::wstring{commandLine}, L"wsladiag", 1, true); bool help = false; - bool verbose = false; std::wstring verb; - std::wstring shellSession; - parser.AddPositionalArgument(verb, 0); // "list" or "shell" - parser.AddPositionalArgument(shellSession, 1); // session name for "shell" + parser.AddPositionalArgument(verb, 0); parser.AddArgument(help, L"--help", L'h'); - parser.AddArgument(verbose, L"--verbose", L'v'); - auto printUsage = []() { - wslutil::PrintMessage( - L"wsladiag - WSLA diagnostics tool\n" - L"Usage:\n" - L" wsladiag list\n" - L" wsladiag shell [--verbose]\n" - L" wsladiag --help\n", - stderr); - }; - - try - { - parser.Parse(); - } - catch (...) - { - const auto hr = wil::ResultFromCaughtException(); - if (hr == E_INVALIDARG) - { - printUsage(); - return 1; - } - throw; - } + parser.Parse(); // Let exceptions propagate to wmain for centralized handling if (help || verb.empty()) { - printUsage(); + PrintUsage(); return 0; } else if (verb == L"list") { - return RunListCommand(verbose); + return RunListCommand(); } else if (verb == L"shell") { - if (shellSession.empty()) - { - printUsage(); - return 1; - } - return RunShellCommand(shellSession, verbose); + return RunShellCommand(commandLine); } else { wslutil::PrintMessage(std::format(L"Unknown command: '{}'", verb), stderr); - printUsage(); + PrintUsage(); return 1; } } int wmain(int, wchar_t**) { + wsl::windows::common::EnableContextualizedErrors(false); + + ExecutionContext context{Context::WslaDiag}; + int exitCode = 1; + HRESULT result = S_OK; + try { - return wsladiag_main(GetCommandLineW()); + exitCode = wsladiag_main(GetCommandLineW()); } catch (...) { - const auto hr = wil::ResultFromCaughtException(); - return ReportError(L"wsladiag failed", hr); + result = wil::ResultFromCaughtException(); } -} + + if (FAILED(result)) + { + if (auto reported = context.ReportedError()) + { + auto strings = wsl::windows::common::wslutil::ErrorToString(*reported); + wslutil::PrintMessage(strings.Message.empty() ? strings.Code : strings.Message, stderr); + } + else + { + // Fallback for errors without context + wslutil::PrintMessage(wslutil::GetErrorString(result), stderr); + } + } + + return exitCode; +} \ No newline at end of file diff --git a/test/windows/CMakeLists.txt b/test/windows/CMakeLists.txt index c700d66f4..1abe54e68 100644 --- a/test/windows/CMakeLists.txt +++ b/test/windows/CMakeLists.txt @@ -9,7 +9,8 @@ set(SOURCES PluginTests.cpp PolicyTests.cpp InstallerTests.cpp - WSLATests.cpp) + WSLATests.cpp + WsladiagTests.cpp) set(HEADERS Common.h diff --git a/test/windows/WSLATests.cpp b/test/windows/WSLATests.cpp index 38dfbc395..37591cb0e 100644 --- a/test/windows/WSLATests.cpp +++ b/test/windows/WSLATests.cpp @@ -170,7 +170,7 @@ class WSLATests if (result.Code != expectedResult) { LogError( - "Comman didn't return expected code (%i). ExitCode: %i, Stdout: '%hs', Stderr: '%hs'", + "Command didn't return expected code (%i). ExitCode: %i, Stdout: '%hs', Stderr: '%hs'", expectedResult, result.Code, result.Output[1].c_str(), diff --git a/test/windows/WsladiagTests.cpp b/test/windows/WsladiagTests.cpp new file mode 100644 index 000000000..50bed9fc9 --- /dev/null +++ b/test/windows/WsladiagTests.cpp @@ -0,0 +1,148 @@ +/*++ + +Copyright (c) Microsoft. All rights reserved. + +Module Name: + + WsladiagTests.cpp + +Abstract: + + This file contains smoke tests for wsladiag. + +--*/ + +#include "precomp.h" +#include "Common.h" +#include + +static const std::wstring c_usageText = + L"wsladiag - WSLA diagnostics tool\r\n" + L"Usage:\r\n" + L" wsladiag list\r\n" + L" wsladiag shell [--verbose]\r\n" + L" wsladiag --help\r\n"; + +namespace WsladiagTests { +class WsladiagTests +{ + WSL_TEST_CLASS(WsladiagTests) + + // Test that wsladiag list command shows either sessions or "no sessions" message + TEST_METHOD(List_ShowsSessionsOrNoSessions) + { + auto [out, err, code] = RunWsladiag(L"list"); + VERIFY_ARE_EQUAL(0, code); + VERIFY_ARE_EQUAL(L"", err); + + ValidateListOutput(out); + } + + // Test that wsladiag --help shows usage information + TEST_METHOD(Help_ShowsUsage) + { + ValidateWsladiagOutput(L"--help", 0, L"", c_usageText); + } + + // Test that wsladiag with no arguments shows usage information + TEST_METHOD(EmptyCommand_ShowsUsage) + { + ValidateWsladiagOutput(L"", 0, L"", c_usageText); + } + + // Test that -h and --help flags produce identical output + TEST_METHOD(Help_ShortAndLongFlags_Match) + { + auto [outH, errH, codeH] = RunWsladiag(L"-h"); + auto [outLong, errLong, codeLong] = RunWsladiag(L"--help"); + + VERIFY_ARE_EQUAL(0, codeH); + VERIFY_ARE_EQUAL(0, codeLong); + + VERIFY_ARE_EQUAL(L"", outH); + VERIFY_ARE_EQUAL(L"", outLong); + + VERIFY_ARE_EQUAL(errH, errLong); + ValidateUsage(errH); + } + + // Test that unknown commands show error message and usage + TEST_METHOD(UnknownCommand_ShowsError) + { + ValidateWsladiagOutput(L"blah", 1, L"", std::wstring(L"Unknown command: 'blah'\r\n") + c_usageText); + } + + // Test that shell command without session name shows error + TEST_METHOD(Shell_MissingName_ShowsError) + { + ValidateWsladiagOutput(L"shell", 1, L"", L"The parameter is incorrect.\r\n"); + } + + // Test shell command with invalid session name (silent mode) + TEST_METHOD(Shell_InvalidSessionName_Silent) + { + auto [out, err, code] = RunWsladiag(L"shell DefinitelyNotARealSession"); + VERIFY_ARE_NOT_EQUAL(0, code); + + VERIFY_ARE_EQUAL(L"", out); + VERIFY_IS_TRUE(err.find(L"Session not found: 'DefinitelyNotARealSession'") != std::wstring::npos); + } + + // Test shell command with invalid session name (verbose mode) + TEST_METHOD(Shell_InvalidSessionName_Verbose) + { + const std::wstring name = L"DefinitelyNotARealSession"; + auto [out, err, code] = RunWsladiag(std::format(L"shell {} --verbose", name)); + VERIFY_ARE_NOT_EQUAL(0, code); + + VERIFY_IS_TRUE(out.find(std::format(L"[diag] shell='{}'", name)) != std::wstring::npos); + VERIFY_IS_TRUE(err.find(L"Session not found") != std::wstring::npos); + } + + // Build command line for wsladiag.exe with given arguments + static std::wstring BuildWsladiagCmd(const std::wstring& args) + { + const auto msiPathOpt = wsl::windows::common::wslutil::GetMsiPackagePath(); + VERIFY_IS_TRUE(msiPathOpt.has_value()); + + const auto exePath = std::filesystem::path(*msiPathOpt) / L"wsladiag.exe"; + const auto exe = exePath.wstring(); + + return args.empty() ? std::format(L"\"{}\"", exe) : std::format(L"\"{}\" {}", exe, args); + } + + // Execute wsladiag with given arguments and return output, error, and exit code + static std::tuple RunWsladiag(const std::wstring& args) + { + auto cmd = BuildWsladiagCmd(args); + return LxsstuLaunchCommandAndCaptureOutputWithResult(cmd.data()); + } + + static void ValidateWsladiagOutput(const std::wstring& cmd, int expectedExitCode, const std::wstring& expectedStdout, const std::wstring& expectedStderr) + { + auto [out, err, code] = RunWsladiag(cmd); + VERIFY_ARE_EQUAL(expectedExitCode, code); + VERIFY_ARE_EQUAL(expectedStdout, out); + VERIFY_ARE_EQUAL(expectedStderr, err); + } + + // Validate that list command output shows either no sessions message or session table + static void ValidateListOutput(const std::wstring& out) + { + const bool noSessions = out.find(L"No WSLA sessions found.") != std::wstring::npos; + + const bool hasTable = out.find(L"Found") != std::wstring::npos && out.find(L"ID") != std::wstring::npos && + out.find(L"Creator PID") != std::wstring::npos && out.find(L"Display Name") != std::wstring::npos; + + VERIFY_IS_TRUE(noSessions || hasTable); + } + + // Validate that usage information contains expected command descriptions + static void ValidateUsage(const std::wstring& err) + { + VERIFY_IS_TRUE(err.find(L"Usage:") != std::wstring::npos); + VERIFY_IS_TRUE(err.find(L"wsladiag list") != std::wstring::npos); + VERIFY_IS_TRUE(err.find(L"wsladiag shell [--verbose]") != std::wstring::npos); + } +}; +} // namespace WsladiagTests \ No newline at end of file