Skip to content
14 changes: 11 additions & 3 deletions src/shared/inc/CommandLine.h
Original file line number Diff line number Diff line change
Expand Up @@ -307,15 +307,17 @@ class ArgumentParser
public:
#ifdef WIN32

ArgumentParser(const std::wstring& CommandLine, LPCWSTR Name, int StartIndex = 1) : m_startIndex(StartIndex), m_name(Name)
ArgumentParser(const std::wstring& CommandLine, LPCWSTR Name, int StartIndex = 1, bool stopUnknownArgs = false) :
m_startIndex(StartIndex), m_name(Name), m_stopUnknownArgs(stopUnknownArgs)
{
m_argv.reset(CommandLineToArgvW(std::wstring(CommandLine).c_str(), &m_argc));
THROW_LAST_ERROR_IF(!m_argv);
}

#else

ArgumentParser(int argc, const char* const* argv) : m_argc(argc), m_argv(argv), m_startIndex(1)
ArgumentParser(int argc, const char* const* argv, bool stopUnknownArgs = false) :
m_argc(argc), m_argv(argv), m_startIndex(1), m_stopUnknownArgs(stopUnknownArgs)
{
}

Expand Down Expand Up @@ -400,7 +402,7 @@ class ArgumentParser
const TChar* value = nullptr;
if (e.Positional)
{
value = m_argv[i]; // Positional arguments directly receive arvg[i]
value = m_argv[i]; // Positional arguments directly receive argv[i]
}
else if (i + 1 < m_argc)
{
Expand Down Expand Up @@ -429,6 +431,11 @@ class ArgumentParser

if (!foundMatch)
{
if (m_stopUnknownArgs)
{
break;
}

THROW_USER_ERROR(wsl::shared::Localization::MessageInvalidCommandLine(m_argv[i], m_name ? m_name : m_argv[0]));
}

Expand Down Expand Up @@ -535,6 +542,7 @@ class ArgumentParser

int m_startIndex{};
const TChar* m_name{};
bool m_stopUnknownArgs{false};
};
} // namespace wsl::shared

Expand Down
1 change: 1 addition & 0 deletions src/windows/common/ExecutionContext.h
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,7 @@ enum Context : ULONGLONG
UpdatePackage = 0x10000000000,
QueryLatestGitHubRelease = 0x20000000000,
VerifyChecksum = 0x40000000000,
WslaDiag = 0x80000000000,
};

DEFINE_ENUM_FLAG_OPERATORS(Context)
Expand Down
3 changes: 2 additions & 1 deletion src/windows/common/wslutil.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -190,7 +190,8 @@ static const std::map<Context, LPCWSTR> g_contextStrings{
X(HNS),
X(ReadDistroConfig),
X(MoveDistro),
X(VerifyChecksum)};
X(VerifyChecksum),
X(WslaDiag)};

#undef X

Expand Down
142 changes: 82 additions & 60 deletions src/windows/wsladiag/wsladiag.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ Module Name:

Abstract:

Entry point for the wsladiag tool, performs WSL runtime initialization and parses --list/--help.
Entry point for the wsladiag tool, performs WSL runtime initialization and parses list/shell/help.

--*/

Expand All @@ -18,11 +18,14 @@ Module Name:
#include "wslaservice.h"
#include "WslSecurity.h"
#include "WSLAProcessLauncher.h"
#include "ExecutionContext.h"
#include <thread>
#include <format>

using namespace wsl::shared;
namespace wslutil = wsl::windows::common::wslutil;
using wsl::windows::common::Context;
using wsl::windows::common::ExecutionContext;
using wsl::windows::common::WSLAProcessLauncher;

// Adding a helper to factor error handling between all the arguments.
Expand All @@ -42,9 +45,23 @@ static int ReportError(const std::wstring& context, HRESULT hr)
return 1;
}

// Handler for `wsladiag shell <SessionName>` (TTY-backed interactive shell).
static int RunShellCommand(const std::wstring& sessionName, bool verbose)
// Handler for `wsladiag shell <SessionName> [--verbose]` command - launches TTY-backed interactive shell.
static int RunShellCommand(std::wstring_view commandLine)
{
std::wstring sessionName;
bool verbose = false;

ArgumentParser parser(std::wstring{commandLine}, L"wsladiag", 2); // Skip "wsladiag.exe shell" to parse shell-specific args
parser.AddPositionalArgument(sessionName, 0);
parser.AddArgument(verbose, L"--verbose", L'v');

parser.Parse();

if (sessionName.empty())
{
THROW_HR(E_INVALIDARG);
}

const auto log = [&](std::wstring_view msg) {
if (verbose)
{
Expand Down Expand Up @@ -85,9 +102,11 @@ static int RunShellCommand(const std::wstring& sessionName, bool verbose)
wsl::windows::common::WSLAProcessLauncher launcher{
shell, {shell, "--login"}, {"TERM=xterm-256color"}, wsl::windows::common::ProcessFlags::None};

launcher.AddFd(WSLA_PROCESS_FD{.Fd = 0, .Type = WSLAFdTypeTerminalInput});
launcher.AddFd(WSLA_PROCESS_FD{.Fd = 1, .Type = WSLAFdTypeTerminalOutput});
launcher.AddFd(WSLA_PROCESS_FD{.Fd = 2, .Type = WSLAFdTypeTerminalControl});
launcher.AddFd(WSLA_PROCESS_FD{.Fd = 0, .Type = WSLAFdTypeTerminalInput, .Path = nullptr});
launcher.AddFd(WSLA_PROCESS_FD{.Fd = 1, .Type = WSLAFdTypeTerminalOutput, .Path = nullptr});
launcher.AddFd(WSLA_PROCESS_FD{.Fd = 2, .Type = WSLAFdTypeTerminalControl, .Path = nullptr});

log(std::format(L"[diag] tty rows={} cols={}", rows, cols));
launcher.SetTtySize(rows, cols);

log(L"[diag] launching shell process...");
Expand All @@ -96,7 +115,6 @@ static int RunShellCommand(const std::wstring& sessionName, bool verbose)

auto ttyIn = process.GetStdHandle(0);
auto ttyOut = process.GetStdHandle(1);
auto ttyControl = process.GetStdHandle(2);

// Console handles.
wil::unique_hfile conin{
Expand All @@ -119,11 +137,11 @@ static int RunShellCommand(const std::wstring& sessionName, bool verbose)
const UINT originalOutCP = GetConsoleOutputCP();
const UINT originalInCP = GetConsoleCP();

auto restoreConsole = wil::scope_exit([&] {
SetConsoleMode(consoleIn, originalInMode);
SetConsoleMode(consoleOut, originalOutMode);
SetConsoleOutputCP(originalOutCP);
SetConsoleCP(originalInCP);
auto restoreConsole = wil::scope_exit_log(WI_DIAGNOSTICS_INFO, [&] {
LOG_IF_WIN32_BOOL_FALSE(SetConsoleMode(consoleIn, originalInMode));
LOG_IF_WIN32_BOOL_FALSE(SetConsoleMode(consoleOut, originalOutMode));
LOG_IF_WIN32_BOOL_FALSE(SetConsoleOutputCP(originalOutCP));
LOG_IF_WIN32_BOOL_FALSE(SetConsoleCP(originalInCP));
});

// Console mode for interactive terminal.
Expand All @@ -139,8 +157,9 @@ static int RunShellCommand(const std::wstring& sessionName, bool verbose)
THROW_LAST_ERROR_IF(!SetConsoleOutputCP(CP_UTF8));
THROW_LAST_ERROR_IF(!SetConsoleCP(CP_UTF8));

// Keep terminal control socket alive.
auto exitEvent = wil::unique_event(wil::EventOptions::ManualReset);

auto ttyControl = process.GetStdHandle(2); // TerminalControl
wsl::shared::SocketChannel controlChannel{wil::unique_socket{(SOCKET)ttyControl.release()}, "TerminalControl", exitEvent.get()};

auto updateTerminalSize = [&]() {
Expand All @@ -155,7 +174,8 @@ static int RunShellCommand(const std::wstring& sessionName, bool verbose)
controlChannel.SendMessage(message);
};

// Relay console -> tty input.
// Start input relay thread to forward console input to TTY
// Runs in parallel with output relay (main thread)
std::thread inputThread([&] {
try
{
Expand All @@ -179,15 +199,16 @@ static int RunShellCommand(const std::wstring& sessionName, bool verbose)
wsl::windows::common::relay::InterruptableRelay(ttyOut.get(), consoleOut, exitEvent.get());

process.GetExitEvent().wait();
auto [code, signalled] = process.GetExitState();

auto [exitCode, signalled] = process.GetExitState();

std::wstring shellWide(shell.begin(), shell.end());
wslutil::PrintMessage(std::format(L"{} exited with: {}{}", shellWide, code, signalled ? L" (signalled)" : L""), stdout);
wslutil::PrintMessage(std::format(L"{} exited with: {}{}", shellWide, exitCode, signalled ? L" (signalled)" : L""), stdout);

return 0;
}

static int RunListCommand(bool /*verbose*/)
static int RunListCommand()
{
wil::com_ptr<IWSLAUserSession> userSession;
THROW_IF_FAILED(CoCreateInstance(__uuidof(WSLAUserSession), nullptr, CLSCTX_LOCAL_SERVER, IID_PPV_ARGS(&userSession)));
Expand Down Expand Up @@ -240,6 +261,17 @@ static int RunListCommand(bool /*verbose*/)
return 0;
}

static void PrintUsage()
{
wslutil::PrintMessage(
L"wsladiag - WSLA diagnostics tool\n"
L"Usage:\n"
L" wsladiag list\n"
L" wsladiag shell <SessionName> [--verbose]\n"
L" wsladiag --help\n",
stderr);
}

int wsladiag_main(std::wstring_view commandLine)
{
wslutil::ConfigureCrt();
Expand All @@ -257,78 +289,68 @@ int wsladiag_main(std::wstring_view commandLine)
THROW_IF_WIN32_ERROR(WSAStartup(MAKEWORD(2, 2), &data));
auto wsaCleanup = wil::scope_exit_log(WI_DIAGNOSTICS_INFO, []() { WSACleanup(); });

ArgumentParser parser(std::wstring{commandLine}, L"wsladiag");
ArgumentParser parser(std::wstring{commandLine}, L"wsladiag", 1, true);

bool help = false;
bool verbose = false;
std::wstring verb;
std::wstring shellSession;

parser.AddPositionalArgument(verb, 0); // "list" or "shell"
parser.AddPositionalArgument(shellSession, 1); // session name for "shell"
parser.AddPositionalArgument(verb, 0);
parser.AddArgument(help, L"--help", L'h');
parser.AddArgument(verbose, L"--verbose", L'v');

auto printUsage = []() {
wslutil::PrintMessage(
L"wsladiag - WSLA diagnostics tool\n"
L"Usage:\n"
L" wsladiag list\n"
L" wsladiag shell <SessionName> [--verbose]\n"
L" wsladiag --help\n",
stderr);
};

try
{
parser.Parse();
}
catch (...)
{
const auto hr = wil::ResultFromCaughtException();
if (hr == E_INVALIDARG)
{
printUsage();
return 1;
}
throw;
}
parser.Parse(); // Let exceptions propagate to wmain for centralized handling

if (help || verb.empty())
{
printUsage();
PrintUsage();
return 0;
}
else if (verb == L"list")
{
return RunListCommand(verbose);
return RunListCommand();
}
else if (verb == L"shell")
{
if (shellSession.empty())
{
printUsage();
return 1;
}
return RunShellCommand(shellSession, verbose);
return RunShellCommand(commandLine);
}
else
{
wslutil::PrintMessage(std::format(L"Unknown command: '{}'", verb), stderr);
printUsage();
PrintUsage();
return 1;
}
}

int wmain(int, wchar_t**)
{
wsl::windows::common::EnableContextualizedErrors(false);

std::optional<ExecutionContext> context;
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: context is always set to the same value, so it doesn't need to be an std::optional

int exitCode = 1;
HRESULT result = S_OK;

try
{
return wsladiag_main(GetCommandLineW());
context.emplace(Context::WslaDiag);
exitCode = wsladiag_main(GetCommandLineW());
}
catch (...)
{
const auto hr = wil::ResultFromCaughtException();
return ReportError(L"wsladiag failed", hr);
result = wil::ResultFromCaughtException();
}
}

if (FAILED(result))
{
if (context.has_value() && context->ReportedError().has_value())
{
auto strings = wsl::windows::common::wslutil::ErrorToString(context->ReportedError().value());
wslutil::PrintMessage(strings.Message, stderr);
}
else
{
// Fallback for errors without context
wslutil::PrintMessage(wslutil::GetErrorString(result), stderr);
}
}

return exitCode;
}
3 changes: 2 additions & 1 deletion test/windows/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,8 @@ set(SOURCES
PluginTests.cpp
PolicyTests.cpp
InstallerTests.cpp
WSLATests.cpp)
WSLATests.cpp
WsladiagTests.cpp)

set(HEADERS
Common.h
Expand Down
2 changes: 1 addition & 1 deletion test/windows/WSLATests.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -170,7 +170,7 @@ class WSLATests
if (result.Code != expectedResult)
{
LogError(
"Comman didn't return expected code (%i). ExitCode: %i, Stdout: '%hs', Stderr: '%hs'",
"Command didn't return expected code (%i). ExitCode: %i, Stdout: '%hs', Stderr: '%hs'",
expectedResult,
result.Code,
result.Output[1].c_str(),
Expand Down
Loading