Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 32 additions & 2 deletions app/ldap_protocol/filter_interpreter.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,16 +32,26 @@
)
from ldap_protocol.utils.helpers import ft_to_dt
from ldap_protocol.utils.queries import get_path_filter, get_search_path
from repo.pg.tables import groups_table, queryable_attr as qa, users_table
from repo.pg.tables import (
directory_table,
groups_table,
queryable_attr as qa,
users_table,
)

from .asn1parser import ASN1Row, TagNumbers
from .objects import LDAPMatchingRule
from .utils.cte import find_members_recursive_cte, get_filter_from_path
from .utils.cte import (
find_members_recursive_cte,
find_root_group_recursive_cte,
get_filter_from_path,
)

_MEMBERS_ATTRS = {
"member",
"memberof",
f"memberof:{LDAPMatchingRule.LDAP_MATCHING_RULE_TRANSITIVE_EVAL}:",
f"member:{LDAPMatchingRule.LDAP_MATCHING_RULE_TRANSITIVE_EVAL}:",
}

_RULE_POS = 0
Expand Down Expand Up @@ -289,6 +299,8 @@ def _get_member_filter_function(
return self._recursive_filter_memberof
return self._filter_memberof
elif attribute == "member":
if oid == LDAPMatchingRule.LDAP_MATCHING_RULE_TRANSITIVE_EVAL:
return self._recursive_filter_member
return self._filter_member
else:
raise ValueError("Incorrect attribute specified")
Expand Down Expand Up @@ -317,6 +329,24 @@ def _filter_memberof(self, dn: str) -> UnaryExpression:
),
) # type: ignore

def _recursive_filter_member(self, dn: str) -> UnaryExpression:
"""Retrieve query conditions with the member attribute (recursive)."""
cte = find_root_group_recursive_cte([dn])

source_directory_id = (
select(directory_table.c.id)
.where(get_filter_from_path(dn))
.scalar_subquery()
)

return qa(Directory.id).in_(
select(cte.c.directory_id)
.where(
cte.c.directory_id != source_directory_id,
)
.distinct(),
) # type: ignore

def _filter_member(self, dn: str) -> UnaryExpression:
"""Retrieve query conditions with the member attribute."""
user_id_subquery = (
Expand Down
110 changes: 110 additions & 0 deletions tests/test_api/test_main/test_router/test_search.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@

from enums import EntityTypeNames
from ldap_protocol.ldap_codes import LDAPCodes
from ldap_protocol.ldap_requests.modify import Operation
from tests.search_request_datasets import (
test_search_by_rule_anr_dataset,
test_search_by_rule_bit_and_dataset,
Expand Down Expand Up @@ -304,6 +305,115 @@ async def test_api_search_recursive_memberof(http_client: AsyncClient) -> None:
assert all(obj["object_name"] in members for obj in data["search_result"])


@pytest.mark.asyncio
@pytest.mark.usefixtures("session")
async def test_search_recursive_member(
http_client: AsyncClient,
) -> None:
"""Test recursive member search for user0."""
user = "cn=user0,cn=users,dc=md,dc=test"
expected_groups = [
"cn=domain admins,cn=Groups,dc=md,dc=test",
]
response = await http_client.post(
"entry/search",
json={
"base_object": "dc=md,dc=test",
"scope": 2,
"deref_aliases": 0,
"size_limit": 1000,
"time_limit": 10,
"types_only": True,
"filter": f"(member:1.2.840.113556.1.4.1941:={user})",
"attributes": [],
"page_number": 1,
},
)
data = response.json()
assert data["resultCode"] == LDAPCodes.SUCCESS
dns = {obj["object_name"] for obj in data["search_result"]}
for group in expected_groups:
assert group in dns, f"Group {group} not found in search results"
assert len(data["search_result"]) >= 1


@pytest.mark.asyncio
@pytest.mark.usefixtures("session")
async def test_search_recursive_member_for_many_roots(
http_client: AsyncClient,
) -> None:
"""Test recursive member search with nested groups chain."""

async def _create_group(dn: str, name: str) -> None:
response = await http_client.post(
"/entry/add",
json={
"entry": dn,
"password": None,
"attributes": [
{"type": "name", "vals": [name]},
{"type": "cn", "vals": [name]},
{
"type": "objectClass",
"vals": ["top", "posixGroup", "group"],
},
],
},
)
assert response.json().get("resultCode") == LDAPCodes.SUCCESS

async def _add_member(dn: str, member: str) -> None:
response = await http_client.patch(
"/entry/update",
json={
"object": dn,
"changes": [
{
"operation": Operation.ADD,
"modification": {"type": "member", "vals": [member]},
},
],
},
)
assert response.json().get("resultCode") == LDAPCodes.SUCCESS

group1_dn = "cn=recursive_test_group1,cn=Groups,dc=md,dc=test"
group2_dn = "cn=recursive_test_group2,cn=Groups,dc=md,dc=test"
group3_dn = "cn=recursive_test_group3,cn=Groups,dc=md,dc=test"
user = "cn=user1,cn=moscow,cn=russia,cn=users,dc=md,dc=test"

await _create_group(group3_dn, "recursive_test_group3")
await _create_group(group2_dn, "recursive_test_group2")
await _create_group(group1_dn, "recursive_test_group1")

await _add_member(group1_dn, user)
await _add_member(group2_dn, group1_dn)
await _add_member(group3_dn, group2_dn)

response = await http_client.post(
"entry/search",
json={
"base_object": "dc=md,dc=test",
"scope": 2,
"deref_aliases": 0,
"size_limit": 1000,
"time_limit": 10,
"types_only": True,
"filter": f"(member:1.2.840.113556.1.4.1941:={user})",
"attributes": [],
"page_number": 1,
},
)
data = response.json()
assert data["resultCode"] == LDAPCodes.SUCCESS
dns = {obj["object_name"] for obj in data["search_result"]}

expected_groups = [group1_dn, group2_dn, group3_dn]
for group in expected_groups:
assert group in dns
assert "cn=domain admins,cn=Groups,dc=md,dc=test" in dns


@pytest.mark.asyncio
@pytest.mark.usefixtures("session")
@pytest.mark.parametrize("dataset", test_search_by_rule_anr_dataset)
Expand Down