diff --git a/asyncmy/connection.pyx b/asyncmy/connection.pyx index 3d1d7e7..ad9f3d3 100644 --- a/asyncmy/connection.pyx +++ b/asyncmy/connection.pyx @@ -108,6 +108,9 @@ class Connection: :param host: Host where the database server is located. :param user: Username to log in as. :param password: Password to use. + :param password_creator: + Optional callable or coroutine that returns a password string + every time a new connection is established. :param database: Database to use, None to not use a particular one. :param port: MySQL port to use, default is usually OK. (default: 3306) :param unix_socket: Use a unix socket rather than TCP/IP. @@ -152,6 +155,7 @@ class Connection: *, user=None, # The first four arguments is based on DB-API 2.0 recommendation. password="", + password_creator=None, host=None, database=None, unix_socket=None, @@ -240,6 +244,7 @@ class Connection: raise ValueError("port should be of type int") self._user = user or DEFAULT_USER self._password = password or b"" + self._password_creator = password_creator if isinstance(self._password, str): self._password = self._password.encode("latin1") self._db = database @@ -549,6 +554,12 @@ class Connection: return self._reader, self._writer try: + if self._password_creator is not None: + new_pw = self._password_creator() + if asyncio.iscoroutine(new_pw): + new_pw = await new_pw + self._password = new_pw.encode("latin1") + if self._unix_socket: self._reader, self._writer = await asyncio.wait_for(asyncio.open_unix_connection(self._unix_socket), timeout=self._connect_timeout, ) @@ -1281,6 +1292,7 @@ class LoadLocalFile: def connect(user=None, password="", + password_creator=None, host=None, database=None, unix_socket=None, @@ -1310,6 +1322,7 @@ def connect(user=None, coro = _connect( user=user, password=password, + password_creator=password_creator, host=host, database=database, unix_socket=unix_socket, diff --git a/conftest.py b/conftest.py index 00e30e9..6216ad1 100644 --- a/conftest.py +++ b/conftest.py @@ -7,13 +7,27 @@ from asyncmy import connect from asyncmy.cursors import DictCursor -connection_kwargs = dict( - host="127.0.0.1", - port=3306, - user="root", - password=os.getenv("MYSQL_PASS") or "123456", - echo=True, -) +def mysql_password_creator(): + """Return the MySQL password dynamically""" + return os.getenv("MYSQL_PASS") or "123456" + + +@pytest_asyncio.fixture(params=["static", "creator"], scope="session") +def connection_kwargs(request): + """Provide connection args for both static and dynamic password modes.""" + base = dict( + host="127.0.0.1", + port=3306, + user="root", + echo=True, + ) + + if request.param == "static": + base["password"] = os.getenv("MYSQL_PASS") or "123456" + else: + base["password_creator"] = mysql_password_creator + + return base @pytest_asyncio.fixture(scope="session") @@ -30,7 +44,7 @@ def event_loop(): @pytest_asyncio.fixture(scope="session") -async def connection(): +async def connection(connection_kwargs): conn = await connect(**connection_kwargs) yield conn await conn.ensure_closed() @@ -63,8 +77,8 @@ async def truncate_table(connection): @pytest_asyncio.fixture(scope="session") -async def pool(): +async def pool(connection_kwargs): pool = await asyncmy.create_pool(**connection_kwargs) yield pool pool.close() - await pool.wait_closed() + await pool.wait_closed() \ No newline at end of file diff --git a/tests/test_connection.py b/tests/test_connection.py index c677951..628bcf5 100644 --- a/tests/test_connection.py +++ b/tests/test_connection.py @@ -4,11 +4,10 @@ from asyncmy.connection import Connection from asyncmy.errors import OperationalError -from conftest import connection_kwargs @pytest.mark.asyncio -async def test_connect(): +async def test_connect(connection_kwargs): connection = Connection(**connection_kwargs) await connection.connect() assert connection._connected @@ -22,7 +21,7 @@ async def test_connect(): @pytest.mark.asyncio -async def test_read_timeout(): +async def test_read_timeout(connection_kwargs): with pytest.raises(OperationalError): connection = Connection(read_timeout=1, **connection_kwargs) await connection.connect()