From 93315acd039cb0cf84312dca5112804e090242da Mon Sep 17 00:00:00 2001 From: Adrian Dombeck Date: Mon, 15 Dec 2025 17:55:25 +0100 Subject: [PATCH 01/14] Add gRPC method SetShell --- internal/proto/authd/authd.pb.go | 179 ++++++++++++------ internal/proto/authd/authd.proto | 6 + internal/proto/authd/authd_grpc.pb.go | 38 ++++ internal/services/permissions/permissions.go | 46 ++++- .../testdata/golden/TestRegisterGRPCServices | 3 + internal/services/user/user.go | 25 +++ internal/users/db/update.go | 17 ++ internal/users/manager.go | 27 +++ internal/users/userutils.go | 68 +++++++ 9 files changed, 342 insertions(+), 67 deletions(-) create mode 100644 internal/users/userutils.go diff --git a/internal/proto/authd/authd.pb.go b/internal/proto/authd/authd.pb.go index 556a17b4ac..b0a71ba675 100644 --- a/internal/proto/authd/authd.pb.go +++ b/internal/proto/authd/authd.pb.go @@ -1413,6 +1413,58 @@ func (x *SetGroupIDResponse) GetWarnings() []string { return nil } +type SetShellRequest struct { + state protoimpl.MessageState `protogen:"open.v1"` + Name string `protobuf:"bytes,1,opt,name=name,proto3" json:"name,omitempty"` + Shell string `protobuf:"bytes,2,opt,name=shell,proto3" json:"shell,omitempty"` + unknownFields protoimpl.UnknownFields + sizeCache protoimpl.SizeCache +} + +func (x *SetShellRequest) Reset() { + *x = SetShellRequest{} + mi := &file_authd_proto_msgTypes[26] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) +} + +func (x *SetShellRequest) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*SetShellRequest) ProtoMessage() {} + +func (x *SetShellRequest) ProtoReflect() protoreflect.Message { + mi := &file_authd_proto_msgTypes[26] + if x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use SetShellRequest.ProtoReflect.Descriptor instead. +func (*SetShellRequest) Descriptor() ([]byte, []int) { + return file_authd_proto_rawDescGZIP(), []int{26} +} + +func (x *SetShellRequest) GetName() string { + if x != nil { + return x.Name + } + return "" +} + +func (x *SetShellRequest) GetShell() string { + if x != nil { + return x.Shell + } + return "" +} + type User struct { state protoimpl.MessageState `protogen:"open.v1"` Name string `protobuf:"bytes,1,opt,name=name,proto3" json:"name,omitempty"` @@ -1427,7 +1479,7 @@ type User struct { func (x *User) Reset() { *x = User{} - mi := &file_authd_proto_msgTypes[26] + mi := &file_authd_proto_msgTypes[27] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -1439,7 +1491,7 @@ func (x *User) String() string { func (*User) ProtoMessage() {} func (x *User) ProtoReflect() protoreflect.Message { - mi := &file_authd_proto_msgTypes[26] + mi := &file_authd_proto_msgTypes[27] if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -1452,7 +1504,7 @@ func (x *User) ProtoReflect() protoreflect.Message { // Deprecated: Use User.ProtoReflect.Descriptor instead. func (*User) Descriptor() ([]byte, []int) { - return file_authd_proto_rawDescGZIP(), []int{26} + return file_authd_proto_rawDescGZIP(), []int{27} } func (x *User) GetName() string { @@ -1506,7 +1558,7 @@ type Users struct { func (x *Users) Reset() { *x = Users{} - mi := &file_authd_proto_msgTypes[27] + mi := &file_authd_proto_msgTypes[28] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -1518,7 +1570,7 @@ func (x *Users) String() string { func (*Users) ProtoMessage() {} func (x *Users) ProtoReflect() protoreflect.Message { - mi := &file_authd_proto_msgTypes[27] + mi := &file_authd_proto_msgTypes[28] if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -1531,7 +1583,7 @@ func (x *Users) ProtoReflect() protoreflect.Message { // Deprecated: Use Users.ProtoReflect.Descriptor instead. func (*Users) Descriptor() ([]byte, []int) { - return file_authd_proto_rawDescGZIP(), []int{27} + return file_authd_proto_rawDescGZIP(), []int{28} } func (x *Users) GetUsers() []*User { @@ -1554,7 +1606,7 @@ type Group struct { func (x *Group) Reset() { *x = Group{} - mi := &file_authd_proto_msgTypes[28] + mi := &file_authd_proto_msgTypes[29] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -1566,7 +1618,7 @@ func (x *Group) String() string { func (*Group) ProtoMessage() {} func (x *Group) ProtoReflect() protoreflect.Message { - mi := &file_authd_proto_msgTypes[28] + mi := &file_authd_proto_msgTypes[29] if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -1579,7 +1631,7 @@ func (x *Group) ProtoReflect() protoreflect.Message { // Deprecated: Use Group.ProtoReflect.Descriptor instead. func (*Group) Descriptor() ([]byte, []int) { - return file_authd_proto_rawDescGZIP(), []int{28} + return file_authd_proto_rawDescGZIP(), []int{29} } func (x *Group) GetName() string { @@ -1619,7 +1671,7 @@ type Groups struct { func (x *Groups) Reset() { *x = Groups{} - mi := &file_authd_proto_msgTypes[29] + mi := &file_authd_proto_msgTypes[30] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -1631,7 +1683,7 @@ func (x *Groups) String() string { func (*Groups) ProtoMessage() {} func (x *Groups) ProtoReflect() protoreflect.Message { - mi := &file_authd_proto_msgTypes[29] + mi := &file_authd_proto_msgTypes[30] if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -1644,7 +1696,7 @@ func (x *Groups) ProtoReflect() protoreflect.Message { // Deprecated: Use Groups.ProtoReflect.Descriptor instead. func (*Groups) Descriptor() ([]byte, []int) { - return file_authd_proto_rawDescGZIP(), []int{29} + return file_authd_proto_rawDescGZIP(), []int{30} } func (x *Groups) GetGroups() []*Group { @@ -1665,7 +1717,7 @@ type ABResponse_BrokerInfo struct { func (x *ABResponse_BrokerInfo) Reset() { *x = ABResponse_BrokerInfo{} - mi := &file_authd_proto_msgTypes[30] + mi := &file_authd_proto_msgTypes[31] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -1677,7 +1729,7 @@ func (x *ABResponse_BrokerInfo) String() string { func (*ABResponse_BrokerInfo) ProtoMessage() {} func (x *ABResponse_BrokerInfo) ProtoReflect() protoreflect.Message { - mi := &file_authd_proto_msgTypes[30] + mi := &file_authd_proto_msgTypes[31] if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -1724,7 +1776,7 @@ type GAMResponse_AuthenticationMode struct { func (x *GAMResponse_AuthenticationMode) Reset() { *x = GAMResponse_AuthenticationMode{} - mi := &file_authd_proto_msgTypes[31] + mi := &file_authd_proto_msgTypes[32] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -1736,7 +1788,7 @@ func (x *GAMResponse_AuthenticationMode) String() string { func (*GAMResponse_AuthenticationMode) ProtoMessage() {} func (x *GAMResponse_AuthenticationMode) ProtoReflect() protoreflect.Message { - mi := &file_authd_proto_msgTypes[31] + mi := &file_authd_proto_msgTypes[32] if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -1781,7 +1833,7 @@ type IARequest_AuthenticationData struct { func (x *IARequest_AuthenticationData) Reset() { *x = IARequest_AuthenticationData{} - mi := &file_authd_proto_msgTypes[32] + mi := &file_authd_proto_msgTypes[33] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -1793,7 +1845,7 @@ func (x *IARequest_AuthenticationData) String() string { func (*IARequest_AuthenticationData) ProtoMessage() {} func (x *IARequest_AuthenticationData) ProtoReflect() protoreflect.Message { - mi := &file_authd_proto_msgTypes[32] + mi := &file_authd_proto_msgTypes[33] if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -1998,7 +2050,10 @@ const file_authd_proto_rawDesc = "" + "\n" + "id_changed\x18\x01 \x01(\bR\tidChanged\x123\n" + "\x16home_dir_owner_changed\x18\x02 \x01(\bR\x13homeDirOwnerChanged\x12\x1a\n" + - "\bwarnings\x18\x03 \x03(\tR\bwarnings\"\x84\x01\n" + + "\bwarnings\x18\x03 \x03(\tR\bwarnings\";\n" + + "\x0fSetShellRequest\x12\x12\n" + + "\x04name\x18\x01 \x01(\tR\x04name\x12\x14\n" + + "\x05shell\x18\x02 \x01(\tR\x05shell\"\x84\x01\n" + "\x04User\x12\x12\n" + "\x04name\x18\x01 \x01(\tR\x04name\x12\x10\n" + "\x03uid\x18\x02 \x01(\rR\x03uid\x12\x10\n" + @@ -2028,7 +2083,7 @@ const file_authd_proto_rawDesc = "" + "\x0fIsAuthenticated\x12\x10.authd.IARequest\x1a\x11.authd.IAResponse\x12,\n" + "\n" + "EndSession\x12\x10.authd.ESRequest\x1a\f.authd.Empty\x12<\n" + - "\x17SetDefaultBrokerForUser\x12\x13.authd.SDBFURequest\x1a\f.authd.Empty2\xb6\x04\n" + + "\x17SetDefaultBrokerForUser\x12\x13.authd.SDBFURequest\x1a\f.authd.Empty2\xe8\x04\n" + "\vUserService\x129\n" + "\rGetUserByName\x12\x1b.authd.GetUserByNameRequest\x1a\v.authd.User\x125\n" + "\vGetUserByID\x12\x19.authd.GetUserByIDRequest\x1a\v.authd.User\x12'\n" + @@ -2038,7 +2093,8 @@ const file_authd_proto_rawDesc = "" + "UnlockUser\x12\x18.authd.UnlockUserRequest\x1a\f.authd.Empty\x12>\n" + "\tSetUserID\x12\x17.authd.SetUserIDRequest\x1a\x18.authd.SetUserIDResponse\x12A\n" + "\n" + - "SetGroupID\x12\x18.authd.SetGroupIDRequest\x1a\x19.authd.SetGroupIDResponse\x12<\n" + + "SetGroupID\x12\x18.authd.SetGroupIDRequest\x1a\x19.authd.SetGroupIDResponse\x120\n" + + "\bSetShell\x12\x16.authd.SetShellRequest\x1a\f.authd.Empty\x12<\n" + "\x0eGetGroupByName\x12\x1c.authd.GetGroupByNameRequest\x1a\f.authd.Group\x128\n" + "\fGetGroupByID\x12\x1a.authd.GetGroupByIDRequest\x1a\f.authd.Group\x12)\n" + "\n" + @@ -2057,7 +2113,7 @@ func file_authd_proto_rawDescGZIP() []byte { } var file_authd_proto_enumTypes = make([]protoimpl.EnumInfo, 1) -var file_authd_proto_msgTypes = make([]protoimpl.MessageInfo, 33) +var file_authd_proto_msgTypes = make([]protoimpl.MessageInfo, 34) var file_authd_proto_goTypes = []any{ (SessionMode)(0), // 0: authd.SessionMode (*Empty)(nil), // 1: authd.Empty @@ -2086,23 +2142,24 @@ var file_authd_proto_goTypes = []any{ (*SetUserIDResponse)(nil), // 24: authd.SetUserIDResponse (*SetGroupIDRequest)(nil), // 25: authd.SetGroupIDRequest (*SetGroupIDResponse)(nil), // 26: authd.SetGroupIDResponse - (*User)(nil), // 27: authd.User - (*Users)(nil), // 28: authd.Users - (*Group)(nil), // 29: authd.Group - (*Groups)(nil), // 30: authd.Groups - (*ABResponse_BrokerInfo)(nil), // 31: authd.ABResponse.BrokerInfo - (*GAMResponse_AuthenticationMode)(nil), // 32: authd.GAMResponse.AuthenticationMode - (*IARequest_AuthenticationData)(nil), // 33: authd.IARequest.AuthenticationData + (*SetShellRequest)(nil), // 27: authd.SetShellRequest + (*User)(nil), // 28: authd.User + (*Users)(nil), // 29: authd.Users + (*Group)(nil), // 30: authd.Group + (*Groups)(nil), // 31: authd.Groups + (*ABResponse_BrokerInfo)(nil), // 32: authd.ABResponse.BrokerInfo + (*GAMResponse_AuthenticationMode)(nil), // 33: authd.GAMResponse.AuthenticationMode + (*IARequest_AuthenticationData)(nil), // 34: authd.IARequest.AuthenticationData } var file_authd_proto_depIdxs = []int32{ - 31, // 0: authd.ABResponse.brokers_infos:type_name -> authd.ABResponse.BrokerInfo + 32, // 0: authd.ABResponse.brokers_infos:type_name -> authd.ABResponse.BrokerInfo 0, // 1: authd.SBRequest.mode:type_name -> authd.SessionMode 9, // 2: authd.GAMRequest.supported_ui_layouts:type_name -> authd.UILayout - 32, // 3: authd.GAMResponse.authentication_modes:type_name -> authd.GAMResponse.AuthenticationMode + 33, // 3: authd.GAMResponse.authentication_modes:type_name -> authd.GAMResponse.AuthenticationMode 9, // 4: authd.SAMResponse.ui_layout_info:type_name -> authd.UILayout - 33, // 5: authd.IARequest.authentication_data:type_name -> authd.IARequest.AuthenticationData - 27, // 6: authd.Users.users:type_name -> authd.User - 29, // 7: authd.Groups.groups:type_name -> authd.Group + 34, // 5: authd.IARequest.authentication_data:type_name -> authd.IARequest.AuthenticationData + 28, // 6: authd.Users.users:type_name -> authd.User + 30, // 7: authd.Groups.groups:type_name -> authd.Group 1, // 8: authd.PAM.AvailableBrokers:input_type -> authd.Empty 2, // 9: authd.PAM.GetPreviousBroker:input_type -> authd.GPBRequest 6, // 10: authd.PAM.SelectBroker:input_type -> authd.SBRequest @@ -2118,29 +2175,31 @@ var file_authd_proto_depIdxs = []int32{ 20, // 20: authd.UserService.UnlockUser:input_type -> authd.UnlockUserRequest 23, // 21: authd.UserService.SetUserID:input_type -> authd.SetUserIDRequest 25, // 22: authd.UserService.SetGroupID:input_type -> authd.SetGroupIDRequest - 21, // 23: authd.UserService.GetGroupByName:input_type -> authd.GetGroupByNameRequest - 22, // 24: authd.UserService.GetGroupByID:input_type -> authd.GetGroupByIDRequest - 1, // 25: authd.UserService.ListGroups:input_type -> authd.Empty - 4, // 26: authd.PAM.AvailableBrokers:output_type -> authd.ABResponse - 3, // 27: authd.PAM.GetPreviousBroker:output_type -> authd.GPBResponse - 7, // 28: authd.PAM.SelectBroker:output_type -> authd.SBResponse - 10, // 29: authd.PAM.GetAuthenticationModes:output_type -> authd.GAMResponse - 12, // 30: authd.PAM.SelectAuthenticationMode:output_type -> authd.SAMResponse - 14, // 31: authd.PAM.IsAuthenticated:output_type -> authd.IAResponse - 1, // 32: authd.PAM.EndSession:output_type -> authd.Empty - 1, // 33: authd.PAM.SetDefaultBrokerForUser:output_type -> authd.Empty - 27, // 34: authd.UserService.GetUserByName:output_type -> authd.User - 27, // 35: authd.UserService.GetUserByID:output_type -> authd.User - 28, // 36: authd.UserService.ListUsers:output_type -> authd.Users - 1, // 37: authd.UserService.LockUser:output_type -> authd.Empty - 1, // 38: authd.UserService.UnlockUser:output_type -> authd.Empty - 24, // 39: authd.UserService.SetUserID:output_type -> authd.SetUserIDResponse - 26, // 40: authd.UserService.SetGroupID:output_type -> authd.SetGroupIDResponse - 29, // 41: authd.UserService.GetGroupByName:output_type -> authd.Group - 29, // 42: authd.UserService.GetGroupByID:output_type -> authd.Group - 30, // 43: authd.UserService.ListGroups:output_type -> authd.Groups - 26, // [26:44] is the sub-list for method output_type - 8, // [8:26] is the sub-list for method input_type + 27, // 23: authd.UserService.SetShell:input_type -> authd.SetShellRequest + 21, // 24: authd.UserService.GetGroupByName:input_type -> authd.GetGroupByNameRequest + 22, // 25: authd.UserService.GetGroupByID:input_type -> authd.GetGroupByIDRequest + 1, // 26: authd.UserService.ListGroups:input_type -> authd.Empty + 4, // 27: authd.PAM.AvailableBrokers:output_type -> authd.ABResponse + 3, // 28: authd.PAM.GetPreviousBroker:output_type -> authd.GPBResponse + 7, // 29: authd.PAM.SelectBroker:output_type -> authd.SBResponse + 10, // 30: authd.PAM.GetAuthenticationModes:output_type -> authd.GAMResponse + 12, // 31: authd.PAM.SelectAuthenticationMode:output_type -> authd.SAMResponse + 14, // 32: authd.PAM.IsAuthenticated:output_type -> authd.IAResponse + 1, // 33: authd.PAM.EndSession:output_type -> authd.Empty + 1, // 34: authd.PAM.SetDefaultBrokerForUser:output_type -> authd.Empty + 28, // 35: authd.UserService.GetUserByName:output_type -> authd.User + 28, // 36: authd.UserService.GetUserByID:output_type -> authd.User + 29, // 37: authd.UserService.ListUsers:output_type -> authd.Users + 1, // 38: authd.UserService.LockUser:output_type -> authd.Empty + 1, // 39: authd.UserService.UnlockUser:output_type -> authd.Empty + 24, // 40: authd.UserService.SetUserID:output_type -> authd.SetUserIDResponse + 26, // 41: authd.UserService.SetGroupID:output_type -> authd.SetGroupIDResponse + 1, // 42: authd.UserService.SetShell:output_type -> authd.Empty + 30, // 43: authd.UserService.GetGroupByName:output_type -> authd.Group + 30, // 44: authd.UserService.GetGroupByID:output_type -> authd.Group + 31, // 45: authd.UserService.ListGroups:output_type -> authd.Groups + 27, // [27:46] is the sub-list for method output_type + 8, // [8:27] is the sub-list for method input_type 8, // [8:8] is the sub-list for extension type_name 8, // [8:8] is the sub-list for extension extendee 0, // [0:8] is the sub-list for field type_name @@ -2152,8 +2211,8 @@ func file_authd_proto_init() { return } file_authd_proto_msgTypes[8].OneofWrappers = []any{} - file_authd_proto_msgTypes[30].OneofWrappers = []any{} - file_authd_proto_msgTypes[32].OneofWrappers = []any{ + file_authd_proto_msgTypes[31].OneofWrappers = []any{} + file_authd_proto_msgTypes[33].OneofWrappers = []any{ (*IARequest_AuthenticationData_Secret)(nil), (*IARequest_AuthenticationData_Wait)(nil), (*IARequest_AuthenticationData_Skip)(nil), @@ -2165,7 +2224,7 @@ func file_authd_proto_init() { GoPackagePath: reflect.TypeOf(x{}).PkgPath(), RawDescriptor: unsafe.Slice(unsafe.StringData(file_authd_proto_rawDesc), len(file_authd_proto_rawDesc)), NumEnums: 1, - NumMessages: 33, + NumMessages: 34, NumExtensions: 0, NumServices: 2, }, diff --git a/internal/proto/authd/authd.proto b/internal/proto/authd/authd.proto index df108d0f51..5706107bf1 100644 --- a/internal/proto/authd/authd.proto +++ b/internal/proto/authd/authd.proto @@ -137,6 +137,7 @@ service UserService { rpc UnlockUser(UnlockUserRequest) returns (Empty); rpc SetUserID(SetUserIDRequest) returns (SetUserIDResponse); rpc SetGroupID(SetGroupIDRequest) returns (SetGroupIDResponse); + rpc SetShell(SetShellRequest) returns (Empty); rpc GetGroupByName(GetGroupByNameRequest) returns (Group); rpc GetGroupByID(GetGroupByIDRequest) returns (Group); @@ -196,6 +197,11 @@ message SetGroupIDResponse { repeated string warnings = 3; } +message SetShellRequest { + string name = 1; + string shell = 2; +} + message User { string name = 1; uint32 uid = 2; diff --git a/internal/proto/authd/authd_grpc.pb.go b/internal/proto/authd/authd_grpc.pb.go index 8685f6fdca..b6ca785927 100644 --- a/internal/proto/authd/authd_grpc.pb.go +++ b/internal/proto/authd/authd_grpc.pb.go @@ -394,6 +394,7 @@ const ( UserService_UnlockUser_FullMethodName = "/authd.UserService/UnlockUser" UserService_SetUserID_FullMethodName = "/authd.UserService/SetUserID" UserService_SetGroupID_FullMethodName = "/authd.UserService/SetGroupID" + UserService_SetShell_FullMethodName = "/authd.UserService/SetShell" UserService_GetGroupByName_FullMethodName = "/authd.UserService/GetGroupByName" UserService_GetGroupByID_FullMethodName = "/authd.UserService/GetGroupByID" UserService_ListGroups_FullMethodName = "/authd.UserService/ListGroups" @@ -410,6 +411,7 @@ type UserServiceClient interface { UnlockUser(ctx context.Context, in *UnlockUserRequest, opts ...grpc.CallOption) (*Empty, error) SetUserID(ctx context.Context, in *SetUserIDRequest, opts ...grpc.CallOption) (*SetUserIDResponse, error) SetGroupID(ctx context.Context, in *SetGroupIDRequest, opts ...grpc.CallOption) (*SetGroupIDResponse, error) + SetShell(ctx context.Context, in *SetShellRequest, opts ...grpc.CallOption) (*Empty, error) GetGroupByName(ctx context.Context, in *GetGroupByNameRequest, opts ...grpc.CallOption) (*Group, error) GetGroupByID(ctx context.Context, in *GetGroupByIDRequest, opts ...grpc.CallOption) (*Group, error) ListGroups(ctx context.Context, in *Empty, opts ...grpc.CallOption) (*Groups, error) @@ -493,6 +495,16 @@ func (c *userServiceClient) SetGroupID(ctx context.Context, in *SetGroupIDReques return out, nil } +func (c *userServiceClient) SetShell(ctx context.Context, in *SetShellRequest, opts ...grpc.CallOption) (*Empty, error) { + cOpts := append([]grpc.CallOption{grpc.StaticMethod()}, opts...) + out := new(Empty) + err := c.cc.Invoke(ctx, UserService_SetShell_FullMethodName, in, out, cOpts...) + if err != nil { + return nil, err + } + return out, nil +} + func (c *userServiceClient) GetGroupByName(ctx context.Context, in *GetGroupByNameRequest, opts ...grpc.CallOption) (*Group, error) { cOpts := append([]grpc.CallOption{grpc.StaticMethod()}, opts...) out := new(Group) @@ -534,6 +546,7 @@ type UserServiceServer interface { UnlockUser(context.Context, *UnlockUserRequest) (*Empty, error) SetUserID(context.Context, *SetUserIDRequest) (*SetUserIDResponse, error) SetGroupID(context.Context, *SetGroupIDRequest) (*SetGroupIDResponse, error) + SetShell(context.Context, *SetShellRequest) (*Empty, error) GetGroupByName(context.Context, *GetGroupByNameRequest) (*Group, error) GetGroupByID(context.Context, *GetGroupByIDRequest) (*Group, error) ListGroups(context.Context, *Empty) (*Groups, error) @@ -568,6 +581,9 @@ func (UnimplementedUserServiceServer) SetUserID(context.Context, *SetUserIDReque func (UnimplementedUserServiceServer) SetGroupID(context.Context, *SetGroupIDRequest) (*SetGroupIDResponse, error) { return nil, status.Error(codes.Unimplemented, "method SetGroupID not implemented") } +func (UnimplementedUserServiceServer) SetShell(context.Context, *SetShellRequest) (*Empty, error) { + return nil, status.Error(codes.Unimplemented, "method SetShell not implemented") +} func (UnimplementedUserServiceServer) GetGroupByName(context.Context, *GetGroupByNameRequest) (*Group, error) { return nil, status.Error(codes.Unimplemented, "method GetGroupByName not implemented") } @@ -724,6 +740,24 @@ func _UserService_SetGroupID_Handler(srv interface{}, ctx context.Context, dec f return interceptor(ctx, in, info, handler) } +func _UserService_SetShell_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) { + in := new(SetShellRequest) + if err := dec(in); err != nil { + return nil, err + } + if interceptor == nil { + return srv.(UserServiceServer).SetShell(ctx, in) + } + info := &grpc.UnaryServerInfo{ + Server: srv, + FullMethod: UserService_SetShell_FullMethodName, + } + handler := func(ctx context.Context, req interface{}) (interface{}, error) { + return srv.(UserServiceServer).SetShell(ctx, req.(*SetShellRequest)) + } + return interceptor(ctx, in, info, handler) +} + func _UserService_GetGroupByName_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) { in := new(GetGroupByNameRequest) if err := dec(in); err != nil { @@ -813,6 +847,10 @@ var UserService_ServiceDesc = grpc.ServiceDesc{ MethodName: "SetGroupID", Handler: _UserService_SetGroupID_Handler, }, + { + MethodName: "SetShell", + Handler: _UserService_SetShell_Handler, + }, { MethodName: "GetGroupByName", Handler: _UserService_GetGroupByName_Handler, diff --git a/internal/services/permissions/permissions.go b/internal/services/permissions/permissions.go index a2c00eb6d7..b6b8cd4506 100644 --- a/internal/services/permissions/permissions.go +++ b/internal/services/permissions/permissions.go @@ -40,18 +40,50 @@ func New(args ...Option) Manager { // CheckRequestIsFromRoot checks if the current gRPC request is from a root user and returns an error if not. // The pid and uid are extracted from peerCredsInfo in the gRPC context. func (m Manager) CheckRequestIsFromRoot(ctx context.Context) (err error) { + isRoot, err := m.isRequestFromRoot(ctx) + if err != nil { + return err + } + if !isRoot { + return errors.New("only root can perform this operation") + } + return nil +} + +// CheckRequestIsFromRootOrUID checks if the current gRPC request is from a root +// user or a specified user and returns an error if not. +func (m Manager) CheckRequestIsFromRootOrUID(ctx context.Context, uid uint32) (err error) { + isRoot, err := m.isRequestFromRoot(ctx) + if err != nil { + return err + } + if isRoot { + return nil + } + + isFromUID, err := m.isRequestFromUID(ctx, uid) + if err != nil { + return err + } + if !isFromUID { + return errors.New("only root or the specified user can perform this operation") + } + return nil +} + +func (m Manager) isRequestFromRoot(ctx context.Context) (bool, error) { + return m.isRequestFromUID(ctx, m.rootUID) +} + +func (m Manager) isRequestFromUID(ctx context.Context, uid uint32) (bool, error) { p, ok := peer.FromContext(ctx) if !ok { - return errors.New("context request doesn't have gRPC peer information") + return false, errors.New("context request doesn't have gRPC peer information") } pci, ok := p.AuthInfo.(peerCredsInfo) if !ok { - return errors.New("context request doesn't have valid gRPC peer credential information") + return false, errors.New("context request doesn't have valid gRPC peer credential information") } - if pci.uid != m.rootUID { - return errors.New("only root can perform this operation") - } - - return nil + return pci.uid == uid, nil } diff --git a/internal/services/testdata/golden/TestRegisterGRPCServices b/internal/services/testdata/golden/TestRegisterGRPCServices index 871d7f2387..c043749b06 100644 --- a/internal/services/testdata/golden/TestRegisterGRPCServices +++ b/internal/services/testdata/golden/TestRegisterGRPCServices @@ -51,6 +51,9 @@ authd.UserService: - name: SetGroupID isclientstream: false isserverstream: false + - name: SetShell + isclientstream: false + isserverstream: false - name: SetUserID isclientstream: false isserverstream: false diff --git a/internal/services/user/user.go b/internal/services/user/user.go index 4d976d8297..b76e64dedf 100644 --- a/internal/services/user/user.go +++ b/internal/services/user/user.go @@ -270,6 +270,31 @@ func (s Service) SetGroupID(ctx context.Context, req *authd.SetGroupIDRequest) ( }, nil } +// SetShell sets the shell of a user. +func (s Service) SetShell(ctx context.Context, req *authd.SetShellRequest) (*authd.Empty, error) { + // authd uses lowercase group names. + name := strings.ToLower(req.GetName()) + + user, err := s.userManager.UserByName(name) + if errors.Is(err, users.NoDataFoundError{}) { + return nil, status.Errorf(codes.NotFound, "user %q not found", name) + } + if err != nil { + return nil, grpcError(err) + } + + if err := s.permissionManager.CheckRequestIsFromRootOrUID(ctx, user.UID); err != nil { + return nil, status.Error(codes.PermissionDenied, err.Error()) + } + + if err = s.userManager.SetShell(name, req.GetShell()); err != nil { + log.Errorf(ctx, "SetShell: %v", err) + return nil, grpcError(err) + } + + return &authd.Empty{}, nil +} + // userToProtobuf converts a types.UserEntry to authd.User. func userToProtobuf(u types.UserEntry) *authd.User { return &authd.User{ diff --git a/internal/users/db/update.go b/internal/users/db/update.go index 7c420d09f9..f60d10737e 100644 --- a/internal/users/db/update.go +++ b/internal/users/db/update.go @@ -373,3 +373,20 @@ func (m *Manager) SetGroupID(groupName string, newGID uint32) ([]UserRow, error) return users, nil } + +// SetShell updates the shell of a user. +func (m *Manager) SetShell(username, shell string) error { + query := `UPDATE users SET shell = ? WHERE name = ?` + res, err := m.db.Exec(query, shell, username) + if err != nil { + return fmt.Errorf("failed to update shell for user: %w", err) + } + rowsAffected, err := res.RowsAffected() + if err != nil { + return fmt.Errorf("failed to get rows affected: %w", err) + } + if rowsAffected == 0 { + return NewUserNotFoundError(username) + } + return nil +} diff --git a/internal/users/manager.go b/internal/users/manager.go index 3b978de19b..2ab4164b7a 100644 --- a/internal/users/manager.go +++ b/internal/users/manager.go @@ -665,6 +665,33 @@ func checkHomeDirOwner(home string, uid, gid uint32) error { return nil } +// SetShell sets the shell for the given user. +func (m *Manager) SetShell(username, shell string) (err error) { + if username == "" { + return errors.New("empty username") + } + + err = checkValidShell(shell) + if err != nil { + return err + } + + stat, err := os.Stat(shell) + if errors.Is(err, os.ErrNotExist) { + return fmt.Errorf("shell %q does not exist", shell) + } + + if stat.IsDir() || stat.Mode()&0111 == 0 { + return fmt.Errorf("shell %q is not executable", shell) + } + + if err = m.db.SetShell(username, shell); err != nil { + return err + } + + return nil +} + // BrokerForUser returns the broker ID for the given user. func (m *Manager) BrokerForUser(username string) (string, error) { u, err := m.db.UserByName(username) diff --git a/internal/users/userutils.go b/internal/users/userutils.go new file mode 100644 index 0000000000..db254c2600 --- /dev/null +++ b/internal/users/userutils.go @@ -0,0 +1,68 @@ +package users + +import ( + "errors" + "os" + "path" + "strings" + "unicode/utf8" + + "golang.org/x/sys/unix" +) + +func checkValidPasswdField(value string) (err error) { + if value == "" { + return errors.New("value cannot be empty") + } + + if !utf8.ValidString(value) { + return errors.New("value must be valid UTF-8") + } + + if strings.ContainsRune(value, ':') { + return errors.New("value cannot contain ':' character") + } + + for _, r := range value { + if r < 32 || r == 127 { + return errors.New("value cannot contain control characters") + } + } + + return nil +} + +func checkValidShell(shell string) (err error) { + // Do the same checks as systemd-homed in shell_is_ok: + // https://github.com/systemd/systemd/blob/ba67af7efb7b743ba1974ef9ceb53fba0e3f9e21/src/home/homectl.c#L2812 + if err = checkValidPasswdField(shell); err != nil { + return err + } + if !path.IsAbs(shell) { + return errors.New("shell must be an absolute path") + } + if shell != path.Clean(shell) { + return errors.New("shell path must be normalized") + } + // PATH_MAX is counted with the terminating null byte + if unix.PathMax-1 < len(shell) { + return errors.New("shell path is too long") + } + + // Check if the shell is in the list of allowed shells in /etc/shells + shells, err := os.ReadFile("/etc/shells") + if err != nil { + return err + } + for _, allowedShell := range strings.Split(string(shells), "\n") { + if allowedShell[0] == '#' { + // Skip comments + continue + } + if allowedShell == shell { + return nil + } + } + + return errors.New("shell is not allowed in /etc/shells") +} From 7af71107ca98a40ac2b26db351eaef7e2db3b8f1 Mon Sep 17 00:00:00 2001 From: Adrian Dombeck Date: Wed, 18 Feb 2026 15:45:02 +0100 Subject: [PATCH 02/14] Test SetShell methods in db and users packages --- internal/users/db/db_test.go | 39 ++++++ .../TestSetShell/Update_existing_user_shell | 18 +++ internal/users/manager_test.go | 121 ++++++++++++++++++ .../TestSetShell/Successfully_set_shell | 18 +++ 4 files changed, 196 insertions(+) create mode 100644 internal/users/db/testdata/golden/TestSetShell/Update_existing_user_shell create mode 100644 internal/users/testdata/golden/TestSetShell/Successfully_set_shell diff --git a/internal/users/db/db_test.go b/internal/users/db/db_test.go index 47cce2060f..47926fc7cc 100644 --- a/internal/users/db/db_test.go +++ b/internal/users/db/db_test.go @@ -1013,6 +1013,45 @@ func TestSetGroupID(t *testing.T) { } } +func TestSetShell(t *testing.T) { + t.Parallel() + + tests := map[string]struct { + nonExistentUser bool + + wantErr bool + }{ + "Update_existing_user_shell": {}, + + "Error_on_nonexistent_user": {nonExistentUser: true, wantErr: true}, + } + + for name, tc := range tests { + t.Run(name, func(t *testing.T) { + t.Parallel() + + m := initDB(t, "one_user_and_group") + + username := "user1" + if tc.nonExistentUser { + username = "nonexistent" + } + + err := m.SetShell(username, "/bin/new-shell") + if tc.wantErr { + require.Error(t, err, "SetShell should return an error for case %q", name) + return + } + require.NoError(t, err, "SetShell should not return an error for case %q", name) + + dbContent, err := db.Z_ForTests_DumpNormalizedYAML(m) + require.NoError(t, err) + + golden.CheckOrUpdate(t, dbContent) + }) + } +} + func TestRemoveDb(t *testing.T) { t.Parallel() diff --git a/internal/users/db/testdata/golden/TestSetShell/Update_existing_user_shell b/internal/users/db/testdata/golden/TestSetShell/Update_existing_user_shell new file mode 100644 index 0000000000..c3b17bbdd4 --- /dev/null +++ b/internal/users/db/testdata/golden/TestSetShell/Update_existing_user_shell @@ -0,0 +1,18 @@ +users: + - name: user1 + uid: 1111 + gid: 11111 + gecos: |- + User1 gecos + On multiple lines + dir: /home/user1 + shell: /bin/new-shell + broker_id: broker-id +groups: + - name: group1 + gid: 11111 + ugid: "12345678" +users_to_groups: + - uid: 1111 + gid: 11111 +schema_version: 2 diff --git a/internal/users/manager_test.go b/internal/users/manager_test.go index 033dffdea5..8d113f6432 100644 --- a/internal/users/manager_test.go +++ b/internal/users/manager_test.go @@ -1301,6 +1301,127 @@ func TestUpdateUserAfterUnlock(t *testing.T) { require.NoError(t, err, "UpdateUser should not fail") } +func TestSetShell(t *testing.T) { + t.Parallel() + + tests := map[string]struct { + nonExistentUser bool + emptyUsername bool + emptyShell bool + shell string + + wantErr bool + }{ + "Successfully_set_shell": {}, + + // checkValidPasswdField error cases + "Error_if_shell_is_empty": { + emptyShell: true, + wantErr: true, + }, + "Error_if_shell_contains_invalid_utf8": { + shell: "/bin/\xff\xfeinvalid", + wantErr: true, + }, + "Error_if_shell_contains_colon": { + shell: "/bin/sh:bash", + wantErr: true, + }, + "Error_if_shell_contains_control_characters": { + shell: "/bin/sh\x00", + wantErr: true, + }, + "Error_if_shell_contains_control_character_tab": { + shell: "/bin/sh\t", + wantErr: true, + }, + "Error_if_shell_contains_control_character_newline": { + shell: "/bin/sh\n", + wantErr: true, + }, + "Error_if_shell_contains_control_character_del": { + shell: "/bin/sh\x7f", + wantErr: true, + }, + + // checkValidShell error cases + "Error_if_shell_is_not_absolute_path": { + shell: "bin/sh", + wantErr: true, + }, + "Error_if_shell_path_is_not_normalized": { + shell: "/bin/../bin/sh", + wantErr: true, + }, + "Error_if_shell_path_is_not_normalized_with_dot": { + shell: "/bin/./sh", + wantErr: true, + }, + "Error_if_shell_path_is_too_long": { + shell: "/" + strings.Repeat("a", 4096), + wantErr: true, + }, + + // Other error cases + "Error_if_shell_does_not_exist": { + shell: "/doesnotexist", + wantErr: true, + }, + "Error_if_shell_is_directory": { + shell: "/etc", + wantErr: true, + }, + "Error_if_shell_is_not_executable": { + shell: "/etc/passwd", + wantErr: true, + }, + "Error_if_user_does_not_exist": { + nonExistentUser: true, + wantErr: true, + }, + "Error_if_username_is_empty": { + emptyUsername: true, + wantErr: true, + }, + } + + for name, tc := range tests { + t.Run(name, func(t *testing.T) { + t.Parallel() + + dbDir := t.TempDir() + err := db.Z_ForTests_CreateDBFromYAML(filepath.Join("testdata", "db", "one_user_and_group.db.yaml"), dbDir) + require.NoError(t, err, "Setup: could not create database from testdata") + + m := newManagerForTests(t, dbDir) + + username := "user1" + if tc.nonExistentUser { + username = "nonexistent" + } else if tc.emptyUsername { + username = "" + } + + shell := "/bin/sh" + if tc.emptyShell { + shell = "" + } else if tc.shell != "" { + shell = tc.shell + } + + err = m.SetShell(username, shell) + requireErrorAssertions(t, err, nil, tc.wantErr) + if tc.wantErr { + return + } + + yamlData, err := db.Z_ForTests_DumpNormalizedYAML(m.DB()) + require.NoError(t, err) + golden.CheckOrUpdate(t, yamlData) + }) + } +} + func requireErrorAssertions(t *testing.T, gotErr, wantErrType error, wantErr bool) { t.Helper() diff --git a/internal/users/testdata/golden/TestSetShell/Successfully_set_shell b/internal/users/testdata/golden/TestSetShell/Successfully_set_shell new file mode 100644 index 0000000000..7fa3788902 --- /dev/null +++ b/internal/users/testdata/golden/TestSetShell/Successfully_set_shell @@ -0,0 +1,18 @@ +users: + - name: user1 + uid: 1111 + gid: 11111 + gecos: |- + User1 gecos + On multiple lines + dir: /home/user1 + shell: /bin/sh + broker_id: broker-id +groups: + - name: group1 + gid: 11111 + ugid: "12345678" +users_to_groups: + - uid: 1111 + gid: 11111 +schema_version: 2 From b0cfa01fd16cd38417d5689b589787cab7b96b6a Mon Sep 17 00:00:00 2001 From: Adrian Dombeck Date: Mon, 15 Dec 2025 17:55:45 +0100 Subject: [PATCH 03/14] Add `authctl user set-shell` command --- cmd/authctl/user/set-shell.go | 46 ++++++++++++++++++ cmd/authctl/user/set-shell_test.go | 48 +++++++++++++++++++ .../Error_when_user_does_not_exist | 1 + .../TestSetShellCommand/Set_shell_success | 0 .../TestUserCommand/Error_on_invalid_command | 1 + .../TestUserCommand/Error_on_invalid_flag | 1 + .../testdata/golden/TestUserCommand/Help_flag | 1 + .../Usage_message_when_no_args | 1 + cmd/authctl/user/user.go | 1 + 9 files changed, 100 insertions(+) create mode 100644 cmd/authctl/user/set-shell.go create mode 100644 cmd/authctl/user/set-shell_test.go create mode 100644 cmd/authctl/user/testdata/golden/TestSetShellCommand/Error_when_user_does_not_exist create mode 100644 cmd/authctl/user/testdata/golden/TestSetShellCommand/Set_shell_success diff --git a/cmd/authctl/user/set-shell.go b/cmd/authctl/user/set-shell.go new file mode 100644 index 0000000000..7b180b182b --- /dev/null +++ b/cmd/authctl/user/set-shell.go @@ -0,0 +1,46 @@ +package user + +import ( + "context" + + "github.com/canonical/authd/cmd/authctl/internal/client" + "github.com/canonical/authd/cmd/authctl/internal/completion" + "github.com/canonical/authd/internal/proto/authd" + "github.com/spf13/cobra" +) + +var setShellCmd = &cobra.Command{ + Use: "set-shell ", + Short: "Set the login shell for a user", + Args: cobra.ExactArgs(2), + ValidArgsFunction: setShellCompletionFunc, + RunE: runSetShell, +} + +func runSetShell(cmd *cobra.Command, args []string) error { + name := args[0] + shell := args[1] + + svc, err := client.NewUserServiceClient() + if err != nil { + return err + } + + _, err = svc.SetShell(context.Background(), &authd.SetShellRequest{ + Name: name, + Shell: shell, + }) + if err != nil { + return err + } + + return nil +} + +func setShellCompletionFunc(cmd *cobra.Command, args []string, toComplete string) ([]string, cobra.ShellCompDirective) { + if len(args) == 0 { + return completion.Users(cmd, args, toComplete) + } + + return nil, cobra.ShellCompDirectiveNoFileComp +} diff --git a/cmd/authctl/user/set-shell_test.go b/cmd/authctl/user/set-shell_test.go new file mode 100644 index 0000000000..0d0b81dcdf --- /dev/null +++ b/cmd/authctl/user/set-shell_test.go @@ -0,0 +1,48 @@ +package user_test + +import ( + "os" + "os/exec" + "path/filepath" + "testing" + + "github.com/canonical/authd/internal/testutils" + "github.com/stretchr/testify/require" + "google.golang.org/grpc/codes" +) + +func TestSetShellCommand(t *testing.T) { + t.Parallel() + + daemonSocket := testutils.StartAuthd(t, daemonPath, + testutils.WithGroupFile(filepath.Join("testdata", "empty.group")), + testutils.WithPreviousDBState("one_user_and_group"), + testutils.WithCurrentUserAsRoot, + ) + + err := os.Setenv("AUTHD_SOCKET", daemonSocket) + require.NoError(t, err, "Failed to set AUTHD_SOCKET environment variable") + + tests := map[string]struct { + args []string + + expectedExitCode int + }{ + "Set_shell_success": {args: []string{"set-shell", "user1", "/bin/bash"}, expectedExitCode: 0}, + + "Error_when_user_does_not_exist": { + args: []string{"set-shell", "invaliduser", "/bin/bash"}, + expectedExitCode: int(codes.NotFound), + }, + } + + for name, tc := range tests { + t.Run(name, func(t *testing.T) { + t.Parallel() + + //nolint:gosec // G204 it's safe to use exec.Command with a variable here + cmd := exec.Command(authctlPath, append([]string{"user"}, tc.args...)...) + testutils.CheckCommand(t, cmd, tc.expectedExitCode) + }) + } +} diff --git a/cmd/authctl/user/testdata/golden/TestSetShellCommand/Error_when_user_does_not_exist b/cmd/authctl/user/testdata/golden/TestSetShellCommand/Error_when_user_does_not_exist new file mode 100644 index 0000000000..93dd7dd5ff --- /dev/null +++ b/cmd/authctl/user/testdata/golden/TestSetShellCommand/Error_when_user_does_not_exist @@ -0,0 +1 @@ +Error: user "invaliduser" not found diff --git a/cmd/authctl/user/testdata/golden/TestSetShellCommand/Set_shell_success b/cmd/authctl/user/testdata/golden/TestSetShellCommand/Set_shell_success new file mode 100644 index 0000000000..e69de29bb2 diff --git a/cmd/authctl/user/testdata/golden/TestUserCommand/Error_on_invalid_command b/cmd/authctl/user/testdata/golden/TestUserCommand/Error_on_invalid_command index a66b03e2c9..b6b2bbed64 100644 --- a/cmd/authctl/user/testdata/golden/TestUserCommand/Error_on_invalid_command +++ b/cmd/authctl/user/testdata/golden/TestUserCommand/Error_on_invalid_command @@ -6,6 +6,7 @@ Available Commands: lock Lock (disable) a user managed by authd unlock Unlock (enable) a user managed by authd set-uid Set the UID of a user managed by authd + set-shell Set the login shell for a user Flags: -h, --help help for user diff --git a/cmd/authctl/user/testdata/golden/TestUserCommand/Error_on_invalid_flag b/cmd/authctl/user/testdata/golden/TestUserCommand/Error_on_invalid_flag index 3408cec6d8..beeb0304bd 100644 --- a/cmd/authctl/user/testdata/golden/TestUserCommand/Error_on_invalid_flag +++ b/cmd/authctl/user/testdata/golden/TestUserCommand/Error_on_invalid_flag @@ -6,6 +6,7 @@ Available Commands: lock Lock (disable) a user managed by authd unlock Unlock (enable) a user managed by authd set-uid Set the UID of a user managed by authd + set-shell Set the login shell for a user Flags: -h, --help help for user diff --git a/cmd/authctl/user/testdata/golden/TestUserCommand/Help_flag b/cmd/authctl/user/testdata/golden/TestUserCommand/Help_flag index ee4765c1cc..7a303d9166 100644 --- a/cmd/authctl/user/testdata/golden/TestUserCommand/Help_flag +++ b/cmd/authctl/user/testdata/golden/TestUserCommand/Help_flag @@ -8,6 +8,7 @@ Available Commands: lock Lock (disable) a user managed by authd unlock Unlock (enable) a user managed by authd set-uid Set the UID of a user managed by authd + set-shell Set the login shell for a user Flags: -h, --help help for user diff --git a/cmd/authctl/user/testdata/golden/TestUserCommand/Usage_message_when_no_args b/cmd/authctl/user/testdata/golden/TestUserCommand/Usage_message_when_no_args index d83ced5684..c90581f34f 100644 --- a/cmd/authctl/user/testdata/golden/TestUserCommand/Usage_message_when_no_args +++ b/cmd/authctl/user/testdata/golden/TestUserCommand/Usage_message_when_no_args @@ -6,6 +6,7 @@ Available Commands: lock Lock (disable) a user managed by authd unlock Unlock (enable) a user managed by authd set-uid Set the UID of a user managed by authd + set-shell Set the login shell for a user Flags: -h, --help help for user diff --git a/cmd/authctl/user/user.go b/cmd/authctl/user/user.go index 109743aa17..e4e218b230 100644 --- a/cmd/authctl/user/user.go +++ b/cmd/authctl/user/user.go @@ -17,4 +17,5 @@ func init() { UserCmd.AddCommand(lockCmd) UserCmd.AddCommand(unlockCmd) UserCmd.AddCommand(setUIDCmd) + UserCmd.AddCommand(setShellCmd) } From 45ab5d4d14d1d07df001ac2d85b6022e800659a9 Mon Sep 17 00:00:00 2001 From: Adrian Dombeck Date: Wed, 18 Feb 2026 15:56:08 +0100 Subject: [PATCH 04/14] docs: Generate docs for `authctl user set-shell` --- docs/reference/cli/authctl_user.md | 1 + docs/reference/cli/authctl_user_set-shell.md | 18 ++++++++++++++++++ docs/reference/cli/index.md | 1 + 3 files changed, 20 insertions(+) create mode 100644 docs/reference/cli/authctl_user_set-shell.md diff --git a/docs/reference/cli/authctl_user.md b/docs/reference/cli/authctl_user.md index 3a6e5606ae..e6163ff3b2 100644 --- a/docs/reference/cli/authctl_user.md +++ b/docs/reference/cli/authctl_user.md @@ -16,6 +16,7 @@ authctl user [flags] * [authctl](authctl.md) - CLI tool to interact with authd * [authctl user lock](authctl_user_lock.md) - Lock (disable) a user managed by authd +* [authctl user set-shell](authctl_user_set-shell.md) - Set the login shell for a user * [authctl user set-uid](authctl_user_set-uid.md) - Set the UID of a user managed by authd * [authctl user unlock](authctl_user_unlock.md) - Unlock (enable) a user managed by authd diff --git a/docs/reference/cli/authctl_user_set-shell.md b/docs/reference/cli/authctl_user_set-shell.md new file mode 100644 index 0000000000..2668479223 --- /dev/null +++ b/docs/reference/cli/authctl_user_set-shell.md @@ -0,0 +1,18 @@ +## authctl user set-shell + +Set the login shell for a user + +``` +authctl user set-shell [flags] +``` + +### Options + +``` + -h, --help help for set-shell +``` + +### SEE ALSO + +* [authctl user](authctl_user.md) - Commands related to users + diff --git a/docs/reference/cli/index.md b/docs/reference/cli/index.md index c9f8a20962..ff0686ff63 100644 --- a/docs/reference/cli/index.md +++ b/docs/reference/cli/index.md @@ -21,6 +21,7 @@ authctl_user authctl_user_lock authctl_user_unlock authctl_user_set-uid +authctl_user_set-shell ``` ```{toctree} From 6e3c5d52162cb9241c9cdc5998916a5366d79de6 Mon Sep 17 00:00:00 2001 From: Adrian Dombeck Date: Wed, 18 Feb 2026 16:14:25 +0100 Subject: [PATCH 05/14] Rename TestIsRequestFromRoot -> TestCheckRequestIsFromRoot The function that's tested is CheckRequestIsFromRoot. --- internal/services/permissions/permissions_test.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/internal/services/permissions/permissions_test.go b/internal/services/permissions/permissions_test.go index 04605cabb4..ebc4177c0b 100644 --- a/internal/services/permissions/permissions_test.go +++ b/internal/services/permissions/permissions_test.go @@ -21,7 +21,7 @@ func TestNew(t *testing.T) { require.NotNil(t, pm, "New permission manager is created") } -func TestIsRequestFromRoot(t *testing.T) { +func TestCheckRequestIsFromRoot(t *testing.T) { t.Parallel() tests := map[string]struct { From e379217b9487cef7d7942a14797b3292111ac0e5 Mon Sep 17 00:00:00 2001 From: Adrian Dombeck Date: Wed, 18 Feb 2026 16:36:36 +0100 Subject: [PATCH 06/14] Add TestCheckRequestIsFromRootOrUID Also renames peerCredsInfo to peerAuthInfo because the field in peer.Peer is named AuthInfo. Also improves the comment of WithUnixPeerCreds. --- internal/services/permissions/export_test.go | 6 +- .../services/permissions/internal_test.go | 2 +- internal/services/permissions/permissions.go | 4 +- .../services/permissions/permissions_test.go | 108 ++++++++++++++---- internal/services/permissions/servercreds.go | 8 +- 5 files changed, 95 insertions(+), 33 deletions(-) diff --git a/internal/services/permissions/export_test.go b/internal/services/permissions/export_test.go index 728aaa902b..7c43b1aad6 100644 --- a/internal/services/permissions/export_test.go +++ b/internal/services/permissions/export_test.go @@ -1,9 +1,9 @@ package permissions -type PeerCredsInfo = peerCredsInfo +type PeerAuthInfo = peerAuthInfo -func NewTestPeerCredsInfo(uid uint32, pid int32) PeerCredsInfo { - return PeerCredsInfo{uid: uid, pid: pid} +func NewTestPeerAuthInfo(uid uint32, pid int32) PeerAuthInfo { + return PeerAuthInfo{uid: uid, pid: pid} } var ( diff --git a/internal/services/permissions/internal_test.go b/internal/services/permissions/internal_test.go index 8bbdf6752b..ea662527f4 100644 --- a/internal/services/permissions/internal_test.go +++ b/internal/services/permissions/internal_test.go @@ -16,7 +16,7 @@ import ( func TestPeerCredsInfoAuthType(t *testing.T) { t.Parallel() - p := peerCredsInfo{ + p := peerAuthInfo{ uid: 11111, pid: 22222, } diff --git a/internal/services/permissions/permissions.go b/internal/services/permissions/permissions.go index b6b8cd4506..27c6bb7b0c 100644 --- a/internal/services/permissions/permissions.go +++ b/internal/services/permissions/permissions.go @@ -38,7 +38,7 @@ func New(args ...Option) Manager { } // CheckRequestIsFromRoot checks if the current gRPC request is from a root user and returns an error if not. -// The pid and uid are extracted from peerCredsInfo in the gRPC context. +// The pid and uid are extracted from peerAuthInfo in the gRPC context. func (m Manager) CheckRequestIsFromRoot(ctx context.Context) (err error) { isRoot, err := m.isRequestFromRoot(ctx) if err != nil { @@ -80,7 +80,7 @@ func (m Manager) isRequestFromUID(ctx context.Context, uid uint32) (bool, error) if !ok { return false, errors.New("context request doesn't have gRPC peer information") } - pci, ok := p.AuthInfo.(peerCredsInfo) + pci, ok := p.AuthInfo.(peerAuthInfo) if !ok { return false, errors.New("context request doesn't have valid gRPC peer credential information") } diff --git a/internal/services/permissions/permissions_test.go b/internal/services/permissions/permissions_test.go index ebc4177c0b..fd231a880a 100644 --- a/internal/services/permissions/permissions_test.go +++ b/internal/services/permissions/permissions_test.go @@ -26,39 +26,22 @@ func TestCheckRequestIsFromRoot(t *testing.T) { tests := map[string]struct { currentUserNotRoot bool - noPeerCredsInfo bool - noAuthInfo bool + noPeerInfo bool + noPeerAuthInfo bool wantErr bool }{ "Granted_if_current_user_considered_as_root": {}, - "Error_as_deny_when_current_user_is_not_root": {currentUserNotRoot: true, wantErr: true}, - "Error_as_deny_when_missing_peer_creds_Info": {noPeerCredsInfo: true, wantErr: true}, - "Error_as_deny_when_missing_auth_info_creds": {noAuthInfo: true, wantErr: true}, + "Error_if_current_user_is_not_root": {currentUserNotRoot: true, wantErr: true}, + "Error_if_missing_peer_info": {noPeerInfo: true, wantErr: true}, + "Error_if_missing_peer_auth_info": {noPeerAuthInfo: true, wantErr: true}, } for name, tc := range tests { t.Run(name, func(t *testing.T) { t.Parallel() - // Setup peer creds info - ctx := context.Background() - if !tc.noPeerCredsInfo { - var authInfo credentials.AuthInfo - if !tc.noAuthInfo { - uid := permissions.CurrentUserUID() - pid := os.Getpid() - if pid > math.MaxInt32 { - t.Fatalf("Setup: pid is too large to be converted to int32: %d", pid) - } - //nolint:gosec // we did check the conversion check beforehand. - authInfo = permissions.NewTestPeerCredsInfo(uid, int32(os.Getpid())) - } - p := peer.Peer{ - AuthInfo: authInfo, - } - ctx = peer.NewContext(ctx, &p) - } + ctx := setupPermissionTestContext(t, tc.noPeerInfo, tc.noPeerAuthInfo) var opts []permissions.Option if !tc.currentUserNotRoot { @@ -77,6 +60,60 @@ func TestCheckRequestIsFromRoot(t *testing.T) { } } +func TestCheckRequestIsFromRootOrUID(t *testing.T) { + t.Parallel() + + tests := map[string]struct { + currentUserNotRoot bool + useCurrentUID bool + useDifferentUID bool + noPeerInfo bool + noPeerAuthInfo bool + + wantErr bool + }{ + "Granted_if_current_user_considered_as_root": {}, + "Granted_if_current_user_matches_target_uid": {currentUserNotRoot: true, useCurrentUID: true}, + + "Error_if_current_user_is_neither_root_nor_target_uid": { + currentUserNotRoot: true, + useDifferentUID: true, + wantErr: true, + }, + "Error_if_missing_peer_info": {noPeerInfo: true, wantErr: true}, + "Error_if_missing_peer_auth_info": {noPeerAuthInfo: true, wantErr: true}, + } + for name, tc := range tests { + t.Run(name, func(t *testing.T) { + t.Parallel() + + ctx := setupPermissionTestContext(t, tc.noPeerInfo, tc.noPeerAuthInfo) + + currentUID := permissions.CurrentUserUID() + targetUID := currentUID + + // If we want a different UID, use a different value + if tc.useDifferentUID { + targetUID = currentUID + 1000 + } + + var opts []permissions.Option + if !tc.currentUserNotRoot { + opts = append(opts, permissions.Z_ForTests_WithCurrentUserAsRoot()) + } + pm := permissions.New(opts...) + + err := pm.CheckRequestIsFromRootOrUID(ctx, targetUID) + + if tc.wantErr { + require.Error(t, err, "CheckRequestIsFromRootOrUID should deny access but didn't") + return + } + require.NoError(t, err, "CheckRequestIsFromRootOrUID should allow access but didn't") + }) + } +} + func TestWithUnixPeerCreds(t *testing.T) { t.Parallel() @@ -84,3 +121,28 @@ func TestWithUnixPeerCreds(t *testing.T) { require.NotNil(t, g, "New gRPC with Unix Peer Creds is created") } + +// setupPermissionTestContext creates a context with peer credentials for testing. +func setupPermissionTestContext(t *testing.T, noPeerInfo, noAuthInfo bool) context.Context { + t.Helper() + + ctx := context.Background() + if noPeerInfo { + return ctx + } + + var authInfo credentials.AuthInfo + if !noAuthInfo { + uid := permissions.CurrentUserUID() + pid := os.Getpid() + if pid > math.MaxInt32 { + require.Fail(t, "Setup: pid is too large to be converted to int32: %d", pid) + } + //nolint:gosec // we checked for an integer overflow above. + authInfo = permissions.NewTestPeerAuthInfo(uid, int32(pid)) + } + p := peer.Peer{ + AuthInfo: authInfo, + } + return peer.NewContext(ctx, &p) +} diff --git a/internal/services/permissions/servercreds.go b/internal/services/permissions/servercreds.go index 72b3d71d3a..85f0383fb7 100644 --- a/internal/services/permissions/servercreds.go +++ b/internal/services/permissions/servercreds.go @@ -13,7 +13,7 @@ import ( "google.golang.org/grpc/credentials" ) -// WithUnixPeerCreds returns the credentials of the caller. +// WithUnixPeerCreds returns a ServerOption that sets credentials for server connections. func WithUnixPeerCreds() grpc.ServerOption { return grpc.Creds(serverPeerCreds{}) } @@ -57,7 +57,7 @@ func (serverPeerCreds) ServerHandshake(conn net.Conn) (n net.Conn, c credentials return nil, nil, fmt.Errorf("Control() error: %v", err) } - return conn, peerCredsInfo{uid: cred.Uid, pid: cred.Pid}, nil + return conn, peerAuthInfo{uid: cred.Uid, pid: cred.Pid}, nil } func (serverPeerCreds) ClientHandshake(_ context.Context, _ string, conn net.Conn) (net.Conn, credentials.AuthInfo, error) { return conn, nil, nil @@ -66,12 +66,12 @@ func (serverPeerCreds) Info() credentials.ProtocolInfo { return credent func (serverPeerCreds) Clone() credentials.TransportCredentials { return nil } func (serverPeerCreds) OverrideServerName(_ string) error { return nil } -type peerCredsInfo struct { +type peerAuthInfo struct { uid uint32 pid int32 } // AuthType returns a string containing the uid and pid of caller. -func (p peerCredsInfo) AuthType() string { +func (p peerAuthInfo) AuthType() string { return fmt.Sprintf("uid: %d, pid: %d", p.uid, p.pid) } From eb9383dd5f1aace32e0db098579d3b8e4baf1be7 Mon Sep 17 00:00:00 2001 From: Adrian Dombeck Date: Wed, 18 Feb 2026 16:56:02 +0100 Subject: [PATCH 07/14] Test the SetShell method of the user service --- internal/services/user/user_test.go | 37 +++++++++++++++++++++++++++++ 1 file changed, 37 insertions(+) diff --git a/internal/services/user/user_test.go b/internal/services/user/user_test.go index 91afdf1e95..78c7dc1b27 100644 --- a/internal/services/user/user_test.go +++ b/internal/services/user/user_test.go @@ -423,6 +423,43 @@ func TestSetGroupID(t *testing.T) { } } +func TestSetShell(t *testing.T) { + tests := map[string]struct { + sourceDB string + + username string + newShell string + closeDB bool + currentUserNotRoot bool + + wantErr bool + }{ + "Successfully_set_shell": {username: "user1", newShell: "/bin/sh"}, + "Successfully_set_shell_when_username_has_uppercase_char": {username: "USER1", newShell: "/bin/sh"}, + + "Error_when_not_root": {username: "user1", newShell: "/bin/sh", currentUserNotRoot: true, wantErr: true}, + "Error_when_users_manager_fails_to_set_shell": {username: "doesnotexist", newShell: "/bin/sh", wantErr: true}, + } + for name, tc := range tests { + t.Run(name, func(t *testing.T) { + client, m := newUserServiceClient(t, tc.sourceDB, tc.currentUserNotRoot) + + if tc.closeDB { + // Close the database to trigger a database error + err := userstestutils.DBManager(m).Close() + require.NoError(t, err, "Setup: failed to close database") + } + + _, err := client.SetShell(context.Background(), &authd.SetShellRequest{Name: tc.username, Shell: tc.newShell}) + if tc.wantErr { + require.Error(t, err, "SetShell should return an error, but did not") + return + } + require.NoError(t, err, "SetShell should not return an error, but did") + }) + } +} + // newUserServiceClient returns a new gRPC client for the CLI service. func newUserServiceClient(t *testing.T, dbFile string, currentUserNotRoot ...bool) (client authd.UserServiceClient, userManager *users.Manager) { t.Helper() From b8f603871a75673e907c9fc3e9fbd8535103f32f Mon Sep 17 00:00:00 2001 From: Adrian Dombeck Date: Mon, 23 Feb 2026 12:58:43 +0100 Subject: [PATCH 08/14] Add envutils The envutils package provides utilities for manipulating string slices representing environment variables. --- internal/envutils/envutils.go | 46 ++++++ internal/envutils/envutils_test.go | 241 +++++++++++++++++++++++++++++ 2 files changed, 287 insertions(+) create mode 100644 internal/envutils/envutils.go create mode 100644 internal/envutils/envutils_test.go diff --git a/internal/envutils/envutils.go b/internal/envutils/envutils.go new file mode 100644 index 0000000000..9f24c64129 --- /dev/null +++ b/internal/envutils/envutils.go @@ -0,0 +1,46 @@ +// Package envutils provides utilities for manipulating string slices representing environment variables. +package envutils + +import ( + "errors" + "fmt" + "strings" +) + +// Getenv retrieves the value of an environment variable from a slice of strings. +func Getenv(env []string, key string) string { + for _, kv := range env { + if strings.HasPrefix(kv, key+"=") { + return strings.TrimPrefix(kv, key+"=") + } + } + return "" +} + +// Setenv sets an environment variable in a slice of strings. +func Setenv(env []string, key, value string) ([]string, error) { + if len(key) == 0 { + return nil, errors.New("empty key") + } + if strings.ContainsAny(key, "="+"\x00") { + return nil, fmt.Errorf("invalid key: %q", key) + } + if strings.ContainsRune(value, '\x00') { + return nil, fmt.Errorf("invalid value: %q", value) + } + + kv := fmt.Sprintf("%s=%s", key, value) + + // Check if the key is already set + for i, kvPair := range env { + if strings.HasPrefix(kvPair, key+"=") { + // Key exists, update the value + env[i] = kv + return env, nil + } + } + + // Key is not set yet, append it + env = append(env, kv) + return env, nil +} diff --git a/internal/envutils/envutils_test.go b/internal/envutils/envutils_test.go new file mode 100644 index 0000000000..25c7108972 --- /dev/null +++ b/internal/envutils/envutils_test.go @@ -0,0 +1,241 @@ +package envutils_test + +import ( + "testing" + + "github.com/canonical/authd/internal/envutils" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestGetenv(t *testing.T) { + t.Parallel() + + tests := map[string]struct { + env []string + key string + want string + }{ + "Get_existing_environment_variable": { + env: []string{"FOO=bar", "BAZ=qux"}, + key: "FOO", + want: "bar", + }, + "Get_environment_variable_with_empty_value": { + env: []string{"FOO=bar", "EMPTY=", "BAZ=qux"}, + key: "EMPTY", + want: "", + }, + "Get_environment_variable_with_special_characters": { + env: []string{"PATH=/usr/bin:/usr/local/bin"}, + key: "PATH", + want: "/usr/bin:/usr/local/bin", + }, + "Get_environment_variable_with_spaces": { + env: []string{"MESSAGE=hello world"}, + key: "MESSAGE", + want: "hello world", + }, + "Get_environment_variable_with_equals_sign_in_value": { + env: []string{"EQUATION=x=y+z"}, + key: "EQUATION", + want: "x=y+z", + }, + "Get_first_variable_in_list": { + env: []string{"FIRST=1", "SECOND=2", "THIRD=3"}, + key: "FIRST", + want: "1", + }, + "Get_middle_variable_in_list": { + env: []string{"FIRST=1", "SECOND=2", "THIRD=3"}, + key: "SECOND", + want: "2", + }, + "Get_last_variable_in_list": { + env: []string{"FIRST=1", "SECOND=2", "THIRD=3"}, + key: "THIRD", + want: "3", + }, + "Return_empty_string_when_key_not_found": { + env: []string{"FOO=bar", "BAZ=qux"}, + key: "MISSING", + want: "", + }, + "Return_empty_string_when_key_not_found_in_empty_environment": { + env: []string{}, + key: "VAR", + want: "", + }, + "Return_empty_string_when_looking_for_partial_key_match": { + env: []string{"FOOBAR=baz"}, + key: "FOO", + want: "", + }, + } + + for name, tc := range tests { + t.Run(name, func(t *testing.T) { + t.Parallel() + + got := envutils.Getenv(tc.env, tc.key) + assert.Equal(t, tc.want, got, "Value should match expected") + }) + } +} + +func TestSetenv(t *testing.T) { + t.Parallel() + + tests := map[string]struct { + env []string + key string + value string + want []string + wantErr bool + errContains string + }{ + "Set_new_environment_variable": { + env: []string{"FOO=bar", "BAZ=qux"}, + key: "NEW_VAR", + value: "new_value", + want: []string{"FOO=bar", "BAZ=qux", "NEW_VAR=new_value"}, + }, + "Update_existing_environment_variable": { + env: []string{"FOO=bar", "BAZ=qux"}, + key: "FOO", + value: "updated", + want: []string{"FOO=updated", "BAZ=qux"}, + }, + "Set_variable_in_empty_environment": { + env: []string{}, + key: "VAR", + value: "value", + want: []string{"VAR=value"}, + }, + "Set_variable_with_empty_value": { + env: []string{"FOO=bar"}, + key: "EMPTY", + value: "", + want: []string{"FOO=bar", "EMPTY="}, + }, + "Update_variable_to_empty_value": { + env: []string{"FOO=bar", "BAZ=qux"}, + key: "FOO", + value: "", + want: []string{"FOO=", "BAZ=qux"}, + }, + "Set_variable_with_special_characters_in_value": { + env: []string{}, + key: "PATH", + value: "/usr/bin:/usr/local/bin", + want: []string{"PATH=/usr/bin:/usr/local/bin"}, + }, + "Set_variable_with_spaces_in_value": { + env: []string{}, + key: "MESSAGE", + value: "hello world", + want: []string{"MESSAGE=hello world"}, + }, + "Update_first_variable_in_list": { + env: []string{"FIRST=1", "SECOND=2", "THIRD=3"}, + key: "FIRST", + value: "updated", + want: []string{"FIRST=updated", "SECOND=2", "THIRD=3"}, + }, + "Update_middle_variable_in_list": { + env: []string{"FIRST=1", "SECOND=2", "THIRD=3"}, + key: "SECOND", + value: "updated", + want: []string{"FIRST=1", "SECOND=updated", "THIRD=3"}, + }, + "Update_last_variable_in_list": { + env: []string{"FIRST=1", "SECOND=2", "THIRD=3"}, + key: "THIRD", + value: "updated", + want: []string{"FIRST=1", "SECOND=2", "THIRD=updated"}, + }, + + // Error cases + "Error_on_empty_key": { + env: []string{"FOO=bar"}, + key: "", + value: "value", + wantErr: true, + errContains: "empty key", + }, + "Error_on_key_with_equals_sign": { + env: []string{"FOO=bar"}, + key: "KEY=VALUE", + value: "value", + wantErr: true, + errContains: "invalid key", + }, + "Error_on_key_with_null_byte": { + env: []string{"FOO=bar"}, + key: "KEY\x00", + value: "value", + wantErr: true, + errContains: "invalid key", + }, + "Error_on_value_with_null_byte": { + env: []string{"FOO=bar"}, + key: "KEY", + value: "value\x00", + wantErr: true, + errContains: "invalid value", + }, + } + + for name, tc := range tests { + t.Run(name, func(t *testing.T) { + t.Parallel() + + got, err := envutils.Setenv(tc.env, tc.key, tc.value) + + if tc.wantErr { + require.Error(t, err, "Setenv should return an error") + assert.Contains(t, err.Error(), tc.errContains, "Error message should contain expected text") + return + } + + require.NoError(t, err, "Setenv should not return an error") + assert.Equal(t, tc.want, got, "Environment slice should match expected") + }) + } +} + +func TestSetenvDoesNotModifyOriginal(t *testing.T) { + t.Parallel() + + original := []string{"FOO=bar", "BAZ=qux"} + originalCopy := make([]string, len(original)) + copy(originalCopy, original) + + result, err := envutils.Setenv(original, "NEW", "value") + require.NoError(t, err) + + // Verify original slice content is unchanged (but may have increased capacity) + assert.Equal(t, originalCopy, original[:len(originalCopy)], "Original slice content should not be modified") + // Verify result contains the new variable + assert.Contains(t, result, "NEW=value", "Result should contain new variable") +} + +func TestSetenvPreservesOrder(t *testing.T) { + t.Parallel() + + // Update a middle variable + env1 := []string{"A=1", "B=2", "C=3", "D=4", "E=5"} + result, err := envutils.Setenv(env1, "C", "updated") + require.NoError(t, err) + + expected := []string{"A=1", "B=2", "C=updated", "D=4", "E=5"} + assert.Equal(t, expected, result, "Order should be preserved when updating") + + // Add a new variable + env2 := []string{"A=1", "B=2", "C=3", "D=4", "E=5"} + result2, err := envutils.Setenv(env2, "F", "6") + require.NoError(t, err) + + expected2 := []string{"A=1", "B=2", "C=3", "D=4", "E=5", "F=6"} + assert.Equal(t, expected2, result2, "New variable should be appended at the end") +} From 4551e0883beca6232f25c720dccd7b8b65d52de9 Mon Sep 17 00:00:00 2001 From: Adrian Dombeck Date: Mon, 23 Feb 2026 12:37:39 +0100 Subject: [PATCH 09/14] Fix authctl tests not producing coverage data We have to pass the GOCOVERDIR environment variable to the authctl binary. --- cmd/authctl/group/group_test.go | 1 + cmd/authctl/group/set-gid_test.go | 19 ++++++++++--------- cmd/authctl/main_test.go | 1 + cmd/authctl/user/lock_test.go | 9 +++++---- cmd/authctl/user/set-shell_test.go | 9 +++++---- cmd/authctl/user/set-uid_test.go | 20 ++++++++++---------- cmd/authctl/user/user_test.go | 1 + 7 files changed, 33 insertions(+), 27 deletions(-) diff --git a/cmd/authctl/group/group_test.go b/cmd/authctl/group/group_test.go index 34d3093521..8fb4d7559a 100644 --- a/cmd/authctl/group/group_test.go +++ b/cmd/authctl/group/group_test.go @@ -32,6 +32,7 @@ func TestGroupCommand(t *testing.T) { //nolint:gosec // G204 it's safe to use exec.Command with a variable here cmd := exec.Command(authctlPath, append([]string{"group"}, tc.args...)...) + cmd.Env = []string{testutils.CoverDirEnv()} testutils.CheckCommand(t, cmd, tc.expectedExitCode) }) } diff --git a/cmd/authctl/group/set-gid_test.go b/cmd/authctl/group/set-gid_test.go index d26382f1df..aa9533d758 100644 --- a/cmd/authctl/group/set-gid_test.go +++ b/cmd/authctl/group/set-gid_test.go @@ -2,12 +2,12 @@ package group_test import ( "math" - "os" "os/exec" "path/filepath" "strconv" "testing" + "github.com/canonical/authd/internal/envutils" "github.com/canonical/authd/internal/testutils" "github.com/stretchr/testify/require" "google.golang.org/grpc/codes" @@ -26,8 +26,10 @@ func TestSetGIDCommand(t *testing.T) { testutils.WithCurrentUserAsRoot, ) - err := os.Setenv("AUTHD_SOCKET", daemonSocket) - require.NoError(t, err, "Failed to set AUTHD_SOCKET environment variable") + authctlEnv := []string{ + "AUTHD_SOCKET=" + daemonSocket, + testutils.CoverDirEnv(), + } tests := map[string]struct { args []string @@ -69,18 +71,17 @@ func TestSetGIDCommand(t *testing.T) { for name, tc := range tests { t.Run(name, func(t *testing.T) { + // Copy authctlEnv to avoid modifying the original slice. + authctlEnv := append([]string{}, authctlEnv...) if tc.authdUnavailable { - origValue := os.Getenv("AUTHD_SOCKET") - err := os.Setenv("AUTHD_SOCKET", "/non-existent") + var err error + authctlEnv, err = envutils.Setenv(authctlEnv, "AUTHD_SOCKET", "/non-existent") require.NoError(t, err, "Failed to set AUTHD_SOCKET environment variable") - t.Cleanup(func() { - err := os.Setenv("AUTHD_SOCKET", origValue) - require.NoError(t, err, "Failed to restore AUTHD_SOCKET environment variable") - }) } //nolint:gosec // G204 it's safe to use exec.Command with a variable here cmd := exec.Command(authctlPath, append([]string{"group"}, tc.args...)...) + cmd.Env = authctlEnv testutils.CheckCommand(t, cmd, tc.expectedExitCode) }) } diff --git a/cmd/authctl/main_test.go b/cmd/authctl/main_test.go index 52c2a0ee62..278286fac7 100644 --- a/cmd/authctl/main_test.go +++ b/cmd/authctl/main_test.go @@ -33,6 +33,7 @@ func TestRootCommand(t *testing.T) { //nolint:gosec // G204 it's safe to use exec.Command with a variable here cmd := exec.Command(authctlPath, tc.args...) + cmd.Env = []string{testutils.CoverDirEnv()} testutils.CheckCommand(t, cmd, tc.expectedExitCode) }) } diff --git a/cmd/authctl/user/lock_test.go b/cmd/authctl/user/lock_test.go index eb89ac349a..0828bf9165 100644 --- a/cmd/authctl/user/lock_test.go +++ b/cmd/authctl/user/lock_test.go @@ -1,13 +1,11 @@ package user_test import ( - "os" "os/exec" "path/filepath" "testing" "github.com/canonical/authd/internal/testutils" - "github.com/stretchr/testify/require" "google.golang.org/grpc/codes" ) @@ -20,8 +18,10 @@ func TestUserLockCommand(t *testing.T) { testutils.WithCurrentUserAsRoot, ) - err := os.Setenv("AUTHD_SOCKET", daemonSocket) - require.NoError(t, err, "Failed to set AUTHD_SOCKET environment variable") + authctlEnv := []string{ + "AUTHD_SOCKET=" + daemonSocket, + testutils.CoverDirEnv(), + } tests := map[string]struct { args []string @@ -38,6 +38,7 @@ func TestUserLockCommand(t *testing.T) { //nolint:gosec // G204 it's safe to use exec.Command with a variable here cmd := exec.Command(authctlPath, append([]string{"user"}, tc.args...)...) + cmd.Env = authctlEnv testutils.CheckCommand(t, cmd, tc.expectedExitCode) }) } diff --git a/cmd/authctl/user/set-shell_test.go b/cmd/authctl/user/set-shell_test.go index 0d0b81dcdf..d6110dd64a 100644 --- a/cmd/authctl/user/set-shell_test.go +++ b/cmd/authctl/user/set-shell_test.go @@ -1,13 +1,11 @@ package user_test import ( - "os" "os/exec" "path/filepath" "testing" "github.com/canonical/authd/internal/testutils" - "github.com/stretchr/testify/require" "google.golang.org/grpc/codes" ) @@ -20,8 +18,10 @@ func TestSetShellCommand(t *testing.T) { testutils.WithCurrentUserAsRoot, ) - err := os.Setenv("AUTHD_SOCKET", daemonSocket) - require.NoError(t, err, "Failed to set AUTHD_SOCKET environment variable") + authctlEnv := []string{ + "AUTHD_SOCKET=" + daemonSocket, + testutils.CoverDirEnv(), + } tests := map[string]struct { args []string @@ -42,6 +42,7 @@ func TestSetShellCommand(t *testing.T) { //nolint:gosec // G204 it's safe to use exec.Command with a variable here cmd := exec.Command(authctlPath, append([]string{"user"}, tc.args...)...) + cmd.Env = authctlEnv testutils.CheckCommand(t, cmd, tc.expectedExitCode) }) } diff --git a/cmd/authctl/user/set-uid_test.go b/cmd/authctl/user/set-uid_test.go index e4a2ed0765..6284330d1f 100644 --- a/cmd/authctl/user/set-uid_test.go +++ b/cmd/authctl/user/set-uid_test.go @@ -2,12 +2,12 @@ package user_test import ( "math" - "os" "os/exec" "path/filepath" "strconv" "testing" + "github.com/canonical/authd/internal/envutils" "github.com/canonical/authd/internal/testutils" "github.com/stretchr/testify/require" "google.golang.org/grpc/codes" @@ -19,15 +19,16 @@ func TestSetUIDCommand(t *testing.T) { // which makes userslocking.WriteLock() return an error immediately when the lock // is already held - unlike the normal behavior which tries to acquire the lock // for 15 seconds before returning an error. - daemonSocket := testutils.StartAuthd(t, daemonPath, testutils.WithGroupFile(filepath.Join("testdata", "empty.group")), testutils.WithPreviousDBState("one_user_and_group"), testutils.WithCurrentUserAsRoot, ) - err := os.Setenv("AUTHD_SOCKET", daemonSocket) - require.NoError(t, err, "Failed to set AUTHD_SOCKET environment variable") + authctlEnv := []string{ + "AUTHD_SOCKET=" + daemonSocket, + testutils.CoverDirEnv(), + } tests := map[string]struct { args []string @@ -69,18 +70,17 @@ func TestSetUIDCommand(t *testing.T) { for name, tc := range tests { t.Run(name, func(t *testing.T) { + // Copy authctlEnv to avoid modifying the original slice. + authctlEnv := append([]string{}, authctlEnv...) if tc.authdUnavailable { - origValue := os.Getenv("AUTHD_SOCKET") - err := os.Setenv("AUTHD_SOCKET", "/non-existent") + var err error + authctlEnv, err = envutils.Setenv(authctlEnv, "AUTHD_SOCKET", "/non-existent") require.NoError(t, err, "Failed to set AUTHD_SOCKET environment variable") - t.Cleanup(func() { - err := os.Setenv("AUTHD_SOCKET", origValue) - require.NoError(t, err, "Failed to restore AUTHD_SOCKET environment variable") - }) } //nolint:gosec // G204 it's safe to use exec.Command with a variable here cmd := exec.Command(authctlPath, append([]string{"user"}, tc.args...)...) + cmd.Env = authctlEnv testutils.CheckCommand(t, cmd, tc.expectedExitCode) }) } diff --git a/cmd/authctl/user/user_test.go b/cmd/authctl/user/user_test.go index 5010e8cc74..1523bd2d67 100644 --- a/cmd/authctl/user/user_test.go +++ b/cmd/authctl/user/user_test.go @@ -32,6 +32,7 @@ func TestUserCommand(t *testing.T) { //nolint:gosec // G204 it's safe to use exec.Command with a variable here cmd := exec.Command(authctlPath, append([]string{"user"}, tc.args...)...) + cmd.Env = []string{testutils.CoverDirEnv()} testutils.CheckCommand(t, cmd, tc.expectedExitCode) }) } From 9591b09808f9783b20d736e98e88d0decea2d384 Mon Sep 17 00:00:00 2001 From: Adrian Dombeck Date: Mon, 23 Feb 2026 13:40:41 +0100 Subject: [PATCH 10/14] Improve error message The grpc.NewClient does not actually try to connect to authd, it just sets up the client but doesn't do any I/O. --- cmd/authctl/internal/client/client.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/cmd/authctl/internal/client/client.go b/cmd/authctl/internal/client/client.go index fff6efe7ff..1c388e72e0 100644 --- a/cmd/authctl/internal/client/client.go +++ b/cmd/authctl/internal/client/client.go @@ -27,7 +27,7 @@ func NewUserServiceClient() (authd.UserServiceClient, error) { conn, err := grpc.NewClient(authdSocket, grpc.WithTransportCredentials(insecure.NewCredentials())) if err != nil { - return nil, fmt.Errorf("failed to connect to authd: %w", err) + return nil, fmt.Errorf("failed to create gRPC client: %w", err) } client := authd.NewUserServiceClient(conn) From 07f98a7110acf2a7533d9e55ab69f016082998b3 Mon Sep 17 00:00:00 2001 From: Adrian Dombeck Date: Wed, 25 Feb 2026 13:56:22 +0100 Subject: [PATCH 11/14] SetShell: Only allow root to set a user's shell The original plan was to allow users to change their own shell without root permissions, like chsh does. However, chsh requires the user to authenticate before it allows change the shell. We should also be able to do that by letting the user authenticate to the broker. That's a lot more effort than originally planned though, so for now we only allow root to change user shells. --- internal/services/user/user.go | 12 ++---------- internal/users/manager.go | 6 ++++++ 2 files changed, 8 insertions(+), 10 deletions(-) diff --git a/internal/services/user/user.go b/internal/services/user/user.go index b76e64dedf..6299a45794 100644 --- a/internal/services/user/user.go +++ b/internal/services/user/user.go @@ -275,19 +275,11 @@ func (s Service) SetShell(ctx context.Context, req *authd.SetShellRequest) (*aut // authd uses lowercase group names. name := strings.ToLower(req.GetName()) - user, err := s.userManager.UserByName(name) - if errors.Is(err, users.NoDataFoundError{}) { - return nil, status.Errorf(codes.NotFound, "user %q not found", name) - } - if err != nil { - return nil, grpcError(err) - } - - if err := s.permissionManager.CheckRequestIsFromRootOrUID(ctx, user.UID); err != nil { + if err := s.permissionManager.CheckRequestIsFromRoot(ctx); err != nil { return nil, status.Error(codes.PermissionDenied, err.Error()) } - if err = s.userManager.SetShell(name, req.GetShell()); err != nil { + if err := s.userManager.SetShell(name, req.GetShell()); err != nil { log.Errorf(ctx, "SetShell: %v", err) return nil, grpcError(err) } diff --git a/internal/users/manager.go b/internal/users/manager.go index 2ab4164b7a..a997beeaf5 100644 --- a/internal/users/manager.go +++ b/internal/users/manager.go @@ -671,6 +671,12 @@ func (m *Manager) SetShell(username, shell string) (err error) { return errors.New("empty username") } + // Check if the user exists + _, err = m.db.UserByName(username) + if err != nil { + return err + } + err = checkValidShell(shell) if err != nil { return err From 7941b7e802d2eecf8100bc0aaecb95b8e21f72cf Mon Sep 17 00:00:00 2001 From: Adrian Dombeck Date: Wed, 25 Feb 2026 22:51:30 +0100 Subject: [PATCH 12/14] SetShell: Allow root to set an invalid shell But display a warning, similar to the behavior of chsh. --- cmd/authctl/user/set-shell.go | 12 +- internal/proto/authd/authd.pb.go | 137 ++++++++++++------ internal/proto/authd/authd.proto | 6 +- internal/proto/authd/authd_grpc.pb.go | 10 +- internal/services/user/user.go | 9 +- internal/users/manager.go | 25 ++-- internal/users/manager_test.go | 45 +++--- .../db} | 0 .../Successfully_set_shell/warnings | 1 + .../Warning_if_shell_does_not_exist/db | 18 +++ .../Warning_if_shell_does_not_exist/warnings | 1 + .../Warning_if_shell_is_directory/db | 18 +++ .../Warning_if_shell_is_directory/warnings | 1 + .../Warning_if_shell_is_not_executable/db | 18 +++ .../warnings | 1 + .../Warning_if_shell_is_not_in_etc_shells/db | 18 +++ .../warnings | 1 + internal/users/userutils.go | 25 +++- 18 files changed, 255 insertions(+), 91 deletions(-) rename internal/users/testdata/golden/TestSetShell/{Successfully_set_shell => Successfully_set_shell/db} (100%) create mode 100644 internal/users/testdata/golden/TestSetShell/Successfully_set_shell/warnings create mode 100644 internal/users/testdata/golden/TestSetShell/Warning_if_shell_does_not_exist/db create mode 100644 internal/users/testdata/golden/TestSetShell/Warning_if_shell_does_not_exist/warnings create mode 100644 internal/users/testdata/golden/TestSetShell/Warning_if_shell_is_directory/db create mode 100644 internal/users/testdata/golden/TestSetShell/Warning_if_shell_is_directory/warnings create mode 100644 internal/users/testdata/golden/TestSetShell/Warning_if_shell_is_not_executable/db create mode 100644 internal/users/testdata/golden/TestSetShell/Warning_if_shell_is_not_executable/warnings create mode 100644 internal/users/testdata/golden/TestSetShell/Warning_if_shell_is_not_in_etc_shells/db create mode 100644 internal/users/testdata/golden/TestSetShell/Warning_if_shell_is_not_in_etc_shells/warnings diff --git a/cmd/authctl/user/set-shell.go b/cmd/authctl/user/set-shell.go index 7b180b182b..0e8e3e8778 100644 --- a/cmd/authctl/user/set-shell.go +++ b/cmd/authctl/user/set-shell.go @@ -5,6 +5,7 @@ import ( "github.com/canonical/authd/cmd/authctl/internal/client" "github.com/canonical/authd/cmd/authctl/internal/completion" + "github.com/canonical/authd/cmd/authctl/internal/log" "github.com/canonical/authd/internal/proto/authd" "github.com/spf13/cobra" ) @@ -26,15 +27,20 @@ func runSetShell(cmd *cobra.Command, args []string) error { return err } - _, err = svc.SetShell(context.Background(), &authd.SetShellRequest{ + resp, err := svc.SetShell(context.Background(), &authd.SetShellRequest{ Name: name, Shell: shell, }) - if err != nil { + if resp == nil { return err } - return nil + // Print any warnings returned by the server. + for _, warning := range resp.Warnings { + log.Warning(warning) + } + + return err } func setShellCompletionFunc(cmd *cobra.Command, args []string, toComplete string) ([]string, cobra.ShellCompDirective) { diff --git a/internal/proto/authd/authd.pb.go b/internal/proto/authd/authd.pb.go index b0a71ba675..5351978afb 100644 --- a/internal/proto/authd/authd.pb.go +++ b/internal/proto/authd/authd.pb.go @@ -1465,6 +1465,50 @@ func (x *SetShellRequest) GetShell() string { return "" } +type SetShellResponse struct { + state protoimpl.MessageState `protogen:"open.v1"` + Warnings []string `protobuf:"bytes,1,rep,name=warnings,proto3" json:"warnings,omitempty"` + unknownFields protoimpl.UnknownFields + sizeCache protoimpl.SizeCache +} + +func (x *SetShellResponse) Reset() { + *x = SetShellResponse{} + mi := &file_authd_proto_msgTypes[27] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) +} + +func (x *SetShellResponse) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*SetShellResponse) ProtoMessage() {} + +func (x *SetShellResponse) ProtoReflect() protoreflect.Message { + mi := &file_authd_proto_msgTypes[27] + if x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use SetShellResponse.ProtoReflect.Descriptor instead. +func (*SetShellResponse) Descriptor() ([]byte, []int) { + return file_authd_proto_rawDescGZIP(), []int{27} +} + +func (x *SetShellResponse) GetWarnings() []string { + if x != nil { + return x.Warnings + } + return nil +} + type User struct { state protoimpl.MessageState `protogen:"open.v1"` Name string `protobuf:"bytes,1,opt,name=name,proto3" json:"name,omitempty"` @@ -1479,7 +1523,7 @@ type User struct { func (x *User) Reset() { *x = User{} - mi := &file_authd_proto_msgTypes[27] + mi := &file_authd_proto_msgTypes[28] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -1491,7 +1535,7 @@ func (x *User) String() string { func (*User) ProtoMessage() {} func (x *User) ProtoReflect() protoreflect.Message { - mi := &file_authd_proto_msgTypes[27] + mi := &file_authd_proto_msgTypes[28] if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -1504,7 +1548,7 @@ func (x *User) ProtoReflect() protoreflect.Message { // Deprecated: Use User.ProtoReflect.Descriptor instead. func (*User) Descriptor() ([]byte, []int) { - return file_authd_proto_rawDescGZIP(), []int{27} + return file_authd_proto_rawDescGZIP(), []int{28} } func (x *User) GetName() string { @@ -1558,7 +1602,7 @@ type Users struct { func (x *Users) Reset() { *x = Users{} - mi := &file_authd_proto_msgTypes[28] + mi := &file_authd_proto_msgTypes[29] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -1570,7 +1614,7 @@ func (x *Users) String() string { func (*Users) ProtoMessage() {} func (x *Users) ProtoReflect() protoreflect.Message { - mi := &file_authd_proto_msgTypes[28] + mi := &file_authd_proto_msgTypes[29] if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -1583,7 +1627,7 @@ func (x *Users) ProtoReflect() protoreflect.Message { // Deprecated: Use Users.ProtoReflect.Descriptor instead. func (*Users) Descriptor() ([]byte, []int) { - return file_authd_proto_rawDescGZIP(), []int{28} + return file_authd_proto_rawDescGZIP(), []int{29} } func (x *Users) GetUsers() []*User { @@ -1606,7 +1650,7 @@ type Group struct { func (x *Group) Reset() { *x = Group{} - mi := &file_authd_proto_msgTypes[29] + mi := &file_authd_proto_msgTypes[30] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -1618,7 +1662,7 @@ func (x *Group) String() string { func (*Group) ProtoMessage() {} func (x *Group) ProtoReflect() protoreflect.Message { - mi := &file_authd_proto_msgTypes[29] + mi := &file_authd_proto_msgTypes[30] if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -1631,7 +1675,7 @@ func (x *Group) ProtoReflect() protoreflect.Message { // Deprecated: Use Group.ProtoReflect.Descriptor instead. func (*Group) Descriptor() ([]byte, []int) { - return file_authd_proto_rawDescGZIP(), []int{29} + return file_authd_proto_rawDescGZIP(), []int{30} } func (x *Group) GetName() string { @@ -1671,7 +1715,7 @@ type Groups struct { func (x *Groups) Reset() { *x = Groups{} - mi := &file_authd_proto_msgTypes[30] + mi := &file_authd_proto_msgTypes[31] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -1683,7 +1727,7 @@ func (x *Groups) String() string { func (*Groups) ProtoMessage() {} func (x *Groups) ProtoReflect() protoreflect.Message { - mi := &file_authd_proto_msgTypes[30] + mi := &file_authd_proto_msgTypes[31] if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -1696,7 +1740,7 @@ func (x *Groups) ProtoReflect() protoreflect.Message { // Deprecated: Use Groups.ProtoReflect.Descriptor instead. func (*Groups) Descriptor() ([]byte, []int) { - return file_authd_proto_rawDescGZIP(), []int{30} + return file_authd_proto_rawDescGZIP(), []int{31} } func (x *Groups) GetGroups() []*Group { @@ -1717,7 +1761,7 @@ type ABResponse_BrokerInfo struct { func (x *ABResponse_BrokerInfo) Reset() { *x = ABResponse_BrokerInfo{} - mi := &file_authd_proto_msgTypes[31] + mi := &file_authd_proto_msgTypes[32] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -1729,7 +1773,7 @@ func (x *ABResponse_BrokerInfo) String() string { func (*ABResponse_BrokerInfo) ProtoMessage() {} func (x *ABResponse_BrokerInfo) ProtoReflect() protoreflect.Message { - mi := &file_authd_proto_msgTypes[31] + mi := &file_authd_proto_msgTypes[32] if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -1776,7 +1820,7 @@ type GAMResponse_AuthenticationMode struct { func (x *GAMResponse_AuthenticationMode) Reset() { *x = GAMResponse_AuthenticationMode{} - mi := &file_authd_proto_msgTypes[32] + mi := &file_authd_proto_msgTypes[33] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -1788,7 +1832,7 @@ func (x *GAMResponse_AuthenticationMode) String() string { func (*GAMResponse_AuthenticationMode) ProtoMessage() {} func (x *GAMResponse_AuthenticationMode) ProtoReflect() protoreflect.Message { - mi := &file_authd_proto_msgTypes[32] + mi := &file_authd_proto_msgTypes[33] if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -1833,7 +1877,7 @@ type IARequest_AuthenticationData struct { func (x *IARequest_AuthenticationData) Reset() { *x = IARequest_AuthenticationData{} - mi := &file_authd_proto_msgTypes[33] + mi := &file_authd_proto_msgTypes[34] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -1845,7 +1889,7 @@ func (x *IARequest_AuthenticationData) String() string { func (*IARequest_AuthenticationData) ProtoMessage() {} func (x *IARequest_AuthenticationData) ProtoReflect() protoreflect.Message { - mi := &file_authd_proto_msgTypes[33] + mi := &file_authd_proto_msgTypes[34] if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -2053,7 +2097,9 @@ const file_authd_proto_rawDesc = "" + "\bwarnings\x18\x03 \x03(\tR\bwarnings\";\n" + "\x0fSetShellRequest\x12\x12\n" + "\x04name\x18\x01 \x01(\tR\x04name\x12\x14\n" + - "\x05shell\x18\x02 \x01(\tR\x05shell\"\x84\x01\n" + + "\x05shell\x18\x02 \x01(\tR\x05shell\".\n" + + "\x10SetShellResponse\x12\x1a\n" + + "\bwarnings\x18\x01 \x03(\tR\bwarnings\"\x84\x01\n" + "\x04User\x12\x12\n" + "\x04name\x18\x01 \x01(\tR\x04name\x12\x10\n" + "\x03uid\x18\x02 \x01(\rR\x03uid\x12\x10\n" + @@ -2083,7 +2129,7 @@ const file_authd_proto_rawDesc = "" + "\x0fIsAuthenticated\x12\x10.authd.IARequest\x1a\x11.authd.IAResponse\x12,\n" + "\n" + "EndSession\x12\x10.authd.ESRequest\x1a\f.authd.Empty\x12<\n" + - "\x17SetDefaultBrokerForUser\x12\x13.authd.SDBFURequest\x1a\f.authd.Empty2\xe8\x04\n" + + "\x17SetDefaultBrokerForUser\x12\x13.authd.SDBFURequest\x1a\f.authd.Empty2\xf3\x04\n" + "\vUserService\x129\n" + "\rGetUserByName\x12\x1b.authd.GetUserByNameRequest\x1a\v.authd.User\x125\n" + "\vGetUserByID\x12\x19.authd.GetUserByIDRequest\x1a\v.authd.User\x12'\n" + @@ -2093,8 +2139,8 @@ const file_authd_proto_rawDesc = "" + "UnlockUser\x12\x18.authd.UnlockUserRequest\x1a\f.authd.Empty\x12>\n" + "\tSetUserID\x12\x17.authd.SetUserIDRequest\x1a\x18.authd.SetUserIDResponse\x12A\n" + "\n" + - "SetGroupID\x12\x18.authd.SetGroupIDRequest\x1a\x19.authd.SetGroupIDResponse\x120\n" + - "\bSetShell\x12\x16.authd.SetShellRequest\x1a\f.authd.Empty\x12<\n" + + "SetGroupID\x12\x18.authd.SetGroupIDRequest\x1a\x19.authd.SetGroupIDResponse\x12;\n" + + "\bSetShell\x12\x16.authd.SetShellRequest\x1a\x17.authd.SetShellResponse\x12<\n" + "\x0eGetGroupByName\x12\x1c.authd.GetGroupByNameRequest\x1a\f.authd.Group\x128\n" + "\fGetGroupByID\x12\x1a.authd.GetGroupByIDRequest\x1a\f.authd.Group\x12)\n" + "\n" + @@ -2113,7 +2159,7 @@ func file_authd_proto_rawDescGZIP() []byte { } var file_authd_proto_enumTypes = make([]protoimpl.EnumInfo, 1) -var file_authd_proto_msgTypes = make([]protoimpl.MessageInfo, 34) +var file_authd_proto_msgTypes = make([]protoimpl.MessageInfo, 35) var file_authd_proto_goTypes = []any{ (SessionMode)(0), // 0: authd.SessionMode (*Empty)(nil), // 1: authd.Empty @@ -2143,23 +2189,24 @@ var file_authd_proto_goTypes = []any{ (*SetGroupIDRequest)(nil), // 25: authd.SetGroupIDRequest (*SetGroupIDResponse)(nil), // 26: authd.SetGroupIDResponse (*SetShellRequest)(nil), // 27: authd.SetShellRequest - (*User)(nil), // 28: authd.User - (*Users)(nil), // 29: authd.Users - (*Group)(nil), // 30: authd.Group - (*Groups)(nil), // 31: authd.Groups - (*ABResponse_BrokerInfo)(nil), // 32: authd.ABResponse.BrokerInfo - (*GAMResponse_AuthenticationMode)(nil), // 33: authd.GAMResponse.AuthenticationMode - (*IARequest_AuthenticationData)(nil), // 34: authd.IARequest.AuthenticationData + (*SetShellResponse)(nil), // 28: authd.SetShellResponse + (*User)(nil), // 29: authd.User + (*Users)(nil), // 30: authd.Users + (*Group)(nil), // 31: authd.Group + (*Groups)(nil), // 32: authd.Groups + (*ABResponse_BrokerInfo)(nil), // 33: authd.ABResponse.BrokerInfo + (*GAMResponse_AuthenticationMode)(nil), // 34: authd.GAMResponse.AuthenticationMode + (*IARequest_AuthenticationData)(nil), // 35: authd.IARequest.AuthenticationData } var file_authd_proto_depIdxs = []int32{ - 32, // 0: authd.ABResponse.brokers_infos:type_name -> authd.ABResponse.BrokerInfo + 33, // 0: authd.ABResponse.brokers_infos:type_name -> authd.ABResponse.BrokerInfo 0, // 1: authd.SBRequest.mode:type_name -> authd.SessionMode 9, // 2: authd.GAMRequest.supported_ui_layouts:type_name -> authd.UILayout - 33, // 3: authd.GAMResponse.authentication_modes:type_name -> authd.GAMResponse.AuthenticationMode + 34, // 3: authd.GAMResponse.authentication_modes:type_name -> authd.GAMResponse.AuthenticationMode 9, // 4: authd.SAMResponse.ui_layout_info:type_name -> authd.UILayout - 34, // 5: authd.IARequest.authentication_data:type_name -> authd.IARequest.AuthenticationData - 28, // 6: authd.Users.users:type_name -> authd.User - 30, // 7: authd.Groups.groups:type_name -> authd.Group + 35, // 5: authd.IARequest.authentication_data:type_name -> authd.IARequest.AuthenticationData + 29, // 6: authd.Users.users:type_name -> authd.User + 31, // 7: authd.Groups.groups:type_name -> authd.Group 1, // 8: authd.PAM.AvailableBrokers:input_type -> authd.Empty 2, // 9: authd.PAM.GetPreviousBroker:input_type -> authd.GPBRequest 6, // 10: authd.PAM.SelectBroker:input_type -> authd.SBRequest @@ -2187,17 +2234,17 @@ var file_authd_proto_depIdxs = []int32{ 14, // 32: authd.PAM.IsAuthenticated:output_type -> authd.IAResponse 1, // 33: authd.PAM.EndSession:output_type -> authd.Empty 1, // 34: authd.PAM.SetDefaultBrokerForUser:output_type -> authd.Empty - 28, // 35: authd.UserService.GetUserByName:output_type -> authd.User - 28, // 36: authd.UserService.GetUserByID:output_type -> authd.User - 29, // 37: authd.UserService.ListUsers:output_type -> authd.Users + 29, // 35: authd.UserService.GetUserByName:output_type -> authd.User + 29, // 36: authd.UserService.GetUserByID:output_type -> authd.User + 30, // 37: authd.UserService.ListUsers:output_type -> authd.Users 1, // 38: authd.UserService.LockUser:output_type -> authd.Empty 1, // 39: authd.UserService.UnlockUser:output_type -> authd.Empty 24, // 40: authd.UserService.SetUserID:output_type -> authd.SetUserIDResponse 26, // 41: authd.UserService.SetGroupID:output_type -> authd.SetGroupIDResponse - 1, // 42: authd.UserService.SetShell:output_type -> authd.Empty - 30, // 43: authd.UserService.GetGroupByName:output_type -> authd.Group - 30, // 44: authd.UserService.GetGroupByID:output_type -> authd.Group - 31, // 45: authd.UserService.ListGroups:output_type -> authd.Groups + 28, // 42: authd.UserService.SetShell:output_type -> authd.SetShellResponse + 31, // 43: authd.UserService.GetGroupByName:output_type -> authd.Group + 31, // 44: authd.UserService.GetGroupByID:output_type -> authd.Group + 32, // 45: authd.UserService.ListGroups:output_type -> authd.Groups 27, // [27:46] is the sub-list for method output_type 8, // [8:27] is the sub-list for method input_type 8, // [8:8] is the sub-list for extension type_name @@ -2211,8 +2258,8 @@ func file_authd_proto_init() { return } file_authd_proto_msgTypes[8].OneofWrappers = []any{} - file_authd_proto_msgTypes[31].OneofWrappers = []any{} - file_authd_proto_msgTypes[33].OneofWrappers = []any{ + file_authd_proto_msgTypes[32].OneofWrappers = []any{} + file_authd_proto_msgTypes[34].OneofWrappers = []any{ (*IARequest_AuthenticationData_Secret)(nil), (*IARequest_AuthenticationData_Wait)(nil), (*IARequest_AuthenticationData_Skip)(nil), @@ -2224,7 +2271,7 @@ func file_authd_proto_init() { GoPackagePath: reflect.TypeOf(x{}).PkgPath(), RawDescriptor: unsafe.Slice(unsafe.StringData(file_authd_proto_rawDesc), len(file_authd_proto_rawDesc)), NumEnums: 1, - NumMessages: 34, + NumMessages: 35, NumExtensions: 0, NumServices: 2, }, diff --git a/internal/proto/authd/authd.proto b/internal/proto/authd/authd.proto index 5706107bf1..9aae541f5f 100644 --- a/internal/proto/authd/authd.proto +++ b/internal/proto/authd/authd.proto @@ -137,7 +137,7 @@ service UserService { rpc UnlockUser(UnlockUserRequest) returns (Empty); rpc SetUserID(SetUserIDRequest) returns (SetUserIDResponse); rpc SetGroupID(SetGroupIDRequest) returns (SetGroupIDResponse); - rpc SetShell(SetShellRequest) returns (Empty); + rpc SetShell(SetShellRequest) returns (SetShellResponse); rpc GetGroupByName(GetGroupByNameRequest) returns (Group); rpc GetGroupByID(GetGroupByIDRequest) returns (Group); @@ -202,6 +202,10 @@ message SetShellRequest { string shell = 2; } +message SetShellResponse { + repeated string warnings = 1; +} + message User { string name = 1; uint32 uid = 2; diff --git a/internal/proto/authd/authd_grpc.pb.go b/internal/proto/authd/authd_grpc.pb.go index b6ca785927..d6146889df 100644 --- a/internal/proto/authd/authd_grpc.pb.go +++ b/internal/proto/authd/authd_grpc.pb.go @@ -411,7 +411,7 @@ type UserServiceClient interface { UnlockUser(ctx context.Context, in *UnlockUserRequest, opts ...grpc.CallOption) (*Empty, error) SetUserID(ctx context.Context, in *SetUserIDRequest, opts ...grpc.CallOption) (*SetUserIDResponse, error) SetGroupID(ctx context.Context, in *SetGroupIDRequest, opts ...grpc.CallOption) (*SetGroupIDResponse, error) - SetShell(ctx context.Context, in *SetShellRequest, opts ...grpc.CallOption) (*Empty, error) + SetShell(ctx context.Context, in *SetShellRequest, opts ...grpc.CallOption) (*SetShellResponse, error) GetGroupByName(ctx context.Context, in *GetGroupByNameRequest, opts ...grpc.CallOption) (*Group, error) GetGroupByID(ctx context.Context, in *GetGroupByIDRequest, opts ...grpc.CallOption) (*Group, error) ListGroups(ctx context.Context, in *Empty, opts ...grpc.CallOption) (*Groups, error) @@ -495,9 +495,9 @@ func (c *userServiceClient) SetGroupID(ctx context.Context, in *SetGroupIDReques return out, nil } -func (c *userServiceClient) SetShell(ctx context.Context, in *SetShellRequest, opts ...grpc.CallOption) (*Empty, error) { +func (c *userServiceClient) SetShell(ctx context.Context, in *SetShellRequest, opts ...grpc.CallOption) (*SetShellResponse, error) { cOpts := append([]grpc.CallOption{grpc.StaticMethod()}, opts...) - out := new(Empty) + out := new(SetShellResponse) err := c.cc.Invoke(ctx, UserService_SetShell_FullMethodName, in, out, cOpts...) if err != nil { return nil, err @@ -546,7 +546,7 @@ type UserServiceServer interface { UnlockUser(context.Context, *UnlockUserRequest) (*Empty, error) SetUserID(context.Context, *SetUserIDRequest) (*SetUserIDResponse, error) SetGroupID(context.Context, *SetGroupIDRequest) (*SetGroupIDResponse, error) - SetShell(context.Context, *SetShellRequest) (*Empty, error) + SetShell(context.Context, *SetShellRequest) (*SetShellResponse, error) GetGroupByName(context.Context, *GetGroupByNameRequest) (*Group, error) GetGroupByID(context.Context, *GetGroupByIDRequest) (*Group, error) ListGroups(context.Context, *Empty) (*Groups, error) @@ -581,7 +581,7 @@ func (UnimplementedUserServiceServer) SetUserID(context.Context, *SetUserIDReque func (UnimplementedUserServiceServer) SetGroupID(context.Context, *SetGroupIDRequest) (*SetGroupIDResponse, error) { return nil, status.Error(codes.Unimplemented, "method SetGroupID not implemented") } -func (UnimplementedUserServiceServer) SetShell(context.Context, *SetShellRequest) (*Empty, error) { +func (UnimplementedUserServiceServer) SetShell(context.Context, *SetShellRequest) (*SetShellResponse, error) { return nil, status.Error(codes.Unimplemented, "method SetShell not implemented") } func (UnimplementedUserServiceServer) GetGroupByName(context.Context, *GetGroupByNameRequest) (*Group, error) { diff --git a/internal/services/user/user.go b/internal/services/user/user.go index 6299a45794..4fe4d1fe3f 100644 --- a/internal/services/user/user.go +++ b/internal/services/user/user.go @@ -271,7 +271,7 @@ func (s Service) SetGroupID(ctx context.Context, req *authd.SetGroupIDRequest) ( } // SetShell sets the shell of a user. -func (s Service) SetShell(ctx context.Context, req *authd.SetShellRequest) (*authd.Empty, error) { +func (s Service) SetShell(ctx context.Context, req *authd.SetShellRequest) (*authd.SetShellResponse, error) { // authd uses lowercase group names. name := strings.ToLower(req.GetName()) @@ -279,12 +279,15 @@ func (s Service) SetShell(ctx context.Context, req *authd.SetShellRequest) (*aut return nil, status.Error(codes.PermissionDenied, err.Error()) } - if err := s.userManager.SetShell(name, req.GetShell()); err != nil { + warnings, err := s.userManager.SetShell(name, req.GetShell()) + if err != nil { log.Errorf(ctx, "SetShell: %v", err) return nil, grpcError(err) } - return &authd.Empty{}, nil + return &authd.SetShellResponse{ + Warnings: warnings, + }, nil } // userToProtobuf converts a types.UserEntry to authd.User. diff --git a/internal/users/manager.go b/internal/users/manager.go index a997beeaf5..85d3214bd4 100644 --- a/internal/users/manager.go +++ b/internal/users/manager.go @@ -666,36 +666,33 @@ func checkHomeDirOwner(home string, uid, gid uint32) error { } // SetShell sets the shell for the given user. -func (m *Manager) SetShell(username, shell string) (err error) { +func (m *Manager) SetShell(username, shell string) (warnings []string, err error) { if username == "" { - return errors.New("empty username") + return nil, errors.New("empty username") } // Check if the user exists _, err = m.db.UserByName(username) if err != nil { - return err + return nil, err } - err = checkValidShell(shell) + err = checkValidShellPath(shell) if err != nil { - return err - } - - stat, err := os.Stat(shell) - if errors.Is(err, os.ErrNotExist) { - return fmt.Errorf("shell %q does not exist", shell) + return nil, err } - if stat.IsDir() || stat.Mode()&0111 == 0 { - return fmt.Errorf("shell %q is not executable", shell) + err = checkValidShell(shell) + if err != nil { + // We allow root to set an invalid shell but print a warning + warnings = append(warnings, fmt.Sprintf("Warning: %s", err.Error())) } if err = m.db.SetShell(username, shell); err != nil { - return err + return warnings, err } - return nil + return warnings, nil } // BrokerForUser returns the broker ID for the given user. diff --git a/internal/users/manager_test.go b/internal/users/manager_test.go index 8d113f6432..379e2ac2ec 100644 --- a/internal/users/manager_test.go +++ b/internal/users/manager_test.go @@ -1310,10 +1310,28 @@ func TestSetShell(t *testing.T) { emptyShell bool shell string - wantErr bool + wantWarnings int + wantErr bool }{ "Successfully_set_shell": {}, + "Warning_if_shell_is_not_in_etc_shells": { + shell: "/bin/ls", + wantWarnings: 1, + }, + "Warning_if_shell_does_not_exist": { + shell: "/doesnotexist", + wantWarnings: 1, + }, + "Warning_if_shell_is_directory": { + shell: "/etc", + wantWarnings: 1, + }, + "Warning_if_shell_is_not_executable": { + shell: "/etc/passwd", + wantWarnings: 1, + }, + // checkValidPasswdField error cases "Error_if_shell_is_empty": { emptyShell: true, @@ -1344,7 +1362,7 @@ func TestSetShell(t *testing.T) { wantErr: true, }, - // checkValidShell error cases + // checkValidShellPath error cases "Error_if_shell_is_not_absolute_path": { shell: "bin/sh", wantErr: true, @@ -1362,19 +1380,7 @@ func TestSetShell(t *testing.T) { wantErr: true, }, - // Other error cases - "Error_if_shell_does_not_exist": { - shell: "/doesnotexist", - wantErr: true, - }, - "Error_if_shell_is_directory": { - shell: "/etc", - wantErr: true, - }, - "Error_if_shell_is_not_executable": { - shell: "/etc/passwd", - wantErr: true, - }, + // other error cases "Error_if_user_does_not_exist": { nonExistentUser: true, wantErr: true, @@ -1409,15 +1415,20 @@ func TestSetShell(t *testing.T) { shell = tc.shell } - err = m.SetShell(username, shell) + warnings, err := m.SetShell(username, shell) requireErrorAssertions(t, err, nil, tc.wantErr) + + require.Len(t, warnings, tc.wantWarnings, "Number of warnings mismatch") + if tc.wantErr { return } yamlData, err := db.Z_ForTests_DumpNormalizedYAML(m.DB()) require.NoError(t, err) - golden.CheckOrUpdate(t, yamlData) + golden.CheckOrUpdate(t, yamlData, golden.WithPath("db")) + + golden.CheckOrUpdateYAML(t, warnings, golden.WithPath("warnings")) }) } } diff --git a/internal/users/testdata/golden/TestSetShell/Successfully_set_shell b/internal/users/testdata/golden/TestSetShell/Successfully_set_shell/db similarity index 100% rename from internal/users/testdata/golden/TestSetShell/Successfully_set_shell rename to internal/users/testdata/golden/TestSetShell/Successfully_set_shell/db diff --git a/internal/users/testdata/golden/TestSetShell/Successfully_set_shell/warnings b/internal/users/testdata/golden/TestSetShell/Successfully_set_shell/warnings new file mode 100644 index 0000000000..fe51488c70 --- /dev/null +++ b/internal/users/testdata/golden/TestSetShell/Successfully_set_shell/warnings @@ -0,0 +1 @@ +[] diff --git a/internal/users/testdata/golden/TestSetShell/Warning_if_shell_does_not_exist/db b/internal/users/testdata/golden/TestSetShell/Warning_if_shell_does_not_exist/db new file mode 100644 index 0000000000..8bea3e2319 --- /dev/null +++ b/internal/users/testdata/golden/TestSetShell/Warning_if_shell_does_not_exist/db @@ -0,0 +1,18 @@ +users: + - name: user1 + uid: 1111 + gid: 11111 + gecos: |- + User1 gecos + On multiple lines + dir: /home/user1 + shell: /doesnotexist + broker_id: broker-id +groups: + - name: group1 + gid: 11111 + ugid: "12345678" +users_to_groups: + - uid: 1111 + gid: 11111 +schema_version: 2 diff --git a/internal/users/testdata/golden/TestSetShell/Warning_if_shell_does_not_exist/warnings b/internal/users/testdata/golden/TestSetShell/Warning_if_shell_does_not_exist/warnings new file mode 100644 index 0000000000..d8f388a81a --- /dev/null +++ b/internal/users/testdata/golden/TestSetShell/Warning_if_shell_does_not_exist/warnings @@ -0,0 +1 @@ +- 'Warning: shell ''/doesnotexist'' does not exist' diff --git a/internal/users/testdata/golden/TestSetShell/Warning_if_shell_is_directory/db b/internal/users/testdata/golden/TestSetShell/Warning_if_shell_is_directory/db new file mode 100644 index 0000000000..4b8eee1498 --- /dev/null +++ b/internal/users/testdata/golden/TestSetShell/Warning_if_shell_is_directory/db @@ -0,0 +1,18 @@ +users: + - name: user1 + uid: 1111 + gid: 11111 + gecos: |- + User1 gecos + On multiple lines + dir: /home/user1 + shell: /etc + broker_id: broker-id +groups: + - name: group1 + gid: 11111 + ugid: "12345678" +users_to_groups: + - uid: 1111 + gid: 11111 +schema_version: 2 diff --git a/internal/users/testdata/golden/TestSetShell/Warning_if_shell_is_directory/warnings b/internal/users/testdata/golden/TestSetShell/Warning_if_shell_is_directory/warnings new file mode 100644 index 0000000000..8063733301 --- /dev/null +++ b/internal/users/testdata/golden/TestSetShell/Warning_if_shell_is_directory/warnings @@ -0,0 +1 @@ +- 'Warning: shell ''/etc'' is not an executable file' diff --git a/internal/users/testdata/golden/TestSetShell/Warning_if_shell_is_not_executable/db b/internal/users/testdata/golden/TestSetShell/Warning_if_shell_is_not_executable/db new file mode 100644 index 0000000000..4e506fb35a --- /dev/null +++ b/internal/users/testdata/golden/TestSetShell/Warning_if_shell_is_not_executable/db @@ -0,0 +1,18 @@ +users: + - name: user1 + uid: 1111 + gid: 11111 + gecos: |- + User1 gecos + On multiple lines + dir: /home/user1 + shell: /etc/passwd + broker_id: broker-id +groups: + - name: group1 + gid: 11111 + ugid: "12345678" +users_to_groups: + - uid: 1111 + gid: 11111 +schema_version: 2 diff --git a/internal/users/testdata/golden/TestSetShell/Warning_if_shell_is_not_executable/warnings b/internal/users/testdata/golden/TestSetShell/Warning_if_shell_is_not_executable/warnings new file mode 100644 index 0000000000..0baecc86c8 --- /dev/null +++ b/internal/users/testdata/golden/TestSetShell/Warning_if_shell_is_not_executable/warnings @@ -0,0 +1 @@ +- 'Warning: shell ''/etc/passwd'' is not an executable file' diff --git a/internal/users/testdata/golden/TestSetShell/Warning_if_shell_is_not_in_etc_shells/db b/internal/users/testdata/golden/TestSetShell/Warning_if_shell_is_not_in_etc_shells/db new file mode 100644 index 0000000000..f5dd819874 --- /dev/null +++ b/internal/users/testdata/golden/TestSetShell/Warning_if_shell_is_not_in_etc_shells/db @@ -0,0 +1,18 @@ +users: + - name: user1 + uid: 1111 + gid: 11111 + gecos: |- + User1 gecos + On multiple lines + dir: /home/user1 + shell: /bin/ls + broker_id: broker-id +groups: + - name: group1 + gid: 11111 + ugid: "12345678" +users_to_groups: + - uid: 1111 + gid: 11111 +schema_version: 2 diff --git a/internal/users/testdata/golden/TestSetShell/Warning_if_shell_is_not_in_etc_shells/warnings b/internal/users/testdata/golden/TestSetShell/Warning_if_shell_is_not_in_etc_shells/warnings new file mode 100644 index 0000000000..d3bff1320b --- /dev/null +++ b/internal/users/testdata/golden/TestSetShell/Warning_if_shell_is_not_in_etc_shells/warnings @@ -0,0 +1 @@ +- 'Warning: shell ''/bin/ls'' is not allowed in /etc/shells' diff --git a/internal/users/userutils.go b/internal/users/userutils.go index db254c2600..554187e0d9 100644 --- a/internal/users/userutils.go +++ b/internal/users/userutils.go @@ -2,6 +2,7 @@ package users import ( "errors" + "fmt" "os" "path" "strings" @@ -32,30 +33,48 @@ func checkValidPasswdField(value string) (err error) { return nil } -func checkValidShell(shell string) (err error) { +func checkValidShellPath(shell string) (err error) { // Do the same checks as systemd-homed in shell_is_ok: // https://github.com/systemd/systemd/blob/ba67af7efb7b743ba1974ef9ceb53fba0e3f9e21/src/home/homectl.c#L2812 if err = checkValidPasswdField(shell); err != nil { return err } + if !path.IsAbs(shell) { return errors.New("shell must be an absolute path") } + if shell != path.Clean(shell) { return errors.New("shell path must be normalized") } + // PATH_MAX is counted with the terminating null byte if unix.PathMax-1 < len(shell) { return errors.New("shell path is too long") } + return nil +} + +func checkValidShell(shell string) (err error) { + // Check if the shell exists and is executable + stat, err := os.Stat(shell) + if errors.Is(err, os.ErrNotExist) { + return fmt.Errorf("shell '%s' does not exist", shell) + } + + if stat.IsDir() || stat.Mode()&0111 == 0 { + return fmt.Errorf("shell '%s' is not an executable file", shell) + } + // Check if the shell is in the list of allowed shells in /etc/shells shells, err := os.ReadFile("/etc/shells") if err != nil { return err } + for _, allowedShell := range strings.Split(string(shells), "\n") { - if allowedShell[0] == '#' { + if len(allowedShell) > 0 && allowedShell[0] == '#' { // Skip comments continue } @@ -64,5 +83,5 @@ func checkValidShell(shell string) (err error) { } } - return errors.New("shell is not allowed in /etc/shells") + return fmt.Errorf("shell '%s' is not allowed in /etc/shells", shell) } From 5c758e9f64f7fee3181a1389bed76d64f57b8da7 Mon Sep 17 00:00:00 2001 From: Adrian Dombeck Date: Wed, 25 Feb 2026 23:01:50 +0100 Subject: [PATCH 13/14] SetUserID: Improve warning messages --- internal/users/manager.go | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/internal/users/manager.go b/internal/users/manager.go index 85d3214bd4..340da97e20 100644 --- a/internal/users/manager.go +++ b/internal/users/manager.go @@ -439,7 +439,7 @@ func (m *Manager) SetUserID(name string, uid uint32) (resp *SetUserIDResp, err e // Check if the home directory is currently owned by the user. homeUID, _, err := getHomeDirOwner(oldUser.Dir) if err != nil && !errors.Is(err, os.ErrNotExist) { - warning := fmt.Sprintf("Could not get owner of home directory '%s'.", oldUser.Dir) + warning := fmt.Sprintf("Warning: Could not get owner of home directory '%s', not updating ownership.", oldUser.Dir) log.Warningf(context.Background(), "%s: %v", warning, err) resp.Warnings = append(resp.Warnings, warning) return resp, nil @@ -451,7 +451,7 @@ func (m *Manager) SetUserID(name string, uid uint32) (resp *SetUserIDResp, err e } if homeUID != oldUser.UID { - warning := fmt.Sprintf("Not updating ownership of home directory '%s' because it is not owned by UID %d (current owner: %d).", oldUser.Dir, oldUser.UID, homeUID) + warning := fmt.Sprintf("Warning: Not updating ownership of home directory '%s' because it is not owned by UID %d (current owner: %d).", oldUser.Dir, oldUser.UID, homeUID) log.Warning(context.Background(), warning) resp.Warnings = append(resp.Warnings, warning) return resp, nil @@ -549,18 +549,18 @@ func (m *Manager) updateUserHomeDirOwnership(userRow db.UserRow, oldGID uint32, // Check if the home directory is currently owned by the group _, homeGID, err := getHomeDirOwner(userRow.Dir) if err != nil && !errors.Is(err, os.ErrNotExist) { - warning := fmt.Sprintf("Could not get owner of home directory '%s' for user '%s'.", userRow.Dir, userRow.Name) + warning := fmt.Sprintf("Warning: Could not get owner of home directory '%s', not updating ownership.", userRow.Dir) log.Warningf(context.Background(), "%s: %v", warning, err) return false, warning, nil } if errors.Is(err, os.ErrNotExist) { // The home directory does not exist, so we don't need to change the owner. - log.Debugf(context.Background(), "Home directory %q for user %q does not exist, skipping ownership change", userRow.Dir, userRow.Name) + log.Debugf(context.Background(), "Not updating ownership of home directory %q for user %q because it does not exist", userRow.Dir, userRow.Name) return false, "", nil } if homeGID != oldGID { - warning := fmt.Sprintf("Not updating ownership of home directory '%s' because it is not owned by GID %d (current owner: %d).", userRow.Dir, oldGID, homeGID) + warning := fmt.Sprintf("Warning: Not updating ownership of home directory '%s' because it is not owned by GID %d (current owner: %d).", userRow.Dir, oldGID, homeGID) log.Warning(context.Background(), warning) return false, warning, nil } From ee009434f9d9eaea27f258d61aafb9bc8aaf90f3 Mon Sep 17 00:00:00 2001 From: Adrian Dombeck Date: Thu, 26 Feb 2026 12:48:20 +0100 Subject: [PATCH 14/14] TestSetUserID: Fix entire warning message replaced The warning message in the golden file still didn't have the "Warning:" prefix we now add to most of our warnings. That's because the tests replaced the entire warning message. Fix by only replacing the parts which need to be replaced. --- internal/users/manager_bwrap_test.go | 11 +++++------ .../response | 2 +- .../response | 2 +- 3 files changed, 7 insertions(+), 8 deletions(-) diff --git a/internal/users/manager_bwrap_test.go b/internal/users/manager_bwrap_test.go index c2a4cd3bd8..b05f1ace96 100644 --- a/internal/users/manager_bwrap_test.go +++ b/internal/users/manager_bwrap_test.go @@ -155,12 +155,11 @@ func TestSetUserID(t *testing.T) { // To make the tests deterministic, we replace the temporary home directory path with a placeholder for i, w := range resp.Warnings { - if regexp.MustCompile(`Could not get owner of home directory '([^"]+)'`).MatchString(w) { - resp.Warnings[i] = `Could not get owner of home directory '{{HOME}}'` - } - if regexp.MustCompile(`Not updating ownership of home directory '([^"]+)' because it is not owned by UID \d+ \(current owner: \d+\)`).MatchString(w) { - resp.Warnings[i] = `Not updating ownership of home directory '{{HOME}}' because it is not owned by UID {{UID}} (current owner: {{CURR_UID}})` - } + // Replace home directory path with placeholder + w = regexp.MustCompile(`home directory '([^']+)'`).ReplaceAllString(w, `home directory '{{HOME}}'`) + // Replace UID and current owner UID with placeholders + w = regexp.MustCompile(`UID (\d+) \(current owner: (\d+)\)`).ReplaceAllString(w, `UID {{UID}} (current owner: {{CURR_UID}})`) + resp.Warnings[i] = w } golden.CheckOrUpdateYAML(t, resp, golden.WithPath("response")) diff --git a/internal/users/testdata/golden/TestSetUserID/Warning_if_home_directory_cannot_be_accessed/response b/internal/users/testdata/golden/TestSetUserID/Warning_if_home_directory_cannot_be_accessed/response index 5967e5be65..6b8aa56782 100644 --- a/internal/users/testdata/golden/TestSetUserID/Warning_if_home_directory_cannot_be_accessed/response +++ b/internal/users/testdata/golden/TestSetUserID/Warning_if_home_directory_cannot_be_accessed/response @@ -1,4 +1,4 @@ idchanged: true homedirownerchanged: false warnings: - - Could not get owner of home directory '{{HOME}}' + - 'Warning: Could not get owner of home directory ''{{HOME}}'', not updating ownership.' diff --git a/internal/users/testdata/golden/TestSetUserID/Warning_if_home_directory_is_owned_by_other_user/response b/internal/users/testdata/golden/TestSetUserID/Warning_if_home_directory_is_owned_by_other_user/response index 1927dd6513..0f7edbb746 100644 --- a/internal/users/testdata/golden/TestSetUserID/Warning_if_home_directory_is_owned_by_other_user/response +++ b/internal/users/testdata/golden/TestSetUserID/Warning_if_home_directory_is_owned_by_other_user/response @@ -1,4 +1,4 @@ idchanged: true homedirownerchanged: false warnings: - - 'Not updating ownership of home directory ''{{HOME}}'' because it is not owned by UID {{UID}} (current owner: {{CURR_UID}})' + - 'Warning: Not updating ownership of home directory ''{{HOME}}'' because it is not owned by UID {{UID}} (current owner: {{CURR_UID}}).'