From c70a8a23eced302943e223002af9d813c23e33c7 Mon Sep 17 00:00:00 2001 From: TheMihMih Date: Fri, 16 Jan 2026 12:01:30 +0300 Subject: [PATCH 1/9] added cache func and cached get_base_directories --- app/ldap_protocol/utils/helpers.py | 76 +++++++++++++++++++++++++++++- app/ldap_protocol/utils/queries.py | 2 + 2 files changed, 77 insertions(+), 1 deletion(-) diff --git a/app/ldap_protocol/utils/helpers.py b/app/ldap_protocol/utils/helpers.py index 296ff5978..73b061050 100644 --- a/app/ldap_protocol/utils/helpers.py +++ b/app/ldap_protocol/utils/helpers.py @@ -130,6 +130,7 @@ License: https://github.com/MultiDirectoryLab/MultiDirectory/blob/main/LICENSE """ +import asyncio import functools import hashlib import random @@ -138,19 +139,23 @@ import time from calendar import timegm from datetime import datetime +from functools import wraps from hashlib import blake2b from operator import attrgetter -from typing import Callable +from typing import Any, Callable, Iterable from zoneinfo import ZoneInfo from loguru import logger from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.ext.compiler import compiles +from sqlalchemy.orm.attributes import instance_state from sqlalchemy.sql.compiler import DDLCompiler from sqlalchemy.sql.expression import ClauseElement, Executable, Visitable from entities import Directory +DEFAULT_CACHE_TIME = 5 * 60 # 5 minutes + def validate_entry(entry: str) -> bool: """Validate entry str. @@ -402,3 +407,72 @@ async def explain_query( for row in await session.execute(explain(query, analyze=True)) ), ) + + +def has_expired_sqla_objs(obj: Any, max_depth: int = 3) -> bool: + def _check(value: Any) -> bool: + try: + state = instance_state(value) + return bool(state.expired_attributes) + except AttributeError: + return False + + def _walk(value: Any, depth: int = 0) -> bool: + if depth > max_depth: + return False + + if _check(value): + return True + + if isinstance(value, str | bytes | bytearray): + return False + + if isinstance(value, dict): + return any(_walk(v, depth + 1) for v in value.values()) + + if isinstance(value, Iterable): + return any(_walk(v, depth + 1) for v in value) + + return False + + return _walk(obj) + + +def async_lru_cache(ttl: int | None = DEFAULT_CACHE_TIME) -> Callable: + cache: dict = {} + locks: dict = {} + + def _is_value_expired( + value: Any, + now: float, + expires_at: float | None, + ) -> bool: + return bool( + expires_at and expires_at < now or has_expired_sqla_objs(value), + ) + + def decorator(func: Callable) -> Callable: + @wraps(func) + async def wrapper(*args: tuple, **kwargs: dict) -> Any: + key = (args, tuple(sorted(kwargs.items()))) + now = time.monotonic() + if key not in locks: + locks[key] = asyncio.Lock() + + async with locks[key]: + if key in cache: + value, expires_at = cache[key] + if not _is_value_expired(value, now, expires_at): + return value + else: + del cache[key] + + result = await func(*args, **kwargs) + expires_at = now + ttl if ttl else None + cache[key] = (result, expires_at) + del locks[key] + return result + + return wrapper + + return decorator diff --git a/app/ldap_protocol/utils/queries.py b/app/ldap_protocol/utils/queries.py index a1f9243de..4e40514aa 100644 --- a/app/ldap_protocol/utils/queries.py +++ b/app/ldap_protocol/utils/queries.py @@ -27,6 +27,7 @@ from .const import EMAIL_RE, GRANT_DN_STRING from .helpers import ( + async_lru_cache, create_integer_hash, create_object_sid, dn_is_base_directory, @@ -35,6 +36,7 @@ ) +@async_lru_cache() async def get_base_directories(session: AsyncSession) -> list[Directory]: """Get base domain directories.""" result = await session.execute( From 6e414332b29c9d98acc3c50a368f6db9524071ab Mon Sep 17 00:00:00 2001 From: TheMihMih Date: Tue, 20 Jan 2026 12:22:19 +0300 Subject: [PATCH 2/9] refactor: change get_base_directories output to dto --- .../versions/16a9fa2c1f1e_rename_readonly_group.py | 7 ++++--- .../71e642808369_add_directory_is_system.py | 9 +++++++-- app/entities.py | 4 ++-- app/ldap_protocol/auth/setup_gateway.py | 10 ++++++---- app/ldap_protocol/ldap_requests/add.py | 2 +- app/ldap_protocol/ldap_requests/modify_dn.py | 2 +- app/ldap_protocol/ldap_requests/search.py | 3 ++- app/ldap_protocol/roles/role_use_case.py | 3 ++- app/ldap_protocol/utils/helpers.py | 13 ++++++++++--- app/ldap_protocol/utils/queries.py | 13 ++++++++----- pyproject.toml | 1 + tests/test_api/test_auth/test_router.py | 2 +- 12 files changed, 45 insertions(+), 24 deletions(-) diff --git a/app/alembic/versions/16a9fa2c1f1e_rename_readonly_group.py b/app/alembic/versions/16a9fa2c1f1e_rename_readonly_group.py index b331dddd5..cf6ac280e 100644 --- a/app/alembic/versions/16a9fa2c1f1e_rename_readonly_group.py +++ b/app/alembic/versions/16a9fa2c1f1e_rename_readonly_group.py @@ -43,8 +43,8 @@ def upgrade(container: AsyncContainer) -> None: # noqa: ARG001 return ro_dir.name = READ_ONLY_GROUP_NAME - - ro_dir.create_path(ro_dir.parent, ro_dir.get_dn_prefix()) + path = ro_dir.parent.path if ro_dir.parent else [] + ro_dir.create_path(path, ro_dir.get_dn_prefix()) session.execute( update(Attribute) @@ -92,7 +92,8 @@ def downgrade(container: AsyncContainer) -> None: # noqa: ARG001 ro_dir.name = "readonly domain controllers" - ro_dir.create_path(ro_dir.parent, ro_dir.get_dn_prefix()) + path = ro_dir.parent.path if ro_dir.parent else [] + ro_dir.create_path(path, ro_dir.get_dn_prefix()) session.execute( update(Attribute) diff --git a/app/alembic/versions/71e642808369_add_directory_is_system.py b/app/alembic/versions/71e642808369_add_directory_is_system.py index 2526190e4..48ece1bc4 100644 --- a/app/alembic/versions/71e642808369_add_directory_is_system.py +++ b/app/alembic/versions/71e642808369_add_directory_is_system.py @@ -56,8 +56,13 @@ async def _indicate_system_directories( if not base_dn_list: return - for base_dn in base_dn_list: - base_dn.is_system = True + await session.execute( + update(Directory) + .where( + qa(Directory.parent_id).is_(None), + ) + .values(is_system=True), + ) await session.flush() diff --git a/app/entities.py b/app/entities.py index 53f5c95e9..acc32675a 100644 --- a/app/entities.py +++ b/app/entities.py @@ -270,10 +270,10 @@ def path_dn(self) -> str: def create_path( self, - parent: Directory | None = None, + parent_path: list | None = None, dn: str = "cn", ) -> None: - pre = parent.path if parent else [] + pre = parent_path or [] self.path = pre + [self.get_dn(dn)] self.depth = len(self.path) self.rdname = dn diff --git a/app/ldap_protocol/auth/setup_gateway.py b/app/ldap_protocol/auth/setup_gateway.py index b5bfe580a..a9aefd094 100644 --- a/app/ldap_protocol/auth/setup_gateway.py +++ b/app/ldap_protocol/auth/setup_gateway.py @@ -11,6 +11,7 @@ from sqlalchemy import exists, select from sqlalchemy.ext.asyncio import AsyncSession +from dtos import DirectoryDTO from entities import Attribute, Directory, Group, NetworkPolicy, User from ldap_protocol.ldap_schema.attribute_value_validator import ( AttributeValueValidator, @@ -124,21 +125,22 @@ async def create_dir( self, data: dict, is_system: bool, - domain: Directory, - parent: Directory | None = None, + domain: Directory | DirectoryDTO, + parent: Directory | DirectoryDTO | None = None, ) -> None: """Create data recursively.""" dir_ = Directory( is_system=is_system, object_class=data["object_class"], name=data["name"], - parent=parent, ) dir_.groups = [] - dir_.create_path(parent, dir_.get_dn_prefix()) + path = parent.path if parent else [] + dir_.create_path(path, dir_.get_dn_prefix()) self._session.add(dir_) await self._session.flush() + dir_.parent_id = parent.id if parent else None await self._session.refresh(dir_, ["id"]) self._session.add( diff --git a/app/ldap_protocol/ldap_requests/add.py b/app/ldap_protocol/ldap_requests/add.py index 6f29fe9af..b000b3798 100644 --- a/app/ldap_protocol/ldap_requests/add.py +++ b/app/ldap_protocol/ldap_requests/add.py @@ -211,7 +211,7 @@ async def handle( # noqa: C901 parent=parent, ) - new_dir.create_path(parent, new_dn) + new_dir.create_path(parent.path, new_dn) ctx.session.add(new_dir) await ctx.session.flush() diff --git a/app/ldap_protocol/ldap_requests/modify_dn.py b/app/ldap_protocol/ldap_requests/modify_dn.py index 7c315eadd..0ac906ce8 100644 --- a/app/ldap_protocol/ldap_requests/modify_dn.py +++ b/app/ldap_protocol/ldap_requests/modify_dn.py @@ -199,7 +199,7 @@ async def handle( return directory.parent = parent_dir - directory.create_path(directory.parent, dn=new_dn) + directory.create_path(parent_dir.path, dn=new_dn) try: await ctx.session.flush() diff --git a/app/ldap_protocol/ldap_requests/search.py b/app/ldap_protocol/ldap_requests/search.py index 1f9579dc2..ebd5730ac 100644 --- a/app/ldap_protocol/ldap_requests/search.py +++ b/app/ldap_protocol/ldap_requests/search.py @@ -23,6 +23,7 @@ from sqlalchemy.sql.elements import ColumnElement, UnaryExpression from sqlalchemy.sql.expression import Select +from dtos import DirectoryDTO from entities import ( Attribute, AttributeType, @@ -367,7 +368,7 @@ def _mutate_query_with_attributes_to_load( def _build_query( self, - base_directories: list[Directory], + base_directories: list[DirectoryDTO], user: UserSchema, access_manager: AccessManager, ) -> Select[tuple[Directory]]: diff --git a/app/ldap_protocol/roles/role_use_case.py b/app/ldap_protocol/roles/role_use_case.py index 1e978a3f1..08951cca7 100644 --- a/app/ldap_protocol/roles/role_use_case.py +++ b/app/ldap_protocol/roles/role_use_case.py @@ -6,6 +6,7 @@ from sqlalchemy import and_, insert, literal, or_, select +from dtos import DirectoryDTO from entities import AccessControlEntry, AceType, Directory, Role from enums import AuthorizationRules, RoleConstants, RoleScope from ldap_protocol.utils.queries import get_base_directories @@ -40,7 +41,7 @@ def __init__( async def inherit_parent_aces( self, - parent_directory: Directory, + parent_directory: Directory | DirectoryDTO, directory: Directory, ) -> None: """Inherit access control entries from the parent directory. diff --git a/app/ldap_protocol/utils/helpers.py b/app/ldap_protocol/utils/helpers.py index 73b061050..4fb72bf80 100644 --- a/app/ldap_protocol/utils/helpers.py +++ b/app/ldap_protocol/utils/helpers.py @@ -152,6 +152,7 @@ from sqlalchemy.sql.compiler import DDLCompiler from sqlalchemy.sql.expression import ClauseElement, Executable, Visitable +from dtos import DirectoryDTO from entities import Directory DEFAULT_CACHE_TIME = 5 * 60 # 5 minutes @@ -197,12 +198,18 @@ def validate_attribute(attribute: str) -> bool: ) -def is_dn_in_base_directory(base_directory: Directory, entry: str) -> bool: +def is_dn_in_base_directory( + base_directory: Directory | DirectoryDTO, + entry: str, +) -> bool: """Check if an entry in a base dn.""" return entry.lower().endswith(base_directory.path_dn.lower()) -def dn_is_base_directory(base_directory: Directory, entry: str) -> bool: +def dn_is_base_directory( + base_directory: Directory | DirectoryDTO, + entry: str, +) -> bool: """Check if an entry is a base dn.""" return base_directory.path_dn.lower() == entry.lower() @@ -307,7 +314,7 @@ def string_to_sid(sid_string: str) -> bytes: def create_object_sid( - domain: Directory, + domain: Directory | DirectoryDTO, rid: int, reserved: bool = False, ) -> str: diff --git a/app/ldap_protocol/utils/queries.py b/app/ldap_protocol/utils/queries.py index 4e40514aa..694b52b69 100644 --- a/app/ldap_protocol/utils/queries.py +++ b/app/ldap_protocol/utils/queries.py @@ -14,6 +14,7 @@ from sqlalchemy.orm import InstrumentedAttribute, joinedload, selectinload from sqlalchemy.sql.expression import ColumnElement +from dtos import DirectoryDTO, _directory_sqla_obj_to_dto from entities import Attribute, Directory, Group, User from ldap_protocol.ldap_schema.attribute_value_validator import ( AttributeValueValidator, @@ -37,13 +38,15 @@ @async_lru_cache() -async def get_base_directories(session: AsyncSession) -> list[Directory]: +async def get_base_directories(session: AsyncSession) -> list[DirectoryDTO]: """Get base domain directories.""" result = await session.execute( select(Directory) .filter(qa(Directory.parent_id).is_(None)), ) # fmt: skip - return list(result.scalars().all()) + return [ + _directory_sqla_obj_to_dto(dir_) for dir_ in result.scalars().all() + ] async def get_user(session: AsyncSession, name: str) -> User | None: @@ -364,14 +367,14 @@ async def create_group( dir_ = Directory( object_class="", name=name, - parent=parent, + parent_id=parent.id, ) session.add(dir_) await session.flush() - await session.refresh(dir_, ["id"]) + await session.refresh(dir_, ["id", "parent_id", "parent"]) group = Group(directory_id=dir_.id) - dir_.create_path(parent) + dir_.create_path(parent.path) session.add(group) dir_.object_sid = create_object_sid( diff --git a/pyproject.toml b/pyproject.toml index f7adf0e26..c25e9d974 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -222,6 +222,7 @@ known-first-party = [ "extra", "enums", "errors", + "dtos", ] known-third-party = [ "alembic", # https://github.com/astral-sh/ruff/issues/10519 diff --git a/tests/test_api/test_auth/test_router.py b/tests/test_api/test_auth/test_router.py index 26c0e4523..c13c0a5a6 100644 --- a/tests/test_api/test_auth/test_router.py +++ b/tests/test_api/test_auth/test_router.py @@ -13,7 +13,6 @@ from fastapi import status from httpx import AsyncClient from jose import jwt -from password_utils import PasswordUtils from sqlalchemy import select, update from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.orm import joinedload @@ -26,6 +25,7 @@ from ldap_protocol.ldap_requests.modify import Operation from ldap_protocol.session_storage import SessionStorage from ldap_protocol.utils.queries import get_search_path +from password_utils import PasswordUtils from repo.pg.tables import queryable_attr as qa from tests.conftest import TestCreds From 97a98245bff312cfdd99cf8c7c115f1f4ffa20fa Mon Sep 17 00:00:00 2001 From: TheMihMih Date: Tue, 20 Jan 2026 12:28:44 +0300 Subject: [PATCH 3/9] refactor: deleted has_expired_sqla_objs --- app/ldap_protocol/utils/helpers.py | 43 ++---------------------------- app/ldap_protocol/utils/queries.py | 2 +- 2 files changed, 3 insertions(+), 42 deletions(-) diff --git a/app/ldap_protocol/utils/helpers.py b/app/ldap_protocol/utils/helpers.py index 4fb72bf80..bcc9845a9 100644 --- a/app/ldap_protocol/utils/helpers.py +++ b/app/ldap_protocol/utils/helpers.py @@ -142,13 +142,12 @@ from functools import wraps from hashlib import blake2b from operator import attrgetter -from typing import Any, Callable, Iterable +from typing import Any, Callable from zoneinfo import ZoneInfo from loguru import logger from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.ext.compiler import compiles -from sqlalchemy.orm.attributes import instance_state from sqlalchemy.sql.compiler import DDLCompiler from sqlalchemy.sql.expression import ClauseElement, Executable, Visitable @@ -416,48 +415,10 @@ async def explain_query( ) -def has_expired_sqla_objs(obj: Any, max_depth: int = 3) -> bool: - def _check(value: Any) -> bool: - try: - state = instance_state(value) - return bool(state.expired_attributes) - except AttributeError: - return False - - def _walk(value: Any, depth: int = 0) -> bool: - if depth > max_depth: - return False - - if _check(value): - return True - - if isinstance(value, str | bytes | bytearray): - return False - - if isinstance(value, dict): - return any(_walk(v, depth + 1) for v in value.values()) - - if isinstance(value, Iterable): - return any(_walk(v, depth + 1) for v in value) - - return False - - return _walk(obj) - - def async_lru_cache(ttl: int | None = DEFAULT_CACHE_TIME) -> Callable: cache: dict = {} locks: dict = {} - def _is_value_expired( - value: Any, - now: float, - expires_at: float | None, - ) -> bool: - return bool( - expires_at and expires_at < now or has_expired_sqla_objs(value), - ) - def decorator(func: Callable) -> Callable: @wraps(func) async def wrapper(*args: tuple, **kwargs: dict) -> Any: @@ -469,7 +430,7 @@ async def wrapper(*args: tuple, **kwargs: dict) -> Any: async with locks[key]: if key in cache: value, expires_at = cache[key] - if not _is_value_expired(value, now, expires_at): + if not expires_at or expires_at > now: return value else: del cache[key] diff --git a/app/ldap_protocol/utils/queries.py b/app/ldap_protocol/utils/queries.py index 694b52b69..d2a8e7420 100644 --- a/app/ldap_protocol/utils/queries.py +++ b/app/ldap_protocol/utils/queries.py @@ -371,7 +371,7 @@ async def create_group( ) session.add(dir_) await session.flush() - await session.refresh(dir_, ["id", "parent_id", "parent"]) + await session.refresh(dir_, ["id"]) group = Group(directory_id=dir_.id) dir_.create_path(parent.path) From ded6d9d065f50e7417ee07e17100af4ccb4ea257 Mon Sep 17 00:00:00 2001 From: TheMihMih Date: Tue, 20 Jan 2026 12:31:09 +0300 Subject: [PATCH 4/9] add: dtos file --- app/dtos.py | 91 +++++++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 91 insertions(+) create mode 100644 app/dtos.py diff --git a/app/dtos.py b/app/dtos.py new file mode 100644 index 000000000..7efee30ca --- /dev/null +++ b/app/dtos.py @@ -0,0 +1,91 @@ +"""Module for dtos.""" + +import uuid +from dataclasses import dataclass +from datetime import datetime +from typing import ClassVar + +from adaptix.conversion import get_converter + +from entities import Directory, DistinguishedNamePrefix + + +@dataclass +class DirectoryDTO: + id: int + name: str + is_system: bool + object_sid: str + object_guid: uuid.UUID + parent_id: int | None + entity_type_id: int | None + object_class: str + rdname: str + created_at: datetime | None + updated_at: datetime | None + depth: int + password_policy_id: int | None + path: list[str] + + search_fields: ClassVar[dict[str, str]] = { + "name": "name", + "objectguid": "objectGUID", + "objectsid": "objectSid", + } + ro_fields: ClassVar[set[str]] = { + "uid", + "whencreated", + "lastlogon", + "authtimestamp", + "objectguid", + "objectsid", + "entitytypename", + } + + def get_dn_prefix(self) -> DistinguishedNamePrefix: + return { + "organizationalUnit": "ou", + "domain": "dc", + "container": "cn", + }.get( + self.object_class, + "cn", + ) # type: ignore + + def get_dn(self, dn: str = "cn") -> str: + return f"{dn}={self.name}" + + @property + def is_domain(self) -> bool: + return not self.parent_id and self.object_class == "domain" + + @property + def host_principal(self) -> str: + return f"host/{self.name}" + + @property + def path_dn(self) -> str: + return ",".join(reversed(self.path)) + + def create_path( + self, + parent: Directory | None = None, + dn: str = "cn", + ) -> None: + pre = parent.path if parent else [] + self.path = pre + [self.get_dn(dn)] + self.depth = len(self.path) + self.rdname = dn + + @property + def relative_id(self) -> str: + """Get RID from objectSid. + + Relative Identifier (RID) is the last sub-authority value of a SID. + """ + if "-" in self.object_sid: + return self.object_sid.split("-")[-1] + return "" + + +_directory_sqla_obj_to_dto = get_converter(Directory, DirectoryDTO) From 097e5f92723e4451021a80c883482d3365355cd5 Mon Sep 17 00:00:00 2001 From: TheMihMih Date: Tue, 20 Jan 2026 18:21:54 +0300 Subject: [PATCH 5/9] refactor: add invalidation --- app/dtos.py | 10 ----- app/ldap_protocol/auth/setup_gateway.py | 2 + app/ldap_protocol/utils/helpers.py | 50 +++++++++++-------------- app/ldap_protocol/utils/queries.py | 4 +- 4 files changed, 25 insertions(+), 41 deletions(-) diff --git a/app/dtos.py b/app/dtos.py index 7efee30ca..41d090c9e 100644 --- a/app/dtos.py +++ b/app/dtos.py @@ -67,16 +67,6 @@ def host_principal(self) -> str: def path_dn(self) -> str: return ",".join(reversed(self.path)) - def create_path( - self, - parent: Directory | None = None, - dn: str = "cn", - ) -> None: - pre = parent.path if parent else [] - self.path = pre + [self.get_dn(dn)] - self.depth = len(self.path) - self.rdname = dn - @property def relative_id(self) -> str: """Get RID from objectSid. diff --git a/app/ldap_protocol/auth/setup_gateway.py b/app/ldap_protocol/auth/setup_gateway.py index a9aefd094..73be8a2e0 100644 --- a/app/ldap_protocol/auth/setup_gateway.py +++ b/app/ldap_protocol/auth/setup_gateway.py @@ -17,6 +17,7 @@ AttributeValueValidator, ) from ldap_protocol.ldap_schema.entity_type_dao import EntityTypeDAO +from ldap_protocol.utils.async_cache import base_directories_cache from ldap_protocol.utils.helpers import create_object_sid, generate_domain_sid from ldap_protocol.utils.queries import get_domain_object_class from password_utils import PasswordUtils @@ -114,6 +115,7 @@ async def setup_enviroment( domain=domain, parent=domain, ) + base_directories_cache.clear() except Exception: import traceback diff --git a/app/ldap_protocol/utils/helpers.py b/app/ldap_protocol/utils/helpers.py index bcc9845a9..d02b38e49 100644 --- a/app/ldap_protocol/utils/helpers.py +++ b/app/ldap_protocol/utils/helpers.py @@ -142,7 +142,7 @@ from functools import wraps from hashlib import blake2b from operator import attrgetter -from typing import Any, Callable +from typing import Any, Callable, Generic, TypeVar from zoneinfo import ZoneInfo from loguru import logger @@ -154,8 +154,6 @@ from dtos import DirectoryDTO from entities import Directory -DEFAULT_CACHE_TIME = 5 * 60 # 5 minutes - def validate_entry(entry: str) -> bool: """Validate entry str. @@ -198,7 +196,7 @@ def validate_attribute(attribute: str) -> bool: def is_dn_in_base_directory( - base_directory: Directory | DirectoryDTO, + base_directory: DirectoryDTO, entry: str, ) -> bool: """Check if an entry in a base dn.""" @@ -206,7 +204,7 @@ def is_dn_in_base_directory( def dn_is_base_directory( - base_directory: Directory | DirectoryDTO, + base_directory: DirectoryDTO, entry: str, ) -> bool: """Check if an entry is a base dn.""" @@ -415,32 +413,26 @@ async def explain_query( ) -def async_lru_cache(ttl: int | None = DEFAULT_CACHE_TIME) -> Callable: - cache: dict = {} - locks: dict = {} +# def async_cache(ttl: int | None = DEFAULT_CACHE_TIME) -> Callable: +# """Cache for get_base_directories""" +# cache: list[tuple[list[DirectoryDTO], float | None]] = [] - def decorator(func: Callable) -> Callable: - @wraps(func) - async def wrapper(*args: tuple, **kwargs: dict) -> Any: - key = (args, tuple(sorted(kwargs.items()))) - now = time.monotonic() - if key not in locks: - locks[key] = asyncio.Lock() +# def decorator(func: Callable) -> Callable: +# @wraps(func) +# async def wrapper(*args: tuple, **kwargs: dict) -> list[DirectoryDTO]: +# if cache: +# value, expires_at = cache[0] +# if not expires_at or expires_at > time.monotonic(): +# return value +# else: +# cache.clear() - async with locks[key]: - if key in cache: - value, expires_at = cache[key] - if not expires_at or expires_at > now: - return value - else: - del cache[key] +# result = await func(*args, **kwargs) +# expires_at = time.monotonic() + ttl if ttl else None +# cache.append((result, expires_at)) - result = await func(*args, **kwargs) - expires_at = now + ttl if ttl else None - cache[key] = (result, expires_at) - del locks[key] - return result +# return result - return wrapper +# return wrapper - return decorator +# return decorator diff --git a/app/ldap_protocol/utils/queries.py b/app/ldap_protocol/utils/queries.py index d2a8e7420..64853a708 100644 --- a/app/ldap_protocol/utils/queries.py +++ b/app/ldap_protocol/utils/queries.py @@ -26,9 +26,9 @@ queryable_attr as qa, ) +from .async_cache import base_directories_cache from .const import EMAIL_RE, GRANT_DN_STRING from .helpers import ( - async_lru_cache, create_integer_hash, create_object_sid, dn_is_base_directory, @@ -37,7 +37,7 @@ ) -@async_lru_cache() +@base_directories_cache async def get_base_directories(session: AsyncSession) -> list[DirectoryDTO]: """Get base domain directories.""" result = await session.execute( From 1ded88e4aff24e8da3707e8c589b632bd8a0e5d3 Mon Sep 17 00:00:00 2001 From: TheMihMih Date: Tue, 20 Jan 2026 18:29:19 +0300 Subject: [PATCH 6/9] refactor: add separate cache file --- app/ldap_protocol/utils/async_cache.py | 42 ++++++++++++++++++++++++++ 1 file changed, 42 insertions(+) create mode 100644 app/ldap_protocol/utils/async_cache.py diff --git a/app/ldap_protocol/utils/async_cache.py b/app/ldap_protocol/utils/async_cache.py new file mode 100644 index 000000000..446998bbc --- /dev/null +++ b/app/ldap_protocol/utils/async_cache.py @@ -0,0 +1,42 @@ +"""Async cache implementation.""" +import time +from functools import wraps +from typing import Callable, Generic, TypeVar + +from dtos import DirectoryDTO + +T = TypeVar("T") +DEFAULT_CACHE_TIME = 5 * 60 # 5 minutes + + +class AsyncTTLCache(Generic[T]): + def __init__(self, ttl: int | None = DEFAULT_CACHE_TIME) -> None: + self._ttl = ttl + self._value: T | None = None + self._expires_at: float | None = None + + def clear(self) -> None: + self._value = None + self._expires_at = None + + def __call__(self, func: Callable) -> Callable: + @wraps(func) + async def wrapper(*args: tuple, **kwargs: dict) -> T: + if self._value is not None: + if not self._expires_at or self._expires_at > time.monotonic(): + return self._value + self.clear() + + result = await func(*args, **kwargs) + + self._value = result + self._expires_at = ( + time.monotonic() + self._ttl if self._ttl else None + ) + + return result + + return wrapper + + +base_directories_cache = AsyncTTLCache[list[DirectoryDTO]]() From 1f1e034a0b867648f97ff7ef91e700dbe3442969 Mon Sep 17 00:00:00 2001 From: TheMihMih Date: Wed, 21 Jan 2026 12:57:27 +0300 Subject: [PATCH 7/9] add: add cache clear to modify_dn --- app/ldap_protocol/ldap_requests/modify_dn.py | 2 ++ app/ldap_protocol/utils/async_cache.py | 1 + app/ldap_protocol/utils/helpers.py | 29 +------------------- interface | 2 +- 4 files changed, 5 insertions(+), 29 deletions(-) diff --git a/app/ldap_protocol/ldap_requests/modify_dn.py b/app/ldap_protocol/ldap_requests/modify_dn.py index 0ac906ce8..13e9b19e8 100644 --- a/app/ldap_protocol/ldap_requests/modify_dn.py +++ b/app/ldap_protocol/ldap_requests/modify_dn.py @@ -19,6 +19,7 @@ ModifyDNResponse, ) from ldap_protocol.objects import ProtocolRequests +from ldap_protocol.utils.async_cache import base_directories_cache from ldap_protocol.utils.queries import get_filter_from_path, validate_entry from repo.pg.tables import ( ace_directory_memberships_table, @@ -200,6 +201,7 @@ async def handle( directory.parent = parent_dir directory.create_path(parent_dir.path, dn=new_dn) + base_directories_cache.clear() try: await ctx.session.flush() diff --git a/app/ldap_protocol/utils/async_cache.py b/app/ldap_protocol/utils/async_cache.py index 446998bbc..ad807060a 100644 --- a/app/ldap_protocol/utils/async_cache.py +++ b/app/ldap_protocol/utils/async_cache.py @@ -1,4 +1,5 @@ """Async cache implementation.""" + import time from functools import wraps from typing import Callable, Generic, TypeVar diff --git a/app/ldap_protocol/utils/helpers.py b/app/ldap_protocol/utils/helpers.py index d02b38e49..ef9905945 100644 --- a/app/ldap_protocol/utils/helpers.py +++ b/app/ldap_protocol/utils/helpers.py @@ -130,7 +130,6 @@ License: https://github.com/MultiDirectoryLab/MultiDirectory/blob/main/LICENSE """ -import asyncio import functools import hashlib import random @@ -139,10 +138,9 @@ import time from calendar import timegm from datetime import datetime -from functools import wraps from hashlib import blake2b from operator import attrgetter -from typing import Any, Callable, Generic, TypeVar +from typing import Callable from zoneinfo import ZoneInfo from loguru import logger @@ -411,28 +409,3 @@ async def explain_query( for row in await session.execute(explain(query, analyze=True)) ), ) - - -# def async_cache(ttl: int | None = DEFAULT_CACHE_TIME) -> Callable: -# """Cache for get_base_directories""" -# cache: list[tuple[list[DirectoryDTO], float | None]] = [] - -# def decorator(func: Callable) -> Callable: -# @wraps(func) -# async def wrapper(*args: tuple, **kwargs: dict) -> list[DirectoryDTO]: -# if cache: -# value, expires_at = cache[0] -# if not expires_at or expires_at > time.monotonic(): -# return value -# else: -# cache.clear() - -# result = await func(*args, **kwargs) -# expires_at = time.monotonic() + ttl if ttl else None -# cache.append((result, expires_at)) - -# return result - -# return wrapper - -# return decorator diff --git a/interface b/interface index 97bbc08dd..95ed5e191 160000 --- a/interface +++ b/interface @@ -1 +1 @@ -Subproject commit 97bbc08dda7584f579f756d8b09abe60db67b47b +Subproject commit 95ed5e191cdafa07b1dfac96a1659926679ead97 From 36caf00ba125ebde627fa6bbdf19813868592901 Mon Sep 17 00:00:00 2001 From: TheMihMih Date: Wed, 21 Jan 2026 14:46:13 +0300 Subject: [PATCH 8/9] refactor: delete cache clear from modify_dn --- app/ldap_protocol/ldap_requests/modify_dn.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/app/ldap_protocol/ldap_requests/modify_dn.py b/app/ldap_protocol/ldap_requests/modify_dn.py index 13e9b19e8..0ac906ce8 100644 --- a/app/ldap_protocol/ldap_requests/modify_dn.py +++ b/app/ldap_protocol/ldap_requests/modify_dn.py @@ -19,7 +19,6 @@ ModifyDNResponse, ) from ldap_protocol.objects import ProtocolRequests -from ldap_protocol.utils.async_cache import base_directories_cache from ldap_protocol.utils.queries import get_filter_from_path, validate_entry from repo.pg.tables import ( ace_directory_memberships_table, @@ -201,7 +200,6 @@ async def handle( directory.parent = parent_dir directory.create_path(parent_dir.path, dn=new_dn) - base_directories_cache.clear() try: await ctx.session.flush() From 86b277820430d89764ea147b5e637242af07bffb Mon Sep 17 00:00:00 2001 From: TheMihMih Date: Wed, 21 Jan 2026 16:27:31 +0300 Subject: [PATCH 9/9] fix: delete field --- app/dtos.py | 1 - 1 file changed, 1 deletion(-) diff --git a/app/dtos.py b/app/dtos.py index 41d090c9e..0761b93fd 100644 --- a/app/dtos.py +++ b/app/dtos.py @@ -24,7 +24,6 @@ class DirectoryDTO: created_at: datetime | None updated_at: datetime | None depth: int - password_policy_id: int | None path: list[str] search_fields: ClassVar[dict[str, str]] = {