Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 6 additions & 10 deletions auth_requestor_plymouth.go
Original file line number Diff line number Diff line change
Expand Up @@ -32,25 +32,21 @@ import (
// PlymouthAuthRequestorStringer is used by the Plymouth implementation
// of [AuthRequestor] to obtain translated strings.
type PlymouthAuthRequestorStringer interface {
// RequestUserCredentialFormatString returns a format string used by
// RequestUserCredential to construct a message that is used to request
// credentials with the supplied auth types. The returned format string
// is interpreted with the following parameters:
// - %[1]s: A human readable name for the storage container.
// - %[2]s: The path of the encrypted storage container.
RequestUserCredentialFormatString(authTypes UserAuthType) (string, error)
// RequestUserCredentialString returns messages used by RequestUserCredential. The
// name is a string supplied via the WithAuthRequestorUserVisibleName option, and the
// path is the storage container path.
RequestUserCredentialString(name, path string, authTypes UserAuthType) (string, error)
}

type plymouthAuthRequestor struct {
stringer PlymouthAuthRequestorStringer
}

func (r *plymouthAuthRequestor) RequestUserCredential(ctx context.Context, name, path string, authTypes UserAuthType) (string, UserAuthType, error) {
fmtString, err := r.stringer.RequestUserCredentialFormatString(authTypes)
msg, err := r.stringer.RequestUserCredentialString(name, path, authTypes)
if err != nil {
return "", 0, fmt.Errorf("cannot request format string for requested auth types: %w", err)
return "", 0, fmt.Errorf("cannot request message string: %w", err)
}
msg := fmt.Sprintf(fmtString, name, path)

cmd := exec.CommandContext(
ctx, "plymouth", "ask-for-password",
Expand Down
22 changes: 12 additions & 10 deletions auth_requestor_plymouth_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -60,29 +60,31 @@ type mockPlymouthAuthRequestorStringer struct {
rucErr error
}

func (s *mockPlymouthAuthRequestorStringer) RequestUserCredentialFormatString(authType UserAuthType) (string, error) {
func (s *mockPlymouthAuthRequestorStringer) RequestUserCredentialString(name, path string, authType UserAuthType) (string, error) {
if s.rucErr != nil {
return "", s.rucErr
}

var fmtString string
switch authType {
case UserAuthTypePassphrase:
return "Enter passphrase for %[1]s (%[2]s):", nil
fmtString = "Enter passphrase for %s (%s):"
case UserAuthTypePIN:
return "Enter PIN for %[1]s (%[2]s):", nil
fmtString = "Enter PIN for %s (%s):"
case UserAuthTypeRecoveryKey:
return "Enter recovery key for %[1]s (%[2]s):", nil
fmtString = "Enter recovery key for %s (%s):"
case UserAuthTypePassphrase | UserAuthTypePIN:
return "Enter passphrase or PIN for %[1]s (%[2]s):", nil
fmtString = "Enter passphrase or PIN for %s (%s):"
case UserAuthTypePassphrase | UserAuthTypeRecoveryKey:
return "Enter passphrase or recovery key for %[1]s (%[2]s):", nil
fmtString = "Enter passphrase or recovery key for %s (%s):"
case UserAuthTypePIN | UserAuthTypeRecoveryKey:
return "Enter PIN or recovery key for %[1]s (%[2]s):", nil
fmtString = "Enter PIN or recovery key for %s (%s):"
case UserAuthTypePassphrase | UserAuthTypePIN | UserAuthTypeRecoveryKey:
return "Enter passphrase, PIN or recovery key for %[1]s (%[2]s):", nil
fmtString = "Enter passphrase, PIN or recovery key for %s (%s):"
default:
return "", errors.New("unexpected UserAuthType")
}
return fmt.Sprintf(fmtString, name, path), nil
}

type testPlymouthRequestUserCredentialsParams struct {
Expand Down Expand Up @@ -226,14 +228,14 @@ func (s *authRequestorPlymouthSuite) TestNewRequestorNoStringer(c *C) {
c.Check(err, ErrorMatches, `must supply an implementation of PlymouthAuthRequestorStringer`)
}

func (s *authRequestorPlymouthSuite) TestRequestUserCredentialObtainFormatStringError(c *C) {
func (s *authRequestorPlymouthSuite) TestRequestUserCredentialObtainMessageError(c *C) {
requestor, err := NewPlymouthAuthRequestor(&mockPlymouthAuthRequestorStringer{
rucErr: errors.New("some error"),
})
c.Assert(err, IsNil)

_, _, err = requestor.RequestUserCredential(context.Background(), "data", "/dev/sda1", UserAuthTypePassphrase)
c.Check(err, ErrorMatches, `cannot request format string for requested auth types: some error`)
c.Check(err, ErrorMatches, `cannot request message string: some error`)
}

func (s *authRequestorPlymouthSuite) TestRequestUserCredentialFailure(c *C) {
Expand Down
26 changes: 14 additions & 12 deletions auth_requestor_systemd.go
Original file line number Diff line number Diff line change
Expand Up @@ -30,16 +30,21 @@ import (
"strings"
)

// SystemdAuthRequestorStringFn is a callback used to supply translated messages
// to the systemd implementation of AuthRequestor.RequestUserCredential. The name
// is a string supplied via the [WithAuthRequestorUserVisibleName] option, and the
// path is the storage container path.
type SystemdAuthRequestorStringFn func(name, path string, authTypes UserAuthType) (string, error)

type systemdAuthRequestor struct {
formatStringFn func(UserAuthType) (string, error)
stringFn SystemdAuthRequestorStringFn
}

func (r *systemdAuthRequestor) RequestUserCredential(ctx context.Context, name, path string, authTypes UserAuthType) (string, UserAuthType, error) {
fmtString, err := r.formatStringFn(authTypes)
msg, err := r.stringFn(name, path, authTypes)
if err != nil {
return "", 0, fmt.Errorf("cannot request format string for requested auth types: %w", err)
return "", 0, fmt.Errorf("cannot request message string: %w", err)
}
msg := fmt.Sprintf(fmtString, name, path)

cmd := exec.CommandContext(
ctx, "systemd-ask-password",
Expand All @@ -62,15 +67,12 @@ func (r *systemdAuthRequestor) RequestUserCredential(ctx context.Context, name,

// NewSystemdAuthRequestor creates an implementation of AuthRequestor that
// delegates to the systemd-ask-password binary. The caller supplies a callback
// to map user auth type combinations to format strings that are used to
// messages.The format strings are interpreted with the following parameters:
// - %[1]s: A human readable name for the storage container.
// - %[2]s: The path of the encrypted storage container.
func NewSystemdAuthRequestor(formatStringFn func(UserAuthType) (string, error)) (AuthRequestor, error) {
if formatStringFn == nil {
return nil, errors.New("must supply a callback to obtain format strings for requesting user credentials")
// to supply messages for user auth requests.
func NewSystemdAuthRequestor(stringFn SystemdAuthRequestorStringFn) (AuthRequestor, error) {
if stringFn == nil {
return nil, errors.New("must supply a SystemdAuthRequestorStringFn")
}
return &systemdAuthRequestor{
formatStringFn: formatStringFn,
stringFn: stringFn,
}, nil
}
32 changes: 17 additions & 15 deletions auth_requestor_systemd_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -71,25 +71,27 @@ type testSystemdRequestUserCredentialsParams struct {
func (s *authRequestorSystemdSuite) testRequestUserCredential(c *C, params *testSystemdRequestUserCredentialsParams) {
s.setPassphrase(c, params.passphrase)

requestor, err := NewSystemdAuthRequestor(func(authType UserAuthType) (string, error) {
requestor, err := NewSystemdAuthRequestor(func(name, path string, authType UserAuthType) (string, error) {
var fmtString string
switch authType {
case UserAuthTypePassphrase:
return "Enter passphrase for %[1]s (%[2]s):", nil
fmtString = "Enter passphrase for %s (%s):"
case UserAuthTypePIN:
return "Enter PIN for %[1]s (%[2]s):", nil
fmtString = "Enter PIN for %s (%s):"
case UserAuthTypeRecoveryKey:
return "Enter recovery key for %[1]s (%[2]s):", nil
fmtString = "Enter recovery key for %s (%s):"
case UserAuthTypePassphrase | UserAuthTypePIN:
return "Enter passphrase or PIN for %[1]s (%[2]s):", nil
fmtString = "Enter passphrase or PIN for %s (%s):"
case UserAuthTypePassphrase | UserAuthTypeRecoveryKey:
return "Enter passphrase or recovery key for %[1]s (%[2]s):", nil
fmtString = "Enter passphrase or recovery key for %s (%s):"
case UserAuthTypePIN | UserAuthTypeRecoveryKey:
return "Enter PIN or recovery key for %[1]s (%[2]s):", nil
fmtString = "Enter PIN or recovery key for %s (%s):"
case UserAuthTypePassphrase | UserAuthTypePIN | UserAuthTypeRecoveryKey:
return "Enter passphrase, PIN or recovery key for %[1]s (%[2]s):", nil
fmtString = "Enter passphrase, PIN or recovery key for %s (%s):"
default:
return "", errors.New("unexpected UserAuthType")
}
return fmt.Sprintf(fmtString, name, path), nil
})
c.Assert(err, IsNil)

Expand Down Expand Up @@ -215,23 +217,23 @@ func (s *authRequestorSystemdSuite) TestRequestUserCredentialPassphraseOrPINOrRe

func (s *authRequestorSystemdSuite) TestNewRequestorNoFormatStringCallback(c *C) {
_, err := NewSystemdAuthRequestor(nil)
c.Check(err, ErrorMatches, `must supply a callback to obtain format strings for requesting user credentials`)
c.Check(err, ErrorMatches, `must supply a SystemdAuthRequestorStringFn`)
}

func (s *authRequestorSystemdSuite) TestRequestUserCredentialObtainFormatStringError(c *C) {
requestor, err := NewSystemdAuthRequestor(func(UserAuthType) (string, error) {
func (s *authRequestorSystemdSuite) TestRequestUserCredentialObtainMessageError(c *C) {
requestor, err := NewSystemdAuthRequestor(func(string, string, UserAuthType) (string, error) {
return "", errors.New("some error")
})
c.Assert(err, IsNil)

_, _, err = requestor.RequestUserCredential(context.Background(), "data", "/dev/sda1", UserAuthTypePassphrase)
c.Check(err, ErrorMatches, `cannot request format string for requested auth types: some error`)
c.Check(err, ErrorMatches, `cannot request message string: some error`)
}

func (s *authRequestorSystemdSuite) TestRequestUserCredentialInvalidResponse(c *C) {
c.Assert(ioutil.WriteFile(s.passwordFile, []byte("foo"), 0600), IsNil)

requestor, err := NewSystemdAuthRequestor(func(UserAuthType) (string, error) {
requestor, err := NewSystemdAuthRequestor(func(string, string, UserAuthType) (string, error) {
return "", nil
})
c.Assert(err, IsNil)
Expand All @@ -241,7 +243,7 @@ func (s *authRequestorSystemdSuite) TestRequestUserCredentialInvalidResponse(c *
}

func (s *authRequestorSystemdSuite) TestRequestUserCredentialFailure(c *C) {
requestor, err := NewSystemdAuthRequestor(func(UserAuthType) (string, error) {
requestor, err := NewSystemdAuthRequestor(func(string, string, UserAuthType) (string, error) {
return "", nil
})
c.Assert(err, IsNil)
Expand All @@ -253,7 +255,7 @@ func (s *authRequestorSystemdSuite) TestRequestUserCredentialFailure(c *C) {
func (s *authRequestorSystemdSuite) TestRequestUserCredentialCanceledContext(c *C) {
c.Assert(ioutil.WriteFile(s.passwordFile, []byte("foo"), 0600), IsNil)

requestor, err := NewSystemdAuthRequestor(func(UserAuthType) (string, error) {
requestor, err := NewSystemdAuthRequestor(func(string, string, UserAuthType) (string, error) {
return "", nil
})
c.Assert(err, IsNil)
Expand Down
Loading