From 94174dda4c37f2ab93aaef04f7ad91c4a5aaa417 Mon Sep 17 00:00:00 2001 From: nicholas1485 <3309028585@qq.com> Date: Tue, 31 Mar 2026 23:00:21 +0800 Subject: [PATCH] fix(session-mysql): commit writes when auto-commit is disabled --- .../core/session/mysql/MysqlSession.java | 180 +++++++++++------- .../core/session/mysql/MysqlSessionTest.java | 106 +++++++++++ .../mysql/e2e/MysqlSessionE2ETest.java | 133 +++++++++++++ 3 files changed, 352 insertions(+), 67 deletions(-) diff --git a/agentscope-extensions/agentscope-extensions-session-mysql/src/main/java/io/agentscope/core/session/mysql/MysqlSession.java b/agentscope-extensions/agentscope-extensions-session-mysql/src/main/java/io/agentscope/core/session/mysql/MysqlSession.java index f60ff3eea..d9f690131 100644 --- a/agentscope-extensions/agentscope-extensions-session-mysql/src/main/java/io/agentscope/core/session/mysql/MysqlSession.java +++ b/agentscope-extensions/agentscope-extensions-session-mysql/src/main/java/io/agentscope/core/session/mysql/MysqlSession.java @@ -93,6 +93,11 @@ public class MysqlSession implements Session { private final String databaseName; private final String tableName; + @FunctionalInterface + private interface SqlOperation { + void execute() throws Exception; + } + /** * Create a MysqlSession with default settings. * @@ -285,6 +290,38 @@ private String getFullTableName() { return "`" + databaseName + "`.`" + tableName + "`"; } + /** + * Execute a write operation in an explicit transaction. + * + *

MysqlSession obtains and owns a fresh JDBC connection for each write method call. This + * helper makes write semantics consistent even when the underlying DataSource defaults to + * {@code autoCommit=false}, and restores the connection's original auto-commit mode before + * returning it to the pool. + */ + private void executeInWriteTransaction(Connection conn, SqlOperation operation) + throws Exception { + boolean originalAutoCommit = conn.getAutoCommit(); + if (originalAutoCommit) { + conn.setAutoCommit(false); + } + + try { + operation.execute(); + conn.commit(); + } catch (Exception e) { + try { + conn.rollback(); + } catch (SQLException rollbackException) { + e.addSuppressed(rollbackException); + } + throw e; + } finally { + if (conn.getAutoCommit() != originalAutoCommit) { + conn.setAutoCommit(originalAutoCommit); + } + } + } + @Override public void save(SessionKey sessionKey, String key, State value) { String sessionId = sessionKey.toIdentifier(); @@ -298,18 +335,21 @@ public void save(SessionKey sessionKey, String key, State value) { + " VALUES (?, ?, ?, ?)" + " ON DUPLICATE KEY UPDATE state_data = VALUES(state_data)"; - try (Connection conn = dataSource.getConnection(); - PreparedStatement stmt = conn.prepareStatement(upsertSql)) { - - String json = JsonUtils.getJsonCodec().toJson(value); - - stmt.setString(1, sessionId); - stmt.setString(2, key); - stmt.setInt(3, SINGLE_STATE_INDEX); - stmt.setString(4, json); - - stmt.executeUpdate(); - + try (Connection conn = dataSource.getConnection()) { + executeInWriteTransaction( + conn, + () -> { + try (PreparedStatement stmt = conn.prepareStatement(upsertSql)) { + String json = JsonUtils.getJsonCodec().toJson(value); + + stmt.setString(1, sessionId); + stmt.setString(2, key); + stmt.setInt(3, SINGLE_STATE_INDEX); + stmt.setString(4, json); + + stmt.executeUpdate(); + } + }); } catch (Exception e) { throw new RuntimeException("Failed to save state: " + key, e); } @@ -344,42 +384,35 @@ public void save(SessionKey sessionKey, String key, List values String hashKey = key + HASH_KEY_SUFFIX; try (Connection conn = dataSource.getConnection()) { - // Compute current hash - String currentHash = ListHashUtil.computeHash(values); - - // Get stored hash - String storedHash = getStoredHash(conn, sessionId, hashKey); - - // Get existing count - int existingCount = getListCount(conn, sessionId, key); - - // Determine if full rewrite is needed - boolean needsFullRewrite = - ListHashUtil.needsFullRewrite( - currentHash, storedHash, values.size(), existingCount); - - if (needsFullRewrite) { - // Transaction: delete all + insert all - conn.setAutoCommit(false); - try { - deleteListItems(conn, sessionId, key); - insertAllItems(conn, sessionId, key, values); - saveHash(conn, sessionId, hashKey, currentHash); - conn.commit(); - } catch (Exception e) { - conn.rollback(); - throw e; - } finally { - conn.setAutoCommit(true); - } - } else if (values.size() > existingCount) { - // Incremental append - List newItems = values.subList(existingCount, values.size()); - insertItems(conn, sessionId, key, newItems, existingCount); - saveHash(conn, sessionId, hashKey, currentHash); - } - // else: no change, skip - + executeInWriteTransaction( + conn, + () -> { + // Compute current hash + String currentHash = ListHashUtil.computeHash(values); + + // Get stored hash + String storedHash = getStoredHash(conn, sessionId, hashKey); + + // Get existing count + int existingCount = getListCount(conn, sessionId, key); + + // Determine if full rewrite is needed + boolean needsFullRewrite = + ListHashUtil.needsFullRewrite( + currentHash, storedHash, values.size(), existingCount); + + if (needsFullRewrite) { + deleteListItems(conn, sessionId, key); + insertAllItems(conn, sessionId, key, values); + saveHash(conn, sessionId, hashKey, currentHash); + } else if (values.size() > existingCount) { + List newItems = + values.subList(existingCount, values.size()); + insertItems(conn, sessionId, key, newItems, existingCount); + saveHash(conn, sessionId, hashKey, currentHash); + } + // else: no change, skip + }); } catch (Exception e) { throw new RuntimeException("Failed to save list: " + key, e); } @@ -626,13 +659,16 @@ public void delete(SessionKey sessionKey) { String deleteSql = "DELETE FROM " + getFullTableName() + " WHERE session_id = ?"; - try (Connection conn = dataSource.getConnection(); - PreparedStatement stmt = conn.prepareStatement(deleteSql)) { - - stmt.setString(1, sessionId); - stmt.executeUpdate(); - - } catch (SQLException e) { + try (Connection conn = dataSource.getConnection()) { + executeInWriteTransaction( + conn, + () -> { + try (PreparedStatement stmt = conn.prepareStatement(deleteSql)) { + stmt.setString(1, sessionId); + stmt.executeUpdate(); + } + }); + } catch (Exception e) { throw new RuntimeException("Failed to delete session: " + sessionId, e); } } @@ -705,12 +741,17 @@ public DataSource getDataSource() { public int clearAllSessions() { String clearSql = "DELETE FROM " + getFullTableName(); - try (Connection conn = dataSource.getConnection(); - PreparedStatement stmt = conn.prepareStatement(clearSql)) { - - return stmt.executeUpdate(); - - } catch (SQLException e) { + try (Connection conn = dataSource.getConnection()) { + int[] deletedRows = new int[1]; + executeInWriteTransaction( + conn, + () -> { + try (PreparedStatement stmt = conn.prepareStatement(clearSql)) { + deletedRows[0] = stmt.executeUpdate(); + } + }); + return deletedRows[0]; + } catch (Exception e) { throw new RuntimeException("Failed to clear sessions", e); } } @@ -730,12 +771,17 @@ public int clearAllSessions() { public int truncateAllSessions() { String clearSql = "TRUNCATE TABLE " + getFullTableName(); - try (Connection conn = dataSource.getConnection(); - PreparedStatement stmt = conn.prepareStatement(clearSql)) { - - return stmt.executeUpdate(); - - } catch (SQLException e) { + try (Connection conn = dataSource.getConnection()) { + int[] truncateResult = new int[1]; + executeInWriteTransaction( + conn, + () -> { + try (PreparedStatement stmt = conn.prepareStatement(clearSql)) { + truncateResult[0] = stmt.executeUpdate(); + } + }); + return truncateResult[0]; + } catch (Exception e) { throw new RuntimeException("Failed to truncate sessions", e); } } diff --git a/agentscope-extensions/agentscope-extensions-session-mysql/src/test/java/io/agentscope/core/session/mysql/MysqlSessionTest.java b/agentscope-extensions/agentscope-extensions-session-mysql/src/test/java/io/agentscope/core/session/mysql/MysqlSessionTest.java index b720b7376..f46e9ffeb 100644 --- a/agentscope-extensions/agentscope-extensions-session-mysql/src/test/java/io/agentscope/core/session/mysql/MysqlSessionTest.java +++ b/agentscope-extensions/agentscope-extensions-session-mysql/src/test/java/io/agentscope/core/session/mysql/MysqlSessionTest.java @@ -21,6 +21,7 @@ import static org.junit.jupiter.api.Assertions.assertTrue; import static org.mockito.ArgumentMatchers.anyString; import static org.mockito.Mockito.atLeast; +import static org.mockito.Mockito.never; import static org.mockito.Mockito.verify; import static org.mockito.Mockito.when; @@ -232,6 +233,22 @@ void testSaveAndGetSingleState() throws SQLException { assertEquals(42, loaded.get().count()); } + @Test + @DisplayName("Should commit single state save when connection auto-commit is disabled") + void testSaveSingleStateCommitsWhenAutoCommitDisabled() throws SQLException { + when(mockConnection.getAutoCommit()).thenReturn(false); + when(mockStatement.execute()).thenReturn(true); + when(mockStatement.executeUpdate()).thenReturn(1); + + MysqlSession session = new MysqlSession(mockDataSource, true); + SessionKey sessionKey = SimpleSessionKey.of("session_auto_commit_off"); + + session.save(sessionKey, "testModule", new TestState("test_value", 42)); + + verify(mockConnection).commit(); + verify(mockConnection, never()).setAutoCommit(true); + } + @Test @DisplayName("Should save and get list state correctly") void testSaveAndGetListState() throws SQLException { @@ -265,6 +282,49 @@ void testSaveAndGetListState() throws SQLException { assertEquals("value2", loaded.get(1).value()); } + @Test + @DisplayName("Should commit incremental list save when connection auto-commit is disabled") + void testSaveListIncrementalAppendCommitsWhenAutoCommitDisabled() throws SQLException { + when(mockConnection.getAutoCommit()).thenReturn(false); + when(mockStatement.execute()).thenReturn(true); + when(mockStatement.executeQuery()).thenReturn(mockResultSet); + when(mockResultSet.next()).thenReturn(false, true); + when(mockResultSet.getInt("max_index")).thenReturn(0); + when(mockResultSet.wasNull()).thenReturn(true); + + MysqlSession session = new MysqlSession(mockDataSource, true); + SessionKey sessionKey = SimpleSessionKey.of("session_list_auto_commit_off"); + List states = List.of(new TestState("value1", 1), new TestState("value2", 2)); + + session.save(sessionKey, "testList", states); + + verify(mockConnection).commit(); + verify(mockConnection, never()).setAutoCommit(true); + } + + @Test + @DisplayName( + "Should not force auto-commit true after full rewrite when connection starts disabled") + void testSaveListFullRewriteRestoresOriginalAutoCommitState() throws SQLException { + when(mockConnection.getAutoCommit()).thenReturn(false); + when(mockStatement.execute()).thenReturn(true); + when(mockStatement.executeQuery()).thenReturn(mockResultSet); + when(mockResultSet.next()).thenReturn(true, true); + when(mockResultSet.getString("state_data")).thenReturn("stale_hash"); + when(mockResultSet.getInt("max_index")).thenReturn(0); + when(mockResultSet.wasNull()).thenReturn(false); + when(mockStatement.executeUpdate()).thenReturn(1); + + MysqlSession session = new MysqlSession(mockDataSource, true); + SessionKey sessionKey = SimpleSessionKey.of("session_full_rewrite_auto_commit_off"); + List states = List.of(new TestState("value1", 1)); + + session.save(sessionKey, "testList", states); + + verify(mockConnection).commit(); + verify(mockConnection, never()).setAutoCommit(true); + } + @Test @DisplayName("Should return empty for non-existent state") void testGetNonExistentState() throws SQLException { @@ -334,6 +394,22 @@ void testDeleteSession() throws SQLException { verify(mockStatement).executeUpdate(); } + @Test + @DisplayName("Should commit delete when connection auto-commit is disabled") + void testDeleteSessionCommitsWhenAutoCommitDisabled() throws SQLException { + when(mockConnection.getAutoCommit()).thenReturn(false); + when(mockStatement.execute()).thenReturn(true); + when(mockStatement.executeUpdate()).thenReturn(1); + + MysqlSession session = new MysqlSession(mockDataSource, true); + SessionKey sessionKey = SimpleSessionKey.of("session1"); + + session.delete(sessionKey); + + verify(mockConnection).commit(); + verify(mockConnection, never()).setAutoCommit(true); + } + @Test @DisplayName("Should list all session keys when empty") void testListSessionKeysEmpty() throws SQLException { @@ -375,6 +451,21 @@ void testClearAllSessions() throws SQLException { assertEquals(5, deleted); } + @Test + @DisplayName("Should commit clearAllSessions when connection auto-commit is disabled") + void testClearAllSessionsCommitsWhenAutoCommitDisabled() throws SQLException { + when(mockConnection.getAutoCommit()).thenReturn(false); + when(mockStatement.execute()).thenReturn(true); + when(mockStatement.executeUpdate()).thenReturn(5); + + MysqlSession session = new MysqlSession(mockDataSource, true); + int deleted = session.clearAllSessions(); + + assertEquals(5, deleted); + verify(mockConnection).commit(); + verify(mockConnection, never()).setAutoCommit(true); + } + @Test @DisplayName("Should truncate session table") void testTruncateAllSessions() throws SQLException { @@ -387,6 +478,21 @@ void testTruncateAllSessions() throws SQLException { assertEquals(0, success); } + @Test + @DisplayName("Should commit truncateAllSessions when connection auto-commit is disabled") + void testTruncateAllSessionsCommitsWhenAutoCommitDisabled() throws SQLException { + when(mockConnection.getAutoCommit()).thenReturn(false); + when(mockStatement.execute()).thenReturn(true); + when(mockStatement.executeUpdate()).thenReturn(0); + + MysqlSession session = new MysqlSession(mockDataSource, true); + int success = session.truncateAllSessions(); + + assertEquals(0, success); + verify(mockConnection).commit(); + verify(mockConnection, never()).setAutoCommit(true); + } + @Test @DisplayName("Should not close DataSource when closing session") void testClose() throws SQLException { diff --git a/agentscope-extensions/agentscope-extensions-session-mysql/src/test/java/io/agentscope/core/session/mysql/e2e/MysqlSessionE2ETest.java b/agentscope-extensions/agentscope-extensions-session-mysql/src/test/java/io/agentscope/core/session/mysql/e2e/MysqlSessionE2ETest.java index 6cbab1ab9..74f74e211 100644 --- a/agentscope-extensions/agentscope-extensions-session-mysql/src/test/java/io/agentscope/core/session/mysql/e2e/MysqlSessionE2ETest.java +++ b/agentscope-extensions/agentscope-extensions-session-mysql/src/test/java/io/agentscope/core/session/mysql/e2e/MysqlSessionE2ETest.java @@ -24,13 +24,16 @@ import io.agentscope.core.state.SessionKey; import io.agentscope.core.state.SimpleSessionKey; import io.agentscope.core.state.State; +import java.io.PrintWriter; import java.sql.Connection; import java.sql.SQLException; +import java.sql.SQLFeatureNotSupportedException; import java.sql.Statement; import java.util.List; import java.util.Optional; import java.util.Set; import java.util.UUID; +import java.util.logging.Logger; import javax.sql.DataSource; import org.h2.jdbcx.JdbcDataSource; import org.junit.jupiter.api.AfterEach; @@ -159,6 +162,83 @@ void testSaveAndLoadListState() { assertEquals("item3", allLoaded.get(2).value()); } + @Test + @DisplayName("Save persists when DataSource connections default to auto-commit false") + void testSavePersistsWithAutoCommitDisabledConnections() { + System.out.println("\n=== Test: Save With Auto-Commit Disabled Connections ==="); + + DataSource baseDataSource = createH2DataSource(); + dataSource = baseDataSource; + String schemaName = generateSafeIdentifier("AGENTSCOPE_E2E").toUpperCase(); + String tableName = generateSafeIdentifier("AGENTSCOPE_SESSIONS").toUpperCase(); + createdSchemaName = schemaName; + + initSchemaAndTable(baseDataSource, schemaName, tableName); + + MysqlSession session = + new MysqlSession( + wrapWithAutoCommit(baseDataSource, false), schemaName, tableName, false); + + SessionKey sessionKey = + SimpleSessionKey.of("mysql_e2e_autocommit_off_" + UUID.randomUUID()); + + session.save(sessionKey, "moduleA", new TestState("hello", 1)); + session.save( + sessionKey, + "stateList", + List.of(new TestState("item1", 1), new TestState("item2", 2))); + + assertTrue(session.exists(sessionKey)); + + Optional loadedState = session.get(sessionKey, "moduleA", TestState.class); + assertTrue(loadedState.isPresent()); + assertEquals("hello", loadedState.get().value()); + + List loadedList = session.getList(sessionKey, "stateList", TestState.class); + assertEquals(2, loadedList.size()); + assertEquals("item1", loadedList.get(0).value()); + assertEquals("item2", loadedList.get(1).value()); + } + + @Test + @DisplayName( + "Delete and cleanup persist when DataSource connections default to auto-commit false") + void testDeleteAndCleanupPersistWithAutoCommitDisabledConnections() { + System.out.println( + "\n=== Test: Delete And Cleanup With Auto-Commit Disabled Connections ==="); + + DataSource baseDataSource = createH2DataSource(); + dataSource = baseDataSource; + String schemaName = generateSafeIdentifier("AGENTSCOPE_E2E").toUpperCase(); + String tableName = generateSafeIdentifier("AGENTSCOPE_SESSIONS").toUpperCase(); + createdSchemaName = schemaName; + + initSchemaAndTable(baseDataSource, schemaName, tableName); + + MysqlSession session = + new MysqlSession( + wrapWithAutoCommit(baseDataSource, false), schemaName, tableName, false); + + SessionKey sessionKey1 = SimpleSessionKey.of("mysql_e2e_delete_" + UUID.randomUUID()); + SessionKey sessionKey2 = SimpleSessionKey.of("mysql_e2e_clear_" + UUID.randomUUID()); + + session.save(sessionKey1, "moduleA", new TestState("hello", 1)); + session.save(sessionKey2, "moduleA", new TestState("world", 2)); + + session.delete(sessionKey1); + assertFalse(session.exists(sessionKey1)); + assertTrue(session.exists(sessionKey2)); + + session.clearAllSessions(); + assertTrue(session.listSessionKeys().isEmpty()); + + session.save(sessionKey1, "moduleA", new TestState("hello_again", 3)); + assertTrue(session.exists(sessionKey1)); + + session.truncateAllSessions(); + assertTrue(session.listSessionKeys().isEmpty()); + } + @Test @DisplayName("Session does not exist should return false") void testSessionNotExists() { @@ -217,6 +297,59 @@ private static DataSource createH2DataSource() { return ds; } + private static DataSource wrapWithAutoCommit(DataSource delegate, boolean autoCommit) { + return new DataSource() { + @Override + public Connection getConnection() throws SQLException { + Connection conn = delegate.getConnection(); + conn.setAutoCommit(autoCommit); + return conn; + } + + @Override + public Connection getConnection(String username, String password) throws SQLException { + Connection conn = delegate.getConnection(username, password); + conn.setAutoCommit(autoCommit); + return conn; + } + + @Override + public PrintWriter getLogWriter() throws SQLException { + return delegate.getLogWriter(); + } + + @Override + public void setLogWriter(PrintWriter out) throws SQLException { + delegate.setLogWriter(out); + } + + @Override + public void setLoginTimeout(int seconds) throws SQLException { + delegate.setLoginTimeout(seconds); + } + + @Override + public int getLoginTimeout() throws SQLException { + return delegate.getLoginTimeout(); + } + + @Override + public Logger getParentLogger() throws SQLFeatureNotSupportedException { + return delegate.getParentLogger(); + } + + @Override + public T unwrap(Class iface) throws SQLException { + return delegate.unwrap(iface); + } + + @Override + public boolean isWrapperFor(Class iface) throws SQLException { + return delegate.isWrapperFor(iface); + } + }; + } + /** Generates a safe MySQL identifier (letters/numbers/underscore) and keeps it <= 64 chars. */ private static String generateSafeIdentifier(String prefix) { String suffix = UUID.randomUUID().toString().replace("-", "_");