diff --git a/CHANGELOG.md b/CHANGELOG.md index a55272b..1b24af9 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -6,6 +6,12 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/), and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html). +## [unreleased] + +### Fixed +- auth: Escape user-provided values in LDAP search filters to prevent LDAP + injection attacks and support values with parenthesis (#63). + ## [1.6.0] - 2025-10-10 ### Added diff --git a/src/authentication/rfl/authentication/ldap.py b/src/authentication/rfl/authentication/ldap.py index 5f980a1..04bd7de 100644 --- a/src/authentication/rfl/authentication/ldap.py +++ b/src/authentication/rfl/authentication/ldap.py @@ -9,6 +9,7 @@ import logging import ldap +import ldap.filter from .user import AuthenticatedUser from .errors import LDAPAuthenticationError @@ -196,7 +197,8 @@ def _get_groups( search_filter = ( "(&" f"(|{object_class_filter})" - f"(|(memberUid={user_name})(member={user_dn}){gid_filter}))" + f"(|(memberUid={ldap.filter.escape_filter_chars(user_name)})" + f"(member={ldap.filter.escape_filter_chars(user_dn)}){gid_filter}))" ) try: results = connection.search_s( @@ -283,7 +285,8 @@ def _lookup_user_dn(self, user): self._bind(connection) search_filter = ( - f"(&(objectClass={self.user_class})({self.user_name_attribute}={user}))" + f"(&(objectClass={self.user_class})" + f"({self.user_name_attribute}={ldap.filter.escape_filter_chars(user)}))" ) try: results = connection.search_s( diff --git a/src/authentication/rfl/tests/test_ldap.py b/src/authentication/rfl/tests/test_ldap.py index 098ecd9..001f31e 100644 --- a/src/authentication/rfl/tests/test_ldap.py +++ b/src/authentication/rfl/tests/test_ldap.py @@ -11,6 +11,7 @@ import ssl import ldap +import ldap.filter from rfl.authentication.ldap import LDAPAuthentifier from rfl.authentication.errors import LDAPAuthenticationError @@ -167,6 +168,39 @@ def test_lookup_user_dn_enabled_bind_dn(self, mock_ldap_initialize): mock_ldap_object.search_s.assert_called_once() mock_ldap_object.unbind_s.assert_called_once() + @patch.object(ldap, "initialize") + def test_lookup_user_dn_escape_special_chars(self, mock_ldap_initialize): + # enable user DN lookup + self.authentifier.lookup_user_dn = True + # setup LDAP mock + mock_ldap_object = mock_ldap_initialize.return_value + user_with_special_chars = "John Doe(user)" + escaped_user = ldap.filter.escape_filter_chars(user_with_special_chars) + mock_ldap_object.search_s.return_value = [ + ( + f"uid={user_with_special_chars},ou=admins,{self.authentifier.user_base}", + {"cn": [b"John Doe(user)"]}, + ) + ] + + self.assertEqual( + self.authentifier._lookup_user_dn(user_with_special_chars), + f"uid={user_with_special_chars},ou=admins,{self.authentifier.user_base}", + ) + # Verify the filter contains escaped user value + call_args = mock_ldap_object.search_s.call_args + search_filter = call_args[0][2] + # Verify the filter contains escaped user value + self.assertIn( + f"({self.authentifier.user_name_attribute}={escaped_user})", search_filter + ) + # The unescaped value should NOT be in the filter + self.assertNotIn( + f"({self.authentifier.user_name_attribute}={user_with_special_chars})", + search_filter, + ) + mock_ldap_object.unbind_s.assert_called_once() + @patch.object(ldap, "initialize") def test_lookup_user_dn_enabled_bind_dn_missing_password( self, mock_ldap_initialize @@ -268,7 +302,7 @@ def test_lookup_user_dn_enabled_not_found(self, mock_ldap_initialize): mock_ldap_object.unbind_s.assert_called_once() @patch.object(ldap, "initialize") - def test_lookup_user_dn_enabled_too_much_results(self, mock_ldap_initialize): + def test_lookup_user_dn_enabled_too_many_results(self, mock_ldap_initialize): # enable user DN lookup self.authentifier.lookup_user_dn = True # setup LDAP mock @@ -550,6 +584,30 @@ def test_get_groups(self): ) self.assertEqual(groups, ["scientists", "biology"]) + def test_get_groups_escape_special_chars(self): + connection = Mock(spec=ldap.ldapobject.LDAPObject) + connection.search_s.return_value = [ + ("cn=scientists,ou=groups,dc=corp,dc=org", {"cn": [b"scientists"]}), + ] + user_name_with_special = "John Doe(user)" + user_dn_with_special = "uid=John Doe(user),ou=people,dc=corp,dc=org" + escaped_user_name = ldap.filter.escape_filter_chars(user_name_with_special) + escaped_user_dn = ldap.filter.escape_filter_chars(user_dn_with_special) + gid = 42 + groups = self.authentifier._get_groups( + connection, user_name_with_special, user_dn_with_special, gid + ) + self.assertEqual(groups, ["scientists"]) + + connection.search_s.assert_called_once_with( + self.authentifier.group_base, + ldap.SCOPE_SUBTREE, + f"(&(|(objectClass=posixGroup)(objectClass=groupOfNames))" + f"(|(memberUid={escaped_user_name})" + f"(member={escaped_user_dn})(gidNumber={gid})))", + [self.authentifier.group_name_attribute], + ) + def test_get_groups_without_gid(self): connection = Mock(spec=ldap.ldapobject.LDAPObject) connection.search_s.return_value = [