diff --git a/contrib/ruby/lib/trilogy.rb b/contrib/ruby/lib/trilogy.rb index bde8a033..e51afd7e 100644 --- a/contrib/ruby/lib/trilogy.rb +++ b/contrib/ruby/lib/trilogy.rb @@ -7,7 +7,27 @@ require "trilogy/encoding" class Trilogy + Synchronization = Module.new + + source = public_instance_methods(false).map do |method| + <<~RUBY + def #{method}(...) + raise SynchronizationError unless @mutex.try_lock + + begin + super + ensure + @mutex.unlock + end + end + RUBY + end + Synchronization.class_eval(source.join(";")) + + prepend(Synchronization) + def initialize(options = {}) + @mutex = Mutex.new options[:port] = options[:port].to_i if options[:port] mysql_encoding = options[:encoding] || "utf8mb4" encoding = Trilogy::Encoding.find(mysql_encoding) diff --git a/contrib/ruby/lib/trilogy/error.rb b/contrib/ruby/lib/trilogy/error.rb index 880c3a97..5885396d 100644 --- a/contrib/ruby/lib/trilogy/error.rb +++ b/contrib/ruby/lib/trilogy/error.rb @@ -54,6 +54,12 @@ class BaseConnectionError < BaseError include ConnectionError end + class SynchronizationError < BaseError + def initialize(message = "This connection is already in use by another thread or fiber") + super + end + end + # Trilogy::ClientError is the base error type for invalid queries or parameters # that shouldn't be retried. class ClientError < BaseError diff --git a/contrib/ruby/test/client_test.rb b/contrib/ruby/test/client_test.rb index 728193ea..67100319 100644 --- a/contrib/ruby/test/client_test.rb +++ b/contrib/ruby/test/client_test.rb @@ -705,6 +705,17 @@ def test_releases_gvl end end + def test_prevent_concurrent_use + client = new_tcp_client + thread = Thread.new { client.query("SELECT SLEEP(1)") } + thread.join(0.2) + assert_raises Trilogy::SynchronizationError do + client.query("SELECT 1") + end + thread.join + client.close + end + USR1 = Class.new(StandardError) def test_interruptible_when_releasing_gvl