diff --git a/agentscope-extensions/agentscope-extensions-session-mysql/src/main/java/io/agentscope/core/session/mysql/MysqlSession.java b/agentscope-extensions/agentscope-extensions-session-mysql/src/main/java/io/agentscope/core/session/mysql/MysqlSession.java index f60ff3eea..d9f690131 100644 --- a/agentscope-extensions/agentscope-extensions-session-mysql/src/main/java/io/agentscope/core/session/mysql/MysqlSession.java +++ b/agentscope-extensions/agentscope-extensions-session-mysql/src/main/java/io/agentscope/core/session/mysql/MysqlSession.java @@ -93,6 +93,11 @@ public class MysqlSession implements Session { private final String databaseName; private final String tableName; + @FunctionalInterface + private interface SqlOperation { + void execute() throws Exception; + } + /** * Create a MysqlSession with default settings. * @@ -285,6 +290,38 @@ private String getFullTableName() { return "`" + databaseName + "`.`" + tableName + "`"; } + /** + * Execute a write operation in an explicit transaction. + * + *
MysqlSession obtains and owns a fresh JDBC connection for each write method call. This
+ * helper makes write semantics consistent even when the underlying DataSource defaults to
+ * {@code autoCommit=false}, and restores the connection's original auto-commit mode before
+ * returning it to the pool.
+ */
+ private void executeInWriteTransaction(Connection conn, SqlOperation operation)
+ throws Exception {
+ boolean originalAutoCommit = conn.getAutoCommit();
+ if (originalAutoCommit) {
+ conn.setAutoCommit(false);
+ }
+
+ try {
+ operation.execute();
+ conn.commit();
+ } catch (Exception e) {
+ try {
+ conn.rollback();
+ } catch (SQLException rollbackException) {
+ e.addSuppressed(rollbackException);
+ }
+ throw e;
+ } finally {
+ if (conn.getAutoCommit() != originalAutoCommit) {
+ conn.setAutoCommit(originalAutoCommit);
+ }
+ }
+ }
+
@Override
public void save(SessionKey sessionKey, String key, State value) {
String sessionId = sessionKey.toIdentifier();
@@ -298,18 +335,21 @@ public void save(SessionKey sessionKey, String key, State value) {
+ " VALUES (?, ?, ?, ?)"
+ " ON DUPLICATE KEY UPDATE state_data = VALUES(state_data)";
- try (Connection conn = dataSource.getConnection();
- PreparedStatement stmt = conn.prepareStatement(upsertSql)) {
-
- String json = JsonUtils.getJsonCodec().toJson(value);
-
- stmt.setString(1, sessionId);
- stmt.setString(2, key);
- stmt.setInt(3, SINGLE_STATE_INDEX);
- stmt.setString(4, json);
-
- stmt.executeUpdate();
-
+ try (Connection conn = dataSource.getConnection()) {
+ executeInWriteTransaction(
+ conn,
+ () -> {
+ try (PreparedStatement stmt = conn.prepareStatement(upsertSql)) {
+ String json = JsonUtils.getJsonCodec().toJson(value);
+
+ stmt.setString(1, sessionId);
+ stmt.setString(2, key);
+ stmt.setInt(3, SINGLE_STATE_INDEX);
+ stmt.setString(4, json);
+
+ stmt.executeUpdate();
+ }
+ });
} catch (Exception e) {
throw new RuntimeException("Failed to save state: " + key, e);
}
@@ -344,42 +384,35 @@ public void save(SessionKey sessionKey, String key, List extends State> values
String hashKey = key + HASH_KEY_SUFFIX;
try (Connection conn = dataSource.getConnection()) {
- // Compute current hash
- String currentHash = ListHashUtil.computeHash(values);
-
- // Get stored hash
- String storedHash = getStoredHash(conn, sessionId, hashKey);
-
- // Get existing count
- int existingCount = getListCount(conn, sessionId, key);
-
- // Determine if full rewrite is needed
- boolean needsFullRewrite =
- ListHashUtil.needsFullRewrite(
- currentHash, storedHash, values.size(), existingCount);
-
- if (needsFullRewrite) {
- // Transaction: delete all + insert all
- conn.setAutoCommit(false);
- try {
- deleteListItems(conn, sessionId, key);
- insertAllItems(conn, sessionId, key, values);
- saveHash(conn, sessionId, hashKey, currentHash);
- conn.commit();
- } catch (Exception e) {
- conn.rollback();
- throw e;
- } finally {
- conn.setAutoCommit(true);
- }
- } else if (values.size() > existingCount) {
- // Incremental append
- List extends State> newItems = values.subList(existingCount, values.size());
- insertItems(conn, sessionId, key, newItems, existingCount);
- saveHash(conn, sessionId, hashKey, currentHash);
- }
- // else: no change, skip
-
+ executeInWriteTransaction(
+ conn,
+ () -> {
+ // Compute current hash
+ String currentHash = ListHashUtil.computeHash(values);
+
+ // Get stored hash
+ String storedHash = getStoredHash(conn, sessionId, hashKey);
+
+ // Get existing count
+ int existingCount = getListCount(conn, sessionId, key);
+
+ // Determine if full rewrite is needed
+ boolean needsFullRewrite =
+ ListHashUtil.needsFullRewrite(
+ currentHash, storedHash, values.size(), existingCount);
+
+ if (needsFullRewrite) {
+ deleteListItems(conn, sessionId, key);
+ insertAllItems(conn, sessionId, key, values);
+ saveHash(conn, sessionId, hashKey, currentHash);
+ } else if (values.size() > existingCount) {
+ List extends State> newItems =
+ values.subList(existingCount, values.size());
+ insertItems(conn, sessionId, key, newItems, existingCount);
+ saveHash(conn, sessionId, hashKey, currentHash);
+ }
+ // else: no change, skip
+ });
} catch (Exception e) {
throw new RuntimeException("Failed to save list: " + key, e);
}
@@ -626,13 +659,16 @@ public void delete(SessionKey sessionKey) {
String deleteSql = "DELETE FROM " + getFullTableName() + " WHERE session_id = ?";
- try (Connection conn = dataSource.getConnection();
- PreparedStatement stmt = conn.prepareStatement(deleteSql)) {
-
- stmt.setString(1, sessionId);
- stmt.executeUpdate();
-
- } catch (SQLException e) {
+ try (Connection conn = dataSource.getConnection()) {
+ executeInWriteTransaction(
+ conn,
+ () -> {
+ try (PreparedStatement stmt = conn.prepareStatement(deleteSql)) {
+ stmt.setString(1, sessionId);
+ stmt.executeUpdate();
+ }
+ });
+ } catch (Exception e) {
throw new RuntimeException("Failed to delete session: " + sessionId, e);
}
}
@@ -705,12 +741,17 @@ public DataSource getDataSource() {
public int clearAllSessions() {
String clearSql = "DELETE FROM " + getFullTableName();
- try (Connection conn = dataSource.getConnection();
- PreparedStatement stmt = conn.prepareStatement(clearSql)) {
-
- return stmt.executeUpdate();
-
- } catch (SQLException e) {
+ try (Connection conn = dataSource.getConnection()) {
+ int[] deletedRows = new int[1];
+ executeInWriteTransaction(
+ conn,
+ () -> {
+ try (PreparedStatement stmt = conn.prepareStatement(clearSql)) {
+ deletedRows[0] = stmt.executeUpdate();
+ }
+ });
+ return deletedRows[0];
+ } catch (Exception e) {
throw new RuntimeException("Failed to clear sessions", e);
}
}
@@ -730,12 +771,17 @@ public int clearAllSessions() {
public int truncateAllSessions() {
String clearSql = "TRUNCATE TABLE " + getFullTableName();
- try (Connection conn = dataSource.getConnection();
- PreparedStatement stmt = conn.prepareStatement(clearSql)) {
-
- return stmt.executeUpdate();
-
- } catch (SQLException e) {
+ try (Connection conn = dataSource.getConnection()) {
+ int[] truncateResult = new int[1];
+ executeInWriteTransaction(
+ conn,
+ () -> {
+ try (PreparedStatement stmt = conn.prepareStatement(clearSql)) {
+ truncateResult[0] = stmt.executeUpdate();
+ }
+ });
+ return truncateResult[0];
+ } catch (Exception e) {
throw new RuntimeException("Failed to truncate sessions", e);
}
}
diff --git a/agentscope-extensions/agentscope-extensions-session-mysql/src/test/java/io/agentscope/core/session/mysql/MysqlSessionTest.java b/agentscope-extensions/agentscope-extensions-session-mysql/src/test/java/io/agentscope/core/session/mysql/MysqlSessionTest.java
index b720b7376..f46e9ffeb 100644
--- a/agentscope-extensions/agentscope-extensions-session-mysql/src/test/java/io/agentscope/core/session/mysql/MysqlSessionTest.java
+++ b/agentscope-extensions/agentscope-extensions-session-mysql/src/test/java/io/agentscope/core/session/mysql/MysqlSessionTest.java
@@ -21,6 +21,7 @@
import static org.junit.jupiter.api.Assertions.assertTrue;
import static org.mockito.ArgumentMatchers.anyString;
import static org.mockito.Mockito.atLeast;
+import static org.mockito.Mockito.never;
import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.when;
@@ -232,6 +233,22 @@ void testSaveAndGetSingleState() throws SQLException {
assertEquals(42, loaded.get().count());
}
+ @Test
+ @DisplayName("Should commit single state save when connection auto-commit is disabled")
+ void testSaveSingleStateCommitsWhenAutoCommitDisabled() throws SQLException {
+ when(mockConnection.getAutoCommit()).thenReturn(false);
+ when(mockStatement.execute()).thenReturn(true);
+ when(mockStatement.executeUpdate()).thenReturn(1);
+
+ MysqlSession session = new MysqlSession(mockDataSource, true);
+ SessionKey sessionKey = SimpleSessionKey.of("session_auto_commit_off");
+
+ session.save(sessionKey, "testModule", new TestState("test_value", 42));
+
+ verify(mockConnection).commit();
+ verify(mockConnection, never()).setAutoCommit(true);
+ }
+
@Test
@DisplayName("Should save and get list state correctly")
void testSaveAndGetListState() throws SQLException {
@@ -265,6 +282,49 @@ void testSaveAndGetListState() throws SQLException {
assertEquals("value2", loaded.get(1).value());
}
+ @Test
+ @DisplayName("Should commit incremental list save when connection auto-commit is disabled")
+ void testSaveListIncrementalAppendCommitsWhenAutoCommitDisabled() throws SQLException {
+ when(mockConnection.getAutoCommit()).thenReturn(false);
+ when(mockStatement.execute()).thenReturn(true);
+ when(mockStatement.executeQuery()).thenReturn(mockResultSet);
+ when(mockResultSet.next()).thenReturn(false, true);
+ when(mockResultSet.getInt("max_index")).thenReturn(0);
+ when(mockResultSet.wasNull()).thenReturn(true);
+
+ MysqlSession session = new MysqlSession(mockDataSource, true);
+ SessionKey sessionKey = SimpleSessionKey.of("session_list_auto_commit_off");
+ List