diff --git a/cmd/authd/main.go b/cmd/authd/main.go index 8314ec27e3..98619bc0c9 100644 --- a/cmd/authd/main.go +++ b/cmd/authd/main.go @@ -5,6 +5,7 @@ import ( "context" "os" "os/signal" + "strconv" "sync" "syscall" @@ -30,6 +31,7 @@ type app interface { } func run(a app) int { + os.Setenv("AUTHD_PID", strconv.FormatInt(int64(os.Getpid()), 10)) defer installSignalHandler(a)() if err := a.Run(); err != nil { diff --git a/nss/integration-tests/helper_test.go b/nss/integration-tests/helper_test.go index 010908d741..03b5db1c45 100644 --- a/nss/integration-tests/helper_test.go +++ b/nss/integration-tests/helper_test.go @@ -6,25 +6,22 @@ import ( "io" "os" "os/exec" - "path/filepath" + "slices" "testing" ) // getentOutputForLib returns the specific part for the nss command for the authd service. // It uses the locally build authd nss module for the integration tests. -func getentOutputForLib(t *testing.T, libPath, socketPath string, rustCovEnv []string, shouldPreCheck bool, cmds ...string) (got string, exitCode int) { +func getentOutputForLib(t *testing.T, socketPath string, env []string, shouldPreCheck bool, cmds ...string) (got string, exitCode int) { t.Helper() // #nosec:G204 - we control the command arguments in tests cmds = append(cmds, "--service", "authd") cmd := exec.Command("getent", cmds...) - cmd.Env = append(cmd.Env, - "AUTHD_NSS_INFO=stderr", - // NSS needs both LD_PRELOAD and LD_LIBRARY_PATH to load the module library - fmt.Sprintf("LD_PRELOAD=%s:%s", libPath, os.Getenv("LD_PRELOAD")), - fmt.Sprintf("LD_LIBRARY_PATH=%s:%s", filepath.Dir(libPath), os.Getenv("LD_LIBRARY_PATH")), - ) - cmd.Env = append(cmd.Env, rustCovEnv...) + cmd.Env = slices.Clone(env) + + // Set the PID to to self, so that we can verify that it won't work for all. + cmd.Env = append(cmd.Env, fmt.Sprintf("AUTHD_PID=%d", os.Getpid())) if socketPath != "" { cmd.Env = append(cmd.Env, fmt.Sprintf("AUTHD_NSS_SOCKET=%s", socketPath)) diff --git a/nss/integration-tests/integration_test.go b/nss/integration-tests/integration_test.go index 79caf306bc..b892db722b 100644 --- a/nss/integration-tests/integration_test.go +++ b/nss/integration-tests/integration_test.go @@ -1,9 +1,13 @@ package nss_test import ( + "bytes" "context" + "fmt" "log" "os" + "os/exec" + "os/user" "path/filepath" "strings" "testing" @@ -13,6 +17,7 @@ import ( "github.com/ubuntu/authd/internal/testutils" "github.com/ubuntu/authd/internal/testutils/golden" localgroupstestutils "github.com/ubuntu/authd/internal/users/localentries/testutils" + "gopkg.in/yaml.v3" ) var daemonPath string @@ -26,12 +31,21 @@ func TestIntegration(t *testing.T) { libPath, rustCovEnv := testutils.BuildRustNSSLib(t, false, "should_pre_check_env") // Create a default daemon to use for most test cases. - defaultSocket := filepath.Join(os.TempDir(), "nss-integration-tests.sock") + defaultSocket := filepath.Join(t.TempDir(), "nss.sock") defaultDbState := "multiple_users_and_groups" defaultOutputPath := filepath.Join(filepath.Dir(daemonPath), "gpasswd.output") defaultGroupsFilePath := filepath.Join(testutils.TestFamilyPath(t), "gpasswd.group") + nssLibraryEnv := append(rustCovEnv, + "AUTHD_NSS_INFO=stderr", + // NSS needs both LD_PRELOAD and LD_LIBRARY_PATH to load the module library + fmt.Sprintf("LD_PRELOAD=%s:%s", libPath, os.Getenv("LD_PRELOAD")), + fmt.Sprintf("LD_LIBRARY_PATH=%s:%s", filepath.Dir(libPath), os.Getenv("LD_LIBRARY_PATH")), + ) + env := append(localgroupstestutils.AuthdIntegrationTestsEnvWithGpasswdMock(t, defaultOutputPath, defaultGroupsFilePath), "AUTHD_INTEGRATIONTESTS_CURRENT_USER_AS_ROOT=1") + env = append(env, nssLibraryEnv...) + env = append(env, fmt.Sprintf("AUTHD_NSS_SOCKET=%s", defaultSocket)) ctx, cancel := context.WithCancel(context.Background()) _, stopped := testutils.RunDaemon(ctx, t, daemonPath, testutils.WithSocketPath(defaultSocket), @@ -118,12 +132,17 @@ func TestIntegration(t *testing.T) { outPath := filepath.Join(t.TempDir(), "gpasswd.output") groupsFilePath := filepath.Join("testdata", "empty.group") + socketPath = filepath.Join(t.TempDir(), "nss.sock") + var daemonStopped chan struct{} ctx, cancel := context.WithCancel(context.Background()) env := localgroupstestutils.AuthdIntegrationTestsEnvWithGpasswdMock(t, outPath, groupsFilePath) - socketPath, daemonStopped = testutils.RunDaemon(ctx, t, daemonPath, + env = append(env, nssLibraryEnv...) + env = append(env, fmt.Sprintf("AUTHD_NSS_SOCKET=%s", socketPath)) + _, daemonStopped = testutils.RunDaemon(ctx, t, daemonPath, testutils.WithPreviousDBState(tc.dbState), testutils.WithEnvironment(env...), + testutils.WithSocketPath(socketPath), ) t.Cleanup(func() { cancel() @@ -136,7 +155,7 @@ func TestIntegration(t *testing.T) { cmds = append(cmds, tc.key) } - got, status := getentOutputForLib(t, libPath, socketPath, rustCovEnv, tc.shouldPreCheck, cmds...) + got, status := getentOutputForLib(t, socketPath, nssLibraryEnv, tc.shouldPreCheck, cmds...) require.Equal(t, tc.wantStatus, status, "Expected status %d, but got %d", tc.wantStatus, status) if tc.shouldPreCheck && tc.getentDB == "passwd" { @@ -164,12 +183,118 @@ func TestIntegration(t *testing.T) { // This is to check that some cache tasks, such as cleaning a corrupted database, work as expected. if tc.wantSecondCall { - got, status := getentOutputForLib(t, libPath, socketPath, rustCovEnv, tc.shouldPreCheck, cmds...) + got, status := getentOutputForLib(t, socketPath, nssLibraryEnv, tc.shouldPreCheck, cmds...) require.NotEqual(t, codeNotFound, status, "Expected no error, but got %v", status) require.Empty(t, got, "Expected empty output, but got %q", got) } }) } + + runPidAbuser := func(action, arg string) []byte { + require.NotEmpty(t, action, "Setup: action should not be empty") + + // #nosec:G204 - we control the command arguments in tests + cmd := exec.Command("go", "run") + if testutils.CoverDirForTests() != "" { + // -cover is a "positional flag", so it needs to come right after the "build" command. + cmd.Args = append(cmd.Args, "-cover") + cmd.Env = testutils.AppendCovEnv(env) + } + if testutils.IsRace() { + cmd.Args = append(cmd.Args, "-race") + } + cmd.Env = append(cmd.Env, nssLibraryEnv...) + cmd.Env = append(cmd.Env, + fmt.Sprintf("AUTHD_NSS_SOCKET=%s", defaultSocket), + "ACTION="+action, + "ACTION_ARG="+arg, + ) + cmd.Env = append(cmd.Env, os.Environ()...) + + cmd.Dir = "pid_abuser" + cmd.Args = append(cmd.Args, "./") + var stdout bytes.Buffer + var stderr bytes.Buffer + cmd.Stdout = &stdout + cmd.Stderr = &stderr + err := cmd.Run() + require.NoError(t, err, "Could not run PID abuser: %s, %s", + stdout.String(), stderr.String()) + t.Logf("STDOUT:\n%s", stdout.String()) + t.Logf("STDERR:\n%s", stderr.String()) + return stdout.Bytes() + } + + t.Run("Simulate_running_as_authd", func(t *testing.T) { + tests := map[string]struct { + action string + arg string + + want any + }{ + "Lookups_user": { + action: "lookup_user", + arg: "user1", + want: user.User{ + Uid: "1111", + Gid: "11111", + Username: "user1", + Name: "User1 gecos\nOn multiple lines", + HomeDir: "/home/user1", + }, + }, + "Lookups_group": { + action: "lookup_group", + arg: "group1", + want: user.Group{Gid: "11111", Name: "group1"}, + }, + "Lookups_uid": { + action: "lookup_uid", + arg: "1111", + want: user.User{ + Uid: "1111", + Gid: "11111", + Username: "user1", + Name: "User1 gecos\nOn multiple lines", + HomeDir: "/home/user1", + }, + }, + "Lookups_gid": { + action: "lookup_gid", + arg: "11111", + want: user.Group{Gid: "11111", Name: "group1"}, + }, + } + for name, tc := range tests { + t.Run(name, func(t *testing.T) { + t.Parallel() + + ret := runPidAbuser(tc.action, tc.arg) + + switch action, _ := strings.CutPrefix(tc.action, "lookup_"); action { + case "user": + fallthrough + case "uid": + u := unmarshalYAML[user.User](t, ret) + require.Equal(t, tc.want, u, "User does not match") + case "group": + fallthrough + case "gid": + g := unmarshalYAML[user.Group](t, ret) + require.Equal(t, tc.want, g, "Group does not match") + } + }) + } + }) +} + +func unmarshalYAML[T any](t *testing.T, yml []byte) T { + t.Helper() + + var val T + err := yaml.Unmarshal(yml, &val) + require.NoError(t, err, "Unmarshalling failed:\n%q", yml) + return val } func TestMockgpasswd(t *testing.T) { diff --git a/nss/integration-tests/pid_abuser/pidabuser.go b/nss/integration-tests/pid_abuser/pidabuser.go new file mode 100644 index 0000000000..732169ee62 --- /dev/null +++ b/nss/integration-tests/pid_abuser/pidabuser.go @@ -0,0 +1,49 @@ +// TiCS: disabled // This file is a test helper. + +// Package main is the package for the pid abuser test tool. +package main + +import ( + "fmt" + "os" + "os/user" + "strconv" + + "gopkg.in/yaml.v3" +) + +func main() { + os.Setenv("AUTHD_PID", strconv.FormatInt(int64(os.Getpid()), 10)) + + action := os.Getenv("ACTION") + actionArg := os.Getenv("ACTION_ARG") + + switch action { + case "lookup_user": + outputAsYAMLOrFail(user.Lookup(actionArg)) + + case "lookup_group": + outputAsYAMLOrFail(user.LookupGroup(actionArg)) + + case "lookup_uid": + outputAsYAMLOrFail(user.LookupId(actionArg)) + + case "lookup_gid": + outputAsYAMLOrFail(user.LookupGroupId(actionArg)) + + default: + panic("Invalid action " + action) + } +} + +func outputAsYAMLOrFail[T any](val T, err error) { + if err != nil { + panic(err) + } + + out, err := yaml.Marshal(val) + if err != nil { + panic(err) + } + fmt.Printf("%s\n", out) +} diff --git a/nss/src/client/mod.rs b/nss/src/client/mod.rs index a65515ee36..755dd37746 100644 --- a/nss/src/client/mod.rs +++ b/nss/src/client/mod.rs @@ -1,6 +1,7 @@ use authd::user_service_client::UserServiceClient; use hyper_util::rt::TokioIo; use std::error::Error; +use std::sync::OnceLock; use tokio::net::UnixStream; use tonic::transport::{Channel, Endpoint, Uri}; use tower::service_fn; @@ -11,18 +12,79 @@ pub mod authd { tonic::include_proto!("authd"); } +const AUTHD_PID_ENV_VAR: &str = "AUTHD_PID"; + /// new_client creates a new client connection to the gRPC server or returns an active one. pub async fn new_client() -> Result, Box> { info!("Connecting to authd on {}...", super::socket_path()); + // Cache for self-check result. + static AUTHD_PROCESS_CHECK: OnceLock = OnceLock::new(); + + let connector = service_fn(|_: Uri| async { + let stream = UnixStream::connect(super::socket_path()).await?; + + if *AUTHD_PROCESS_CHECK.get_or_init(|| check_is_authd_process(&stream)) { + info!("Module loaded by authd itself: ignoring the connection"); + + return Err(std::io::Error::new( + std::io::ErrorKind::Unsupported, + "Ignoring connection from authd to authd itself", + )); + } + + Ok::<_, std::io::Error>(TokioIo::new(stream)) + }); + // The URL must have a valid format, even though we don't use it. let ch = Endpoint::try_from("https://not-used:404")? .connect_timeout(CONNECTION_TIMEOUT) - .connect_with_connector(service_fn(|_: Uri| async { - let stream = UnixStream::connect(super::socket_path()).await?; - Ok::<_, std::io::Error>(TokioIo::new(stream)) - })) + .connect_with_connector(connector) .await?; Ok(UserServiceClient::new(ch)) } + +fn check_is_authd_process(stream: &UnixStream) -> bool { + // Check if we've been launched with a AUTHD_PID env variable set with + // a numeric value. If these checks fail, we can just continue with the + // connection as we were. As for sure the library has not been loaded + // by authd. + let Ok(authd_pid) = std::env::var(AUTHD_PID_ENV_VAR) else { + return false; + }; + info!( + "authd module launched with {}={}", + AUTHD_PID_ENV_VAR, authd_pid + ); + let Ok(authd_pid_value) = authd_pid.parse::() else { + return false; + }; + + let current_pid = std::process::id(); + info!("current PID is {}", current_pid); + if current_pid != authd_pid_value { + return false; + } + + // Get the peer credentials, and check if the server PIDs matches the + // AUTHD_PID, an if it does, we can avoid any connection since we're + // sure that we have been loaded by authd (and not by another crafted + // client to act like it, to ignore the authd module) + let Ok(peer_cred) = stream.peer_cred() else { + return false; + }; + let Some(peer_pid) = peer_cred.pid() else { + return false; + }; + + info!( + "authd socket is provided by PID {} (expecting {})", + peer_pid, authd_pid + ); + if authd_pid_value != peer_pid.try_into().unwrap() { + return false; + } + + return true; +} diff --git a/nss/src/passwd/mod.rs b/nss/src/passwd/mod.rs index 2bab9ec006..97beed58a5 100644 --- a/nss/src/passwd/mod.rs +++ b/nss/src/passwd/mod.rs @@ -3,6 +3,7 @@ use libc::uid_t; use libnss::interop::Response; use libnss::passwd::{Passwd, PasswdHooks}; use std::path::PathBuf; +use std::sync::OnceLock; use tokio::runtime::Builder; use tonic::Request; @@ -178,13 +179,18 @@ fn is_proc_matching(pid: u32, name: &str) -> bool { /// should_pre_check returns true if the current process sshd or a child of sshd. #[allow(unreachable_code)] // This function body is overridden in integration tests, so we need to ignore the warning. fn should_pre_check() -> bool { - #[cfg(feature = "should_pre_check_env")] - return std::env::var("AUTHD_NSS_SHOULD_PRE_CHECK").is_ok(); + static SHOULD_PRE_CHECK: OnceLock = OnceLock::new(); - let pid = std::process::id(); - if is_proc_matching(pid, SSHD_BINARY_PATH) { - return true; - } + *SHOULD_PRE_CHECK.get_or_init(|| { + #[cfg(feature = "should_pre_check_env")] + return std::env::var("AUTHD_NSS_SHOULD_PRE_CHECK").is_ok(); + + let pid = std::process::id(); - is_proc_matching(std::os::unix::process::parent_id(), SSHD_BINARY_PATH) + if is_proc_matching(pid, SSHD_BINARY_PATH) { + return true; + } + + is_proc_matching(std::os::unix::process::parent_id(), SSHD_BINARY_PATH) + }) }