From f333f6bd70a8119bfd5bc697744b3c79836234ff Mon Sep 17 00:00:00 2001 From: Jean Boussier Date: Tue, 25 Jul 2023 16:55:23 +0200 Subject: [PATCH] Add a reference to the connection from prepared statements Ref: https://github.com/trilogy-libraries/trilogy/issues/105 To close a prepared statement you need to have access to the connection that created it. But in managed languages like Ruby, the obvious thing to do is to close the prepared statement when the associated object is garbage collected. But the order in which objects are garbaged collected is never guaranteed so when freeing a statement the connection might have been freed already and it's hard to detect. We checked how libmysqlclient handles it, and each `MYSQL_STMT` has a reference to its `MYSQL` (connection), and the connection keeps a doubly linked list of the statements it created. When a statement is closed it's removed from the list, when the connection is closed, all the connection references are set to NULL. We implemented exactly the same logic here. Additionally, prepared statement can only be used with the connection they were created from. As such having all the `trilogy_stmt_*` function take a connection isn't great for usability. So this change opens the door to only taking a `trilogy_stmt_t *`. Co-Authored-By: Adrianna Chang --- inc/trilogy/client.h | 14 +++++++- src/blocking.c | 30 ++++++++++++++++ src/client.c | 16 ++++++++- test/blocking_test.c | 85 ++++++++++++++++++++++++++++++++++++++++++++ 4 files changed, 143 insertions(+), 2 deletions(-) diff --git a/inc/trilogy/client.h b/inc/trilogy/client.h index afce6273..59f5f1b2 100644 --- a/inc/trilogy/client.h +++ b/inc/trilogy/client.h @@ -63,6 +63,8 @@ */ typedef trilogy_column_packet_t trilogy_column_t; +typedef struct trilogy_stmt trilogy_stmt_t; + /* trilogy_conn_t - The Trilogy client's instance type. * * This type is shared for the non-blocking and blocking versions of the API. @@ -82,6 +84,7 @@ typedef struct { uint16_t server_status; trilogy_sock_t *socket; + trilogy_stmt_t *prepared_statements; // private: uint8_t recv_buff[TRILOGY_DEFAULT_BUF_SIZE]; @@ -619,7 +622,16 @@ int trilogy_stmt_prepare_send(trilogy_conn_t *conn, const char *stmt, size_t stm /* trilogy_stmt_t - The trilogy client's prepared statement type. */ -typedef trilogy_stmt_ok_packet_t trilogy_stmt_t; + +struct trilogy_stmt { + trilogy_stmt_t *prev; + trilogy_stmt_t *next; + uint32_t id; + uint16_t column_count; + uint16_t parameter_count; + uint16_t warning_count; + trilogy_conn_t *connection; +}; /* trilogy_stmt_prepare_recv - Read the prepared statement prepare command response * from the MySQL-compatible server. diff --git a/src/blocking.c b/src/blocking.c index 3afa562e..119e9aa0 100644 --- a/src/blocking.c +++ b/src/blocking.c @@ -263,6 +263,8 @@ int trilogy_close(trilogy_conn_t *conn) int trilogy_stmt_prepare(trilogy_conn_t *conn, const char *stmt, size_t stmt_len, trilogy_stmt_t *stmt_out) { + memset(stmt_out, 0, sizeof(trilogy_stmt_t)); + int rc = trilogy_stmt_prepare_send(conn, stmt, stmt_len); if (rc == TRILOGY_AGAIN) { @@ -276,6 +278,16 @@ int trilogy_stmt_prepare(trilogy_conn_t *conn, const char *stmt, size_t stmt_len while (1) { rc = trilogy_stmt_prepare_recv(conn, stmt_out); + if (rc == TRILOGY_OK) { + stmt_out->connection = conn; + if (conn->prepared_statements) { + stmt_out->next = conn->prepared_statements; + stmt_out->next->prev = stmt_out; + } + conn->prepared_statements = stmt_out; + return rc; + } + if (rc != TRILOGY_AGAIN) { return rc; } @@ -365,6 +377,10 @@ int trilogy_stmt_reset(trilogy_conn_t *conn, trilogy_stmt_t *stmt) int trilogy_stmt_close(trilogy_conn_t *conn, trilogy_stmt_t *stmt) { + if (!stmt->connection || conn != stmt->connection) { + // User BUG!!! Return an error or crash? + } + int rc = trilogy_stmt_close_send(conn, stmt); if (rc == TRILOGY_AGAIN) { @@ -375,6 +391,20 @@ int trilogy_stmt_close(trilogy_conn_t *conn, trilogy_stmt_t *stmt) return rc; } + if (stmt->prev == NULL) { + // assert stmt->connection->prepared_statements == stmt + stmt->connection->prepared_statements = stmt->next; + if (stmt->next) { + stmt->next->prev = NULL; + } + } else { + stmt->prev->next = stmt->next; + if (stmt->next) { + stmt->next->prev = stmt->prev; + } + } + + stmt->connection = NULL; return TRILOGY_OK; } diff --git a/src/client.c b/src/client.c index e78abd6e..3d3907b8 100644 --- a/src/client.c +++ b/src/client.c @@ -145,6 +145,8 @@ int trilogy_init(trilogy_conn_t *conn) conn->recv_buff_pos = 0; conn->recv_buff_len = 0; + conn->prepared_statements = NULL; + trilogy_packet_parser_init(&conn->packet_parser, &packet_parser_callbacks); conn->packet_parser.user_data = &conn->packet_buffer; @@ -765,6 +767,12 @@ void trilogy_free(trilogy_conn_t *conn) conn->socket = NULL; } + trilogy_stmt_t *stmt = conn->prepared_statements; + while (stmt) { + stmt->connection = NULL; + stmt = stmt->next; + } + trilogy_buffer_free(&conn->packet_buffer); } @@ -803,13 +811,19 @@ int trilogy_stmt_prepare_recv(trilogy_conn_t *conn, trilogy_stmt_t *stmt_out) switch (current_packet_type(conn)) { case TRILOGY_PACKET_OK: { - err = trilogy_parse_stmt_ok_packet(conn->packet_buffer.buff, conn->packet_buffer.len, stmt_out); + trilogy_stmt_ok_packet_t out_packet; + err = trilogy_parse_stmt_ok_packet(conn->packet_buffer.buff, conn->packet_buffer.len, &out_packet); if (err < 0) { return err; } conn->warning_count = stmt_out->warning_count; + stmt_out->connection = conn; + stmt_out->id = out_packet.id; + stmt_out->column_count = out_packet.column_count; + stmt_out->parameter_count = out_packet.parameter_count; + stmt_out->warning_count = out_packet.warning_count; return TRILOGY_OK; } diff --git a/test/blocking_test.c b/test/blocking_test.c index cd37a1c3..eddbc575 100644 --- a/test/blocking_test.c +++ b/test/blocking_test.c @@ -176,6 +176,7 @@ TEST test_blocking_stmt_prepare() int err = trilogy_stmt_prepare(&conn, query, strlen(query), &stmt); ASSERT_OK(err); + ASSERT(stmt.connection); ASSERT_EQ(1, stmt.parameter_count); @@ -204,6 +205,7 @@ TEST test_blocking_stmt_execute_str() int err = trilogy_stmt_prepare(&conn, query, strlen(query), &stmt); ASSERT_OK(err); + ASSERT(stmt.connection); ASSERT_EQ(1, stmt.parameter_count); @@ -258,6 +260,7 @@ TEST test_blocking_stmt_execute_integer() int err = trilogy_stmt_prepare(&conn, query, strlen(query), &stmt); ASSERT_OK(err); + ASSERT(stmt.connection); ASSERT_EQ(1, stmt.parameter_count); @@ -332,6 +335,7 @@ TEST test_blocking_stmt_execute_double() int err = trilogy_stmt_prepare(&conn, query, strlen(query), &stmt); ASSERT_OK(err); + ASSERT(stmt.connection); ASSERT_EQ(1, stmt.parameter_count); @@ -384,6 +388,7 @@ TEST test_blocking_stmt_execute_float() { int err = trilogy_stmt_prepare(&conn, query, strlen(query), &stmt); ASSERT_OK(err); + ASSERT(stmt.connection); ASSERT_EQ(1, stmt.parameter_count); @@ -443,6 +448,7 @@ TEST test_blocking_stmt_execute_long() int err = trilogy_stmt_prepare(&conn, query, strlen(query), &stmt); ASSERT_OK(err); + ASSERT(stmt.connection); ASSERT_EQ(1, stmt.parameter_count); @@ -516,6 +522,7 @@ TEST test_blocking_stmt_execute_short() { int err = trilogy_stmt_prepare(&conn, query, strlen(query), &stmt); ASSERT_OK(err); + ASSERT(stmt.connection); ASSERT_EQ(1, stmt.parameter_count); @@ -589,6 +596,7 @@ TEST test_blocking_stmt_execute_tiny() { int err = trilogy_stmt_prepare(&conn, query, strlen(query), &stmt); ASSERT_OK(err); + ASSERT(stmt.connection); ASSERT_EQ(1, stmt.parameter_count); @@ -663,6 +671,7 @@ TEST test_blocking_stmt_execute_datetime() int err = trilogy_stmt_prepare(&conn, query, strlen(query), &stmt); ASSERT_OK(err); + ASSERT(stmt.connection); ASSERT_EQ(0, stmt.parameter_count); @@ -714,6 +723,7 @@ TEST test_blocking_stmt_execute_time() int err = trilogy_stmt_prepare(&conn, query, strlen(query), &stmt); ASSERT_OK(err); + ASSERT(stmt.connection); ASSERT_EQ(0, stmt.parameter_count); @@ -762,6 +772,7 @@ TEST test_blocking_stmt_execute_year() int err = trilogy_stmt_prepare(&conn, query, strlen(query), &stmt); ASSERT_OK(err); + ASSERT(stmt.connection); ASSERT_EQ(0, stmt.parameter_count); @@ -808,6 +819,7 @@ TEST test_blocking_stmt_reset() int err = trilogy_stmt_prepare(&conn, query, strlen(query), &stmt); ASSERT_OK(err); + ASSERT(stmt.connection); ASSERT_EQ(1, stmt.parameter_count); @@ -839,6 +851,7 @@ TEST test_blocking_stmt_close() int err = trilogy_stmt_prepare(&conn, query, strlen(query), &stmt); ASSERT_OK(err); + ASSERT(stmt.connection); ASSERT_EQ(1, stmt.parameter_count); @@ -851,11 +864,82 @@ TEST test_blocking_stmt_close() trilogy_column_packet_t column_def; err = trilogy_read_full_column(&conn, &column_def); ASSERT_OK(err); + ASSERT_EQ(conn.prepared_statements, &stmt); + + const char *query2 = "SELECT YEAR('2022-01-31')"; + trilogy_stmt_t stmt2; + + err = trilogy_stmt_prepare(&conn, query2, strlen(query2), &stmt2); + ASSERT_OK(err); + ASSERT(stmt2.connection); + + ASSERT_EQ(0, stmt2.parameter_count); + + trilogy_column_packet_t param2; + err = trilogy_read_full_column(&conn, ¶m2); + ASSERT_OK(err); + + ASSERT_EQ(1, stmt2.column_count); + + ASSERT_EQ(conn.prepared_statements, &stmt2); + + err = trilogy_stmt_close(&conn, &stmt2); + ASSERT_OK(err); + ASSERT_EQ(conn.prepared_statements, &stmt); err = trilogy_stmt_close(&conn, &stmt); ASSERT_OK(err); + ASSERT_EQ(conn.prepared_statements, NULL); + + trilogy_free(&conn); + PASS(); +} + +TEST test_blocking_stmt_conn_close() +{ + trilogy_conn_t conn; + + connect_conn(&conn); + + const char *query = "SELECT ?"; + trilogy_stmt_t stmt; + + int err = trilogy_stmt_prepare(&conn, query, strlen(query), &stmt); + ASSERT_OK(err); + ASSERT(stmt.connection); + + ASSERT_EQ(1, stmt.parameter_count); + + trilogy_column_packet_t param; + err = trilogy_read_full_column(&conn, ¶m); + ASSERT_OK(err); + + ASSERT_EQ(1, stmt.column_count); + + trilogy_column_packet_t column_def; + err = trilogy_read_full_column(&conn, &column_def); + ASSERT_OK(err); + + const char *query2 = "SELECT YEAR('2022-01-31')"; + trilogy_stmt_t stmt2; + + err = trilogy_stmt_prepare(&conn, query2, strlen(query2), &stmt2); + ASSERT_OK(err); + ASSERT(stmt2.connection); + + ASSERT_EQ(0, stmt2.parameter_count); + + trilogy_column_packet_t param2; + err = trilogy_read_full_column(&conn, ¶m2); + ASSERT_OK(err); + + ASSERT_EQ(1, stmt2.column_count); trilogy_free(&conn); + + ASSERT(stmt.connection == NULL); + ASSERT(stmt2.connection == NULL); + PASS(); } @@ -881,6 +965,7 @@ int blocking_test() RUN_TEST(test_blocking_stmt_execute_year); RUN_TEST(test_blocking_stmt_reset); RUN_TEST(test_blocking_stmt_close); + RUN_TEST(test_blocking_stmt_conn_close); return 0; }