diff --git a/x/poa/keeper/keeper.go b/x/poa/keeper/keeper.go index bde50eb..57dd1f1 100644 --- a/x/poa/keeper/keeper.go +++ b/x/poa/keeper/keeper.go @@ -83,6 +83,17 @@ func (k Keeper) ExecuteAddValidator(ctx sdk.Context, msg *types.MsgAddValidator) if err != nil { return err } + + // Check if the maximum number of validators has been reached + validators, err := k.sk.GetAllValidators(ctx) + if err != nil { + return err + } + //nolint:gosec + if uint32(len(validators)) >= params.MaxValidators { + return types.ErrMaxValidatorsReached + } + denom := params.BondDenom balance := k.bk.GetBalance(ctx, accAddress, denom) if !balance.IsZero() { diff --git a/x/poa/keeper/keeper_test.go b/x/poa/keeper/keeper_test.go index 63b9eb6..0e69c5e 100644 --- a/x/poa/keeper/keeper_test.go +++ b/x/poa/keeper/keeper_test.go @@ -20,7 +20,8 @@ func poaKeeperTestSetup(t *testing.T) (*Keeper, sdk.Context) { stakingHooks.EXPECT().BeforeValidatorSlashed(gomock.Any(), gomock.Any(), gomock.Any()).Return(nil).AnyTimes() stakingKeeper.EXPECT().GetParams(ctx).Return(stakingtypes.Params{ - BondDenom: "XRP", + BondDenom: "XRP", + MaxValidators: 32, }, nil).AnyTimes() stakingKeeper.EXPECT().GetValidator(ctx, gomock.Any()).Return(stakingtypes.Validator{Tokens: math.NewInt(0)}, nil).AnyTimes() stakingKeeper.EXPECT().GetAllDelegatorDelegations(ctx, gomock.Any()).Return([]stakingtypes.Delegation{}, nil).AnyTimes() @@ -32,6 +33,7 @@ func poaKeeperTestSetup(t *testing.T) (*Keeper, sdk.Context) { stakingKeeper.EXPECT().BondDenom(ctx).Return("XRP", nil).AnyTimes() stakingKeeper.EXPECT().Unbond(ctx, gomock.Any(), gomock.Any(), gomock.Any()).Return(math.ZeroInt(), nil).AnyTimes() stakingKeeper.EXPECT().Hooks().Return(stakingHooks).AnyTimes() + stakingKeeper.EXPECT().GetAllValidators(ctx).Return([]stakingtypes.Validator{}, nil).AnyTimes() } bankExpectations := func(ctx sdk.Context, bankKeeper *testutil.MockBankKeeper) { diff --git a/x/poa/types/errors.go b/x/poa/types/errors.go index a51b00e..17aefe8 100644 --- a/x/poa/types/errors.go +++ b/x/poa/types/errors.go @@ -15,4 +15,5 @@ var ( ErrAddressHasUnbondedTokens = sdkerrors.Register(ModuleName, 6, "address already has unbonded tokens") ErrAddressHasDelegatedTokens = sdkerrors.Register(ModuleName, 7, "address already has delegated tokens") ErrInvalidValidatorStatus = sdkerrors.Register(ModuleName, 8, "invalid validator status") + ErrMaxValidatorsReached = sdkerrors.Register(ModuleName, 9, "maximum number of validators reached") )