Skip to content
Merged
Original file line number Diff line number Diff line change
Expand Up @@ -37,8 +37,6 @@
import org.slf4j.LoggerFactory;

import java.io.IOException;
import java.util.ArrayList;
import java.util.List;
import java.util.Objects;
import java.util.Optional;
import java.util.concurrent.CompletableFuture;
Expand Down Expand Up @@ -182,19 +180,20 @@ public String toString() {

@Override
public void run() {
CompletableFuture<Void> applyLogFutures = CompletableFuture.completedFuture(null);
for(; state != State.STOP; ) {
try {
waitForCommit();
waitForCommit(applyLogFutures);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Trying to understand changes and impact,

  1. waitForCoimmit(applyLogFutures) --> call takeSnapshot(applyLogFutures) --> future.get() ==> No Ops
  2. applyLog(applyLogFutures): can add some future object to be wait
  3. checkAndTakeSnapshot(applyLogFutures) : This is moved before STOP(), and does future.get() by internal takeSnapshot()
  4. In Stop() check, just do future.get() for waiting task completion which was not there earlier

Q1: Step 3 changes checkAndTakeSnapshot() moving out of stop() and force checking to take snapshot do have performance impact?
Q2: Let No snapshot to be taken, in this case, future.get() is never called, do this is intended ?

  • wait on future is called if need take snapshot in checkAndTakeSnapshot()
  • wait on future is called if stop()
  • else its not called

Q3: We can refactor code as above behavior, that,

  • applyLog() can return existing future set
  • check for future.get() before stop()
  • move checkAndTakeSnapshot() out before stop()

This have similar impact instead of passing future deep inside takeSnapshot().

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.


Future.get() will be called inconsequential of whether the snapshot is taken or not.

Copy link
Contributor Author

@swamirishi swamirishi Feb 10, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We need to do a future.get() just before taking a snapshot. We cannot be having transactions still being applied in the background which could cause some inconsistency. It makes more sense to wait just before taking a snapshot that is why I moved the future.get() inside takeSnapshot method


if (state == State.RELOAD) {
reload();
}

final MemoizedSupplier<List<CompletableFuture<Message>>> futures = applyLog();
checkAndTakeSnapshot(futures);
applyLogFutures = applyLog(applyLogFutures);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's check if applyLogFutures.isCompletedExceptionally() before applying logs.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

But in this case we have already submitted next set of tasks right? What would be the point of checking the exception?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The point is to fail fast. Otherwise, it will keep applying log even if the state machine has failed.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The statemachine can take a call right if it has failed. From what I understand statemachine should receive all the transactions, and it is on the statemachine to have the guardrail if it should apply a transaction or not.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If we want to fail fast we should fail the group remove as well.

checkAndTakeSnapshot(applyLogFutures);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In one of case where no snpashot to be taken and no stop(), future.get() will not be called. Do this will fix the issue as expected for applyTransaction to finish ? I think no, it should wait for all case.

And related to performance impact if always need wait, do this call is always in stop() flow or all flow ? this part is not clear.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

as discussed, this is not a problem.


if (shouldStop()) {
checkAndTakeSnapshot(futures);
applyLogFutures.get();
stop();
}
} catch (Throwable t) {
Expand All @@ -210,14 +209,14 @@ public void run() {
}
}

private void waitForCommit() throws InterruptedException {
private void waitForCommit(CompletableFuture<?> applyLogFutures) throws InterruptedException, ExecutionException {
// When a peer starts, the committed is initialized to 0.
// It will be updated only after the leader contacts other peers.
// Thus it is possible to have applied > committed initially.
final long applied = getLastAppliedIndex();
for(; applied >= raftLog.getLastCommittedIndex() && state == State.RUNNING && !shouldStop(); ) {
if (server.getSnapshotRequestHandler().shouldTriggerTakingSnapshot()) {
takeSnapshot();
takeSnapshot(applyLogFutures);
}
if (awaitForSignal.await(100, TimeUnit.MILLISECONDS)) {
return;
Expand All @@ -239,8 +238,7 @@ private void reload() throws IOException {
state = State.RUNNING;
}

private MemoizedSupplier<List<CompletableFuture<Message>>> applyLog() throws RaftLogIOException {
final MemoizedSupplier<List<CompletableFuture<Message>>> futures = MemoizedSupplier.valueOf(ArrayList::new);
private CompletableFuture<Void> applyLog(CompletableFuture<Void> applyLogFutures) throws RaftLogIOException {
final long committed = raftLog.getLastCommittedIndex();
for(long applied; (applied = getLastAppliedIndex()) < committed && state == State.RUNNING && !shouldStop(); ) {
final long nextIndex = applied + 1;
Expand All @@ -263,7 +261,12 @@ private MemoizedSupplier<List<CompletableFuture<Message>>> applyLog() throws Raf
final long incremented = appliedIndex.incrementAndGet(debugIndexChange);
Preconditions.assertTrue(incremented == nextIndex);
if (f != null) {
futures.get().add(f);
CompletableFuture<Message> exceptionHandledFuture = f.exceptionally(ex -> {
LOG.error("Exception while {}: applying txn index={}, nextLog={}", this, nextIndex,
LogProtoUtils.toLogEntryString(entry), ex);
return null;
});
applyLogFutures = applyLogFutures.thenCombine(exceptionHandledFuture, (v, message) -> null);
f.thenAccept(m -> notifyAppliedIndex(incremented));
} else {
notifyAppliedIndex(incremented);
Expand All @@ -272,23 +275,20 @@ private MemoizedSupplier<List<CompletableFuture<Message>>> applyLog() throws Raf
next.release();
}
}
return futures;
return applyLogFutures;
}

private void checkAndTakeSnapshot(MemoizedSupplier<List<CompletableFuture<Message>>> futures)
private void checkAndTakeSnapshot(CompletableFuture<?> futures)
throws ExecutionException, InterruptedException {
// check if need to trigger a snapshot
if (shouldTakeSnapshot()) {
if (futures.isInitialized()) {
JavaUtils.allOf(futures.get()).get();
}

takeSnapshot();
takeSnapshot(futures);
}
}

private void takeSnapshot() {
private void takeSnapshot(CompletableFuture<?> applyLogFutures) throws ExecutionException, InterruptedException {
final long i;
applyLogFutures.get();
try {
try(UncheckedAutoCloseable ignored = Timekeeper.start(stateMachineMetrics.get().getTakeSnapshotTimer())) {
i = stateMachine.takeSnapshot();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,47 +28,106 @@
import org.apache.ratis.statemachine.impl.SimpleStateMachine4Testing;
import org.apache.ratis.statemachine.StateMachine;
import org.apache.ratis.statemachine.TransactionContext;
import org.junit.Assert;
import org.junit.Test;

import java.util.concurrent.CompletableFuture;
import org.junit.*;
import org.mockito.MockedStatic;
import org.mockito.Mockito;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import java.util.*;
import java.util.concurrent.*;
import java.util.concurrent.atomic.AtomicLong;

public abstract class StateMachineShutdownTests<CLUSTER extends MiniRaftCluster>
extends BaseTest
implements MiniRaftCluster.Factory.Get<CLUSTER> {

public static Logger LOG = LoggerFactory.getLogger(StateMachineUpdater.class);
private static MockedStatic<CompletableFuture> mocked;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is actually unused. Now, it is causing problems in JDK17; see RATIS-2344.

My bad that I did not see this earlier.

protected static class StateMachineWithConditionalWait extends
SimpleStateMachine4Testing {
boolean unblockAllTxns = false;
final Set<Long> blockTxns = ConcurrentHashMap.newKeySet();
private final ExecutorService executor = Executors.newFixedThreadPool(10);
public static Map<Long, Set<CompletableFuture<Message>>> futures = new ConcurrentHashMap<>();
public static Map<RaftPeerId, AtomicLong> numTxns = new ConcurrentHashMap<>();
private final Map<Long, Long> appliedTxns = new ConcurrentHashMap<>();

private synchronized void updateTxns() {
long appliedIndex = this.getLastAppliedTermIndex().getIndex() + 1;
Long appliedTerm = null;
while (appliedTxns.containsKey(appliedIndex)) {
appliedTerm = appliedTxns.remove(appliedIndex);
appliedIndex += 1;
}
if (appliedTerm != null) {
updateLastAppliedTermIndex(appliedTerm, appliedIndex - 1);
}
}

private final Long objectToWait = 0L;
volatile boolean blockOnApply = true;
@Override
public void notifyTermIndexUpdated(long term, long index) {
appliedTxns.put(index, term);
updateTxns();
}

@Override
public CompletableFuture<Message> applyTransaction(TransactionContext trx) {
if (blockOnApply) {
synchronized (objectToWait) {
try {
objectToWait.wait();
} catch (InterruptedException e) {
Thread.currentThread().interrupt();
throw new RuntimeException();
final RaftProtos.LogEntryProto entry = trx.getLogEntryUnsafe();

CompletableFuture<Message> future = new CompletableFuture<>();
futures.computeIfAbsent(Thread.currentThread().getId(), k -> new HashSet<>()).add(future);
executor.submit(() -> {
synchronized (blockTxns) {
if (!unblockAllTxns) {
blockTxns.add(entry.getIndex());
}
while (!unblockAllTxns && blockTxns.contains(entry.getIndex())) {
try {
blockTxns.wait(10000);
} catch (InterruptedException e) {
throw new RuntimeException(e);
}
}
}
numTxns.computeIfAbsent(getId(), (k) -> new AtomicLong()).incrementAndGet();
appliedTxns.put(entry.getIndex(), entry.getTerm());
updateTxns();
future.complete(new RaftTestUtil.SimpleMessage("done"));
});
return future;
}

public void unBlockApplyTxn(long txnId) {
synchronized (blockTxns) {
blockTxns.remove(txnId);
blockTxns.notifyAll();
}
final RaftProtos.LogEntryProto entry = trx.getLogEntryUnsafe();
updateLastAppliedTermIndex(entry.getTerm(), entry.getIndex());
return CompletableFuture.completedFuture(new RaftTestUtil.SimpleMessage("done"));
}

public void unBlockApplyTxn() {
blockOnApply = false;
synchronized (objectToWait) {
objectToWait.notifyAll();
public void unblockAllTxns() {
unblockAllTxns = true;
synchronized (blockTxns) {
for (Long txnId : blockTxns) {
blockTxns.remove(txnId);
}
blockTxns.notifyAll();
}
}
}

@Before
public void setup() {
mocked = Mockito.mockStatic(CompletableFuture.class, Mockito.CALLS_REAL_METHODS);
}

@After
public void tearDownClass() {
if (mocked != null) {
mocked.close();
}

}

@Test
public void testStateMachineShutdownWaitsForApplyTxn() throws Exception {
final RaftProperties prop = getProperties();
Expand All @@ -82,10 +141,9 @@ public void testStateMachineShutdownWaitsForApplyTxn() throws Exception {

//Unblock leader and one follower
((StateMachineWithConditionalWait)leader.getStateMachine())
.unBlockApplyTxn();
.unblockAllTxns();
((StateMachineWithConditionalWait)cluster.
getFollowers().get(0).getStateMachine()).unBlockApplyTxn();

getFollowers().get(0).getStateMachine()).unblockAllTxns();
cluster.getLeaderAndSendFirstMessage(true);

try (final RaftClient client = cluster.createClient(leaderId)) {
Expand All @@ -107,16 +165,30 @@ public void testStateMachineShutdownWaitsForApplyTxn() throws Exception {
final Thread t = new Thread(secondFollower::close);
t.start();

// The second follower should still be blocked in apply transaction
Assert.assertTrue(secondFollower.getInfo().getLastAppliedIndex() < logIndex);


// Now unblock the second follower
((StateMachineWithConditionalWait) secondFollower.getStateMachine())
.unBlockApplyTxn();
long minIndex = ((StateMachineWithConditionalWait) secondFollower.getStateMachine()).blockTxns.stream()
.min(Comparator.naturalOrder()).get();
Assert.assertEquals(2, StateMachineWithConditionalWait.numTxns.values().stream()
.filter(val -> val.get() == 3).count());
// The second follower should still be blocked in apply transaction
Assert.assertTrue(secondFollower.getInfo().getLastAppliedIndex() < minIndex);
for (long index : ((StateMachineWithConditionalWait) secondFollower.getStateMachine()).blockTxns) {
if (minIndex != index) {
((StateMachineWithConditionalWait) secondFollower.getStateMachine()).unBlockApplyTxn(index);
}
}
Assert.assertEquals(2, StateMachineWithConditionalWait.numTxns.values().stream()
.filter(val -> val.get() == 3).count());
Assert.assertTrue(secondFollower.getInfo().getLastAppliedIndex() < minIndex);
((StateMachineWithConditionalWait) secondFollower.getStateMachine()).unBlockApplyTxn(minIndex);

// Now wait for the thread
t.join(5000);
Assert.assertEquals(logIndex, secondFollower.getInfo().getLastAppliedIndex());
Assert.assertEquals(3, StateMachineWithConditionalWait.numTxns.values().stream()
.filter(val -> val.get() == 3).count());

cluster.shutdown();
}
Expand Down
Loading