diff --git a/app/ldap_protocol/filter_interpreter.py b/app/ldap_protocol/filter_interpreter.py index c456fac00..ce0b301c5 100644 --- a/app/ldap_protocol/filter_interpreter.py +++ b/app/ldap_protocol/filter_interpreter.py @@ -32,16 +32,26 @@ ) from ldap_protocol.utils.helpers import ft_to_dt from ldap_protocol.utils.queries import get_path_filter, get_search_path -from repo.pg.tables import groups_table, queryable_attr as qa, users_table +from repo.pg.tables import ( + directory_table, + groups_table, + queryable_attr as qa, + users_table, +) from .asn1parser import ASN1Row, TagNumbers from .objects import LDAPMatchingRule -from .utils.cte import find_members_recursive_cte, get_filter_from_path +from .utils.cte import ( + find_members_recursive_cte, + find_root_group_recursive_cte, + get_filter_from_path, +) _MEMBERS_ATTRS = { "member", "memberof", f"memberof:{LDAPMatchingRule.LDAP_MATCHING_RULE_TRANSITIVE_EVAL}:", + f"member:{LDAPMatchingRule.LDAP_MATCHING_RULE_TRANSITIVE_EVAL}:", } _RULE_POS = 0 @@ -289,6 +299,8 @@ def _get_member_filter_function( return self._recursive_filter_memberof return self._filter_memberof elif attribute == "member": + if oid == LDAPMatchingRule.LDAP_MATCHING_RULE_TRANSITIVE_EVAL: + return self._recursive_filter_member return self._filter_member else: raise ValueError("Incorrect attribute specified") @@ -317,6 +329,24 @@ def _filter_memberof(self, dn: str) -> UnaryExpression: ), ) # type: ignore + def _recursive_filter_member(self, dn: str) -> UnaryExpression: + """Retrieve query conditions with the member attribute (recursive).""" + cte = find_root_group_recursive_cte([dn]) + + source_directory_id = ( + select(directory_table.c.id) + .where(get_filter_from_path(dn)) + .scalar_subquery() + ) + + return qa(Directory.id).in_( + select(cte.c.directory_id) + .where( + cte.c.directory_id != source_directory_id, + ) + .distinct(), + ) # type: ignore + def _filter_member(self, dn: str) -> UnaryExpression: """Retrieve query conditions with the member attribute.""" user_id_subquery = ( diff --git a/tests/test_api/test_main/test_router/test_search.py b/tests/test_api/test_main/test_router/test_search.py index 01fbb59c2..77baea3f1 100644 --- a/tests/test_api/test_main/test_router/test_search.py +++ b/tests/test_api/test_main/test_router/test_search.py @@ -9,6 +9,7 @@ from enums import EntityTypeNames from ldap_protocol.ldap_codes import LDAPCodes +from ldap_protocol.ldap_requests.modify import Operation from tests.search_request_datasets import ( test_search_by_rule_anr_dataset, test_search_by_rule_bit_and_dataset, @@ -304,6 +305,115 @@ async def test_api_search_recursive_memberof(http_client: AsyncClient) -> None: assert all(obj["object_name"] in members for obj in data["search_result"]) +@pytest.mark.asyncio +@pytest.mark.usefixtures("session") +async def test_search_recursive_member( + http_client: AsyncClient, +) -> None: + """Test recursive member search for user0.""" + user = "cn=user0,cn=users,dc=md,dc=test" + expected_groups = [ + "cn=domain admins,cn=Groups,dc=md,dc=test", + ] + response = await http_client.post( + "entry/search", + json={ + "base_object": "dc=md,dc=test", + "scope": 2, + "deref_aliases": 0, + "size_limit": 1000, + "time_limit": 10, + "types_only": True, + "filter": f"(member:1.2.840.113556.1.4.1941:={user})", + "attributes": [], + "page_number": 1, + }, + ) + data = response.json() + assert data["resultCode"] == LDAPCodes.SUCCESS + dns = {obj["object_name"] for obj in data["search_result"]} + for group in expected_groups: + assert group in dns, f"Group {group} not found in search results" + assert len(data["search_result"]) >= 1 + + +@pytest.mark.asyncio +@pytest.mark.usefixtures("session") +async def test_search_recursive_member_for_many_roots( + http_client: AsyncClient, +) -> None: + """Test recursive member search with nested groups chain.""" + + async def _create_group(dn: str, name: str) -> None: + response = await http_client.post( + "/entry/add", + json={ + "entry": dn, + "password": None, + "attributes": [ + {"type": "name", "vals": [name]}, + {"type": "cn", "vals": [name]}, + { + "type": "objectClass", + "vals": ["top", "posixGroup", "group"], + }, + ], + }, + ) + assert response.json().get("resultCode") == LDAPCodes.SUCCESS + + async def _add_member(dn: str, member: str) -> None: + response = await http_client.patch( + "/entry/update", + json={ + "object": dn, + "changes": [ + { + "operation": Operation.ADD, + "modification": {"type": "member", "vals": [member]}, + }, + ], + }, + ) + assert response.json().get("resultCode") == LDAPCodes.SUCCESS + + group1_dn = "cn=recursive_test_group1,cn=Groups,dc=md,dc=test" + group2_dn = "cn=recursive_test_group2,cn=Groups,dc=md,dc=test" + group3_dn = "cn=recursive_test_group3,cn=Groups,dc=md,dc=test" + user = "cn=user1,cn=moscow,cn=russia,cn=users,dc=md,dc=test" + + await _create_group(group3_dn, "recursive_test_group3") + await _create_group(group2_dn, "recursive_test_group2") + await _create_group(group1_dn, "recursive_test_group1") + + await _add_member(group1_dn, user) + await _add_member(group2_dn, group1_dn) + await _add_member(group3_dn, group2_dn) + + response = await http_client.post( + "entry/search", + json={ + "base_object": "dc=md,dc=test", + "scope": 2, + "deref_aliases": 0, + "size_limit": 1000, + "time_limit": 10, + "types_only": True, + "filter": f"(member:1.2.840.113556.1.4.1941:={user})", + "attributes": [], + "page_number": 1, + }, + ) + data = response.json() + assert data["resultCode"] == LDAPCodes.SUCCESS + dns = {obj["object_name"] for obj in data["search_result"]} + + expected_groups = [group1_dn, group2_dn, group3_dn] + for group in expected_groups: + assert group in dns + assert "cn=domain admins,cn=Groups,dc=md,dc=test" in dns + + @pytest.mark.asyncio @pytest.mark.usefixtures("session") @pytest.mark.parametrize("dataset", test_search_by_rule_anr_dataset)