Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,11 @@ public class MysqlSession implements Session {
private final String databaseName;
private final String tableName;

@FunctionalInterface
private interface SqlOperation {
void execute() throws Exception;
}

/**
* Create a MysqlSession with default settings.
*
Expand Down Expand Up @@ -285,6 +290,38 @@ private String getFullTableName() {
return "`" + databaseName + "`.`" + tableName + "`";
}

/**
* Execute a write operation in an explicit transaction.
*
* <p>MysqlSession obtains and owns a fresh JDBC connection for each write method call. This
* helper makes write semantics consistent even when the underlying DataSource defaults to
* {@code autoCommit=false}, and restores the connection's original auto-commit mode before
* returning it to the pool.
*/
private void executeInWriteTransaction(Connection conn, SqlOperation operation)
throws Exception {
boolean originalAutoCommit = conn.getAutoCommit();
if (originalAutoCommit) {
conn.setAutoCommit(false);
}

try {
operation.execute();
conn.commit();
} catch (Exception e) {
try {
conn.rollback();
} catch (SQLException rollbackException) {
e.addSuppressed(rollbackException);
}
throw e;
} finally {
if (conn.getAutoCommit() != originalAutoCommit) {
conn.setAutoCommit(originalAutoCommit);
}
}
}

@Override
public void save(SessionKey sessionKey, String key, State value) {
String sessionId = sessionKey.toIdentifier();
Expand All @@ -298,18 +335,21 @@ public void save(SessionKey sessionKey, String key, State value) {
+ " VALUES (?, ?, ?, ?)"
+ " ON DUPLICATE KEY UPDATE state_data = VALUES(state_data)";

try (Connection conn = dataSource.getConnection();
PreparedStatement stmt = conn.prepareStatement(upsertSql)) {

String json = JsonUtils.getJsonCodec().toJson(value);

stmt.setString(1, sessionId);
stmt.setString(2, key);
stmt.setInt(3, SINGLE_STATE_INDEX);
stmt.setString(4, json);

stmt.executeUpdate();

try (Connection conn = dataSource.getConnection()) {
executeInWriteTransaction(
conn,
() -> {
try (PreparedStatement stmt = conn.prepareStatement(upsertSql)) {
String json = JsonUtils.getJsonCodec().toJson(value);

stmt.setString(1, sessionId);
stmt.setString(2, key);
stmt.setInt(3, SINGLE_STATE_INDEX);
stmt.setString(4, json);

stmt.executeUpdate();
}
});
} catch (Exception e) {
throw new RuntimeException("Failed to save state: " + key, e);
}
Expand Down Expand Up @@ -344,42 +384,35 @@ public void save(SessionKey sessionKey, String key, List<? extends State> values
String hashKey = key + HASH_KEY_SUFFIX;

try (Connection conn = dataSource.getConnection()) {
// Compute current hash
String currentHash = ListHashUtil.computeHash(values);

// Get stored hash
String storedHash = getStoredHash(conn, sessionId, hashKey);

// Get existing count
int existingCount = getListCount(conn, sessionId, key);

// Determine if full rewrite is needed
boolean needsFullRewrite =
ListHashUtil.needsFullRewrite(
currentHash, storedHash, values.size(), existingCount);

if (needsFullRewrite) {
// Transaction: delete all + insert all
conn.setAutoCommit(false);
try {
deleteListItems(conn, sessionId, key);
insertAllItems(conn, sessionId, key, values);
saveHash(conn, sessionId, hashKey, currentHash);
conn.commit();
} catch (Exception e) {
conn.rollback();
throw e;
} finally {
conn.setAutoCommit(true);
}
} else if (values.size() > existingCount) {
// Incremental append
List<? extends State> newItems = values.subList(existingCount, values.size());
insertItems(conn, sessionId, key, newItems, existingCount);
saveHash(conn, sessionId, hashKey, currentHash);
}
// else: no change, skip

executeInWriteTransaction(
conn,
() -> {
// Compute current hash
String currentHash = ListHashUtil.computeHash(values);

// Get stored hash
String storedHash = getStoredHash(conn, sessionId, hashKey);

// Get existing count
int existingCount = getListCount(conn, sessionId, key);

// Determine if full rewrite is needed
boolean needsFullRewrite =
ListHashUtil.needsFullRewrite(
currentHash, storedHash, values.size(), existingCount);

if (needsFullRewrite) {
deleteListItems(conn, sessionId, key);
insertAllItems(conn, sessionId, key, values);
saveHash(conn, sessionId, hashKey, currentHash);
} else if (values.size() > existingCount) {
List<? extends State> newItems =
values.subList(existingCount, values.size());
insertItems(conn, sessionId, key, newItems, existingCount);
saveHash(conn, sessionId, hashKey, currentHash);
}
// else: no change, skip
});
} catch (Exception e) {
throw new RuntimeException("Failed to save list: " + key, e);
}
Expand Down Expand Up @@ -626,13 +659,16 @@ public void delete(SessionKey sessionKey) {

String deleteSql = "DELETE FROM " + getFullTableName() + " WHERE session_id = ?";

try (Connection conn = dataSource.getConnection();
PreparedStatement stmt = conn.prepareStatement(deleteSql)) {

stmt.setString(1, sessionId);
stmt.executeUpdate();

} catch (SQLException e) {
try (Connection conn = dataSource.getConnection()) {
executeInWriteTransaction(
conn,
() -> {
try (PreparedStatement stmt = conn.prepareStatement(deleteSql)) {
stmt.setString(1, sessionId);
stmt.executeUpdate();
}
});
} catch (Exception e) {
throw new RuntimeException("Failed to delete session: " + sessionId, e);
}
}
Expand Down Expand Up @@ -705,12 +741,17 @@ public DataSource getDataSource() {
public int clearAllSessions() {
String clearSql = "DELETE FROM " + getFullTableName();

try (Connection conn = dataSource.getConnection();
PreparedStatement stmt = conn.prepareStatement(clearSql)) {

return stmt.executeUpdate();

} catch (SQLException e) {
try (Connection conn = dataSource.getConnection()) {
int[] deletedRows = new int[1];
executeInWriteTransaction(
conn,
() -> {
try (PreparedStatement stmt = conn.prepareStatement(clearSql)) {
deletedRows[0] = stmt.executeUpdate();
}
});
return deletedRows[0];
} catch (Exception e) {
throw new RuntimeException("Failed to clear sessions", e);
}
}
Expand All @@ -730,12 +771,17 @@ public int clearAllSessions() {
public int truncateAllSessions() {
String clearSql = "TRUNCATE TABLE " + getFullTableName();

try (Connection conn = dataSource.getConnection();
PreparedStatement stmt = conn.prepareStatement(clearSql)) {

return stmt.executeUpdate();

} catch (SQLException e) {
try (Connection conn = dataSource.getConnection()) {
int[] truncateResult = new int[1];
executeInWriteTransaction(
conn,
() -> {
try (PreparedStatement stmt = conn.prepareStatement(clearSql)) {
truncateResult[0] = stmt.executeUpdate();
}
});
return truncateResult[0];
} catch (Exception e) {
throw new RuntimeException("Failed to truncate sessions", e);
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
import static org.junit.jupiter.api.Assertions.assertTrue;
import static org.mockito.ArgumentMatchers.anyString;
import static org.mockito.Mockito.atLeast;
import static org.mockito.Mockito.never;
import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.when;

Expand Down Expand Up @@ -232,6 +233,22 @@ void testSaveAndGetSingleState() throws SQLException {
assertEquals(42, loaded.get().count());
}

@Test
@DisplayName("Should commit single state save when connection auto-commit is disabled")
void testSaveSingleStateCommitsWhenAutoCommitDisabled() throws SQLException {
when(mockConnection.getAutoCommit()).thenReturn(false);
when(mockStatement.execute()).thenReturn(true);
when(mockStatement.executeUpdate()).thenReturn(1);

MysqlSession session = new MysqlSession(mockDataSource, true);
SessionKey sessionKey = SimpleSessionKey.of("session_auto_commit_off");

session.save(sessionKey, "testModule", new TestState("test_value", 42));

verify(mockConnection).commit();
verify(mockConnection, never()).setAutoCommit(true);
}

@Test
@DisplayName("Should save and get list state correctly")
void testSaveAndGetListState() throws SQLException {
Expand Down Expand Up @@ -265,6 +282,49 @@ void testSaveAndGetListState() throws SQLException {
assertEquals("value2", loaded.get(1).value());
}

@Test
@DisplayName("Should commit incremental list save when connection auto-commit is disabled")
void testSaveListIncrementalAppendCommitsWhenAutoCommitDisabled() throws SQLException {
when(mockConnection.getAutoCommit()).thenReturn(false);
when(mockStatement.execute()).thenReturn(true);
when(mockStatement.executeQuery()).thenReturn(mockResultSet);
when(mockResultSet.next()).thenReturn(false, true);
when(mockResultSet.getInt("max_index")).thenReturn(0);
when(mockResultSet.wasNull()).thenReturn(true);

MysqlSession session = new MysqlSession(mockDataSource, true);
SessionKey sessionKey = SimpleSessionKey.of("session_list_auto_commit_off");
List<TestState> states = List.of(new TestState("value1", 1), new TestState("value2", 2));

session.save(sessionKey, "testList", states);

verify(mockConnection).commit();
verify(mockConnection, never()).setAutoCommit(true);
}

@Test
@DisplayName(
"Should not force auto-commit true after full rewrite when connection starts disabled")
void testSaveListFullRewriteRestoresOriginalAutoCommitState() throws SQLException {
when(mockConnection.getAutoCommit()).thenReturn(false);
when(mockStatement.execute()).thenReturn(true);
when(mockStatement.executeQuery()).thenReturn(mockResultSet);
when(mockResultSet.next()).thenReturn(true, true);
when(mockResultSet.getString("state_data")).thenReturn("stale_hash");
when(mockResultSet.getInt("max_index")).thenReturn(0);
when(mockResultSet.wasNull()).thenReturn(false);
when(mockStatement.executeUpdate()).thenReturn(1);

MysqlSession session = new MysqlSession(mockDataSource, true);
SessionKey sessionKey = SimpleSessionKey.of("session_full_rewrite_auto_commit_off");
List<TestState> states = List.of(new TestState("value1", 1));

session.save(sessionKey, "testList", states);

verify(mockConnection).commit();
verify(mockConnection, never()).setAutoCommit(true);
}

@Test
@DisplayName("Should return empty for non-existent state")
void testGetNonExistentState() throws SQLException {
Expand Down Expand Up @@ -334,6 +394,22 @@ void testDeleteSession() throws SQLException {
verify(mockStatement).executeUpdate();
}

@Test
@DisplayName("Should commit delete when connection auto-commit is disabled")
void testDeleteSessionCommitsWhenAutoCommitDisabled() throws SQLException {
when(mockConnection.getAutoCommit()).thenReturn(false);
when(mockStatement.execute()).thenReturn(true);
when(mockStatement.executeUpdate()).thenReturn(1);

MysqlSession session = new MysqlSession(mockDataSource, true);
SessionKey sessionKey = SimpleSessionKey.of("session1");

session.delete(sessionKey);

verify(mockConnection).commit();
verify(mockConnection, never()).setAutoCommit(true);
}

@Test
@DisplayName("Should list all session keys when empty")
void testListSessionKeysEmpty() throws SQLException {
Expand Down Expand Up @@ -375,6 +451,21 @@ void testClearAllSessions() throws SQLException {
assertEquals(5, deleted);
}

@Test
@DisplayName("Should commit clearAllSessions when connection auto-commit is disabled")
void testClearAllSessionsCommitsWhenAutoCommitDisabled() throws SQLException {
when(mockConnection.getAutoCommit()).thenReturn(false);
when(mockStatement.execute()).thenReturn(true);
when(mockStatement.executeUpdate()).thenReturn(5);

MysqlSession session = new MysqlSession(mockDataSource, true);
int deleted = session.clearAllSessions();

assertEquals(5, deleted);
verify(mockConnection).commit();
verify(mockConnection, never()).setAutoCommit(true);
}

@Test
@DisplayName("Should truncate session table")
void testTruncateAllSessions() throws SQLException {
Expand All @@ -387,6 +478,21 @@ void testTruncateAllSessions() throws SQLException {
assertEquals(0, success);
}

@Test
@DisplayName("Should commit truncateAllSessions when connection auto-commit is disabled")
void testTruncateAllSessionsCommitsWhenAutoCommitDisabled() throws SQLException {
when(mockConnection.getAutoCommit()).thenReturn(false);
when(mockStatement.execute()).thenReturn(true);
when(mockStatement.executeUpdate()).thenReturn(0);

MysqlSession session = new MysqlSession(mockDataSource, true);
int success = session.truncateAllSessions();

assertEquals(0, success);
verify(mockConnection).commit();
verify(mockConnection, never()).setAutoCommit(true);
}

@Test
@DisplayName("Should not close DataSource when closing session")
void testClose() throws SQLException {
Expand Down
Loading
Loading