diff --git a/.dns/dns_api.py b/.dns/dns_api.py deleted file mode 100644 index f57f23bed..000000000 --- a/.dns/dns_api.py +++ /dev/null @@ -1,1395 +0,0 @@ -"""API for managing Bind9 DNS server. - -Copyright (c) 2025 MultiFactor -License: https://github.com/MultiDirectoryLab/MultiDirectory/blob/main/LICENSE -""" - -import contextlib -import logging -import os -import re -import subprocess -import tempfile -from collections import defaultdict -from dataclasses import dataclass -from enum import StrEnum -from typing import Annotated, NoReturn - -import dns -import dns.zone -import jinja2 -from fastapi import APIRouter, Depends, FastAPI, HTTPException, Request, status -from pydantic import BaseModel - -logging.basicConfig(level=logging.INFO) - -TEMPLATES: jinja2.Environment = jinja2.Environment( - loader=jinja2.FileSystemLoader("templates/"), - autoescape=True, - keep_trailing_newline=True, -) - -ZONE_FILES_DIR = "/opt" -NAMED_CONF = "/etc/bind/named.conf" -NAMED_LOCAL = "/etc/bind/named.conf.local" -NAMED_OPTIONS = "/etc/bind/named.conf.options" - -FIRST_SETUP_RECORDS = [ - {"name": "_ldap._tcp.", "value": "0 0 389 ", "type": "SRV"}, - {"name": "_ldaps._tcp.", "value": "0 0 636 ", "type": "SRV"}, - {"name": "_kerberos._tcp.", "value": "0 0 88 ", "type": "SRV"}, - {"name": "_kerberos._udp.", "value": "0 0 88 ", "type": "SRV"}, - {"name": "_kdc._tcp.", "value": "0 0 88 ", "type": "SRV"}, - {"name": "_kdc._udp.", "value": "0 0 88 ", "type": "SRV"}, - {"name": "_kpasswd._tcp.", "value": "0 0 464 ", "type": "SRV"}, - {"name": "_kpasswd._udp.", "value": "0 0 464 ", "type": "SRV"}, - # Record for PDC Emulator - { - "name": "_ldap._tcp.pdc._msdcs.", - "value": "0 100 389 ", - "type": "SRV", - }, - # Records for DC Locator (for trusts) - { - "name": "_kerberos._tcp.dc._msdcs.", - "value": "0 100 88 ", - "type": "SRV", - }, - { - "name": "_kerberos._tcp.Default-First-Site-Name._sites.dc._msdcs.", - "value": "0 100 88 ", - "type": "SRV", - }, - { - "name": "_ldap._tcp.dc._msdcs.", - "value": "0 100 389 ", - "type": "SRV", - }, - { - "name": "_ldap._tcp.Default-First-Site-Name._sites.dc._msdcs.", - "value": "0 100 389 ", - "type": "SRV", - }, - # Records for Global Catalog - {"name": "_gc._tcp.", "value": "0 100 3268 ", "type": "SRV"}, - { - "name": "_ldap._tcp.Default-First-Site-Name._sites.gc._msdcs.", - "value": "0 100 3268 ", - "type": "SRV", - }, - { - "name": "_ldap._tcp.gc._msdcs.", - "value": "0 100 3268 ", - "type": "SRV", - }, -] - - -class DNSError(Exception): - """Base class for DNS exceptions.""" - - -class DNSZoneCreateError(DNSError): - """DNS zone create error.""" - - -class DNSDomainNotFoundError(DNSError): - """DNS domain not found error.""" - - -class DNSZoneValidationError(DNSError): - """DNS validation error.""" - - -class DNSZoneConfigError(DNSError): - """DNS zone config error.""" - - -class DNSZoneNotFoundError(DNSError): - """DNS zone not found error.""" - - -class DNSZoneType(StrEnum): - """DNS zone types.""" - - MASTER = "master" - FORWARD = "forward" - - -class DNSRecordType(StrEnum): - """DNS record types.""" - - A = "A" - AAAA = "AAAA" - CNAME = "CNAME" - MX = "MX" - NS = "NS" - TXT = "TXT" - SOA = "SOA" - PTR = "PTR" - SRV = "SRV" - - -@dataclass -class DNSRecord: - """Single DNS record.""" - - name: str - value: str - ttl: int - - -@dataclass -class DNSRecords: - """List of DNS records grouped by type.""" - - type: DNSRecordType - records: list[DNSRecord] - - -@dataclass -class DNSZone: - """DNS zone.""" - - name: str - type: DNSZoneType - records: list[DNSRecords] - - -@dataclass -class DNSForwardZone: - """DNS forward zone.""" - - name: str - type: DNSZoneType - forwarders: list[str] - - -class DNSZoneParamName(StrEnum): - """Possible DNS zone option names.""" - - acl = "acl" - forwarders = "forwarders" - ttl = "ttl" - - -class DNSServerParamName(StrEnum): - """Possible DNS server option names.""" - - dnssec = "dnssec-validation" - - -@dataclass -class DNSZoneParam: - """DNS zone parameter.""" - - name: DNSZoneParamName - value: str | list[str] | None - - -class DNSZoneCreateRequest(BaseModel): - """DNS zone create request scheme.""" - - zone_name: str - zone_type: DNSZoneType - nameserver: str | None - params: list[DNSZoneParam] - - -class DNSZoneUpdateRequest(BaseModel): - """DNS zone update request scheme.""" - - zone_name: str - params: list[DNSZoneParam] - - -class DNSZoneDeleteRequest(BaseModel): - """DNS zone delete request scheme.""" - - zone_name: str - - -class DNSRecordCreateRequest(BaseModel): - """DNS record create request scheme.""" - - zone_name: str - record_name: str - record_value: str - record_type: str - ttl: int - - -class DNSRecordUpdateRequest(BaseModel): - """DNS record update request scheme.""" - - zone_name: str - record_name: str - record_value: str - record_type: DNSRecordType - ttl: int - - -class DNSRecordDeleteRequest(BaseModel): - """DNS record delete request schem.""" - - zone_name: str - record_name: str - record_value: str - record_type: DNSRecordType - - -class DNSServerSetupRequest(BaseModel): - """DNS server setup request schem.""" - - zone_name: str - - -@dataclass -class DNSServerParam: - """DNS zone parameter.""" - - name: DNSServerParamName - value: str | list[str] - - -class BindDNSServerManager: - """Bind9 DNS server manager.""" - - @staticmethod - def _get_zone_obj_by_zone_name(zone_name) -> dns.zone.Zone: - """Get DNS zone object by zone name. - - Algorithm: - 1. Build the path to the zone file using the zone name. - 2. Load the zone object using dns.zone.from_file. - - Args: - zone_name (str): Name of the DNS zone. - - Returns: - dns.zone.Zone: Zone object. - - """ - zone_file = os.path.join(ZONE_FILES_DIR, f"{zone_name}.zone") - return dns.zone.from_file( - zone_file, - relativize=False, - origin=zone_name, - ) - - def _write_zone_data_to_file( - self, - zone_name: str, - zone: dns.zone.Zone, - ) -> None: - """Write zone data to file and reload the zone. - - Algorithm: - 1. Save the zone object to a file. - 2. Call reload to apply changes. - - Args: - zone_name (str): Name of the DNS zone. - zone (dns.zone.Zone): Zone object. - - """ - error = self._check_zone(zone.to_text(), zone_name) - if error: - raise DNSZoneCreateError( - f"Error while writing zone data to file {zone_name}: {error}", - ) - - zone.to_file(os.path.join(ZONE_FILES_DIR, f"{zone_name}.zone")) - self.reload(zone_name) - - def _check_config(self, config: str) -> str | None: - with tempfile.NamedTemporaryFile(mode="w") as tf: - tf.write(config) - tmp_path = tf.name - - result = subprocess.run( # noqa: S603 - ["/usr/bin/named-checkconf", tmp_path], - capture_output=True, - text=True, - ) - - return result.stderr - - def _check_zone(self, zonefile: str, zone_name: str) -> str | None: - with tempfile.NamedTemporaryFile(mode="w") as zf: - zf.write(zonefile) - tmp_path = zf.name - - result = subprocess.run( # noqa: S603 - [ - "/usr/bin/named-checkzone", - "-i", - "none", - zone_name, - tmp_path, - ], - capture_output=True, - text=True, - ) - - return result.stderr - - def _get_base_domain(self) -> str: - """Get base domain. - - Algorithm: - 1. Open named.conf.local. - 2. Get first domain. - - """ - named_local = None - - with open(NAMED_LOCAL) as file: - named_local = file.read() - - pattern = r""" - zone\s+"([^"]+)"\s*{[^}]*? - type\s+master\b[^}]*? - """ - - matches = re.search(pattern, named_local, re.DOTALL | re.VERBOSE) - - if not matches: - raise DNSDomainNotFoundError("Base domain not found") - - return matches.group(1) - - def add_zone( - self, - zone_name: str, - zone_type: str, - nameserver_ip: str | None, - params: list[DNSZoneParam], - ) -> None: - """Add a new DNS zone. - - Algorithm: - 1. Build a dictionary of zone parameters. - 2. Render the zone file and zone options templates. - 3. Process parameters (acl, forwarders, ttl, etc.) and add them - to the zone options. - 4. Write the zone options to named.conf.local. - 5. Restart the server. - - Args: - zone_name (str): Name of the DNS zone. - zone_type (str): Type of the DNS zone. - nameserver_ip (str | None): Nameserver IP address. - params (list[DNSZoneParam]): List of zone parameters. - - """ - params_dict = {param.name: param.value for param in params} - - if zone_type != DNSZoneType.FORWARD: - nameserver_ip = ( - nameserver_ip - if nameserver_ip is not None - else os.getenv("DEFAULT_NAMESERVER") - ) - nameserver = ( - self._get_base_domain() - if "in-addr.arpa" in zone_name - else zone_name - ) - - zf_template = TEMPLATES.get_template("zone.template") - zone_file = zf_template.render( - domain=zone_name, - nameserver=nameserver, - ttl=params_dict.get("ttl", 604800), - ) - - zone_error = self._check_zone(zone_file, zone_name) - if zone_error: - raise DNSZoneValidationError( - f"Error in zonefile during adding zone: {zone_error}", - ) - - with open( - os.path.join(ZONE_FILES_DIR, f"{zone_name}.zone"), - "w", - ) as file: - file.write(zone_file) - - if "in-addr.arpa" not in zone_name: - for record in [ - DNSRecord( - name=zone_name, - value=nameserver_ip, - ttl=604800, - ), - DNSRecord( - name=f"ns1.{zone_name}", - value=nameserver_ip, - ttl=604800, - ), - DNSRecord( - name=f"ns2.{zone_name}", - value="127.0.0.1", - ttl=604800, - ), - ]: - self.add_record( - record, - DNSRecordType.A, - zone_name=zone_name, - ) - - zo_template = TEMPLATES.get_template("zone_options.template") - zone_options = zo_template.render( - zone_name=zone_name, - zone_type=zone_type, - forwarders=params_dict.get("forwarders"), - ) - - for param in params: - param_name = param.name if param.name != "acl" else "allow-query" - if ( - param_name == "allow-query" - and zone_type == DNSZoneType.FORWARD - ): - continue - if isinstance(param.value, list): - param_value = "{ " + f"{'; '.join(param.value)};" + " }" - else: - param_value = param.value - - zone_options = self._add_zone_param( - zone_options, - zone_name, - param_name, - param_value, - ) - - config_error = self._check_config(zone_options) - if config_error: - raise DNSError( - f"Error with config during adding zone: {config_error}", - ) - - with open(NAMED_LOCAL, "a") as file: - file.write(zone_options) - - self.restart() - - @staticmethod - def _add_zone_param( - named_local: str, - zone_name: str, - param_name: str, - param_value: str, - ) -> str: - """Add a zone parameter to named.conf.local. - - Regex explanation: - - (zone\\s+"{zone_name}"\\s*{{[^}}]*?) - Captures the start of the zone block for the given zone_name, - including all content up to the closing '};'. - - (\\s*}};) - Captures the closing of the zone block - (with optional whitespace). - The regex is used to insert a new parameter - just before the end of the zone block. - - Algorithm: - 1. Use re.sub to add the parameter line inside the zone block. - 2. Return the modified text. - - Args: - named_local (str): Contents of named.conf.local. - zone_name (str): Name of the DNS zone. - param_name (str): Parameter name. - param_value (str): Parameter value. - - Returns: - str: Modified named.conf.local content. - - """ - pattern = rf'(zone\s+"{re.escape(zone_name)}"\s*{{[^}}]*?)(\s*}};)' - replacement = rf"\1\n {param_name} {param_value};\2" - return re.sub(pattern, replacement, named_local, flags=re.DOTALL) - - @staticmethod - def _delete_zone_param( - named_local: str, - zone_name: str, - param_name: str, - ) -> str: - """Delete a zone parameter from named.conf.local. - - Regex explanation: - - (zone\\s+"{zone_name}"\\s*{{) - Captures the start of the zone block for the given zone_name. - - (.*?) - Non-greedy match for any content up to the parameter line. - - (^\\s*{param_name}\\s+(?:[^{{;\\n}}]+|{{[^}}]+}})\\s*;\\s*\\n) - Matches the parameter line (with possible value in braces - or not), including the trailing semicolon and newline. - - (.*?}}) - Matches the rest of the zone block up to the closing brace. - The regex is used to remove the parameter line from the zone block. - - Algorithm: - 1. Use re.sub to remove the parameter line from the zone block. - 2. Return the modified text. - - Args: - named_local (str): Contents of named.conf.local. - zone_name (str): Name of the DNS zone. - param_name (str): Parameter name. - - Returns: - str: Modified named.conf.local content. - - """ - pattern = rf""" - (zone\s+"{re.escape(zone_name)}"\s*{{) - (.*?) - ^\s*{re.escape(param_name)}\s+ - (?:[^{{;\n}}]+|{{[^}}]+}}) - \s*;\s*\n - (.*?}}) - """ - - return re.sub( - pattern, - r"\1\2\3", - named_local, - flags=re.DOTALL | re.VERBOSE | re.MULTILINE, - ) - - def _update_zone_param( - self, - named_local: str, - zone_name: str, - param_name: str, - param_value: str, - ) -> str: - """Update a zone parameter in named.conf.local. - - Algorithm: - 1. Remove the old parameter value using _delete_zone_param. - 2. Add the new value using _add_zone_param. - 3. Return the modified text. - - Args: - named_local (str): Contents of named.conf.local. - zone_name (str): Name of the DNS zone. - param_name (str): Parameter name. - param_value (str): Parameter value. - - Returns: - str: Modified named.conf.local content. - - """ - new_named_local = self._delete_zone_param( - named_local, - zone_name, - param_name, - ) - return self._add_zone_param( - new_named_local, - zone_name, - param_name, - param_value, - ) - - def update_zone(self, zone_name: str, params: list[DNSZoneParam]) -> None: - """Update zone parameters. - - Regex explanation: - - ^zone\\s+"{zone_name}"\\s*{{ - Matches the start of the zone block for the given zone_name. - - [^}}]*? - Non-greedy match for any content inside the block up - to the parameter. - - \\s{param_name}\\b - Matches the parameter name as a whole word. - - \\s+(?:[^{{;\\n}}]+|{{[^}}]+}})\\s*; - Matches the parameter value (either a simple value or a block - in braces), followed by a semicolon. - This regex is used to check if the parameter exists in the zone - block. - - Algorithm: - 1. Read named.conf.local content. - 2. For each parameter, check if it exists in the zone block - using regex. - 3. If value is None, remove the parameter; otherwise, update or - add it. - 4. Write the modified config back to the file. - - Args: - zone_name (str): Name of the DNS zone. - params (list[DNSZoneParam]): List of zone parameters. - - """ - named_local = None - with open(NAMED_LOCAL) as file: - named_local = file.read() - - for param in params: - param_name = param.name if param.name != "acl" else "allow-query" - pattern = rf""" - ^zone\s+"{re.escape(zone_name)}"\s*{{ - [^}}]*? - \s{re.escape(param_name)}\b - \s+(?:[^{{;\n}}]+|{{[^}}]+}}) - \s*; - """ - has_param = bool( - re.search( - pattern, - named_local, - flags=re.MULTILINE | re.VERBOSE | re.DOTALL, - ), - ) - - if param.value is None: - named_local = self._delete_zone_param( - named_local, - zone_name, - param_name, - ) - continue - - if isinstance(param.value, list): - param_value = "{ " + f"{'; '.join(param.value)};" + " }" - else: - param_value = param.value - - if has_param: - named_local = self._update_zone_param( - named_local, - zone_name, - param_name, - param_value, - ) - else: - named_local = self._add_zone_param( - named_local, - zone_name, - param_name, - param_value, - ) - - error = self._check_config(named_local) - if error: - raise DNSZoneConfigError( - f"Error while updating zone {zone_name}: {error}", - ) - - with open(NAMED_LOCAL, "w") as file: - file.write(named_local) - - self.restart() - - def delete_zone(self, zone_name: str) -> None: - """Delete an existing zone. - - Regex explanation: - - ^\\s*zone\\s+"{zone_name}"\\s*{{ - Matches the start of the zone block for the given zone_name. - - (?:[^{{}}]|{{(?:[^{{}}]|{{[^}}]*}})*}})*? - Non-greedy match for any content inside the block, including - nested braces. - - \\s*}};\\s* - Matches the closing of the zone block (with optional - whitespace). - This regex is used to remove the entire zone block from the config. - - Algorithm: - 1. Read named.conf.local content. - 2. Determine the zone type. - 3. Remove the zone block using regex. - 4. If not a forward zone, remove the zone file. - 5. Restart the server. - - Args: - zone_name (str): Name of the DNS zone. - - """ - named_local = None - with open(NAMED_LOCAL) as file: - named_local = file.read() - - zone_type = self.get_zone_type_by_zone_name(zone_name) - - pattern = rf""" - ^\s*zone\s+"{re.escape(zone_name)}"\s*{{ - (?: - [^{{}}] - | - {{(?:[^{{}}]|{{[^}}]*}})*}} - )*? - \s*}};\s* - """ - named_local = re.sub( - pattern, - "", - named_local, - flags=re.MULTILINE | re.VERBOSE | re.DOTALL, - ) - - error = self._check_config(named_local) - if error: - raise DNSZoneConfigError( - f"Error while deleting zone {zone_name}: {error}", - ) - - with open(NAMED_LOCAL, "w") as file: - file.write(named_local) - - if zone_type != DNSZoneType.FORWARD: - with contextlib.suppress(FileNotFoundError): - os.remove(os.path.join(ZONE_FILES_DIR, f"{zone_name}.zone")) - - self.restart() - - def reload(self, zone_name: str | None = None) -> None: - """Reload a zone by name or all zones if no name is provided. - - Algorithm: - 1. Call rndc reload with the zone name or without it. - - Args: - zone_name (str | None): Name of the DNS zone or None. - - """ - subprocess.run( # noqa: S603 - [ - "/usr/sbin/rndc", - "reload", - zone_name if zone_name else "", - ], - ) - - def restart(self) -> None: - """Restart the Bind9 server (reconfig). - - Algorithm: - 1. Call rndc reconfig. - """ - subprocess.run( # noqa: S603 - [ - "/usr/sbin/rndc", - "reconfig", - ], - ) - - def first_setup(self, zone_name: str) -> str: - """Perform initial setup of the Bind9 server. - - Algorithm: - 1. Create a master zone. - 2. Add standard SRV records for services (ldap, kerberos, etc.). - - Args: - zone_name (str): Name of the DNS zone. - - """ - self.add_zone( - zone_name, - "master", - None, - params=[], - ) - - self.add_record( - DNSRecord( - name=f"gc._msdcs.{zone_name}", - value=os.getenv("DEFAULT_NAMESERVER"), - ttl=604800, - ), - DNSRecordType.A, - zone_name, - ) - - for record in FIRST_SETUP_RECORDS: - self.add_record( - DNSRecord( - name=f"{record.get('name')}{zone_name}", - value=f"{record.get('value')}{zone_name}.", - ttl=604800, - ), - record.get("type"), - zone_name, - ) - - @staticmethod - def get_zone_type_by_zone_name(zone_name: str) -> DNSZoneType: - """Get the zone type by zone name. - - Regex explanation: - - zone\\s+"{zone_name}"\\s*{{\\s*type\\s*([^;]+); - Matches the zone block for the given zone_name and captures - the type value after 'type'. - The first capturing group contains the zone type - (e.g., master, forward). - - Algorithm: - 1. Read named.conf.local content. - 2. Use regex to find the zone block and extract the type. - - Args: - zone_name (str): Name of the DNS zone. - - Returns: - DNSZoneType: Zone type. - - """ - with open(NAMED_LOCAL) as file: - named_local_settings = file.read() - - pattern = rf'zone\s*"{re.escape(zone_name)}"\s*{{\s*type\s*([^;]+);' - match = re.search(pattern, named_local_settings) - if not match: - raise DNSZoneNotFoundError(f"Zone not found: {zone_name}") - return DNSZoneType(match.group(1).strip()) - - def get_all_records_from_zone( - self, - zone_name: str, - ) -> DNSRecords: - """Get all records from a zone by name. - - Algorithm: - 1. Load the zone object. - 2. Iterate over all rdata and group by type. - 3. Return a list of DNSRecords by type. - - Args: - zone_name (str): Name of the DNS zone. - - Returns: - list[DNSRecords]: List of DNSRecords grouped by type. - - """ - result: defaultdict[str, list] = defaultdict(list) - - zone = self._get_zone_obj_by_zone_name(zone_name) - for name, ttl, rdata in zone.iterate_rdatas(): - record_type = rdata.rdtype.name - - result[record_type].append( - DNSRecord( - name=name.to_text(), - value=rdata.to_text(), - ttl=ttl, - ), - ) - - return [ - DNSRecords(type=record_type, records=records) - for record_type, records in result.items() - ] - - def get_all_records(self) -> list[DNSZone]: - """Get all records from all zones. - - Algorithm: - 1. Scan the directory for zone files. - 2. For each file, determine the zone name and type. - 3. Get all records for the zone. - 4. Return a list of DNSZone objects. - - Returns: - list[DNSZone]: List of DNSZone objects. - - """ - zone_files = os.listdir(ZONE_FILES_DIR) - - result: list[DNSZone] = [] - for file in zone_files: - if file.split(".")[-1] != "zone": - continue - zone_name = ".".join(file.split(".")[:-1]) - zone_type = self.get_zone_type_by_zone_name(zone_name) - zone_records = self.get_all_records_from_zone( - zone_name, - ) - result.append( - DNSZone( - name=zone_name, - type=zone_type, - records=zone_records, - ), - ) - - return result - - async def get_forward_zones(self) -> list[DNSForwardZone]: - """Get all forward DNS zones. - - Regex explanation: - - zone\\s+"([^"]+)"\\s*{{ - Captures the zone name. - - [^}}]*?type\\s+forward\\b[^}}]*? - Matches content up to the 'type forward' declaration. - - forwarders\\s*{{([^}}]+)}} - Captures the content inside the forwarders block - (list of forwarder IPs). - The first group is the zone name, - the second group is the forwarders list. - - Algorithm: - 1. Read named.conf.local content. - 2. Use regex to find forward zone blocks and their forwarders. - 3. Return a list of DNSForwardZone objects. - - Returns: - list[DNSForwardZone]: List of forward zones. - - """ - named_local = None - with open(NAMED_LOCAL) as file: - named_local = file.read() - - pattern = r""" - zone\s+"([^"]+)"\s*{[^}]*? - type\s+forward\b[^}]*? - forwarders\s*{([^}]+)} - """ - - matches = re.findall(pattern, named_local, re.DOTALL | re.VERBOSE) - - result = [] - for zone_name, forwarders in matches: - clean_forwarders = [ - forwarder.strip() - for forwarder in forwarders.split(";") - if forwarder.strip() - ] - result.append( - DNSForwardZone( - zone_name, - DNSZoneType.FORWARD, - clean_forwarders, - ), - ) - - return result - - def add_record( - self, - record: DNSRecord, - record_type: DNSRecordType, - zone_name: str, - ) -> None: - """Add a DNS record to a zone. - - Algorithm: - 1. Load the zone object. - 2. Build rdata by type and value. - 3. Add rdata to the rdataset. - 4. Save changes to the zone file and reload the zone. - - Args: - record (DNSRecord): DNS record to add. - record_type (DNSRecordType): Type of the DNS record. - zone_name (str): Name of the DNS zone. - - """ - zone = self._get_zone_obj_by_zone_name(zone_name) - - record_name = dns.name.from_text(record.name) - rdata = dns.rdata.from_text( - dns.rdataclass.IN, - dns.rdatatype.from_text(record_type), - record.value, - ) - - zone.find_rdataset(record_name, rdata.rdtype, create=True).add( - rdata, - ttl=record.ttl, - ) - - self._write_zone_data_to_file(zone_name, zone) - - def delete_record( - self, - record: DNSRecord, - record_type: DNSRecordType, - zone_name: str, - ) -> None: - """Delete a record from a zone. - - Algorithm: - 1. Load the zone object. - 2. Find the rdataset by name and type. - 3. If rdata is present, remove it from the rdataset. - 4. Save changes to the zone file and reload the zone. - - Args: - record (DNSRecord): DNS record to delete. - record_type (DNSRecordType): Type of the DNS record. - zone_name (str): Name of the DNS zone. - - """ - zone = self._get_zone_obj_by_zone_name(zone_name) - name = dns.name.from_text(record.name) - rdatatype = dns.rdatatype.from_text(record_type) - rdata = dns.rdata.from_text( - dns.rdataclass.IN, - rdatatype, - record.value, - ) - - if name in zone.nodes: - node = zone.nodes[name] - rdataset = node.get_rdataset(dns.rdataclass.IN, rdatatype) - if rdataset and rdata in rdataset: - rdataset.remove(rdata) - - self._write_zone_data_to_file(zone_name, zone) - - def update_record( - self, - old_record: DNSRecord, - new_record: DNSRecord, - record_type, - zone_name, - ) -> None: - """Update a record in a zone (value or TTL). - - Algorithm: - 1. Delete the old record. - 2. Add the new record with updated values. - - Args: - old_record (DNSRecord): Old DNS record. - new_record (DNSRecord): New DNS record. - record_type: Type of the DNS record. - zone_name (str): Name of the DNS zone. - - """ - self.delete_record(old_record, record_type, zone_name) - self.add_record(new_record, record_type, zone_name) - - @staticmethod - def _add_new_server_param( - named_options: str, - param_name: str, - param_value: str, - ) -> str: - """Add a new parameter to the options block in named.conf.options. - - Regex explanation: - - (options\\s*\\{{[\\s\\S]*?) - Captures the start of the options block and all its content - up to the closing '};'. - - (\\s*\\}};) - Captures the closing of the options block - (with optional whitespace). - The regex is used to insert a new parameter just before the end of - the options block. - - Algorithm: - 1. Use re.sub to add the parameter line inside the options block. - 2. Return the modified text. - - Args: - named_options (str): Contents of named.conf.options. - param_name (str): Parameter name. - param_value (str): Parameter value. - - Returns: - str: Modified named.conf.options content. - - """ - return re.sub( - r"(options\s*\{[\s\S]*?)(\s*\};)", - rf"\1 {param_name} {param_value};\2", - named_options, - flags=re.DOTALL, - ) - - def update_dns_settings(self, settings: list[DNSServerParam]) -> None: - """Update or add DNS server parameters. - - Regex explanation: - - \\b{param_name}\\s+ - Matches the parameter name as a whole word, - followed by whitespace. - - ([^;\\n{{]+|{{[^}}]+}}) - Captures the parameter value, which can be a simple value or - a block in braces. - The first capturing group contains the parameter value. - - Algorithm: - 1. Read named.conf.options content. - 2. For each parameter, search for it using regex. - 3. If not found, add it; otherwise, update it. - 4. Write the modified config back to the file. - - Args: - settings (list[DNSServerParam]): List of server parameters. - - """ - named_options = None - - with open(NAMED_OPTIONS) as file: - named_options = file.read() - - for param in settings: - if isinstance(param.value, list): - param_value = "{ " + f"{'; '.join(param.value)};" + " }" - else: - param_value = param.value - pattern = rf"\b{re.escape(param.name)}\s+([^;\n{{]+|{{[^}}]+}})" - matched_param = re.search( - pattern, - named_options, - flags=re.MULTILINE, - ) - if matched_param is None: - named_options = self._add_new_server_param( - named_options, - param.name, - param_value, - ) - else: - named_options = re.sub( - pattern, - f"{param.name} {param_value}", - named_options, - ) - - error = self._check_config(named_options) - if error: - raise DNSZoneConfigError( - f"Error while updating DNS settings: {error}", - ) - - with open(NAMED_OPTIONS, "w") as file: - file.write(named_options) - - self.restart() - - @staticmethod - def get_server_settings() -> list[DNSServerParam]: - """Get a list of modifiable DNS server settings. - - Regex explanation: - - \\b{param_name}\\s+ - Matches the parameter name as a whole word, - followed by whitespace. - - ([^;\\n{{]+|{{[^}}]+}}) - Captures the parameter value, which can be a simple value or - a block in braces. - The first capturing group contains the parameter value. - - Algorithm: - 1. Read named.conf.options content. - 2. For each parameter in DNSServerParamName, - search for its value using regex. - 3. Return a list of DNSServerParam objects. - - Returns: - list[DNSServerParam]: List of server parameters. - - """ - named_options = None - with open(NAMED_OPTIONS) as file: - named_options = file.read() - - result = [] - for param_name in DNSServerParamName: - pattern = rf"\b{re.escape(param_name)}\s+([^;\n{{]+|{{[^}}]+}})" - matched_param_value = re.search(pattern, named_options) - if not matched_param_value: - continue - result.append( - DNSServerParam( - name=param_name, - value=matched_param_value.group(1).strip(), - ), - ) - - return result - - -async def get_dns_manager() -> type[BindDNSServerManager]: - """Get DNS server manager client.""" - return BindDNSServerManager() - - -zone_router = APIRouter(prefix="/zone", tags=["zone"]) -record_router = APIRouter(prefix="/record", tags=["record"]) -server_router = APIRouter(prefix="/server", tags=["server"]) - - -@zone_router.post("") -def create_zone( - data: DNSZoneCreateRequest, - dns_manager: Annotated[BindDNSServerManager, Depends(get_dns_manager)], -) -> None: - """Create DNS zone.""" - dns_manager.add_zone( - data.zone_name, - data.zone_type, - data.nameserver, - data.params, - ) - - -@zone_router.patch("") -def update_zone( - data: DNSZoneUpdateRequest, - dns_manager: Annotated[BindDNSServerManager, Depends(get_dns_manager)], -) -> None: - """Update DNS zone settings.""" - dns_manager.update_zone(data.zone_name, data.params) - - -@zone_router.delete("") -def delete_zone( - data: DNSZoneDeleteRequest, - dns_manager: Annotated[BindDNSServerManager, Depends(get_dns_manager)], -) -> None: - """Delete DNS zone.""" - dns_manager.delete_zone(data.zone_name) - - -@zone_router.get("") -async def get_all_records_by_zone( - dns_manager: Annotated[BindDNSServerManager, Depends(get_dns_manager)], -) -> list[DNSZone]: - """Get all DNS records grouped by zone.""" - return dns_manager.get_all_records() - - -@zone_router.get("/forward") -async def get_forward_zones( - dns_manager: Annotated[BindDNSServerManager, Depends(get_dns_manager)], -) -> list[DNSForwardZone]: - """Get all forward DNS zones.""" - return await dns_manager.get_forward_zones() - - -@record_router.post("") -def create_record( - data: DNSRecordCreateRequest, - dns_manager: Annotated[BindDNSServerManager, Depends(get_dns_manager)], -) -> None: - """Create DNS record in given zone.""" - dns_manager.add_record( - DNSRecord( - data.record_name, - data.record_value, - data.ttl, - ), - data.record_type, - data.zone_name, - ) - - -@record_router.patch("") -def update_record( - data: DNSRecordUpdateRequest, - dns_manager: Annotated[BindDNSServerManager, Depends(get_dns_manager)], -) -> None: - """Update existing DNS record.""" - dns_manager.update_record( - old_record=DNSRecord( - data.record_name, - data.record_value, - 0, - ), - new_record=DNSRecord( - data.record_name, - data.record_value, - data.ttl, - ), - record_type=data.record_type, - zone_name=data.zone_name, - ) - - -@record_router.delete("") -def delete_record( - data: DNSRecordDeleteRequest, - dns_manager: Annotated[BindDNSServerManager, Depends(get_dns_manager)], -) -> None: - """Delete existing DNS record.""" - dns_manager.delete_record( - DNSRecord( - data.record_name, - data.record_value, - 0, - ), - data.record_type, - data.zone_name, - ) - - -@server_router.get("/restart") -def restart_dns_server( - dns_manager: Annotated[BindDNSServerManager, Depends(get_dns_manager)], -) -> None: - """Restart DNS server via reconfig.""" - dns_manager.restart() - - -@zone_router.get("/reload/{zone_name}") -def reload_zone( - zone_name: str, - dns_manager: Annotated[BindDNSServerManager, Depends(get_dns_manager)], -) -> None: - """Force reload DNS zone from zone file.""" - dns_manager.reload(zone_name) - - -@server_router.patch("/settings") -def update_dns_server_settings( - settings: list[DNSServerParam], - dns_manager: Annotated[BindDNSServerManager, Depends(get_dns_manager)], -) -> None: - """Update settings of DNS server.""" - dns_manager.update_dns_settings(settings) - - -@server_router.get("/settings") -async def get_server_settings( - dns_manager: Annotated[BindDNSServerManager, Depends(get_dns_manager)], -) -> list[DNSServerParam]: - """Get list of modifiable server settings.""" - return dns_manager.get_server_settings() - - -@server_router.post("/setup") -def setup_server( - data: DNSServerSetupRequest, - dns_manager: Annotated[BindDNSServerManager, Depends(get_dns_manager)], -) -> None: - """Init setup of DNS server.""" - dns_manager.first_setup(data.zone_name) - - -async def handle_dns_error( - request: Request, # noqa: ARG001 - exc: Exception, -) -> NoReturn: - """Handle DNS API error.""" - raise HTTPException(status.HTTP_400_BAD_REQUEST, detail=str(exc)) - - -def create_app() -> FastAPI: - """Create FastAPI app.""" - app = FastAPI( - name="DNSServerManager", - title="DNSServerManager", - ) - - app.include_router(record_router) - app.include_router(zone_router) - app.include_router(server_router) - - app.add_exception_handler(DNSError, handler=handle_dns_error) - - return app diff --git a/.dns/entrypoint.sh b/.dns/entrypoint.sh deleted file mode 100755 index 25e0891d9..000000000 --- a/.dns/entrypoint.sh +++ /dev/null @@ -1,19 +0,0 @@ -#!/bin/bash - -fix_rndc_key() { - local rndc_key="/etc/bind/rndc.key" - if [ -f "$rndc_key" ]; then - chown bind:bind "$rndc_key" 2>/dev/null || chown 100:101 "$rndc_key" 2>/dev/null || true - chmod 640 "$rndc_key" 2>/dev/null || true - fi -} - -/usr/local/bin/docker-entrypoint.sh & - -fix_rndc_key - -/venvs/.venv/bin/python3.13 -m uvicorn --factory dns_api:create_app --host 0.0.0.0 --reload & - -wait -n - -exit $? diff --git a/.dns/templates/zone.template b/.dns/templates/zone.template deleted file mode 100644 index 249ebc349..000000000 --- a/.dns/templates/zone.template +++ /dev/null @@ -1,11 +0,0 @@ -$ORIGIN . -$TTL {{ ttl }} -{{ domain }} IN SOA ns1.{{ nameserver }}. support.md.ru. ( - {{ today }}01 - 10800 - 3600 - 604800 - 21600 - ) - IN NS ns1.{{ nameserver }}. - IN NS ns2.{{ nameserver }}. diff --git a/.dns/templates/zone_options.template b/.dns/templates/zone_options.template deleted file mode 100644 index 22f20a3f3..000000000 --- a/.dns/templates/zone_options.template +++ /dev/null @@ -1,10 +0,0 @@ -zone "{{ zone_name }}" { - type {{ zone_type }}; - {%- if zone_type == "master" %} - file "/opt/{{ zone_name }}.zone"; - notify no; - {%- endif %} - {%- if zone_type == "forward" %} - forward only; - {%- endif %} -}; diff --git a/.docker/Dockerfile b/.docker/Dockerfile index b7942c3c8..253269161 100644 --- a/.docker/Dockerfile +++ b/.docker/Dockerfile @@ -35,7 +35,7 @@ ENV VIRTUAL_ENV=/venvs/.venv \ VERSION=${VERSION:-beta} -RUN set -eux; apk add --no-cache krb5-libs curl openssl netcat-openbsd +RUN set -eux; apk add --no-cache krb5-libs curl openssl netcat-openbsd libsodium-dev COPY app /app COPY pyproject.toml / diff --git a/.docker/bind9.Dockerfile b/.docker/bind9.Dockerfile deleted file mode 100644 index d5b8154a8..000000000 --- a/.docker/bind9.Dockerfile +++ /dev/null @@ -1,45 +0,0 @@ -FROM python:3.13-bookworm AS builder - -ENV VIRTUAL_ENV=/venvs/.venv \ - PATH="/venvs/.venv/bin:$PATH" - -WORKDIR /venvs - -RUN python -m venv .venv -RUN pip install \ - fastapi==0.115.12 \ - uvicorn==0.34.2 \ - pydantic==2.10.6 \ - jinja2==3.1.6 \ - dnspython==2.7.0 - -FROM ubuntu/bind9:latest AS runtime - -ENV LANG=C.UTF-8 \ - DEBIAN_FRONTEND=noninteractive \ - VIRTUAL_ENV=/venvs/.venv \ - PATH="/venvs/.venv/bin:$PATH" \ - PYTHONDONTWRITEBYTECODE=1 \ - PYTHONUNBUFFERED=1 - -RUN apt update -RUN apt install -y python3.13 - -COPY --from=builder ${VIRTUAL_ENV} ${VIRTUAL_ENV} - -RUN ln -sf /usr/bin/python3.13 /venvs/.venv/bin/python - -COPY .dns/ /server/ -WORKDIR /server - -RUN chown bind:bind /opt - -RUN mkdir /var/log/named && \ - touch /var/log/named/bind.log && \ - chown bind:bind /var/log/named && \ - chmod 755 /var/log/named && \ - chmod 644 /var/log/named/bind.log - -EXPOSE 8000 - -ENTRYPOINT [ "./entrypoint.sh" ] diff --git a/.docker/dev.Dockerfile b/.docker/dev.Dockerfile index 0e89ebe96..ca3ac7295 100644 --- a/.docker/dev.Dockerfile +++ b/.docker/dev.Dockerfile @@ -33,7 +33,7 @@ ENV VIRTUAL_ENV=/venvs/.venv \ PATH="/venvs/.venv/bin:$PATH" \ VERSION=${VERSION:-beta} -RUN set -eux; apk add --no-cache krb5-libs curl openssl netcat-openbsd +RUN set -eux; apk add --no-cache krb5-libs curl openssl netcat-openbsd libsodium-dev COPY app /app COPY pyproject.toml / diff --git a/.docker/krb.Dockerfile b/.docker/krb.Dockerfile index afbee2892..5533b1b2a 100644 --- a/.docker/krb.Dockerfile +++ b/.docker/krb.Dockerfile @@ -7,12 +7,13 @@ ENV VIRTUAL_ENV=/venvs/.venv \ PATH="/venvs/.venv/bin:$PATH" WORKDIR /venvs +COPY .kerberos/kadmin_local-0.1.1.tar.gz / RUN python -m venv .venv RUN pip install \ fastapi \ uvicorn \ - https://github.com/xianglei/python-kadmv/releases/download/0.1.7/python-kadmV-0.1.7.tar.gz + /kadmin_local-0.1.1.tar.gz FROM ghcr.io/multidirectorylab/krb5_base:${VERSION} AS runtime diff --git a/.docker/pdns_auth.Dockerfile b/.docker/pdns_auth.Dockerfile new file mode 100644 index 000000000..6298932ff --- /dev/null +++ b/.docker/pdns_auth.Dockerfile @@ -0,0 +1,66 @@ +FROM alpine:3.20 AS builder + +RUN apk add --no-cache --virtual .build-deps \ + build-base \ + lmdb-dev \ + openssl-dev \ + boost-dev \ + autoconf automake libtool \ + git ragel bison flex \ + lua5.4-dev \ + curl-dev + +RUN apk add --no-cache \ + lua \ + lua-dev \ + lmdb \ + boost-libs \ + openssl-libs-static \ + curl \ + libstdc++ + +RUN git clone https://github.com/PowerDNS/pdns.git /pdns +WORKDIR /pdns + +RUN git submodule init &&\ + git submodule update &&\ + git checkout auth-5.0.1 + +RUN autoreconf -vi + +RUN mkdir /build && \ + ./configure \ + --sysconfdir=/etc/powerdns \ + --enable-option-checking=fatal \ + --with-dynmodules='lmdb' \ + --with-modules='' \ + --with-unixodbc-lib=/usr/lib/$(dpkg-architecture -q DEB_BUILD_GNU_TYPE) && \ + make clean && \ + make $MAKEFLAGS -C ext &&\ + make $MAKEFLAGS -C modules &&\ + make $MAKEFLAGS -C pdns && \ + make -C pdns install DESTDIR=/build &&\ + make -C modules install DESTDIR=/build &&\ + make clean && \ + strip /build/usr/local/bin/* /build/usr/local/sbin/* /build/usr/local/lib/pdns/*.so + +FROM alpine:3.20 AS runtime + +COPY --from=builder /build / + +RUN apk add --no-cache \ + lua \ + lua-dev \ + lmdb \ + boost-libs \ + openssl-libs-static \ + curl \ + libstdc++ + +RUN mkdir -p /etc/powerdns/pdns.d /var/run/pdns /var/lib/powerdns /etc/powerdns/templates.d /var/lib/pdns-lmdb + +COPY ./.package/pdns.conf /etc/powerdns/pdns.conf + +EXPOSE 8082/tcp + +CMD ["/usr/local/sbin/pdns_server"] \ No newline at end of file diff --git a/.docker/test.Dockerfile b/.docker/test.Dockerfile index da4c9461e..288e59eea 100644 --- a/.docker/test.Dockerfile +++ b/.docker/test.Dockerfile @@ -27,7 +27,7 @@ RUN --mount=type=cache,target=/root/.cache/uv \ FROM python:3.13.7-alpine3.21 AS runtime WORKDIR /app -RUN set -eux; apk add --no-cache openldap-clients openssl curl krb5-libs +RUN set -eux; apk add --no-cache openldap-clients openssl curl krb5-libs libsodium-dev ENV VIRTUAL_ENV=/venvs/.venv \ PATH="/venvs/.venv/bin:$PATH" \ diff --git a/.github/workflows/build-beta.yml b/.github/workflows/build-beta.yml index 8e2f0b4a9..de070e1c7 100644 --- a/.github/workflows/build-beta.yml +++ b/.github/workflows/build-beta.yml @@ -156,7 +156,7 @@ jobs: --build-arg BUILDKIT_INLINE_CACHE=1 \ --build-arg VERSION=beta - build-bind9: + build-pdns_auth: runs-on: ubuntu-latest needs: [build-tests, run-ssh-test, run-tests] steps: @@ -173,14 +173,14 @@ jobs: - name: Build docker image env: - TAG: ghcr.io/${{ env.REPO }}_bind9:beta + TAG: ghcr.io/${{ env.REPO }}_pdns_auth:beta DOCKER_BUILDKIT: '1' run: | echo $TAG docker build \ --push \ --target=runtime \ - -f .docker/bind9.Dockerfile . \ + -f .docker/pdns_auth.Dockerfile . \ -t $TAG \ --cache-to type=gha,mode=max \ --cache-from $TAG \ diff --git a/.github/workflows/build-dev.yml b/.github/workflows/build-dev.yml index 6bd6a2ce6..a02d3d70a 100644 --- a/.github/workflows/build-dev.yml +++ b/.github/workflows/build-dev.yml @@ -155,7 +155,7 @@ jobs: --build-arg BUILDKIT_INLINE_CACHE=1 \ --build-arg VERSION=dev - build-bind9: + build-pdns_auth: runs-on: ubuntu-latest needs: [build-tests, run-ssh-test, run-tests] steps: @@ -172,14 +172,14 @@ jobs: - name: Build docker image env: - TAG: ghcr.io/${{ env.REPO }}_bind9:dev + TAG: ghcr.io/${{ env.REPO }}_pdns_auth:dev DOCKER_BUILDKIT: '1' run: | echo $TAG docker build \ --push \ --target=runtime \ - -f .docker/bind9.Dockerfile . \ + -f .docker/pdns_auth.Dockerfile . \ -t $TAG \ --cache-to type=gha,mode=max \ --cache-from $TAG \ diff --git a/.github/workflows/build-docker-image.yml b/.github/workflows/build-docker-image.yml index a607b3c7a..367f4e0bc 100644 --- a/.github/workflows/build-docker-image.yml +++ b/.github/workflows/build-docker-image.yml @@ -176,7 +176,7 @@ jobs: --build-arg BUILDKIT_INLINE_CACHE=1 \ --build-arg VERSION=latest - build-bind9: + build-pdns_auth: runs-on: ubuntu-latest needs: [build-tests, run-ssh-test, run-tests] steps: @@ -193,14 +193,14 @@ jobs: - name: Build docker image env: - TAG: ghcr.io/${{ env.REPO }}_bind9:latest + TAG: ghcr.io/${{ env.REPO }}_pdns_auth:latest DOCKER_BUILDKIT: '1' run: | echo $TAG docker build \ --push \ --target=runtime \ - -f .docker/bind9.Dockerfile . \ + -f .docker/pdns_auth.Dockerfile . \ -t $TAG \ --cache-to type=gha,mode=max \ --cache-from $TAG \ diff --git a/.gitignore b/.gitignore index 9c09001ca..11596e8f9 100644 --- a/.gitignore +++ b/.gitignore @@ -30,6 +30,7 @@ share/python-wheels/ .installed.cfg *.egg MANIFEST +resolve.conf # PyInstaller # Usually these files are written by a python script from a template diff --git a/.gitmodules b/.gitmodules index 7ea94c39a..e69de29bb 100644 --- a/.gitmodules +++ b/.gitmodules @@ -1,4 +0,0 @@ -[submodule "interface"] - path = interface - url = https://github.com/MultifactorLab/MultiDirectory-Web-Admin.git - ignore = all diff --git a/.kerberos/config_server.py b/.kerberos/config_server.py index 2806c9b86..d3cc24dfb 100644 --- a/.kerberos/config_server.py +++ b/.kerberos/config_server.py @@ -31,7 +31,7 @@ ) from fastapi.middleware.cors import CORSMiddleware from fastapi.responses import FileResponse -from pydantic import BaseModel +from pydantic import BaseModel, Field from starlette.background import BackgroundTask KRB5_CONF_PATH = "/etc/krb5.conf" @@ -79,6 +79,30 @@ class PrincipalNotFoundError(Exception): """Not found error.""" +class AddPrincipalRequest(BaseModel): + """Request model for adding principal.""" + + principal_name: str + password: str | None = None + algorithms: list[str] | None = None + + +class KtaddRequest(BaseModel): + """Request model for ktadd.""" + + names: list[str] + is_rand_key: bool = Field(default=False) + + +class ModifyPrincipalRequest(BaseModel): + """Request model for modifying principal.""" + + principal_name: str + new_name: str | None = None + algorithms: list[str] | None = None + password: str | None = None + + class AbstractKRBManager(ABC): """Kadmin manager.""" @@ -95,12 +119,14 @@ async def add_princ( self, name: str, password: str | None, + algorithms: list[str] | None = None, **dbargs, ) -> None: """Create principal. :param str name: principal - :param str | None password: if empty - uses randkey. + :param str | None password: if None - uses randkey. + :param list[str] | None algorithms: encryption algorithms """ @abstractmethod @@ -135,19 +161,17 @@ async def del_princ(self, name: str) -> None: """ @abstractmethod - async def rename_princ(self, name: str, new_name: str) -> None: - """Rename principal. - - :param str name: original name - :param str new_name: new name - """ - - @abstractmethod - async def ktadd(self, names: list[str], fn: str) -> None: + async def ktadd( + self, + names: list[str], + fn: str, + is_rand_key: bool = False, + ) -> None: """Create or write to keytab. - :param str name: principal + :param list[str] names: principals :param str fn: filename + :param bool is_rand_key: generate new principal keys """ @abstractmethod @@ -164,6 +188,23 @@ async def force_pw_principal(self, name: str, **dbargs) -> None: :param str name: principal """ + @abstractmethod + async def modify_principal( + self, + principal_name: str, + new_name: str | None = None, + algorithms: list[str] | None = None, + password: str | None = None, + **dbargs, + ) -> None: + """Modify principal (rename, change algorithms, password). + + :param str principal_name: current principal name + :param str | None new_name: new name if rename needed + :param list[str] | None algorithms: new encryption algorithms + :param str | None password: new password + """ + class KAdminLocalManager(AbstractKRBManager): """Kadmin manager.""" @@ -206,18 +247,21 @@ async def add_princ( self, name: str, password: str | None, + algorithms: list[str] | None = None, **dbargs, ) -> None: """Create principal. :param str name: principal - :param str | None password: if empty - uses randkey. + :param str | None password: if None - uses randkey. + :param list[str] | None algorithms: encryption algorithms """ await self.loop.run_in_executor( self.pool, self.client.add_principal, name, password, + algorithms, ) if password: @@ -287,32 +331,30 @@ async def del_princ(self, name: str) -> None: except kadmv.UnknownPrincipalError: raise PrincipalNotFoundError - async def rename_princ(self, name: str, new_name: str) -> None: - """Rename principal. - - :param str name: original name - :param str new_name: new name - """ - await self.loop.run_in_executor( - self.pool, - self.client.rename_principal, - name, - new_name, - ) - - async def ktadd(self, names: list[str], fn: str) -> None: + async def ktadd( + self, + names: list[str], + fn: str, + is_rand_key: bool = True, + ) -> None: """Create or write to keytab. - :param str name: principal + :param list[str] names: principals :param str fn: filename - :raises self.PrincipalNotFoundError: on not found princ + :param bool is_rand_key: generate new principal keys + :raises PrincipalNotFoundError: on not found princ """ principals = [await self._get_raw_principal(name) for name in names] if not all(principals): raise PrincipalNotFoundError("Principal not found") for princ in principals: - await self.loop.run_in_executor(self.pool, princ.ktadd, fn) + await self.loop.run_in_executor( + self.pool, + princ.ktadd, + fn, + is_rand_key, + ) async def lock_princ(self, name: str, **dbargs) -> None: """Lock princ. @@ -332,6 +374,36 @@ async def force_pw_principal(self, name: str, **dbargs) -> None: princ.pwexpire = "Now" await self.loop.run_in_executor(self.pool, princ.commit) + async def modify_principal( + self, + principal_name: str, + new_name: str | None = None, + algorithms: list[str] | None = None, + password: str | None = None, + **dbargs, + ) -> None: + """Modify principal (rename, change algorithms, password). + + :param str principal_name: current principal name + :param str | None new_name: new name if rename needed + :param list[str] | None algorithms: new encryption algorithms + :param str | None password: new password + """ + args = [] + if new_name: + args.append(new_name) + if password: + args.append(password) + if algorithms: + args.append(algorithms) + + await self.loop.run_in_executor( + self.pool, + self.client.modify_principal, + principal_name, + *args, + ) + @asynccontextmanager async def kadmin_lifespan(app: FastAPI) -> AsyncIterator[None]: @@ -494,16 +566,18 @@ async def reset_setup() -> None: @principal_router.post("", response_class=Response, status_code=201) async def add_princ( kadmin: Annotated[AbstractKRBManager, Depends(get_kadmin)], - name: Annotated[str, Body()], - password: Annotated[str | None, Body(embed=True)] = None, + request: AddPrincipalRequest, ) -> None: """Add principal. :param Annotated[AbstractKRBManager, Depends kadmin: kadmin abstract - :param Annotated[str, Body name: principal name - :param Annotated[str, Body password: principal password + :param AddPrincipalRequest request: request data """ - await kadmin.add_princ(name, password) + await kadmin.add_princ( + request.principal_name, + request.password, + algorithms=request.algorithms, + ) @principal_router.get("") @@ -511,11 +585,10 @@ async def get_princ( kadmin: Annotated[AbstractKRBManager, Depends(get_kadmin)], name: str, ) -> Principal: - """Add principal. + """Get principal. :param Annotated[AbstractKRBManager, Depends kadmin: kadmin abstract - :param Annotated[str, Body name: principal name - :param Annotated[str, Body password: principal password + :param str name: principal name """ return await kadmin.get_princ(name) @@ -525,11 +598,10 @@ async def del_princ( kadmin: Annotated[AbstractKRBManager, Depends(get_kadmin)], name: str, ) -> None: - """Add principal. + """Delete principal. :param Annotated[AbstractKRBManager, Depends kadmin: kadmin abstract - :param Annotated[str, Body name: principal name - :param Annotated[str, Body password: principal password + :param str name: principal name """ await kadmin.del_princ(name) @@ -569,38 +641,43 @@ async def create_or_update_princ_password( @principal_router.put( - "", + "/modify", status_code=status.HTTP_202_ACCEPTED, response_class=Response, ) -async def rename_princ( +async def modify_princ( kadmin: Annotated[AbstractKRBManager, Depends(get_kadmin)], - name: Annotated[str, Body()], - new_name: Annotated[str, Body()], + request: ModifyPrincipalRequest, ) -> None: - """Rename principal. + """Modify principal (rename, algorithms, password). :param Annotated[AbstractKRBManager, Depends kadmin: kadmin abstract - :param Annotated[str, Body name: principal name - :param Annotated[str, Body new_name: principal new name + :param ModifyPrincipalRequest request: request data """ - """""" - await kadmin.rename_princ(name, new_name) + await kadmin.modify_principal( + principal_name=request.principal_name, + new_name=request.new_name, + algorithms=request.algorithms, + password=request.password, + ) @principal_router.post("/ktadd") async def ktadd( kadmin: Annotated[AbstractKRBManager, Depends(get_kadmin)], - names: Annotated[list[str], Body()], + request: KtaddRequest, ) -> FileResponse: """Ktadd principal. :param Annotated[AbstractKRBManager, Depends kadmin: kadmin abstract - :param Annotated[str, Body name: principal name - :param Annotated[str, Body password: principal password + :param KtaddRequest request: request data """ filename = os.path.join(gettempdir(), str(uuid.uuid1())) - await kadmin.ktadd(names, filename) + await kadmin.ktadd( + request.names, + filename, + request.is_rand_key, + ) return FileResponse( filename, diff --git a/.kerberos/kadmin_local-0.1.1.tar.gz b/.kerberos/kadmin_local-0.1.1.tar.gz new file mode 100644 index 000000000..b4cba9055 Binary files /dev/null and b/.kerberos/kadmin_local-0.1.1.tar.gz differ diff --git a/.package/dnsdist.conf b/.package/dnsdist.conf new file mode 100644 index 000000000..c74ab876e --- /dev/null +++ b/.package/dnsdist.conf @@ -0,0 +1,6 @@ +setLocal('0.0.0.0:53') +controlSocket('0.0.0.0:8084') +setKey('supersecretapikey') +addConsoleACL('172.20.0.0/24') +includeDirectory('/etc/dnsdist/conf.d/') +setACL('0.0.0.0/0') diff --git a/.package/docker-compose.yml b/.package/docker-compose.yml index 104a416bd..dc7db924b 100644 --- a/.package/docker-compose.yml +++ b/.package/docker-compose.yml @@ -5,12 +5,12 @@ services: traefik: image: "mirror.gcr.io/traefik:v3.6.1" container_name: traefik + networks: + md_net: restart: unless-stopped command: - "--providers.file.filename=/traefik.yml" ports: - - "53:53" - - "53:53/udp" - "80:80" - "389:389" - "389:389/udp" @@ -42,6 +42,8 @@ services: traefik_certs_dumper: image: ghcr.io/multidirectorylab/multidirectory:${VERSION:-latest} container_name: traefik_certs_dumper + networks: + md_net: restart: "on-failure" env_file: .env @@ -56,6 +58,8 @@ services: interface: image: ghcr.io/multidirectorylab/multidirectory-web-admin:${VERSION:-latest} container_name: multidirectory_interface + networks: + md_net: restart: unless-stopped hostname: interface environment: @@ -74,6 +78,8 @@ services: image: ghcr.io/multidirectorylab/multidirectory:${VERSION:-latest} container_name: multidirectory_migrations restart: "no" + networks: + md_net: env_file: .env command: python multidirectory.py --migrate @@ -81,8 +87,32 @@ services: postgres: condition: service_healthy + dns_migration: + image: ghcr.io/multidirectorylab/multidirectory:${VERSION:-latest} + container_name: multidirectory_dns_migration + networks: + md_net: + restart: "no" + volumes: + - dns_server_file:/opt/ + - dns_server_config:/etc/bind/ + - dnsdist_confd:/dnsdist + env_file: .env + command: python multidirectory.py --migrate_dns + depends_on: + migrations: + condition: service_completed_successfully + pdns_auth: + condition: service_started + pdns_recursor: + condition: service_started + pdnsdist: + condition: service_started + ldap_server: image: ghcr.io/multidirectorylab/multidirectory:${VERSION:-latest} + networks: + md_net: restart: unless-stopped hostname: multidirectory-ldap env_file: @@ -119,7 +149,7 @@ services: - traefik.tcp.routers.ldap.entrypoints=ldap - traefik.tcp.routers.ldap.service=ldap - traefik.tcp.services.ldap.loadbalancer.server.port=389 - - traefik.tcp.services.ldap.loadbalancer.proxyprotocol.version=2 + - traefik.tcp.services.ldap.loadbalancer.serversTransport=ldap_transport@file - traefik.tcp.routers.ldaps.rule=HostSNI(`*`) - traefik.tcp.routers.ldaps.entrypoints=ldaps @@ -127,10 +157,12 @@ services: - traefik.tcp.routers.ldaps.tls=true - traefik.tcp.routers.ldaps.tls.certResolver=md-resolver - traefik.tcp.services.ldaps.loadbalancer.server.port=636 - - traefik.tcp.services.ldaps.loadbalancer.proxyprotocol.version=2 + - traefik.tcp.services.ldaps.loadbalancer.serversTransport=ldap_transport@file cldap_server: image: ghcr.io/multidirectorylab/multidirectory:${VERSION:-latest} + networks: + md_net: restart: unless-stopped environment: - SERVICE_NAME=cldap_server @@ -166,6 +198,8 @@ services: cpus: "0.25" memory: 100M image: ghcr.io/multidirectorylab/multidirectory:${VERSION:-latest} + networks: + md_net: restart: unless-stopped environment: - SERVICE_NAME=global_ldap_server @@ -193,7 +227,7 @@ services: - traefik.tcp.routers.global_ldap.entrypoints=global_ldap - traefik.tcp.routers.global_ldap.service=global_ldap - traefik.tcp.services.global_ldap.loadbalancer.server.port=3268 - - traefik.tcp.services.global_ldap.loadbalancer.proxyprotocol.version=2 + - traefik.tcp.services.global_ldap.loadbalancer.serversTransport=ldap_transport@file - traefik.tcp.routers.global_ldap_tls.rule=HostSNI(`*`) - traefik.tcp.routers.global_ldap_tls.entrypoints=global_ldap_tls @@ -201,11 +235,13 @@ services: - traefik.tcp.routers.global_ldap_tls.tls=true - traefik.tcp.routers.global_ldap_tls.tls.certresolver=md-resolver - traefik.tcp.services.global_ldap_tls.loadbalancer.server.port=3269 - - traefik.tcp.services.global_ldap_tls.loadbalancer.proxyprotocol.version=2 + - traefik.tcp.services.global_ldap_tls.loadbalancer.serversTransport=ldap_transport@file api_server: image: ghcr.io/multidirectorylab/multidirectory:${VERSION:-latest} container_name: multidirectory_api + networks: + md_net: restart: unless-stopped env_file: .env @@ -216,6 +252,7 @@ services: - dns_server_config:/DNS_server_configs/ - ldap_keytab:/LDAP_keytab/ - ./resolv.conf:/resolv.conf + - dnsdist_confd:/dnsdist hostname: api_server environment: USE_CORE_TLS: 1 @@ -230,7 +267,6 @@ services: - "traefik.http.routers.api.service=api" - "traefik.http.routers.api.middlewares=api_strip" - "traefik.http.middlewares.api_strip.stripprefix.prefixes=/api" - - "traefik.http.middlewares.api_strip.stripprefix.forceslash=false" command: python multidirectory.py --http depends_on: @@ -239,6 +275,8 @@ services: postgres: container_name: MD-postgres + networks: + md_net: image: mirror.gcr.io/postgres:16 restart: unless-stopped env_file: @@ -260,6 +298,8 @@ services: cert_check: image: ghcr.io/multidirectorylab/multidirectory:${VERSION:-latest} container_name: multidirectory_certs_check + networks: + md_net: restart: "no" volumes: - ./certs:/certs @@ -268,6 +308,8 @@ services: maintence: image: ghcr.io/multidirectorylab/multidirectory:${VERSION:-latest} container_name: md_maintence + networks: + md_net: restart: unless-stopped volumes: - ./certs:/certs @@ -292,6 +334,8 @@ services: kdc: container_name: kdc + networks: + md_net: restart: unless-stopped hostname: kerberos volumes: @@ -308,6 +352,8 @@ services: kadmin_api: image: ghcr.io/multidirectorylab/multidirectory_kerberos:${VERSION:-latest} container_name: kadmin_api + networks: + md_net: restart: unless-stopped volumes: - ./certs:/certs @@ -322,9 +368,12 @@ services: condition: service_started working_dir: /server command: ./entrypoint.sh + kadmind: image: ghcr.io/multidirectorylab/multidirectory_kerberos:${VERSION:-latest} container_name: kadmind + networks: + md_net: restart: unless-stopped hostname: kerberos volumes: @@ -347,27 +396,58 @@ services: - traefik.tcp.routers.kpasswd.service=kpasswd - traefik.tcp.services.kpasswd.loadbalancer.server.port=464 - bind_dns: - image: ghcr.io/multidirectorylab/multidirectory_bind9:${VERSION:-latest} - container_name: bind9 - hostname: bind9 - restart: unless-stopped + pdns_auth: + image: ghcr.io/multidirectorylab/multidirectory_pdns_auth:${VERSION:-latest} + container_name: pdns_auth + cap_add: + - NET_ADMIN + networks: + default: + md_net: + ipv4_address: 172.20.0.202 + expose: + - 8082 + - 53/udp + - 53/tcp volumes: - - dns_server_file:/opt/ - - dns_server_config:/etc/bind/ - tty: true - env_file: - - .env - environment: - - USE_CONFIG_FILE_LOGGING=true - depends_on: - ldap_server: - condition: service_healthy - restart: true - labels: - - traefik.enable=true - - traefik.udp.routers.bind_dns_udp.entrypoints=bind_dns_udp - - traefik.udp.services.bind_dns_udp.loadbalancer.server.port=53 + - dns_lmdb:/var/lib/pdns-lmdb + - ./pdns.conf:/etc/powerdns/pdns.conf + + + pdns_recursor: + image: powerdns/pdns-recursor-51:5.1.7 + container_name: pdns_recursor + cap_add: + - NET_ADMIN + networks: + default: + md_net: + ipv4_address: 172.20.0.200 + expose: + - 8083 + - 53/udp + - 53/tcp + volumes: + - ./recursor.conf:/etc/powerdns/recursor.conf + - forward_zones:/etc/powerdns/recursor.d/ + + pdnsdist: + image: powerdns/dnsdist-19:1.9.11 + container_name: pdnsdist + cap_add: + - NET_ADMIN + networks: + default: + md_net: + ipv4_address: 172.20.0.201 + expose: + - 8084 + ports: + - "53:53/tcp" + - "53:53/udp" + volumes: + - ./dnsdist.conf:/etc/dnsdist/dnsdist.conf + - dnsdist_confd:/etc/dnsdist/conf.d kea_dhcp4: image: ghcr.io/multidirectorylab/multidirectory_dhcp4:${VERSION:-latest} @@ -389,6 +469,8 @@ services: kea_ctrl_agent: image: jonasal/kea-ctrl-agent:3.1.2-alpine container_name: kea_ctrl_agent + networks: + md_net: restart: unless-stopped command: -c /kea/config/kea-ctrl-agent.conf tty: true @@ -403,6 +485,8 @@ services: dragonfly_mem: image: 'docker.dragonflydb.io/dragonflydb/dragonfly' container_name: dragonfly + networks: + md_net: restart: unless-stopped volumes: - dragonflydata:/data @@ -424,6 +508,8 @@ services: shadow_api: image: ghcr.io/multidirectorylab/multidirectory:${VERSION:-latest} container_name: shadow_api + networks: + md_net: restart: unless-stopped tty: true volumes: @@ -440,6 +526,8 @@ services: event_handler: image: ghcr.io/multidirectorylab/multidirectory:${VERSION:-latest} container_name: event_handler + networks: + md_net: restart: unless-stopped tty: true env_file: @@ -454,6 +542,8 @@ services: event_sender: image: ghcr.io/multidirectorylab/multidirectory:${VERSION:-latest} container_name: event_sender + networks: + md_net: restart: unless-stopped tty: true depends_on: @@ -467,6 +557,14 @@ services: environment: HANDLER_NAME: event_sender-1 +networks: + md_net: + driver: bridge + ipam: + config: + - subnet: 172.20.0.0/24 + gateway: 172.20.0.1 + volumes: postgres: kdc: @@ -478,3 +576,7 @@ volumes: leases: sockets: dhcp: + dns_lmdb: + dns_config: + forward_zones: + dnsdist_confd: diff --git a/.package/pdns.conf b/.package/pdns.conf new file mode 100644 index 000000000..80635dc2b --- /dev/null +++ b/.package/pdns.conf @@ -0,0 +1,11 @@ +launch=lmdb +lmdb-filename=/var/lib/pdns-lmdb/pdns.lmdb +daemon=no +local-address=0.0.0.0 +local-port=53 +api=yes +api-key=supersecretapikey +webserver-allow-from=0.0.0.0/0 +webserver=yes +webserver-address=0.0.0.0 +webserver-port=8082 diff --git a/.package/recursor.conf b/.package/recursor.conf new file mode 100644 index 000000000..a47cacc50 --- /dev/null +++ b/.package/recursor.conf @@ -0,0 +1,10 @@ +local-address=0.0.0.0 +webserver-allow-from=0.0.0.0/0 +forward-zones-recurse=.=1.1.1.1;8.8.8.8 +forward-zones= +api-config-dir=/etc/powerdns/recursor.d/ +include-dir=/etc/powerdns/recursor.d/ +webserver=yes +webserver-address=0.0.0.0 +webserver-port=8083 +api-key=supersecretapikey diff --git a/.package/resolv.conf b/.package/resolv.conf new file mode 100644 index 000000000..a151b09ed --- /dev/null +++ b/.package/resolv.conf @@ -0,0 +1,16 @@ +# +# macOS Notice +# +# This file is not consulted for DNS hostname resolution, address +# resolution, or the DNS query routing mechanism used by most +# processes on this system. +# +# To view the DNS configuration used by this system, use: +# scutil --dns +# +# SEE ALSO +# dns-sd(1), scutil(8) +# +# This file is automatically generated. +# +nameserver 192.168.68.1 diff --git a/.package/setup.bat b/.package/setup.bat index 53e08d9e7..d11fa32ec 100644 --- a/.package/setup.bat +++ b/.package/setup.bat @@ -115,3 +115,32 @@ if not exist "certs" ( ) else ( echo Directory already exists: certs ) + +:: 9. DNS_API_KEY +findstr /b /i /c:"PDNS_API_KEY=" .env >nul +if errorlevel 1 ( + set "chars=ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789" + set "pdns_key=" + for /L %%i in (1,1,16) do ( + set /a "rand=!random! %% 62" + for %%j in (!rand!) do set "pdns_key=!pdns_key!!chars:~%%j,1!" + ) + powershell -Command "(gc .\\pdns.conf) -replace supersecretapikey, %pdns_key% | sc .\\pdns.conf -Enc UTF8" + powershell -Command "(gc .\\recursor.conf) -replace supersecretapikey, %pdns_key% | sc .\\recursor.conf -Enc UTF8" + echo PDNS_API_KEY=!pdns_key!>> .env +) + +:: 10. DNSDIST_API_KEY +findstr /b /i /c:"PDNS_DIST_KEY=" .env >nul +if errorlevel 1 ( + for /f %%i in ('powershell -command "[Convert]::ToBase64String((1..32|%%{[byte](Get-Random -Max 256)}))"') do set "randkey=%%i" + powershell -Command "(gc .\\dnsdist.conf) -replace supersecretapikey, %randkey% | sc .\\dnsdist.conf -Enc UTF8" + echo PDNS_DIST_KEY=!randkey!>> .env +) + +:: 9. HOST_MACHINE_NAME +findstr /b /i /c:"HOST_MACHINE_NAME=" .env >nul +if errorlevel 1 ( + set "host_machine_name=%COMPUTERNAME%" + echo HOST_MACHINE_NAME=!host_machine_name!>> .env +) diff --git a/.package/setup.sh b/.package/setup.sh index 3e510b402..7f6545a19 100755 --- a/.package/setup.sh +++ b/.package/setup.sh @@ -79,3 +79,24 @@ if [ ! -d "certs" ]; then else echo "Directory already exists: certs" fi + +# DNS_API_KEY +if ! get_env_var "PDNS_API_KEY"; then + dns_api_key=$(openssl rand -hex 16) + sed -i "s|supersecretapikey|${dns_api_key}|g" recursor.conf + sed -i "s|supersecretapikey|${dns_api_key}|g" pdns.conf + add_env_var "PDNS_API_KEY" "$dns_api_key" +fi + +# DNSDIST_API_KEY +if ! get_env_var "PDNS_DIST_KEY"; then + dnsdist_key=$(openssl rand -base64 32) + sed -i "s|supersecretapikey|${dnsdist_key}|g" dnsdist.conf + add_env_var "PDNS_DIST_KEY" "$dnsdist_key" +fi + +# HOST_MACHINE_NAME +if ! get_env_var "HOST_MACHINE_NAME"; then + host_machine_name=$(hostname) + add_env_var "HOST_MACHINE_NAME" "$host_machine_name" +fi diff --git a/.package/traefik.yml b/.package/traefik.yml index bb8d711de..e672aafd7 100644 --- a/.package/traefik.yml +++ b/.package/traefik.yml @@ -7,6 +7,12 @@ api: ping: entryPoint: "ping" +tcp: + serversTransports: + ldap_transport: + proxyProtocol: + version: 2 + entryPoints: ping: address: ":8800" @@ -40,8 +46,6 @@ entryPoints: address: ":749" kpasswd: address: ":464" - bind_dns_udp: - address: ":53/udp" websecure: address: ":443" http: diff --git a/app/alembic/versions/16a9fa2c1f1e_rename_readonly_group.py b/app/alembic/versions/16a9fa2c1f1e_rename_readonly_group.py index b331dddd5..3cfbca4a2 100644 --- a/app/alembic/versions/16a9fa2c1f1e_rename_readonly_group.py +++ b/app/alembic/versions/16a9fa2c1f1e_rename_readonly_group.py @@ -43,7 +43,6 @@ def upgrade(container: AsyncContainer) -> None: # noqa: ARG001 return ro_dir.name = READ_ONLY_GROUP_NAME - ro_dir.create_path(ro_dir.parent, ro_dir.get_dn_prefix()) session.execute( @@ -91,7 +90,6 @@ def downgrade(container: AsyncContainer) -> None: # noqa: ARG001 return ro_dir.name = "readonly domain controllers" - ro_dir.create_path(ro_dir.parent, ro_dir.get_dn_prefix()) session.execute( diff --git a/app/alembic/versions/19d86e660cf2_fix_krbadmin_access.py b/app/alembic/versions/19d86e660cf2_fix_krbadmin_access.py new file mode 100644 index 000000000..87ed5c90d --- /dev/null +++ b/app/alembic/versions/19d86e660cf2_fix_krbadmin_access.py @@ -0,0 +1,52 @@ +"""Fix krbadmin access. + +Revision ID: 19d86e660cf2 +Revises: ebf19750805e +Create Date: 2026-02-19 11:40:15.805997 + +""" + +from alembic import op +from dishka import AsyncContainer, Scope +from sqlalchemy.ext.asyncio import AsyncConnection, AsyncSession + +from enums import RoleConstants +from ldap_protocol.roles.exceptions import RoleNotFoundError +from ldap_protocol.roles.role_dao import RoleDAO +from ldap_protocol.roles.role_use_case import RoleUseCase +from ldap_protocol.utils.queries import get_base_directories + +# revision identifiers, used by Alembic. +revision: None | str = "19d86e660cf2" +down_revision: None | str = "ebf19750805e" +branch_labels: None | list[str] = None +depends_on: None | list[str] = None + + +def upgrade(container: AsyncContainer) -> None: + """Upgrade.""" + + async def _fix_krbadmin_role(connection: AsyncConnection) -> None: # noqa: ARG001 + async with container(scope=Scope.REQUEST) as cnt: + session = await cnt.get(AsyncSession) + role_dao = await cnt.get(RoleDAO) + role_use_case = await cnt.get(RoleUseCase) + + base_dn_list = await get_base_directories(session) + if not base_dn_list: + return + + try: + await role_dao.get_by_name(RoleConstants.KERBEROS_ROLE_NAME) + except RoleNotFoundError: + return + else: + await role_use_case.add_read_only_role_to_krbadmin_group() + + await session.commit() + + op.run_async(_fix_krbadmin_role) + + +def downgrade(container: AsyncContainer) -> None: + """Downgrade.""" diff --git a/app/alembic/versions/275222846605_initial_ldap_schema.py b/app/alembic/versions/275222846605_initial_ldap_schema.py index 226c9270b..6994b0c77 100644 --- a/app/alembic/versions/275222846605_initial_ldap_schema.py +++ b/app/alembic/versions/275222846605_initial_ldap_schema.py @@ -50,12 +50,11 @@ def upgrade(container: AsyncContainer) -> None: sa.Column("single_value", sa.Boolean(), nullable=False), sa.Column("no_user_modification", sa.Boolean(), nullable=False), sa.Column("is_system", sa.Boolean(), nullable=False), - sa.Column( - "is_included_anr", - sa.Boolean(), - nullable=True, - ), # NOTE: added in f24ed0e49df2_add_filter_anr.py sa.PrimaryKeyConstraint("id"), + # NOTE: it added in 2dadf40c026a_add_system_flags_to_attribute_types.py + sa.Column("system_flags", sa.Integer(), nullable=False), + # NOTE: it added in f24ed0e49df2_add_filter_anr.py + sa.Column("is_included_anr", sa.Boolean(), nullable=True), ) op.create_index( op.f("ix_AttributeTypes_oid"), @@ -359,6 +358,7 @@ async def _create_attribute_types(connection: AsyncConnection) -> None: # noqa: single_value=True, no_user_modification=False, is_system=True, + system_flags=0, is_included_anr=False, ), ) @@ -400,6 +400,9 @@ async def _modify_object_classes(connection: AsyncConnection) -> None: # noqa: # NOTE: it added in f24ed0e49df2_add_filter_anr.py op.drop_column("AttributeTypes", "is_included_anr") + # NOTE: it added in 2dadf40c026a_add_system_flags_to_attribute_types.py + op.drop_column("AttributeTypes", "system_flags") + session.commit() diff --git a/app/alembic/versions/2dadf40c026a_add_system_flags_to_attribute_types.py b/app/alembic/versions/2dadf40c026a_add_system_flags_to_attribute_types.py new file mode 100644 index 000000000..b819c1c86 --- /dev/null +++ b/app/alembic/versions/2dadf40c026a_add_system_flags_to_attribute_types.py @@ -0,0 +1,172 @@ +"""Add systemFlags for AttributeTypes. + +Revision ID: 2dadf40c026a +Revises: f4e6cd18a01d +Create Date: 2026-02-04 09:33:33.218126 + +""" + +import contextlib + +import sqlalchemy as sa +from alembic import op +from dishka import AsyncContainer, Scope +from sqlalchemy.ext.asyncio import AsyncConnection, AsyncSession +from sqlalchemy.orm import Session + +from entities import AttributeType +from ldap_protocol.ldap_schema.attribute_type_use_case import ( + AttributeTypeUseCase, +) +from ldap_protocol.ldap_schema.exceptions import AttributeTypeNotFoundError + +# revision identifiers, used by Alembic. +revision: None | str = "2dadf40c026a" +down_revision: None | str = "f4e6cd18a01d" +branch_labels: None | list[str] = None +depends_on: None | list[str] = None + + +_NON_REPLICATED_ATTRIBUTES_TYPE_NAMES = ( + "badPasswordTime", + "badPwdCount", + "bridgeheadServerListBL", + "dSCorePropagationData", + "frsComputerReferenceBL", + "fRSMemberReferenceBL", + "isMemberOfDL", + "isPrivilegeHolder", + "lastLogoff", + "lastLogon", + "logonCount", + "managedObjects", + "masteredBy", + "modifiedCount", + "msCOMPartitionSetLink", + "msCOMUserLink", + "msDSAuthenticatedToAccountlist", + "msDSCachedMembership", + "msDSCachedMembershipTimeStamp", + "msDSEnabledFeatureBL", + "msDSExecuteScriptPassword", + "msDSHostServiceAccountBL", + "msDSMasteredBy", + "msDSOIDToGroupLinkBL", + "msDSPSOApplied", + "msDSMembersForAzRoleBL", + "msDSNCType", + "msDSNonMembersBL", + "msDSObjectReferenceBL", + "msDSOperationsForAzRoleBL", + "msDSOperationsForAzTaskBL", + "msDSNCROReplicaLocationsBL", + "msDSReplicationEpoch", + "msDSRetiredReplNCSignatures", + "msDSTasksForAzRoleBL", + "msDSTasksForAzTaskBL", + "msDSRevealedDSAs", + "msDSKrbTgtLinkBL", + "msDSIsFullReplicaFor", + "msDSIsDomainFor", + "msDSIsPartialReplicaFor", + "msDSUSNLastSyncSuccess", + "msDSValueTypeReferenceBL", + "msDSTokenGroupNames", + "msDSTokenGroupNamesGlobalAndUniversal", + "msDSTokenGroupNamesNoGCAcceptable", + "msExchOwnerBL", + "msDFSRMemberReferenceBL", + "msDFSRComputerReferenceBL", + "netbootSCPBL", + "nonSecurityMemberBL", + "objDistName", + "objectGuid", + "partialAttributeDeletionList", + "partialAttributeSet", + "pekList", + "prefixMap", + "queryPolicyBL", + "replPropertyMetaData", + "replUpToDateVector", + "reports", + "repsFrom", + "repsTo", + "rIDNextRID", + "rIDPreviousAllocationPool", + "schemaUpdate", + "serverReferenceBL", + "serverState", + "siteObjectBL", + "subRefs", + "uSNChanged", + "uSNCreated", + "uSNLastObjRem", + "whenChanged", + "msSFU30PosixMemberOf", + "msTSPrimaryDesktopBL", + "msTSSecondaryDesktopBL", + "msDSBridgeHeadServersUsed", + "msDSClaimSharesPossibleValuesWithBL", + "msDSMembersOfResourcePropertyListBL", + "msTPMTpmInformationForComputerBL", + "msAuthzMemberRulesInCentralAccessPolicyBL", + "msDSGenerationId", + "msDSIsPrimaryComputerFor", + "msDSTDOEgressBL", + "msDSTDOIngressBL", + "msDSTransformationRulesCompiled", + "msDSIsMemberOfDLTransitive", + "msDSMemberTransitive", + "msDSParentDistName", + "msDSAssignedAuthNPolicySiloBL", + "msDSAuthNPolicySiloMembersBL", + "msDSUserAuthNPolicyBL", + "msDSComputerAuthNPolicyBL", + "msDSServiceAuthNPolicyBL", + "msDSAssignedAuthNPolicyBL", + "msDSKeyPrincipalBL", + "msDSKeyCredentialLinkBL", +) + + +def upgrade(container: AsyncContainer) -> None: + """Upgrade.""" + bind = op.get_bind() + session = Session(bind=bind) + + op.add_column( + "AttributeTypes", + sa.Column( + "system_flags", + sa.Integer(), + nullable=True, + server_default=sa.text("0"), + ), + ) + + session.execute(sa.update(AttributeType).values({"system_flags": 0})) + + async def _set_attr_replication_flag(connection: AsyncConnection) -> None: # noqa: ARG001 + async with container(scope=Scope.REQUEST) as cnt: + session = await cnt.get(AsyncSession) + at_type_use_case = await cnt.get(AttributeTypeUseCase) + + for name in _NON_REPLICATED_ATTRIBUTES_TYPE_NAMES: + with contextlib.suppress(AttributeTypeNotFoundError): + await at_type_use_case.set_attr_replication_flag( + name, + need_to_replicate=False, + ) + + await session.commit() + + op.run_async(_set_attr_replication_flag) + + op.alter_column("AttributeTypes", "system_flags", nullable=False) + + session.commit() + + +def downgrade(container: AsyncContainer) -> None: # noqa: ARG001 + """Downgrade.""" + op.drop_column("AttributeTypes", "system_flags") diff --git a/app/alembic/versions/379fce54fb08_rename_base_cn_to_cc.py b/app/alembic/versions/379fce54fb08_rename_base_cn_to_cc.py new file mode 100644 index 000000000..6d88c6d0f --- /dev/null +++ b/app/alembic/versions/379fce54fb08_rename_base_cn_to_cc.py @@ -0,0 +1,142 @@ +"""Rename base containers. + +users -> Users, groups -> Groups, computers -> Computers. + +Revision ID: 379fce54fb08 +Revises: ec45e3e8aa0f +Create Date: 2026-01-23 12:26:10.758698 + +""" + +from alembic import op +from dishka import AsyncContainer, Scope +from sqlalchemy import select +from sqlalchemy.ext.asyncio import AsyncConnection, AsyncSession + +from entities import Attribute, Directory +from repo.pg.tables import queryable_attr as qa + +# revision identifiers, used by Alembic. +revision: None | str = "379fce54fb08" +down_revision: None | str = "ec45e3e8aa0f" +branch_labels: None | list[str] = None +depends_on: None | list[str] = None + + +CONTAINER_RENAMES = { + "users": "Users", + "groups": "Groups", + "computers": "Computers", +} + + +async def _update_descendants( + session: AsyncSession, + parent_id: int, + cn_from: str, + cn_to: str, +) -> None: + """Recursively update paths of all descendants.""" + child_dirs = await session.scalars( + select(Directory).where(qa(Directory.parent_id) == parent_id), + ) + + for child_dir in child_dirs: + child_dir.path = [cn_to if p == cn_from else p for p in child_dir.path] + await session.flush() + await _update_descendants( + session, + child_dir.id, + cn_from=cn_from, + cn_to=cn_to, + ) + + +async def _update_attributes( + session: AsyncSession, + old_value: str, + new_value: str, +) -> None: + """Update attribute values containing old DN references.""" + result = await session.execute( + select(Attribute).where( + qa(Attribute.value).ilike(f"%{old_value}%"), + ), + ) + attributes = result.scalars().all() + + for attr in attributes: + if attr.value and old_value in attr.value: + attr.value = attr.value.replace(old_value, new_value) + + await session.flush() + + +async def _rename_container( + session: AsyncSession, + old_name: str, + new_name: str, +) -> None: + """Rename a single container and update all references.""" + container_dir = await session.scalar( + select(Directory).where( + qa(Directory.name) == old_name, + qa(Directory.is_system).is_(True), + ), + ) + + if not container_dir: + return + + cn_from = f"cn={old_name}" + cn_to = f"cn={new_name}" + + container_dir.name = new_name + container_dir.path = [ + cn_to if p == cn_from else p for p in container_dir.path + ] + + await session.flush() + + await _update_descendants( + session, + container_dir.id, + cn_from=cn_from, + cn_to=cn_to, + ) + + await _update_attributes(session, cn_from, cn_to) + + +def upgrade(container: AsyncContainer) -> None: + """Upgrade: Rename containers to capitalized versions.""" + + async def _rename_containers( + connection: AsyncConnection, # noqa: ARG001 + ) -> None: + async with container(scope=Scope.REQUEST) as cnt: + session = await cnt.get(AsyncSession) + + for old_name, new_name in CONTAINER_RENAMES.items(): + await _rename_container(session, old_name, new_name) + + await session.commit() + + op.run_async(_rename_containers) + + +def downgrade(container: AsyncContainer) -> None: + """Downgrade: Rename containers back to lowercase.""" + + async def _rename_containers_back( + connection: AsyncConnection, # noqa: ARG001 + ) -> None: + async with container(scope=Scope.REQUEST) as cnt: + session = await cnt.get(AsyncSession) + + for old_name, new_name in CONTAINER_RENAMES.items(): + await _rename_container(session, new_name, old_name) + + await session.commit() + + op.run_async(_rename_containers_back) diff --git a/app/alembic/versions/71e642808369_add_directory_is_system.py b/app/alembic/versions/71e642808369_add_directory_is_system.py index 2526190e4..398d6f6df 100644 --- a/app/alembic/versions/71e642808369_add_directory_is_system.py +++ b/app/alembic/versions/71e642808369_add_directory_is_system.py @@ -14,13 +14,10 @@ from sqlalchemy.orm import Session from constants import ( - COMPUTERS_CONTAINER_NAME, DOMAIN_ADMIN_GROUP_NAME, DOMAIN_COMPUTERS_GROUP_NAME, DOMAIN_USERS_GROUP_NAME, - GROUPS_CONTAINER_NAME, READ_ONLY_GROUP_NAME, - USERS_CONTAINER_NAME, ) from entities import Directory from ldap_protocol.utils.queries import get_base_directories @@ -56,8 +53,13 @@ async def _indicate_system_directories( if not base_dn_list: return - for base_dn in base_dn_list: - base_dn.is_system = True + await session.execute( + update(Directory) + .where( + qa(Directory.parent_id).is_(None), + ) + .values(is_system=True), + ) await session.flush() @@ -67,13 +69,13 @@ async def _indicate_system_directories( qa(Directory.is_system).is_(False), qa(Directory.name).in_( ( - GROUPS_CONTAINER_NAME, + "groups", DOMAIN_ADMIN_GROUP_NAME, DOMAIN_USERS_GROUP_NAME, READ_ONLY_GROUP_NAME, DOMAIN_COMPUTERS_GROUP_NAME, - COMPUTERS_CONTAINER_NAME, - USERS_CONTAINER_NAME, + "computers", + "users", "services", "krbadmin", "kerberos", diff --git a/app/alembic/versions/8164b4a9e1f1_add_ou_computers.py b/app/alembic/versions/8164b4a9e1f1_add_ou_computers.py index 5f8608a4a..a4eb7297c 100644 --- a/app/alembic/versions/8164b4a9e1f1_add_ou_computers.py +++ b/app/alembic/versions/8164b4a9e1f1_add_ou_computers.py @@ -12,7 +12,6 @@ from sqlalchemy import delete, exists, select from sqlalchemy.ext.asyncio import AsyncConnection, AsyncSession -from constants import COMPUTERS_CONTAINER_NAME from entities import Directory from extra.alembic_utils import temporary_stub_column from ldap_protocol.roles.role_use_case import RoleUseCase @@ -26,8 +25,9 @@ depends_on: None = None +COMPUTERS = "computers" _OU_COMPUTERS_DATA = { - "name": COMPUTERS_CONTAINER_NAME, + "name": COMPUTERS, "object_class": "organizationalUnit", "attributes": {"objectClass": ["top", "container"]}, "children": [], @@ -53,7 +53,7 @@ async def _create_ou_computers(connection: AsyncConnection) -> None: # noqa: AR exists_ou_computers = await session.scalar( select( exists(Directory) - .where(qa(Directory.name) == COMPUTERS_CONTAINER_NAME), + .where(qa(Directory.name) == COMPUTERS), ), ) # fmt: skip if exists_ou_computers: @@ -68,7 +68,7 @@ async def _create_ou_computers(connection: AsyncConnection) -> None: # noqa: AR ou_computers_dir = await session.scalar( select(Directory) - .where(qa(Directory.name) == COMPUTERS_CONTAINER_NAME), + .where(qa(Directory.name) == COMPUTERS), ) # fmt: skip if not ou_computers_dir: raise Exception("Directory 'ou=computers' not found.") @@ -97,7 +97,7 @@ async def _delete_ou_computers(connection: AsyncConnection) -> None: # noqa: AR await session.execute( delete(Directory) - .where(qa(Directory.name) == COMPUTERS_CONTAINER_NAME), + .where(qa(Directory.name) == COMPUTERS), ) # fmt: skip await session.commit() diff --git a/app/alembic/versions/ebf19750805e_add_domain_controllers_ou.py b/app/alembic/versions/ebf19750805e_add_domain_controllers_ou.py new file mode 100644 index 000000000..2ad84def9 --- /dev/null +++ b/app/alembic/versions/ebf19750805e_add_domain_controllers_ou.py @@ -0,0 +1,155 @@ +"""Add OU 'Domain Controllers' if it does not exist. + +Revision ID: ebf19750805e +Revises: 2dadf40c026a +Create Date: 2026-02-17 08:52:28.048004 + +""" + +from typing import Any + +from alembic import op +from dishka import AsyncContainer, Scope +from sqlalchemy import delete, exists, select +from sqlalchemy.ext.asyncio import AsyncConnection, AsyncSession + +from config import Settings +from constants import DOMAIN_CONTROLLERS_OU_NAME +from entities import Directory +from enums import SamAccountTypeCodes +from ldap_protocol.auth.setup_gateway import SetupGateway +from ldap_protocol.objects import UserAccountControlFlag +from ldap_protocol.roles.role_use_case import RoleUseCase +from ldap_protocol.utils.queries import get_base_directories +from repo.pg.tables import queryable_attr as qa + +# revision identifiers, used by Alembic. +revision: None | str = "ebf19750805e" +down_revision: None | str = "2dadf40c026a" +branch_labels: None | list[str] = None +depends_on: None | list[str] = None + + +_OU_DOMAIN_CONTROLLERS_DATA: dict[str, Any] = { + "name": DOMAIN_CONTROLLERS_OU_NAME, + "object_class": "organizationalUnit", + "attributes": {"objectClass": ["top", "container"]}, +} + + +def upgrade(container: AsyncContainer) -> None: + """Upgrade.""" + + async def _create_domain_controllers_ou( + connection: AsyncConnection, # noqa: ARG001 + ) -> None: + async with container(scope=Scope.REQUEST) as cnt: + settings = await cnt.get(Settings) + session = await cnt.get(AsyncSession) + setup_gateway = await cnt.get(SetupGateway) + role_use_case = await cnt.get(RoleUseCase) + + base_directories = await get_base_directories(session) + if not base_directories: + return + domain_dir = base_directories[0] + + exists_dc_ou = await session.scalar( + select( + exists(Directory) + .where(qa(Directory.name) == DOMAIN_CONTROLLERS_OU_NAME), + ), + ) # fmt: skip + if exists_dc_ou: + return + + domain_controller_data = [ + { + "name": settings.HOST_MACHINE_SHORT_NAME, + "object_class": "computer", + "attributes": { + "objectClass": ["top"], + "userAccountControl": [ + str( + UserAccountControlFlag.SERVER_TRUST_ACCOUNT.value, + ), + ], + "sAMAccountType": [ + str(SamAccountTypeCodes.SAM_MACHINE_ACCOUNT), + ], + "sAMAccountName": [settings.HOST_MACHINE_SHORT_NAME], + "ipHostNumber": [settings.DEFAULT_NAMESERVER], + }, + }, + ] + _OU_DOMAIN_CONTROLLERS_DATA["children"] = domain_controller_data + + await setup_gateway.create_dir( + _OU_DOMAIN_CONTROLLERS_DATA, + is_system=True, + domain=domain_dir, + parent=domain_dir, + ) + + dc_ou = await session.scalar( + select(Directory).where( + qa(Directory.name) == DOMAIN_CONTROLLERS_OU_NAME, + ), + ) + if not dc_ou: + raise Exception("Domain Controllers OU was not created") + + dc = await session.scalar( + select(Directory).where( + qa(Directory.name) == settings.HOST_MACHINE_SHORT_NAME, + ), + ) + if not dc: + raise Exception("Domain Controller was not created") + + await role_use_case.inherit_parent_aces( + parent_directory=domain_dir, + directory=dc_ou, + ) + await role_use_case.inherit_parent_aces( + parent_directory=dc_ou, + directory=dc, + ) + + await session.commit() + + op.run_async(_create_domain_controllers_ou) + + +def downgrade(container: AsyncContainer) -> None: + """Downgrade.""" + + async def _delete_domain_controllers_ou( + connection: AsyncConnection, # noqa: ARG001 + ) -> None: + async with container(scope=Scope.REQUEST) as cnt: + session = await cnt.get(AsyncSession) + + domain_controller_ou = await session.scalar( + select(Directory).where( + qa(Directory.name) == DOMAIN_CONTROLLERS_OU_NAME, + ), + ) + + if not domain_controller_ou: + return + + await session.execute( + delete(Directory).where( + qa(Directory.parent_id) == domain_controller_ou.id, + ), + ) + + await session.execute( + delete(Directory).where( + qa(Directory.id) == domain_controller_ou.id, + ), + ) + await session.commit() + + op.run_async(_delete_domain_controllers_ou) diff --git a/app/alembic/versions/f4e6cd18a01d_add_samaccounttype.py b/app/alembic/versions/f4e6cd18a01d_add_samaccounttype.py new file mode 100644 index 000000000..efb2d0af5 --- /dev/null +++ b/app/alembic/versions/f4e6cd18a01d_add_samaccounttype.py @@ -0,0 +1,89 @@ +"""Add sAMAccountType to existing user/group/computer entries. + +Revision ID: f4e6cd18a01d +Revises: 379fce54fb08 +Create Date: 2026-01-30 13:08:26.299158 + +""" + +from alembic import op +from dishka import AsyncContainer, Scope +from sqlalchemy import select +from sqlalchemy.ext.asyncio import AsyncConnection, AsyncSession +from sqlalchemy.orm import joinedload + +from entities import Attribute, Directory, EntityType +from enums import EntityTypeNames, SamAccountTypeCodes +from repo.pg.tables import queryable_attr as qa + +revision: None | str = "f4e6cd18a01d" +down_revision: None | str = "379fce54fb08" +branch_labels: None | list[str] = None +depends_on: None | list[str] = None + +_SAM_ACCOUNT_TYPE_ATTR = "sAMAccountType" +_SECURITY_PRINCIPAL_TYPES = ( + EntityTypeNames.USER, + EntityTypeNames.GROUP, + EntityTypeNames.COMPUTER, +) +_ENTITY_TO_SAM: dict[str, SamAccountTypeCodes] = { + EntityTypeNames.USER: SamAccountTypeCodes.SAM_USER_OBJECT, + EntityTypeNames.GROUP: SamAccountTypeCodes.SAM_GROUP_OBJECT, + EntityTypeNames.COMPUTER: SamAccountTypeCodes.SAM_MACHINE_ACCOUNT, +} + + +def upgrade(container: AsyncContainer) -> None: + """Add sAMAccountType attributes for user/group/computer.""" + + async def _add_samaccounttype(connection: AsyncConnection) -> None: # noqa: ARG001 + async with container(scope=Scope.REQUEST) as cnt: + session = await cnt.get(AsyncSession) + + entity_types = await session.scalars( + select(EntityType) + .where(qa(EntityType.name).in_(_SECURITY_PRINCIPAL_TYPES)), + ) # fmt: skip + entity_type_ids = [et.id for et in entity_types] + if not entity_type_ids: + return + + has_sam = select( + qa(Attribute.directory_id), + ).where( + qa(Attribute.name).ilike(_SAM_ACCOUNT_TYPE_ATTR.lower()), + ) + dirs_without_sam = await session.scalars( + select(Directory) + .where( + qa(Directory.entity_type_id).in_(entity_type_ids), + ~qa(Directory.id).in_(has_sam), + ) + .options(joinedload(qa(Directory.entity_type))), + ) + + for directory in dirs_without_sam: + sam_value = ( + _ENTITY_TO_SAM.get(directory.entity_type.name) + if directory.entity_type + else None + ) + if sam_value is None: + continue + + session.add( + Attribute( + name=_SAM_ACCOUNT_TYPE_ATTR, + value=str(sam_value), + directory_id=directory.id, + ), + ) + + await session.commit() + + op.run_async(_add_samaccounttype) + + +def downgrade(container: AsyncContainer) -> None: + """Downgrade.""" diff --git a/app/api/__init__.py b/app/api/__init__.py index 69f1e8f37..ab5235198 100644 --- a/app/api/__init__.py +++ b/app/api/__init__.py @@ -9,8 +9,8 @@ from .auth.router_mfa import mfa_router from .auth.session_router import session_router from .dhcp.router import dhcp_router +from .dns.router import dns_router from .ldap_schema.entity_type_router import ldap_schema_router -from .main.dns_router import dns_router from .main.krb5_router import krb5_router from .main.router import entry_router from .network.router import network_router diff --git a/app/api/audit/adapter.py b/app/api/audit/adapter.py index e7a39665d..9d438beb5 100644 --- a/app/api/audit/adapter.py +++ b/app/api/audit/adapter.py @@ -4,17 +4,17 @@ License: https://github.com/MultiDirectoryLab/MultiDirectory/blob/main/LICENSE """ -from api.base_adapter import BaseAdapter -from ldap_protocol.policies.audit.dataclasses import ( - AuditDestinationDTO, - AuditPolicyDTO, -) -from ldap_protocol.policies.audit.schemas import ( +from api.audit.schemas import ( AuditDestinationResponse, AuditDestinationSchemaRequest, AuditPolicyResponse, AuditPolicySchemaRequest, ) +from api.base_adapter import BaseAdapter +from ldap_protocol.policies.audit.dataclasses import ( + AuditDestinationDTO, + AuditPolicyDTO, +) from ldap_protocol.policies.audit.service import AuditService @@ -49,7 +49,7 @@ async def get_destinations(self) -> list[AuditDestinationResponse]: """Get all audit destinations.""" return [ AuditDestinationResponse( - id=destination.id, # type: ignore + id=destination.id, name=destination.name, service_type=destination.service_type.name.lower(), host=destination.host, diff --git a/app/api/audit/router.py b/app/api/audit/router.py index 4a328e2ef..24ccbe3c9 100644 --- a/app/api/audit/router.py +++ b/app/api/audit/router.py @@ -9,23 +9,24 @@ from fastapi_error_map.routing import ErrorAwareRouter from fastapi_error_map.rules import rule +from api.audit.schemas import ( + AuditDestinationResponse, + AuditDestinationSchemaRequest, + AuditPolicyResponse, + AuditPolicySchemaRequest, +) from api.auth.utils import verify_auth from api.error_routing import ( ERROR_MAP_TYPE, DishkaErrorAwareRoute, DomainErrorTranslator, ) +from api.utils import require_master_db from enums import DomainCodes from ldap_protocol.policies.audit.exception import ( AuditAlreadyExistsError, AuditNotFoundError, ) -from ldap_protocol.policies.audit.schemas import ( - AuditDestinationResponse, - AuditDestinationSchemaRequest, - AuditPolicyResponse, - AuditPolicySchemaRequest, -) from .adapter import AuditPoliciesAdapter @@ -59,7 +60,11 @@ async def get_audit_policies( return await audit_adapter.get_policies() -@audit_router.put("/policy/{policy_id}", error_map=error_map) +@audit_router.put( + "/policy/{policy_id}", + error_map=error_map, + dependencies=[Depends(require_master_db)], +) async def update_audit_policy( policy_id: int, policy_data: AuditPolicySchemaRequest, @@ -81,6 +86,7 @@ async def get_audit_destinations( "/destination", status_code=status.HTTP_201_CREATED, error_map=error_map, + dependencies=[Depends(require_master_db)], ) async def create_audit_destination( destination_data: AuditDestinationSchemaRequest, @@ -90,7 +96,11 @@ async def create_audit_destination( return await audit_adapter.create_destination(destination_data) -@audit_router.delete("/destination/{destination_id}", error_map=error_map) +@audit_router.delete( + "/destination/{destination_id}", + error_map=error_map, + dependencies=[Depends(require_master_db)], +) async def delete_audit_destination( destination_id: int, audit_adapter: FromDishka[AuditPoliciesAdapter], @@ -99,7 +109,11 @@ async def delete_audit_destination( await audit_adapter.delete_destination(destination_id) -@audit_router.put("/destination/{destination_id}", error_map=error_map) +@audit_router.put( + "/destination/{destination_id}", + error_map=error_map, + dependencies=[Depends(require_master_db)], +) async def update_audit_destination( destination_id: int, destination_data: AuditDestinationSchemaRequest, diff --git a/app/ldap_protocol/policies/audit/schemas.py b/app/api/audit/schemas.py similarity index 86% rename from app/ldap_protocol/policies/audit/schemas.py rename to app/api/audit/schemas.py index a23387467..40bbf52e9 100644 --- a/app/ldap_protocol/policies/audit/schemas.py +++ b/app/api/audit/schemas.py @@ -1,11 +1,9 @@ -"""Audit policies schemas module. +"""Audit schemas. Copyright (c) 2025 MultiFactor License: https://github.com/MultiDirectoryLab/MultiDirectory/blob/main/LICENSE """ -from dataclasses import dataclass - from pydantic import BaseModel, Field from enums import AuditDestinationProtocolType, AuditDestinationServiceType @@ -20,8 +18,7 @@ class AuditPolicySchemaRequest(BaseModel): severity: str -@dataclass -class AuditPolicyResponse: +class AuditPolicyResponse(BaseModel): """Audit policy schema.""" id: int @@ -44,8 +41,7 @@ class Config: # noqa: D106 use_enum_values = True -@dataclass -class AuditDestinationResponse: +class AuditDestinationResponse(BaseModel): """Audit destination schema.""" id: int diff --git a/app/api/auth/adapters/auth.py b/app/api/auth/adapters/auth.py index 50ed85ee7..bb7f766fb 100644 --- a/app/api/auth/adapters/auth.py +++ b/app/api/auth/adapters/auth.py @@ -9,14 +9,10 @@ from adaptix.conversion import get_converter from fastapi import Request +from api.auth.schemas import MFAChallengeResponse, OAuth2Form, SetupRequest from api.base_adapter import BaseAdapter from ldap_protocol.auth import AuthManager -from ldap_protocol.auth.dto import SetupDTO -from ldap_protocol.auth.schemas import ( - MFAChallengeResponse, - OAuth2Form, - SetupRequest, -) +from ldap_protocol.auth.dto import LoginRequestDTO, SetupDTO from ldap_protocol.dialogue import UserSchema _convert_request_to_dto = get_converter(SetupRequest, SetupDTO) @@ -42,10 +38,13 @@ async def login( :raises HTTPException: 403 if access is forbidden (e.g. not in admins, disabled, expired, or policy failed) :raises HTTPException: 426 if MFA is required - :return: None + :return: MFAChallengeResponse | None """ login_dto = await self._service.login( - form=form, + form=LoginRequestDTO( + username=form.username, + password=form.password, + ), url=request.url_for("callback_mfa"), ip=ip, user_agent=user_agent, @@ -54,7 +53,12 @@ async def login( self._service.set_new_session_key( login_dto.session_key, ) - return login_dto.mfa_challenge + if login_dto.mfa_challenge is not None: + return MFAChallengeResponse( + status=login_dto.mfa_challenge.status, + message=login_dto.mfa_challenge.message, + ) + return None async def reset_password( self, diff --git a/app/api/auth/adapters/mfa.py b/app/api/auth/adapters/mfa.py index 9fa3b4a02..163858ba6 100644 --- a/app/api/auth/adapters/mfa.py +++ b/app/api/auth/adapters/mfa.py @@ -9,10 +9,11 @@ from fastapi import status from fastapi.responses import RedirectResponse +from api.auth.schemas import MFACreateRequest, MFAGetResponse from api.base_adapter import BaseAdapter from ldap_protocol.auth import MFAManager +from ldap_protocol.auth.dto import MFACreateRequestDTO from ldap_protocol.auth.exceptions.mfa import MFATokenError -from ldap_protocol.auth.schemas import MFACreateRequest, MFAGetResponse from ldap_protocol.multifactor import MFA_HTTP_Creds, MFA_LDAP_Creds @@ -25,7 +26,15 @@ async def setup_mfa(self, mfa: MFACreateRequest) -> bool: :param mfa: MFACreateRequest :return: bool """ - return await self._service.setup_mfa(mfa) + return await self._service.setup_mfa( + MFACreateRequestDTO( + mfa_key=mfa.mfa_key, + mfa_secret=mfa.mfa_secret, + is_ldap_scope=mfa.is_ldap_scope, + key_name=mfa.key_name, + secret_name=mfa.secret_name, + ), + ) async def remove_mfa(self, scope: str) -> None: """Delete MFA keys by scope. @@ -46,7 +55,16 @@ async def get_mfa( :param mfa_creds_ldap: MFA_LDAP_Creds :return: MFAGetResponse """ - return await self._service.get_mfa(mfa_creds, mfa_creds_ldap) + mfa_get_response = await self._service.get_mfa( + mfa_creds, + mfa_creds_ldap, + ) + return MFAGetResponse( + mfa_key=mfa_get_response.mfa_key, + mfa_secret=mfa_get_response.mfa_secret, + mfa_key_ldap=mfa_get_response.mfa_key_ldap, + mfa_secret_ldap=mfa_get_response.mfa_secret_ldap, + ) async def callback_mfa( self, diff --git a/app/api/auth/router_auth.py b/app/api/auth/router_auth.py index ae8df7bfd..004f8c0c0 100644 --- a/app/api/auth/router_auth.py +++ b/app/api/auth/router_auth.py @@ -13,12 +13,14 @@ from fastapi_error_map.rules import rule from api.auth.adapters import AuthFastAPIAdapter +from api.auth.schemas import MFAChallengeResponse, OAuth2Form, SetupRequest from api.auth.utils import get_ip_from_request, get_user_agent_from_request from api.error_routing import ( ERROR_MAP_TYPE, DishkaErrorAwareRoute, DomainErrorTranslator, ) +from api.utils import require_master_db from enums import DomainCodes from ldap_protocol.auth.exceptions.mfa import ( MFAAPIError, @@ -26,11 +28,6 @@ MFARequiredError, MissingMFACredentialsError, ) -from ldap_protocol.auth.schemas import ( - MFAChallengeResponse, - OAuth2Form, - SetupRequest, -) from ldap_protocol.dialogue import UserSchema from ldap_protocol.identity.exceptions import ( AlreadyConfiguredError, @@ -67,7 +64,7 @@ translator=translator, ), PasswordPolicyError: rule( - status=status.HTTP_422_UNPROCESSABLE_ENTITY, + status=status.HTTP_422_UNPROCESSABLE_CONTENT, translator=translator, ), UserNotFoundError: rule( @@ -75,7 +72,7 @@ translator=translator, ), AuthValidationError: rule( - status=status.HTTP_422_UNPROCESSABLE_ENTITY, + status=status.HTTP_422_UNPROCESSABLE_CONTENT, translator=translator, ), MFARequiredError: rule( @@ -186,7 +183,7 @@ async def logout( @auth_router.patch( "/user/password", status_code=200, - dependencies=[Depends(verify_auth)], + dependencies=[Depends(verify_auth), Depends(require_master_db)], error_map=error_map, ) async def password_reset( @@ -229,6 +226,7 @@ async def check_setup( status_code=status.HTTP_200_OK, responses={423: {"detail": "Locked"}}, error_map=error_map, + dependencies=[Depends(require_master_db)], ) async def first_setup( request: SetupRequest, diff --git a/app/api/auth/router_mfa.py b/app/api/auth/router_mfa.py index 18424c8ca..a003f7a52 100644 --- a/app/api/auth/router_mfa.py +++ b/app/api/auth/router_mfa.py @@ -14,6 +14,7 @@ from fastapi_error_map.rules import rule from api.auth.adapters import MFAFastAPIAdapter +from api.auth.schemas import MFACreateRequest, MFAGetResponse from api.auth.utils import ( get_ip_from_request, get_user_agent_from_request, @@ -24,6 +25,7 @@ DishkaErrorAwareRoute, DomainErrorTranslator, ) +from api.utils import require_master_db from enums import DomainCodes from ldap_protocol.auth.exceptions.mfa import ( ForbiddenError, @@ -34,7 +36,6 @@ NetworkPolicyError, NotFoundError, ) -from ldap_protocol.auth.schemas import MFACreateRequest, MFAGetResponse from ldap_protocol.multifactor import MFA_HTTP_Creds, MFA_LDAP_Creds translator = DomainErrorTranslator(DomainCodes.MFA) @@ -62,7 +63,7 @@ translator=translator, ), InvalidCredentialsError: rule( - status=status.HTTP_422_UNPROCESSABLE_ENTITY, + status=status.HTTP_422_UNPROCESSABLE_CONTENT, translator=translator, ), NotFoundError: rule( @@ -81,7 +82,7 @@ @mfa_router.post( "/setup", status_code=status.HTTP_201_CREATED, - dependencies=[Depends(verify_auth)], + dependencies=[Depends(verify_auth), Depends(require_master_db)], error_map=error_map, ) async def setup_mfa( @@ -100,7 +101,7 @@ async def setup_mfa( @mfa_router.delete( "/keys", - dependencies=[Depends(verify_auth)], + dependencies=[Depends(verify_auth), Depends(require_master_db)], error_map=error_map, ) async def remove_mfa( @@ -113,7 +114,7 @@ async def remove_mfa( @mfa_router.post( "/get", - dependencies=[Depends(verify_auth)], + dependencies=[Depends(verify_auth), Depends(require_master_db)], error_map=error_map, ) async def get_mfa( diff --git a/app/ldap_protocol/auth/schemas.py b/app/api/auth/schemas.py similarity index 71% rename from app/ldap_protocol/auth/schemas.py rename to app/api/auth/schemas.py index fe786189c..3102f2b37 100644 --- a/app/ldap_protocol/auth/schemas.py +++ b/app/api/auth/schemas.py @@ -1,25 +1,14 @@ -"""Schemas for auth module. +"""Auth schemas. Copyright (c) 2025 MultiFactor License: https://github.com/MultiDirectoryLab/MultiDirectory/blob/main/LICENSE """ import re -from dataclasses import dataclass -from datetime import datetime -from ipaddress import IPv4Address, IPv6Address -from typing import Literal from fastapi.param_functions import Form from fastapi.security import OAuth2PasswordRequestForm -from pydantic import ( - BaseModel, - ConfigDict, - Field, - SecretStr, - computed_field, - field_validator, -) +from pydantic import BaseModel, SecretStr, computed_field, field_validator from ldap_protocol.utils.const import EmailStr @@ -96,23 +85,3 @@ class MFAChallengeResponse(BaseModel): status: str message: str - - -@dataclass -class LoginDTO: - """Login Data Transfer Object.""" - - session_key: str | None - mfa_challenge: MFAChallengeResponse | None - - -class SessionContentSchema(BaseModel): - """Session content schema.""" - - model_config = ConfigDict(extra="allow") - - id: int - sign: str = Field("", description="Session signature") - issued: datetime - ip: IPv4Address | IPv6Address - protocol: Literal["ldap", "http"] = "http" diff --git a/app/api/dhcp/adapter.py b/app/api/dhcp/adapter.py index d063ad144..2a680e137 100644 --- a/app/api/dhcp/adapter.py +++ b/app/api/dhcp/adapter.py @@ -7,8 +7,7 @@ from ipaddress import IPv4Address from api.base_adapter import BaseAdapter -from ldap_protocol.dhcp import ( - AbstractDHCPManager, +from api.dhcp.schemas import ( DHCPChangeStateSchemaRequest, DHCPLeaseSchemaRequest, DHCPLeaseSchemaResponse, @@ -19,6 +18,7 @@ DHCPSubnetSchemaAddRequest, DHCPSubnetSchemaResponse, ) +from ldap_protocol.dhcp import AbstractDHCPManager from ldap_protocol.dhcp.dataclasses import ( DHCPLease, DHCPOptionData, diff --git a/app/api/dhcp/router.py b/app/api/dhcp/router.py index 053241809..d1eb9b77b 100644 --- a/app/api/dhcp/router.py +++ b/app/api/dhcp/router.py @@ -12,6 +12,17 @@ from fastapi_error_map.rules import rule from api.auth.utils import verify_auth +from api.dhcp.schemas import ( + DHCPChangeStateSchemaRequest, + DHCPLeaseSchemaRequest, + DHCPLeaseSchemaResponse, + DHCPLeaseToReservationErrorResponse, + DHCPReservationSchemaRequest, + DHCPReservationSchemaResponse, + DHCPStateSchemaResponse, + DHCPSubnetSchemaAddRequest, + DHCPSubnetSchemaResponse, +) from api.error_routing import ( ERROR_MAP_TYPE, DishkaErrorAwareRoute, @@ -25,18 +36,7 @@ DHCPEntryNotFoundError, DHCPEntryUpdateError, DHCPOperationError, - DHCPValidatonError, -) -from ldap_protocol.dhcp.schemas import ( - DHCPChangeStateSchemaRequest, - DHCPLeaseSchemaRequest, - DHCPLeaseSchemaResponse, - DHCPLeaseToReservationErrorResponse, - DHCPReservationSchemaRequest, - DHCPReservationSchemaResponse, - DHCPStateSchemaResponse, - DHCPSubnetSchemaAddRequest, - DHCPSubnetSchemaResponse, + DHCPValidationError, ) from .adapter import DHCPAdapter @@ -65,8 +65,8 @@ status=status.HTTP_400_BAD_REQUEST, translator=translator, ), - DHCPValidatonError: rule( - status=status.HTTP_422_UNPROCESSABLE_ENTITY, + DHCPValidationError: rule( + status=status.HTTP_422_UNPROCESSABLE_CONTENT, translator=translator, ), DHCPOperationError: rule( diff --git a/app/ldap_protocol/dhcp/schemas.py b/app/api/dhcp/schemas.py similarity index 71% rename from app/ldap_protocol/dhcp/schemas.py rename to app/api/dhcp/schemas.py index 8f3b0a2c6..c3e10dde8 100644 --- a/app/ldap_protocol/dhcp/schemas.py +++ b/app/api/dhcp/schemas.py @@ -1,56 +1,15 @@ -"""Schemas for DHCP manager. +"""DHCP schemas. Copyright (c) 2025 MultiFactor License: https://github.com/MultiDirectoryLab/MultiDirectory/blob/main/LICENSE """ -from dataclasses import dataclass, field from datetime import datetime from ipaddress import IPv4Address, IPv4Network from pydantic import BaseModel, field_serializer -from .dataclasses import DHCPLease, DHCPReservation, DHCPSubnet -from .enums import DHCPManagerState, KeaDHCPCommands - - -@dataclass -class KeaDHCPCommandRequest: - """Single command request.""" - - command: KeaDHCPCommands - - -@dataclass -class KeaDHCPBaseAPIRequest(KeaDHCPCommandRequest): - """Base request for Kea DHCP API.""" - - arguments: list[int] | dict[str, str] | None = None - service: list[str] = field(default_factory=lambda: ["dhcp4"]) - - -@dataclass -class KeaDHCPAPISubnetRequest(KeaDHCPCommandRequest): - """Request for Kea DHCP API to manage subnets.""" - - subnet4: DHCPSubnet | list[DHCPSubnet] - service: list[str] = field(default_factory=lambda: ["dhcp4"]) - - -@dataclass -class KeaDHCPAPILeaseRequest(KeaDHCPCommandRequest): - """Request for Kea DHCP API to manage leases.""" - - lease: DHCPLease - service: list[str] = field(default_factory=lambda: ["dhcp4"]) - - -@dataclass -class KeaDHCPAPIReservationRequest(KeaDHCPCommandRequest): - """Request for Kea DHCP API to manage reservations.""" - - arguments: DHCPReservation - service: list[str] = field(default_factory=lambda: ["dhcp4"]) +from ldap_protocol.dhcp.enums import DHCPManagerState class DHCPSubnetSchemaAddRequest(BaseModel): diff --git a/app/api/dns/__init__.py b/app/api/dns/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/app/api/dns/adapter.py b/app/api/dns/adapter.py new file mode 100644 index 000000000..3514e9a80 --- /dev/null +++ b/app/api/dns/adapter.py @@ -0,0 +1,202 @@ +"""DNS adapter. + +Copyright (c) 2025 MultiFactor +License: https://github.com/MultiDirectoryLab/MultiDirectory/blob/main/LICENSE +""" + +from api.base_adapter import BaseAdapter +from api.dns.schema import ( + DNSServiceForwardZoneCheckRequest, + DNSServiceForwardZoneRequest, + DNSServiceMasterZoneRequest, + DNSServiceRecordCreateRequest, + DNSServiceRecordDeleteRequest, + DNSServiceRecordUpdateRequest, + DNSServiceSetStateRequest, + DNSServiceSetupRequest, + DNSServiceZoneDeleteRequest, +) +from ldap_protocol.dns.dto import ( + DNSForwardServerStatus, + DNSForwardZoneDTO, + DNSMasterZoneDTO, + DNSRecordDTO, + DNSRRSetDTO, + DNSSettingsDTO, +) +from ldap_protocol.dns.enums import DNSRecordType +from ldap_protocol.dns.use_cases import DNSUseCase + + +class DNSFastAPIAdapter(BaseAdapter[DNSUseCase]): + """DNS adapter.""" + + async def create_record( + self, + zone_id: str, + data: DNSServiceRecordCreateRequest, + ) -> None: + """Create DNS record.""" + await self._service.create_record( + zone_id, + DNSRRSetDTO( + name=data.record_name, + type=DNSRecordType(data.record_type), + records=[ + DNSRecordDTO( + content=data.record_value, + disabled=False, + ), + ], + ttl=data.ttl, + ), + ) + + async def delete_record( + self, + zone_id: str, + data: DNSServiceRecordDeleteRequest, + ) -> None: + """Delete DNS record.""" + await self._service.delete_record( + zone_id, + DNSRRSetDTO( + name=data.record_name, + type=data.record_type, + records=[ + DNSRecordDTO( + content=data.record_value, + disabled=False, + ), + ], + ), + ) + + async def update_record( + self, + zone_id: str, + data: DNSServiceRecordUpdateRequest, + ) -> None: + """Update DNS record.""" + await self._service.update_record( + zone_id, + DNSRRSetDTO( + name=data.record_name, + type=data.record_type, + records=[ + DNSRecordDTO( + content=data.record_value, + disabled=False, + ), + ], + ttl=data.ttl, + ), + ) + + async def get_records(self, zone_id: str) -> list[DNSRRSetDTO]: + """Get all DNS records of current zone.""" + return await self._service.get_records(zone_id) + + async def get_status(self) -> dict[str, str | None]: + """Get DNS service status.""" + return await self._service.get_status() + + async def set_state( + self, + data: DNSServiceSetStateRequest, + ) -> None: + """Set DNS manager state.""" + await self._service.set_state(data.state) + + async def setup(self, data: DNSServiceSetupRequest | None) -> None: + await self._service.setup( + DNSSettingsDTO( + dns_server_ip=data.dns_ip_address, + tsig_key=data.tsig_key, + domain=data.domain, + default_nameserver=str(data.dns_ip_address), + ) + if data is not None + else data, + ) + + async def create_forward_zone( + self, + data: DNSServiceForwardZoneRequest, + ) -> None: + """Create new DNS forward zone.""" + await self._service.create_forward_zone( + DNSForwardZoneDTO( + id=data.zone_name, + name=data.zone_name, + servers=data.servers, + ), + ) + + async def get_forward_zones(self) -> list[DNSForwardZoneDTO]: + """Get list of DNS forward zones with forwarders.""" + return await self._service.get_forward_zones() + + async def update_forward_zone( + self, + data: DNSServiceForwardZoneRequest, + ) -> None: + """Update DNS forward zone with given params.""" + await self._service.update_forward_zone( + DNSForwardZoneDTO( + id=data.zone_name, + name=data.zone_name, + servers=data.servers, + ), + ) + + async def delete_forward_zones( + self, + data: DNSServiceZoneDeleteRequest, + ) -> None: + """Delete DNS forward zones.""" + await self._service.delete_forward_zones(data.zone_ids) + + async def create_master_zone( + self, + data: DNSServiceMasterZoneRequest, + ) -> None: + """Create new DNS zone.""" + await self._service.create_master_zone( + DNSMasterZoneDTO( + id=data.zone_name, + name=data.zone_name, + dnssec=data.dnssec, + ), + ) + + async def get_master_zones(self) -> list[DNSMasterZoneDTO]: + """Get all DNS master zones.""" + return await self._service.get_master_zones() + + async def update_master_zone( + self, + data: DNSServiceMasterZoneRequest, + ) -> None: + """Update DNS zone with given params.""" + await self._service.update_master_zone( + DNSMasterZoneDTO( + id=data.zone_name, + name=data.zone_name, + dnssec=data.dnssec, + ), + ) + + async def delete_master_zones( + self, + data: DNSServiceZoneDeleteRequest, + ) -> None: + """Delete DNS zones.""" + await self._service.delete_master_zones(data.zone_ids) + + async def check_forward_zone( + self, + data: DNSServiceForwardZoneCheckRequest, + ) -> list[DNSForwardServerStatus]: + """Check DNS forward zone for availability.""" + return await self._service.check_forward_zone(data.dns_server_ips) diff --git a/app/api/main/dns_router.py b/app/api/dns/router.py similarity index 62% rename from app/api/main/dns_router.py rename to app/api/dns/router.py index d93382512..7f4bad5cb 100644 --- a/app/api/main/dns_router.py +++ b/app/api/dns/router.py @@ -12,30 +12,30 @@ import ldap_protocol.dns.exceptions as dns_exc from api.auth.utils import verify_auth -from api.error_routing import ( - ERROR_MAP_TYPE, - DishkaErrorAwareRoute, - DomainErrorTranslator, -) -from api.main.adapters.dns import DNSFastAPIAdapter -from api.main.schema import ( +from api.dns.adapter import DNSFastAPIAdapter +from api.dns.schema import ( DNSServiceForwardZoneCheckRequest, + DNSServiceForwardZoneRequest, + DNSServiceMasterZoneRequest, DNSServiceRecordCreateRequest, DNSServiceRecordDeleteRequest, DNSServiceRecordUpdateRequest, - DNSServiceReloadZoneRequest, + DNSServiceSetStateRequest, DNSServiceSetupRequest, - DNSServiceZoneCreateRequest, DNSServiceZoneDeleteRequest, - DNSServiceZoneUpdateRequest, ) +from api.error_routing import ( + ERROR_MAP_TYPE, + DishkaErrorAwareRoute, + DomainErrorTranslator, +) +from api.utils import require_master_db from enums import DomainCodes from ldap_protocol.dns import ( DNSForwardServerStatus, - DNSForwardZone, - DNSRecords, - DNSServerParam, - DNSZone, + DNSForwardZoneDTO, + DNSMasterZoneDTO, + DNSRRSetDTO, ) translator = DomainErrorTranslator(DomainCodes.DNS) @@ -43,13 +43,17 @@ error_map: ERROR_MAP_TYPE = { dns_exc.DNSSetupError: rule( - status=status.HTTP_422_UNPROCESSABLE_ENTITY, + status=status.HTTP_422_UNPROCESSABLE_CONTENT, translator=translator, ), dns_exc.DNSRecordCreateError: rule( status=status.HTTP_400_BAD_REQUEST, translator=translator, ), + dns_exc.DNSRecordGetError: rule( + status=status.HTTP_400_BAD_REQUEST, + translator=translator, + ), dns_exc.DNSRecordUpdateError: rule( status=status.HTTP_400_BAD_REQUEST, translator=translator, @@ -62,6 +66,10 @@ status=status.HTTP_400_BAD_REQUEST, translator=translator, ), + dns_exc.DNSZoneGetError: rule( + status=status.HTTP_400_BAD_REQUEST, + translator=translator, + ), dns_exc.DNSZoneUpdateError: rule( status=status.HTTP_400_BAD_REQUEST, translator=translator, @@ -90,45 +98,49 @@ dns_router = ErrorAwareRouter( prefix="/dns", - tags=["DNS_SERVICE"], + tags=["DNS Service"], dependencies=[Depends(verify_auth)], route_class=DishkaErrorAwareRoute, ) -@dns_router.post("/record", error_map=error_map) +@dns_router.post("/record/{zone_id}", error_map=error_map) async def create_record( + zone_id: str, data: DNSServiceRecordCreateRequest, adapter: FromDishka[DNSFastAPIAdapter], ) -> None: """Create DNS record with given params.""" - await adapter.create_record(data) + await adapter.create_record(zone_id, data) -@dns_router.delete("/record", error_map=error_map) -async def delete_single_record( - data: DNSServiceRecordDeleteRequest, +@dns_router.get("/record/{zone_id}", error_map=error_map) +async def get_all_records( + zone_id: str, adapter: FromDishka[DNSFastAPIAdapter], -) -> None: - """Delete DNS record with given params.""" - await adapter.delete_record(data) +) -> list[DNSRRSetDTO]: + """Get all DNS records of current zone.""" + return await adapter.get_records(zone_id) -@dns_router.patch("/record", error_map=error_map) +@dns_router.patch("/record/{zone_id}", error_map=error_map) async def update_record( + zone_id: str, data: DNSServiceRecordUpdateRequest, adapter: FromDishka[DNSFastAPIAdapter], ) -> None: """Update DNS record with given params.""" - await adapter.update_record(data) + await adapter.update_record(zone_id, data) -@dns_router.get("/record", error_map=error_map) -async def get_all_records( +@dns_router.delete("/record/{zone_id}", error_map=error_map) +async def delete_single_record( + zone_id: str, + data: DNSServiceRecordDeleteRequest, adapter: FromDishka[DNSFastAPIAdapter], -) -> list[DNSRecords]: - """Get all DNS records of current zone.""" - return await adapter.get_all_records() +) -> None: + """Delete DNS record with given params.""" + await adapter.delete_record(zone_id, data) @dns_router.get("/status", error_map=error_map) @@ -136,32 +148,68 @@ async def get_dns_status( adapter: FromDishka[DNSFastAPIAdapter], ) -> dict[str, str | None]: """Get DNS service status.""" - return await adapter.get_dns_status() + return await adapter.get_status() -@dns_router.post("/setup", error_map=error_map) +@dns_router.post( + "/setup", + error_map=error_map, + dependencies=[Depends(require_master_db)], +) async def setup_dns( - data: DNSServiceSetupRequest, adapter: FromDishka[DNSFastAPIAdapter], + data: DNSServiceSetupRequest | None = None, ) -> None: """Set up DNS service.""" - await adapter.setup_dns(data) + await adapter.setup(data) -@dns_router.get("/zone", error_map=error_map) -async def get_dns_zone( +@dns_router.post( + "/state", + error_map=error_map, + dependencies=[Depends(require_master_db)], +) +async def set_dns_state( + data: DNSServiceSetStateRequest, adapter: FromDishka[DNSFastAPIAdapter], -) -> list[DNSZone]: - """Get all DNS records of all zones.""" - return await adapter.get_dns_zone() +) -> None: + """Set DNS manager state.""" + await adapter.set_state(data) + + +@dns_router.post("/zone/forward", error_map=error_map) +async def create_forward_zone( + data: DNSServiceForwardZoneRequest, + adapter: FromDishka[DNSFastAPIAdapter], +) -> None: + """Create new forward DNS zone.""" + return await adapter.create_forward_zone(data) @dns_router.get("/zone/forward", error_map=error_map) async def get_forward_dns_zones( adapter: FromDishka[DNSFastAPIAdapter], -) -> list[DNSForwardZone]: +) -> list[DNSForwardZoneDTO]: """Get list of DNS forward zones with forwarders.""" - return await adapter.get_forward_dns_zones() + return await adapter.get_forward_zones() + + +@dns_router.patch("/zone/forward", error_map=error_map) +async def update_forward_zone( + data: DNSServiceForwardZoneRequest, + adapter: FromDishka[DNSFastAPIAdapter], +) -> None: + """Update forward DNS zone with given params.""" + await adapter.update_forward_zone(data) + + +@dns_router.delete("/zone/forward", error_map=error_map) +async def delete_forward_zone( + data: DNSServiceZoneDeleteRequest, + adapter: FromDishka[DNSFastAPIAdapter], +) -> None: + """Delete DNS forward zone.""" + await adapter.delete_forward_zones(data) @dns_router.post( @@ -170,30 +218,38 @@ async def get_forward_dns_zones( warn_on_unmapped=False, default_client_error_translator=translator, ) -async def create_zone( - data: DNSServiceZoneCreateRequest, +async def create_master_zone( + data: DNSServiceMasterZoneRequest, adapter: FromDishka[DNSFastAPIAdapter], ) -> None: """Create new DNS zone.""" - await adapter.create_zone(data) + await adapter.create_master_zone(data) + + +@dns_router.get("/zone", error_map=error_map) +async def get_dns_zones( + adapter: FromDishka[DNSFastAPIAdapter], +) -> list[DNSMasterZoneDTO]: + """Get all DNS records of all zones.""" + return await adapter.get_master_zones() @dns_router.patch("/zone", error_map=error_map) -async def update_zone( - data: DNSServiceZoneUpdateRequest, +async def update_master_zone( + data: DNSServiceMasterZoneRequest, adapter: FromDishka[DNSFastAPIAdapter], ) -> None: """Update DNS zone with given params.""" - await adapter.update_zone(data) + await adapter.update_master_zone(data) @dns_router.delete("/zone", error_map=error_map) -async def delete_zone( +async def delete_master_zone( data: DNSServiceZoneDeleteRequest, adapter: FromDishka[DNSFastAPIAdapter], ) -> None: """Delete DNS zone.""" - await adapter.delete_zone(data) + await adapter.delete_master_zones(data) @dns_router.post("/forward_check", error_map=error_map) @@ -202,38 +258,4 @@ async def check_dns_forward_zone( adapter: FromDishka[DNSFastAPIAdapter], ) -> list[DNSForwardServerStatus]: """Check given DNS forward zone for availability.""" - return await adapter.check_dns_forward_zone(data) - - -@dns_router.get("/zone/reload/", error_map=error_map) -async def reload_zone( - data: DNSServiceReloadZoneRequest, - adapter: FromDishka[DNSFastAPIAdapter], -) -> None: - """Reload given DNS zone.""" - await adapter.reload_zone(data) - - -@dns_router.patch("/server/options") -async def update_server_options( - data: list[DNSServerParam], - adapter: FromDishka[DNSFastAPIAdapter], -) -> None: - """Update DNS server options.""" - await adapter.update_server_options(data) - - -@dns_router.get("/server/options") -async def get_server_options( - adapter: FromDishka[DNSFastAPIAdapter], -) -> list[DNSServerParam]: - """Get list of modifiable DNS server params.""" - return await adapter.get_server_options() - - -@dns_router.get("/server/restart") -async def restart_server( - adapter: FromDishka[DNSFastAPIAdapter], -) -> None: - """Restart entire DNS server.""" - await adapter.restart_server() + return await adapter.check_forward_zone(data) diff --git a/app/api/dns/schema.py b/app/api/dns/schema.py new file mode 100644 index 000000000..1cc595580 --- /dev/null +++ b/app/api/dns/schema.py @@ -0,0 +1,79 @@ +"""Schemas for DNS router. + +Copyright (c) 2026 MultiFactor +License: https://github.com/MultiDirectoryLab/MultiDirectory/blob/main/LICENSE +""" + +from ipaddress import IPv4Address, IPv6Address + +from pydantic import BaseModel + +from ldap_protocol.dns import DNSManagerState, DNSRecordType + + +class DNSServiceSetStateRequest(BaseModel): + """DNS set state request schema.""" + + state: DNSManagerState + + +class DNSServiceSetupRequest(BaseModel): + """DNS setup request schema.""" + + domain: str + dns_ip_address: IPv4Address | IPv6Address | None = None + tsig_key: str | None = None + + +class DNSServiceRecordBaseRequest(BaseModel): + """DNS setup base schema.""" + + record_name: str + record_type: DNSRecordType + + +class DNSServiceRecordCreateRequest(DNSServiceRecordBaseRequest): + """DNS create request schema.""" + + record_value: str + ttl: int | None = None + + +class DNSServiceRecordDeleteRequest(DNSServiceRecordBaseRequest): + """DNS delete request schema.""" + + record_value: str + + +class DNSServiceRecordUpdateRequest(DNSServiceRecordBaseRequest): + """DNS update request schema.""" + + record_value: str + ttl: int | None = None + + +class DNSServiceForwardZoneRequest(BaseModel): + """DNS zone create request scheme.""" + + zone_name: str + servers: list[str] + + +class DNSServiceMasterZoneRequest(BaseModel): + """DNS zone create request scheme.""" + + zone_name: str + nameserver_ip: str + dnssec: bool = False + + +class DNSServiceZoneDeleteRequest(BaseModel): + """DNS zone delete request scheme.""" + + zone_ids: list[str] + + +class DNSServiceForwardZoneCheckRequest(BaseModel): + """Forwarder DNS server check request scheme.""" + + dns_server_ips: list[IPv4Address | IPv6Address] diff --git a/app/api/ldap_schema/adapters/attribute_type.py b/app/api/ldap_schema/adapters/attribute_type.py index ad1ea6516..73e5f32bc 100644 --- a/app/api/ldap_schema/adapters/attribute_type.py +++ b/app/api/ldap_schema/adapters/attribute_type.py @@ -44,6 +44,7 @@ def _convert_update_uschema_to_dto( single_value=request.single_value, no_user_modification=request.no_user_modification, is_system=False, + system_flags=0, is_included_anr=request.is_included_anr, ) diff --git a/app/api/ldap_schema/attribute_type_router.py b/app/api/ldap_schema/attribute_type_router.py index 5a2f1f368..a75a1826a 100644 --- a/app/api/ldap_schema/attribute_type_router.py +++ b/app/api/ldap_schema/attribute_type_router.py @@ -7,7 +7,7 @@ from typing import Annotated from dishka.integrations.fastapi import FromDishka -from fastapi import Query, status +from fastapi import Depends, Query, status from api.ldap_schema import LimitedListType, error_map, ldap_schema_router from api.ldap_schema.adapters.attribute_type import AttributeTypeFastAPIAdapter @@ -16,6 +16,7 @@ AttributeTypeSchema, AttributeTypeUpdateSchema, ) +from api.utils import require_master_db from ldap_protocol.utils.pagination import PaginationParams @@ -23,6 +24,7 @@ "/attribute_type", status_code=status.HTTP_201_CREATED, error_map=error_map, + dependencies=[Depends(require_master_db)], ) async def create_one_attribute_type( request_data: AttributeTypeSchema[None], @@ -59,6 +61,7 @@ async def get_list_attribute_types_with_pagination( @ldap_schema_router.patch( "/attribute_type/{attribute_type_name}", error_map=error_map, + dependencies=[Depends(require_master_db)], ) async def modify_one_attribute_type( attribute_type_name: str, @@ -72,6 +75,7 @@ async def modify_one_attribute_type( @ldap_schema_router.post( "/attribute_types/delete", error_map=error_map, + dependencies=[Depends(require_master_db)], ) async def delete_bulk_attribute_types( attribute_types_names: LimitedListType, diff --git a/app/api/ldap_schema/entity_type_router.py b/app/api/ldap_schema/entity_type_router.py index 31de91616..129230b8e 100644 --- a/app/api/ldap_schema/entity_type_router.py +++ b/app/api/ldap_schema/entity_type_router.py @@ -7,7 +7,7 @@ from typing import Annotated from dishka.integrations.fastapi import FromDishka -from fastapi import Query, status +from fastapi import Depends, Query, status from api.ldap_schema import LimitedListType, error_map from api.ldap_schema.adapters.entity_type import LDAPEntityTypeFastAPIAdapter @@ -17,6 +17,7 @@ EntityTypeSchema, EntityTypeUpdateSchema, ) +from api.utils import require_master_db from ldap_protocol.utils.pagination import PaginationParams @@ -24,6 +25,7 @@ "/entity_type", status_code=status.HTTP_201_CREATED, error_map=error_map, + dependencies=[Depends(require_master_db)], ) async def create_one_entity_type( request_data: EntityTypeSchema[None], @@ -66,6 +68,7 @@ async def get_entity_type_attributes( @ldap_schema_router.patch( "/entity_type/{entity_type_name}", error_map=error_map, + dependencies=[Depends(require_master_db)], ) async def modify_one_entity_type( entity_type_name: str, @@ -76,7 +79,11 @@ async def modify_one_entity_type( await adapter.update(name=entity_type_name, data=request_data) -@ldap_schema_router.post("/entity_type/delete", error_map=error_map) +@ldap_schema_router.post( + "/entity_type/delete", + error_map=error_map, + dependencies=[Depends(require_master_db)], +) async def delete_bulk_entity_types( entity_type_names: LimitedListType, adapter: FromDishka[LDAPEntityTypeFastAPIAdapter], diff --git a/app/api/ldap_schema/object_class_router.py b/app/api/ldap_schema/object_class_router.py index a351f3b33..a6baced69 100644 --- a/app/api/ldap_schema/object_class_router.py +++ b/app/api/ldap_schema/object_class_router.py @@ -7,7 +7,7 @@ from typing import Annotated from dishka.integrations.fastapi import FromDishka -from fastapi import Query, status +from fastapi import Depends, Query, status from api.ldap_schema import LimitedListType, error_map from api.ldap_schema.adapters.object_class import ObjectClassFastAPIAdapter @@ -17,6 +17,7 @@ ObjectClassSchema, ObjectClassUpdateSchema, ) +from api.utils import require_master_db from ldap_protocol.utils.pagination import PaginationParams @@ -24,6 +25,7 @@ "/object_class", status_code=status.HTTP_201_CREATED, error_map=error_map, + dependencies=[Depends(require_master_db)], ) async def create_one_object_class( request_data: ObjectClassSchema[None], @@ -57,6 +59,7 @@ async def get_list_object_classes_with_pagination( @ldap_schema_router.patch( "/object_class/{object_class_name}", error_map=error_map, + dependencies=[Depends(require_master_db)], ) async def modify_one_object_class( object_class_name: str, @@ -67,7 +70,11 @@ async def modify_one_object_class( await adapter.update(object_class_name, request_data) -@ldap_schema_router.post("/object_class/delete", error_map=error_map) +@ldap_schema_router.post( + "/object_class/delete", + error_map=error_map, + dependencies=[Depends(require_master_db)], +) async def delete_bulk_object_classes( object_classes_names: LimitedListType, adapter: FromDishka[ObjectClassFastAPIAdapter], diff --git a/app/api/ldap_schema/schema.py b/app/api/ldap_schema/schema.py index 9e6453eff..b3dabefb6 100644 --- a/app/api/ldap_schema/schema.py +++ b/app/api/ldap_schema/schema.py @@ -28,6 +28,7 @@ class AttributeTypeSchema(BaseModel, Generic[_IdT]): single_value: bool no_user_modification: bool is_system: bool + system_flags: int = 0 is_included_anr: bool = False object_class_names: list[str] = Field(default_factory=list) diff --git a/app/api/main/adapters/dns.py b/app/api/main/adapters/dns.py deleted file mode 100644 index 352099fad..000000000 --- a/app/api/main/adapters/dns.py +++ /dev/null @@ -1,138 +0,0 @@ -"""DNS adapter. - -Copyright (c) 2025 MultiFactor -License: https://github.com/MultiDirectoryLab/MultiDirectory/blob/main/LICENSE -""" - -from api.base_adapter import BaseAdapter -from api.main.schema import ( - DNSServiceForwardZoneCheckRequest, - DNSServiceRecordCreateRequest, - DNSServiceRecordDeleteRequest, - DNSServiceRecordUpdateRequest, - DNSServiceReloadZoneRequest, - DNSServiceSetupRequest, - DNSServiceZoneCreateRequest, - DNSServiceZoneDeleteRequest, - DNSServiceZoneUpdateRequest, -) -from ldap_protocol.dns.base import ( - DNSForwardServerStatus, - DNSForwardZone, - DNSRecords, - DNSServerParam, - DNSZone, -) -from ldap_protocol.dns.use_cases import DNSUseCase - - -class DNSFastAPIAdapter(BaseAdapter[DNSUseCase]): - """DNS adapter.""" - - async def create_record( - self, - data: DNSServiceRecordCreateRequest, - ) -> None: - """Create DNS record.""" - await self._service.create_record( - data.record_name, - data.record_value, - data.record_type, - data.ttl, - data.zone_name, - ) - - async def delete_record( - self, - data: DNSServiceRecordDeleteRequest, - ) -> None: - """Delete DNS record.""" - await self._service.delete_record( - data.record_name, - data.record_value, - data.record_type, - data.zone_name, - ) - - async def update_record( - self, - data: DNSServiceRecordUpdateRequest, - ) -> None: - """Update DNS record.""" - await self._service.update_record( - data.record_name, - data.record_value, - data.record_type, - data.ttl, - data.zone_name, - ) - - async def get_all_records(self) -> list[DNSRecords]: - """Get all DNS records of current zone.""" - return await self._service.get_all_records() - - async def get_dns_status(self) -> dict[str, str | None]: - """Get DNS service status.""" - return await self._service.get_dns_status() - - async def setup_dns(self, data: DNSServiceSetupRequest) -> None: - await self._service.setup_dns( - dns_status=data.dns_status, - domain=data.domain, - dns_ip_address=data.dns_ip_address, - tsig_key=data.tsig_key, - ) - - async def get_dns_zone(self) -> list[DNSZone]: - """Get all DNS zones.""" - return await self._service.get_all_zones_records() - - async def get_forward_dns_zones(self) -> list[DNSForwardZone]: - """Get list of DNS forward zones with forwarders.""" - return await self._service.get_forward_zones() - - async def create_zone(self, data: DNSServiceZoneCreateRequest) -> None: - """Create new DNS zone.""" - await self._service.create_zone( - data.zone_name, - data.zone_type, - data.nameserver, - data.params, - ) - - async def update_zone(self, data: DNSServiceZoneUpdateRequest) -> None: - """Update DNS zone with given params.""" - await self._service.update_zone( - data.zone_name, - data.params, - ) - - async def delete_zone(self, data: DNSServiceZoneDeleteRequest) -> None: - """Delete DNS zone.""" - await self._service.delete_zone(data.zone_names) - - async def check_dns_forward_zone( - self, - data: DNSServiceForwardZoneCheckRequest, - ) -> list[DNSForwardServerStatus]: - """Check DNS forward zone for availability.""" - return await self._service.check_dns_forward_zone(data.dns_server_ips) - - async def reload_zone(self, data: DNSServiceReloadZoneRequest) -> None: - """Reload DNS zone.""" - await self._service.reload_zone(data.zone_name) - - async def update_server_options( - self, - data: list[DNSServerParam], - ) -> None: - """Update DNS server options.""" - await self._service.update_server_options(data) - - async def get_server_options(self) -> list[DNSServerParam]: - """Get list of modifiable DNS server params.""" - return await self._service.get_server_options() - - async def restart_server(self) -> None: - """Restart DNS server.""" - await self._service.restart_server() diff --git a/app/api/main/adapters/kerberos.py b/app/api/main/adapters/kerberos.py index 1bbe252e2..c140fcc55 100644 --- a/app/api/main/adapters/kerberos.py +++ b/app/api/main/adapters/kerberos.py @@ -12,7 +12,12 @@ from starlette.background import BackgroundTask from api.base_adapter import BaseAdapter -from api.main.schema import KerberosSetupRequest +from api.main.schema import ( + KerberosSetupRequest, + KtaddRequest, + ModifyPrincipalRequest, + PrincipalAddRequest, +) from ldap_protocol.dialogue import LDAPSession, UserSchema from ldap_protocol.kerberos import KerberosState from ldap_protocol.kerberos.service import KerberosService @@ -66,46 +71,29 @@ async def setup_kdc( ) return Response(background=task) - async def add_principal( - self, - primary: str, - instance: str, - ) -> None: + async def add_principal(self, request: PrincipalAddRequest) -> None: """Create principal in Kerberos with given name. :raises HTTPException: on Kerberos errors :return: None """ - return await self._service.add_principal(primary, instance) - - async def rename_principal( - self, - principal_name: str, - principal_new_name: str, - ) -> None: - """Rename principal in Kerberos. - - :raises HTTPException: on Kerberos errors - :return: None - """ - return await self._service.rename_principal( - principal_name, - principal_new_name, + return await self._service.add_principal( + request.principal_name, + password=request.password, + algorithms=request.algorithms, ) - async def reset_principal_pw( - self, - principal_name: str, - new_password: str, - ) -> None: - """Reset principal password in Kerberos. + async def modify_principal(self, request: ModifyPrincipalRequest) -> None: + """Modify principal ( password, algorithms). :raises HTTPException: on Kerberos errors :return: None """ - return await self._service.reset_principal_pw( - principal_name, - new_password, + return await self._service.modify_principal( + principal_name=request.principal_name, + new_name=request.new_name, + algorithms=request.algorithms, + password=request.password, ) async def delete_principal( @@ -121,14 +109,17 @@ async def delete_principal( async def ktadd( self, - names: list[str], + data: KtaddRequest, ) -> StreamingResponse: """Generate keytab and return as streaming response. :raises HTTPException: on Kerberos errors :return: StreamingResponse """ - aiter_bytes, task_struct = await self._service.ktadd(names) + aiter_bytes, task_struct = await self._service.ktadd( + data.names, + is_rand_key=data.is_rand_key, + ) task = BackgroundTask( task_struct.func, *task_struct.args, diff --git a/app/api/main/krb5_router.py b/app/api/main/krb5_router.py index 91f64a5b6..37ed49721 100644 --- a/app/api/main/krb5_router.py +++ b/app/api/main/krb5_router.py @@ -23,7 +23,13 @@ DomainErrorTranslator, ) from api.main.adapters.kerberos import KerberosFastAPIAdapter -from api.main.schema import KerberosSetupRequest +from api.main.schema import ( + KerberosSetupRequest, + KtaddRequest, + ModifyPrincipalRequest, + PrincipalAddRequest, +) +from api.utils import require_master_db from enums import DomainCodes from ldap_protocol.dialogue import LDAPSession from ldap_protocol.kerberos import KerberosState @@ -82,7 +88,7 @@ "/setup/tree", response_class=Response, error_map=error_map, - dependencies=[Depends(verify_auth)], + dependencies=[Depends(verify_auth), Depends(require_master_db)], ) async def setup_krb_catalogue( mail: Annotated[EmailStr, Body()], @@ -106,7 +112,12 @@ async def setup_krb_catalogue( ) -@krb5_router.post("/setup", response_class=Response, error_map=error_map) +@krb5_router.post( + "/setup", + response_class=Response, + error_map=error_map, + dependencies=[Depends(require_master_db)], +) async def setup_kdc( data: KerberosSetupRequest, identity_adapter: FromDishka[AuthFastAPIAdapter], @@ -143,15 +154,15 @@ async def setup_kdc( error_map=error_map, ) async def ktadd( - names: Annotated[LIMITED_LIST, Body()], kerberos_adapter: FromDishka[KerberosFastAPIAdapter], + request: KtaddRequest, ) -> StreamingResponse: """Create keytab from kadmin server. :param Annotated[LDAPSession, Depends ldap_session: ldap :return bytes: file """ - return await kerberos_adapter.ktadd(names) + return await kerberos_adapter.ktadd(request) @krb5_router.get( @@ -172,13 +183,12 @@ async def get_krb_status( @krb5_router.post( - "/principal/add", - dependencies=[Depends(verify_auth)], + "/principal", + dependencies=[Depends(verify_auth), Depends(require_master_db)], error_map=error_map, ) async def add_principal( - primary: Annotated[LIMITED_STR, Body()], - instance: Annotated[LIMITED_STR, Body()], + request: PrincipalAddRequest, kerberos_adapter: FromDishka[KerberosFastAPIAdapter], ) -> None: """Create principal in kerberos with given name. @@ -188,57 +198,24 @@ async def add_principal( :param Annotated[LDAPSession, Depends ldap_session: ldap :raises HTTPException: on failed kamin request. """ - await kerberos_adapter.add_principal(primary, instance) + await kerberos_adapter.add_principal(request) -@krb5_router.patch( - "/principal/rename", - dependencies=[Depends(verify_auth)], +@krb5_router.put( + "/principal", + dependencies=[Depends(verify_auth), Depends(require_master_db)], error_map=error_map, ) -async def rename_principal( - principal_name: Annotated[LIMITED_STR, Body()], - principal_new_name: Annotated[LIMITED_STR, Body()], +async def modify_principal( + request: ModifyPrincipalRequest, kerberos_adapter: FromDishka[KerberosFastAPIAdapter], ) -> None: - """Rename principal in kerberos with given name. - - \f - :param Annotated[str, Body principal_name: upn - :param Annotated[LIMITED_STR, Body principal_new_name: _description_ - :param Annotated[LDAPSession, Depends ldap_session: ldap - :raises HTTPException: on failed kamin request. - """ - await kerberos_adapter.rename_principal( - principal_name, - principal_new_name, - ) - - -@krb5_router.patch( - "/principal/reset", - dependencies=[Depends(verify_auth)], - error_map=error_map, -) -async def reset_principal_pw( - principal_name: Annotated[LIMITED_STR, Body()], - new_password: Annotated[LIMITED_STR, Body()], - kerberos_adapter: FromDishka[KerberosFastAPIAdapter], -) -> None: - """Reset principal password in kerberos with given name. - - \f - :param Annotated[str, Body principal_name: upn - :param Annotated[LIMITED_STR, Body new_password: _description_ - :param Annotated[LDAPSession, Depends ldap_session: ldap - :raises HTTPException: on failed kamin request. - """ - await kerberos_adapter.reset_principal_pw(principal_name, new_password) + await kerberos_adapter.modify_principal(request) @krb5_router.delete( "/principal/delete", - dependencies=[Depends(verify_auth)], + dependencies=[Depends(verify_auth), Depends(require_master_db)], error_map=error_map, ) async def delete_principal( diff --git a/app/api/main/router.py b/app/api/main/router.py index f4df578e8..44ce09de2 100644 --- a/app/api/main/router.py +++ b/app/api/main/router.py @@ -16,7 +16,9 @@ DishkaErrorAwareRoute, DomainErrorTranslator, ) +from api.utils import require_master_db from enums import DomainCodes +from ldap_protocol.custom_requests.rename import RenameRequest from ldap_protocol.identity.exceptions import UnauthorizedError from ldap_protocol.ldap_requests import ( AddRequest, @@ -25,9 +27,13 @@ ModifyRequest, ) from ldap_protocol.ldap_responses import LDAPResult -from ldap_protocol.utils.queries import set_or_update_primary_group +from ldap_protocol.utils.queries import ( + get_group_path_dn_by_primary_group_id, + set_or_update_primary_group, +) from .schema import ( + PrimaryGroupPathDNResponse, PrimaryGroupRequest, SearchRequest, SearchResponse, @@ -37,7 +43,6 @@ translator = DomainErrorTranslator(DomainCodes.LDAP) - error_map: ERROR_MAP_TYPE = { UnauthorizedError: rule( status=status.HTTP_401_UNAUTHORIZED, @@ -54,10 +59,7 @@ @entry_router.post("/search", error_map=error_map) -async def search( - request: SearchRequest, - req: Request, -) -> SearchResponse: +async def search(request: SearchRequest, req: Request) -> SearchResponse: """LDAP SEARCH entry request.""" responses = await request.handle_api(req.state.dishka_container) metadata: SearchResultDone = responses.pop(-1) # type: ignore @@ -72,25 +74,31 @@ async def search( ) -@entry_router.post("/add", error_map=error_map) -async def add( - request: AddRequest, - req: Request, -) -> LDAPResult: +@entry_router.post( + "/add", + error_map=error_map, + dependencies=[Depends(require_master_db)], +) +async def add(request: AddRequest, req: Request) -> LDAPResult: """LDAP ADD entry request.""" return await request.handle_api(req.state.dishka_container) -@entry_router.patch("/update", error_map=error_map) -async def modify( - request: ModifyRequest, - req: Request, -) -> LDAPResult: +@entry_router.patch( + "/update", + error_map=error_map, + dependencies=[Depends(require_master_db)], +) +async def modify(request: ModifyRequest, req: Request) -> LDAPResult: """LDAP MODIFY entry request.""" return await request.handle_api(req.state.dishka_container) -@entry_router.patch("/update_many", error_map=error_map) +@entry_router.patch( + "/update_many", + error_map=error_map, + dependencies=[Depends(require_master_db)], +) async def modify_many( requests: list[ModifyRequest], req: Request, @@ -102,25 +110,57 @@ async def modify_many( return results -@entry_router.put("/update/dn", error_map=error_map) -async def modify_dn( - request: ModifyDNRequest, - req: Request, -) -> LDAPResult: +@entry_router.put( + "/update/dn", + error_map=error_map, + dependencies=[Depends(require_master_db)], +) +async def modify_dn(request: ModifyDNRequest, req: Request) -> LDAPResult: """LDAP MODIFY entry DN request.""" return await request.handle_api(req.state.dishka_container) -@entry_router.delete("/delete", error_map=error_map) -async def delete( - request: DeleteRequest, +@entry_router.post( + "/update_many/dn", + error_map=error_map, + dependencies=[Depends(require_master_db)], +) +async def modify_dn_many( + requests: list[ModifyDNRequest], req: Request, -) -> LDAPResult: +) -> list[LDAPResult]: + """LDAP MODIFY entry DN request.""" + results = [] + for request in requests: + results.append(await request.handle_api(req.state.dishka_container)) + return results + + +@entry_router.put( + "/rename", + error_map=error_map, + dependencies=[Depends(require_master_db)], +) +async def rename(request: RenameRequest, req: Request) -> LDAPResult: + """LDAP rename entry request.""" + return await request.handle_api(req.state.dishka_container) + + +@entry_router.delete( + "/delete", + error_map=error_map, + dependencies=[Depends(require_master_db)], +) +async def delete(request: DeleteRequest, req: Request) -> LDAPResult: """LDAP DELETE entry request.""" return await request.handle_api(req.state.dishka_container) -@entry_router.post("/delete_many", error_map=error_map) +@entry_router.post( + "/delete_many", + error_map=error_map, + dependencies=[Depends(require_master_db)], +) async def delete_many( requests: list[DeleteRequest], req: Request, @@ -132,7 +172,10 @@ async def delete_many( return results -@entry_router.post("/set_primary_group") +@entry_router.post( + "/set_primary_group", + dependencies=[Depends(require_master_db)], +) async def set_primary_group( request: PrimaryGroupRequest, session: FromDishka[AsyncSession], @@ -146,3 +189,20 @@ async def set_primary_group( ) except (ValueError, IntegrityError): raise HTTPException(status_code=400, detail="Invalid request") + + +@entry_router.get("/group/primary/{primary_group_id}") +async def get_group_path_dn_by_primary_grp_id( + primary_group_id: int, + session: FromDishka[AsyncSession], +) -> PrimaryGroupPathDNResponse: + """Get group path DN by primary group ID.""" + try: + path_dn = await get_group_path_dn_by_primary_group_id( + primary_group_id, + session, + ) + except ValueError: + raise HTTPException(status_code=404, detail="Invalid primary group ID") + + return PrimaryGroupPathDNResponse(path_dn=path_dn) diff --git a/app/api/main/schema.py b/app/api/main/schema.py index 537b0af7c..5ea6545a8 100644 --- a/app/api/main/schema.py +++ b/app/api/main/schema.py @@ -4,7 +4,6 @@ License: https://github.com/MultiDirectoryLab/MultiDirectory/blob/main/LICENSE """ -from ipaddress import IPv4Address, IPv6Address from typing import final from dishka import AsyncContainer @@ -12,7 +11,6 @@ from sqlalchemy.sql.elements import ColumnElement, UnaryExpression from entities import Directory -from ldap_protocol.dns import DNSManagerState, DNSZoneParam, DNSZoneType from ldap_protocol.filter_interpreter import ( Filter, FilterInterpreterProtocol, @@ -70,82 +68,28 @@ class KerberosSetupRequest(BaseModel): stash_password: SecretStr -class DNSServiceSetupRequest(BaseModel): - """DNS setup request schema.""" +class PrincipalAddRequest(BaseModel): + """Request schema for POST /principal/add.""" - dns_status: DNSManagerState - domain: str - dns_ip_address: IPv4Address | IPv6Address | None = None - tsig_key: str | None = None + principal_name: str + algorithms: list[str] | None = None + password: str | None = None -class DNSServiceRecordBaseRequest(BaseModel): - """DNS setup base schema.""" +class KtaddRequest(BaseModel): + """Request schema for POST /ktadd.""" - record_name: str - record_type: str - zone_name: str | None = None + names: list[str] + is_rand_key: bool = False -class DNSServiceRecordCreateRequest(DNSServiceRecordBaseRequest): - """DNS create request schema.""" +class ModifyPrincipalRequest(BaseModel): + """Request schema for PUT /principal (full modify).""" - record_value: str - ttl: int | None = None - - -class DNSServiceRecordDeleteRequest(DNSServiceRecordBaseRequest): - """DNS delete request schema.""" - - record_value: str - - -class DNSServiceRecordUpdateRequest(DNSServiceRecordBaseRequest): - """DNS update request schema.""" - - record_value: str | None = None - ttl: int | None = None - - -class DNSServiceZoneCreateRequest(BaseModel): - """DNS zone create request scheme.""" - - zone_name: str - zone_type: DNSZoneType - nameserver: str | None = None - params: list[DNSZoneParam] - - -class DNSServiceZoneUpdateRequest(BaseModel): - """DNS zone update request scheme.""" - - zone_name: str - params: list[DNSZoneParam] - - -class DNSServiceZoneDeleteRequest(BaseModel): - """DNS zone delete request scheme.""" - - zone_names: list[str] - - -class DNSServiceReloadZoneRequest(BaseModel): - """DNS zone reload request scheme.""" - - zone_name: str - - -class DNSServiceForwardZoneCheckRequest(BaseModel): - """Forwarder DNS server check request scheme.""" - - dns_server_ips: list[IPv4Address | IPv6Address] - - -class DNSServiceOptionsUpdateRequest(BaseModel): - """DNS server options update request scheme.""" - - name: str - value: str | list[str] = "" + principal_name: str + new_name: str | None = None + algorithms: list[str] | None = None + password: str | None = None class PrimaryGroupRequest(BaseModel): @@ -153,3 +97,9 @@ class PrimaryGroupRequest(BaseModel): directory_dn: GRANT_DN_STRING group_dn: GRANT_DN_STRING + + +class PrimaryGroupPathDNResponse(BaseModel): + """Response schema for getting group path DN by primary group ID.""" + + path_dn: str diff --git a/app/api/network/router.py b/app/api/network/router.py index bc65ed858..71c87cb5b 100644 --- a/app/api/network/router.py +++ b/app/api/network/router.py @@ -18,6 +18,7 @@ DomainErrorTranslator, ) from api.network.adapters.network import NetworkPolicyFastAPIAdapter +from api.utils import require_master_db from enums import DomainCodes from ldap_protocol.policies.network.exceptions import ( LastActivePolicyError, @@ -38,7 +39,7 @@ error_map: ERROR_MAP_TYPE = { NetworkPolicyAlreadyExistsError: rule( - status=status.HTTP_422_UNPROCESSABLE_ENTITY, + status=status.HTTP_422_UNPROCESSABLE_CONTENT, translator=translator, ), NetworkPolicyNotFoundError: rule( @@ -46,7 +47,7 @@ translator=translator, ), LastActivePolicyError: rule( - status=status.HTTP_422_UNPROCESSABLE_ENTITY, + status=status.HTTP_422_UNPROCESSABLE_CONTENT, translator=translator, ), } @@ -64,6 +65,7 @@ "", status_code=status.HTTP_201_CREATED, error_map=error_map, + dependencies=[Depends(require_master_db)], ) async def add_network_policy( policy: Policy, @@ -97,6 +99,7 @@ async def get_list_network_policies( response_class=RedirectResponse, status_code=status.HTTP_303_SEE_OTHER, error_map=error_map, + dependencies=[Depends(require_master_db)], ) async def delete_network_policy( policy_id: int, @@ -114,7 +117,11 @@ async def delete_network_policy( return await adapter.delete(request, policy_id) # type: ignore -@network_router.patch("/{policy_id}", error_map=error_map) +@network_router.patch( + "/{policy_id}", + error_map=error_map, + dependencies=[Depends(require_master_db)], +) async def switch_network_policy( policy_id: int, adapter: FromDishka[NetworkPolicyFastAPIAdapter], @@ -133,7 +140,11 @@ async def switch_network_policy( return await adapter.switch_network_policy(policy_id) -@network_router.put("", error_map=error_map) +@network_router.put( + "", + error_map=error_map, + dependencies=[Depends(require_master_db)], +) async def update_network_policy( request: PolicyUpdate, adapter: FromDishka[NetworkPolicyFastAPIAdapter], @@ -150,7 +161,11 @@ async def update_network_policy( return await adapter.update(request) -@network_router.post("/swap", error_map=error_map) +@network_router.post( + "/swap", + error_map=error_map, + dependencies=[Depends(require_master_db)], +) async def swap_network_policy( swap: SwapRequest, adapter: FromDishka[NetworkPolicyFastAPIAdapter], diff --git a/app/api/network/utils.py b/app/api/network/utils.py index 532399eb9..a01db466d 100644 --- a/app/api/network/utils.py +++ b/app/api/network/utils.py @@ -27,6 +27,6 @@ async def check_policy_count(session: AsyncSession) -> None: if count.one() == 1: raise HTTPException( - status.HTTP_422_UNPROCESSABLE_ENTITY, + status.HTTP_422_UNPROCESSABLE_CONTENT, "At least one policy should be active", ) diff --git a/app/api/password_policy/password_ban_word_router.py b/app/api/password_policy/password_ban_word_router.py index a0c06a04e..5185124dc 100644 --- a/app/api/password_policy/password_ban_word_router.py +++ b/app/api/password_policy/password_ban_word_router.py @@ -13,6 +13,7 @@ from api.error_routing import DishkaErrorAwareRoute from api.password_policy.adapter import PasswordBanWordsFastAPIAdapter from api.password_policy.error_utils import error_map +from api.utils import require_master_db password_ban_word_router = ErrorAwareRouter( prefix="/password_ban_word", @@ -26,6 +27,7 @@ "/upload_txt", status_code=status.HTTP_201_CREATED, error_map=error_map, + dependencies=[Depends(require_master_db)], ) async def upload_ban_words_txt( file: UploadFile, diff --git a/app/api/password_policy/password_policy_router.py b/app/api/password_policy/password_policy_router.py index 812777ecd..36bd206c3 100644 --- a/app/api/password_policy/password_policy_router.py +++ b/app/api/password_policy/password_policy_router.py @@ -13,6 +13,7 @@ from api.password_policy.adapter import PasswordPolicyFastAPIAdapter from api.password_policy.error_utils import error_map from api.password_policy.schemas import PasswordPolicySchema +from api.utils import require_master_db from ldap_protocol.utils.const import GRANT_DN_STRING from .schemas import PriorityT @@ -51,7 +52,11 @@ async def get_password_policy_by_dir_path_dn( return await adapter.get_password_policy_by_dir_path_dn(path_dn) -@password_policy_router.put("/{id_}", error_map=error_map) +@password_policy_router.put( + "/{id_}", + error_map=error_map, + dependencies=[Depends(require_master_db)], +) async def update( id_: int, policy: PasswordPolicySchema[PriorityT], @@ -61,7 +66,11 @@ async def update( await adapter.update(id_, policy) -@password_policy_router.put("/reset/domain_policy", error_map=error_map) +@password_policy_router.put( + "/reset/domain_policy", + error_map=error_map, + dependencies=[Depends(require_master_db)], +) async def reset_domain_policy_to_default_config( adapter: FromDishka[PasswordPolicyFastAPIAdapter], ) -> None: diff --git a/app/api/password_policy/user_password_history_router.py b/app/api/password_policy/user_password_history_router.py index 2285c3cdd..9af233c12 100644 --- a/app/api/password_policy/user_password_history_router.py +++ b/app/api/password_policy/user_password_history_router.py @@ -18,6 +18,7 @@ DomainErrorTranslator, ) from api.password_policy.adapter import UserPasswordHistoryResetFastAPIAdapter +from api.utils import require_master_db from enums import DomainCodes from ldap_protocol.identity.exceptions import ( AuthorizationError, @@ -39,7 +40,7 @@ user_password_history_router = ErrorAwareRouter( prefix="/user/password_history", - dependencies=[Depends(verify_auth)], + dependencies=[Depends(verify_auth), Depends(require_master_db)], tags=["User Password history"], route_class=DishkaErrorAwareRoute, ) diff --git a/app/api/shadow/router.py b/app/api/shadow/router.py index ee8938a18..b708babb0 100644 --- a/app/api/shadow/router.py +++ b/app/api/shadow/router.py @@ -8,7 +8,7 @@ from typing import Annotated from dishka import FromDishka -from fastapi import Body, status +from fastapi import Body, Depends, status from fastapi_error_map.routing import ErrorAwareRouter from fastapi_error_map.rules import rule @@ -17,6 +17,7 @@ DishkaErrorAwareRoute, DomainErrorTranslator, ) +from api.utils import require_master_db from enums import DomainCodes from ldap_protocol.auth.exceptions.mfa import ( AuthenticationError, @@ -46,7 +47,7 @@ translator=translator, ), PasswordPolicyError: rule( - status=status.HTTP_422_UNPROCESSABLE_ENTITY, + status=status.HTTP_422_UNPROCESSABLE_CONTENT, translator=translator, ), PermissionError: rule( @@ -67,7 +68,11 @@ async def proxy_request( return await adapter.proxy_request(principal, ip) -@shadow_router.post("/sync/password", error_map=error_map) +@shadow_router.post( + "/sync/password", + error_map=error_map, + dependencies=[Depends(require_master_db)], +) async def change_password( principal: Annotated[str, Body(embed=True)], new_password: Annotated[str, Body(embed=True)], diff --git a/app/api/utils.py b/app/api/utils.py new file mode 100644 index 000000000..5f94d56f6 --- /dev/null +++ b/app/api/utils.py @@ -0,0 +1,22 @@ +"""Utils with master database check. + +Copyright (c) 2026 MultiFactor +License: https://github.com/MultiDirectoryLab/MultiDirectory/blob/main/LICENSE +""" + +from dishka import FromDishka +from dishka.integrations.fastapi import inject +from fastapi import HTTPException, status + +from ldap_protocol.master_check_use_case import MasterCheckUseCase + + +@inject +async def require_master_db( + master_check_use_case: FromDishka[MasterCheckUseCase], +) -> None: + if not await master_check_use_case.check_master(): + raise HTTPException( + status_code=status.HTTP_503_SERVICE_UNAVAILABLE, + detail="Master DB is not available", + ) diff --git a/app/config.py b/app/config.py index 423eb2bf8..dcf689f9f 100644 --- a/app/config.py +++ b/app/config.py @@ -24,6 +24,8 @@ ) from sqlalchemy.ext.asyncio import AsyncEngine, create_async_engine +from enums import PostgresRWModeType + def _get_vendor_version() -> str: with open("/pyproject.toml", "rb") as f: @@ -34,6 +36,7 @@ class Settings(BaseModel): """Settigns with database dsn.""" DOMAIN: str + HOST_MACHINE_NAME: str DEBUG: bool = False AUTO_RELOAD: bool = False @@ -49,12 +52,20 @@ class Settings(BaseModel): TCP_PACKET_SIZE: int = 1024 COROUTINES_NUM_PER_CLIENT: int = 3 + POSTGRES_RW_MODE: PostgresRWModeType = PostgresRWModeType.SINGLE POSTGRES_SCHEMA: ClassVar[str] = "postgresql+psycopg" - POSTGRES_DB: str = "postgres" + POSTGRES_REPLICA_DB: str = "" + POSTGRES_REPLICA_HOST: str = "" + POSTGRES_REPLICA_USER: str = "" + POSTGRES_REPLICA_PASSWORD: str = "" + POSTGRES_REPLICA_CONNECT_TIMEOUT: int = 4 + + POSTGRES_DB: str = "postgres" POSTGRES_HOST: str = "postgres" POSTGRES_USER: str POSTGRES_PASSWORD: str + POSTGRES_CONNECT_TIMEOUT: int = 4 SESSION_STORAGE_URL: RedisDsn = RedisDsn("redis://dragonfly:6379/1") SESSION_KEY_LENGTH: int = 16 @@ -87,6 +98,15 @@ class Settings(BaseModel): AUDIT_SECOND_RETRY_TIME: int = 60 AUDIT_THIRD_RETRY_TIME: int = 1440 + @computed_field # type: ignore + @cached_property + def HOST_MACHINE_SHORT_NAME(self) -> str: # noqa: N802 + """Host machine name part before the first dot.""" + value = self.HOST_MACHINE_NAME.strip() + if not value: + raise ValueError("HOST_MACHINE_NAME is not set or empty") + return value.split(".", 1)[0] + @computed_field # type: ignore @cached_property def POSTGRES_URI(self) -> PostgresDsn: # noqa @@ -99,6 +119,54 @@ def POSTGRES_URI(self) -> PostgresDsn: # noqa f"{self.POSTGRES_DB}", ) + @computed_field # type: ignore + @cached_property + def REPLICA_POSTGRES_URI(self) -> PostgresDsn: # noqa + """Build replica postgres DSN.""" + return PostgresDsn( + f"{self.POSTGRES_SCHEMA}://" + f"{self.POSTGRES_REPLICA_USER}:" + f"{self.POSTGRES_REPLICA_PASSWORD}@" + f"{self.POSTGRES_REPLICA_HOST}/" + f"{self.POSTGRES_REPLICA_DB}", + ) + + @cached_property + def engine(self) -> AsyncEngine: + """Get engine.""" + return create_async_engine( + str(self.POSTGRES_URI), + pool_size=self.INSTANCE_DB_POOL_SIZE, + max_overflow=self.INSTANCE_DB_POOL_OVERFLOW, + pool_timeout=self.INSTANCE_DB_POOL_TIMEOUT, + pool_recycle=self.INSTANCE_DB_POOL_RECYCLE, + pool_pre_ping=False, + future=True, + echo=False, + logging_name="master", + connect_args={"connect_timeout": self.POSTGRES_CONNECT_TIMEOUT}, + ) + + @cached_property + def replica_engine(self) -> AsyncEngine | None: + if self.POSTGRES_RW_MODE == PostgresRWModeType.SINGLE: + return None + + return create_async_engine( + str(self.REPLICA_POSTGRES_URI), + pool_size=self.INSTANCE_DB_POOL_SIZE, + max_overflow=self.INSTANCE_DB_POOL_OVERFLOW, + pool_timeout=self.INSTANCE_DB_POOL_TIMEOUT, + pool_recycle=self.INSTANCE_DB_POOL_RECYCLE, + pool_pre_ping=False, + future=True, + echo=False, + logging_name="replica", + connect_args={ + "connect_timeout": self.POSTGRES_REPLICA_CONNECT_TIMEOUT, + }, + ) + VENDOR_NAME: ClassVar[str] = "MultiFactor" VENDOR_VERSION: str = Field( default_factory=_get_vendor_version, @@ -130,7 +198,18 @@ def POSTGRES_URI(self) -> PostgresDsn: # noqa autoescape=True, ) - DNS_BIND_HOST: str = "bind_dns" + PDNS_AUTH_SERVER_HOST: str = "pdns_auth" + PDNS_AUTH_SERVER_IP: str = "172.20.0.202" + PDNS_AUTH_SERVER_PORT: int = 8082 + PDNS_RECURSOR_SERVER_HOST: str = "pdns_recursor" + PDNS_RECURSOR_SERVER_IP: str = "172.20.0.200" + PDNS_RECURSOR_SERVER_PORT: int = 8083 + PDNS_DIST_IP: str = "172.20.0.201" + PDNS_DIST_PORT: int = 8084 + PDNS_DIST_CONFIG_PATH: str = "/dnsdist/delta.conf" + PDNS_DIST_KEY: str + PDNS_API_KEY: str + DEFAULT_NAMESERVER: str ENABLE_SQLALCHEMY_LOGGING: bool = False PYTEST_XDIST_WORKER: str = "master" @@ -185,6 +264,12 @@ def MFA_API_URI(self) -> str: # noqa: N802 return "https://api.multifactor.dev" return "https://api.multifactor.ru" + @computed_field # type: ignore + @cached_property + def is_global_catalog(self) -> bool: + """Check if this is Global Catalog server.""" + return self.PORT in (self.GLOBAL_LDAP_PORT, self.GLOBAL_LDAP_TLS_PORT) + @computed_field # type: ignore @cached_property def KRB5_CONFIG_SERVER(self) -> HttpUrl: # noqa: N802 @@ -220,20 +305,6 @@ def check_certs_exist(self) -> bool: """Check if certs exist.""" return os.path.exists(self.SSL_CERT) and os.path.exists(self.SSL_KEY) - @cached_property - def engine(self) -> AsyncEngine: - """Get engine.""" - return create_async_engine( - str(self.POSTGRES_URI), - pool_size=self.INSTANCE_DB_POOL_SIZE, - max_overflow=self.INSTANCE_DB_POOL_OVERFLOW, - pool_timeout=self.INSTANCE_DB_POOL_TIMEOUT, - pool_recycle=self.INSTANCE_DB_POOL_RECYCLE, - pool_pre_ping=False, - future=True, - echo=False, - ) - @classmethod def from_os(cls) -> "Settings": """Get cls from environ.""" diff --git a/app/constants.py b/app/constants.py index f54d78a35..5086dfad1 100644 --- a/app/constants.py +++ b/app/constants.py @@ -6,11 +6,12 @@ from typing import TypedDict -from enums import EntityTypeNames +from enums import EntityTypeNames, SamAccountTypeCodes -GROUPS_CONTAINER_NAME = "groups" -COMPUTERS_CONTAINER_NAME = "computers" -USERS_CONTAINER_NAME = "users" +GROUPS_CONTAINER_NAME = "Groups" +COMPUTERS_CONTAINER_NAME = "Computers" +USERS_CONTAINER_NAME = "Users" +DOMAIN_CONTROLLERS_OU_NAME = "Domain Controllers" READ_ONLY_GROUP_NAME = "read-only" @@ -24,7 +25,7 @@ "groupType": ["-2147483646"], "instanceType": ["4"], "sAMAccountName": ["groups"], - "sAMAccountType": ["268435456"], + "sAMAccountType": [str(SamAccountTypeCodes.SAM_GROUP_OBJECT.value)], } @@ -308,7 +309,9 @@ class EntityTypeData(TypedDict): "groupType": ["-2147483646"], "instanceType": ["4"], "sAMAccountName": [DOMAIN_ADMIN_GROUP_NAME], - "sAMAccountType": ["268435456"], + "sAMAccountType": [ + str(SamAccountTypeCodes.SAM_GROUP_OBJECT.value), + ], "gidNumber": ["512"], }, "objectSid": 512, @@ -321,7 +324,9 @@ class EntityTypeData(TypedDict): "groupType": ["-2147483646"], "instanceType": ["4"], "sAMAccountName": [DOMAIN_USERS_GROUP_NAME], - "sAMAccountType": ["268435456"], + "sAMAccountType": [ + str(SamAccountTypeCodes.SAM_GROUP_OBJECT.value), + ], "gidNumber": ["513"], }, "objectSid": 513, @@ -334,7 +339,9 @@ class EntityTypeData(TypedDict): "groupType": ["-2147483646"], "instanceType": ["4"], "sAMAccountName": [READ_ONLY_GROUP_NAME], - "sAMAccountType": ["268435456"], + "sAMAccountType": [ + str(SamAccountTypeCodes.SAM_GROUP_OBJECT.value), + ], "gidNumber": ["521"], }, "objectSid": 521, @@ -347,7 +354,9 @@ class EntityTypeData(TypedDict): "groupType": ["-2147483646"], "instanceType": ["4"], "sAMAccountName": [DOMAIN_COMPUTERS_GROUP_NAME], - "sAMAccountType": ["268435456"], + "sAMAccountType": [ + str(SamAccountTypeCodes.SAM_GROUP_OBJECT.value), + ], "gidNumber": ["515"], }, "objectSid": 515, diff --git a/app/db_routing.py b/app/db_routing.py new file mode 100644 index 000000000..f19ff6e1e --- /dev/null +++ b/app/db_routing.py @@ -0,0 +1,90 @@ +"""Engine registry and routing session. + +Copyright (c) 2026 MultiFactor +License: https://github.com/MultiDirectoryLab/MultiDirectory/blob/main/LICENSE +""" + +from typing import Any, Sequence + +from sqlalchemy import Delete, Insert, Update, exc as sa_exc +from sqlalchemy.engine import Engine +from sqlalchemy.ext.asyncio import AsyncEngine +from sqlalchemy.orm import Session + +from enums import PostgresRWModeType + + +class EngineRegistry: + _master_engine: AsyncEngine + _replica_engine: AsyncEngine | None + + def __init__( + self, + master_engine: AsyncEngine, + replica_engine: AsyncEngine | None, + ) -> None: + self._master_engine = master_engine + self._replica_engine = replica_engine + + def get_master_engine(self) -> AsyncEngine: + return self._master_engine + + def get_replica_engine(self) -> AsyncEngine: + if self._replica_engine is None: + raise RuntimeError("Replica engine is not configured") + return self._replica_engine + + def get_sync_master_engine(self) -> Engine: + return self._master_engine.sync_engine + + def get_sync_replica_engine(self) -> Engine: + if self._replica_engine is None: + raise RuntimeError("Replica engine is not configured") + return self._replica_engine.sync_engine + + +class RoutingSession(Session): + _force_master: bool = False + + @property + def engine_registry(self) -> EngineRegistry: + engine_registry = self.info.get("engine_registry") + if engine_registry is None: + raise RuntimeError("Engine registry is not configured") + return engine_registry + + @property + def rw_mode(self) -> PostgresRWModeType: + rw_mode = self.info.get("rw_mode") + if rw_mode is None: + raise RuntimeError("RW mode is not configured") + return rw_mode + + def set_force_master(self, value: bool) -> None: + self._force_master = value + + def get_bind(self, mapper=None, *, clause=None, **kw) -> Engine: # type: ignore # noqa: ARG002 + if self.rw_mode == PostgresRWModeType.SINGLE: + return self.engine_registry.get_sync_master_engine() + + if isinstance(clause, Update | Insert | Delete): + self._force_master = True + return self.engine_registry.get_sync_master_engine() + + if self._force_master or self._flushing: + return self.engine_registry.get_sync_master_engine() + else: + return self.engine_registry.get_sync_replica_engine() + + def flush(self, objects: Sequence[Any] | None = None) -> None: + if self._flushing: + raise sa_exc.InvalidRequestError("Session is already flushing") + + if self._is_clean(): + return + try: + self._flushing = True + self._flush(objects) + finally: + self._flushing = False + self._force_master = True diff --git a/app/entities.py b/app/entities.py index 535da02f6..9b4d70e16 100644 --- a/app/entities.py +++ b/app/entities.py @@ -69,6 +69,7 @@ class AttributeType: single_value: bool = False no_user_modification: bool = False is_system: bool = False + system_flags: int = 0 # NOTE: ms-adts/cf133d47-b358-4add-81d3-15ea1cff9cd9 # see section 3.1.1.2.3 `searchFlags` (fANR) for details is_included_anr: bool = False @@ -240,6 +241,7 @@ class Directory: "objectguid", "objectsid", "entitytypename", + "name", } def get_dn_prefix(self) -> DistinguishedNamePrefix: @@ -259,10 +261,6 @@ def get_dn(self, dn: str = "cn") -> str: def is_domain(self) -> bool: return not self.parent_id and self.object_class == "domain" - @property - def host_principal(self) -> str: - return f"host/{self.name}" - @property def path_dn(self) -> str: return ",".join(reversed(self.path)) @@ -372,9 +370,6 @@ class User: "homedirectory": "homeDirectory", } - def get_upn_prefix(self) -> str: - return self.user_principal_name.split("@")[0] - def is_expired(self) -> bool: if self.account_exp is None: return False diff --git a/app/enums.py b/app/enums.py index f482b928e..2c991d9f4 100644 --- a/app/enums.py +++ b/app/enums.py @@ -12,6 +12,13 @@ from typing import Iterable, Self +class PostgresRWModeType(StrEnum): + """Postgres read/write mode type.""" + + SINGLE = "single" + REPLICATION = "replication" + + class AceType(IntEnum): """ACE types.""" @@ -105,9 +112,9 @@ class RoleConstants(StrEnum): READ_ONLY_ROLE_NAME = "Read Only Role" KERBEROS_ROLE_NAME = "Kerberos Role" - DOMAIN_ADMINS_GROUP_CN = "cn=domain admins,cn=groups," - READONLY_GROUP_CN = "cn=read-only,cn=groups," - KERBEROS_GROUP_CN = "cn=krbadmin,cn=groups," + DOMAIN_ADMINS_GROUP_CN = "cn=domain admins,cn=Groups," + READONLY_GROUP_CN = "cn=read-only,cn=Groups," + KERBEROS_GROUP_CN = "cn=krbadmin,cn=Groups," @verify(UNIQUE) @@ -150,6 +157,7 @@ class AuthorizationRules(IntFlag): ATTRIBUTE_TYPE_GET_PAGINATOR = auto() ATTRIBUTE_TYPE_UPDATE = auto() ATTRIBUTE_TYPE_DELETE_ALL_BY_NAMES = auto() + ATTRIBUTE_TYPE_SET_ATTR_REPLICATION_FLAG = auto() ENTITY_TYPE_GET = auto() ENTITY_TYPE_CREATE = auto() @@ -170,24 +178,22 @@ class AuthorizationRules(IntFlag): DNS_UPDATE_RECORD = auto() DNS_GET_ALL_RECORDS = auto() DNS_GET_DNS_STATUS = auto() - DNS_GET_ALL_ZONES_RECORDS = auto() - DNS_GET_FORWARD_ZONES = auto() - DNS_CREATE_ZONE = auto() - DNS_UPDATE_ZONE = auto() - DNS_DELETE_ZONE = auto() + DNS_GET_MASTER_ZONES = auto() + DNS_GET_FWD_ZONES = auto() + DNS_DELETE_MASTER_ZONES = auto() + DNS_DELETE_FWD_ZONES = auto() + DNS_CREATE_MASTER_ZONE = auto() + DNS_CREATE_FWD_ZONE = auto() + DNS_UPDATE_MASTER_ZONE = auto() + DNS_UPDATE_FWD_ZONE = auto() DNS_CHECK_DNS_FORWARD_ZONE = auto() - DNS_RELOAD_ZONE = auto() - DNS_UPDATE_SERVER_OPTIONS = auto() - DNS_GET_SERVER_OPTIONS = auto() - DNS_RESTART_SERVER = auto() KRB_SETUP_CATALOGUE = auto() KRB_SETUP_KDC = auto() KRB_KTADD = auto() KRB_GET_STATUS = auto() KRB_ADD_PRINCIPAL = auto() - KRB_RENAME_PRINCIPAL = auto() - KRB_RESET_PRINCIPAL_PW = auto() + KRB_MODIFY_PRINCIPAL = auto() KRB_DELETE_PRINCIPAL = auto() AUDIT_GET_POLICIES = auto() @@ -254,3 +260,23 @@ class DomainCodes(IntEnum): DHCP = 12 LDAP_SCHEMA = 13 SHADOW = 14 + + +class SamAccountTypeCodes(IntEnum): + """SAM Account Type values.""" + + SAM_DOMAIN_OBJECT = 0 + SAM_GROUP_OBJECT = 268435456 + SAM_NON_SECURITY_GROUP_OBJECT = 268435457 + SAM_ALIAS_OBJECT = 536870912 + SAM_NON_SECURITY_ALIAS_OBJECT = 536870913 + SAM_USER_OBJECT = 805306368 + SAM_MACHINE_ACCOUNT = 805306369 + SAM_TRUST_ACCOUNT = 805306370 + SAM_APP_BASIC_GROUP = 1073741824 + SAM_APP_QUERY_GROUP = 1073741825 + + @staticmethod + def to_hex(value: int) -> str: + """Convert decimal value to hex string.""" + return hex(value) diff --git a/app/extra/scripts/add_domain_controller.py b/app/extra/scripts/add_domain_controller.py new file mode 100644 index 000000000..3f700328a --- /dev/null +++ b/app/extra/scripts/add_domain_controller.py @@ -0,0 +1,147 @@ +"""Add domain controller. + +Copyright (c) 2026 MultiFactor +License: https://github.com/MultiDirectoryLab/MultiDirectory/blob/main/LICENSE +""" + +from loguru import logger +from sqlalchemy import select +from sqlalchemy.ext.asyncio import AsyncSession + +from config import Settings +from constants import DOMAIN_CONTROLLERS_OU_NAME +from entities import Attribute, Directory +from enums import SamAccountTypeCodes +from ldap_protocol.ldap_schema.entity_type_dao import EntityTypeDAO +from ldap_protocol.objects import UserAccountControlFlag +from ldap_protocol.roles.role_use_case import RoleUseCase +from ldap_protocol.utils.helpers import create_object_sid +from ldap_protocol.utils.queries import get_base_directories +from repo.pg.tables import queryable_attr as qa + + +async def _add_domain_controller( + session: AsyncSession, + role_use_case: RoleUseCase, + entity_type_dao: EntityTypeDAO, + settings: Settings, + domain: Directory, + dc_ou_dir: Directory, +) -> None: + dc_directory = Directory( + object_class="", + name=settings.HOST_MACHINE_SHORT_NAME, + is_system=True, + ) + dc_directory.create_path(dc_ou_dir) + session.add(dc_directory) + await session.flush() + + dc_directory.parent_id = dc_ou_dir.id + dc_directory.object_sid = create_object_sid(domain, dc_directory.id) + await session.flush() + + attributes = [ + Attribute( + name="objectClass", + value="top", + directory_id=dc_directory.id, + ), + Attribute( + name="objectClass", + value="computer", + directory_id=dc_directory.id, + ), + Attribute( + name="sAMAccountName", + value=settings.HOST_MACHINE_SHORT_NAME, + directory_id=dc_directory.id, + ), + Attribute( + name="userAccountControl", + value=str( + UserAccountControlFlag.SERVER_TRUST_ACCOUNT, + ), + directory_id=dc_directory.id, + ), + Attribute( + name="sAMAccountType", + value=str(SamAccountTypeCodes.SAM_MACHINE_ACCOUNT), + directory_id=dc_directory.id, + ), + Attribute( + name="ipHostNumber", + value=settings.DEFAULT_NAMESERVER, + directory_id=dc_directory.id, + ), + Attribute( + name="cn", + value=settings.HOST_MACHINE_SHORT_NAME, + directory_id=dc_directory.id, + ), + ] + + session.add_all(attributes) + await session.flush() + + await role_use_case.inherit_parent_aces( + parent_directory=dc_ou_dir, + directory=dc_directory, + ) + await entity_type_dao.attach_entity_type_to_directory( + directory=dc_directory, + is_system_entity_type=False, + object_class_names={"top", "computer"}, + ) + await session.flush() + + +async def add_domain_controller( + session: AsyncSession, + settings: Settings, + role_use_case: RoleUseCase, + entity_type_dao: EntityTypeDAO, +) -> None: + logger.info("Adding domain controller.") + + domains = await get_base_directories(session) + if not domains: + logger.debug("Cannot get base directory") + return + + domain_controllers_ou = await session.scalar( + select(Directory).where( + qa(Directory.name) == DOMAIN_CONTROLLERS_OU_NAME, + ), + ) + + if not domain_controllers_ou: + logger.debug("Domain controllers OU does not exist.") + return + + domain_controller = await session.scalar( + select(qa(Directory.id).distinct()) + .join(qa(Directory.attributes)) + .where( + qa(Directory.parent_id) == domain_controllers_ou.id, + qa(Attribute.name) == "ipHostNumber", + qa(Attribute.value) == settings.DEFAULT_NAMESERVER, + ), + ) + + if domain_controller: + logger.debug("Domain controllers already exists") + return + + await _add_domain_controller( + session=session, + role_use_case=role_use_case, + entity_type_dao=entity_type_dao, + settings=settings, + domain=domains[0], + dc_ou_dir=domain_controllers_ou, + ) + + logger.debug("Domain controller added.") + + await session.commit() diff --git a/app/extra/scripts/principal_block_user_sync.py b/app/extra/scripts/principal_block_user_sync.py index 858a65a7d..d4c483728 100644 --- a/app/extra/scripts/principal_block_user_sync.py +++ b/app/extra/scripts/principal_block_user_sync.py @@ -32,7 +32,7 @@ async def principal_block_sync( if "@" in user.user_principal_name: principal_postfix = user.user_principal_name.split("@")[1].upper() - principal_name = f"{user.get_upn_prefix()}@{principal_postfix}" + principal_name = f"{user.sam_account_name}@{principal_postfix}" else: continue diff --git a/app/extra/scripts/uac_sync.py b/app/extra/scripts/uac_sync.py index f0623e1b1..8f9ce4f46 100644 --- a/app/extra/scripts/uac_sync.py +++ b/app/extra/scripts/uac_sync.py @@ -71,7 +71,7 @@ async def disable_accounts( ) # fmt: skip async for user in users: - await kadmin.lock_principal(user.get_upn_prefix()) + await kadmin.lock_principal(user.sam_account_name) await add_lock_and_expire_attributes( session, diff --git a/app/extra/scripts/update_krb5_config.py b/app/extra/scripts/update_krb5_config.py index b0ecda0f6..325e2a47c 100644 --- a/app/extra/scripts/update_krb5_config.py +++ b/app/extra/scripts/update_krb5_config.py @@ -53,7 +53,7 @@ async def update_krb5_config( base_dn_list = await get_base_directories(session) if not base_dn_list: - logger.error("No base directories found") + logger.warning("No base directories found") return base_dn = base_dn_list[0].path_dn diff --git a/app/ioc.py b/app/ioc.py index d6489f842..1af2d9c98 100644 --- a/app/ioc.py +++ b/app/ioc.py @@ -8,12 +8,12 @@ import httpx import redis.asyncio as redis +from db_routing import EngineRegistry, RoutingSession from dishka import Provider, Scope, from_context, provide from fastapi import Request from loguru import logger from sqlalchemy.ext.asyncio import ( AsyncConnection, - AsyncEngine, AsyncSession, async_sessionmaker, ) @@ -26,10 +26,10 @@ ) from api.auth.utils import get_ip_from_request, get_user_agent_from_request from api.dhcp.adapter import DHCPAdapter +from api.dns.adapter import DNSFastAPIAdapter from api.ldap_schema.adapters.attribute_type import AttributeTypeFastAPIAdapter from api.ldap_schema.adapters.entity_type import LDAPEntityTypeFastAPIAdapter from api.ldap_schema.adapters.object_class import ObjectClassFastAPIAdapter -from api.main.adapters.dns import DNSFastAPIAdapter from api.main.adapters.kerberos import KerberosFastAPIAdapter from api.network.adapters.network import NetworkPolicyFastAPIAdapter from api.password_policy.adapter import ( @@ -54,12 +54,20 @@ from ldap_protocol.dialogue import LDAPSession from ldap_protocol.dns import ( AbstractDNSManager, - DNSManagerSettings, - get_dns_manager_class, + DNSManagerState, + DNSSettingsDTO, + DNSStateGateway, + DNSUseCase, + PowerDNSAuthHTTPClient, + PowerDNSDistClient, + PowerDNSManager, + PowerDNSRecursorHTTPClient, + RemoteDNSManager, + StubDNSManager, +) +from ldap_protocol.dns.bind_to_pdns_migration_use_case import ( + BindToPDNSMigrationUseCase, ) -from ldap_protocol.dns.dns_gateway import DNSStateGateway -from ldap_protocol.dns.use_cases import DNSUseCase -from ldap_protocol.dns.utils import resolve_dns_server_ip from ldap_protocol.identity import IdentityProvider from ldap_protocol.identity.provider_gateway import IdentityProviderGateway from ldap_protocol.kerberos import AbstractKadmin, get_kerberos_class @@ -78,6 +86,9 @@ LDAPUnbindRequestContext, ) from ldap_protocol.ldap_schema.attribute_type_dao import AttributeTypeDAO +from ldap_protocol.ldap_schema.attribute_type_system_flags_use_case import ( + AttributeTypeSystemFlagsUseCase, +) from ldap_protocol.ldap_schema.attribute_type_use_case import ( AttributeTypeUseCase, ) @@ -88,6 +99,10 @@ from ldap_protocol.ldap_schema.entity_type_use_case import EntityTypeUseCase from ldap_protocol.ldap_schema.object_class_dao import ObjectClassDAO from ldap_protocol.ldap_schema.object_class_use_case import ObjectClassUseCase +from ldap_protocol.master_check_use_case import ( + MasterCheckUseCase, + MasterGatewayProtocol, +) from ldap_protocol.multifactor import ( Creds, LDAPMultiFactorAPI, @@ -148,10 +163,10 @@ from ldap_protocol.session_storage import RedisSessionStorage, SessionStorage from ldap_protocol.session_storage.repository import SessionRepository from password_utils import PasswordUtils +from repo.pg.master_gateway import PGMasterGateway SessionStorageClient = NewType("SessionStorageClient", redis.Redis) KadminHTTPClient = NewType("KadminHTTPClient", httpx.AsyncClient) -DNSManagerHTTPClient = NewType("DNSManagerHTTPClient", httpx.AsyncClient) MFAHTTPClient = NewType("MFAHTTPClient", httpx.AsyncClient) DHCPManagerHTTPClient = NewType("DHCPManagerHTTPClient", httpx.AsyncClient) @@ -163,17 +178,27 @@ class MainProvider(Provider): settings = from_context(provides=Settings, scope=Scope.APP) @provide(scope=Scope.APP) - def get_engine(self, settings: Settings) -> AsyncEngine: - """Get async engine.""" - return settings.engine + def get_engine_registry(self, settings: Settings) -> EngineRegistry: + return EngineRegistry( + master_engine=settings.engine, + replica_engine=settings.replica_engine, + ) @provide(scope=Scope.APP) def get_session_factory( self, - engine: AsyncEngine, + settings: Settings, + engine_registry: EngineRegistry, ) -> async_sessionmaker[AsyncSession]: """Create session factory.""" - return async_sessionmaker(engine, expire_on_commit=False) + return async_sessionmaker( + sync_session_class=RoutingSession, + expire_on_commit=False, + info={ + "engine_registry": engine_registry, + "rw_mode": settings.POSTGRES_RW_MODE, + }, + ) @provide(scope=Scope.REQUEST) async def create_session( @@ -233,14 +258,6 @@ def get_kadmin( """ return kadmin_class(client) - @provide(scope=Scope.REQUEST) - async def get_dns_mngr_class( - self, - dns_state_gateway: DNSStateGateway, - ) -> type[AbstractDNSManager]: - """Get DNS manager type.""" - return await get_dns_manager_class(dns_state_gateway) - @provide(scope=Scope.REQUEST, provides=AuthorizationProviderProtocol) async def get_auth_provider_class( self, @@ -248,40 +265,99 @@ async def get_auth_provider_class( """Get AuthorizationProvider.""" return None - @provide(scope=Scope.REQUEST) - async def get_dns_mngr_settings( + @provide(scope=Scope.APP) + async def get_power_dns_auth_http_client( self, settings: Settings, - dns_state_gateway: DNSStateGateway, - ) -> DNSManagerSettings: - """Get DNS manager's settings.""" - resolve_coro = resolve_dns_server_ip( - settings.DNS_BIND_HOST, - ) - return await dns_state_gateway.get_dns_manager_settings( - resolve_coro, - ) + ) -> AsyncIterator[PowerDNSAuthHTTPClient]: + """Get PowerDNS Auth server client.""" + async with httpx.AsyncClient( + base_url=f"http://{settings.PDNS_AUTH_SERVER_HOST}:{settings.PDNS_AUTH_SERVER_PORT}/api/v1/servers/localhost", + headers={"X-API-Key": settings.PDNS_API_KEY}, + ) as client: + yield PowerDNSAuthHTTPClient(http_client=client) @provide(scope=Scope.APP) - async def get_dns_http_client( + async def get_power_dns_recursor_http_client( self, settings: Settings, - ) -> AsyncIterator[DNSManagerHTTPClient]: - """Get async client for DNS manager.""" + ) -> AsyncIterator[PowerDNSRecursorHTTPClient]: + """Get PowerDNS Auth server client.""" async with httpx.AsyncClient( - base_url=f"http://{settings.DNS_BIND_HOST}:8000", + base_url=f"http://{settings.PDNS_RECURSOR_SERVER_HOST}:{settings.PDNS_RECURSOR_SERVER_PORT}/api/v1/servers/localhost", + headers={"X-API-Key": settings.PDNS_API_KEY}, ) as client: - yield DNSManagerHTTPClient(client) + yield PowerDNSRecursorHTTPClient(http_client=client) + + @provide(scope=Scope.APP) + def get_power_dns_dist_client( + self, + settings: Settings, + ) -> PowerDNSDistClient: + """Get PowerDNS dist client.""" + return PowerDNSDistClient( + dnsdist_host=settings.PDNS_DIST_IP, + dnsdist_port=settings.PDNS_DIST_PORT, + dnsdist_key=settings.PDNS_DIST_KEY, + config_path=settings.PDNS_DIST_CONFIG_PATH, + ) + + @provide(scope=Scope.REQUEST) + async def get_dns_mngr_settings( + self, + dns_state_gateway: DNSStateGateway, + settings: Settings, + root_dse_gw: DomainReadProtocol, + ) -> AsyncIterator[DNSSettingsDTO]: + """Get DNS manager's settings.""" + domain = await root_dse_gw.get_domain() + dns_settings = await dns_state_gateway.get_dns_manager_settings( + settings, + domain.name, + ) + yield dns_settings @provide(scope=Scope.REQUEST) - def get_dns_mngr( + async def get_dns_mngr( self, - settings: DNSManagerSettings, - dns_manager_class: type[AbstractDNSManager], - http_client: DNSManagerHTTPClient, - ) -> AbstractDNSManager: + dns_settings: DNSSettingsDTO, + dns_state_gateway: DNSStateGateway, + power_dns_auth_client: PowerDNSAuthHTTPClient, + power_dns_recursor_client: PowerDNSRecursorHTTPClient, + power_dns_dist_client: PowerDNSDistClient, + ) -> AsyncIterator[AbstractDNSManager]: """Get DNSManager class.""" - return dns_manager_class(settings=settings, http_client=http_client) + state = await dns_state_gateway.get_state() + if state == DNSManagerState.SELFHOSTED: + yield PowerDNSManager( + settings=dns_settings, + power_dns_auth_client=power_dns_auth_client, + power_dns_recursor_client=power_dns_recursor_client, + dnsdist_client=power_dns_dist_client, + ) + elif state == DNSManagerState.HOSTED: + yield RemoteDNSManager(settings=dns_settings) + else: + yield StubDNSManager(settings=dns_settings) + + @provide(scope=Scope.REQUEST) + async def get_dns_migration_usecase( + self, + dns_settings: DNSSettingsDTO, + power_dns_auth_client: PowerDNSAuthHTTPClient, + power_dns_recursor_client: PowerDNSRecursorHTTPClient, + power_dns_dist_client: PowerDNSDistClient, + ) -> AsyncIterator[BindToPDNSMigrationUseCase]: + """Get DNS migration manager class.""" + yield BindToPDNSMigrationUseCase( + PowerDNSManager( + settings=dns_settings, + power_dns_auth_client=power_dns_auth_client, + power_dns_recursor_client=power_dns_recursor_client, + dnsdist_client=power_dns_dist_client, + ), + dns_settings=dns_settings, + ) @provide(scope=Scope.APP) async def get_redis_for_sessions( @@ -435,6 +511,10 @@ def get_dhcp_mngr( scope=Scope.RUNTIME, ) attribute_type_dao = provide(AttributeTypeDAO, scope=Scope.REQUEST) + attribute_type_system_flags_use_case = provide( + AttributeTypeSystemFlagsUseCase, + scope=Scope.REQUEST, + ) object_class_dao = provide(ObjectClassDAO, scope=Scope.REQUEST) entity_type_dao = provide(EntityTypeDAO, scope=Scope.REQUEST) attribute_type_use_case = provide( @@ -472,10 +552,10 @@ def get_dhcp_mngr( ace_dao = provide(AccessControlEntryDAO, scope=Scope.REQUEST) role_use_case = provide(RoleUseCase, scope=Scope.REQUEST) session_repository = provide(SessionRepository, scope=Scope.REQUEST) - entity_type_use_case = provide(EntityTypeUseCase, scope=Scope.REQUEST) dns_use_case = provide(DNSUseCase, scope=Scope.REQUEST) dns_state_gateway = provide(DNSStateGateway, scope=Scope.REQUEST) + rootdse_gw = provide( SADomainGateway, provides=DomainReadProtocol, @@ -571,6 +651,19 @@ def get_audit_monitor( session_key=session_key, ) + @provide(scope=Scope.REQUEST, provides=MasterGatewayProtocol) + async def get_master_gateway( + self, + session: AsyncSession, + settings: Settings, + ) -> PGMasterGateway: + return PGMasterGateway(session, settings) + + master_check_use_case = provide( + MasterCheckUseCase, + scope=Scope.REQUEST, + ) + identity_provider_gateway = provide( IdentityProviderGateway, scope=Scope.REQUEST, @@ -720,6 +813,29 @@ async def get_session( yield session await session.disconnect() + bind_request_context = provide( + LDAPBindRequestContext, + scope=Scope.REQUEST, + ) + search_request_context = provide( + LDAPSearchRequestContext, + scope=Scope.REQUEST, + ) + unbind_request_context = provide( + LDAPUnbindRequestContext, + scope=Scope.REQUEST, + ) + + network_policy_validator = provide( + NetworkPolicyValidatorGateway, + provides=NetworkPolicyValidatorProtocol, + scope=Scope.REQUEST, + ) + network_policy_validator_use_case = provide( + NetworkPolicyValidatorUseCase, + scope=Scope.REQUEST, + ) + class GlobalLDAPServerProvider(Provider): """Provider with session scope.""" @@ -895,8 +1011,9 @@ def get_session_factory( @provide(scope=Scope.APP) async def get_conn_factory( self, - engine: AsyncEngine, + engine_registry: EngineRegistry, ) -> AsyncIterator[AsyncConnection]: """Create session factory.""" + engine = engine_registry.get_master_engine() async with engine.connect() as connection: yield connection diff --git a/app/ldap_protocol/auth/auth_manager.py b/app/ldap_protocol/auth/auth_manager.py index 61c336f56..6dff8b69e 100644 --- a/app/ldap_protocol/auth/auth_manager.py +++ b/app/ldap_protocol/auth/auth_manager.py @@ -14,9 +14,8 @@ from config import Settings from entities import User from enums import AuthorizationRules, MFAFlags -from ldap_protocol.auth.dto import SetupDTO +from ldap_protocol.auth.dto import LoginRequestDTO, LoginResponseDTO, SetupDTO from ldap_protocol.auth.mfa_manager import MFAManager -from ldap_protocol.auth.schemas import LoginDTO, OAuth2Form from ldap_protocol.auth.use_cases import SetupUseCase from ldap_protocol.auth.utils import authenticate_user from ldap_protocol.dialogue import UserSchema @@ -100,11 +99,11 @@ def __getattribute__(self, name: str) -> object: async def login( self, - form: OAuth2Form, + form: LoginRequestDTO, url: URL, ip: IPv4Address | IPv6Address, user_agent: str, - ) -> LoginDTO: + ) -> LoginResponseDTO: """Log in a user. :param form: OAuth2Form with username and password @@ -169,8 +168,8 @@ async def login( ) if request_2fa: ( - mfa_challenge, - key, + mfa_challenge_dto, + session_key, ) = await self._mfa_manager.two_factor_protocol( user=user, network_policy=network_policy, @@ -178,7 +177,10 @@ async def login( ip=ip, user_agent=user_agent, ) - return LoginDTO(key, mfa_challenge) + return LoginResponseDTO( + session_key=session_key, + mfa_challenge=mfa_challenge_dto, + ) session_key = await self._repository.create_session_key( user, @@ -186,7 +188,10 @@ async def login( user_agent, self.key_ttl, ) - return LoginDTO(session_key, None) + return LoginResponseDTO( + session_key=session_key, + mfa_challenge=None, + ) async def _update_password( self, @@ -232,7 +237,7 @@ async def _update_password( if include_krb: await self._kadmin.create_or_update_principal_pw( - user.get_upn_prefix(), + user.sam_account_name, new_password, ) diff --git a/app/ldap_protocol/auth/dto.py b/app/ldap_protocol/auth/dto.py index 909c0f35e..80ac483d0 100644 --- a/app/ldap_protocol/auth/dto.py +++ b/app/ldap_protocol/auth/dto.py @@ -6,6 +6,16 @@ from dataclasses import dataclass +from enums import MFAChallengeStatuses + + +@dataclass +class LoginRequestDTO: + """Login request DTO.""" + + username: str + password: str + @dataclass class SetupDTO: @@ -17,3 +27,40 @@ class SetupDTO: display_name: str mail: str password: str + + +@dataclass +class MFAChallengeResponseDTO: + """MFA challenge response DTO.""" + + status: MFAChallengeStatuses + message: str + + +@dataclass +class LoginResponseDTO: + """Login response DTO.""" + + session_key: str | None + mfa_challenge: MFAChallengeResponseDTO | None + + +@dataclass +class MFACreateRequestDTO: + """MFA create request DTO.""" + + mfa_key: str + mfa_secret: str + is_ldap_scope: bool + secret_name: str + key_name: str + + +@dataclass +class MFAGetResponseDTO: + """MFA get response DTO.""" + + mfa_key: str | None + mfa_secret: str | None + mfa_key_ldap: str | None + mfa_secret_ldap: str | None diff --git a/app/ldap_protocol/auth/mfa_manager.py b/app/ldap_protocol/auth/mfa_manager.py index 334a66a44..0e24f2285 100644 --- a/app/ldap_protocol/auth/mfa_manager.py +++ b/app/ldap_protocol/auth/mfa_manager.py @@ -21,6 +21,11 @@ from config import Settings from entities import CatalogueSetting, NetworkPolicy, User from enums import AuthorizationRules, MFAChallengeStatuses, MFAFlags +from ldap_protocol.auth.dto import ( + MFAChallengeResponseDTO, + MFACreateRequestDTO, + MFAGetResponseDTO, +) from ldap_protocol.auth.exceptions.mfa import ( AuthenticationError, ForbiddenError, @@ -31,11 +36,6 @@ MissingMFACredentialsError, NetworkPolicyError, ) -from ldap_protocol.auth.schemas import ( - MFAChallengeResponse, - MFACreateRequest, - MFAGetResponse, -) from ldap_protocol.auth.utils import get_user from ldap_protocol.identity import IdentityProvider from ldap_protocol.multifactor import ( @@ -102,10 +102,10 @@ def __getattribute__(self, name: str) -> object: return self._monitor.wrap_proxy_request(attr) return attr - async def setup_mfa(self, mfa: MFACreateRequest) -> bool: + async def setup_mfa(self, mfa: MFACreateRequestDTO) -> bool: """Create or update MFA keys. - :param mfa: MFACreateRequest + :param mfa: MFACreateRequestDTO :return: bool """ async with self._session.begin_nested(): @@ -151,12 +151,12 @@ async def get_mfa( self, mfa_creds: MFA_HTTP_Creds | None, mfa_creds_ldap: MFA_LDAP_Creds | None, - ) -> MFAGetResponse: + ) -> MFAGetResponseDTO: """Get MFA keys for http and ldap. :param mfa_creds: MFA_HTTP_Creds or None :param mfa_creds_ldap: MFA_LDAP_Creds or None - :return: MFAGetResponse + :return: MFAGetResponseDTO """ if not mfa_creds: mfa_creds = MFA_HTTP_Creds(Creds(None, None)) @@ -164,7 +164,7 @@ async def get_mfa( if not mfa_creds_ldap: mfa_creds_ldap = MFA_LDAP_Creds(Creds(None, None)) - return MFAGetResponse( + return MFAGetResponseDTO( mfa_key=mfa_creds.key, mfa_secret=mfa_creds.secret, mfa_key_ldap=mfa_creds_ldap.key, @@ -219,14 +219,14 @@ async def _create_bypass_data( message: str, ip: IPv4Address | IPv6Address, user_agent: str, - ) -> tuple[MFAChallengeResponse, str | None]: + ) -> tuple[MFAChallengeResponseDTO, str | None]: """Create session key and response. :param user: User :param message: str :param ip: IPv4Address | IPv6Address :param user_agent: str - :return: tuple[MFAChallengeResponse, str | None] + :return: tuple[MFAChallengeResponseDTO, str | None] """ key = await self._repository.create_session_key( user, @@ -235,7 +235,7 @@ async def _create_bypass_data( self.key_ttl, ) return ( - MFAChallengeResponse( + MFAChallengeResponseDTO( status=MFAChallengeStatuses.BYPASS, message=message, ), @@ -249,7 +249,7 @@ async def two_factor_protocol( url: URL, ip: IPv4Address | IPv6Address, user_agent: str, - ) -> tuple[MFAChallengeResponse, str | None]: + ) -> tuple[MFAChallengeResponseDTO, str | None]: """Initiate two-factor protocol with application. :param user: User @@ -258,7 +258,7 @@ async def two_factor_protocol( :param ip: IP address :param user_agent: User-Agent string :return: - tuple[MFAChallengeResponse, str | None] (session key | None) + tuple[MFAChallengeResponseDTO, str | None] (session key | None) :raises MissingMFACredentialsError: if MFA is not initialized :raises InvalidCredentialsError: if credentials are invalid :raises NetworkPolicyError: if network policy is not passed @@ -300,7 +300,7 @@ async def two_factor_protocol( weakref.finalize(bypass_coro, bypass_coro.close) return ( - MFAChallengeResponse( + MFAChallengeResponseDTO( status=MFAChallengeStatuses.PENDING, message=redirect_url, ), diff --git a/app/ldap_protocol/auth/setup_gateway.py b/app/ldap_protocol/auth/setup_gateway.py index b5bfe580a..6cbad0ea1 100644 --- a/app/ldap_protocol/auth/setup_gateway.py +++ b/app/ldap_protocol/auth/setup_gateway.py @@ -16,6 +16,7 @@ AttributeValueValidator, ) from ldap_protocol.ldap_schema.entity_type_dao import EntityTypeDAO +from ldap_protocol.utils.async_cache import base_directories_cache from ldap_protocol.utils.helpers import create_object_sid, generate_domain_sid from ldap_protocol.utils.queries import get_domain_object_class from password_utils import PasswordUtils @@ -59,7 +60,7 @@ async def setup_enviroment( self, *, data: list, - is_system: bool, + is_system: bool = True, dn: str = "multifactor.dev", ) -> None: """Create directories and users for enviroment.""" @@ -113,6 +114,7 @@ async def setup_enviroment( domain=domain, parent=domain, ) + base_directories_cache.clear() except Exception: import traceback @@ -132,13 +134,13 @@ async def create_dir( is_system=is_system, object_class=data["object_class"], name=data["name"], - parent=parent, ) dir_.groups = [] dir_.create_path(parent, dir_.get_dn_prefix()) self._session.add(dir_) await self._session.flush() + dir_.parent_id = parent.id if parent else None await self._session.refresh(dir_, ["id"]) self._session.add( diff --git a/app/ldap_protocol/auth/use_cases.py b/app/ldap_protocol/auth/use_cases.py index b9a53414e..ca063bcd7 100644 --- a/app/ldap_protocol/auth/use_cases.py +++ b/app/ldap_protocol/auth/use_cases.py @@ -9,11 +9,14 @@ from sqlalchemy.exc import IntegrityError from sqlalchemy.ext.asyncio import AsyncSession +from config import Settings from constants import ( DOMAIN_ADMIN_GROUP_NAME, + DOMAIN_CONTROLLERS_OU_NAME, FIRST_SETUP_DATA, USERS_CONTAINER_NAME, ) +from enums import SamAccountTypeCodes from ldap_protocol.auth.dto import SetupDTO from ldap_protocol.auth.setup_gateway import SetupGateway from ldap_protocol.identity.exceptions import ( @@ -21,6 +24,7 @@ ForbiddenError, ) from ldap_protocol.ldap_schema.entity_type_use_case import EntityTypeUseCase +from ldap_protocol.objects import UserAccountControlFlag from ldap_protocol.policies.audit.audit_use_case import AuditUseCase from ldap_protocol.policies.password import PasswordPolicyUseCases from ldap_protocol.roles.role_use_case import RoleUseCase @@ -38,6 +42,7 @@ def __init__( role_use_case: RoleUseCase, audit_use_case: AuditUseCase, session: AsyncSession, + settings: Settings, ) -> None: """Initialize Setup manager. @@ -51,6 +56,7 @@ def __init__( self._role_use_case = role_use_case self._audit_use_case = audit_use_case self._session = session + self._settings = settings async def setup(self, dto: SetupDTO) -> None: """Perform the initial setup of structure and policies. @@ -66,6 +72,7 @@ async def setup(self, dto: SetupDTO) -> None: data = copy.deepcopy(FIRST_SETUP_DATA) data.append(self._create_user_data(dto)) + data.append(self._create_domain_controller_data()) await self._create(dto, data) @@ -76,6 +83,36 @@ async def is_setup(self) -> bool: """ return await self._setup_gateway.is_setup() + def _create_domain_controller_data(self) -> dict: + return { + "name": DOMAIN_CONTROLLERS_OU_NAME, + "object_class": "organizationalUnit", + "attributes": { + "objectClass": ["top", "container"], + }, + "children": [ + { + "name": self._settings.HOST_MACHINE_SHORT_NAME, + "object_class": "computer", + "attributes": { + "objectClass": ["top"], + "userAccountControl": [ + str( + UserAccountControlFlag.SERVER_TRUST_ACCOUNT.value, + ), + ], + "sAMAccountType": [ + str(SamAccountTypeCodes.SAM_MACHINE_ACCOUNT), + ], + "sAMAccountName": [ + self._settings.HOST_MACHINE_SHORT_NAME, + ], + "ipHostNumber": [self._settings.DEFAULT_NAMESERVER], + }, + }, + ], + } + def _create_user_data(self, dto: SetupDTO) -> dict: """Create user data by request. @@ -114,6 +151,9 @@ def _create_user_data(self, dto: SetupDTO) -> dict: "userAccountControl": ["512"], "primaryGroupID": ["512"], "givenName": [dto.username], + "sAMAccountType": [ + str(SamAccountTypeCodes.SAM_USER_OBJECT), + ], }, "objectSid": 500, }, diff --git a/app/ldap_protocol/custom_requests/__init__.py b/app/ldap_protocol/custom_requests/__init__.py new file mode 100644 index 000000000..2f7a89149 --- /dev/null +++ b/app/ldap_protocol/custom_requests/__init__.py @@ -0,0 +1,9 @@ +"""Custom Requests. + +Copyright (c) 2026 MultiFactor +License: https://github.com/MultiDirectoryLab/MultiDirectory/blob/main/LICENSE +""" + +from .rename import RenameRequest + +__all__ = ["RenameRequest"] diff --git a/app/ldap_protocol/custom_requests/rename.py b/app/ldap_protocol/custom_requests/rename.py new file mode 100644 index 000000000..1748112f8 --- /dev/null +++ b/app/ldap_protocol/custom_requests/rename.py @@ -0,0 +1,85 @@ +"""RenameRequest for main router. + +Copyright (c) 2026 MultiFactor +License: https://github.com/MultiDirectoryLab/MultiDirectory/blob/main/LICENSE +""" + +from dishka import AsyncContainer +from pydantic import BaseModel +from sqlalchemy.ext.asyncio import AsyncSession + +from ldap_protocol.ldap_requests import ( + ModifyDNRequest as LDAPModifyDNRequest, + ModifyRequest as LDAPModifyRequest, +) +from ldap_protocol.ldap_responses import LDAPResult +from ldap_protocol.objects import Changes + + +class RenameRequest(BaseModel): + """Rename Request. It's not RFC 4511. + + Combines ModifyDN and Modify operations. + """ + + object: str + newrdn: str + changes: list[Changes] + + @property + def _new_object(self) -> str: + return f"{self.newrdn},{','.join(self.object.split(',')[1:])}" + + @property + def _oldrdn(self) -> str: + return self.object.split(",")[0] + + async def _modify_dn_request( + self, + container: AsyncContainer, + entry: str, + newrdn: str, + ) -> LDAPResult: + modify_dn_request = LDAPModifyDNRequest( + entry=entry, + newrdn=newrdn, + deleteoldrdn=True, + new_superior=None, + ) + return await modify_dn_request.handle_api(container) + + async def _expire_session_objects(self, container: AsyncContainer) -> None: + session = await container.get(AsyncSession) + session.expire_all() + + async def _modify_request(self, container: AsyncContainer) -> LDAPResult: + modify_request = LDAPModifyRequest( + object=self._new_object, + changes=self.changes, + ) + return await modify_request.handle_api(container) + + async def handle_api(self, container: AsyncContainer) -> LDAPResult: + """Handle RenameRequest by executing ModifyDN then Modify. + + If ModifyRequest fails, rollback the ModifyDnRequest and return error. + """ + modify_dn_response = await self._modify_dn_request( + container, + self.object, + self.newrdn, + ) + if not modify_dn_response or modify_dn_response.result_code != 0: + return modify_dn_response + + await self._expire_session_objects(container) + + modify_response = await self._modify_request(container) + if not modify_response or modify_response.result_code != 0: + await self._modify_dn_request( + container, + self._new_object, + self._oldrdn, + ) + + return modify_response diff --git a/app/ldap_protocol/dhcp/__init__.py b/app/ldap_protocol/dhcp/__init__.py index cf26f1903..a813b3d03 100644 --- a/app/ldap_protocol/dhcp/__init__.py +++ b/app/ldap_protocol/dhcp/__init__.py @@ -8,21 +8,10 @@ DHCPEntryNotFoundError, DHCPEntryUpdateError, DHCPOperationError, - DHCPValidatonError, + DHCPValidationError, ) from .kea_dhcp_manager import KeaDHCPManager from .kea_dhcp_repository import KeaDHCPAPIRepository -from .schemas import ( - DHCPChangeStateSchemaRequest, - DHCPLeaseSchemaRequest, - DHCPLeaseSchemaResponse, - DHCPLeaseToReservationErrorResponse, - DHCPReservationSchemaRequest, - DHCPReservationSchemaResponse, - DHCPStateSchemaResponse, - DHCPSubnetSchemaAddRequest, - DHCPSubnetSchemaResponse, -) from .stub import StubDHCPAPIRepository, StubDHCPManager @@ -54,17 +43,8 @@ def get_dhcp_api_repository_class( "DHCPEntryDeleteError", "DHCPEntryAddError", "DHCPEntryUpdateError", - "DHCPValidatonError", + "DHCPValidationError", "DHCPOperationError", "DHCPAPIError", "DHCPSubnetSchemaRequest", - "DHCPSubnetSchemaAddRequest", - "DHCPReservationSchemaRequest", - "DHCPSubnetSchemaResponse", - "DHCPLeaseSchemaRequest", - "DHCPLeaseSchemaResponse", - "DHCPReservationSchemaResponse", - "DHCPChangeStateSchemaRequest", - "DHCPStateSchemaResponse", - "DHCPLeaseToReservationErrorResponse", ] diff --git a/app/ldap_protocol/dhcp/dtos.py b/app/ldap_protocol/dhcp/dtos.py new file mode 100644 index 000000000..2b7c097a0 --- /dev/null +++ b/app/ldap_protocol/dhcp/dtos.py @@ -0,0 +1,49 @@ +"""DTOs for DHCP manager. + +Copyright (c) 2025 MultiFactor +License: https://github.com/MultiDirectoryLab/MultiDirectory/blob/main/LICENSE +""" + +from dataclasses import dataclass, field + +from .dataclasses import DHCPLease, DHCPReservation, DHCPSubnet +from .enums import KeaDHCPCommands + + +@dataclass +class KeaDHCPCommandRequest: + """Single command request.""" + + command: KeaDHCPCommands + + +@dataclass +class KeaDHCPBaseAPIRequest(KeaDHCPCommandRequest): + """Base request for Kea DHCP API.""" + + arguments: list[int] | dict[str, str] | None = None + service: list[str] = field(default_factory=lambda: ["dhcp4"]) + + +@dataclass +class KeaDHCPAPISubnetRequest(KeaDHCPCommandRequest): + """Request for Kea DHCP API to manage subnets.""" + + subnet4: DHCPSubnet | list[DHCPSubnet] + service: list[str] = field(default_factory=lambda: ["dhcp4"]) + + +@dataclass +class KeaDHCPAPILeaseRequest(KeaDHCPCommandRequest): + """Request for Kea DHCP API to manage leases.""" + + lease: DHCPLease + service: list[str] = field(default_factory=lambda: ["dhcp4"]) + + +@dataclass +class KeaDHCPAPIReservationRequest(KeaDHCPCommandRequest): + """Request for Kea DHCP API to manage reservations.""" + + arguments: DHCPReservation + service: list[str] = field(default_factory=lambda: ["dhcp4"]) diff --git a/app/ldap_protocol/dhcp/exceptions.py b/app/ldap_protocol/dhcp/exceptions.py index 4b29a4514..c5be2f126 100644 --- a/app/ldap_protocol/dhcp/exceptions.py +++ b/app/ldap_protocol/dhcp/exceptions.py @@ -37,7 +37,7 @@ class DHCPAPIError(DHCPError): code = ErrorCodes.DHCP_API_ERROR -class DHCPValidatonError(DHCPError): +class DHCPValidationError(DHCPError): """DHCP validation error.""" code = ErrorCodes.DHCP_VALIDATION_ERROR diff --git a/app/ldap_protocol/dhcp/kea_dhcp_repository.py b/app/ldap_protocol/dhcp/kea_dhcp_repository.py index a41c8e82c..849523bec 100644 --- a/app/ldap_protocol/dhcp/kea_dhcp_repository.py +++ b/app/ldap_protocol/dhcp/kea_dhcp_repository.py @@ -17,6 +17,12 @@ DHCPReservation, DHCPSubnet, ) +from .dtos import ( + KeaDHCPAPILeaseRequest, + KeaDHCPAPIReservationRequest, + KeaDHCPAPISubnetRequest, + KeaDHCPBaseAPIRequest, +) from .enums import KeaDHCPCommands, KeaDHCPResultCodes from .exceptions import ( DHCPAPIError, @@ -40,12 +46,6 @@ release_lease_retort, update_subnet_retort, ) -from .schemas import ( - KeaDHCPAPILeaseRequest, - KeaDHCPAPIReservationRequest, - KeaDHCPAPISubnetRequest, - KeaDHCPBaseAPIRequest, -) class KeaDHCPAPIRepository(DHCPAPIRepository): diff --git a/app/ldap_protocol/dhcp/retorts.py b/app/ldap_protocol/dhcp/retorts.py index 2a0d4c11d..023b800a1 100644 --- a/app/ldap_protocol/dhcp/retorts.py +++ b/app/ldap_protocol/dhcp/retorts.py @@ -7,7 +7,7 @@ from adaptix import Retort, name_mapping from .dataclasses import DHCPLease, DHCPReservation, DHCPSubnet -from .schemas import ( +from .dtos import ( KeaDHCPAPILeaseRequest, KeaDHCPAPISubnetRequest, KeaDHCPBaseAPIRequest, diff --git a/app/ldap_protocol/dns/__init__.py b/app/ldap_protocol/dns/__init__.py index f9c97fba7..f647092a3 100644 --- a/app/ldap_protocol/dns/__init__.py +++ b/app/ldap_protocol/dns/__init__.py @@ -1,61 +1,64 @@ -from .base import ( +from ldap_protocol.dns.clients import ( + PowerDNSAuthHTTPClient, + PowerDNSDistClient, + PowerDNSRecursorHTTPClient, +) +from ldap_protocol.dns.constants import ( DNS_MANAGER_IP_ADDRESS_NAME, DNS_MANAGER_STATE_NAME, DNS_MANAGER_ZONE_NAME, - AbstractDNSManager, +) +from ldap_protocol.dns.dns_gateway import DNSStateGateway +from ldap_protocol.dns.dto import ( DNSForwardServerStatus, - DNSForwardZone, - DNSManagerSettings, + DNSForwardZoneDTO, + DNSMasterZoneDTO, + DNSRRSetDTO, + DNSSettingsDTO, + PowerDNSSettingsDTO, +) +from ldap_protocol.dns.enums import ( DNSManagerState, + DNSRecordType, + PowerDNSZoneType, +) +from ldap_protocol.dns.exceptions import ( + DNSConnectionError, + DNSError, DNSNotImplementedError, - DNSRecords, - DNSServerParam, - DNSServerParamName, - DNSZone, - DNSZoneParam, - DNSZoneParamName, - DNSZoneType, ) -from .dns_gateway import DNSStateGateway -from .exceptions import DNSConnectionError, DNSError -from .remote import RemoteDNSManager -from .selfhosted import SelfHostedDNSManager -from .stub import StubDNSManager - - -async def get_dns_manager_class( - dns_state_gateway: DNSStateGateway, -) -> type[AbstractDNSManager]: - """Get DNS manager class.""" - dns_state = await dns_state_gateway.get_dns_state() - if dns_state == DNSManagerState.SELFHOSTED: - return SelfHostedDNSManager - elif dns_state == DNSManagerState.HOSTED: - return RemoteDNSManager - return StubDNSManager - +from ldap_protocol.dns.managers import ( + AbstractDNSManager, + PowerDNSManager, + RemoteDNSManager, + StubDNSManager, +) +from ldap_protocol.dns.use_cases import DNSUseCase __all__ = [ "get_dns_manager_class", + "DNSUseCase", "AbstractDNSManager", + "PowerDNSManager", + "PowerDNSAuthHTTPClient", + "PowerDNSRecursorHTTPClient", + "PowerDNSDistClient", "RemoteDNSManager", - "SelfHostedDNSManager", "StubDNSManager", "DNSStateGateway", "DNSForwardServerStatus", - "DNSForwardZone", - "DNSManagerSettings", - "DNSRecords", - "DNSServerParam", - "DNSZone", - "DNSZoneParam", - "DNSZoneType", - "DNSServerParamName", - "DNSZoneParamName", - "DNSConnectionError", + "DNSForwardZoneDTO", + "DNSSettingsDTO", + "PowerDNSSettingsDTO", + "DNSRRSetDTO", + "DNSMasterZoneDTO", + "PowerDNSZoneType", + "DNSRecordType", + "DNSManagerState", "DNS_MANAGER_IP_ADDRESS_NAME", "DNS_MANAGER_ZONE_NAME", "DNS_MANAGER_STATE_NAME", "DNSNotImplementedError", "DNSError", + "DNSConnectionError", ] diff --git a/app/ldap_protocol/dns/base.py b/app/ldap_protocol/dns/base.py deleted file mode 100644 index 01fe71c8c..000000000 --- a/app/ldap_protocol/dns/base.py +++ /dev/null @@ -1,306 +0,0 @@ -"""Abstract DNS service for DNS server managing. - -Copyright (c) 2024 MultiFactor -License: https://github.com/MultiDirectoryLab/MultiDirectory/blob/main/LICENSE -""" - -from abc import abstractmethod -from dataclasses import dataclass -from enum import StrEnum -from ipaddress import IPv4Address, IPv6Address - -import httpx -from loguru import logger as loguru_logger - -from ldap_protocol.dns.dto import DNSSettingDTO - -from .exceptions import DNSSetupError - -DNS_MANAGER_STATE_NAME = "DNSManagerState" -DNS_MANAGER_ZONE_NAME = "DNSManagerZoneName" -DNS_MANAGER_IP_ADDRESS_NAME = "DNSManagerIpAddress" -DNS_MANAGER_TSIG_KEY_NAME = "DNSManagerTSIGKey" -log = loguru_logger.bind(name="DNSManager") - -log.add( - "logs/dnsmanager_{time:DD-MM-YYYY}.log", - filter=lambda rec: rec["extra"].get("name") == "dnsmanager", - retention="10 days", - rotation="1d", - colorize=False, -) - - -class DNSZoneType(StrEnum): - """DNS zone types.""" - - MASTER = "master" - FORWARD = "forward" - - -class DNSForwarderServerStatus(StrEnum): - """Forwarder DNS server statuses.""" - - VALIDATED = "validated" - NOT_VALIDATED = "not validated" - NOT_FOUND = "not found" - - -class DNSNotImplementedError(NotImplementedError): - """API Not Implemented Error.""" - - -class DNSRecordType(StrEnum): - """DNS record types.""" - - a = "A" - aaaa = "AAAA" - cname = "CNAME" - mx = "MX" - ns = "NS" - txt = "TXT" - soa = "SOA" - ptr = "PTR" - srv = "SRV" - - -class DNSZoneParamName(StrEnum): - """Possible DNS zone option names.""" - - acl = "acl" - forwarders = "forwarders" - ttl = "ttl" - - -class DNSServerParamName(StrEnum): - """Possible DNS server option names.""" - - dnssec = "dnssec-validation" - - -class DNSManagerState(StrEnum): - """DNSManager state enum.""" - - NOT_CONFIGURED = "0" - SELFHOSTED = "1" - HOSTED = "2" - - -@dataclass -class DNSZoneParam: - """DNS zone parameter.""" - - name: DNSZoneParamName - value: str | list[str] | None - - -@dataclass -class DNSServerParam: - """DNS zone parameter.""" - - name: DNSServerParamName - value: str | list[str] - - -@dataclass -class DNSForwardServerStatus: - """Forward DNS server status.""" - - ip: str - status: DNSForwarderServerStatus - FQDN: str | None - - -@dataclass -class DNSRecord: - """Single dns record.""" - - name: str - value: str - ttl: int - - -@dataclass -class DNSRecords: - """Grouped dns records.""" - - type: str - records: list[DNSRecord] - - -@dataclass -class DNSZone: - """DNS zone.""" - - name: str - type: DNSZoneType - records: list[DNSRecords] - - -@dataclass -class DNSForwardZone: - """DNS forward zone.""" - - name: str - type: DNSZoneType - forwarders: list[str] - - -class DNSManagerSettings: - """DNS Manager settings.""" - - zone_name: str | None - domain: str | None - dns_server_ip: str | None - tsig_key: str | None - - def __init__( - self, - zone_name: str | None, - dns_server_ip: str | None, - tsig_key: str | None, - ) -> None: - """Set settings.""" - self.zone_name = zone_name - self.domain = zone_name + "." if zone_name is not None else None - self.dns_server_ip = dns_server_ip - self.tsig_key = tsig_key - - -class AbstractDNSManager: - """Abstract DNS manager class.""" - - _dns_settings: DNSManagerSettings - _http_client: httpx.AsyncClient - - def __init__( - self, - settings: DNSManagerSettings, - http_client: httpx.AsyncClient, - ) -> None: - """Set up DNS manager.""" - self._dns_settings = settings - self._http_client = http_client - - async def setup( - self, - dns_status: str, - domain: str, - dns_ip_address: str | IPv4Address | IPv6Address | None, - tsig_key: str | None, - ) -> DNSSettingDTO: - """Set up DNS server and DNS manager.""" - try: - if ( - dns_status == DNSManagerState.SELFHOSTED - and self._http_client is not None - ): - await self._http_client.post( - "/server/setup", - json={"zone_name": domain}, - ) - tsig_key = None - return DNSSettingDTO( - zone_name=domain, - dns_server_ip=dns_ip_address, - tsig_key=tsig_key, - ) - - except Exception as e: - raise DNSSetupError(e) - - @abstractmethod - async def create_record( - self, - hostname: str, - ip: str, - record_type: str, - ttl: int | None, - zone_name: str | None = None, - ) -> None: ... - - @abstractmethod - async def update_record( - self, - hostname: str, - ip: str | None, - record_type: str, - ttl: int | None, - zone_name: str | None = None, - ) -> None: ... - - @abstractmethod - async def delete_record( - self, - hostname: str, - ip: str, - record_type: str, - zone_name: str | None = None, - ) -> None: ... - - @abstractmethod - async def get_all_records(self) -> list[DNSRecords]: ... - - @abstractmethod - async def get_all_zones_records(self) -> list[DNSZone]: - raise DNSNotImplementedError - - @abstractmethod - async def get_forward_zones(self) -> list[DNSForwardZone]: - raise DNSNotImplementedError - - @abstractmethod - async def create_zone( - self, - zone_name: str, - zone_type: DNSZoneType, - nameserver: str | None, - params: list[DNSZoneParam], - ) -> None: - raise DNSNotImplementedError - - @abstractmethod - async def update_zone( - self, - zone_name: str, - params: list[DNSZoneParam] | None, - ) -> None: - raise DNSNotImplementedError - - @abstractmethod - async def delete_zone( - self, - zone_names: list[str], - ) -> None: - raise DNSNotImplementedError - - @abstractmethod - async def check_forward_dns_server( - self, - dns_server_ip: IPv4Address | IPv6Address, - host_dns_servers: list[str], - ) -> DNSForwardServerStatus: - raise DNSNotImplementedError - - @abstractmethod - async def update_server_options( - self, - params: list[DNSServerParam], - ) -> None: - raise DNSNotImplementedError - - @abstractmethod - async def get_server_options(self) -> list[DNSServerParam]: ... - - @abstractmethod - async def restart_server( - self, - ) -> None: - raise DNSNotImplementedError - - @abstractmethod - async def reload_zone( - self, - zone_name: str, - ) -> None: - raise DNSNotImplementedError diff --git a/app/ldap_protocol/dns/bind_to_pdns_migration_use_case.py b/app/ldap_protocol/dns/bind_to_pdns_migration_use_case.py new file mode 100644 index 000000000..be208d12e --- /dev/null +++ b/app/ldap_protocol/dns/bind_to_pdns_migration_use_case.py @@ -0,0 +1,200 @@ +"""Manager for migrating from BIND to PowerDNS. + +Copyright (c) 2026 MultiFactor +License: https://github.com/MultiDirectoryLab/MultiDirectory/blob/main/LICENSE +""" + +import os + +import dns.zone +from loguru import logger + +from ldap_protocol.dns.dto import ( + DNSForwardZoneDTO, + DNSMasterZoneDTO, + DNSRecordDTO, + DNSRRSetDTO, + DNSSettingsDTO, +) +from ldap_protocol.dns.enums import DNSRecordType +from ldap_protocol.dns.managers.power_dns_manager import PowerDNSManager + + +class BindToPDNSMigrationUseCase: + bind_zone_file_dir: str = "/opt/" + bind_config_files_dir: str = "/etc/bind/" + + def __init__( + self, + pdns_manager: PowerDNSManager, + dns_settings: DNSSettingsDTO, + ) -> None: + self.pdns_manager = pdns_manager + self.dns_settings = dns_settings + + def _strip_record_name(self, record_name: str, zone_name: str) -> str: + """Strip trash from record name.""" + logger.debug( + f"Stripping record name '{record_name}' for zone '{zone_name}'", + ) + if record_name.startswith(("\\032", "\\@")) and record_name != "\\@": + record_name = record_name.removeprefix("\\032").removeprefix("\\@") + elif record_name == "\\@": + record_name = zone_name + return ( + record_name if not record_name.startswith(".") else record_name[1:] + ) + + def parse_bind_config_file( + self, + ) -> tuple[list[DNSMasterZoneDTO], list[DNSForwardZoneDTO]]: + """Parse BIND configuration files to extract zone information.""" + master_zones: list[DNSMasterZoneDTO] = [] + forward_zones: list[DNSForwardZoneDTO] = [] + + with open( + os.path.join(self.bind_config_files_dir, "named.conf.local"), + ) as f: + for line in f: + line = line.strip() + if line.startswith("zone"): + parts = line.split() + if len(parts) >= 2: + zone_name = parts[1].strip('"') + continue + + if "type master" in line: + master_zones.append( + DNSMasterZoneDTO( + id=zone_name, + name=zone_name, + ), + ) + elif "type forward" in line: + forward_zone = DNSForwardZoneDTO( + id=zone_name, + name=zone_name, + ) + elif "forwarders" in line and forward_zone: + forwarders_part = line.split("forwarders")[1] + forwarders = [ + f + for f in forwarders_part.strip(";") + .strip(" ") + .strip("{") + .strip("}") + .strip(" ") + .split(";")[:-1] + ] + forward_zone.servers = forwarders + forward_zones.append(forward_zone) + forward_zone = None + + return master_zones, forward_zones + + def parse_zones_records( + self, + master_zones: list[DNSMasterZoneDTO], + ) -> list[DNSMasterZoneDTO]: + """Parse zone files to extract DNS records.""" + zones_with_records: list[DNSMasterZoneDTO] = [] + + for zone in master_zones: + zone_rrsets: list[DNSRRSetDTO] = [] + zone_file_path = os.path.join( + self.bind_zone_file_dir, + f"{zone.name}.zone", + ) + try: + zone_obj = dns.zone.from_file( + zone_file_path, + origin=zone.name, + relativize=False, + ) + except FileNotFoundError: + logger.error( + f"Zone file for zone {zone.name} not found, skipping...", + ) + continue + + for name, ttl, rdata in zone_obj.iterate_rdatas(): + try: + record_type = DNSRecordType(rdata.rdtype.name) + except ValueError: + logger.warning( + f"Unsupported DNS record type {rdata.rdtype.name} in zone '{zone.name}'", # noqa: E501 + ) + continue + + zone_rrsets.append( + DNSRRSetDTO( + name=self._strip_record_name( + name.to_text(), + zone.name, + ), + type=record_type, + records=[ + DNSRecordDTO( + content=rdata.to_text(), + disabled=False, + ), + ], + ttl=ttl, + ), + ) + zone.rrsets = zone_rrsets + zones_with_records.append(zone) + + return zones_with_records + + async def get_bind_zones( + self, + ) -> tuple[list[DNSMasterZoneDTO], list[DNSForwardZoneDTO]]: + """Get zones from BIND.""" + master_zones, forward_zones = self.parse_bind_config_file() + master_zones = self.parse_zones_records(master_zones) + + return master_zones, forward_zones + + async def migrate_from_bind(self) -> None: + """Migrate from BIND to PowerDNS.""" + master_zones, forward_zones = await self.get_bind_zones() + + for master_zone in master_zones: + await self.pdns_manager.create_master_zone( + master_zone, + is_empty=True, + ) + for rrset in master_zone.rrsets: + await self.pdns_manager.create_record( + master_zone.name, + rrset, + ) + + for forward_zone in forward_zones: + await self.pdns_manager.create_forward_zone(forward_zone) + + open(os.path.join(self.bind_zone_file_dir, "migrated"), "a").close() + open(os.path.join(self.bind_config_files_dir, "migrated"), "a").close() + + def is_migration_needed(self) -> bool: + """Check if migration is needed.""" + return not ( + os.path.exists(os.path.join(self.bind_zone_file_dir, "migrated")) + and os.path.exists( + os.path.join(self.bind_config_files_dir, "migrated"), + ) + ) and bool(os.listdir(self.bind_zone_file_dir)) + + async def migrate(self) -> None: + """Migrate from BIND to PowerDNS.""" + if not self.is_migration_needed(): + logger.info("BIND to PowerDNS migration is not needed, exiting...") + return + + logger.info("Starting BIND to PowerDNS migration...") + await self.pdns_manager.setup(self.dns_settings, is_migration=True) + + await self.migrate_from_bind() + logger.info("Migration successful") + return diff --git a/app/ldap_protocol/dns/clients/__init__.py b/app/ldap_protocol/dns/clients/__init__.py new file mode 100644 index 000000000..64a731025 --- /dev/null +++ b/app/ldap_protocol/dns/clients/__init__.py @@ -0,0 +1,17 @@ +from ldap_protocol.dns.clients.abstract_client import ( + AbstractDNSForwardHTTPClient, + AbstractDNSMasterHTTPClient, +) +from ldap_protocol.dns.clients.power_dns_http_clients import ( + PowerDNSAuthHTTPClient, + PowerDNSRecursorHTTPClient, +) +from ldap_protocol.dns.clients.power_dnsdist_client import PowerDNSDistClient + +__all__ = [ + "PowerDNSDistClient", + "PowerDNSAuthHTTPClient", + "PowerDNSRecursorHTTPClient", + "AbstractDNSMasterHTTPClient", + "AbstractDNSForwardHTTPClient", +] diff --git a/app/ldap_protocol/dns/clients/abstract_client.py b/app/ldap_protocol/dns/clients/abstract_client.py new file mode 100644 index 000000000..fa05a95c4 --- /dev/null +++ b/app/ldap_protocol/dns/clients/abstract_client.py @@ -0,0 +1,110 @@ +"""Abstract DNS client for DNS server managing. + +Copyright (c) 2026 MultiFactor +License: https://github.com/MultiDirectoryLab/MultiDirectory/blob/main/LICENSE +""" + +from abc import abstractmethod + +import httpx +from fastapi import status + +from ldap_protocol.dns.dto import ( + DNSForwardZoneDTO, + DNSMasterZoneDTO, + DNSRRSetDTO, +) +from ldap_protocol.dns.exceptions import ( + DNSEntryNotFoundError, + DNSNotImplementedError, + DNSNotSupportedError, + DNSUnavailableError, + DNSValidationError, +) + + +class AbstractDNSHTTPClient: + """Abstract DNS client class.""" + + def __init__( + self, + http_client: httpx.AsyncClient, + ) -> None: + """Initialize the PowerDNS HTTP client.""" + self._http_client = http_client + + async def _validate_response(self, response: httpx.Response) -> None: + """Validate the API response.""" + match response.status_code: + case status.HTTP_400_BAD_REQUEST: + raise DNSNotSupportedError(response.text or "Bad Request") + case status.HTTP_404_NOT_FOUND: + raise DNSEntryNotFoundError(response.text or "Not Found") + case status.HTTP_422_UNPROCESSABLE_ENTITY: + raise DNSValidationError( + response.text or "Unprocessable Entity", + ) + case status.HTTP_500_INTERNAL_SERVER_ERROR: + raise DNSUnavailableError( + response.text or "Internal Server Error", + ) + + +class AbstractDNSMasterHTTPClient(AbstractDNSHTTPClient): + """Abstract DNS client for master server.""" + + @abstractmethod + async def create_record(self, record: DNSRRSetDTO) -> None: + raise DNSNotImplementedError + + @abstractmethod + async def get_records(self, zone_id: str) -> list[DNSRRSetDTO]: + raise DNSNotImplementedError + + @abstractmethod + async def update_record(self, record: DNSRRSetDTO) -> None: + raise DNSNotImplementedError + + @abstractmethod + async def delete_record(self, zone_id: str, record: DNSRRSetDTO) -> None: + raise DNSNotImplementedError + + @abstractmethod + async def create_master_zone(self, zone: DNSMasterZoneDTO) -> None: + raise DNSNotImplementedError + + @abstractmethod + async def get_master_zones(self) -> list[DNSMasterZoneDTO]: + raise DNSNotImplementedError + + @abstractmethod + async def get_master_zone_by_id(self, zone_id: str) -> DNSMasterZoneDTO: + raise DNSNotImplementedError + + @abstractmethod + async def update_master_zone(self, zone: DNSMasterZoneDTO) -> None: + raise DNSNotImplementedError + + @abstractmethod + async def delete_master_zone(self, zone_id: str) -> None: + raise DNSNotImplementedError + + +class AbstractDNSForwardHTTPClient(AbstractDNSHTTPClient): + """Abstract DNS slient for forward server.""" + + @abstractmethod + async def create_forward_zone(self, zone: DNSForwardZoneDTO) -> None: + raise DNSNotImplementedError + + @abstractmethod + async def get_forward_zones(self) -> list[DNSForwardZoneDTO]: + raise DNSNotImplementedError + + @abstractmethod + async def update_forward_zone(self, zone: DNSForwardZoneDTO) -> None: + raise DNSNotImplementedError + + @abstractmethod + async def delete_forward_zone(self, zone_id: str) -> None: + raise DNSNotImplementedError diff --git a/app/ldap_protocol/dns/clients/power_dns_http_clients.py b/app/ldap_protocol/dns/clients/power_dns_http_clients.py new file mode 100644 index 000000000..82fc345b2 --- /dev/null +++ b/app/ldap_protocol/dns/clients/power_dns_http_clients.py @@ -0,0 +1,112 @@ +"""HTTP Client for PowerDNS servers. + +Copyright (c) 2026 MultiFactor +License: https://github.com/MultiDirectoryLab/MultiDirectory/blob/main/LICENSE +""" + +from adaptix import Retort + +from ldap_protocol.dns.clients.abstract_client import AbstractDNSHTTPClient +from ldap_protocol.dns.dto import ( + DNSForwardZoneDTO, + DNSMasterZoneDTO, + DNSRRSetDTO, +) + +base_retort = Retort() + + +class PowerDNSAuthHTTPClient(AbstractDNSHTTPClient): + """HTTP client for PowerDNS Auth server.""" + + async def record_action(self, zone_id: str, record: DNSRRSetDTO) -> None: + """Send request to perform action on DNS record in given zone.""" + response = await self._http_client.patch( + f"/zones/{zone_id}", + json={"rrsets": [base_retort.dump(record)]}, + ) + + await self._validate_response(response) + + async def get_records(self, zone_id: str) -> list[DNSRRSetDTO]: + """Send request to get all records of given zone.""" + response = await self._http_client.get(f"/zones/{zone_id}") + await self._validate_response(response) + + zone = base_retort.load(response.json(), DNSMasterZoneDTO) + return zone.rrsets + + async def create_master_zone(self, zone: DNSMasterZoneDTO) -> None: + """Send request to create new master zone.""" + response = await self._http_client.post( + "/zones", + json=base_retort.dump(zone), + ) + await self._validate_response(response) + + async def get_master_zones(self) -> list[DNSMasterZoneDTO]: + """Send request to get all master zones.""" + response = await self._http_client.get("/zones") + await self._validate_response(response) + + return base_retort.load(response.json(), list[DNSMasterZoneDTO]) + + async def get_master_zone_by_id(self, zone_id: str) -> DNSMasterZoneDTO: + """Send request to get master zone by ID.""" + response = await self._http_client.get(f"/zones/{zone_id}") + await self._validate_response(response) + + return base_retort.load(response.json(), DNSMasterZoneDTO) + + async def update_master_zone( + self, + zone_id: str, + zone: DNSMasterZoneDTO, + ) -> None: + """Send request to update master zone with given ID.""" + response = await self._http_client.put( + f"/zones/{zone_id}", + json=base_retort.dump(zone), + ) + await self._validate_response(response) + + async def delete_master_zone(self, zone_id: str) -> None: + """Send request to delete master zone with given ID.""" + response = await self._http_client.delete(f"/zones/{zone_id}") + await self._validate_response(response) + + +class PowerDNSRecursorHTTPClient(AbstractDNSHTTPClient): + """HTTP client for PowerDNS Recursor server.""" + + async def create_forward_zone(self, zone: DNSForwardZoneDTO) -> None: + """Send request to create forward zone.""" + response = await self._http_client.post( + "/zones", + json=base_retort.dump(zone), + ) + await self._validate_response(response) + + async def get_forward_zones(self) -> list[DNSForwardZoneDTO]: + """Send request to get all forward zones.""" + response = await self._http_client.get("/zones") + await self._validate_response(response) + + return base_retort.load(response.json(), list[DNSForwardZoneDTO]) + + async def update_forward_zone( + self, + zone_id: str, + zone: DNSForwardZoneDTO, + ) -> None: + """Send request to update forward zone with given ID.""" + response = await self._http_client.put( + f"/zones/{zone_id}", + json=base_retort.dump(zone), + ) + await self._validate_response(response) + + async def delete_forward_zone(self, zone_id: str) -> None: + """Send request to delete forward zone with given ID.""" + response = await self._http_client.delete(f"/zones/{zone_id}") + await self._validate_response(response) diff --git a/app/ldap_protocol/dns/clients/power_dnsdist_client.py b/app/ldap_protocol/dns/clients/power_dnsdist_client.py new file mode 100644 index 000000000..1e13899f4 --- /dev/null +++ b/app/ldap_protocol/dns/clients/power_dnsdist_client.py @@ -0,0 +1,250 @@ +"""Clinet for Power dnsdist service. + +Copyright (c) 2026 MultiFactor +License: https://github.com/MultiDirectoryLab/MultiDirectory/blob/main/LICENSE +""" + +import re +from ipaddress import IPv4Address +from typing import Literal, overload + +from dnsdist_console import Console + +from ldap_protocol.dns.dto import ( + CommandResponse, + DNSdistCommand, + DNSdistCommandsDelta, + DNSdistRulesTable, + RuleEntry, +) +from ldap_protocol.dns.enums import DNSdistCommandTypes +from ldap_protocol.dns.exceptions import DNSdistError + + +class PowerDNSDistClient: + """Client for dnsdist.""" + + def __init__( + self, + dnsdist_host: str, + dnsdist_port: int, + dnsdist_key: str, + config_path: str, + ) -> None: + self._console = Console( + host=dnsdist_host, + port=dnsdist_port, + key=dnsdist_key, + ) + self._config_path = config_path + + @overload + def _send_command( + self, + command: str, + *, + expected: Literal[DNSdistCommandTypes.GENERIC], + ) -> CommandResponse: ... + + @overload + def _send_command( + self, + command: str, + *, + expected: Literal[DNSdistCommandTypes.SHOW_RULES], + ) -> DNSdistRulesTable: ... + + @overload + def _send_command( + self, + command: str, + *, + expected: Literal[DNSdistCommandTypes.COMMANDS_DELTA], + ) -> DNSdistCommandsDelta: ... + + def _send_command( + self, + command: str, + *, + expected: DNSdistCommandTypes = DNSdistCommandTypes.GENERIC, + ) -> CommandResponse | DNSdistRulesTable | DNSdistCommandsDelta: + """Send command to dnsdist console.""" + raw: str = self._console.send_command(command) + + if expected is DNSdistCommandTypes.GENERIC: + if "error" in raw.lower() or "fail" in raw.lower(): + raise DNSdistError(f"dnsdist command error: {raw.strip()}") + return CommandResponse(message=raw.strip() or "OK") + + if expected is DNSdistCommandTypes.SHOW_RULES: + rules = [] + pattern = re.compile(r"^(\d+)\s+\d+\s+(.+?)\s{2,}(to .+)$") + for line in raw.strip().split("\n"): + if "to pool" in line and (matches := pattern.match(line)): + rules.append( + RuleEntry( + id=int(matches.group(1)), + match=matches.group(2).strip(), + action=matches.group(3).strip(), + ), + ) + return DNSdistRulesTable(rules=rules, count=len(rules)) + + if expected is DNSdistCommandTypes.COMMANDS_DELTA: + commands = [] + for command in raw.split("\n"): + commands.append(DNSdistCommand(command=command)) + return DNSdistCommandsDelta(delta=commands, count=len(commands)) + + def _get_all_rules(self) -> DNSdistRulesTable: + """Get list of all rules.""" + command = "showRules()" + return self._send_command( + command, + expected=DNSdistCommandTypes.SHOW_RULES, + ) + + def get_rule_by_match(self, match: str) -> RuleEntry | None: + """Get rule by rule match.""" + rules = self._get_all_rules() + for rule in rules.rules: + if rule.match == match: + return rule + + return None + + def add_server( + self, + server_host: str | IPv4Address, + pool: str, + ) -> None: + """Add server to dnsdist config.""" + command = f""" + newServer({{ + address = "{server_host}:53", + pool = "{pool}" + }}) + """ + self._send_command( + command, + expected=DNSdistCommandTypes.GENERIC, + ) + + self._persist_config() + + def setup_dnsdist(self, recursor_ip: str) -> None: + """Set up dnsdist with initial configuration.""" + command = f""" + newServer({{ + address = "{recursor_ip}:53", + pool = "recursor" + }}) + """ + self._send_command( + command, + expected=DNSdistCommandTypes.GENERIC, + ) + + command = """ + addAction( + AllRule(), + PoolAction("recursor") + ) + """ + self._send_command( + command, + expected=DNSdistCommandTypes.GENERIC, + ) + + def add_zone_rule(self, domain: str) -> None: + """Add rule to redirect master zone DNS requests to auth server.""" + command = f""" + addAction( + QNameSuffixRule("{domain}"), + PoolAction("master") + ) + """ + self._send_command( + command, + expected=DNSdistCommandTypes.GENERIC, + ) + + self._deprioritize_all_match_rule() + + self._persist_config() + + def remove_zone_rule(self, domain: str) -> None: + """Remove redirect rule from dnsdist.""" + rules = self._get_all_rules() + if not rules.count: + raise DNSdistError( + "Failed to delete existing rule in dnsdist: Not Found", + ) + + for rule in rules.rules: + rule_match = rule.match.split(" ")[-1] + domain_match = domain if domain.endswith(".") else f"{domain}." + if domain_match == rule_match: + command = f"rmRule({rule.id})" + self._send_command( + command, + expected=DNSdistCommandTypes.GENERIC, + ) + + self._persist_config() + + def _deprioritize_all_match_rule(self) -> None: + """Remove and add all matching rule to depriortitize it.""" + rule = self.get_rule_by_match("All") + if rule is not None: + command = f"rmRule({rule.id})" + self._send_command( + command, + expected=DNSdistCommandTypes.GENERIC, + ) + + command = """ + addAction( + AllRule(), + PoolAction("recursor") + ) + """ + self._send_command( + command, + expected=DNSdistCommandTypes.GENERIC, + ) + + self._persist_config() + + def _get_commands_delta(self) -> DNSdistCommandsDelta: + """Get list of commands that have not been persisted yet.""" + command = "delta()" + return self._send_command( + command, + expected=DNSdistCommandTypes.COMMANDS_DELTA, + ) + + def _save_commands_delta( + self, + commands_delta: DNSdistCommandsDelta, + ) -> None: + """Save commands delta to dnsdist config file.""" + with open(self._config_path, "a+", encoding="utf-8") as config_file: + for command in commands_delta.delta: + config_file.write(f"{command.command}\n") + + def _clear_console_history(self) -> None: + """Clear console history to delete written delta.""" + command = "clearConsoleHistory()" + self._send_command( + command, + expected=DNSdistCommandTypes.GENERIC, + ) + + def _persist_config(self) -> None: + """Persist dnsdist config to file.""" + commands_delta = self._get_commands_delta() + if commands_delta.count: + self._save_commands_delta(commands_delta) + + self._clear_console_history() diff --git a/app/ldap_protocol/dns/constants.py b/app/ldap_protocol/dns/constants.py new file mode 100644 index 000000000..3e7a33171 --- /dev/null +++ b/app/ldap_protocol/dns/constants.py @@ -0,0 +1,98 @@ +"""Constants for DNS module. + +Copyright (c) 2025 MultiFactor +License: https://github.com/MultiDirectoryLab/MultiDirectory/blob/main/LICENSE +""" + +from ldap_protocol.dns.enums import DNSRecordType + +DNS_MANAGER_STATE_NAME = "DNSManagerState" +DNS_MANAGER_ZONE_NAME = "DNSManagerZoneName" +DNS_MANAGER_IP_ADDRESS_NAME = "DNSManagerIpAddress" +DNS_MANAGER_TSIG_KEY_NAME = "DNSManagerTSIGKey" + +DEFAULT_FORWARD_ZONE_NAMES: list[str] = [ + ".", + "b.e.f.ip6.arpa.", + "a.e.f.ip6.arpa.", + "23.172.in-addr.arpa.", + "21.172.in-addr.arpa.", + "254.169.in-addr.arpa.", + "20.172.in-addr.arpa.", + "17.172.in-addr.arpa.", + "31.172.in-addr.arpa.", + "22.172.in-addr.arpa.", + "16.172.in-addr.arpa.", + "19.172.in-addr.arpa.", + "24.172.in-addr.arpa.", + "168.192.in-addr.arpa.", + "10.in-addr.arpa.", + "8.e.f.ip6.arpa.", + "127.in-addr.arpa.", + "113.0.203.in-addr.arpa.", + "26.172.in-addr.arpa.", + "27.172.in-addr.arpa.", + "8.b.d.0.1.0.0.2.ip6.arpa.", + "28.172.in-addr.arpa.", + "d.f.ip6.arpa.", + "18.172.in-addr.arpa.", + "30.172.in-addr.arpa.", + "9.e.f.ip6.arpa.", + "100.51.198.in-addr.arpa.", + "255.255.255.255.in-addr.arpa.", + "0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.ip6.arpa.", + "29.172.in-addr.arpa.", + "0.in-addr.arpa.", + "25.172.in-addr.arpa.", + "2.0.192.in-addr.arpa.", +] + +DNS_FIRST_SETUP_RECORDS: list[dict[str, str | DNSRecordType]] = [ + {"name": "_ldap._tcp.", "value": "0 0 389 ", "type": DNSRecordType.SRV}, + {"name": "_ldaps._tcp.", "value": "0 0 636 ", "type": DNSRecordType.SRV}, + {"name": "_kerberos._tcp.", "value": "0 0 88 ", "type": DNSRecordType.SRV}, + {"name": "_kerberos._udp.", "value": "0 0 88 ", "type": DNSRecordType.SRV}, + {"name": "_kdc._tcp.", "value": "0 0 88 ", "type": DNSRecordType.SRV}, + {"name": "_kdc._udp.", "value": "0 0 88 ", "type": DNSRecordType.SRV}, + {"name": "_kpasswd._tcp.", "value": "0 0 464 ", "type": DNSRecordType.SRV}, + {"name": "_kpasswd._udp.", "value": "0 0 464 ", "type": DNSRecordType.SRV}, + # Record for PDC Emulator + { + "name": "_ldap._tcp.pdc._msdcs.", + "value": "0 100 389 ", + "type": DNSRecordType.SRV, + }, + # Records for DC Locator (for trusts) + { + "name": "_kerberos._tcp.dc._msdcs.", + "value": "0 100 88 ", + "type": DNSRecordType.SRV, + }, + { + "name": "_kerberos._tcp.Default-First-Site-Name._sites.dc._msdcs.", + "value": "0 100 88 ", + "type": DNSRecordType.SRV, + }, + { + "name": "_ldap._tcp.dc._msdcs.", + "value": "0 100 389 ", + "type": DNSRecordType.SRV, + }, + { + "name": "_ldap._tcp.Default-First-Site-Name._sites.dc._msdcs.", + "value": "0 100 389 ", + "type": DNSRecordType.SRV, + }, + # Records for Global Catalog + {"name": "_gc._tcp.", "value": "0 100 3268 ", "type": DNSRecordType.SRV}, + { + "name": "_ldap._tcp.Default-First-Site-Name._sites.gc._msdcs.", + "value": "0 100 3268 ", + "type": DNSRecordType.SRV, + }, + { + "name": "_ldap._tcp.gc._msdcs.", + "value": "0 100 3268 ", + "type": DNSRecordType.SRV, + }, +] diff --git a/app/ldap_protocol/dns/dns_gateway.py b/app/ldap_protocol/dns/dns_gateway.py index c5a525f35..f5fa6802d 100644 --- a/app/ldap_protocol/dns/dns_gateway.py +++ b/app/ldap_protocol/dns/dns_gateway.py @@ -4,21 +4,21 @@ License: https://github.com/MultiDirectoryLab/MultiDirectory/blob/main/LICENSE """ -from typing import Awaitable +from ipaddress import IPv4Address from sqlalchemy import case, select, update from sqlalchemy.ext.asyncio import AsyncSession +from config import Settings from entities import CatalogueSetting -from ldap_protocol.dns.base import ( +from ldap_protocol.dns.constants import ( DNS_MANAGER_IP_ADDRESS_NAME, DNS_MANAGER_STATE_NAME, DNS_MANAGER_TSIG_KEY_NAME, DNS_MANAGER_ZONE_NAME, - DNSManagerSettings, - DNSManagerState, ) -from ldap_protocol.dns.dto import DNSSettingDTO +from ldap_protocol.dns.dto import DNSSettingsDTO, PowerDNSSettingsDTO +from ldap_protocol.dns.enums import DNSManagerState from repo.pg.tables import queryable_attr as qa @@ -29,54 +29,44 @@ def __init__(self, session: AsyncSession) -> None: """Initialize DNS gateway.""" self._session = session - async def setup_dns_state( - self, - state: DNSManagerState | str, - ) -> None: - """Set up DNS server and DNS manager.""" - await self._session.execute( - update(CatalogueSetting) - .values({"value": state}) - .filter_by(name=DNS_MANAGER_STATE_NAME), - ) - async def get(self, name: str) -> CatalogueSetting | None: """Get DNS by name.""" return await self._session.scalar( - select(CatalogueSetting).filter_by(name=name), - ) + select(CatalogueSetting) + .filter_by(name=name), + ) # fmt: skip async def create(self, data: CatalogueSetting) -> None: """Create DNS.""" self._session.add(data) await self._session.commit() - async def get_dns_settings(self) -> dict[str, str]: + async def get_settings_from_db(self) -> dict[str, str]: """Get DNS managers.""" - return { - setting.name: setting.value - for setting in await self._session.scalars( - select(CatalogueSetting).filter( - qa(CatalogueSetting.name).in_( - [ - DNS_MANAGER_ZONE_NAME, - DNS_MANAGER_IP_ADDRESS_NAME, - DNS_MANAGER_TSIG_KEY_NAME, - ], + settings = await self._session.scalars( + select(CatalogueSetting) + .filter( + qa(CatalogueSetting.name).in_(( + DNS_MANAGER_ZONE_NAME, + DNS_MANAGER_IP_ADDRESS_NAME, + DNS_MANAGER_TSIG_KEY_NAME, ), ), - ) - } + ), + ) # fmt: skip + result = {setting.name: setting.value for setting in settings} + + return result async def update_settings( self, - data: DNSSettingDTO, + data: DNSSettingsDTO, ) -> None: """Update DNS settings.""" settings = [ ( qa(CatalogueSetting.name) == DNS_MANAGER_ZONE_NAME, - data.zone_name, + data.domain, ), ( qa(CatalogueSetting.name) == DNS_MANAGER_IP_ADDRESS_NAME, @@ -111,14 +101,14 @@ async def update_settings( async def create_settings( self, - data: DNSSettingDTO, + data: DNSSettingsDTO, ) -> None: """Create DNS settings.""" self._session.add_all( [ CatalogueSetting( name=DNS_MANAGER_ZONE_NAME, - value=data.zone_name or "", + value=data.domain or "", ), CatalogueSetting( name=DNS_MANAGER_IP_ADDRESS_NAME, @@ -134,22 +124,38 @@ async def create_settings( async def get_dns_manager_settings( self, - resolve_coro: Awaitable[str], - ) -> DNSManagerSettings: + app_settings: Settings, + domain: str, + ) -> DNSSettingsDTO: """Get DNS manager settings.""" - settings = await self.get_dns_settings() - dns_server_ip = settings.get(DNS_MANAGER_IP_ADDRESS_NAME) + power_dns_settings = PowerDNSSettingsDTO( + auth_server_ip=app_settings.PDNS_AUTH_SERVER_IP, + recursor_server_ip=app_settings.PDNS_RECURSOR_SERVER_IP, + ) + dns_settings = DNSSettingsDTO( + domain=domain, + dns_server_ip=None, + tsig_key=None, + default_nameserver=app_settings.DEFAULT_NAMESERVER, + power_dns_settings=power_dns_settings, + ) - if await self.get_dns_state() == DNSManagerState.SELFHOSTED: - dns_server_ip = await resolve_coro + if await self.get_state() == DNSManagerState.HOSTED: + settings_from_db = await self.get_settings_from_db() + dns_settings.domain = settings_from_db.get( + DNS_MANAGER_ZONE_NAME, + "", + ) + dns_settings.dns_server_ip = IPv4Address( + settings_from_db.get(DNS_MANAGER_IP_ADDRESS_NAME), + ) + dns_settings.tsig_key = settings_from_db.get( + DNS_MANAGER_TSIG_KEY_NAME, + ) - return DNSManagerSettings( - zone_name=settings.get(DNS_MANAGER_ZONE_NAME), - dns_server_ip=dns_server_ip, - tsig_key=settings.get(DNS_MANAGER_TSIG_KEY_NAME), - ) + return dns_settings - async def get_dns_state(self) -> DNSManagerState: + async def get_state(self) -> DNSManagerState: """Get DNS state.""" state = await self.get(DNS_MANAGER_STATE_NAME) if state is None: @@ -161,3 +167,23 @@ async def get_dns_state(self) -> DNSManagerState: ) return DNSManagerState.NOT_CONFIGURED return DNSManagerState(state.value) + + async def set_state( + self, + state: DNSManagerState, + ) -> None: + """Set DNS state.""" + existing_state = await self.get(DNS_MANAGER_STATE_NAME) + if existing_state is None: + await self.create( + CatalogueSetting( + name=DNS_MANAGER_STATE_NAME, + value=state, + ), + ) + else: + await self._session.execute( + update(CatalogueSetting) + .values({"value": state}) + .filter_by(name=DNS_MANAGER_STATE_NAME), + ) diff --git a/app/ldap_protocol/dns/dto.py b/app/ldap_protocol/dns/dto.py index 8edd4d781..20596be16 100644 --- a/app/ldap_protocol/dns/dto.py +++ b/app/ldap_protocol/dns/dto.py @@ -4,14 +4,118 @@ License: https://github.com/MultiDirectoryLab/MultiDirectory/blob/main/LICENSE """ -from dataclasses import dataclass +from dataclasses import dataclass, field from ipaddress import IPv4Address, IPv6Address +from ldap_protocol.dns.enums import ( + DNSForwarderServerStatus, + DNSRecordType, + PowerDNSRecordChangeType, + PowerDNSZoneType, +) + + +@dataclass +class CommandResponse: + success: bool = True + message: str = " " + + +@dataclass +class RuleEntry: + id: int + match: str + action: str + + +@dataclass +class DNSdistRulesTable: + rules: list[RuleEntry] + count: int + + +@dataclass +class DNSdistCommand: + command: str + + +@dataclass +class DNSdistCommandsDelta: + delta: list[DNSdistCommand] + count: int + @dataclass -class DNSSettingDTO: - """DNS settings entity.""" +class PowerDNSSettingsDTO: + """PowerDNS related settings.""" + + auth_server_ip: str + recursor_server_ip: str - zone_name: str | None - dns_server_ip: str | IPv4Address | IPv6Address | None + +@dataclass +class DNSSettingsDTO: + """DNS settings DTO.""" + + domain: str + dns_server_ip: IPv4Address | IPv6Address | None tsig_key: str | None + default_nameserver: str + power_dns_settings: PowerDNSSettingsDTO | None = field(default=None) + + +@dataclass +class DNSRecordDTO: + """DNS record DTO.""" + + content: str + disabled: bool + modified_at: int | None = None + + +@dataclass +class DNSRRSetDTO: + """DNS RRSet(Resource Record Set) DTO.""" + + name: str + type: DNSRecordType + records: list[DNSRecordDTO] + changetype: PowerDNSRecordChangeType | None = None + ttl: int | None = None + + +@dataclass +class DNSZoneBaseDTO: + """DNS zone DTO.""" + + id: str + name: str + rrsets: list[DNSRRSetDTO] = field(default_factory=list) + type: str = "zone" + + +@dataclass +class DNSMasterZoneDTO(DNSZoneBaseDTO): + """DNS master zone DTO.""" + + dnssec: bool = field(default=False) + nameservers: list[str] = field(default_factory=list) + kind: PowerDNSZoneType = PowerDNSZoneType.MASTER + + +@dataclass +class DNSForwardZoneDTO(DNSZoneBaseDTO): + """DNS forward zone DTO.""" + + servers: list[str] = field(default_factory=list) + recursion_desired: bool = field(default=False) + kind: PowerDNSZoneType = PowerDNSZoneType.FORWARDED + + +@dataclass +class DNSForwardServerStatus: + """Forward DNS server status.""" + + ip: str + status: DNSForwarderServerStatus + FQDN: str | None diff --git a/app/ldap_protocol/dns/enums.py b/app/ldap_protocol/dns/enums.py new file mode 100644 index 000000000..218f125be --- /dev/null +++ b/app/ldap_protocol/dns/enums.py @@ -0,0 +1,63 @@ +"""Enums for DNS module. + +Copyright (c) 2025 MultiFactor +License: https://github.com/MultiDirectoryLab/MultiDirectory/blob/main/LICENSE +""" + +from enum import Enum, StrEnum + + +class DNSdistCommandTypes(Enum): + """PDNSdist command types.""" + + GENERIC = "generic" + SHOW_RULES = "show_rules" + COMMANDS_DELTA = "commands_delta" + + +class DNSRecordType(StrEnum): + """PowerDNS Record Types.""" + + A = "A" + AAAA = "AAAA" + CNAME = "CNAME" + MX = "MX" + TXT = "TXT" + NS = "NS" + SOA = "SOA" + SRV = "SRV" + PTR = "PTR" + + +class PowerDNSZoneType(StrEnum): + """PowerDNS Zone Types.""" + + MASTER = "Master" + FORWARDED = "Forwarded" + NATIVE = "Native" + PRIMARY = "Primary" + + +class PowerDNSRecordChangeType(StrEnum): + """PowerDNS Record Change Types.""" + + REPLACE = "REPLACE" + DELETE = "DELETE" + EXTEND = "EXTEND" + PRUNE = "PRUNE" + + +class DNSForwarderServerStatus(StrEnum): + """Forwarder DNS server statuses.""" + + VALIDATED = "validated" + NOT_VALIDATED = "not validated" + NOT_FOUND = "not found" + + +class DNSManagerState(StrEnum): + """DNSManager state enum.""" + + NOT_CONFIGURED = "0" + SELFHOSTED = "1" + HOSTED = "2" diff --git a/app/ldap_protocol/dns/exceptions.py b/app/ldap_protocol/dns/exceptions.py index 5b9da9f5e..b4320cf89 100644 --- a/app/ldap_protocol/dns/exceptions.py +++ b/app/ldap_protocol/dns/exceptions.py @@ -14,15 +14,18 @@ class ErrorCodes(IntEnum): BASE_ERROR = 0 DNS_SETUP_ERROR = 1 - DNS_RECORD_CREATE_ERROR = 2 - DNS_RECORD_UPDATE_ERROR = 3 - DNS_RECORD_DELETE_ERROR = 4 - DNS_ZONE_CREATE_ERROR = 5 - DNS_ZONE_UPDATE_ERROR = 6 - DNS_ZONE_DELETE_ERROR = 7 - DNS_UPDATE_SERVER_OPTIONS_ERROR = 8 - DNS_CONNECTION_ERROR = 9 - DNS_NOT_IMPLEMENTED_ERROR = 10 + DNS_RECORD_GET_ERROR = 2 + DNS_RECORD_CREATE_ERROR = 3 + DNS_RECORD_UPDATE_ERROR = 4 + DNS_RECORD_DELETE_ERROR = 5 + DNS_ZONE_GET_ERROR = 6 + DNS_ZONE_CREATE_ERROR = 7 + DNS_ZONE_UPDATE_ERROR = 8 + DNS_ZONE_DELETE_ERROR = 9 + DNS_UPDATE_SERVER_OPTIONS_ERROR = 10 + DNS_CONNECTION_ERROR = 11 + DNS_NOT_IMPLEMENTED_ERROR = 12 + DNS_UNAVAILABLE_ERROR = 13 class DNSError(BaseDomainException): @@ -43,6 +46,12 @@ class DNSRecordCreateError(DNSError): code = ErrorCodes.DNS_RECORD_CREATE_ERROR +class DNSRecordGetError(DNSError): + """DNS record get error.""" + + code = ErrorCodes.DNS_RECORD_GET_ERROR + + class DNSRecordUpdateError(DNSError): """DNS record update error.""" @@ -61,6 +70,12 @@ class DNSZoneCreateError(DNSError): code = ErrorCodes.DNS_ZONE_CREATE_ERROR +class DNSZoneGetError(DNSError): + """DNS zone get error.""" + + code = ErrorCodes.DNS_ZONE_GET_ERROR + + class DNSZoneUpdateError(DNSError): """DNS zone update error.""" @@ -89,3 +104,37 @@ class DNSNotImplementedError(DNSError): """DNS not implemented error.""" code = ErrorCodes.DNS_NOT_IMPLEMENTED_ERROR + + +class DNSUnavailableError(DNSError): + """DNS server is unavailable.""" + + code = ErrorCodes.DNS_UNAVAILABLE_ERROR + + +class DNSCreateEntryError(DNSError): + """DNS create entry error.""" + + +class DNSDeleteEntryError(DNSError): + """DNS delete entry error.""" + + +class DNSUpdateEntryError(DNSError): + """DNS update entry error.""" + + +class DNSEntryNotFoundError(DNSError): + """DNS entry not found error.""" + + +class DNSValidationError(DNSError): + """DNS validation error.""" + + +class DNSNotSupportedError(DNSError): + """DNS not supported error.""" + + +class DNSdistError(DNSError): + """DNS dist error.""" diff --git a/app/ldap_protocol/dns/managers/__init__.py b/app/ldap_protocol/dns/managers/__init__.py new file mode 100644 index 000000000..5422d2de7 --- /dev/null +++ b/app/ldap_protocol/dns/managers/__init__.py @@ -0,0 +1,11 @@ +from ldap_protocol.dns.managers.abstract_dns_manager import AbstractDNSManager +from ldap_protocol.dns.managers.power_dns_manager import PowerDNSManager +from ldap_protocol.dns.managers.remote_dns_manager import RemoteDNSManager +from ldap_protocol.dns.managers.stub_dns_manager import StubDNSManager + +__all__ = [ + "PowerDNSManager", + "RemoteDNSManager", + "AbstractDNSManager", + "StubDNSManager", +] diff --git a/app/ldap_protocol/dns/managers/abstract_dns_manager.py b/app/ldap_protocol/dns/managers/abstract_dns_manager.py new file mode 100644 index 000000000..90cf0a14e --- /dev/null +++ b/app/ldap_protocol/dns/managers/abstract_dns_manager.py @@ -0,0 +1,119 @@ +"""Abstract DNS manager for DNS server managing. + +Copyright (c) 2026 MultiFactor +License: https://github.com/MultiDirectoryLab/MultiDirectory/blob/main/LICENSE +""" + +from abc import abstractmethod +from ipaddress import IPv4Address, IPv6Address + +from ldap_protocol.dns.clients.abstract_client import ( + AbstractDNSForwardHTTPClient, + AbstractDNSMasterHTTPClient, +) +from ldap_protocol.dns.dto import ( + DNSForwardServerStatus, + DNSForwardZoneDTO, + DNSMasterZoneDTO, + DNSRRSetDTO, + DNSSettingsDTO, +) + + +class AbstractDNSManager: + """Abstract DNS manager class.""" + + _dns_settings: DNSSettingsDTO + _dns_master_client: AbstractDNSMasterHTTPClient | None = None + _dns_forward_client: AbstractDNSForwardHTTPClient | None = None + + def __init__( + self, + settings: DNSSettingsDTO, + ) -> None: + """Set up DNS manager.""" + self._dns_settings = settings + + @abstractmethod + async def setup( + self, + dns_settings: DNSSettingsDTO, + is_migration: bool = False, + ) -> None: ... + + @abstractmethod + async def create_record( + self, + zone_id: str, + record: DNSRRSetDTO, + ) -> None: ... + + @abstractmethod + async def update_record( + self, + zone_id: str, + record: DNSRRSetDTO, + ) -> None: ... + + @abstractmethod + async def delete_record( + self, + zone_id: str, + record: DNSRRSetDTO, + ) -> None: ... + + @abstractmethod + async def get_records( + self, + zone_id: str, + ) -> list[DNSRRSetDTO]: ... + + @abstractmethod + async def get_master_zones(self) -> list[DNSMasterZoneDTO]: ... + + @abstractmethod + async def get_forward_zones(self) -> list[DNSForwardZoneDTO]: ... + + @abstractmethod + async def create_master_zone( + self, + zone: DNSMasterZoneDTO, + is_empty: bool = False, + ) -> None: ... + + @abstractmethod + async def create_forward_zone( + self, + zone: DNSForwardZoneDTO, + ) -> None: ... + + @abstractmethod + async def update_master_zone( + self, + zone: DNSMasterZoneDTO, + ) -> None: ... + + @abstractmethod + async def update_forward_zone( + self, + zone: DNSForwardZoneDTO, + ) -> None: ... + + @abstractmethod + async def delete_master_zone( + self, + zone_id: str, + ) -> None: ... + + @abstractmethod + async def delete_forward_zone( + self, + zone_id: str, + ) -> None: ... + + @abstractmethod + async def check_forward_dns_server( + self, + dns_server_ip: IPv4Address | IPv6Address, + host_dns_servers: list[str], + ) -> DNSForwardServerStatus: ... diff --git a/app/ldap_protocol/dns/managers/power_dns_manager.py b/app/ldap_protocol/dns/managers/power_dns_manager.py new file mode 100644 index 000000000..d42d478b7 --- /dev/null +++ b/app/ldap_protocol/dns/managers/power_dns_manager.py @@ -0,0 +1,352 @@ +"""PowerDNS API manager module. + +Copyright (c) 2025 MultiFactor +License: https://github.com/MultiDirectoryLab/MultiDirectory/blob/main/LICENSE +""" + +import asyncio +from ipaddress import IPv4Address, IPv6Address + +import dns.asyncresolver + +from ldap_protocol.dns.clients import ( + PowerDNSAuthHTTPClient, + PowerDNSDistClient, + PowerDNSRecursorHTTPClient, +) +from ldap_protocol.dns.constants import ( + DEFAULT_FORWARD_ZONE_NAMES, + DNS_FIRST_SETUP_RECORDS, +) +from ldap_protocol.dns.dto import ( + DNSForwardServerStatus, + DNSForwardZoneDTO, + DNSMasterZoneDTO, + DNSRecordDTO, + DNSRRSetDTO, + DNSSettingsDTO, +) +from ldap_protocol.dns.enums import ( + DNSForwarderServerStatus, + DNSRecordType, + PowerDNSRecordChangeType, +) +from ldap_protocol.dns.exceptions import ( + DNSError, + DNSRecordCreateError, + DNSRecordDeleteError, + DNSRecordGetError, + DNSRecordUpdateError, + DNSSetupError, + DNSZoneCreateError, + DNSZoneDeleteError, + DNSZoneGetError, + DNSZoneUpdateError, +) +from ldap_protocol.dns.managers.abstract_dns_manager import AbstractDNSManager +from ldap_protocol.dns.utils import create_initial_zone_records, logger_wraps + + +class PowerDNSManager(AbstractDNSManager): + """Manager for interacting with the PowerDNS API.""" + + _power_dns_auth_client: PowerDNSAuthHTTPClient + _power_dns_recursor_client: PowerDNSRecursorHTTPClient + _dnsdist_client: PowerDNSDistClient + + def __init__( + self, + settings: DNSSettingsDTO, + power_dns_auth_client: PowerDNSAuthHTTPClient, + power_dns_recursor_client: PowerDNSRecursorHTTPClient, + dnsdist_client: PowerDNSDistClient, + ) -> None: + """Initialize the PowerDNS API repository.""" + super().__init__(settings) + self._power_dns_auth_client = power_dns_auth_client + self._power_dns_recursor_client = power_dns_recursor_client + self._dnsdist_client = dnsdist_client + + @staticmethod + def _normalize_dns_name(name: str) -> str: + """Normalize DNS name by ensuring it ends with a dot.""" + return name if name.endswith(".") else f"{name}." + + @logger_wraps() + async def setup( + self, + dns_settings: DNSSettingsDTO, + is_migration: bool = False, + ) -> None: + """Set up DNS server and DNS manager.""" + records = [] + if dns_settings.power_dns_settings is None: + raise DNSSetupError("PowerDNS settings is not set.") + + if not is_migration: + for record in DNS_FIRST_SETUP_RECORDS: + records.append( + DNSRRSetDTO( + name=f"{record['name']}{self._dns_settings.domain}.", + type=DNSRecordType(record["type"]), + records=[ + DNSRecordDTO( + content=f"{record['value']}{self._dns_settings.domain}.", + disabled=False, + modified_at=None, + ), + ], + changetype=PowerDNSRecordChangeType.EXTEND, + ttl=3600, + ), + ) + + try: + self._dnsdist_client.setup_dnsdist( + dns_settings.power_dns_settings.recursor_server_ip, + ) + self._dnsdist_client.add_server( + dns_settings.power_dns_settings.auth_server_ip, + "master", + ) + if not is_migration: + await self.create_master_zone( + DNSMasterZoneDTO( + id=self._dns_settings.domain, + name=self._dns_settings.domain, + dnssec=False, + rrsets=records, + ), + ) + except DNSZoneCreateError as e: + raise DNSSetupError(f"Failed to set up DNS: {e}") + + @logger_wraps() + async def create_record(self, zone_id: str, record: DNSRRSetDTO) -> None: + """Create a DNS record in the specified zone.""" + record.name = self._normalize_dns_name(record.name) + + record.changetype = PowerDNSRecordChangeType.REPLACE + + try: + await self._power_dns_auth_client.record_action(zone_id, record) + except DNSError as e: + raise DNSRecordCreateError(f"Failed to create DNS record: {e}") + + @logger_wraps() + async def get_records(self, zone_id: str) -> list[DNSRRSetDTO]: + """Retrieve all DNS records for the specified zone.""" + try: + return await self._power_dns_auth_client.get_records(zone_id) + except DNSError as e: + raise DNSRecordGetError(f"Failed to get DNS records: {e}") + + @logger_wraps() + async def update_record(self, zone_id: str, record: DNSRRSetDTO) -> None: + """Update a DNS record in the specified zone.""" + record.name = self._normalize_dns_name(record.name) + + record.changetype = PowerDNSRecordChangeType.REPLACE + + try: + await self._power_dns_auth_client.record_action(zone_id, record) + except DNSError as e: + raise DNSRecordUpdateError(f"Failed to update DNS record: {e}") + + @logger_wraps() + async def delete_record(self, zone_id: str, record: DNSRRSetDTO) -> None: + """Delete a DNS record from the specified zone.""" + record.name = self._normalize_dns_name(record.name) + + record.changetype = PowerDNSRecordChangeType.DELETE + + try: + await self._power_dns_auth_client.record_action(zone_id, record) + except DNSError as e: + raise DNSRecordDeleteError(f"Failed to delete DNS record: {e}") + + @logger_wraps() + async def create_master_zone( + self, + zone: DNSMasterZoneDTO, + is_empty: bool = False, + ) -> None: + """Create a master DNS zone.""" + zone.name = self._normalize_dns_name(zone.name) + + if not is_empty: + zone.nameservers.append(f"ns1.{zone.name}") + + records = await create_initial_zone_records( + zone.name, + self._dns_settings.default_nameserver, + ) + zone.rrsets.extend(records) + + try: + await self._power_dns_auth_client.create_master_zone(zone) + self._dnsdist_client.add_zone_rule( + zone.name if not zone.name.endswith(".") else zone.name[:-1], + ) + except DNSError as e: + raise DNSZoneCreateError(f"Failed to create DNS zone: {e}") + + @logger_wraps() + async def create_forward_zone(self, zone: DNSForwardZoneDTO) -> None: + """Create a forward DNS zone.""" + zone.name = self._normalize_dns_name(zone.name) + + try: + await self._power_dns_recursor_client.create_forward_zone(zone) + except DNSError as e: + raise DNSZoneCreateError(f"Failed to create DNS zone: {e}") + + @logger_wraps() + async def get_master_zones(self) -> list[DNSMasterZoneDTO]: + """Retrieve all DNS zones.""" + try: + zones = await self._power_dns_auth_client.get_master_zones() + except DNSError as e: + raise DNSZoneGetError(f"Failed to get DNS zones: {e}") + + for zone in zones: + zone.rrsets = await self.get_records(zone.id) + + return zones + + @logger_wraps() + async def get_master_zone_by_id(self, zone_id: str) -> DNSMasterZoneDTO: + """Get master DNS zone by ID.""" + try: + return await self._power_dns_auth_client.get_master_zone_by_id( + zone_id, + ) + except DNSError as e: + raise DNSZoneGetError(f"Failed to get DNS zones: {e}") + + @logger_wraps() + async def get_forward_zones(self) -> list[DNSForwardZoneDTO]: + """Retrieve all forward DNS zones.""" + try: + forward_zones = ( + await self._power_dns_recursor_client.get_forward_zones() + ) + return [ + zone + for zone in forward_zones + if zone.name not in DEFAULT_FORWARD_ZONE_NAMES + ] + except DNSError as e: + raise DNSZoneGetError(f"Failed to get DNS zones: {e}") + + @logger_wraps() + async def update_master_zone(self, zone: DNSMasterZoneDTO) -> None: + """Update a master DNS zone.""" + zone.name = self._normalize_dns_name(zone.name) + try: + await self._power_dns_auth_client.update_master_zone(zone.id, zone) + except DNSError as e: + raise DNSZoneUpdateError(f"Failed to update DNS zone: {e}") + + @logger_wraps() + async def update_forward_zone(self, zone: DNSForwardZoneDTO) -> None: + """Update a forward DNS zone.""" + zone.name = self._normalize_dns_name(zone.name) + + try: + await self._power_dns_recursor_client.update_forward_zone( + zone.id, + zone, + ) + except DNSError as e: + raise DNSZoneUpdateError(f"Failed to update DNS zone: {e}") + + @logger_wraps() + async def delete_master_zone(self, zone_id: str) -> None: + """Delete a DNS zone.""" + zone = await self.get_master_zone_by_id(zone_id) + + try: + await self._power_dns_auth_client.delete_master_zone(zone_id) + self._dnsdist_client.remove_zone_rule(zone.name[:-1]) + except DNSError as e: + raise DNSZoneDeleteError(f"Failed to delete DNS zone: {e}") + + @logger_wraps() + async def delete_forward_zone(self, zone_id: str) -> None: + """Delete a DNS forward zone.""" + try: + await self._power_dns_recursor_client.delete_forward_zone(zone_id) + except DNSError as e: + raise DNSZoneDeleteError(f"Failed to delete DNS zone: {e}") + + @logger_wraps() + async def find_forward_dns_fqdn( + self, + dns_server_ip: IPv4Address | IPv6Address, + host_dns_servers: list[str], + ) -> str | None: + """Find forward DNS FQDN.""" + reversed_ip = ( + ".".join(reversed((str(dns_server_ip)).split("."))) + + ".in-addr.arpa" + ) + + async def get_fqdn_and_latency( + server: str, + ) -> tuple[float, str | None]: + resolver = dns.asyncresolver.Resolver() + resolver.nameservers = [server] + resolver.timeout = 10 + + try: + event_loop = asyncio.get_running_loop() + start_time = event_loop.time() + fqdn = await resolver.resolve(reversed_ip, DNSRecordType.PTR) + latency = event_loop.time() - start_time + + return (latency, fqdn[0].to_text()) + except ( + dns.asyncresolver.NoAnswer, + dns.asyncresolver.NXDOMAIN, + ): + return (float("inf"), None) + + fqdn_list = await asyncio.gather( + *(get_fqdn_and_latency(server) for server in host_dns_servers), + ) + fqdn_list.sort(key=lambda x: x[0]) + return fqdn_list[0][1] if fqdn_list else None + + @logger_wraps() + async def check_forward_dns_server( + self, + dns_server_ip: IPv4Address | IPv6Address, + host_dns_servers: list[str], + ) -> DNSForwardServerStatus: + str_dns_server_ip = str(dns_server_ip) + + try: + fqdn = await self.find_forward_dns_fqdn( + dns_server_ip, + host_dns_servers, + ) + except (dns.asyncresolver.NoAnswer, dns.asyncresolver.NXDOMAIN): + return DNSForwardServerStatus( + str_dns_server_ip, + DNSForwarderServerStatus.NOT_VALIDATED, + None, + ) + + if not fqdn: + return DNSForwardServerStatus( + str_dns_server_ip, + DNSForwarderServerStatus.NOT_FOUND, + None, + ) + + return DNSForwardServerStatus( + str_dns_server_ip, + DNSForwarderServerStatus.VALIDATED, + fqdn, + ) diff --git a/app/ldap_protocol/dns/managers/remote_dns_manager.py b/app/ldap_protocol/dns/managers/remote_dns_manager.py new file mode 100644 index 000000000..c41f844bc --- /dev/null +++ b/app/ldap_protocol/dns/managers/remote_dns_manager.py @@ -0,0 +1,193 @@ +"""DNS service for remote DNS server managing. + +Copyright (c) 2024 MultiFactor +License: https://github.com/MultiDirectoryLab/MultiDirectory/blob/main/LICENSE +""" + +from ipaddress import IPv4Address, IPv6Address + +from dns.asyncquery import inbound_xfr as make_inbound_xfr, tcp as asynctcp +from dns.message import Message, make_query as make_dns_query +from dns.name import from_text +from dns.rdataclass import IN +from dns.rdatatype import AXFR +from dns.tsig import Key as TsigKey +from dns.update import Update +from dns.zone import Zone + +from ldap_protocol.dns.dto import ( + DNSForwardServerStatus, + DNSForwardZoneDTO, + DNSMasterZoneDTO, + DNSRecordDTO, + DNSRRSetDTO, + DNSSettingsDTO, +) +from ldap_protocol.dns.exceptions import ( + DNSConnectionError, + DNSNotImplementedError, +) +from ldap_protocol.dns.managers.abstract_dns_manager import AbstractDNSManager +from ldap_protocol.dns.utils import logger_wraps + + +class RemoteDNSManager(AbstractDNSManager): + """DNS server manager.""" + + async def _send(self, action: Message) -> None: + """Send request to DNS server.""" + if self._dns_settings.tsig_key is not None: + action.use_tsig( + keyring=TsigKey("zone.", self._dns_settings.tsig_key), + keyname="zone.", + ) + + if self._dns_settings.dns_server_ip is None: + raise DNSConnectionError + + await asynctcp(action, str(self._dns_settings.dns_server_ip)) + + async def setup( + self, + dns_settings: DNSSettingsDTO, # noqa: ARG002 + is_migration: bool = False, # noqa: ARG002 + ) -> None: + """Set up DNS server and DNS manager.""" + raise DNSNotImplementedError + + @logger_wraps() + async def create_record( + self, + zone_id: str, + record: DNSRRSetDTO, + ) -> None: + """Create DNS record.""" + action = Update(self._dns_settings.domain or zone_id) + action.add( + record.name, + record.ttl, + record.type, + record.records[0].content, + ) + + await self._send(action) + + @logger_wraps() + async def get_records(self, zone_id: str) -> list[DNSRRSetDTO]: + """Get all DNS records.""" + if ( + self._dns_settings.dns_server_ip is None + or self._dns_settings.domain is None + ): + raise DNSConnectionError + + zone = from_text(self._dns_settings.domain or zone_id) + zone_tm = Zone(zone) + query = make_dns_query(zone, AXFR, IN) + + if self._dns_settings.tsig_key is not None: + query.use_tsig( + keyring=TsigKey("zone.", self._dns_settings.tsig_key), + keyname="zone.", + ) + + await make_inbound_xfr( + str(self._dns_settings.dns_server_ip), + zone_tm, + ) + + return [ + DNSRRSetDTO( + name=name.to_text() + f".{self._dns_settings.domain}.", + type=rdata.rdtype.name, + records=[ + DNSRecordDTO( + content=rdata.to_text(), + disabled=False, + ), + ], + ttl=ttl, + ) + for name, ttl, rdata in zone_tm.iterate_rdatas() + ] + + @logger_wraps() + async def update_record( + self, + zone_id: str, + record: DNSRRSetDTO, + ) -> None: + """Update DNS record.""" + action = Update(self._dns_settings.domain or zone_id) + action.replace( + record.name, + record.ttl, + record.type, + record.records[0].content, + ) + await self._send(action) + + @logger_wraps() + async def delete_record( + self, + zone_id: str, + record: DNSRRSetDTO, + ) -> None: + """Delete DNS record.""" + action = Update(self._dns_settings.domain or zone_id) + action.delete( + record.name, + record.type, + record.records[0].content, + ) + await self._send(action) + + async def get_master_zones(self) -> list[DNSMasterZoneDTO]: + raise DNSNotImplementedError + + async def get_forward_zones(self) -> list[DNSForwardZoneDTO]: + raise DNSNotImplementedError + + async def create_master_zone( + self, + zone: DNSMasterZoneDTO, # noqa: ARG002 + is_empty: bool = False, # noqa: ARG002 + ) -> None: + raise DNSNotImplementedError + + async def create_forward_zone( + self, + zone: DNSForwardZoneDTO, # noqa: ARG002 + ) -> None: + raise DNSNotImplementedError + + async def update_master_zone( + self, + zone: DNSMasterZoneDTO, # noqa: ARG002 + ) -> None: + raise DNSNotImplementedError + + async def update_forward_zone( + self, + zone: DNSForwardZoneDTO, # noqa: ARG002 + ) -> None: + raise DNSNotImplementedError + + async def delete_master_zone( + self, + zone_id: str, # noqa: ARG002 + ) -> None: + raise DNSNotImplementedError + + async def delete_forward_zone( + self, + zone_id: str, # noqa: ARG002 + ) -> None: + raise DNSNotImplementedError + + async def check_forward_dns_server( + self, + dns_server_ip: IPv4Address | IPv6Address, # noqa: ARG002 + host_dns_servers: list[str], # noqa: ARG002 + ) -> DNSForwardServerStatus: + raise DNSNotImplementedError diff --git a/app/ldap_protocol/dns/managers/stub_dns_manager.py b/app/ldap_protocol/dns/managers/stub_dns_manager.py new file mode 100644 index 000000000..2dceeb626 --- /dev/null +++ b/app/ldap_protocol/dns/managers/stub_dns_manager.py @@ -0,0 +1,107 @@ +"""Stub calls for DNS server API. + +Copyright (c) 2024 MultiFactor +License: https://github.com/MultiDirectoryLab/MultiDirectory/blob/main/LICENSE +""" + +from ipaddress import IPv4Address, IPv6Address + +from ldap_protocol.dns.dto import ( + DNSForwardZoneDTO, + DNSMasterZoneDTO, + DNSRRSetDTO, + DNSSettingsDTO, +) +from ldap_protocol.dns.managers.abstract_dns_manager import AbstractDNSManager +from ldap_protocol.dns.utils import logger_wraps + + +class StubDNSManager(AbstractDNSManager): + """Stub client.""" + + @logger_wraps(is_stub=True) + async def setup( + self, + dns_settings: DNSSettingsDTO, + is_migration: bool = False, + ) -> None: ... + + @logger_wraps(is_stub=True) + async def create_record( + self, + zone_id: str, + record: DNSRRSetDTO, + ) -> None: ... + + @logger_wraps(is_stub=True) + async def update_record( + self, + zone_id: str, + record: DNSRRSetDTO, + ) -> None: ... + + @logger_wraps(is_stub=True) + async def delete_record( + self, + zone_id: str, + record: DNSRRSetDTO, + ) -> None: ... + + @logger_wraps(is_stub=True) + async def get_records( + self, + zone_id: str, # noqa: ARG002 + ) -> list[DNSRRSetDTO]: + return [] + + @logger_wraps(is_stub=True) + async def get_master_zones(self) -> list[DNSMasterZoneDTO]: + return [] + + @logger_wraps(is_stub=True) + async def get_forward_zones(self) -> list[DNSForwardZoneDTO]: + return [] + + @logger_wraps(is_stub=True) + async def create_master_zone( + self, + zone: DNSMasterZoneDTO, + is_empty: bool = False, + ) -> None: ... + + @logger_wraps(is_stub=True) + async def create_forward_zone( + self, + zone: DNSForwardZoneDTO, + ) -> None: ... + + @logger_wraps(is_stub=True) + async def update_master_zone( + self, + zone: DNSMasterZoneDTO, + ) -> None: ... + + @logger_wraps(is_stub=True) + async def update_forward_zone( + self, + zone: DNSForwardZoneDTO, + ) -> None: ... + + @logger_wraps(is_stub=True) + async def delete_master_zone( + self, + zone_id: str, + ) -> None: ... + + @logger_wraps(is_stub=True) + async def delete_forward_zone( + self, + zone_id: str, + ) -> None: ... + + @logger_wraps(is_stub=True) + async def check_forward_dns_server( + self, + dns_server_ip: IPv4Address | IPv6Address, + host_dns_servers: list[str], + ) -> None: ... diff --git a/app/ldap_protocol/dns/remote.py b/app/ldap_protocol/dns/remote.py deleted file mode 100644 index 1c2cb25fd..000000000 --- a/app/ldap_protocol/dns/remote.py +++ /dev/null @@ -1,125 +0,0 @@ -"""DNS service for remote DNS server managing. - -Copyright (c) 2024 MultiFactor -License: https://github.com/MultiDirectoryLab/MultiDirectory/blob/main/LICENSE -""" - -from collections import defaultdict - -from dns.asyncquery import inbound_xfr as make_inbound_xfr, tcp as asynctcp -from dns.message import Message, make_query as make_dns_query -from dns.name import from_text -from dns.rdataclass import IN -from dns.rdatatype import AXFR -from dns.tsig import Key as TsigKey -from dns.update import Update -from dns.zone import Zone - -from .base import AbstractDNSManager, DNSRecord, DNSRecords -from .exceptions import DNSConnectionError -from .utils import logger_wraps - - -class RemoteDNSManager(AbstractDNSManager): - """DNS server manager.""" - - async def _send(self, action: Message) -> None: - """Send request to DNS server.""" - if self._dns_settings.tsig_key is not None: - action.use_tsig( - keyring=TsigKey("zone.", self._dns_settings.tsig_key), - keyname="zone.", - ) - - if self._dns_settings.dns_server_ip is None: - raise DNSConnectionError - - await asynctcp(action, self._dns_settings.dns_server_ip) - - @logger_wraps() - async def create_record( - self, - hostname: str, - ip: str, - record_type: str, - ttl: int | None, - zone_name: str | None = None, - ) -> None: - """Create DNS record.""" - action = Update(self._dns_settings.zone_name or zone_name) - action.add(hostname, ttl, record_type, ip) - - await self._send(action) - - @logger_wraps() - async def get_all_records(self) -> list[DNSRecords]: - """Get all DNS records.""" - if ( - self._dns_settings.dns_server_ip is None - or self._dns_settings.zone_name is None - ): - raise DNSConnectionError - - zone = from_text(self._dns_settings.zone_name) - zone_tm = Zone(zone) - query = make_dns_query(zone, AXFR, IN) - - if self._dns_settings.tsig_key is not None: - query.use_tsig( - keyring=TsigKey("zone.", self._dns_settings.tsig_key), - keyname="zone.", - ) - - await make_inbound_xfr( - self._dns_settings.dns_server_ip, - zone_tm, - ) - - result: defaultdict[str, list] = defaultdict(list) - for name, ttl, rdata in zone_tm.iterate_rdatas(): - record_type = rdata.rdtype.name - - if record_type == "SOA": - continue - - result[record_type].append( - DNSRecord( - name=(name.to_text() + f".{self._dns_settings.zone_name}"), - value=rdata.to_text(), - ttl=ttl, - ), - ) - - return [ - DNSRecords(type=record_type, records=records) - for record_type, records in result.items() - ] - - @logger_wraps() - async def update_record( - self, - hostname: str, - ip: str | None, - record_type: str, - ttl: int | None, - zone_name: str | None = None, - ) -> None: - """Update DNS record.""" - action = Update(self._dns_settings.zone_name or zone_name) - action.replace(hostname, ttl, record_type, ip) - - await self._send(action) - - @logger_wraps() - async def delete_record( - self, - hostname: str, - ip: str, - record_type: str, - zone_name: str | None = None, - ) -> None: - """Delete DNS record.""" - action = Update(self._dns_settings.zone_name or zone_name) - action.delete(hostname, record_type, ip) - - await self._send(action) diff --git a/app/ldap_protocol/dns/selfhosted.py b/app/ldap_protocol/dns/selfhosted.py deleted file mode 100644 index 4870e2e70..000000000 --- a/app/ldap_protocol/dns/selfhosted.py +++ /dev/null @@ -1,286 +0,0 @@ -"""DNS service for SelfHosted DNS server managing. - -Copyright (c) 2024 MultiFactor -License: https://github.com/MultiDirectoryLab/MultiDirectory/blob/main/LICENSE -""" - -import asyncio -from dataclasses import asdict -from ipaddress import IPv4Address, IPv6Address - -import dns.asyncresolver - -import ldap_protocol.dns.exceptions as dns_exc - -from .base import ( - AbstractDNSManager, - DNSForwarderServerStatus, - DNSForwardServerStatus, - DNSForwardZone, - DNSRecords, - DNSRecordType, - DNSServerParam, - DNSZone, - DNSZoneParam, - DNSZoneType, -) -from .utils import logger_wraps - - -class SelfHostedDNSManager(AbstractDNSManager): - """Manager for selfhosted Bind9 DNS server.""" - - @logger_wraps() - async def create_record( - self, - hostname: str, - ip: str, - record_type: DNSRecordType, - ttl: int, - zone_name: str | None = None, - ) -> None: - """Create DNS record.""" - response = await self._http_client.post( - "/record", - json={ - "zone_name": zone_name, - "record_name": hostname, - "record_type": record_type, - "record_value": ip, - "ttl": ttl, - }, - ) - - if response.status_code != 200: - raise dns_exc.DNSRecordCreateError(response.text) - - @logger_wraps() - async def update_record( - self, - hostname: str, - ip: str | None, - record_type: str, - ttl: int | None, - zone_name: str | None = None, - ) -> None: - response = await self._http_client.patch( - "/record", - json={ - "zone_name": zone_name, - "record_name": hostname, - "record_type": record_type, - "record_value": ip, - "ttl": ttl, - }, - ) - - if response.status_code != 200: - raise dns_exc.DNSRecordUpdateError(response.text) - - @logger_wraps() - async def delete_record( - self, - hostname: str, - ip: str, - record_type: str, - zone_name: str | None = None, - ) -> None: - response = await self._http_client.request( - "delete", - "/record", - json={ - "zone_name": zone_name, - "record_name": hostname, - "record_type": record_type, - "record_value": ip, - }, - ) - - if response.status_code != 200: - raise dns_exc.DNSRecordDeleteError(response.text) - - @logger_wraps() - async def get_all_records(self) -> list[DNSRecords]: - response = await self._http_client.get("/zone") - - response_data = response.json() - - if ( - isinstance(response_data, list) - and len(response_data) > 0 - and "records" in response_data[0] - ): - return response_data[0]["records"] - else: - return [] - - @logger_wraps() - async def get_all_zones_records(self) -> list[DNSZone]: - response = await self._http_client.get("/zone") - - return response.json() - - @logger_wraps() - async def get_forward_zones(self) -> list[DNSForwardZone]: - response = await self._http_client.get("/zone/forward") - - return response.json() - - @logger_wraps() - async def create_zone( - self, - zone_name: str, - zone_type: DNSZoneType, - nameserver: str | None, - params: list[DNSZoneParam], - ) -> None: - response = await self._http_client.post( - "/zone", - json={ - "zone_name": zone_name, - "zone_type": zone_type, - "nameserver": nameserver, - "params": [asdict(param) for param in params], - }, - ) - - if response.status_code != 200: - raise dns_exc.DNSZoneCreateError(response.text) - - @logger_wraps() - async def update_zone( - self, - zone_name: str, - params: list[DNSZoneParam], - ) -> None: - response = await self._http_client.patch( - "/zone", - json={ - "zone_name": zone_name, - "params": [asdict(param) for param in params], - }, - ) - - if response.status_code != 200: - raise dns_exc.DNSZoneUpdateError(response.text) - - @logger_wraps() - async def delete_zone( - self, - zone_names: list[str], - ) -> None: - for zone_name in zone_names: - response = await self._http_client.request( - "delete", - "/zone", - json={"zone_name": zone_name}, - ) - - if response.status_code != 200: - raise dns_exc.DNSZoneDeleteError(response.text) - - @logger_wraps() - async def find_forward_dns_fqdn( - self, - dns_server_ip: IPv4Address | IPv6Address, - host_dns_servers: list[str], - ) -> str | None: - """Find forward DNS FQDN.""" - reversed_ip = ( - ".".join(reversed((str(dns_server_ip)).split("."))) - + ".in-addr.arpa" - ) - - async def get_fqdn_and_latency( - server: str, - ) -> tuple[float, str | None]: - resolver = dns.asyncresolver.Resolver() - resolver.nameservers = [server] - resolver.timeout = 10 - - try: - event_loop = asyncio.get_running_loop() - start_time = event_loop.time() - fqdn = await resolver.resolve( - reversed_ip, - "PTR", - ) - latency = event_loop.time() - start_time - - return (latency, fqdn[0].to_text()) - except ( - dns.asyncresolver.NoAnswer, - dns.asyncresolver.NXDOMAIN, - ): - return (float("inf"), None) - - fqdn_list = await asyncio.gather( - *(get_fqdn_and_latency(server) for server in host_dns_servers), - ) - fqdn_list.sort(key=lambda x: x[0]) - return fqdn_list[0][1] if fqdn_list else None - - @logger_wraps() - async def check_forward_dns_server( - self, - dns_server_ip: IPv4Address | IPv6Address, - host_dns_servers: list[str], - ) -> DNSForwardServerStatus: - str_dns_server_ip = str(dns_server_ip) - - try: - fqdn = await self.find_forward_dns_fqdn( - str_dns_server_ip, - host_dns_servers, - ) - except (dns.asyncresolver.NoAnswer, dns.asyncresolver.NXDOMAIN): - return DNSForwardServerStatus( - str_dns_server_ip, - DNSForwarderServerStatus.NOT_VALIDATED, - None, - ) - - if not fqdn: - return DNSForwardServerStatus( - str_dns_server_ip, - DNSForwarderServerStatus.NOT_FOUND, - None, - ) - - return DNSForwardServerStatus( - str_dns_server_ip, - DNSForwarderServerStatus.VALIDATED, - fqdn, - ) - - @logger_wraps() - async def update_server_options( - self, - params: list[DNSServerParam], - ) -> None: - response = await self._http_client.patch( - "/server/settings", - json=[asdict(param) for param in params], - ) - - if response.status_code != 200: - raise dns_exc.DNSUpdateServerOptionsError(response.text) - - @logger_wraps() - async def get_server_options(self) -> list[DNSServerParam]: - response = await self._http_client.get("/server/settings") - - return response.json() - - @logger_wraps() - async def restart_server( - self, - ) -> None: - await self._http_client.get("/server/restart") - - @logger_wraps() - async def reload_zone( - self, - zone_name: str, - ) -> None: - await self._http_client.get(f"/zone/{zone_name}") diff --git a/app/ldap_protocol/dns/stub.py b/app/ldap_protocol/dns/stub.py deleted file mode 100644 index 836a98a62..000000000 --- a/app/ldap_protocol/dns/stub.py +++ /dev/null @@ -1,109 +0,0 @@ -"""Stub calls for DNS server API. - -Copyright (c) 2024 MultiFactor -License: https://github.com/MultiDirectoryLab/MultiDirectory/blob/main/LICENSE -""" - -from .base import ( - AbstractDNSManager, - DNSForwardZone, - DNSRecords, - DNSServerParam, - DNSZoneParam, - DNSZoneType, -) -from .utils import logger_wraps - - -class StubDNSManager(AbstractDNSManager): - """Stub client.""" - - @logger_wraps(is_stub=True) - async def create_record( - self, - hostname: str, - ip: str, - record_type: str, - ttl: int | None, - zone_name: str | None = None, - ) -> None: ... - - @logger_wraps(is_stub=True) - async def update_record( - self, - hostname: str, - ip: str, - record_type: str, - ttl: int, - zone_name: str | None = None, - ) -> None: ... - - @logger_wraps(is_stub=True) - async def delete_record( - self, - hostname: str, - ip: str, - record_type: str, - zone_name: str | None = None, - ) -> None: ... - - @logger_wraps(is_stub=True) - async def get_all_zones_records(self) -> None: ... - - @logger_wraps(is_stub=True) - async def get_forward_zones(self) -> list[DNSForwardZone]: - return [] - - @logger_wraps(is_stub=True) - async def create_zone( - self, - zone_name: str, - zone_type: DNSZoneType, - nameserver: str | None, - params: list[DNSZoneParam], - ) -> None: ... - - @logger_wraps(is_stub=True) - async def update_zone( - self, - zone_name: str, - params: list[DNSZoneParam] | None, - ) -> None: ... - - @logger_wraps(is_stub=True) - async def delete_zone( - self, - zone_names: list[str], - ) -> None: ... - - @logger_wraps(is_stub=True) - async def check_forward_dns_server( - self, - dns_server_ip: str, - ) -> None: ... - - @logger_wraps(is_stub=True) - async def update_server_options( - self, - params: list[DNSServerParam], - ) -> None: ... - - @logger_wraps(is_stub=True) - async def get_server_options(self) -> list[DNSServerParam]: - return [] - - @logger_wraps(is_stub=True) - async def restart_server( - self, - ) -> None: ... - - @logger_wraps(is_stub=True) - async def reload_zone( - self, - zone_name: str, - ) -> None: ... - - @logger_wraps(is_stub=True) - async def get_all_records(self) -> list[DNSRecords]: - """Stub DNS manager get all records.""" - return [] diff --git a/app/ldap_protocol/dns/use_cases.py b/app/ldap_protocol/dns/use_cases.py index 0b5f32291..78dfb2092 100644 --- a/app/ldap_protocol/dns/use_cases.py +++ b/app/ldap_protocol/dns/use_cases.py @@ -10,18 +10,17 @@ from abstract_service import AbstractService from config import Settings from enums import AuthorizationRules -from ldap_protocol.dns.base import ( - AbstractDNSManager, +from ldap_protocol.dns.dns_gateway import DNSStateGateway +from ldap_protocol.dns.dto import ( DNSForwardServerStatus, - DNSForwardZone, - DNSManagerSettings, - DNSRecords, - DNSServerParam, - DNSZone, - DNSZoneParam, - DNSZoneType, + DNSForwardZoneDTO, + DNSMasterZoneDTO, + DNSRRSetDTO, + DNSSettingsDTO, ) -from ldap_protocol.dns.dns_gateway import DNSStateGateway +from ldap_protocol.dns.enums import DNSManagerState +from ldap_protocol.dns.exceptions import DNSError, DNSSetupError +from ldap_protocol.dns.managers.abstract_dns_manager import AbstractDNSManager class DNSUseCase(AbstractService): @@ -31,7 +30,7 @@ def __init__( self, dns_manager: AbstractDNSManager, dns_gateway: DNSStateGateway, - dns_settings: DNSManagerSettings, + dns_settings: DNSSettingsDTO, settings: Settings, ) -> None: """Initialize DNS use case.""" @@ -40,116 +39,94 @@ def __init__( self._dns_settings = dns_settings self._dns_gateway = dns_gateway - async def setup_dns( + async def setup( self, - dns_status: str, - domain: str, - dns_ip_address: str | IPv4Address | IPv6Address | None, - tsig_key: str | None, + dns_settings: DNSSettingsDTO | None, ) -> None: """Set up DNS server and DNS manager.""" - setup_data = await self._dns_manager.setup( - dns_status, - domain, - dns_ip_address or self._settings.DNS_BIND_HOST, - tsig_key, - ) - if self._dns_settings.domain is not None: - await self._dns_gateway.update_settings(setup_data) - else: - await self._dns_gateway.create_settings(setup_data) + state = await self._dns_gateway.get_state() - await self._dns_gateway.setup_dns_state(dns_status) + if state == DNSManagerState.SELFHOSTED: + await self._dns_manager.setup( + self._dns_settings, + ) + elif state == DNSManagerState.HOSTED: + if dns_settings is None: + raise DNSSetupError() + if self._dns_settings.dns_server_ip is None: + await self._dns_gateway.create_settings(dns_settings) + else: + await self._dns_gateway.update_settings(dns_settings) + else: + raise DNSSetupError() async def create_record( self, - hostname: str, - ip: str, - record_type: str, - ttl: int | None, - zone_name: str | None = None, + zone_id: str, + record: DNSRRSetDTO, ) -> None: """Create DNS record.""" - await self._dns_manager.create_record( - hostname, - ip, - record_type, - ttl, - zone_name, - ) + await self._dns_manager.create_record(zone_id, record) - async def delete_record( - self, - hostname: str, - ip: str, - record_type: str, - zone_name: str | None = None, - ) -> None: - """Delete DNS record.""" - await self._dns_manager.delete_record( - hostname, - ip, - record_type, - zone_name, - ) + async def get_records(self, zone_id: str) -> list[DNSRRSetDTO]: + """Get all DNS records.""" + return await self._dns_manager.get_records(zone_id) - async def update_record( - self, - hostname: str, - ip: str | None, - record_type: str, - ttl: int | None, - zone_name: str | None = None, - ) -> None: + async def update_record(self, zone_id: str, record: DNSRRSetDTO) -> None: """Update DNS record.""" - await self._dns_manager.update_record( - hostname, - ip, - record_type, - ttl, - zone_name, - ) + await self._dns_manager.update_record(zone_id, record) - async def get_all_records(self) -> list[DNSRecords]: - """Get all DNS records.""" - return await self._dns_manager.get_all_records() + async def delete_record(self, zone_id: str, record: DNSRRSetDTO) -> None: + """Delete DNS record.""" + await self._dns_manager.delete_record(zone_id, record) - async def get_all_zones_records(self) -> list[DNSZone]: + async def create_master_zone(self, zone: DNSMasterZoneDTO) -> None: + """Create DNS master zone.""" + await self._dns_manager.create_master_zone(zone) + + async def create_forward_zone(self, zone: DNSForwardZoneDTO) -> None: + """Create DNS forward zone.""" + await self._dns_manager.create_forward_zone(zone) + + async def get_master_zones(self) -> list[DNSMasterZoneDTO]: """Get all DNS zones.""" - return await self._dns_manager.get_all_zones_records() + return await self._dns_manager.get_master_zones() - async def get_forward_zones(self) -> list[DNSForwardZone]: + async def get_forward_zones(self) -> list[DNSForwardZoneDTO]: """Get all forward zones.""" return await self._dns_manager.get_forward_zones() - async def create_zone( - self, - zone_name: str, - zone_type: DNSZoneType, - nameserver: str | None, - params: list[DNSZoneParam], - ) -> None: - """Create DNS zone.""" - await self._dns_manager.create_zone( - zone_name, - zone_type, - nameserver, - params, - ) - - async def update_zone( - self, - zone_name: str, - params: list[DNSZoneParam] | None, - ) -> None: - """Update DNS zone.""" - await self._dns_manager.update_zone(zone_name, params) - - async def delete_zone(self, zone_names: list[str]) -> None: - """Delete DNS zone.""" - await self._dns_manager.delete_zone(zone_names) - - async def check_forward_dns_server( + async def update_master_zone(self, zone: DNSMasterZoneDTO) -> None: + """Update DNS master zone.""" + await self._dns_manager.update_master_zone(zone) + + async def update_forward_zone(self, zone: DNSForwardZoneDTO) -> None: + """Update DNS forward zone.""" + await self._dns_manager.update_forward_zone(zone) + + async def delete_master_zones(self, zone_ids: list[str]) -> None: + """Delete DNS master zones.""" + last_error = None + try: + for zone_id in zone_ids: + await self._dns_manager.delete_master_zone(zone_id) + except DNSError as e: + last_error = e + if last_error: + raise last_error + + async def delete_forward_zones(self, zone_ids: list[str]) -> None: + """Delete DNS forward zones.""" + last_error = None + try: + for zone_id in zone_ids: + await self._dns_manager.delete_forward_zone(zone_id) + except DNSError as e: + last_error = e + if last_error: + raise last_error + + async def check_forward_server( self, dns_server_ip: IPv4Address | IPv6Address, host_dns_servers: list[str], @@ -160,40 +137,27 @@ async def check_forward_dns_server( host_dns_servers, ) - async def update_server_options( - self, - params: list[DNSServerParam], - ) -> None: - """Update DNS server options.""" - await self._dns_manager.update_server_options(params) - - async def restart_server(self) -> None: - """Restart DNS server.""" - await self._dns_manager.restart_server() - - async def reload_zone(self, zone_name: str) -> None: - """Reload DNS zone.""" - await self._dns_manager.reload_zone(zone_name) - - async def get_server_options(self) -> list[DNSServerParam]: - """Get DNS server options.""" - return await self._dns_manager.get_server_options() - - async def get_dns_status(self) -> dict[str, str | None]: + async def get_status(self) -> dict[str, str | None]: """Get DNS status.""" return { - "dns_status": await self._dns_gateway.get_dns_state(), - "zone_name": self._dns_settings.zone_name, - "dns_server_ip": self._dns_settings.dns_server_ip, + "dns_status": await self._dns_gateway.get_state(), + "zone_name": self._dns_settings.domain, + "dns_server_ip": str(self._dns_settings.dns_server_ip) + if self._dns_settings.dns_server_ip is not None + else None, } - async def check_dns_forward_zone( + async def set_state(self, state: DNSManagerState) -> None: + """Set DNS manager state.""" + await self._dns_gateway.set_state(state) + + async def check_forward_zone( self, data: list[IPv4Address | IPv6Address], ) -> list[DNSForwardServerStatus]: """Check DNS forward zone for availability.""" return [ - await self.check_forward_dns_server( + await self.check_forward_server( dns_server_ip, self._settings.HOST_DNS_SERVERS, ) @@ -201,20 +165,20 @@ async def check_dns_forward_zone( ] PERMISSIONS: ClassVar[dict[str, AuthorizationRules]] = { - setup_dns.__name__: AuthorizationRules.DNS_SETUP_DNS, + setup.__name__: AuthorizationRules.DNS_SETUP_DNS, create_record.__name__: AuthorizationRules.DNS_CREATE_RECORD, delete_record.__name__: AuthorizationRules.DNS_DELETE_RECORD, update_record.__name__: AuthorizationRules.DNS_UPDATE_RECORD, - get_all_records.__name__: AuthorizationRules.DNS_GET_ALL_RECORDS, - get_dns_status.__name__: AuthorizationRules.DNS_GET_DNS_STATUS, - get_all_zones_records.__name__: AuthorizationRules.DNS_GET_ALL_ZONES_RECORDS, # noqa: E501 - get_forward_zones.__name__: AuthorizationRules.DNS_GET_FORWARD_ZONES, - create_zone.__name__: AuthorizationRules.DNS_CREATE_ZONE, - update_zone.__name__: AuthorizationRules.DNS_UPDATE_ZONE, - delete_zone.__name__: AuthorizationRules.DNS_DELETE_ZONE, - check_dns_forward_zone.__name__: AuthorizationRules.DNS_CHECK_DNS_FORWARD_ZONE, # noqa: E501 - reload_zone.__name__: AuthorizationRules.DNS_RELOAD_ZONE, - update_server_options.__name__: AuthorizationRules.DNS_UPDATE_SERVER_OPTIONS, # noqa: E501 - get_server_options.__name__: AuthorizationRules.DNS_GET_SERVER_OPTIONS, - restart_server.__name__: AuthorizationRules.DNS_RESTART_SERVER, + get_records.__name__: AuthorizationRules.DNS_GET_ALL_RECORDS, + get_status.__name__: AuthorizationRules.DNS_GET_DNS_STATUS, + delete_forward_zones.__name__: AuthorizationRules.DNS_DELETE_FWD_ZONES, + get_master_zones.__name__: AuthorizationRules.DNS_GET_MASTER_ZONES, + get_forward_zones.__name__: AuthorizationRules.DNS_GET_FWD_ZONES, + create_master_zone.__name__: AuthorizationRules.DNS_CREATE_MASTER_ZONE, + create_forward_zone.__name__: AuthorizationRules.DNS_CREATE_FWD_ZONE, + update_master_zone.__name__: AuthorizationRules.DNS_UPDATE_MASTER_ZONE, + update_forward_zone.__name__: AuthorizationRules.DNS_UPDATE_FWD_ZONE, + delete_master_zones.__name__: AuthorizationRules.DNS_DELETE_MASTER_ZONES, # noqa: E501 + delete_forward_zones.__name__: AuthorizationRules.DNS_DELETE_FWD_ZONES, + check_forward_zone.__name__: AuthorizationRules.DNS_CHECK_DNS_FORWARD_ZONE, # noqa: E501 } diff --git a/app/ldap_protocol/dns/utils.py b/app/ldap_protocol/dns/utils.py index 9adc21fe9..feccf9eae 100644 --- a/app/ldap_protocol/dns/utils.py +++ b/app/ldap_protocol/dns/utils.py @@ -8,9 +8,21 @@ from typing import Any, Callable from dns.asyncresolver import Resolver as AsyncResolver +from loguru import logger -from .base import log -from .exceptions import DNSConnectionError +from ldap_protocol.dns.dto import DNSRecordDTO, DNSRRSetDTO +from ldap_protocol.dns.enums import DNSRecordType, PowerDNSRecordChangeType +from ldap_protocol.dns.exceptions import DNSConnectionError, DNSError + +log = logger.bind(name="DNSManager") + +log.add( + "logs/dnsmanager_{time:DD-MM-YYYY}.log", + filter=lambda rec: rec["extra"].get("name") == "DNSManager", + retention="10 days", + rotation="1d", + colorize=False, +) def logger_wraps(is_stub: bool = False) -> Callable: @@ -23,17 +35,12 @@ def wrapper(func: Callable) -> Callable: @functools.wraps(func) async def wrapped(*args: str, **kwargs: str) -> Any: logger = log.opt(depth=1) - - logger.info(f"Calling{bus_type}'{name}'") try: result = await func(*args, **kwargs) - except DNSConnectionError as err: - logger.error(f"{name} call raised: {err}") + except DNSError as err: + logger.error(f"{name} call in {bus_type} raised: {err}") raise - else: - if not is_stub: - logger.success(f"Executed {name}") return result return wrapped @@ -48,3 +55,52 @@ async def resolve_dns_server_ip(host: str) -> str: if dns_server_ip_resolve is None or dns_server_ip_resolve.rrset is None: raise DNSConnectionError return dns_server_ip_resolve.rrset[0].address + + +async def create_initial_zone_records( + domain: str, + nameserver: str, +) -> list[DNSRRSetDTO]: + """Get initial records for new zone.""" + return [ + DNSRRSetDTO( + name=f"{domain}", + type=DNSRecordType.A, + records=[ + DNSRecordDTO( + content=nameserver, + disabled=False, + modified_at=None, + ), + ], + changetype=PowerDNSRecordChangeType.EXTEND, + ttl=3600, + ), + DNSRRSetDTO( + name=f"ns1.{domain}", + type=DNSRecordType.A, + records=[ + DNSRecordDTO( + content=nameserver, + disabled=False, + modified_at=None, + ), + ], + changetype=PowerDNSRecordChangeType.EXTEND, + ttl=3600, + ), + DNSRRSetDTO( + name=f"{domain}", + type=DNSRecordType.SOA, + records=[ + DNSRecordDTO( + content=f"ns1.{domain} hostmaster.{domain}" + + " 1 10800 3600 604800 3600", + disabled=False, + modified_at=None, + ), + ], + changetype=PowerDNSRecordChangeType.EXTEND, + ttl=3600, + ), + ] diff --git a/app/ldap_protocol/filter_interpreter.py b/app/ldap_protocol/filter_interpreter.py index c456fac00..ce0b301c5 100644 --- a/app/ldap_protocol/filter_interpreter.py +++ b/app/ldap_protocol/filter_interpreter.py @@ -32,16 +32,26 @@ ) from ldap_protocol.utils.helpers import ft_to_dt from ldap_protocol.utils.queries import get_path_filter, get_search_path -from repo.pg.tables import groups_table, queryable_attr as qa, users_table +from repo.pg.tables import ( + directory_table, + groups_table, + queryable_attr as qa, + users_table, +) from .asn1parser import ASN1Row, TagNumbers from .objects import LDAPMatchingRule -from .utils.cte import find_members_recursive_cte, get_filter_from_path +from .utils.cte import ( + find_members_recursive_cte, + find_root_group_recursive_cte, + get_filter_from_path, +) _MEMBERS_ATTRS = { "member", "memberof", f"memberof:{LDAPMatchingRule.LDAP_MATCHING_RULE_TRANSITIVE_EVAL}:", + f"member:{LDAPMatchingRule.LDAP_MATCHING_RULE_TRANSITIVE_EVAL}:", } _RULE_POS = 0 @@ -289,6 +299,8 @@ def _get_member_filter_function( return self._recursive_filter_memberof return self._filter_memberof elif attribute == "member": + if oid == LDAPMatchingRule.LDAP_MATCHING_RULE_TRANSITIVE_EVAL: + return self._recursive_filter_member return self._filter_member else: raise ValueError("Incorrect attribute specified") @@ -317,6 +329,24 @@ def _filter_memberof(self, dn: str) -> UnaryExpression: ), ) # type: ignore + def _recursive_filter_member(self, dn: str) -> UnaryExpression: + """Retrieve query conditions with the member attribute (recursive).""" + cte = find_root_group_recursive_cte([dn]) + + source_directory_id = ( + select(directory_table.c.id) + .where(get_filter_from_path(dn)) + .scalar_subquery() + ) + + return qa(Directory.id).in_( + select(cte.c.directory_id) + .where( + cte.c.directory_id != source_directory_id, + ) + .distinct(), + ) # type: ignore + def _filter_member(self, dn: str) -> UnaryExpression: """Retrieve query conditions with the member attribute.""" user_id_subquery = ( diff --git a/app/ldap_protocol/kerberos/base.py b/app/ldap_protocol/kerberos/base.py index d70960738..31dcb355e 100644 --- a/app/ldap_protocol/kerberos/base.py +++ b/app/ldap_protocol/kerberos/base.py @@ -153,8 +153,9 @@ async def setup( @abstractmethod async def add_principal( self, - name: str, - password: str | None, + principal_name: str, + password: str | None = None, + algorithms: list[str] | None = None, timeout: int | float = 1, ) -> None: ... @@ -179,7 +180,13 @@ async def create_or_update_principal_pw( ) -> None: ... @abstractmethod - async def rename_princ(self, name: str, new_name: str) -> None: ... + async def modify_princ( + self, + name: str, + new_name: str | None, + algorithms: list[str] | None = None, + password: str | None = None, + ) -> None: ... @backoff.on_exception( backoff.constant, @@ -202,7 +209,11 @@ async def get_status(self, wait_for_positive: bool = False) -> bool: return status @abstractmethod - async def ktadd(self, names: list[str]) -> httpx.Response: ... + async def ktadd( + self, + names: list[str], + is_rand_key: bool, + ) -> httpx.Response: ... @abstractmethod async def lock_principal(self, name: str) -> None: ... @@ -221,14 +232,17 @@ async def ldap_principal_setup(self, name: str, path: str) -> None: if response.status_code == 200: return - response = await self.client.post("/principal", json={"name": name}) + response = await self.client.post( + "/principal", + json={"principal_name": name}, + ) if response.status_code != 201: log.error(f"Error creating ldap principal: {response.text}") return response = await self.client.post( "/principal/ktadd", - json=[name], + json={"names": [name], "is_rand_key": False}, ) if response.status_code != 200: log.error(f"Error getting keytab: {response.text}") diff --git a/app/ldap_protocol/kerberos/client.py b/app/ldap_protocol/kerberos/client.py index 8dfcb8a23..c85c062fc 100644 --- a/app/ldap_protocol/kerberos/client.py +++ b/app/ldap_protocol/kerberos/client.py @@ -20,12 +20,17 @@ async def add_principal( self, name: str, password: str | None, - timeout: int = 1, + algorithms: list[str] | None = None, + timeout: int | float = 1, ) -> None: """Add request.""" response = await self.client.post( "principal", - json={"name": name, "password": password}, + json={ + "principal_name": name, + "password": password, + "algorithms": algorithms, + }, timeout=timeout, ) @@ -89,17 +94,32 @@ async def create_or_update_principal_pw( raise krb_exc.KRBAPIChangePasswordError(response.text) @logger_wraps() - async def rename_princ(self, name: str, new_name: str) -> None: + async def modify_princ( + self, + name: str, + new_name: str | None, + algorithms: list[str] | None, + password: str | None, + ) -> None: """Rename request.""" response = await self.client.put( "principal", - json={"name": name, "new_name": new_name}, + json={ + "name": name, + "new_name": new_name, + "algorithms": algorithms, + "password": password, + }, ) if response.status_code != 202: - raise krb_exc.KRBAPIRenamePrincipalError(response.text) + raise krb_exc.KRBAPIModifyPrincipalError(response.text) @logger_wraps() - async def ktadd(self, names: list[str]) -> httpx.Response: + async def ktadd( + self, + names: list[str], + is_rand_key: bool, + ) -> httpx.Response: """Ktadd build request for stream and return response. :param list[str] names: principals @@ -108,7 +128,7 @@ async def ktadd(self, names: list[str]) -> httpx.Response: request = self.client.build_request( "POST", "/principal/ktadd", - json=names, + json={"names": names, "is_rand_key": is_rand_key}, ) response = await self.client.send(request, stream=True) diff --git a/app/ldap_protocol/kerberos/schemas.py b/app/ldap_protocol/kerberos/dtos.py similarity index 85% rename from app/ldap_protocol/kerberos/schemas.py rename to app/ldap_protocol/kerberos/dtos.py index b3be3abb5..d01775aee 100644 --- a/app/ldap_protocol/kerberos/schemas.py +++ b/app/ldap_protocol/kerberos/dtos.py @@ -11,7 +11,7 @@ @dataclass -class KerberosAdminDnGroup: +class KerberosAdminDnGroupDTO: """Kerberos admin, services container, and admin group DNs.""" krbadmin_dn: str @@ -20,8 +20,8 @@ class KerberosAdminDnGroup: @dataclass -class AddRequests: - """AddRequests for Kerberos admin structure: group, services, krb_user.""" +class AddRequestsDTO: + """AddRequestsDTO for Kerberos admin structure.""" group: AddRequest services: AddRequest @@ -29,7 +29,7 @@ class AddRequests: @dataclass -class KDCContext: +class KDCContextDTO: """Kerberos KDC configuration context.""" base_dn: str @@ -43,7 +43,7 @@ class KDCContext: @dataclass -class TaskStruct: +class TaskStructDTO: """Structure for background task: function, args, kwargs.""" func: Callable[..., Any] diff --git a/app/ldap_protocol/kerberos/exceptions.py b/app/ldap_protocol/kerberos/exceptions.py index 735149eff..2008aff1c 100644 --- a/app/ldap_protocol/kerberos/exceptions.py +++ b/app/ldap_protocol/kerberos/exceptions.py @@ -31,7 +31,7 @@ class ErrorCodes(IntEnum): KERBEROS_API_GET_PRINCIPAL_ERROR = 16 KERBEROS_API_DELETE_PRINCIPAL_ERROR = 17 KERBEROS_API_CHANGE_PASSWORD_ERROR = 18 - KERBEROS_API_RENAME_PRINCIPAL_ERROR = 19 + KERBEROS_API_MODIFY_PRINCIPAL_ERROR = 19 KERBEROS_API_LOCK_PRINCIPAL_ERROR = 20 KERBEROS_API_FORCE_PASSWORD_CHANGE_ERROR = 21 KERBEROS_API_STATUS_NOT_FOUND_ERROR = 22 @@ -132,10 +132,10 @@ class KRBAPIChangePasswordError(KRBAPIError): code = ErrorCodes.KERBEROS_API_CHANGE_PASSWORD_ERROR -class KRBAPIRenamePrincipalError(KRBAPIError): +class KRBAPIModifyPrincipalError(KRBAPIError): """Rename principal error.""" - code = ErrorCodes.KERBEROS_API_RENAME_PRINCIPAL_ERROR + code = ErrorCodes.KERBEROS_API_MODIFY_PRINCIPAL_ERROR class KRBAPILockPrincipalError(KRBAPIError): diff --git a/app/ldap_protocol/kerberos/ldap_structure.py b/app/ldap_protocol/kerberos/ldap_structure.py index 45228a3c8..fec8741c0 100644 --- a/app/ldap_protocol/kerberos/ldap_structure.py +++ b/app/ldap_protocol/kerberos/ldap_structure.py @@ -68,6 +68,7 @@ async def create_kerberos_structure( async with self._session.begin_nested(): await self._role_use_case.create_kerberos_system_role() + await self._role_use_case.add_read_only_role_to_krbadmin_group() user_result = await anext(krb_user.handle(ctx)) if user_result.result_code != 0: raise KerberosConflictError("User error") diff --git a/app/ldap_protocol/kerberos/service.py b/app/ldap_protocol/kerberos/service.py index f6a0aae05..fa838abb9 100644 --- a/app/ldap_protocol/kerberos/service.py +++ b/app/ldap_protocol/kerberos/service.py @@ -30,19 +30,24 @@ from password_utils import PasswordUtils from .base import AbstractKadmin +from .dtos import ( + AddRequestsDTO, + KDCContextDTO, + KerberosAdminDnGroupDTO, + TaskStructDTO, +) from .exceptions import ( KRBAPIAddPrincipalError, KRBAPIConnectionError, KRBAPIDeletePrincipalError, + KRBAPIModifyPrincipalError, KRBAPIPrincipalNotFoundError, - KRBAPIRenamePrincipalError, KRBAPISetupConfigsError, KRBAPISetupStashError, KRBAPISetupTreeError, KRBAPIStatusNotFoundError, ) from .ldap_structure import KRBLDAPStructureManager -from .schemas import AddRequests, KDCContext, KerberosAdminDnGroup, TaskStruct from .template_render import KRBTemplateRenderer from .utils import ( KerberosState, @@ -138,17 +143,20 @@ async def _get_base_dn(self) -> tuple[str, str]: ) return base_dn_list[0].path_dn, base_dn_list[0].name - def _build_kerberos_admin_dns(self, base_dn: str) -> KerberosAdminDnGroup: + def _build_kerberos_admin_dns( + self, + base_dn: str, + ) -> KerberosAdminDnGroupDTO: """Build DN strings for Kerberos admin, services, and group. :param str base_dn: Base DN. - :return KerberosAdminDnGroup: + :return KerberosAdminDnGroupDTO: dataclass with DN for krbadmin, services_container, krbadmin_group. """ - krbadmin = f"cn=krbadmin,cn=users,{base_dn}" + krbadmin = f"cn=krbadmin,cn=Users,{base_dn}" services_container = get_system_container_dn(base_dn) - krbgroup = f"cn=krbadmin,cn=groups,{base_dn}" - return KerberosAdminDnGroup( + krbgroup = f"cn=krbadmin,cn=Groups,{base_dn}" + return KerberosAdminDnGroupDTO( krbadmin_dn=krbadmin, services_container_dn=services_container, krbadmin_group_dn=krbgroup, @@ -156,17 +164,17 @@ def _build_kerberos_admin_dns(self, base_dn: str) -> KerberosAdminDnGroup: def _build_add_requests( self, - dns: KerberosAdminDnGroup, + dns: KerberosAdminDnGroupDTO, mail: str, krbadmin_password: SecretStr, - ) -> AddRequests: + ) -> AddRequestsDTO: """Build AddRequest objects for group, services, and admin user. - :param KerberosAdminDnGroup dns: + :param KerberosAdminDnGroupDTO dns: DNs for krbadmin, services container, and group. :param str mail: Email for krbadmin. :param SecretStr krbadmin_password: Password for krbadmin. - :return AddRequests: + :return AddRequestsDTO: dataclass of AddRequest for group, services, and user. """ group = AddRequest.from_dict( @@ -219,7 +227,7 @@ def _build_add_requests( }, is_system=True, ) - return AddRequests( + return AddRequestsDTO( group=group, services=services, krb_user=krb_user, @@ -232,8 +240,8 @@ async def setup_kdc( stash_password: str, user: UserSchema, request: Request, - ) -> TaskStruct: - """Set up KDC, generate configs, and return TaskStruct. + ) -> TaskStructDTO: + """Set up KDC, generate configs, and return TaskStructDTO. Args: krbadmin_password (str): Password for krbadmin. @@ -289,17 +297,17 @@ async def setup_kdc( admin_password, ) - async def _get_kdc_context(self) -> KDCContext: + async def _get_kdc_context(self) -> KDCContextDTO: """Build and return context for KDC setup/config rendering. :raises Exception: If base DN cannot be retrieved. - :return KDCContext: dataclass with all required KDC context fields. + :return KDCContextDTO: dataclass with all required KDC context fields. """ base_dn, domain = await self._get_base_dn() krbadmin = f"cn=krbadmin,cn=users,{base_dn}" krbgroup = f"cn=krbadmin,cn=groups,{base_dn}" services_container = get_system_container_dn(base_dn) - return KDCContext( + return KDCContextDTO( base_dn=base_dn, domain=domain, krbadmin=krbadmin, @@ -335,7 +343,7 @@ async def _schedule_principal_task( request: Request, user: UserSchema, password: str, - ) -> TaskStruct: + ) -> TaskStructDTO: """Schedule background task for principal creation after KDC setup. :param Request request: FastAPI request (for DI container). @@ -356,9 +364,14 @@ async def _schedule_principal_task( user.user_principal_name.split("@")[0], password, ) - return TaskStruct(func=func, args=args) + return TaskStructDTO(func=func, args=args) - async def add_principal(self, primary: str, instance: str) -> None: + async def add_principal( + self, + principal_name: str, + password: str | None, + algorithms: list[str] | None, + ) -> None: """Create principal in Kerberos with given name. :param str primary: Principal primary name. @@ -367,52 +380,42 @@ async def add_principal(self, primary: str, instance: str) -> None: :return None: None. """ try: - principal_name = f"{primary}/{instance}" - await self._kadmin.add_principal(principal_name, None) + await self._kadmin.add_principal( + principal_name, + password, + algorithms, + ) except KRBAPIAddPrincipalError as exc: raise KerberosDependencyError( f"Error adding principal: {exc}", ) from exc - async def rename_principal( + async def modify_principal( self, principal_name: str, - principal_new_name: str, + new_name: str | None, + algorithms: list[str] | None, + password: str | None, ) -> None: - """Rename principal in Kerberos with given name. + """Modify principal in Kerberos with given name. :param str principal_name: Current principal name. - :param str principal_new_name: New principal name. + :param str new_name: New principal name. + :param list[str] | None algorithms: Algorithms. + :param str | None password: Password. :raises KerberosDependencyError: On failed kadmin request. :return None: None. """ try: - await self._kadmin.rename_princ(principal_name, principal_new_name) - except KRBAPIRenamePrincipalError as exc: - raise KerberosDependencyError( - f"Error renaming principal: {exc}", - ) from exc - - async def reset_principal_pw( - self, - principal_name: str, - new_password: str, - ) -> None: - """Reset principal password in Kerberos with given name. - - :param str principal_name: Principal name. - :param str new_password: New password. - :raises KerberosDependencyError: On failed kadmin request. - :return None: None. - """ - try: - await self._kadmin.change_principal_password( + await self._kadmin.modify_princ( principal_name, - new_password, + new_name, + algorithms, + password, ) - except Exception as exc: + except KRBAPIModifyPrincipalError as exc: raise KerberosDependencyError( - f"Error resetting principal password: {exc}", + f"Error renaming principal: {exc}", ) from exc async def delete_principal(self, principal_name: str) -> None: @@ -432,20 +435,22 @@ async def delete_principal(self, principal_name: str) -> None: async def ktadd( self, names: list[str], - ) -> tuple[AsyncIterator[bytes], TaskStruct]: - """Generate keytab and return (aiter_bytes, TaskStruct). + is_rand_key: bool, + ) -> tuple[AsyncIterator[bytes], TaskStructDTO]: + """Generate keytab and return (aiter_bytes, TaskStructDTO). :param list[str] names: List of principal names. + :param bool is_rand_key: If True, generate new principal keys. :raises KerberosNotFoundError: If principal not found. :return tuple: (aiter_bytes, (func, args, kwargs)). """ try: - response = await self._kadmin.ktadd(names) + response = await self._kadmin.ktadd(names, is_rand_key) except KRBAPIPrincipalNotFoundError: raise KerberosNotFoundError("Principal not found") aiter_bytes = response.aiter_bytes() func = response.aclose - return aiter_bytes, TaskStruct(func=func) + return aiter_bytes, TaskStructDTO(func=func) async def get_status(self) -> KerberosState: """Get Kerberos server state (db + actual server). @@ -469,7 +474,6 @@ async def get_status(self) -> KerberosState: ktadd.__name__: AuthorizationRules.KRB_KTADD, get_status.__name__: AuthorizationRules.KRB_GET_STATUS, add_principal.__name__: AuthorizationRules.KRB_ADD_PRINCIPAL, - rename_principal.__name__: AuthorizationRules.KRB_RENAME_PRINCIPAL, - reset_principal_pw.__name__: AuthorizationRules.KRB_RESET_PRINCIPAL_PW, + modify_principal.__name__: AuthorizationRules.KRB_MODIFY_PRINCIPAL, delete_principal.__name__: AuthorizationRules.KRB_DELETE_PRINCIPAL, } diff --git a/app/ldap_protocol/kerberos/stub.py b/app/ldap_protocol/kerberos/stub.py index 889583c16..5e50efdcf 100644 --- a/app/ldap_protocol/kerberos/stub.py +++ b/app/ldap_protocol/kerberos/stub.py @@ -19,8 +19,9 @@ async def setup(self, *args, **kwargs) -> None: # type: ignore @logger_wraps(is_stub=True) async def add_principal( self, - name: str, - password: str | None, + principal_name: str, + password: str | None = None, + algorithms: list[str] | None = None, timeout: int = 1, ) -> None: ... @@ -45,10 +46,16 @@ async def create_or_update_principal_pw( ) -> None: ... @logger_wraps(is_stub=True) - async def rename_princ(self, name: str, new_name: str) -> None: ... + async def modify_princ( + self, + name: str, + new_name: str | None, + algorithms: list[str] | None = None, + password: str | None = None, + ) -> None: ... @logger_wraps(is_stub=True) - async def ktadd(self, names: list[str]) -> NoReturn: # noqa: ARG002 + async def ktadd(self, names: list[str], is_rand_key: bool) -> NoReturn: # noqa: ARG002 raise KRBAPIPrincipalNotFoundError @logger_wraps(is_stub=True) diff --git a/app/ldap_protocol/kerberos/template_render.py b/app/ldap_protocol/kerberos/template_render.py index 0df7b36a7..d4428a596 100644 --- a/app/ldap_protocol/kerberos/template_render.py +++ b/app/ldap_protocol/kerberos/template_render.py @@ -6,7 +6,7 @@ import jinja2 -from .schemas import KDCContext +from .dtos import KDCContextDTO class KRBTemplateRenderer: @@ -23,11 +23,11 @@ def __init__(self, templates: jinja2.Environment) -> None: """ self._templates = templates - async def render_krb5(self, context: KDCContext) -> str: + async def render_krb5(self, context: KDCContextDTO) -> str: """Render the krb5.conf configuration file using the provided context. :param context: - KDCContext dataclass with Kerberos configuration parameters. + KDCContextDTO dataclass with Kerberos configuration parameters. :return: Rendered krb5.conf as a string. """ krb5_template = self._templates.get_template("krb5.conf") @@ -40,11 +40,11 @@ async def render_krb5(self, context: KDCContext) -> str: sync_password_url=context.sync_password_url, ) - async def render_kdc(self, context: KDCContext) -> str: + async def render_kdc(self, context: KDCContextDTO) -> str: """Render the kdc.conf configuration file using the provided context. :param context: - KDCContext dataclass with Kerberos configuration parameters. + KDCContextDTO dataclass with Kerberos configuration parameters. :return: Rendered kdc.conf as a string. """ kdc_template = self._templates.get_template("kdc.conf") diff --git a/app/ldap_protocol/kerberos/utils.py b/app/ldap_protocol/kerberos/utils.py index c6278ed95..026b71a1d 100644 --- a/app/ldap_protocol/kerberos/utils.py +++ b/app/ldap_protocol/kerberos/utils.py @@ -63,8 +63,7 @@ async def wrapped(*args: str, **kwargs: str) -> Any: except Exception as err: if isinstance(err, KRBAPIError): logger.error(f"{name} call raised: {err}") - raise - + raise else: if not is_stub: logger.success(f"Executed {name}") diff --git a/app/ldap_protocol/ldap_requests/add.py b/app/ldap_protocol/ldap_requests/add.py index 75be3f6fc..d6e6e8078 100644 --- a/app/ldap_protocol/ldap_requests/add.py +++ b/app/ldap_protocol/ldap_requests/add.py @@ -13,7 +13,7 @@ from constants import DOMAIN_COMPUTERS_GROUP_NAME, DOMAIN_USERS_GROUP_NAME from entities import Attribute, Directory, Group, User -from enums import AceType, EntityTypeNames +from enums import AceType, EntityTypeNames, SamAccountTypeCodes from ldap_protocol.asn1parser import ASN1Row from ldap_protocol.kerberos.exceptions import ( KRBAPIAddPrincipalError, @@ -64,6 +64,7 @@ class AddRequest(BaseRequest): ``` """ + RESPONSE_TYPE: ClassVar[type] = AddResponse PROTOCOL_OP: ClassVar[int] = ProtocolRequests.ADD CONTEXT_TYPE: ClassVar[type] = LDAPAddRequestContext @@ -233,7 +234,12 @@ async def handle( # noqa: C901 parent_groups: list[Group] = [] user_attributes: dict[str, str] = {} group_attributes: list[str] = [] - user_fields = User.search_fields.keys() | User.fields.keys() + is_user_like = "user" in self.object_class_names + user_fields = ( + User.search_fields.keys() | User.fields.keys() + if is_user_like + else set() + ) attributes.append( Attribute( @@ -249,11 +255,7 @@ async def handle( # noqa: C901 # in the attributes if ( attr_name in Directory.ro_fields - or attr_name - in ( - "userpassword", - "unicodepwd", - ) + or attr_name in ("userpassword", "unicodepwd") or attr_name == new_dir.rdname ): continue @@ -294,6 +296,7 @@ async def handle( # noqa: C901 or "userPrincipalName" in user_attributes ) is_computer = "computer" in self.attrs_dict.get("objectClass", []) + computer_sam_account_name = None if is_user: if not any( @@ -342,11 +345,11 @@ async def handle( # noqa: C901 ), ) - for uattr, value in { - "loginShell": "/bin/bash", - "uidNumber": str(create_integer_hash(user.sam_account_name)), - "homeDirectory": f"/home/{user.sam_account_name}", - }.items(): + for uattr, value in ( + ("loginShell", "/bin/bash"), + ("uidNumber", str(create_integer_hash(user.sam_account_name))), + ("homeDirectory", f"/home/{user.sam_account_name}"), + ): if uattr in user_attributes: value = user_attributes[uattr] del user_attributes[uattr] @@ -372,33 +375,44 @@ async def handle( # noqa: C901 items_to_add.append(group) group.parent_groups.extend(parent_groups) - elif is_computer and "useraccountcontrol" not in self.l_attrs_dict: - if not any( - group.directory.name.lower() == DOMAIN_COMPUTERS_GROUP_NAME - for group in parent_groups - ): - parent_groups.append( - await get_group( - DOMAIN_COMPUTERS_GROUP_NAME, - ctx.session, - ), - ) - await ctx.session.refresh( - instance=new_dir, - attribute_names=["groups"], - with_for_update=None, - ) - new_dir.groups.extend(parent_groups) + elif is_computer: + computer_sam_account_name = new_dir.name + attributes.append( Attribute( - name="userAccountControl", - value=str( - UserAccountControlFlag.WORKSTATION_TRUST_ACCOUNT, - ), + name="sAMAccountName", + value=computer_sam_account_name, directory_id=new_dir.id, ), ) + if "useraccountcontrol" not in self.l_attrs_dict: + if not any( + group.directory.name.lower() == DOMAIN_COMPUTERS_GROUP_NAME + for group in parent_groups + ): + parent_groups.append( + await get_group( + DOMAIN_COMPUTERS_GROUP_NAME, + ctx.session, + ), + ) + await ctx.session.refresh( + instance=new_dir, + attribute_names=["groups"], + with_for_update=None, + ) + new_dir.groups.extend(parent_groups) + attributes.append( + Attribute( + name="userAccountControl", + value=str( + UserAccountControlFlag.WORKSTATION_TRUST_ACCOUNT, + ), + directory_id=new_dir.id, + ), + ) + if (is_user or is_group) and "gidnumber" not in self.l_attrs_dict: reverse_d_name = new_dir.name[::-1] value = ( @@ -421,6 +435,32 @@ async def handle( # noqa: C901 ), ) + if "samaccounttype" not in self.l_attrs_dict: + if is_user: + attributes.append( + Attribute( + name="sAMAccountType", + value=str(SamAccountTypeCodes.SAM_USER_OBJECT), + directory_id=new_dir.id, + ), + ) + elif is_group: + attributes.append( + Attribute( + name="sAMAccountType", + value=str(SamAccountTypeCodes.SAM_GROUP_OBJECT), + directory_id=new_dir.id, + ), + ) + elif is_computer: + attributes.append( + Attribute( + name="sAMAccountType", + value=str(SamAccountTypeCodes.SAM_MACHINE_ACCOUNT), + directory_id=new_dir.id, + ), + ) + if not ctx.attribute_value_validator.is_directory_attributes_valid( entity_type.name if entity_type else "", attributes, @@ -461,24 +501,22 @@ async def handle( # noqa: C901 KRBAPIDeletePrincipalError, KRBAPIPrincipalNotFoundError, ): - await ctx.kadmin.del_principal( - user.get_upn_prefix(), - ) + await ctx.kadmin.del_principal(user.sam_account_name) pw = ( self.password.get_secret_value() if self.password else None ) - await ctx.kadmin.add_principal(user.get_upn_prefix(), pw) + await ctx.kadmin.add_principal(user.sam_account_name, pw) elif is_computer: await ctx.kadmin.add_principal( - f"{new_dir.host_principal}.{base_dn.name}", + f"host/{computer_sam_account_name}.{base_dn.name}", None, ) await ctx.kadmin.add_principal( - new_dir.host_principal, + f"host/{computer_sam_account_name}", None, ) except (KRBAPIAddPrincipalError, KRBAPIConnectionError): diff --git a/app/ldap_protocol/ldap_requests/base.py b/app/ldap_protocol/ldap_requests/base.py index 445ce3bae..63667f034 100644 --- a/app/ldap_protocol/ldap_requests/base.py +++ b/app/ldap_protocol/ldap_requests/base.py @@ -18,11 +18,13 @@ from dishka import AsyncContainer from loguru import logger from pydantic import BaseModel +from sqlalchemy.exc import OperationalError from config import Settings from entities import Directory from ldap_protocol.dependency import resolve_deps from ldap_protocol.dialogue import LDAPSession +from ldap_protocol.ldap_codes import LDAPCodes from ldap_protocol.ldap_responses import BaseResponse, LDAPResult from ldap_protocol.objects import ProtocolRequests from ldap_protocol.policies.audit.audit_use_case import AuditUseCase @@ -63,6 +65,7 @@ class _APIProtocol: ... class BaseRequest(ABC, _APIProtocol, BaseModel): """Base request builder.""" + RESPONSE_TYPE: ClassVar[type] CONTEXT_TYPE: ClassVar[type] handle: ClassVar[handler] from_data: ClassVar[serializer] @@ -118,9 +121,17 @@ async def handle_tcp( ctx = await container.get(self.CONTEXT_TYPE) # type: ignore responses = [] - async for response in self.handle(ctx=ctx): - responses.append(response) - yield response + try: + async for response in self.handle(ctx=ctx): + responses.append(response) + yield response + except OperationalError: + if self.PROTOCOL_OP != ProtocolRequests.ABANDON: + yield self.RESPONSE_TYPE( + result_code=LDAPCodes.UNAVAILABLE, + errorMessage="Master DB is not available", + ) + return if self.PROTOCOL_OP != ProtocolRequests.SEARCH: ldap_session = await container.get(LDAPSession) @@ -172,7 +183,17 @@ async def _handle_api( else: log_api.info(f"{get_class_name(self)}[{un}]") - responses = [response async for response in self.handle(ctx=ctx)] + try: + responses = [response async for response in self.handle(ctx=ctx)] + except OperationalError: + responses = [] + if self.PROTOCOL_OP != ProtocolRequests.ABANDON: + responses.append( + self.RESPONSE_TYPE( + result_code=LDAPCodes.UNAVAILABLE, + errorMessage="Master DB is not available", + ), + ) if settings.DEBUG: for response in responses: diff --git a/app/ldap_protocol/ldap_requests/bind.py b/app/ldap_protocol/ldap_requests/bind.py index 445b2f25c..a747c26f6 100644 --- a/app/ldap_protocol/ldap_requests/bind.py +++ b/app/ldap_protocol/ldap_requests/bind.py @@ -8,6 +8,7 @@ from typing import AsyncGenerator, ClassVar from pydantic import Field +from sqlalchemy.exc import OperationalError from entities import NetworkPolicy from enums import MFAFlags @@ -42,6 +43,7 @@ class BindRequest(BaseRequest): """Bind request fields mapping.""" + RESPONSE_TYPE: ClassVar[type] = BindResponse PROTOCOL_OP: ClassVar[int] = ProtocolRequests.BIND CONTEXT_TYPE: ClassVar[type] = LDAPBindRequestContext @@ -209,13 +211,18 @@ async def handle( KRBAPIConnectionError, ): await ctx.kadmin.add_principal( - user.get_upn_prefix(), + user.sam_account_name, self.authentication_choice.password.get_secret_value(), - 0.1, + timeout=0.1, ) await ctx.ldap_session.set_user(user) - await set_user_logon_attrs(user, ctx.session, ctx.settings.TIMEZONE) + with contextlib.suppress(OperationalError): + await set_user_logon_attrs( + user, + ctx.session, + ctx.settings.TIMEZONE, + ) server_sasl_creds = None if isinstance(self.authentication_choice, SaslSPNEGOAuthentication): diff --git a/app/ldap_protocol/ldap_requests/delete.py b/app/ldap_protocol/ldap_requests/delete.py index e2b127331..b8ad639d8 100644 --- a/app/ldap_protocol/ldap_requests/delete.py +++ b/app/ldap_protocol/ldap_requests/delete.py @@ -42,6 +42,7 @@ class DeleteRequest(BaseRequest): DelRequest ::= [APPLICATION 10] LDAPDN """ + RESPONSE_TYPE: ClassVar[type] = DeleteResponse PROTOCOL_OP: ClassVar[int] = ProtocolRequests.DELETE CONTEXT_TYPE: ClassVar[type] = LDAPDeleteRequestContext @@ -74,6 +75,7 @@ async def handle( # noqa: C901 select(Directory) .options( joinedload(qa(Directory.user)), + joinedload(qa(Directory.entity_type)), selectinload(qa(Directory.groups)).selectinload( qa(Group.directory), ), @@ -154,13 +156,16 @@ async def handle( # noqa: C901 await ctx.session_storage.clear_user_sessions( directory.user.id, ) - await ctx.kadmin.del_principal(directory.user.get_upn_prefix()) + await ctx.kadmin.del_principal(directory.user.sam_account_name) if await is_computer(directory.id, ctx.session): - await ctx.kadmin.del_principal(directory.host_principal) - await ctx.kadmin.del_principal( - f"{directory.host_principal}.{base_dn.name}", - ) + computer_sam_account_names = directory.attributes_dict.get("sAMAccountName") # noqa: E501 # fmt: skip + if computer_sam_account_names: + computer_sam_account_name = computer_sam_account_names[0] + await ctx.kadmin.del_principal(f"host/{computer_sam_account_name}") # noqa: E501 # fmt: skip + await ctx.kadmin.del_principal(f"host/{computer_sam_account_name}.{base_dn.name}") # noqa: E501 # fmt: skip + else: + raise KRBAPIDeletePrincipalError except KRBAPIPrincipalNotFoundError: pass except (KRBAPIDeletePrincipalError, KRBAPIConnectionError): diff --git a/app/ldap_protocol/ldap_requests/extended.py b/app/ldap_protocol/ldap_requests/extended.py index c3967889e..a3e74ad28 100644 --- a/app/ldap_protocol/ldap_requests/extended.py +++ b/app/ldap_protocol/ldap_requests/extended.py @@ -248,7 +248,7 @@ async def handle( ): try: await ctx.kadmin.create_or_update_principal_pw( - user.get_upn_prefix(), + user.sam_account_name, new_password, ) except ( @@ -307,6 +307,7 @@ class ExtendedRequest(BaseRequest): requestValue [1] OCTET STRING OPTIONAL } """ + RESPONSE_TYPE: ClassVar[type] = ExtendedResponse PROTOCOL_OP: ClassVar[int] = ProtocolRequests.EXTENDED CONTEXT_TYPE: ClassVar[type] = LDAPExtendedRequestContext request_name: LDAPOID diff --git a/app/ldap_protocol/ldap_requests/modify.py b/app/ldap_protocol/ldap_requests/modify.py index 676550e3e..9b1b03edf 100644 --- a/app/ldap_protocol/ldap_requests/modify.py +++ b/app/ldap_protocol/ldap_requests/modify.py @@ -8,7 +8,8 @@ from typing import AsyncGenerator, ClassVar from loguru import logger -from sqlalchemy import Select, and_, delete, or_, select, update +from pydantic import PrivateAttr +from sqlalchemy import Select, and_, delete, func, or_, select, update from sqlalchemy.exc import IntegrityError from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.orm import joinedload, selectinload @@ -24,6 +25,7 @@ KRBAPIConnectionError, KRBAPIForcePasswordChangeError, KRBAPILockPrincipalError, + KRBAPIModifyPrincipalError, KRBAPIPrincipalNotFoundError, ) from ldap_protocol.ldap_codes import LDAPCodes @@ -37,11 +39,16 @@ from ldap_protocol.policies.password import PasswordPolicyUseCases from ldap_protocol.session_storage import SessionStorage from ldap_protocol.utils.cte import check_root_group_membership_intersection -from ldap_protocol.utils.helpers import ft_to_dt, validate_entry +from ldap_protocol.utils.helpers import ( + ft_to_dt, + is_dn_in_base_directory, + validate_entry, +) from ldap_protocol.utils.queries import ( add_lock_and_expire_attributes, clear_group_membership, extend_group_membership, + get_base_directories, get_directories, get_directory_by_rid, get_filter_from_path, @@ -71,6 +78,7 @@ class ModifyForbiddenError(Exception): PermissionError, ModifyForbiddenError, KRBAPIPrincipalNotFoundError, + KRBAPIModifyPrincipalError, KRBAPILockPrincipalError, KRBAPIForcePasswordChangeError, ) @@ -94,12 +102,18 @@ class ModifyRequest(BaseRequest): ``` """ + RESPONSE_TYPE: ClassVar[type] = ModifyResponse PROTOCOL_OP: ClassVar[int] = ProtocolRequests.MODIFY CONTEXT_TYPE: ClassVar[type] = LDAPModifyRequestContext object: str changes: list[Changes] + # NOTE: If the old value was changed (for example, in _delete) + # in one method, then you need to have access to the old value + # from other methods (for example, from _add) + _old_vals: dict[str, str | None] = PrivateAttr(default_factory=dict) + @classmethod def from_data(cls, data: list[ASN1Row]) -> "ModifyRequest": entry, proto_changes = data @@ -131,7 +145,7 @@ async def _update_password_expiration( return if not ( - change.modification.type == "krbpasswordexpiration" + change.l_type == "krbpasswordexpiration" and change.modification.vals[0] == "19700101000000Z" ): return @@ -184,7 +198,7 @@ async def handle( entity_type_id=directory.entity_type_id, ) - names = {change.get_name() for change in self.changes} + names = {change.l_type for change in self.changes} password_change_requested = self._is_password_change_requested(names) self_modify = directory.id == ctx.ldap_session.user.directory_id @@ -194,25 +208,23 @@ async def handle( and await ctx.password_use_cases.is_password_change_restricted( directory.id, ) + ) or ( + not can_modify and not (password_change_requested and self_modify) ): yield ModifyResponse( result_code=LDAPCodes.INSUFFICIENT_ACCESS_RIGHTS, ) return + if directory.rdname in names: + yield ModifyResponse(result_code=LDAPCodes.NOT_ALLOWED_ON_RDN) + return + before_attrs = self.get_directory_attrs(directory) entity_type = directory.entity_type try: - if not can_modify and not ( - password_change_requested and self_modify - ): - yield ModifyResponse( - result_code=LDAPCodes.INSUFFICIENT_ACCESS_RIGHTS, - ) - return - for change in self.changes: - if change.modification.type.lower() in Directory.ro_fields: + if change.l_type in Directory.ro_fields: continue if not ctx.attribute_value_validator.is_partial_attribute_valid( # noqa: E501 @@ -222,7 +234,7 @@ async def handle( await ctx.session.rollback() yield ModifyResponse( result_code=LDAPCodes.UNDEFINED_ATTRIBUTE_TYPE, - message="Invalid attribute value(s)", + error_message="Invalid attribute value(s)", ) return @@ -272,10 +284,10 @@ async def handle( except MODIFY_EXCEPTION_STACK as err: await ctx.session.rollback() - result_code, message = self._match_bad_response(err) + result_code, error_message = self._match_bad_response(err) yield ModifyResponse( result_code=result_code, - message=message, + error_message=error_message, ) return @@ -321,6 +333,9 @@ def _match_bad_response(self, err: BaseException) -> tuple[LDAPCodes, str]: case ModifyForbiddenError(): return LDAPCodes.OPERATIONS_ERROR, str(err) + case KRBAPIModifyPrincipalError(): + return LDAPCodes.UNAVAILABLE, "Kerberos error" + case KRBAPIPrincipalNotFoundError(): return LDAPCodes.UNAVAILABLE, "Kerberos error" @@ -612,6 +627,18 @@ async def _validate_object_class_modification( if is_object_class_in_replaced or is_object_class_in_deleted: raise ModifyForbiddenError("ObjectClass can't be deleted.") + def _need_to_cache_samaccountname_old_value( + self, + change: Changes, + directory: Directory, + ) -> bool: + return bool( + directory.entity_type + and directory.entity_type.name == EntityTypeNames.COMPUTER + and change.l_type == "samaccountname" + and not self._old_vals.get(change.modification.type), + ) + async def _delete( self, change: Changes, @@ -621,9 +648,8 @@ async def _delete( name_only: bool = False, ) -> None: attrs = [] - name = change.modification.type.lower() - if name == "memberof": + if change.l_type == "memberof": await self._delete_memberof( change=change, directory=directory, @@ -632,7 +658,7 @@ async def _delete( ) return - if name == "member": + if change.l_type == "member": await self._delete_member( change=change, directory=directory, @@ -641,14 +667,16 @@ async def _delete( ) return - if name == "objectclass": + if change.l_type == "objectclass": await self._validate_object_class_modification(change, directory) if name_only or not change.modification.vals: attrs.append(qa(Attribute.name) == change.modification.type) else: for value in change.modification.vals: - if name not in (Directory.search_fields | User.search_fields): + if change.l_type not in ( + Directory.search_fields | User.search_fields + ): if isinstance(value, str): condition = qa(Attribute.value) == value elif isinstance(value, bytes): @@ -656,10 +684,15 @@ async def _delete( attrs.append( and_( - qa(Attribute.name) == change.modification.type, + func.lower(qa(Attribute.name)) == change.l_type, condition, ), - ) + ) # fmt: skip + + if self._need_to_cache_samaccountname_old_value(change, directory): + vals = directory.attributes_dict.get(change.modification.type) + if vals: + self._old_vals[change.modification.type] = vals[0] if attrs: del_query = ( @@ -773,16 +806,15 @@ async def _add_group_attrs( directory: Directory, session: AsyncSession, ) -> None: - name = change.get_name() - if name == "primarygroupid": + if change.l_type == "primarygroupid": await self._add_primary_group_attribute( change, directory, session, ) - elif name == "memberof": + elif change.l_type == "memberof": await self._add_memberof(change, directory, session) - elif name == "member": + elif change.l_type == "member": await self._add_member(change, directory, session) async def _add( # noqa: C901 @@ -797,24 +829,22 @@ async def _add( # noqa: C901 password_use_cases: PasswordPolicyUseCases, password_utils: PasswordUtils, ) -> None: + base_dir = None attrs = [] - name = change.get_name() - if name in {"memberof", "member", "primarygroupid"}: + if change.l_type in ("memberof", "member", "primarygroupid"): await self._add_group_attrs(change, directory, session) return for value in change.modification.vals: - if name == "useraccountcontrol": + if change.l_type == "useraccountcontrol": uac_val = int(value) if not UserAccountControlFlag.is_value_valid(uac_val): continue elif ( - bool( - uac_val & UserAccountControlFlag.ACCOUNTDISABLE, - ) + bool(uac_val & UserAccountControlFlag.ACCOUNTDISABLE) and directory.user ): if directory.path_dn == current_user.dn: @@ -823,7 +853,7 @@ async def _add( # noqa: C901 ) await kadmin.lock_principal( - directory.user.get_upn_prefix(), + directory.user.sam_account_name, ) await add_lock_and_expire_attributes( @@ -837,9 +867,7 @@ async def _add( # noqa: C901 ) elif ( - not bool( - uac_val & UserAccountControlFlag.ACCOUNTDISABLE, - ) + not bool(uac_val & UserAccountControlFlag.ACCOUNTDISABLE) and directory.user ): await unlock_principal( @@ -858,37 +886,100 @@ async def _add( # noqa: C901 ), ) # fmt: skip - if name == "pwdlastset" and value == "0" and directory.user: + if ( + change.l_type == "pwdlastset" + and value == "0" + and directory.user + ): await kadmin.force_princ_pw_change( - directory.user.get_upn_prefix(), + directory.user.sam_account_name, ) - if name == directory.rdname: + if change.l_type == directory.rdname: await session.execute( update(Directory) .filter(directory_table.c.id == directory.id) .values(name=value), ) - if name in Directory.search_fields: + if change.l_type in Directory.search_fields: await session.execute( update(Directory) .filter(directory_table.c.id == directory.id) - .values({name: value}), + .values({change.l_type: value}), ) - elif name in User.search_fields: - if name == "accountexpires": + elif ( + change.l_type in User.search_fields + and directory.entity_type + and directory.entity_type.name == EntityTypeNames.USER + and directory.user + ): + if change.l_type == "accountexpires": new_value = ft_to_dt(int(value)) if value != "0" else None else: new_value = value # type: ignore - await session.execute( - update(User) - .filter_by(directory=directory) - .values({name: new_value}), + if change.l_type in ("userprincipalname", "samaccountname"): + if change.l_type == "userprincipalname": + new_user_principal_name = str(new_value) + new_sam_account_name = new_user_principal_name.split("@")[0] # noqa: E501 # fmt: skip + elif change.l_type == "samaccountname": + if not base_dir: + base_dir = await self._get_base_dir( + directory, + session, + ) + + new_sam_account_name = str(new_value) + new_user_principal_name = f"{new_sam_account_name}@{base_dir.name}" # noqa: E501 # fmt: skip + + if directory.user.sam_account_name != new_sam_account_name: + await kadmin.modify_princ( + directory.user.sam_account_name, + new_sam_account_name, + ) + + directory.user.user_principal_name = new_user_principal_name # noqa: E501 # fmt: skip + directory.user.sam_account_name = new_sam_account_name + else: + await session.execute( + update(User) + .filter_by(directory=directory) + .values({change.l_type: new_value}), + ) + + elif ( + change.l_type == "samaccountname" + and directory.entity_type + and directory.entity_type.name == EntityTypeNames.COMPUTER + ): + if not base_dir: + base_dir = await self._get_base_dir( + directory, + session, + ) + + await self._modify_computer_samaccountname( + change, + kadmin, + base_dir, + value, ) - elif name in ("userpassword", "unicodepwd") and directory.user: + + attrs.append( + Attribute( + name=change.modification.type, + value=value if isinstance(value, str) else None, + bvalue=value if isinstance(value, bytes) else None, + directory_id=directory.id, + ), + ) # fmt: skip + + elif ( + change.l_type in ("userpassword", "unicodepwd") + and directory.user + ): if not settings.USE_CORE_TLS: raise PermissionError("TLS required") @@ -918,7 +1009,7 @@ async def _add( # noqa: C901 directory.user, ) await kadmin.create_or_update_principal_pw( - directory.user.get_upn_prefix(), + directory.user.sam_account_name, value, ) @@ -935,3 +1026,45 @@ async def _add( # noqa: C901 ) session.add_all(attrs) + + async def _modify_computer_samaccountname( + self, + change: Changes, + kadmin: AbstractKadmin, + base_dir: Directory, + new_sam_account_name: bytes | str, + ) -> None: + old_sam_account_name = self._old_vals.get(change.modification.type) + new_sam_account_name = str(new_sam_account_name) + + if not old_sam_account_name: + raise ModifyForbiddenError("Old sAMAccountName value not found.") + + if old_sam_account_name != new_sam_account_name: + await kadmin.modify_princ( + f"host/{old_sam_account_name}", + f"host/{new_sam_account_name}", + ) + await kadmin.modify_princ( + f"host/{old_sam_account_name}.{base_dir.name}", + f"host/{new_sam_account_name}.{base_dir.name}", + ) + + async def _get_base_dir( + self, + directory: Directory, + session: AsyncSession, + ) -> Directory: + base_dir = None + + for base_directory in await get_base_directories(session): + if is_dn_in_base_directory( + base_directory, + directory.path_dn, + ): + base_dir = base_directory + break + else: + raise ModifyForbiddenError("Base directory not found.") + + return base_dir diff --git a/app/ldap_protocol/ldap_requests/modify_dn.py b/app/ldap_protocol/ldap_requests/modify_dn.py index 7c315eadd..7cd6d45c7 100644 --- a/app/ldap_protocol/ldap_requests/modify_dn.py +++ b/app/ldap_protocol/ldap_requests/modify_dn.py @@ -8,7 +8,8 @@ from sqlalchemy import delete, func, select, text, update from sqlalchemy.exc import IntegrityError -from sqlalchemy.orm import joinedload, selectinload +from sqlalchemy.ext.asyncio import AsyncSession +from sqlalchemy.orm import joinedload from entities import AccessControlEntry, Attribute, Directory from enums import AceType @@ -18,7 +19,13 @@ INVALID_ACCESS_RESPONSE, ModifyDNResponse, ) -from ldap_protocol.objects import ProtocolRequests +from ldap_protocol.objects import ( + Changes, + Operation, + PartialAttribute, + ProtocolRequests, +) +from ldap_protocol.roles.access_manager import AccessManager from ldap_protocol.utils.queries import get_filter_from_path, validate_entry from repo.pg.tables import ( ace_directory_memberships_table, @@ -62,11 +69,12 @@ class ModifyDNRequest(BaseRequest): entry='cn=main,dc=multifactor,dc=dev' newrdn='cn=main2' deleteoldrdn=true - new_superior='cn=users,dc=multifactor,dc=dev' + new_superior='cn=Users,dc=multifactor,dc=dev' - >>> cn = main2, cn = users, dc = multifactor, dc = dev + >>> cn = main2, cn = Users, dc = multifactor, dc = dev """ + RESPONSE_TYPE: ClassVar[type] = ModifyDNResponse PROTOCOL_OP: ClassVar[int] = ProtocolRequests.MODIFY_DN CONTEXT_TYPE: ClassVar[type] = LDAPModifyDNRequestContext @@ -85,7 +93,87 @@ def from_data(cls, data: list[ASN1Row]) -> "ModifyDNRequest": new_superior=None if len(data) < 4 else data[3].value, ) - async def handle( + def _is_move_to_new_superior(self, directory: Directory) -> bool: + return bool( + self.new_superior + and directory.parent + and self.new_superior != directory.parent.path_dn, + ) + + def _can_modify_rdn( + self, + access_manager: AccessManager, + directory: Directory, + old_dn: str, + old_name: str, + new_dn: str, + new_name: str, + ) -> bool: + change = [ + Changes( + operation=Operation.ADD, + modification=PartialAttribute(type=new_dn, vals=[new_name]), + ), + ] + if self.deleteoldrdn: + change.append( + Changes( + operation=Operation.DELETE, + modification=PartialAttribute( + type=old_dn, + vals=[old_name], + ), + ), + ) + return access_manager.check_modify_access( + changes=change, + aces=directory.access_control_entries, + entity_type_id=directory.entity_type_id, + ) + + async def _delete_old_inherited_aces( + self, + session: AsyncSession, + directory: Directory, + old_depth: int, + ) -> None: + old_inherited_aces_ids = select(qa(AccessControlEntry.id)).where( + qa(AccessControlEntry.directories).contains(directory), + qa(AccessControlEntry.depth) != old_depth, + ) + await session.execute( + delete(ace_directory_memberships_table) + .filter_by( + directory_id=directory.id, + ) + .where( + ace_directory_memberships_table.c.access_control_entry_id.in_( + old_inherited_aces_ids, + ), + ), + ) + + async def _update_explicit_aces( + self, + session: AsyncSession, + directory: Directory, + old_depth: int, + new_path: list[str], + ) -> None: + new_path_dn = ",".join(reversed(new_path)) + new_depth = len(new_path) + + explicit_aces_ids = select(qa(AccessControlEntry.id)).where( + qa(AccessControlEntry.directories).contains(directory), + qa(AccessControlEntry.depth) == old_depth, + ) + await session.execute( + update(AccessControlEntry) + .where(qa(AccessControlEntry.id).in_(explicit_aces_ids)) + .values(path=new_path_dn, depth=new_depth), + ) + + async def handle( # noqa: C901 self, ctx: LDAPModifyDNRequestContext, ) -> AsyncGenerator[ModifyDNResponse, None]: @@ -122,8 +210,8 @@ async def handle( query = ctx.access_manager.mutate_query_with_ace_load( user_role_ids=ctx.ldap_session.user.role_ids, query=query, - ace_types=[AceType.DELETE], - require_attribute_type_null=True, + ace_types=[AceType.DELETE, AceType.WRITE], + load_attribute_type=True, ) directory = await ctx.session.scalar(query) @@ -142,15 +230,27 @@ async def handle( ) return - old_name = directory.name new_dn, new_name = self.newrdn.split("=") - directory.name = new_name + is_move_to_new_superior = self._is_move_to_new_superior(directory) + old_name = directory.name old_path = directory.path old_dn = old_path[-1].split("=")[0] - old_depth = directory.depth + if not self._can_modify_rdn( + access_manager=ctx.access_manager, + directory=directory, + old_dn=old_dn, + old_name=old_name, + new_dn=new_dn, + new_name=new_name, + ): + yield ModifyDNResponse( + result_code=LDAPCodes.INSUFFICIENT_ACCESS_RIGHTS, + ) + return + if ( directory.entity_type and not ctx.attribute_value_validator.is_value_valid( @@ -166,13 +266,31 @@ async def handle( ) return - if ( - self.new_superior - and directory.parent - and self.new_superior != directory.parent.path_dn - ): + directory.name = new_name + + if is_move_to_new_superior: + delete_aces = [ + ace + for ace in directory.access_control_entries + if ( + ace.ace_type == AceType.DELETE + and ace.attribute_type is None + ) + ] + + can_delete = ctx.access_manager.check_entity_level_access( + aces=delete_aces, + entity_type_id=directory.entity_type_id, + ) + + if not can_delete: + yield ModifyDNResponse( + result_code=LDAPCodes.INSUFFICIENT_ACCESS_RIGHTS, + ) + return + new_sup_query = select(Directory).filter( - get_filter_from_path(self.new_superior), + get_filter_from_path(self.new_superior), # type: ignore ) new_sup_query = ctx.access_manager.mutate_query_with_ace_load( user_role_ids=ctx.ldap_session.user.role_ids, @@ -203,11 +321,11 @@ async def handle( try: await ctx.session.flush() - await ctx.session.execute( - delete(ace_directory_memberships_table) - .filter_by(directory_id=directory.id), - ) # fmt: skip - + await self._delete_old_inherited_aces( + ctx.session, + directory=directory, + old_depth=old_depth, + ) await ctx.role_use_case.inherit_parent_aces( parent_directory=directory.parent, directory=directory, @@ -265,20 +383,12 @@ async def handle( ) await ctx.session.flush() - explicit_aces_query = ( - select(AccessControlEntry) - .options(selectinload(qa(AccessControlEntry.directories))) - .where( - qa(AccessControlEntry.directories).any( - qa(Directory.id) == directory.id, - ), - qa(AccessControlEntry.depth) == old_depth, - ) + await self._update_explicit_aces( + ctx.session, + directory, + old_depth, + new_path, ) - for ace in await ctx.session.scalars(explicit_aces_query): - ace.directories.append(directory) - ace.path = directory.path_dn - ace.depth = directory.depth await ctx.session.flush() diff --git a/app/ldap_protocol/ldap_requests/search.py b/app/ldap_protocol/ldap_requests/search.py index c6505322a..c9ab0bd57 100644 --- a/app/ldap_protocol/ldap_requests/search.py +++ b/app/ldap_protocol/ldap_requests/search.py @@ -104,6 +104,7 @@ class SearchRequest(BaseRequest): ``` """ + RESPONSE_TYPE: ClassVar[type] = SearchResultDone PROTOCOL_OP: ClassVar[int] = ProtocolRequests.SEARCH CONTEXT_TYPE: ClassVar[type] = LDAPSearchRequestContext @@ -166,8 +167,8 @@ def all_attrs(self) -> bool: return "*" in self.requested_attrs or not self.requested_attrs @cached_property - def requested_attrs(self) -> list[str]: - return [attr.lower() for attr in self.attributes] + def requested_attrs(self) -> set[str]: + return {attr.lower() for attr in self.attributes} @classmethod def from_data(cls, data: dict[str, list[ASN1Row]]) -> "SearchRequest": @@ -252,7 +253,7 @@ def check_netlogon_filter(self) -> bool: return "netlogon" in self.requested_attrs async def _get_netlogon(self, ctx: LDAPSearchRequestContext) -> bytes: - rootdse = await ctx.rootdse_rd.get(self.requested_attrs) + rootdse = await ctx.rootdse_rd.get(set()) nl = NetLogonAttributeHandler.from_filter(rootdse, self.filter) return nl.get_attr() @@ -303,9 +304,16 @@ async def get_result( result_code=LDAPCodes.INSUFFICIENT_ACCESS_RIGHTS, ) return + base_directories = await get_base_directories(ctx.session) + if ( + ctx.settings.is_global_catalog + and not self.base_object + and base_directories + ): + self.base_object = base_directories[0].path_dn query = self._build_query( - await get_base_directories(ctx.session), + base_directories, user, ctx.access_manager, ) diff --git a/app/ldap_protocol/ldap_schema/attribute_type_dao.py b/app/ldap_protocol/ldap_schema/attribute_type_dao.py index 30211f74c..63b795e0a 100644 --- a/app/ldap_protocol/ldap_schema/attribute_type_dao.py +++ b/app/ldap_protocol/ldap_schema/attribute_type_dao.py @@ -56,9 +56,9 @@ def __init__(self, session: AsyncSession) -> None: """Initialize Attribute Type DAO with session.""" self.__session = session - async def get(self, _id: str) -> AttributeTypeDTO: - """Get Attribute Type by id.""" - return _convert_model_to_dto(await self._get_one_raw_by_name(_id)) + async def get(self, name: str) -> AttributeTypeDTO: + """Get Attribute Type by name.""" + return _convert_model_to_dto(await self._get_one_raw_by_name(name)) async def get_all(self) -> list[AttributeTypeDTO]: """Get all Attribute Types.""" @@ -82,7 +82,7 @@ async def create(self, dto: AttributeTypeDTO) -> None: + f" '{dto.name}' already exists.", ) - async def update(self, _id: str, dto: AttributeTypeDTO) -> None: + async def update(self, name: str, dto: AttributeTypeDTO) -> None: """Update Attribute Type. Docs: @@ -95,7 +95,7 @@ async def update(self, _id: str, dto: AttributeTypeDTO) -> None: can only be modified for non-system attributes to preserve LDAP schema integrity. """ - obj = await self._get_one_raw_by_name(_id) + obj = await self._get_one_raw_by_name(name) obj.is_included_anr = dto.is_included_anr @@ -106,9 +106,15 @@ async def update(self, _id: str, dto: AttributeTypeDTO) -> None: await self.__session.flush() - async def delete(self, _id: str) -> None: + async def update_sys_flags(self, name: str, dto: AttributeTypeDTO) -> None: + """Update system flags of Attribute Type.""" + obj = await self._get_one_raw_by_name(name) + obj.system_flags = dto.system_flags + await self.__session.flush() + + async def delete(self, name: str) -> None: """Delete Attribute Type.""" - attribute_type = await self._get_one_raw_by_name(_id) + attribute_type = await self._get_one_raw_by_name(name) await self.__session.delete(attribute_type) await self.__session.flush() @@ -150,7 +156,7 @@ async def _get_one_raw_by_name(self, name: str) -> AttributeType: async def get_all_by_names( self, names: list[str] | set[str], - ) -> list[AttributeTypeDTO]: + ) -> list[AttributeTypeDTO[int]]: """Get list of Attribute Types by names. :param list[str] names: Attribute Type names. diff --git a/app/ldap_protocol/ldap_schema/attribute_type_system_flags_use_case.py b/app/ldap_protocol/ldap_schema/attribute_type_system_flags_use_case.py new file mode 100644 index 000000000..a903028a8 --- /dev/null +++ b/app/ldap_protocol/ldap_schema/attribute_type_system_flags_use_case.py @@ -0,0 +1,64 @@ +"""SystemFlags helpers for LDAP schema objects. + +Copyright (c) 2026 MultiFactor +License: https://github.com/MultiDirectoryLab/MultiDirectory/blob/main/LICENSE +""" + +from __future__ import annotations + +from enum import IntFlag + +from ldap_protocol.ldap_schema.dto import AttributeTypeDTO + + +class AttributeTypeSystemFlags(IntFlag): + """SystemFlags for attributeSchema objects in AD. + + Bits from 7 to 25 unused. Must be zero and ignored. + ms-adts/1e38247d-8234-4273-9de3-bbf313548631 + """ + + ATTR_NOT_REPLICATED = 0x00000001 # 31 + ATTR_REQ_PARTIAL_SET_MEMBER = 0x00000002 # 30 + ATTR_IS_CONSTRUCTED = 0x00000004 # 29 + ATTR_IS_OPERATIONAL = 0x00000008 # 28 + SCHEMA_BASE_OBJECT = 0x00000010 # 27 + ATTR_IS_RDN = 0x00000020 # 26 + DISALLOW_MOVE_ON_DELETE = 0x02000000 # 6 + DOMAIN_DISALLOW_MOVE = 0x04000000 # 5 + DOMAIN_DISALLOW_RENAME = 0x08000000 # 4 + CONFIG_ALLOW_LIMITED_MOVE = 0x10000000 # 3 + CONFIG_ALLOW_MOVE = 0x20000000 # 2 + CONFIG_ALLOW_RENAME = 0x40000000 # 1 + DISALLOW_DELETE = 0x80000000 # 0 + + +class AttributeTypeSystemFlagsUseCase: + def is_attr_replicated( + self, + attribute_type_dto: AttributeTypeDTO, + ) -> bool: + """Check if attribute is replicated based on system_flags.""" + return not bool( + attribute_type_dto.system_flags + & AttributeTypeSystemFlags.ATTR_NOT_REPLICATED, + ) + + def set_attr_replication_flag( + self, + attribute_type_dto: AttributeTypeDTO, + need_to_replicate: bool, + ) -> AttributeTypeDTO: + """Set/clear replication flag in systemFlags.""" + if not need_to_replicate: + attribute_type_dto.system_flags = int( + attribute_type_dto.system_flags + | AttributeTypeSystemFlags.ATTR_NOT_REPLICATED, + ) + else: + attribute_type_dto.system_flags = int( + attribute_type_dto.system_flags + & ~AttributeTypeSystemFlags.ATTR_NOT_REPLICATED, + ) + + return attribute_type_dto diff --git a/app/ldap_protocol/ldap_schema/attribute_type_use_case.py b/app/ldap_protocol/ldap_schema/attribute_type_use_case.py index ebaf1f986..95f5425fc 100644 --- a/app/ldap_protocol/ldap_schema/attribute_type_use_case.py +++ b/app/ldap_protocol/ldap_schema/attribute_type_use_case.py @@ -9,6 +9,9 @@ from abstract_service import AbstractService from enums import AuthorizationRules from ldap_protocol.ldap_schema.attribute_type_dao import AttributeTypeDAO +from ldap_protocol.ldap_schema.attribute_type_system_flags_use_case import ( + AttributeTypeSystemFlagsUseCase, +) from ldap_protocol.ldap_schema.dto import AttributeTypeDTO from ldap_protocol.ldap_schema.object_class_dao import ObjectClassDAO from ldap_protocol.utils.pagination import PaginationParams, PaginationResult @@ -20,15 +23,19 @@ class AttributeTypeUseCase(AbstractService): def __init__( self, attribute_type_dao: AttributeTypeDAO, + attribute_type_system_flags_use_case: AttributeTypeSystemFlagsUseCase, object_class_dao: ObjectClassDAO, ) -> None: """Init AttributeTypeUseCase.""" self._attribute_type_dao = attribute_type_dao + self._attribute_type_system_flags_use_case = ( + attribute_type_system_flags_use_case + ) self._object_class_dao = object_class_dao - async def get(self, _id: str) -> AttributeTypeDTO: - """Get Attribute Type by id.""" - dto = await self._attribute_type_dao.get(_id) + async def get(self, name: str) -> AttributeTypeDTO: + """Get Attribute Type by name.""" + dto = await self._attribute_type_dao.get(name) dto.object_class_names = await self._object_class_dao.get_object_class_names_include_attribute_type( # noqa: E501 dto.name, ) @@ -42,13 +49,13 @@ async def create(self, dto: AttributeTypeDTO) -> None: """Create Attribute Type.""" await self._attribute_type_dao.create(dto) - async def update(self, _id: str, dto: AttributeTypeDTO) -> None: + async def update(self, name: str, dto: AttributeTypeDTO) -> None: """Update Attribute Type.""" - await self._attribute_type_dao.update(_id, dto) + await self._attribute_type_dao.update(name, dto) - async def delete(self, _id: str) -> None: + async def delete(self, name: str) -> None: """Delete Attribute Type.""" - await self._attribute_type_dao.delete(_id) + await self._attribute_type_dao.delete(name) async def get_paginator( self, @@ -68,10 +75,29 @@ async def delete_all_by_names(self, names: list[str]) -> None: """Delete not system Attribute Types by names.""" return await self._attribute_type_dao.delete_all_by_names(names) + async def is_attr_replicated(self, name: str) -> bool: + """Check if attribute is replicated based on systemFlags.""" + dto = await self.get(name) + return self._attribute_type_system_flags_use_case.is_attr_replicated(dto) # noqa: E501 # fmt: skip + + async def set_attr_replication_flag( + self, + name: str, + need_to_replicate: bool, + ) -> None: + """Set replication flag in systemFlags.""" + dto = await self.get(name) + dto = self._attribute_type_system_flags_use_case.set_attr_replication_flag( # noqa: E501 + dto, + need_to_replicate, + ) + await self._attribute_type_dao.update_sys_flags(dto.name, dto) + PERMISSIONS: ClassVar[dict[str, AuthorizationRules]] = { get.__name__: AuthorizationRules.ATTRIBUTE_TYPE_GET, create.__name__: AuthorizationRules.ATTRIBUTE_TYPE_CREATE, get_paginator.__name__: AuthorizationRules.ATTRIBUTE_TYPE_GET_PAGINATOR, # noqa: E501 update.__name__: AuthorizationRules.ATTRIBUTE_TYPE_UPDATE, delete_all_by_names.__name__: AuthorizationRules.ATTRIBUTE_TYPE_DELETE_ALL_BY_NAMES, # noqa: E501 + set_attr_replication_flag.__name__: AuthorizationRules.ATTRIBUTE_TYPE_SET_ATTR_REPLICATION_FLAG, # noqa: E501 } diff --git a/app/ldap_protocol/ldap_schema/dto.py b/app/ldap_protocol/ldap_schema/dto.py index 118a6e1e8..7699b6966 100644 --- a/app/ldap_protocol/ldap_schema/dto.py +++ b/app/ldap_protocol/ldap_schema/dto.py @@ -22,6 +22,7 @@ class AttributeTypeDTO(Generic[_IdT]): single_value: bool no_user_modification: bool is_system: bool + system_flags: int is_included_anr: bool id: _IdT = None # type: ignore object_class_names: set[str] = field(default_factory=set) diff --git a/app/ldap_protocol/ldap_schema/entity_type_dao.py b/app/ldap_protocol/ldap_schema/entity_type_dao.py index abfdc49d1..1a708d711 100644 --- a/app/ldap_protocol/ldap_schema/entity_type_dao.py +++ b/app/ldap_protocol/ldap_schema/entity_type_dao.py @@ -85,9 +85,9 @@ async def create(self, dto: EntityTypeDTO[None]) -> None: f"Entity Type with name '{dto.name}' already exists.", ) - async def update(self, _id: str, dto: EntityTypeDTO[int]) -> None: + async def update(self, name: str, dto: EntityTypeDTO[int]) -> None: """Update an Entity Type.""" - entity_type = await self._get_one_raw_by_name(_id) + entity_type = await self._get_one_raw_by_name(name) try: await self.__object_class_dao.is_all_object_classes_exists( @@ -153,9 +153,9 @@ async def update(self, _id: str, dto: EntityTypeDTO[int]) -> None: f"names {dto.object_class_names} already exists.", ) - async def delete(self, _id: str) -> None: + async def delete(self, name: str) -> None: """Delete an Entity Type.""" - entity_type = await self._get_one_raw_by_name(_id) + entity_type = await self._get_one_raw_by_name(name) await self.__session.delete(entity_type) await self.__session.flush() @@ -182,10 +182,7 @@ async def get_paginator( session=self.__session, ) - async def _get_one_raw_by_name( - self, - name: str, - ) -> EntityType: + async def _get_one_raw_by_name(self, name: str) -> EntityType: """Get single Entity Type by name. :param str name: Entity Type name. @@ -203,14 +200,14 @@ async def _get_one_raw_by_name( ) return entity_type - async def get(self, _id: str) -> EntityTypeDTO: + async def get(self, name: str) -> EntityTypeDTO: """Get single Entity Type by name. :param str name: Entity Type name. :raise EntityTypeNotFoundError: If Entity Type not found. :return EntityType: Instance of Entity Type. """ - return _convert(await self._get_one_raw_by_name(_id)) + return _convert(await self._get_one_raw_by_name(name)) async def get_entity_type_by_object_class_names( self, diff --git a/app/ldap_protocol/ldap_schema/entity_type_use_case.py b/app/ldap_protocol/ldap_schema/entity_type_use_case.py index 5958e6a99..e7589c3f4 100644 --- a/app/ldap_protocol/ldap_schema/entity_type_use_case.py +++ b/app/ldap_protocol/ldap_schema/entity_type_use_case.py @@ -42,10 +42,10 @@ async def create(self, dto: EntityTypeDTO) -> None: ) await self._entity_type_dao.create(dto) - async def update(self, _id: str, dto: EntityTypeDTO) -> None: + async def update(self, name: str, dto: EntityTypeDTO) -> None: """Update Entity Type.""" try: - entity_type = await self.get(_id) + entity_type = await self.get(name) except EntityTypeNotFoundError: raise EntityTypeCantModifyError @@ -53,13 +53,13 @@ async def update(self, _id: str, dto: EntityTypeDTO) -> None: raise EntityTypeCantModifyError( f"Entity Type '{dto.name}' is system and cannot be modified.", ) - if _id != dto.name: - await self._validate_name(name=_id) + if name != dto.name: + await self._validate_name(name=dto.name) await self._entity_type_dao.update(entity_type.name, dto) - async def get(self, _id: str) -> EntityTypeDTO: + async def get(self, name: str) -> EntityTypeDTO: """Get Entity Type by name.""" - return await self._entity_type_dao.get(_id) + return await self._entity_type_dao.get(name) async def _validate_name( self, diff --git a/app/ldap_protocol/ldap_schema/object_class_dao.py b/app/ldap_protocol/ldap_schema/object_class_dao.py index 9bc29644e..83bcd7eef 100644 --- a/app/ldap_protocol/ldap_schema/object_class_dao.py +++ b/app/ldap_protocol/ldap_schema/object_class_dao.py @@ -77,9 +77,9 @@ async def get_object_class_names_include_attribute_type( ) # fmt: skip return set(row[0] for row in result.fetchall()) - async def delete(self, _id: str) -> None: + async def delete(self, name: str) -> None: """Delete Object Class.""" - object_class = await self._get_one_raw_by_name(_id) + object_class = await self._get_one_raw_by_name(name) await self.__session.delete(object_class) await self.__session.flush() @@ -245,14 +245,14 @@ async def _get_one_raw_by_name(self, name: str) -> ObjectClass: ) return object_class - async def get(self, _id: str) -> ObjectClassDTO: - """Get single Object Class by id. + async def get(self, name: str) -> ObjectClassDTO: + """Get single Object Class by name. - :param str _id: Object Class name. + :param str name: Object Class name. :raise ObjectClassNotFoundError: If Object Class not found. :return ObjectClass: Instance of Object Class. """ - return _converter(await self._get_one_raw_by_name(_id)) + return _converter(await self._get_one_raw_by_name(name)) async def get_all_by_names( self, @@ -273,16 +273,9 @@ async def get_all_by_names( ) # fmt: skip return list(map(_converter, query.all())) - async def update(self, _id: str, dto: ObjectClassDTO[None, str]) -> None: - """Modify Object Class. - - :param ObjectClassDTO object_class: Object Class. - :param ObjectClassDTO dto: New statement ObjectClass - :raise ObjectClassCantModifyError: If Object Class is system,\ - it cannot be changed. - :return None. - """ - obj = await self._get_one_raw_by_name(_id) + async def update(self, name: str, dto: ObjectClassDTO[None, str]) -> None: + """Update Object Class.""" + obj = await self._get_one_raw_by_name(name) if obj.is_system: raise ObjectClassCantModifyError( "System Object Class cannot be modified.", diff --git a/app/ldap_protocol/ldap_schema/object_class_use_case.py b/app/ldap_protocol/ldap_schema/object_class_use_case.py index c35a845ac..11c171a58 100644 --- a/app/ldap_protocol/ldap_schema/object_class_use_case.py +++ b/app/ldap_protocol/ldap_schema/object_class_use_case.py @@ -30,9 +30,9 @@ async def get_all(self) -> list[ObjectClassDTO[int, AttributeTypeDTO]]: """Get all Object Classes.""" return await self._object_class_dao.get_all() - async def delete(self, _id: str) -> None: + async def delete(self, name: str) -> None: """Delete Object Class.""" - await self._object_class_dao.delete(_id) + await self._object_class_dao.delete(name) async def get_paginator( self, @@ -45,9 +45,9 @@ async def create(self, dto: ObjectClassDTO[None, str]) -> None: """Create a new Object Class.""" await self._object_class_dao.create(dto) - async def get(self, _id: str) -> ObjectClassDTO: - """Get Object Class by id.""" - dto = await self._object_class_dao.get(_id) + async def get(self, name: str) -> ObjectClassDTO: + """Get Object Class by name.""" + dto = await self._object_class_dao.get(name) dto.entity_type_names = ( await self._entity_type_dao.get_entity_type_names_include_oc_name( dto.name, @@ -62,9 +62,9 @@ async def get_all_by_names( """Get list of Object Classes by names.""" return await self._object_class_dao.get_all_by_names(names) - async def update(self, _id: str, dto: ObjectClassDTO[None, str]) -> None: + async def update(self, name: str, dto: ObjectClassDTO[None, str]) -> None: """Modify Object Class.""" - await self._object_class_dao.update(_id, dto) + await self._object_class_dao.update(name, dto) async def delete_all_by_names(self, names: list[str]) -> None: """Delete not system Object Classes by Names.""" diff --git a/app/ldap_protocol/master_check_use_case.py b/app/ldap_protocol/master_check_use_case.py new file mode 100644 index 000000000..2c010e788 --- /dev/null +++ b/app/ldap_protocol/master_check_use_case.py @@ -0,0 +1,30 @@ +"""Check Master Use Case. + +Copyright (c) 2026 MultiFactor +License: https://github.com/MultiDirectoryLab/MultiDirectory/blob/main/LICENSE +""" + +from typing import ClassVar, Protocol + +from abstract_service import AbstractService +from enums import AuthorizationRules + + +class MasterGatewayProtocol(Protocol): + """Master DB Gateway Protocol.""" + + async def check_master(self) -> bool: ... + + +class MasterCheckUseCase(AbstractService): + """Check Master Use Case.""" + + _master_gateway: MasterGatewayProtocol + + def __init__(self, master_gateway: MasterGatewayProtocol) -> None: + self._master_gateway = master_gateway + + async def check_master(self) -> bool: + return await self._master_gateway.check_master() + + PERMISSIONS: ClassVar[dict[str, AuthorizationRules]] = {} diff --git a/app/ldap_protocol/objects.py b/app/ldap_protocol/objects.py index 75effb3f0..d69301c0e 100644 --- a/app/ldap_protocol/objects.py +++ b/app/ldap_protocol/objects.py @@ -5,6 +5,7 @@ """ from enum import IntEnum, IntFlag, StrEnum, unique +from functools import cached_property from typing import Annotated import annotated_types @@ -82,8 +83,9 @@ class Changes(BaseModel): operation: Operation modification: PartialAttribute - def get_name(self) -> str: - """Get mod name.""" + @cached_property + def l_type(self) -> str: + """Get modification type (it's attribute name) in lower case.""" return self.modification.type.lower() diff --git a/app/ldap_protocol/policies/audit/events/dataclasses.py b/app/ldap_protocol/policies/audit/events/dataclasses.py index 78432b6ca..38c922043 100644 --- a/app/ldap_protocol/policies/audit/events/dataclasses.py +++ b/app/ldap_protocol/policies/audit/events/dataclasses.py @@ -208,6 +208,7 @@ def destination_dict(self) -> dict[str, Any]: "policy_id": self.policy_id, "details": self.details, "service_name": self.service_name, + "severity": self.severity, } diff --git a/app/ldap_protocol/policies/audit/monitor.py b/app/ldap_protocol/policies/audit/monitor.py index 5ce08d0ab..6258daf49 100644 --- a/app/ldap_protocol/policies/audit/monitor.py +++ b/app/ldap_protocol/policies/audit/monitor.py @@ -14,6 +14,7 @@ from config import Settings from entities import User +from ldap_protocol.auth.dto import LoginRequestDTO from ldap_protocol.auth.exceptions.mfa import ( AuthenticationError, ForbiddenError, @@ -22,7 +23,6 @@ MFATokenError, NetworkPolicyError, ) -from ldap_protocol.auth.schemas import OAuth2Form from ldap_protocol.identity.exceptions import ( AuthorizationError, AuthValidationError, @@ -224,7 +224,7 @@ async def wrapped_proxy_request( def wrap_login(self, attr: _T) -> _T: @wraps(attr) async def wrapped_login( - form: OAuth2Form, + form: LoginRequestDTO, url: URL, ip: IPv4Address | IPv6Address, user_agent: str, diff --git a/app/ldap_protocol/policies/network/use_cases.py b/app/ldap_protocol/policies/network/use_cases.py index cde4294d6..38dbc11cc 100644 --- a/app/ldap_protocol/policies/network/use_cases.py +++ b/app/ldap_protocol/policies/network/use_cases.py @@ -148,38 +148,58 @@ async def update( ) -> NetworkPolicyDTO: """Update network policy.""" policy = await self._network_policy_gateway.get_with_for_update(dto.id) + + await self._apply_field_updates(policy, dto) + await self._apply_netmask_updates(policy, dto) + await self._apply_group_updates(policy, dto) + + if await self._network_policy_gateway.check_policy_exists(policy): + raise NetworkPolicyAlreadyExistsError("Entry already exists") + + await self._session.commit() + + return _convert_model_to_dto(policy) + + async def _apply_field_updates( + self, + policy: NetworkPolicy, + dto: NetworkPolicyUpdateDTO, + ) -> None: + """Apply regular field updates.""" for field in dto.fields_to_update: value = getattr(dto, field) if value is not None: setattr(policy, field, value) + async def _apply_netmask_updates( + self, + policy: NetworkPolicy, + dto: NetworkPolicyUpdateDTO, + ) -> None: + """Apply netmask updates.""" if dto.netmasks and dto.raw: policy.netmasks = dto.netmasks policy.raw = dto.raw - if ( - dto.groups is not None - and len(dto.groups) > 0 - and len(dto.groups) != 0 - ): - policy.groups = await self._network_policy_gateway.get_groups( - dto.groups, + async def _apply_group_updates( + self, + policy: NetworkPolicy, + dto: NetworkPolicyUpdateDTO, + ) -> None: + """Apply group updates.""" + if dto.groups is not None: + policy.groups = ( + await self._network_policy_gateway.get_groups(dto.groups) + if dto.groups + else [] ) - if ( - dto.mfa_groups is not None - and len(dto.mfa_groups) > 0 - and len(dto.mfa_groups) != 0 - ): - policy.mfa_groups = await self._network_policy_gateway.get_groups( - dto.mfa_groups, - ) - if await self._network_policy_gateway.check_policy_exists(policy): - raise NetworkPolicyAlreadyExistsError( - "Entry already exists", + if dto.mfa_groups is not None: + policy.mfa_groups = ( + await self._network_policy_gateway.get_groups(dto.mfa_groups) + if dto.mfa_groups + else [] ) - await self._session.commit() - return _convert_model_to_dto(policy) async def swap_priorities(self, id1: int, id2: int) -> SwapPrioritiesDTO: """Swap priorities for network policies.""" diff --git a/app/ldap_protocol/roles/access_manager.py b/app/ldap_protocol/roles/access_manager.py index 9e7a964f4..0edc74190 100644 --- a/app/ldap_protocol/roles/access_manager.py +++ b/app/ldap_protocol/roles/access_manager.py @@ -123,17 +123,16 @@ def check_modify_access( return False for change in changes: - attr_name = change.get_name() if change.operation == Operation.DELETE: if not cls._check_modify_access( - attr_name, + change.l_type, filtered_aces, AceType.DELETE, ): return False elif change.operation == Operation.ADD: if not cls._check_modify_access( - attr_name, + change.l_type, filtered_aces, AceType.WRITE, ): @@ -141,12 +140,12 @@ def check_modify_access( else: if not ( cls._check_modify_access( - attr_name, + change.l_type, filtered_aces, AceType.WRITE, ) and cls._check_modify_access( - attr_name, + change.l_type, filtered_aces, AceType.DELETE, ) @@ -305,12 +304,19 @@ def mutate_query_with_ace_load( null attribute_type_id :return: mutated query with access control entries loaded """ - selectin_loader = selectinload( + base_loader = selectinload( qa(Directory.access_control_entries), ) + + loader_options = [ + base_loader.joinedload(qa(AccessControlEntry.entity_type)), + ] + if load_attribute_type: - selectin_loader = selectin_loader.joinedload( - qa(AccessControlEntry.attribute_type), + loader_options.append( + base_loader.joinedload( + qa(AccessControlEntry.attribute_type), + ), ) criteria_conditions = [ @@ -332,7 +338,7 @@ def mutate_query_with_ace_load( ) return query.options( - selectin_loader, + *loader_options, with_loader_criteria( AccessControlEntry, and_(*criteria_conditions), diff --git a/app/ldap_protocol/roles/ace_dao.py b/app/ldap_protocol/roles/ace_dao.py index 679268cfd..202060115 100644 --- a/app/ldap_protocol/roles/ace_dao.py +++ b/app/ldap_protocol/roles/ace_dao.py @@ -192,7 +192,6 @@ async def create_bulk(self, dtos: list[AccessControlEntryDTO]) -> None: objects to create. """ directory_cache = {} - new_aces = [] for ace in dtos: cache_key = (ace.base_dn, ace.scope) if cache_key not in directory_cache: @@ -219,9 +218,8 @@ async def create_bulk(self, dtos: list[AccessControlEntryDTO]) -> None: is_allow=ace.is_allow, directories=directory_cache[cache_key], ) - new_aces.append(new_ace) + self._session.add(new_ace) - self._session.add_all(new_aces) try: await self._session.flush() except IntegrityError: diff --git a/app/ldap_protocol/roles/role_use_case.py b/app/ldap_protocol/roles/role_use_case.py index d9c2921e0..75a1339f1 100644 --- a/app/ldap_protocol/roles/role_use_case.py +++ b/app/ldap_protocol/roles/role_use_case.py @@ -216,6 +216,31 @@ async def create_kerberos_system_role(self) -> None: ) await self._access_control_entry_dao.create_bulk(aces) + async def add_read_only_role_to_krbadmin_group(self) -> None: + """Add Read Only role to krbadmin group.""" + base_dn_list = await get_base_directories(self._role_dao._session) # noqa: SLF001 + if not base_dn_list: + return + + try: + read_only_role = await self._role_dao.get_by_name( + RoleConstants.READ_ONLY_ROLE_NAME, + ) + except RoleNotFoundError: + return + else: + new_groups_dn = [ + RoleConstants.KERBEROS_GROUP_CN + base_dn_list[0].path_dn, + RoleConstants.READONLY_GROUP_CN + base_dn_list[0].path_dn, + ] + + read_only_role.groups = new_groups_dn + + await self._role_dao.update( + read_only_role.get_id(), + read_only_role, + ) + async def delete_kerberos_system_role(self) -> None: """Delete the Kerberos system role.""" try: diff --git a/app/ldap_protocol/rootdse/netlogon.py b/app/ldap_protocol/rootdse/netlogon.py index 8efb2b78c..5dc225b0b 100644 --- a/app/ldap_protocol/rootdse/netlogon.py +++ b/app/ldap_protocol/rootdse/netlogon.py @@ -9,6 +9,7 @@ License: https://github.com/MultiDirectoryLab/MultiDirectory/blob/main/LICENSE """ +import codecs import ipaddress import struct import uuid @@ -180,13 +181,17 @@ def set_info(self) -> None: ) @staticmethod - def _convert_little_endian_string_to_int(value: str) -> int: + def _convert_little_endian_string_to_int(value: str | bytes) -> int: """Convert little-endian string to int.""" - return int.from_bytes( - value.encode().decode("unicode_escape").encode(), - byteorder="little", - signed=False, - ) + if isinstance(value, bytes): + return int.from_bytes(value, "little", signed=False) + + if "\\x" in value: + value = codecs.decode(value, "unicode_escape").encode("latin-1") + else: + value = value.encode("latin-1", errors="strict") + + return int.from_bytes(value, "little", signed=False) def get_attr(self) -> bytes: """Get NetLogon response.""" @@ -291,6 +296,8 @@ def _get_netlogon_response_5_ex(self) -> bytes: DSFlag.CLOSEST_FLAG, DSFlag.WRITABLE_FLAG, DSFlag.GOOD_TIMESERV_FLAG, + DSFlag.KDC_FLAG, + DSFlag.WS_FLAG, ]: ds_flags |= flag diff --git a/app/ldap_protocol/rootdse/reader.py b/app/ldap_protocol/rootdse/reader.py index 5f128fb9c..065be0a54 100644 --- a/app/ldap_protocol/rootdse/reader.py +++ b/app/ldap_protocol/rootdse/reader.py @@ -21,62 +21,69 @@ def __init__(self, settings: Settings, gw: DomainReadProtocol) -> None: async def get( self, - requested_attrs: list[str], + requested_attrs: set[str], ) -> defaultdict[str, list[str]]: domain = await self._gw.get_domain() - data = defaultdict(list) schema = "CN=Schema" - if requested_attrs == ["subschemasubentry"]: - data["subschemaSubentry"].append(schema) - return data - - data["dnsHostName"].append(domain.name) - data["serverName"].append(domain.name) - data["serviceName"].append(domain.name) - data["dsServiceName"].append(domain.name) - data["LDAPServiceName"].append(domain.name) - data["dnsForestName"].append(domain.name) - data["dnsDomainName"].append(domain.name) - data["domainGuid"].append(str(domain.object_guid)) - data["vendorName"].append(self._settings.VENDOR_NAME) - data["vendorVersion"].append(self._settings.VENDOR_VERSION) - data["namingContexts"].append(domain.path_dn) - data["namingContexts"].append(schema) - data["rootDomainNamingContext"].append(domain.path_dn) - data["supportedLDAPVersion"].append("3") - data["defaultNamingContext"].append(domain.path_dn) - data["currentTime"].append( - get_generalized_now(self._settings.TIMEZONE), - ) - data["subschemaSubentry"].append(schema) - data["schemaNamingContext"].append(schema) - data["supportedSASLMechanisms"] = [ - "ANONYMOUS", - "PLAIN", - "GSSAPI", - "GSS-SPNEGO", - ] - data["highestCommittedUSN"].append("126991") - data["supportedExtension"] = [ - "1.3.6.1.4.1.4203.1.11.3", # whoami - "1.3.6.1.4.1.4203.1.11.1", # password modify - ] - data["supportedControl"] = [ - "2.16.840.1.113730.3.4.4", # password expire policy - ] - data["domainFunctionality"].append("0") - data["supportedLDAPPolicies"] = [ - "MaxConnIdleTime", - "MaxPageSize", - "MaxValRange", - ] - data["supportedCapabilities"] = [ - "1.2.840.113556.1.4.800", # ACTIVE_DIRECTORY_OID - "1.2.840.113556.1.4.1670", # ACTIVE_DIRECTORY_V51_OID - "1.2.840.113556.1.4.1791", # ACTIVE_DIRECTORY_LDAP_INTEG_OID - ] - - return data + + all_attrs: dict[str, list[str]] = { + "dnsHostName": [domain.name], + "serverName": [domain.name], + "serviceName": [domain.name], + "dsServiceName": [domain.name], + "LDAPServiceName": [domain.name], + "dnsForestName": [domain.name], + "dnsDomainName": [domain.name], + "domainGuid": [str(domain.object_guid)], + "vendorName": [self._settings.VENDOR_NAME], + "vendorVersion": [self._settings.VENDOR_VERSION], + "namingContexts": [domain.path_dn, schema], + "rootDomainNamingContext": [domain.path_dn], + "supportedLDAPVersion": ["3"], + "defaultNamingContext": [domain.path_dn], + "currentTime": [ + get_generalized_now(self._settings.TIMEZONE), + ], + "subschemaSubentry": [schema], + "schemaNamingContext": [schema], + "supportedSASLMechanisms": [ + "ANONYMOUS", + "PLAIN", + "GSSAPI", + "GSS-SPNEGO", + ], + "highestCommittedUSN": ["126991"], + "supportedExtension": [ + "1.3.6.1.4.1.4203.1.11.3", # whoami + "1.3.6.1.4.1.4203.1.11.1", # password modify + ], + "supportedControl": [ + "2.16.840.1.113730.3.4.4", # password expire policy + ], + "domainFunctionality": ["7"], + "forestFunctionality": ["7"], + "supportedLDAPPolicies": [ + "MaxConnIdleTime", + "MaxPageSize", + "MaxValRange", + ], + "supportedCapabilities": [ + "1.2.840.113556.1.4.800", # ACTIVE_DIRECTORY_OID + "1.2.840.113556.1.4.1670", # ACTIVE_DIRECTORY_V51_OID + "1.2.840.113556.1.4.1791", # ACTIVE_DIRECTORY_LDAP_INTEG_OID + ], + } + + if not requested_attrs or "*" in requested_attrs: + return defaultdict(list, all_attrs) + + result = defaultdict(list) + + for attr_name, values in all_attrs.items(): + if attr_name.lower() in requested_attrs: + result[attr_name].extend(values) + + return result class DCInfoReader: diff --git a/app/ldap_protocol/session_storage/repository.py b/app/ldap_protocol/session_storage/repository.py index 84366faee..2e73dbc2d 100644 --- a/app/ldap_protocol/session_storage/repository.py +++ b/app/ldap_protocol/session_storage/repository.py @@ -1,9 +1,11 @@ """Enterprise Session Repository.""" +import contextlib from dataclasses import dataclass from ipaddress import IPv4Address, IPv6Address from typing import ClassVar, Literal +from sqlalchemy.exc import OperationalError from sqlalchemy.ext.asyncio import AsyncSession from abstract_service import AbstractService @@ -87,8 +89,13 @@ async def create_session_key( }, ttl=ttl, ) + with contextlib.suppress(OperationalError): + await set_user_logon_attrs( + user, + self.session, + self.settings.TIMEZONE, + ) - await set_user_logon_attrs(user, self.session, self.settings.TIMEZONE) return key async def get_user_sessions( diff --git a/app/ldap_protocol/utils/async_cache.py b/app/ldap_protocol/utils/async_cache.py new file mode 100644 index 000000000..f723440a6 --- /dev/null +++ b/app/ldap_protocol/utils/async_cache.py @@ -0,0 +1,46 @@ +"""Async cache implementation.""" + +import time +from functools import wraps +from typing import Awaitable, Callable, Generic, TypeVar + +from entities import Directory + +T = TypeVar("T") +DEFAULT_CACHE_TIME = 5 * 60 # 5 minutes + + +class AsyncTTLCache(Generic[T]): + def __init__(self, ttl: int | None = DEFAULT_CACHE_TIME) -> None: + self._ttl = ttl + self._value: T | None = None + self._expires_at: float | None = None + + def clear(self) -> None: + self._value = None + self._expires_at = None + + def __call__( + self, + func: Callable[..., Awaitable[T]], + ) -> Callable[..., Awaitable[T]]: + @wraps(func) + async def wrapper(*args: tuple, **kwargs: dict) -> T: + if self._value is not None: + if not self._expires_at or self._expires_at > time.monotonic(): + return self._value + self.clear() + + result = await func(*args, **kwargs) + + self._value = result + self._expires_at = ( + time.monotonic() + self._ttl if self._ttl else None + ) + + return result + + return wrapper + + +base_directories_cache = AsyncTTLCache[list[Directory]]() diff --git a/app/ldap_protocol/utils/cte.py b/app/ldap_protocol/utils/cte.py index e2cbe75ee..7b4628254 100644 --- a/app/ldap_protocol/utils/cte.py +++ b/app/ldap_protocol/utils/cte.py @@ -63,7 +63,7 @@ def find_members_recursive_cte( FROM "Directory" JOIN "Groups" ON "Directory".id = "Groups"."directoryId" WHERE "Directory"."path" = - '{dc=test,dc=md,cn=groups,"cn=domain admins"}' + '{dc=test,dc=md,cn=Groups,"cn=domain admins"}' UNION ALL @@ -129,7 +129,7 @@ def find_root_group_recursive_cte(dn_list: list) -> CTE: FROM "Directory" LEFT OUTER JOIN "Groups" ON "Directory".id = "Groups"."directoryId" WHERE "Directory"."path" = - '{dc=test,dc=md,cn=groups,"cn=domain admins"}' + '{dc=test,dc=md,cn=Groups,"cn=domain admins"}' UNION ALL diff --git a/app/ldap_protocol/utils/pagination.py b/app/ldap_protocol/utils/pagination.py index 5e4ef6e4b..34f9788d3 100644 --- a/app/ldap_protocol/utils/pagination.py +++ b/app/ldap_protocol/utils/pagination.py @@ -95,6 +95,12 @@ class PaginationResult[S, P]: metadata: PaginationMetadata items: Sequence[P] + @classmethod + def _validate_query(cls, query: Select[tuple[S]]) -> bool: + return not ( + query._order_by_clause is None or len(query._order_by_clause) == 0 # noqa: SLF001 + ) + @classmethod async def get( cls, @@ -104,7 +110,7 @@ async def get( session: AsyncSession, ) -> Self: """Get paginator.""" - if query._order_by_clause is None or len(query._order_by_clause) == 0: # noqa: SLF001 + if not cls._validate_query(query): raise ValueError("Select query must have an order_by clause.") metadata = PaginationMetadata( diff --git a/app/ldap_protocol/utils/queries.py b/app/ldap_protocol/utils/queries.py index a1f9243de..2e078b840 100644 --- a/app/ldap_protocol/utils/queries.py +++ b/app/ldap_protocol/utils/queries.py @@ -5,16 +5,23 @@ """ import time +from copy import copy from datetime import datetime from typing import Iterator from zoneinfo import ZoneInfo from sqlalchemy import Column, exists, func, insert, or_, select, update from sqlalchemy.ext.asyncio import AsyncSession -from sqlalchemy.orm import InstrumentedAttribute, joinedload, selectinload +from sqlalchemy.orm import ( + InstrumentedAttribute, + contains_eager, + joinedload, + selectinload, +) from sqlalchemy.sql.expression import ColumnElement from entities import Attribute, Directory, Group, User +from enums import SamAccountTypeCodes from ldap_protocol.ldap_schema.attribute_value_validator import ( AttributeValueValidator, AttributeValueValidatorError, @@ -25,6 +32,7 @@ queryable_attr as qa, ) +from .async_cache import base_directories_cache from .const import EMAIL_RE, GRANT_DN_STRING from .helpers import ( create_integer_hash, @@ -35,13 +43,19 @@ ) +@base_directories_cache async def get_base_directories(session: AsyncSession) -> list[Directory]: """Get base domain directories.""" result = await session.execute( select(Directory) .filter(qa(Directory.parent_id).is_(None)), ) # fmt: skip - return list(result.scalars().all()) + res = [] + for dir_ in result.scalars(): + new_dir = copy(dir_) + session.expunge(new_dir) + res.append(new_dir) + return res async def get_user(session: AsyncSession, name: str) -> User | None: @@ -320,7 +334,7 @@ async def get_dn_by_id(id_: int, session: AsyncSession) -> str: """Get dn by id. >>> await get_dn_by_id(0, session) - >>> "cn=groups,dc=example,dc=com" + >>> "cn=Groups,dc=example,dc=com" """ query = select(Directory).filter_by(id=id_) retval = (await session.scalars(query)).one() @@ -345,7 +359,7 @@ async def create_group( ) -> tuple[Directory, Group]: """Create group in default groups path. - cn=name,cn=groups,dc=domain,dc=com + cn=name,cn=Groups,dc=domain,dc=com :param str name: group name :param int sid: objectSid @@ -354,7 +368,7 @@ async def create_group( base_dn_list = await get_base_directories(session) query = select(Directory).filter( - get_filter_from_path("cn=groups," + base_dn_list[0].path_dn), + get_filter_from_path("cn=Groups," + base_dn_list[0].path_dn), ) parent = (await session.scalars(query)).one() @@ -362,7 +376,7 @@ async def create_group( dir_ = Directory( object_class="", name=name, - parent=parent, + parent_id=parent.id, ) session.add(dir_) await session.flush() @@ -386,7 +400,7 @@ async def create_group( "instanceType": ["4"], "sAMAccountName": [dir_.name], dir_.rdname: [dir_.name], - "sAMAccountType": ["268435456"], + "sAMAccountType": [str(SamAccountTypeCodes.SAM_GROUP_OBJECT.value)], "gidNumber": [str(create_integer_hash(dir_.name))], } @@ -530,3 +544,30 @@ async def set_or_update_primary_group( ) await session.commit() + + +async def get_group_path_dn_by_primary_group_id( + primary_group_id: int, + session: AsyncSession, +) -> str: + """Get group path DN by primary group ID. + + :param int primary_group_id: primary group ID + :param AsyncSession session: db session + :return str: group path DN + :raises ValueError: if no group found with the given primaryGroupID + """ + query = ( + select(Directory) + .join(qa(Directory.group)) + .options(contains_eager(qa(Directory.group))) + .filter(qa(Directory.object_sid).endswith(f"-{primary_group_id}")) + ) + + directory = await session.scalar(query) + if directory is None: + raise ValueError( + f"No group found with primaryGroupID '{primary_group_id}'.", + ) + + return directory.path_dn diff --git a/app/ldap_protocol/utils/raw_definition_parser.py b/app/ldap_protocol/utils/raw_definition_parser.py index 0d3ddfa27..4fa7361e0 100644 --- a/app/ldap_protocol/utils/raw_definition_parser.py +++ b/app/ldap_protocol/utils/raw_definition_parser.py @@ -59,6 +59,7 @@ def create_attribute_type_by_raw( single_value=attribute_type_info.single_value, no_user_modification=attribute_type_info.no_user_modification, is_system=True, + system_flags=0, is_included_anr=False, ) diff --git a/app/multidirectory.py b/app/multidirectory.py index 22a19259d..7e7c58862 100644 --- a/app/multidirectory.py +++ b/app/multidirectory.py @@ -50,6 +50,9 @@ MFAProvider, ) from ldap_protocol.dependency import resolve_deps +from ldap_protocol.dns.bind_to_pdns_migration_use_case import ( + BindToPDNSMigrationUseCase, +) from ldap_protocol.identity.exceptions import UnauthorizedError from ldap_protocol.policies.audit.events.handler import AuditEventHandler from ldap_protocol.policies.audit.events.sender import AuditEventSenderManager @@ -287,6 +290,18 @@ async def event_sender_factory(settings: Settings) -> None: await asyncio.gather(manager.run()) +async def migrate_dns_factory(settings: Settings) -> None: + """Run DNS migration.""" + main_container = make_async_container( + MainProvider(), + context={Settings: settings}, + ) + + async with main_container(scope=Scope.REQUEST) as container: + usecase = await container.get(BindToPDNSMigrationUseCase) + await usecase.migrate() + + ldap = partial(run_entrypoint, factory=ldap_factory) cldap = partial(run_entrypoint, factory=cldap_factory) global_ldap_server = partial( @@ -297,6 +312,7 @@ async def event_sender_factory(settings: Settings) -> None: create_shadow_app = partial(create_prod_app, factory=_create_shadow_app) event_handler = partial(run_entrypoint, factory=event_handler_factory) event_sender = partial(run_entrypoint, factory=event_sender_factory) +dns_migration = partial(run_entrypoint, factory=migrate_dns_factory) if __name__ == "__main__": @@ -334,6 +350,11 @@ async def event_sender_factory(settings: Settings) -> None: action="store_true", help="Make migrations", ) + group.add_argument( + "--migrate_dns", + action="store_true", + help="Migrate DNS from BIND to PowerDNS", + ) args = parser.parse_args() @@ -376,3 +397,5 @@ async def event_sender_factory(settings: Settings) -> None: dump_acme_cert() elif args.migrate: command.upgrade(Config("alembic.ini"), "head") + elif args.migrate_dns: + dns_migration(settings=settings) diff --git a/app/repo/pg/master_gateway.py b/app/repo/pg/master_gateway.py new file mode 100644 index 000000000..20476c8d3 --- /dev/null +++ b/app/repo/pg/master_gateway.py @@ -0,0 +1,33 @@ +"""Master DB Gateway. + +Copyright (c) 2026 MultiFactor +License: https://github.com/MultiDirectoryLab/MultiDirectory/blob/main/LICENSE +""" + +from loguru import logger +from sqlalchemy import text +from sqlalchemy.exc import OperationalError +from sqlalchemy.ext.asyncio import AsyncSession + +from config import Settings +from enums import PostgresRWModeType + + +class PGMasterGateway: + def __init__(self, session: AsyncSession, settings: Settings) -> None: + self._session = session + self._settings = settings + + async def check_master(self) -> bool: + if self._settings.POSTGRES_RW_MODE == PostgresRWModeType.SINGLE: + return True + + try: + self._session.sync_session.set_force_master(True) # type: ignore + await self._session.execute(text("SELECT 1")) + except OperationalError as e: + logger.error(f"Master DB check failed: {e}") + return False + else: + self._session.sync_session.set_force_master(False) # type: ignore + return True diff --git a/app/repo/pg/tables.py b/app/repo/pg/tables.py index 5391c95d5..a13db43ae 100644 --- a/app/repo/pg/tables.py +++ b/app/repo/pg/tables.py @@ -343,6 +343,7 @@ def _compile_create_uc( Column("single_value", Boolean, nullable=False), Column("no_user_modification", Boolean, nullable=False), Column("is_system", Boolean, nullable=False), + Column("system_flags", Integer, nullable=False, server_default=text("0")), Column("is_included_anr", Boolean, nullable=False), Index("idx_attribute_types_name_gin_trgm", "name", postgresql_using="gin"), ) diff --git a/app/schedule.py b/app/schedule.py index 35e59a85d..22fc26cd4 100644 --- a/app/schedule.py +++ b/app/schedule.py @@ -7,6 +7,7 @@ from loguru import logger from config import Settings +from extra.scripts.add_domain_controller import add_domain_controller from extra.scripts.check_ldap_principal import check_ldap_principal from extra.scripts.principal_block_user_sync import principal_block_sync from extra.scripts.uac_sync import disable_accounts @@ -27,6 +28,7 @@ (update_krb5_config, -1.0), (update_admin_permissions, -1.0), (update_status_process_events, 300.0), + (add_domain_controller, 600.0), } diff --git a/dnsdist.conf b/dnsdist.conf new file mode 100644 index 000000000..7446a4db5 --- /dev/null +++ b/dnsdist.conf @@ -0,0 +1,6 @@ +setLocal('0.0.0.0:53') +controlSocket('0.0.0.0:8084') +setKey('PSAag0AEziPZuBB7kdcfIEkVJOyQInRcBRAhadWDpU0=') +addConsoleACL('172.20.0.0/24') +includeDirectory('/etc/dnsdist/conf.d/') +setACL('0.0.0.0/0') diff --git a/docker-compose.dev.yml b/docker-compose.dev.yml index a28c8eae4..31c8dd8bc 100644 --- a/docker-compose.dev.yml +++ b/docker-compose.dev.yml @@ -2,6 +2,8 @@ services: traefik: image: "traefik:v3.6.1" container_name: traefik + networks: + md_net: restart: unless-stopped command: # - --metrics @@ -15,10 +17,10 @@ services: - "443:443" - "389:389" - "389:389/udp" + - "3268:3268" + - "3269:3269" - "636:636" - "749:749" - - "53:53" - - "53:53/udp" volumes: - "/var/run/docker.sock:/var/run/docker.sock:ro" - ./certs:/letsencrypt @@ -42,6 +44,8 @@ services: traefik_certs_dumper: image: multidirectory container_name: traefik_certs_dumper + networks: + md_net: restart: "on-failure" volumes: - ./certs:/certs @@ -54,6 +58,8 @@ services: interface: container_name: multidirectory_interface + networks: + md_net: # image: ghcr.io/multidirectorylab/multidirectory-web-admin::beta restart: unless-stopped build: @@ -79,6 +85,8 @@ services: cert_check: image: multidirectory container_name: multidirectory_certs_check + networks: + md_net: restart: "no" volumes: - ./certs:/certs @@ -93,6 +101,8 @@ services: DOCKER_BUILDKIT: 1 target: runtime image: multidirectory + networks: + md_net: restart: unless-stopped environment: - SERVICE_NAME=cldap_server @@ -129,6 +139,8 @@ services: DOCKER_BUILDKIT: 1 target: runtime image: multidirectory + networks: + md_net: restart: unless-stopped deploy: mode: replicated @@ -143,6 +155,7 @@ services: volumes: - ./app:/app - ./certs:/certs + - ./logs:/app/logs - ldap_keytab:/LDAP_keytab/ env_file: local.env command: python -OO multidirectory.py --global_ldap_server @@ -165,7 +178,7 @@ services: - traefik.tcp.routers.global_ldap.entrypoints=global_ldap - traefik.tcp.routers.global_ldap.service=global_ldap - traefik.tcp.services.global_ldap.loadbalancer.server.port=3268 - - traefik.tcp.services.global_ldap.loadbalancer.proxyprotocol.version=2 + - traefik.tcp.services.global_ldap.loadbalancer.serversTransport=ldap_transport@file - traefik.tcp.routers.global_ldap_tls.rule=HostSNI(`*`) - traefik.tcp.routers.global_ldap_tls.entrypoints=global_ldap_tls @@ -173,11 +186,13 @@ services: - traefik.tcp.routers.global_ldap_tls.tls=true - traefik.tcp.routers.global_ldap_tls.tls.certresolver=md-resolver - traefik.tcp.services.global_ldap_tls.loadbalancer.server.port=3269 - - traefik.tcp.services.global_ldap_tls.loadbalancer.proxyprotocol.version=2 + - traefik.tcp.services.global_ldap_tls.loadbalancer.serversTransport=ldap_transport@file cert_local_check: image: multidirectory container_name: multidirectory_local_certs_check + networks: + md_net: restart: "no" volumes: - ./certs:/certs @@ -186,6 +201,8 @@ services: migrations: image: multidirectory container_name: multidirectory_migrations + networks: + md_net: restart: "no" command: python multidirectory.py --migrate env_file: @@ -205,6 +222,8 @@ services: DOCKER_BUILDKIT: 1 target: runtime image: multidirectory + networks: + md_net: restart: unless-stopped hostname: multidirectory volumes: @@ -236,7 +255,7 @@ services: - traefik.tcp.routers.ldap.entrypoints=ldap - traefik.tcp.routers.ldap.service=ldap - traefik.tcp.services.ldap.loadbalancer.server.port=389 - - traefik.tcp.services.ldap.loadbalancer.proxyprotocol.version=2 + - traefik.tcp.services.ldap.loadbalancer.serversTransport=ldap_transport@file - traefik.tcp.routers.ldaps.rule=HostSNI(`*`) - traefik.tcp.routers.ldaps.entrypoints=ldaps @@ -244,7 +263,7 @@ services: - traefik.tcp.routers.ldaps.tls=true - traefik.tcp.routers.ldaps.tls.certresolver=md-resolver - traefik.tcp.services.ldaps.loadbalancer.server.port=636 - - traefik.tcp.services.ldaps.loadbalancer.proxyprotocol.version=2 + - traefik.tcp.services.ldaps.loadbalancer.serversTransport=ldap_transport@file healthcheck: test: ["CMD-SHELL", "nc -zv 127.0.0.1 389 636"] interval: 30s @@ -260,6 +279,8 @@ services: USE_CORE_TLS: 1 SERVICE_NAME: multidirectory_api KRB5_LDAP_URI: ldap://ldap_server + networks: + md_net: hostname: api_server env_file: local.env @@ -288,6 +309,8 @@ services: shadow_api: image: multidirectory container_name: shadow_api + networks: + md_net: restart: unless-stopped tty: true depends_on: @@ -308,6 +331,8 @@ services: maintence: image: multidirectory container_name: md_maintence + networks: + md_net: restart: unless-stopped volumes: - ./certs:/certs @@ -332,6 +357,8 @@ services: postgres: container_name: MD-postgres + networks: + md_net: image: postgres:16 restart: unless-stopped environment: @@ -356,6 +383,8 @@ services: args: VERSION: beta container_name: kdc + networks: + md_net: restart: unless-stopped hostname: kerberos volumes: @@ -375,6 +404,8 @@ services: kadmin_api: image: krb5md container_name: kadmin-api + networks: + md_net: restart: unless-stopped volumes: - ./certs:/certs @@ -395,6 +426,8 @@ services: kadmind: container_name: kadmind + networks: + md_net: restart: unless-stopped hostname: kerberos volumes: @@ -421,6 +454,8 @@ services: dragonfly: image: 'docker.dragonflydb.io/dragonflydb/dragonfly' container_name: dragonfly + networks: + md_net: restart: always expose: - 6379 @@ -433,30 +468,91 @@ services: cpus: '0.25' memory: 0.5GiB - bind_dns: + kea_dhcp4: + image: kea_image:0.1 + network_mode: host + cap_add: + - NET_ADMIN build: context: . - dockerfile: ./.docker/bind9.Dockerfile - image: bind9md - container_name: bind9 - hostname: bind9 + dockerfile: ./.docker/kea.Dockerfile + container_name: kea_dhcp4 + tty: true restart: unless-stopped - environment: - - DEFAULT_NAMESERVER=192.168.69.241 - - USE_CONFIG_FILE_LOGGING=true + command: -c /kea/config/kea-dhcp4.conf volumes: - - dns_server_file:/opt/ - - dns_server_config:/etc/bind/ - - .dns/:/server/ + - dhcp:/kea/config + - sockets:/kea/sockets + - leases:/kea/leases + + kea_ctrl_agent: + image: jonasal/kea-ctrl-agent:3.1.2-alpine + container_name: kea_ctrl_agent + networks: + md_net: + restart: unless-stopped + command: -c /kea/config/kea-ctrl-agent.conf tty: true depends_on: - ldap_server: - condition: service_healthy - restart: true - labels: - - traefik.enable=true - - traefik.udp.routers.bind_dns_udp.entrypoints=bind_dns_udp - - traefik.udp.services.bind_dns_udp.loadbalancer.server.port=53 + kea_dhcp4: + condition: service_started + volumes: + - ./.package/kea-ctrl-agent.conf:/kea/config/kea-ctrl-agent.conf + - sockets:/kea/sockets + - leases:/kea/leases + + pdns_auth: + build: + context: . + dockerfile: ./.docker/pdns_auth.Dockerfile + args: + DOCKER_BUILDKIT: 1 + image: pdns_auth_md + container_name: pdns_auth + networks: + default: + md_net: + ipv4_address: 172.20.0.202 + expose: + - 8082 + - 53/udp + - 53/tcp + volumes: + - dns_lmdb:/var/lib/pdns-lmdb + - dns_config:/etc/powerdns + + + pdnsdist: + image: powerdns/dnsdist-19:1.9.11 + container_name: pdnsdist + networks: + default: + md_net: + ipv4_address: 172.20.0.201 + expose: + - 8084 + ports: + - "53:53/tcp" + - "53:53/udp" + volumes: + - ./dnsdist.conf:/etc/dnsdist/dnsdist.conf + - dnsdist_confd:/etc/dnsdist/conf.d + + + pdns_recursor: + image: powerdns/pdns-recursor-51:5.1.7 + container_name: pdns_recursor + networks: + default: + md_net: + ipv4_address: 172.20.0.200 + expose: + - 8083 + - 53/udp + - 53/tcp + volumes: + - ./.package/recursor.conf:/etc/powerdns/recursor.conf + - forward_zones:/etc/powerdns/recursor.d/ event_handler: image: multidirectory @@ -491,7 +587,13 @@ services: environment: HANDLER_NAME: event_sender-1 - +networks: + md_net: + driver: bridge + ipam: + config: + - subnet: 172.20.0.0/24 + gateway: 172.20.0.1 volumes: postgres: @@ -500,5 +602,13 @@ volumes: kdc: dns_server_file: dns_server_config: + kea_ctrl_agent: ldap_keytab: dragonflydata: + dnsdist_confd: + dns_lmdb: + dns_config: + dhcp: + sockets: + leases: + forward_zones: diff --git a/docker-compose.remote.test.yml b/docker-compose.remote.test.yml index bd5659b02..4c4556cee 100644 --- a/docker-compose.remote.test.yml +++ b/docker-compose.remote.test.yml @@ -5,6 +5,10 @@ services: environment: DEBUG: 1 DOMAIN: md.test + PDNS_API_KEY: testkey123 + PDNS_DIST_KEY: testkey123 + DEFAULT_NAMESERVER: 127.0.0.1 + HOST_MACHINE_NAME: DC1 POSTGRES_USER: user1 POSTGRES_PASSWORD: password123 SECRET_KEY: 6a0452ae20cab4e21b6e9d18fa4b7bf397dd66ec3968b2d7407694278fd84cce diff --git a/docker-compose.test.yml b/docker-compose.test.yml index 96076b657..2c0ac63d2 100644 --- a/docker-compose.test.yml +++ b/docker-compose.test.yml @@ -14,8 +14,12 @@ services: environment: DEBUG: 1 DOMAIN: md.test + HOST_MACHINE_NAME: DC1 + DEFAULT_NAMESERVER: 127.0.0.1 POSTGRES_USER: user1 POSTGRES_PASSWORD: password123 + PDNS_API_KEY: testkey123 + PDNS_DIST_KEY: testkey123 SECRET_KEY: 6a0452ae20cab4e21b6e9d18fa4b7bf397dd66ec3968b2d7407694278fd84cce POSTGRES_HOST: postgres # PYTHONTRACEMALLOC: 1 diff --git a/docker-compose.yml b/docker-compose.yml index 543042d34..47dc60214 100644 --- a/docker-compose.yml +++ b/docker-compose.yml @@ -5,6 +5,8 @@ services: traefik: image: "traefik:v3.6.1" container_name: "traefik" + networks: + md_net: command: - "--providers.file.filename=/traefik.yml" ports: @@ -17,8 +19,6 @@ services: - "636:636" - "749:749" - "464:464" - - "53:53" - - "53:53/udp" volumes: - "/var/run/docker.sock:/var/run/docker.sock:ro" - "./certs:/certs" @@ -32,6 +32,8 @@ services: DOCKER_BUILDKIT: 1 target: runtime image: multidirectory + networks: + md_net: restart: unless-stopped environment: - SERVICE_NAME=ldap_server @@ -63,14 +65,14 @@ services: - traefik.tcp.routers.ldap.entrypoints=ldap - traefik.tcp.routers.ldap.service=ldap - traefik.tcp.services.ldap.loadbalancer.server.port=389 - - traefik.tcp.services.ldap.loadbalancer.proxyprotocol.version=2 + - traefik.tcp.services.ldap.loadbalancer.serversTransport=ldap_transport@file - traefik.tcp.routers.ldaps.rule=HostSNI(`*`) - traefik.tcp.routers.ldaps.entrypoints=ldaps - traefik.tcp.routers.ldaps.service=ldaps - traefik.tcp.routers.ldaps.tls=true - traefik.tcp.services.ldaps.loadbalancer.server.port=636 - - traefik.tcp.services.ldaps.loadbalancer.proxyprotocol.version=2 + - traefik.tcp.services.ldaps.loadbalancer.serversTransport=ldap_transport@file healthcheck: test: ["CMD-SHELL", "nc -zv 127.0.0.1 389 636"] interval: 30s @@ -86,6 +88,8 @@ services: DOCKER_BUILDKIT: 1 target: runtime image: multidirectory + networks: + md_net: user: root restart: unless-stopped environment: @@ -123,6 +127,8 @@ services: DOCKER_BUILDKIT: 1 target: runtime image: multidirectory + networks: + md_net: restart: unless-stopped deploy: mode: replicated @@ -137,6 +143,7 @@ services: volumes: - ./app:/app - ./certs:/certs + - ./logs:/app/logs - ldap_keytab:/LDAP_keytab/ env_file: local.env command: python -OO multidirectory.py --global_ldap_server @@ -159,24 +166,26 @@ services: - traefik.tcp.routers.global_ldap.entrypoints=global_ldap - traefik.tcp.routers.global_ldap.service=global_ldap - traefik.tcp.services.global_ldap.loadbalancer.server.port=3268 - - traefik.tcp.services.global_ldap.loadbalancer.proxyprotocol.version=2 + - traefik.tcp.services.global_ldap.loadbalancer.serversTransport=ldap_transport@file - traefik.tcp.routers.global_ldap_tls.rule=HostSNI(`*`) - traefik.tcp.routers.global_ldap_tls.entrypoints=global_ldap_tls - traefik.tcp.routers.global_ldap_tls.service=global_ldap_tls - traefik.tcp.routers.global_ldap_tls.tls=true - traefik.tcp.services.global_ldap_tls.loadbalancer.server.port=3269 - - traefik.tcp.services.global_ldap_tls.loadbalancer.proxyprotocol.version=2 - + - traefik.tcp.services.global_ldap_tls.loadbalancer.serversTransport=ldap_transport@file api: image: multidirectory container_name: multidirectory_api + networks: + md_net: volumes: - ./app:/app - ./certs:/certs - dns_server_file:/DNS_server_file/ - dns_server_config:/DNS_server_configs/ - ldap_keytab:/LDAP_keytab/ + - dnsdist_confd:/dnsdist env_file: local.env command: python multidirectory.py --http environment: @@ -190,7 +199,6 @@ services: - "traefik.http.routers.api.service=api" - "traefik.http.routers.api.middlewares=api_strip" - "traefik.http.middlewares.api_strip.stripprefix.prefixes=/api" - - "traefik.http.middlewares.api_strip.stripprefix.forceslash=false" depends_on: migrations: condition: service_completed_successfully @@ -205,6 +213,8 @@ services: migrations: image: multidirectory container_name: multidirectory_migrations + networks: + md_net: restart: "no" volumes: - ./app:/app @@ -218,6 +228,8 @@ services: cert_check: image: multidirectory container_name: multidirectory_certs_check + networks: + md_net: restart: "no" volumes: - ./certs:/certs @@ -227,6 +239,8 @@ services: cert_local_check: image: multidirectory container_name: multidirectory_local_certs_check + networks: + md_net: restart: "no" volumes: - ./certs:/certs @@ -234,6 +248,8 @@ services: postgres: container_name: MD-postgres + networks: + md_net: image: postgres:16 restart: unless-stopped ports: @@ -253,6 +269,8 @@ services: pgadmin: container_name: pgadmin_container + networks: + md_net: image: dpage/pgadmin4 environment: PGADMIN_DEFAULT_EMAIL: ${PGADMIN_DEFAULT_EMAIL:-pgadmin4@pgadmin.org} @@ -272,6 +290,8 @@ services: kadmin_api: image: krb5md container_name: kadmin_api + networks: + md_net: restart: unless-stopped volumes: - ./certs:/certs @@ -291,31 +311,6 @@ services: working_dir: /server command: ./entrypoint.sh - bind_dns: - build: - context: . - dockerfile: ./.docker/bind9.Dockerfile - image: bind9md - container_name: bind9 - hostname: bind9 - restart: unless-stopped - environment: - - DEFAULT_NAMESERVER=127.0.0.2 - - USE_CONFIG_FILE_LOGGING=true - volumes: - - dns_server_file:/opt/ - - dns_server_config:/etc/bind/ - - .dns/:/server/ - tty: true - depends_on: - ldap_server: - condition: service_healthy - restart: true - labels: - - traefik.enable=true - - traefik.udp.routers.bind_dns_udp.entrypoints=bind_dns_udp - - traefik.udp.services.bind_dns_udp.loadbalancer.server.port=53 - kea_dhcp4: image: kea_image:0.1 network_mode: host @@ -336,6 +331,8 @@ services: kea_ctrl_agent: image: jonasal/kea-ctrl-agent:3.1.2-alpine container_name: kea_ctrl_agent + networks: + md_net: restart: unless-stopped command: -c /kea/config/kea-ctrl-agent.conf tty: true @@ -354,6 +351,8 @@ services: args: VERSION: beta container_name: kdc + networks: + md_net: hostname: kerberos restart: unless-stopped volumes: @@ -372,6 +371,8 @@ services: kadmind: container_name: kadmind + networks: + md_net: restart: unless-stopped hostname: kerberos volumes: @@ -400,6 +401,8 @@ services: shadow_api: image: multidirectory container_name: shadow_api + networks: + md_net: restart: unless-stopped tty: true depends_on: @@ -421,6 +424,8 @@ services: maintence: image: multidirectory container_name: md_maintence + networks: + md_net: restart: unless-stopped volumes: - ./certs:/certs @@ -444,6 +449,8 @@ services: interface: container_name: multidirectory_interface + networks: + md_net: build: context: ./interface dockerfile: configurations/docker/Dockerfile.dev @@ -466,6 +473,8 @@ services: dragonfly: image: "docker.dragonflydb.io/dragonflydb/dragonfly" container_name: dragonfly + networks: + md_net: expose: - 6379 deploy: @@ -479,6 +488,8 @@ services: redis-commander: container_name: redis-commander + networks: + md_net: hostname: redis-commander image: ghcr.io/joeferner/redis-commander:latest restart: always @@ -495,6 +506,8 @@ services: event_handler: image: multidirectory container_name: event_handler + networks: + md_net: restart: unless-stopped tty: true env_file: local.env @@ -510,6 +523,8 @@ services: event_sender: image: multidirectory container_name: event_sender + networks: + md_net: restart: unless-stopped tty: true depends_on: @@ -526,12 +541,72 @@ services: syslog: image: balabit/syslog-ng:latest container_name: syslog-server + networks: + md_net: volumes: - ./syslog:/var/log - ./syslog-ng.conf:/etc/syslog-ng/syslog-ng.conf privileged: true restart: always + pdns_auth: + build: + context: . + dockerfile: ./.docker/pdns_auth.Dockerfile + args: + DOCKER_BUILDKIT: 1 + image: pdns_auth_md + container_name: pdns_auth + networks: + md_net: + ipv4_address: 172.20.0.202 + expose: + - 8082 + - 53/udp + - 53/tcp + volumes: + - dns_lmdb:/var/lib/pdns-lmdb + - dns_config:/etc/powerdns + + + pdnsdist: + image: powerdns/dnsdist-19:1.9.11 + container_name: pdnsdist + networks: + md_net: + ipv4_address: 172.20.0.201 + expose: + - 8084 + ports: + - "53:53/tcp" + - "53:53/udp" + volumes: + - ./dnsdist.conf:/etc/dnsdist/dnsdist.conf + - dnsdist_confd:/etc/dnsdist/conf.d + + + pdns_recursor: + image: powerdns/pdns-recursor-51:5.1.7 + container_name: pdns_recursor + networks: + md_net: + ipv4_address: 172.20.0.200 + expose: + - 8083 + - 53/udp + - 53/tcp + volumes: + - ./.package/recursor.conf:/etc/powerdns/recursor.conf + - forward_zones:/etc/powerdns/recursor.d/ + +networks: + md_net: + driver: bridge + ipam: + config: + - subnet: 172.20.0.0/24 + gateway: 172.20.0.1 + volumes: postgres: pgadmin: @@ -544,3 +619,7 @@ volumes: leases: sockets: dhcp: + dns_lmdb: + dns_config: + forward_zones: + dnsdist_confd: diff --git a/interface b/interface deleted file mode 160000 index f31962020..000000000 --- a/interface +++ /dev/null @@ -1 +0,0 @@ -Subproject commit f31962020a6689e6a4c61fb3349db5b5c7895f92 diff --git a/local.env b/local.env index 8eb377378..eb0aa533e 100644 --- a/local.env +++ b/local.env @@ -1,8 +1,12 @@ DEBUG=1 AUTO_RELOAD=1 DOMAIN=md.localhost +HOST_MACHINE_NAME=DC1 POSTGRES_USER=user1 POSTGRES_PASSWORD=password123 SECRET_KEY=6a0452ae20cab4e21b6e9d18fa4b7bf397dd66ec3968b2d7407694278fd84cce MFA_API_SOURCE=dev ACCESS_TOKEN_EXPIRE_MINUTES=180 +DEFAULT_NAMESERVER=172.20.0.4 +PDNS_API_KEY=supersecretapikey +PDNS_DIST_KEY=PSAag0AEziPZuBB7kdcfIEkVJOyQInRcBRAhadWDpU0= diff --git a/logs/.gitignore b/logs/.gitignore new file mode 100644 index 000000000..e69de29bb diff --git a/pyproject.toml b/pyproject.toml index f7adf0e26..a2ccd8416 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -18,6 +18,7 @@ dependencies = [ "bcrypt==4.0.1", "cryptography>=44.0.1", "dishka>=1.6.0", + "dnsdist-console>=1.6.0", "dnspython>=2.7.0", "fastapi>=0.115.0", "fastapi-error-map>=0.9.8", diff --git a/tests/conftest.py b/tests/conftest.py index c9ba0f8ff..9be038db5 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -7,7 +7,6 @@ import asyncio import os import uuid -import weakref from contextlib import suppress from dataclasses import dataclass from typing import AsyncGenerator, AsyncIterator, Generator, Iterator @@ -50,10 +49,10 @@ ) from api.auth.utils import get_ip_from_request, get_user_agent_from_request from api.dhcp.adapter import DHCPAdapter +from api.dns.adapter import DNSFastAPIAdapter from api.ldap_schema.adapters.attribute_type import AttributeTypeFastAPIAdapter from api.ldap_schema.adapters.entity_type import LDAPEntityTypeFastAPIAdapter from api.ldap_schema.adapters.object_class import ObjectClassFastAPIAdapter -from api.main.adapters.dns import DNSFastAPIAdapter from api.main.adapters.kerberos import KerberosFastAPIAdapter from api.network.adapters.network import NetworkPolicyFastAPIAdapter from api.password_policy.adapter import ( @@ -74,11 +73,10 @@ from ldap_protocol.dialogue import LDAPSession from ldap_protocol.dns import ( AbstractDNSManager, - DNSManagerSettings, + DNSSettingsDTO, StubDNSManager, ) from ldap_protocol.dns.dns_gateway import DNSStateGateway -from ldap_protocol.dns.dto import DNSSettingDTO from ldap_protocol.dns.use_cases import DNSUseCase from ldap_protocol.identity import IdentityProvider from ldap_protocol.identity.provider_gateway import IdentityProviderGateway @@ -99,6 +97,9 @@ LDAPUnbindRequestContext, ) from ldap_protocol.ldap_schema.attribute_type_dao import AttributeTypeDAO +from ldap_protocol.ldap_schema.attribute_type_system_flags_use_case import ( + AttributeTypeSystemFlagsUseCase, +) from ldap_protocol.ldap_schema.attribute_type_use_case import ( AttributeTypeUseCase, ) @@ -110,6 +111,10 @@ from ldap_protocol.ldap_schema.entity_type_use_case import EntityTypeUseCase from ldap_protocol.ldap_schema.object_class_dao import ObjectClassDAO from ldap_protocol.ldap_schema.object_class_use_case import ObjectClassUseCase +from ldap_protocol.master_check_use_case import ( + MasterCheckUseCase, + MasterGatewayProtocol, +) from ldap_protocol.multifactor import LDAPMultiFactorAPI, MultifactorAPI from ldap_protocol.permissions_checker import AuthorizationProvider from ldap_protocol.policies.audit.audit_use_case import AuditUseCase @@ -157,6 +162,7 @@ from ldap_protocol.session_storage.repository import SessionRepository from ldap_protocol.utils.queries import get_user from password_utils import PasswordUtils +from repo.pg.master_gateway import PGMasterGateway from tests.constants import TEST_DATA @@ -188,7 +194,7 @@ async def get_kadmin(self) -> AsyncIterator[AsyncMock]: kadmin.get_status = AsyncMock(return_value=False) kadmin.add_principal = AsyncMock() kadmin.del_principal = AsyncMock() - kadmin.rename_princ = AsyncMock() + kadmin.modify_princ = AsyncMock() kadmin.create_or_update_principal_pw = AsyncMock() kadmin.change_principal_password = AsyncMock() kadmin.lock_principal = AsyncMock() @@ -201,7 +207,7 @@ async def get_kadmin(self) -> AsyncIterator[AsyncMock]: self._cached_kadmin = None - @provide(scope=Scope.REQUEST, provides=AbstractDHCPManager) + @provide(scope=Scope.APP, provides=AbstractDHCPManager) async def get_dhcp_mngr(self) -> AsyncIterator[AsyncMock]: """Get mock DHCP manager.""" dhcp_manager = AsyncMock(spec=StubDHCPManager) @@ -213,60 +219,58 @@ async def get_dhcp_mngr(self) -> AsyncIterator[AsyncMock]: self._cached_dhcp_manager = None - @provide(scope=Scope.REQUEST, provides=AbstractDNSManager) + @provide(scope=Scope.APP, provides=AbstractDNSManager) async def get_dns_mngr(self) -> AsyncIterator[AsyncMock]: """Get mock DNS manager.""" dns_manager = AsyncMock(spec=StubDNSManager) - dns_manager.setup.return_value = DNSSettingDTO( - zone_name="example.com", - dns_server_ip="127.0.0.1", - tsig_key=None, - ) - dns_manager.get_all_records.return_value = [ + dns_manager.get_records.return_value = [ { + "name": "example.com", "type": "A", "records": [ { - "name": "example.com", - "value": "127.0.0.1", - "ttl": 3600, + "content": "127.0.0.1", + "disabled": False, + "modified_at": None, }, ], - }, - ] - dns_manager.get_server_options.return_value = [ - { - "name": "dnssec-validation", - "value": "no", + "ttl": 3600, }, ] dns_manager.get_forward_zones.return_value = [ { - "name": "test.local", - "type": "forward", - "forwarders": [ - "127.0.0.1", - "127.0.0.2", - ], + "id": "forward1", + "name": "forward1.", + "rrsets": [], + "kind": "Forwarded", + "type": "zone", + "servers": ["127.0.0.1"], + "recursion_desired": False, }, ] - dns_manager.get_all_zones_records.return_value = [ + dns_manager.get_master_zones.return_value = [ { - "name": "test.local", - "type": "master", - "records": [ + "id": "zone1", + "name": "example.com.", + "rrsets": [ { + "name": "example.com", "type": "A", "records": [ { - "name": "example.com", - "value": "127.0.0.1", - "ttl": 3600, + "content": "127.0.0.1", + "disabled": False, + "modified_at": None, }, ], + "ttl": 3600, }, ], + "dnssec": False, + "nameservers": ["ns1.example.com."], + "kind": "Master", + "type": "zone", }, ] @@ -277,37 +281,26 @@ async def get_dns_mngr(self) -> AsyncIterator[AsyncMock]: self._cached_dns_manager = None - @provide(scope=Scope.REQUEST, provides=DNSManagerSettings, cache=False) + @provide(scope=Scope.REQUEST, provides=DNSSettingsDTO, cache=False) async def get_dns_mngr_settings( self, dns_state_gateway: DNSStateGateway, - ) -> AsyncIterator["DNSManagerSettings"]: + settings: Settings, + root_dse_gw: DomainReadProtocol, + ) -> AsyncIterator["DNSSettingsDTO"]: """Get DNS manager's settings.""" + domain = await root_dse_gw.get_domain() + yield await dns_state_gateway.get_dns_manager_settings( + settings, + domain.name, + ) - async def resolve() -> str: - return "127.0.0.1" - - resolver = resolve() - yield await dns_state_gateway.get_dns_manager_settings(resolver) - weakref.finalize(resolver, resolver.close) - - @provide(scope=Scope.REQUEST, provides=AttributeTypeDAO, cache=False) - def get_attribute_type_dao( - self, - session: AsyncSession, - ) -> AttributeTypeDAO: - """Get Attribute Type DAO.""" - return AttributeTypeDAO(session) - - @provide(scope=Scope.REQUEST, provides=ObjectClassDAO, cache=False) - def get_object_class_dao(self, session: AsyncSession) -> ObjectClassDAO: - """Get Object Class DAO.""" - return ObjectClassDAO(session=session) - - get_entity_type_dao = provide( - EntityTypeDAO, + attribute_type_dao = provide(AttributeTypeDAO, scope=Scope.REQUEST) + object_class_dao = provide(ObjectClassDAO, scope=Scope.REQUEST) + entity_type_dao = provide(EntityTypeDAO, scope=Scope.REQUEST) + attribute_type_system_flags_use_case = provide( + AttributeTypeSystemFlagsUseCase, scope=Scope.REQUEST, - cache=False, ) attribute_type_use_case = provide( AttributeTypeUseCase, @@ -467,6 +460,19 @@ async def get_redis_for_sessions( with suppress(RuntimeError): await client.aclose() + @provide(scope=Scope.REQUEST, provides=MasterGatewayProtocol) + async def get_master_gateway( + self, + session: AsyncSession, + settings: Settings, + ) -> PGMasterGateway: + return PGMasterGateway(session, settings) + + master_check_use_case = provide( + MasterCheckUseCase, + scope=Scope.REQUEST, + ) + @provide(scope=Scope.APP) async def get_session_storage( self, @@ -1003,7 +1009,7 @@ async def setup_session( name="TEST ONLY LOGIN ROLE", creator_upn=None, is_system=True, - groups=["cn=admin login only,cn=groups,dc=md,dc=test"], + groups=["cn=admin login only,cn=Groups,dc=md,dc=test"], permissions=AuthorizationRules.AUTH_LOGIN, ), ) @@ -1064,6 +1070,15 @@ async def network_policy_gateway( yield await container.get(NetworkPolicyGateway) +@pytest_asyncio.fixture(scope="function") +async def network_policy_use_case( + container: AsyncContainer, +) -> AsyncIterator[NetworkPolicyUseCase]: + """Get network policy gateway.""" + async with container(scope=Scope.REQUEST) as container: + yield await container.get(NetworkPolicyUseCase) + + @pytest_asyncio.fixture(scope="function") async def network_policy_validator( container: AsyncContainer, diff --git a/tests/constants.py b/tests/constants.py index 5542e0742..ab5ffb954 100644 --- a/tests/constants.py +++ b/tests/constants.py @@ -11,6 +11,7 @@ GROUPS_CONTAINER_NAME, USERS_CONTAINER_NAME, ) +from enums import SamAccountTypeCodes from ldap_protocol.objects import UserAccountControlFlag TEST_DATA = [ @@ -30,8 +31,11 @@ "groupType": ["-2147483646"], "instanceType": ["4"], "sAMAccountName": [DOMAIN_ADMIN_GROUP_NAME], - "sAMAccountType": ["268435456"], + "sAMAccountType": [ + str(SamAccountTypeCodes.SAM_GROUP_OBJECT.value), + ], }, + "objectSid": 512, }, { "name": "developers", @@ -42,7 +46,9 @@ "groupType": ["-2147483646"], "instanceType": ["4"], "sAMAccountName": ["developers"], - "sAMAccountType": ["268435456"], + "sAMAccountType": [ + str(SamAccountTypeCodes.SAM_GROUP_OBJECT.value), + ], }, }, { @@ -53,7 +59,9 @@ "groupType": ["-2147483646"], "instanceType": ["4"], "sAMAccountName": ["admin login only"], - "sAMAccountType": ["268435456"], + "sAMAccountType": [ + str(SamAccountTypeCodes.SAM_GROUP_OBJECT.value), + ], }, }, { @@ -64,7 +72,9 @@ "groupType": ["-2147483646"], "instanceType": ["4"], "sAMAccountName": [DOMAIN_USERS_GROUP_NAME], - "sAMAccountType": ["268435456"], + "sAMAccountType": [ + str(SamAccountTypeCodes.SAM_GROUP_OBJECT.value), + ], }, }, { @@ -75,7 +85,9 @@ "groupType": ["-2147483646"], "instanceType": ["4"], "sAMAccountName": [DOMAIN_COMPUTERS_GROUP_NAME], - "sAMAccountType": ["268435456"], + "sAMAccountType": [ + str(SamAccountTypeCodes.SAM_GROUP_OBJECT.value), + ], }, }, ], @@ -368,7 +380,11 @@ "groupType": ["-2147483646"], "instanceType": ["4"], "sAMAccountName": ["testGroup1"], - "sAMAccountType": ["268435456"], + "sAMAccountType": [ + str( + SamAccountTypeCodes.SAM_GROUP_OBJECT.value, + ), + ], }, }, ], @@ -381,7 +397,9 @@ "groupType": ["-2147483646"], "instanceType": ["4"], "sAMAccountName": ["testGroup2"], - "sAMAccountType": ["268435456"], + "sAMAccountType": [ + str(SamAccountTypeCodes.SAM_GROUP_OBJECT.value), + ], }, }, ], @@ -402,7 +420,9 @@ "groupType": ["-2147483646"], "instanceType": ["4"], "sAMAccountName": ["testGroup3"], - "sAMAccountType": ["268435456"], + "sAMAccountType": [ + str(SamAccountTypeCodes.SAM_GROUP_OBJECT.value), + ], }, }, ], diff --git a/tests/search_request_datasets.py b/tests/search_request_datasets.py index 77557bae9..cb1e7a317 100644 --- a/tests/search_request_datasets.py +++ b/tests/search_request_datasets.py @@ -17,28 +17,28 @@ test_search_by_rule_anr_dataset = [ # with split by space - {"filter": "(anr=Joh Lenno)", "objects": ["cn=user0,cn=users,dc=md,dc=test"]}, # noqa: E501 - {"filter": "(anr=Lennon John)", "objects": ["cn=user0,cn=users,dc=md,dc=test"]}, # noqa: E501 - {"filter": "(anr=John Lennon)", "objects": ["cn=user0,cn=users,dc=md,dc=test"]}, # noqa: E501 - {"filter": "(anr=john lennon)", "objects": ["cn=user0,cn=users,dc=md,dc=test"]}, # noqa: E501 - {"filter": "(anr==Lennon John)", "objects": ["cn=user0,cn=users,dc=md,dc=test"]}, # noqa: E501 + {"filter": "(anr=Joh Lenno)", "objects": ["cn=user0,cn=Users,dc=md,dc=test"]}, # noqa: E501 + {"filter": "(anr=Lennon John)", "objects": ["cn=user0,cn=Users,dc=md,dc=test"]}, # noqa: E501 + {"filter": "(anr=John Lennon)", "objects": ["cn=user0,cn=Users,dc=md,dc=test"]}, # noqa: E501 + {"filter": "(anr=john lennon)", "objects": ["cn=user0,cn=Users,dc=md,dc=test"]}, # noqa: E501 + {"filter": "(anr==Lennon John)", "objects": ["cn=user0,cn=Users,dc=md,dc=test"]}, # noqa: E501 # without split by space - {"filter": "(anr=user0)", "objects": ["cn=user0,cn=users,dc=md,dc=test"]}, - {"filter": "(anr=user0*)", "objects": ["cn=user0,cn=users,dc=md,dc=test"]}, - {"filter": "(anr>=user0)", "objects": ["cn=user0,cn=users,dc=md,dc=test"]}, - {"filter": "(anr<=user0)", "objects": ["cn=user0,cn=users,dc=md,dc=test"]}, - {"filter": "(anr~=user0)", "objects": ["cn=user0,cn=users,dc=md,dc=test"]}, - {"filter": "(anr==user0)", "objects": ["cn=user0,cn=users,dc=md,dc=test"]}, - {"filter": "(anr==user0*)", "objects": ["cn=user0,cn=users,dc=md,dc=test"]}, # noqa: E501 - {"filter": "(aNR=user0*)", "objects": ["cn=user0,cn=users,dc=md,dc=test"]}, - {"filter": "(anr=uSEr0*)", "objects": ["cn=user0,cn=users,dc=md,dc=test"]}, - {"filter": "(anr=domain admins)", "objects": ["cn=domain admins,cn=groups,dc=md,dc=test"]}, # noqa: E501 + {"filter": "(anr=user0)", "objects": ["cn=user0,cn=Users,dc=md,dc=test"]}, + {"filter": "(anr=user0*)", "objects": ["cn=user0,cn=Users,dc=md,dc=test"]}, + {"filter": "(anr>=user0)", "objects": ["cn=user0,cn=Users,dc=md,dc=test"]}, + {"filter": "(anr<=user0)", "objects": ["cn=user0,cn=Users,dc=md,dc=test"]}, + {"filter": "(anr~=user0)", "objects": ["cn=user0,cn=Users,dc=md,dc=test"]}, + {"filter": "(anr==user0)", "objects": ["cn=user0,cn=Users,dc=md,dc=test"]}, + {"filter": "(anr==user0*)", "objects": ["cn=user0,cn=Users,dc=md,dc=test"]}, # noqa: E501 + {"filter": "(aNR=user0*)", "objects": ["cn=user0,cn=Users,dc=md,dc=test"]}, + {"filter": "(anr=uSEr0*)", "objects": ["cn=user0,cn=Users,dc=md,dc=test"]}, + {"filter": "(anr=domain admins)", "objects": ["cn=domain admins,cn=Groups,dc=md,dc=test"]}, # noqa: E501 {"filter": "(anr=user_admin_3@mail.com)", "objects": ["cn=user_admin_3,ou=test_bit_rules,dc=md,dc=test"]}, # noqa: E501 { "filter": "(anr=user_admin_*)", "objects": [ - "cn=user_admin,cn=users,dc=md,dc=test", - "cn=user_admin_for_roles,cn=users,dc=md,dc=test", + "cn=user_admin,cn=Users,dc=md,dc=test", + "cn=user_admin_for_roles,cn=Users,dc=md,dc=test", "cn=user_admin_1,ou=test_bit_rules,dc=md,dc=test", "cn=user_admin_2,ou=test_bit_rules,dc=md,dc=test", "cn=user_admin_3,ou=test_bit_rules,dc=md,dc=test", @@ -50,11 +50,11 @@ { "filter": f"(useraccountcontrol:1.2.840.113556.1.4.803:={UserAccountControlFlag.NORMAL_ACCOUNT})", # noqa: E501 "objects": [ - "cn=user0,cn=users,dc=md,dc=test", - "cn=user_admin,cn=users,dc=md,dc=test", - "cn=user_admin_for_roles,cn=users,dc=md,dc=test", - "cn=user_non_admin,cn=users,dc=md,dc=test", - "cn=user1,cn=moscow,cn=russia,cn=users,dc=md,dc=test", + "cn=user0,cn=Users,dc=md,dc=test", + "cn=user_admin,cn=Users,dc=md,dc=test", + "cn=user_admin_for_roles,cn=Users,dc=md,dc=test", + "cn=user_non_admin,cn=Users,dc=md,dc=test", + "cn=user1,cn=moscow,cn=russia,cn=Users,dc=md,dc=test", "cn=user_admin_1,ou=test_bit_rules,dc=md,dc=test", "cn=user_admin_2,ou=test_bit_rules,dc=md,dc=test", ], @@ -83,11 +83,11 @@ { "filter": f"(!(userAccountControl:1.2.840.113556.1.4.803:={UserAccountControlFlag.ACCOUNTDISABLE}))", # noqa: E501 "objects": [ - "cn=user0,cn=users,dc=md,dc=test", - "cn=user_admin,cn=users,dc=md,dc=test", - "cn=user_admin_for_roles,cn=users,dc=md,dc=test", - "cn=user_non_admin,cn=users,dc=md,dc=test", - "cn=user1,cn=moscow,cn=russia,cn=users,dc=md,dc=test", + "cn=user0,cn=Users,dc=md,dc=test", + "cn=user_admin,cn=Users,dc=md,dc=test", + "cn=user_admin_for_roles,cn=Users,dc=md,dc=test", + "cn=user_non_admin,cn=Users,dc=md,dc=test", + "cn=user1,cn=moscow,cn=russia,cn=Users,dc=md,dc=test", "cn=user_admin_2,ou=test_bit_rules,dc=md,dc=test", ], }, @@ -104,14 +104,14 @@ + UserAccountControlFlag.NORMAL_ACCOUNT })", "objects": [ - "cn=user0,cn=users,dc=md,dc=test", - "cn=user_admin,cn=users,dc=md,dc=test", - "cn=user_admin_for_roles,cn=users,dc=md,dc=test", + "cn=user0,cn=Users,dc=md,dc=test", + "cn=user_admin,cn=Users,dc=md,dc=test", + "cn=user_admin_for_roles,cn=Users,dc=md,dc=test", "cn=user_admin_1,ou=test_bit_rules,dc=md,dc=test", "cn=user_admin_2,ou=test_bit_rules,dc=md,dc=test", "cn=user_admin_3,ou=test_bit_rules,dc=md,dc=test", - "cn=user1,cn=moscow,cn=russia,cn=users,dc=md,dc=test", - "cn=user_non_admin,cn=users,dc=md,dc=test", + "cn=user1,cn=moscow,cn=russia,cn=Users,dc=md,dc=test", + "cn=user_non_admin,cn=Users,dc=md,dc=test", ], }, { @@ -124,11 +124,11 @@ { "filter": f"(!(userAccountControl:1.2.840.113556.1.4.804:={UserAccountControlFlag.ACCOUNTDISABLE}))", # noqa: E501 "objects": [ - "cn=user0,cn=users,dc=md,dc=test", - "cn=user_admin,cn=users,dc=md,dc=test", - "cn=user_admin_for_roles,cn=users,dc=md,dc=test", - "cn=user_non_admin,cn=users,dc=md,dc=test", - "cn=user1,cn=moscow,cn=russia,cn=users,dc=md,dc=test", + "cn=user0,cn=Users,dc=md,dc=test", + "cn=user_admin,cn=Users,dc=md,dc=test", + "cn=user_admin_for_roles,cn=Users,dc=md,dc=test", + "cn=user_non_admin,cn=Users,dc=md,dc=test", + "cn=user1,cn=moscow,cn=russia,cn=Users,dc=md,dc=test", "cn=user_admin_2,ou=test_bit_rules,dc=md,dc=test", ], }, diff --git a/tests/test_api/test_audit/test_router.py b/tests/test_api/test_audit/test_router.py index 2abc682b6..bd201d23f 100644 --- a/tests/test_api/test_audit/test_router.py +++ b/tests/test_api/test_audit/test_router.py @@ -10,15 +10,15 @@ from fastapi import status from httpx import AsyncClient +from api.audit.schemas import ( + AuditDestinationSchemaRequest, + AuditPolicySchemaRequest, +) from enums import AuditDestinationProtocolType, AuditDestinationServiceType from ldap_protocol.policies.audit.dataclasses import ( AuditDestinationDTO, AuditPolicyDTO, ) -from ldap_protocol.policies.audit.schemas import ( - AuditDestinationSchemaRequest, - AuditPolicySchemaRequest, -) @pytest.mark.asyncio diff --git a/tests/test_api/test_auth/test_router.py b/tests/test_api/test_auth/test_router.py index c13c0a5a6..0256c0463 100644 --- a/tests/test_api/test_auth/test_router.py +++ b/tests/test_api/test_auth/test_router.py @@ -24,7 +24,7 @@ from ldap_protocol.ldap_codes import LDAPCodes from ldap_protocol.ldap_requests.modify import Operation from ldap_protocol.session_storage import SessionStorage -from ldap_protocol.utils.queries import get_search_path +from ldap_protocol.utils.queries import get_filter_from_path from password_utils import PasswordUtils from repo.pg.tables import queryable_attr as qa from tests.conftest import TestCreds @@ -114,7 +114,7 @@ async def test_first_setup_and_oauth( assert result["user_principal_name"] == "test" assert result["mail"] == "test@md.example-345.ru" assert result["display_name"] == "test" - assert result["dn"] == "cn=test,cn=users,dc=md,dc=test-localhost" + assert result["dn"] == "cn=test,cn=Users,dc=md,dc=test-localhost" result = await session.scalars( select(Directory) @@ -123,9 +123,9 @@ async def test_first_setup_and_oauth( .selectinload(qa(Group.roles)) .selectinload(qa(Role.access_control_entries)), ) - .filter_by( - path=get_search_path( - "cn=read-only,cn=groups,dc=md,dc=test-localhost", + .filter( + get_filter_from_path( + "cn=read-only,cn=Groups,dc=md,dc=test-localhost", ), ), ) @@ -211,7 +211,7 @@ async def test_first_setup_with_invalid_domain( "/auth/setup", json=test_case, ) - assert response.status_code == status.HTTP_422_UNPROCESSABLE_ENTITY + assert response.status_code == status.HTTP_422_UNPROCESSABLE_CONTENT response = await unbound_http_client.get("/auth/setup") assert response.status_code == status.HTTP_200_OK @@ -222,7 +222,7 @@ async def test_first_setup_with_invalid_domain( @pytest.mark.usefixtures("session") async def test_update_password_and_check_uac(http_client: AsyncClient) -> None: """Update password and check userAccountControl attr.""" - user_dn = "cn=user0,cn=users,dc=md,dc=test" + user_dn = "cn=user0,cn=Users,dc=md,dc=test" response = await http_client.patch( "entry/update", @@ -384,7 +384,7 @@ async def test_update_password_with_empty_old_password( }, ) - assert response.status_code == status.HTTP_422_UNPROCESSABLE_ENTITY + assert response.status_code == status.HTTP_422_UNPROCESSABLE_CONTENT new_auth = await http_client.post( "auth/", @@ -468,7 +468,7 @@ async def test_auth_disabled_user( response = await http_client.patch( "entry/update", json={ - "object": "cn=user_admin,cn=users,dc=md,dc=test", + "object": "cn=user_admin,cn=Users,dc=md,dc=test", "changes": [ { "operation": Operation.REPLACE, @@ -507,7 +507,7 @@ async def test_lock_and_unlock_user( storage: SessionStorage, ) -> None: """Block user and verify nsAccountLock and shadowExpires attributes.""" - user_dn = "cn=user_non_admin,cn=users,dc=md,dc=test" + user_dn = "cn=user_non_admin,cn=Users,dc=md,dc=test" dir_ = await session.scalar( select(Directory) .options(joinedload(qa(Directory.user))) diff --git a/tests/test_api/test_auth/test_sessions.py b/tests/test_api/test_auth/test_sessions.py index 59b11208c..e2c6d3fd9 100644 --- a/tests/test_api/test_auth/test_sessions.py +++ b/tests/test_api/test_auth/test_sessions.py @@ -217,7 +217,7 @@ async def test_block_ldap_user_without_session( storage: SessionStorage, ) -> None: """Test blocking ldap user without active session.""" - user_dn = "cn=user_non_admin,cn=users,dc=md,dc=test" + user_dn = "cn=user_non_admin,cn=Users,dc=md,dc=test" un = "user_non_admin" user = await get_user(session, un) @@ -253,7 +253,7 @@ async def test_block_ldap_user_with_active_session( storage: SessionStorage, ) -> None: """Test blocking ldap user with active session.""" - user_dn = "cn=user_non_admin,cn=users,dc=md,dc=test" + user_dn = "cn=user_non_admin,cn=Users,dc=md,dc=test" un = "user_non_admin" pw = "password" diff --git a/tests/test_api/test_dhcp/test_adapter.py b/tests/test_api/test_dhcp/test_adapter.py index 5d2dd4b26..f67b03016 100644 --- a/tests/test_api/test_dhcp/test_adapter.py +++ b/tests/test_api/test_dhcp/test_adapter.py @@ -10,6 +10,11 @@ import pytest from api.dhcp.adapter import DHCPAdapter +from api.dhcp.schemas import ( + DHCPLeaseSchemaRequest, + DHCPReservationSchemaRequest, + DHCPSubnetSchemaAddRequest, +) from authorization_provider_protocol import AuthorizationProviderProtocol from ldap_protocol.dhcp.dataclasses import ( DHCPLease, @@ -18,11 +23,6 @@ DHCPReservation, DHCPSubnet, ) -from ldap_protocol.dhcp.schemas import ( - DHCPLeaseSchemaRequest, - DHCPReservationSchemaRequest, - DHCPSubnetSchemaAddRequest, -) @pytest.fixture diff --git a/tests/test_api/test_dhcp/test_router.py b/tests/test_api/test_dhcp/test_router.py index fbb739816..2ea89d0de 100644 --- a/tests/test_api/test_dhcp/test_router.py +++ b/tests/test_api/test_dhcp/test_router.py @@ -147,7 +147,7 @@ async def test_create_subnet_invalid_data( json=invalid_data, ) - assert response.status_code == status.HTTP_422_UNPROCESSABLE_ENTITY + assert response.status_code == status.HTTP_422_UNPROCESSABLE_CONTENT @pytest.mark.asyncio @@ -297,7 +297,7 @@ async def test_create_lease_invalid_data( json=invalid_data, ) - assert response.status_code == status.HTTP_422_UNPROCESSABLE_ENTITY + assert response.status_code == status.HTTP_422_UNPROCESSABLE_CONTENT @pytest.mark.asyncio @@ -486,7 +486,7 @@ async def test_create_reservation_invalid_data( json=invalid_data, ) - assert response.status_code == status.HTTP_422_UNPROCESSABLE_ENTITY + assert response.status_code == status.HTTP_422_UNPROCESSABLE_CONTENT @pytest.mark.asyncio @@ -597,7 +597,7 @@ async def test_delete_reservation_missing_params( """Test reservation deletion with missing parameters.""" response = await http_client.delete("/dhcp/reservation") - assert response.status_code == status.HTTP_422_UNPROCESSABLE_ENTITY + assert response.status_code == status.HTTP_422_UNPROCESSABLE_CONTENT @pytest.mark.asyncio diff --git a/tests/test_api/test_ldap_schema/test_attribute_type_router_datasets.py b/tests/test_api/test_ldap_schema/test_attribute_type_router_datasets.py index bcfea7210..e04eecc8d 100644 --- a/tests/test_api/test_ldap_schema/test_attribute_type_router_datasets.py +++ b/tests/test_api/test_ldap_schema/test_attribute_type_router_datasets.py @@ -115,6 +115,6 @@ { "attribute_type_schemas": [], "attribute_types_deleted": [], - "status_code": status.HTTP_422_UNPROCESSABLE_ENTITY, + "status_code": status.HTTP_422_UNPROCESSABLE_CONTENT, }, ] diff --git a/tests/test_api/test_ldap_schema/test_entity_type_router.py b/tests/test_api/test_ldap_schema/test_entity_type_router.py index c130c2067..b7a40c66e 100644 --- a/tests/test_api/test_ldap_schema/test_entity_type_router.py +++ b/tests/test_api/test_ldap_schema/test_entity_type_router.py @@ -78,7 +78,7 @@ async def test_create_one_entity_type_value_422( "is_system": False, }, ) - assert response.status_code == status.HTTP_422_UNPROCESSABLE_ENTITY + assert response.status_code == status.HTTP_422_UNPROCESSABLE_CONTENT @pytest.mark.parametrize( diff --git a/tests/test_api/test_ldap_schema/test_object_class_router_datasets.py b/tests/test_api/test_ldap_schema/test_object_class_router_datasets.py index 3864c4035..1824cee7a 100644 --- a/tests/test_api/test_ldap_schema/test_object_class_router_datasets.py +++ b/tests/test_api/test_ldap_schema/test_object_class_router_datasets.py @@ -206,7 +206,7 @@ { "object_class_datas": [], "object_classes_deleted": [], - "status_code": status.HTTP_422_UNPROCESSABLE_ENTITY, + "status_code": status.HTTP_422_UNPROCESSABLE_CONTENT, }, { "object_class_datas": [ diff --git a/tests/test_api/test_main/conftest.py b/tests/test_api/test_main/conftest.py index 8f1b58dea..1ee2f69ba 100644 --- a/tests/test_api/test_main/conftest.py +++ b/tests/test_api/test_main/conftest.py @@ -106,7 +106,7 @@ async def adding_test_user( "operation": Operation.ADD, "modification": { "type": "memberOf", - "vals": ["cn=domain admins,cn=groups,dc=md,dc=test"], + "vals": ["cn=domain admins,cn=Groups,dc=md,dc=test"], }, }, { @@ -138,6 +138,30 @@ async def adding_test_user( assert auth.cookies.get("id") +@pytest_asyncio.fixture(scope="function") +async def adding_test_computer( + http_client: AsyncClient, +) -> None: + """Test api correct (name) add.""" + response = await http_client.post( + "/entry/add", + json={ + "entry": "cn=mycomputer,dc=md,dc=test", + "password": None, + "attributes": [ + {"type": "name", "vals": ["mycomputer name"]}, + {"type": "cn", "vals": ["mycomputer"]}, + {"type": "objectClass", "vals": ["computer", "top"]}, + ], + }, + ) + + data = response.json() + + assert isinstance(data, dict) + assert data.get("resultCode") == LDAPCodes.SUCCESS + + @pytest_asyncio.fixture(scope="function") async def add_dns_settings( session: AsyncSession, diff --git a/tests/test_api/test_main/test_dns.py b/tests/test_api/test_main/test_dns.py index 6d521408c..2838166f4 100644 --- a/tests/test_api/test_main/test_dns.py +++ b/tests/test_api/test_main/test_dns.py @@ -1,19 +1,12 @@ """Test DNS service.""" -from dataclasses import asdict - import pytest from httpx import AsyncClient from starlette import status -from ldap_protocol.dns import ( - AbstractDNSManager, - DNSManagerState, - DNSServerParam, - DNSServerParamName, - DNSZoneParam, - DNSZoneParamName, -) +from ldap_protocol.dns import AbstractDNSManager +from ldap_protocol.dns.dto import DNSMasterZoneDTO, DNSRecordDTO, DNSRRSetDTO +from ldap_protocol.dns.enums import DNSRecordType, PowerDNSZoneType @pytest.mark.asyncio @@ -26,12 +19,11 @@ async def test_dns_create_record( zone_name = "hello.zone" hostname = "hello" ip = "127.0.0.1" - record_type = "A" + record_type = DNSRecordType.A ttl = 3600 response = await http_client.post( - "/dns/record", + f"/dns/record/{zone_name}", json={ - "zone_name": zone_name, "record_name": hostname, "record_value": ip, "record_type": record_type, @@ -42,7 +34,20 @@ async def test_dns_create_record( dns_manager.create_record.assert_called() # type: ignore assert ( dns_manager.create_record.call_args.args # type: ignore - ) == (hostname, ip, record_type, int(ttl), zone_name) + ) == ( + zone_name, + DNSRRSetDTO( + name=hostname, + type=record_type, + records=[ + DNSRecordDTO( + content=ip, + disabled=False, + ), + ], + ttl=ttl, + ), + ) assert response.status_code == status.HTTP_200_OK @@ -57,12 +62,11 @@ async def test_dns_delete_record( zone_name = "hello.zone" hostname = "hello" ip = "127.0.0.1" - record_type = "A" + record_type = DNSRecordType.A response = await http_client.request( "DELETE", - "/dns/record", + f"/dns/record/{zone_name}", json={ - "zone_name": zone_name, "record_name": hostname, "record_value": ip, "record_type": record_type, @@ -72,7 +76,19 @@ async def test_dns_delete_record( dns_manager.delete_record.assert_called() # type: ignore assert ( dns_manager.delete_record.call_args.args # type: ignore - ) == (hostname, ip, record_type, zone_name) + ) == ( + zone_name, + DNSRRSetDTO( + name=hostname, + type=record_type, + records=[ + DNSRecordDTO( + content=ip, + disabled=False, + ), + ], + ), + ) assert response.status_code == status.HTTP_200_OK @@ -87,13 +103,12 @@ async def test_dns_update_record( zone_name = "hello.zone" hostname = "hello" ip = "127.0.0.1" - record_type = "A" + record_type = DNSRecordType.A ttl = 3600 response = await http_client.request( "PATCH", - "/dns/record", + f"/dns/record/{zone_name}", json={ - "zone_name": zone_name, "record_name": hostname, "record_value": ip, "record_type": record_type, @@ -104,7 +119,20 @@ async def test_dns_update_record( dns_manager.update_record.assert_called() # type: ignore assert ( dns_manager.update_record.call_args.args # type: ignore - ) == (hostname, ip, record_type, int(ttl), zone_name) + ) == ( + zone_name, + DNSRRSetDTO( + name=hostname, + type=record_type, + records=[ + DNSRecordDTO( + content=ip, + disabled=False, + ), + ], + ttl=ttl, + ), + ) assert response.status_code == status.HTTP_200_OK @@ -113,21 +141,25 @@ async def test_dns_update_record( @pytest.mark.usefixtures("session") async def test_dns_get_all_records(http_client: AsyncClient) -> None: """DNS Manager get all records test.""" - response = await http_client.get("/dns/record") + zone_name = "hello.zone" + response = await http_client.get(f"/dns/record/{zone_name}") assert response.status_code == status.HTTP_200_OK data = response.json() assert data == [ { + "name": "example.com", "type": "A", + "changetype": None, "records": [ { - "name": "example.com", - "value": "127.0.0.1", - "ttl": 3600, + "content": "127.0.0.1", + "disabled": False, + "modified_at": None, }, ], + "ttl": 3600, }, ] @@ -139,18 +171,12 @@ async def test_dns_setup_selfhosted( dns_manager: AbstractDNSManager, ) -> None: """DNS Manager setup test.""" - dns_status = DNSManagerState.SELFHOSTED - domain = "example.com" - tsig_key = None - dns_ip_address = "127.0.0.1" + response = await http_client.post("/dns/state", json={"state": "1"}) + + assert response.status_code == status.HTTP_200_OK + response = await http_client.post( "/dns/setup", - json={ - "dns_status": dns_status, - "domain": domain, - "dns_ip_address": dns_ip_address, - "tsig_key": tsig_key, - }, ) assert response.status_code == status.HTTP_200_OK @@ -182,28 +208,31 @@ async def test_dns_create_zone( ) -> None: """DNS Manager create zone test.""" zone_name = "hello" - zone_type = "master" - nameserver = None - params = [ - DNSZoneParam( - DNSZoneParamName.acl, - ["127.0.0.1"], - ), - ] + nameserver = "192.168.1.1" response = await http_client.post( "/dns/zone", json={ "zone_name": zone_name, - "zone_type": zone_type, - "params": [asdict(param) for param in params], + "nameserver_ip": nameserver, + "dnssec": False, }, ) assert response.status_code == status.HTTP_200_OK - dns_manager.create_zone.assert_called() # type: ignore + dns_manager.create_master_zone.assert_called() # type: ignore assert ( - dns_manager.create_zone.call_args.args # type: ignore - ) == (zone_name, zone_type, nameserver, params) + dns_manager.create_master_zone.call_args.args # type: ignore + ) == ( + DNSMasterZoneDTO( + id=zone_name, + rrsets=[], + name=zone_name, + dnssec=False, + type="zone", + nameservers=[], + kind=PowerDNSZoneType.MASTER, + ), + ) @pytest.mark.asyncio @@ -215,25 +244,31 @@ async def test_dns_update_zone( ) -> None: """DNS Manager update zone test.""" zone_name = "hello" - params = [ - DNSZoneParam( - DNSZoneParamName.acl, - ["127.0.0.1"], - ), - ] + nameserver = "192.168.1.1" response = await http_client.patch( "/dns/zone", json={ "zone_name": zone_name, - "params": [asdict(param) for param in params], + "nameserver_ip": nameserver, + "dnssec": False, }, ) assert response.status_code == status.HTTP_200_OK - dns_manager.update_zone.assert_called() # type: ignore + dns_manager.update_master_zone.assert_called() # type: ignore assert ( - dns_manager.update_zone.call_args.args # type: ignore - ) == (zone_name, params) + dns_manager.update_master_zone.call_args.args # type: ignore + ) == ( + DNSMasterZoneDTO( + id=zone_name, + rrsets=[], + name=zone_name, + dnssec=False, + type="zone", + nameservers=[], + kind=PowerDNSZoneType.MASTER, + ), + ) @pytest.mark.asyncio @@ -244,67 +279,19 @@ async def test_dns_delete_zone( dns_manager: AbstractDNSManager, ) -> None: """DNS Manager delete zone test.""" - zone_names = ["hello"] + zone_ids = ["hello"] response = await http_client.request( "DELETE", "/dns/zone", - json={"zone_names": zone_names}, - ) - - assert response.status_code == status.HTTP_200_OK - dns_manager.delete_zone.assert_called() # type: ignore - assert ( - dns_manager.delete_zone.call_args.args # type: ignore - ) == (zone_names,) - - -@pytest.mark.asyncio -@pytest.mark.usefixtures("add_dns_settings") -@pytest.mark.usefixtures("session") -async def test_dns_update_server_options( - http_client: AsyncClient, - dns_manager: AbstractDNSManager, -) -> None: - """DNS Manager update DNS server options test.""" - params = [ - DNSServerParam( - DNSServerParamName.dnssec, - ["127.0.0.1"], - ), - ] - response = await http_client.patch( - "/dns/server/options", - json=[asdict(param) for param in params], + json={"zone_ids": zone_ids}, ) assert response.status_code == status.HTTP_200_OK - dns_manager.update_server_options.assert_called() # type: ignore + dns_manager.delete_master_zone.assert_called() # type: ignore assert ( - dns_manager.update_server_options.call_args.args # type: ignore - ) == (params,) - - -@pytest.mark.asyncio -@pytest.mark.usefixtures("add_dns_settings") -@pytest.mark.usefixtures("session") -async def test_dns_get_server_options( - http_client: AsyncClient, - dns_manager: AbstractDNSManager, -) -> None: - """DNS Manager get DNS server options test.""" - response = await http_client.get("/dns/server/options") - - assert response.status_code == status.HTTP_200_OK - dns_manager.get_server_options.assert_called() # type: ignore - - data = response.json() - assert data == [ - { - "name": "dnssec-validation", - "value": "no", - }, - ] + dns_manager.delete_master_zone.call_args.args # type: ignore + ) == (zone_ids[0],) @pytest.mark.asyncio @@ -318,25 +305,32 @@ async def test_dns_get_all_zones_with_records( response = await http_client.get("/dns/zone") assert response.status_code == status.HTTP_200_OK - dns_manager.get_all_zones_records.assert_called() # type: ignore + dns_manager.get_master_zones.assert_called() # type: ignore data = response.json() assert data == [ { - "name": "test.local", - "type": "master", - "records": [ + "id": "zone1", + "name": "example.com.", + "rrsets": [ { + "name": "example.com", "type": "A", + "changetype": None, "records": [ { - "name": "example.com", - "value": "127.0.0.1", - "ttl": 3600, + "content": "127.0.0.1", + "disabled": False, + "modified_at": None, }, ], + "ttl": 3600, }, ], + "dnssec": False, + "nameservers": ["ns1.example.com."], + "kind": "Master", + "type": "zone", }, ] @@ -357,11 +351,12 @@ async def test_dns_get_all_forward_zones( data = response.json() assert data == [ { - "name": "test.local", - "type": "forward", - "forwarders": [ - "127.0.0.1", - "127.0.0.2", - ], + "id": "forward1", + "name": "forward1.", + "rrsets": [], + "kind": "Forwarded", + "type": "zone", + "servers": ["127.0.0.1"], + "recursion_desired": False, }, ] diff --git a/tests/test_api/test_main/test_kadmin.py b/tests/test_api/test_main/test_kadmin.py index 0ffbd6ebe..b13909357 100644 --- a/tests/test_api/test_main/test_kadmin.py +++ b/tests/test_api/test_main/test_kadmin.py @@ -95,7 +95,7 @@ async def test_tree_creation( bind = MutePolicyBindRequest( version=0, - name="cn=krbadmin,cn=users,dc=md,dc=test", + name="cn=krbadmin,cn=Users,dc=md,dc=test", AuthenticationChoice=SimpleAuthentication(password=krbadmin_pw), ) @@ -162,7 +162,7 @@ async def test_setup_call( assert kadmin.setup.call_args.kwargs == { "domain": "md.test", - "admin_dn": "cn=user0,cn=users,dc=md,dc=test", + "admin_dn": "cn=user0,cn=Users,dc=md,dc=test", "services_dn": "ou=System,dc=md,dc=test", "krbadmin_dn": "cn=krbadmin,cn=users,dc=md,dc=test", "krbadmin_password": "Password123", @@ -212,7 +212,10 @@ async def test_ktadd( :param LDAPSession ldap_session: ldap """ names = ["test1", "test2"] - response = await http_client.post("/kerberos/ktadd", json=names) + response = await http_client.post( + "/kerberos/ktadd", + json={"names": names, "is_rand_key": False}, + ) kadmin.ktadd.assert_called() # type: ignore assert kadmin.ktadd.call_args.args[0] == names # type: ignore @@ -240,7 +243,10 @@ async def test_ktadd_400( kadmin.ktadd.side_effect = KRBAPIPrincipalNotFoundError() # type: ignore names = ["test1", "test2"] - response = await http_client.post("/kerberos/ktadd", json=names) + response = await http_client.post( + "/kerberos/ktadd", + json={"names": names, "is_rand_key": False}, + ) assert response.status_code == status.HTTP_400_BAD_REQUEST @@ -349,7 +355,7 @@ async def test_bind_create_user( assert await proc.wait() == 0 kadmin_args = kadmin.add_principal.call_args.args # type: ignore - assert kadmin_args == (san, pw, 0.1) + assert kadmin_args == (san, pw) @pytest.mark.asyncio @@ -362,7 +368,7 @@ async def test_extended_pw_change_call( kadmin: AbstractKadmin, ) -> None: """Test anonymous pwd change.""" - user_dn = "cn=user0,cn=users,dc=md,dc=test" + user_dn = "cn=user0,cn=Users,dc=md,dc=test" password = creds.pw new_test_password = "Password123" # noqa await anonymous_ldap_client.bind(user_dn, password) @@ -389,20 +395,20 @@ async def test_add_princ( :param LDAPSession ldap_session: ldap """ response = await http_client.post( - "/kerberos/principal/add", + "/kerberos/principal", json={ - "primary": "host", - "instance": "12345", + "principal_name": "host/12345", + "password": None, }, ) kadmin_args = kadmin.add_principal.call_args.args # type: ignore assert response.status_code == status.HTTP_200_OK - assert kadmin_args == ("host/12345", None) + assert kadmin_args == ("host/12345", None, None) @pytest.mark.asyncio @pytest.mark.usefixtures("session") -async def test_rename_princ( +async def test_modify_princ( http_client: AsyncClient, kadmin: AbstractKadmin, ) -> None: @@ -411,16 +417,16 @@ async def test_rename_princ( :param AsyncClient http_client: http cl :param LDAPSession ldap_session: ldap """ - response = await http_client.patch( - "/kerberos/principal/rename", + response = await http_client.put( + "/kerberos/principal", json={ "principal_name": "name", - "principal_new_name": "nname", + "new_name": "nname", }, ) - kadmin_args = kadmin.rename_princ.call_args.args # type: ignore + kadmin_args = kadmin.modify_princ.call_args.args # type: ignore assert response.status_code == status.HTTP_200_OK - assert kadmin_args == ("name", "nname") + assert kadmin_args == ("name", "nname", None, None) @pytest.mark.asyncio @@ -434,16 +440,16 @@ async def test_change_princ( :param AsyncClient http_client: http cl :param LDAPSession ldap_session: ldap """ - response = await http_client.patch( - "/kerberos/principal/reset", + response = await http_client.put( + "/kerberos/principal", json={ "principal_name": "name", - "new_password": "pw123", + "password": "pw123", }, ) - kadmin_args = kadmin.change_principal_password.call_args.args # type: ignore + kadmin_args = kadmin.modify_princ.call_args.args # type: ignore assert response.status_code == status.HTTP_200_OK - assert kadmin_args == ("name", "pw123") + assert kadmin_args == ("name", None, None, "pw123") @pytest.mark.asyncio diff --git a/tests/test_api/test_main/test_router/test_add.py b/tests/test_api/test_main/test_router/test_add.py index 3050bedec..c83803322 100644 --- a/tests/test_api/test_main/test_router/test_add.py +++ b/tests/test_api/test_main/test_router/test_add.py @@ -8,6 +8,7 @@ from fastapi import status from httpx import AsyncClient +from enums import SamAccountTypeCodes from ldap_protocol.ldap_codes import LDAPCodes from ldap_protocol.objects import UserAccountControlFlag from tests.api_datasets import test_api_forbidden_chars_in_attr_value @@ -28,7 +29,7 @@ async def test_api_correct_add(http_client: AsyncClient) -> None: {"type": "objectClass", "vals": ["organization", "top"]}, { "type": "memberOf", - "vals": ["cn=domain admins,cn=groups,dc=md,dc=test"], + "vals": ["cn=domain admins,cn=Groups,dc=md,dc=test"], }, ], }, @@ -42,35 +43,6 @@ async def test_api_correct_add(http_client: AsyncClient) -> None: assert data.get("errorMessage") == "" -@pytest.mark.asyncio -@pytest.mark.usefixtures("session") -async def test_api_add_incorrect_computer_name( - http_client: AsyncClient, -) -> None: - """Test api incorrect (name) add.""" - response = await http_client.post( - "/entry/add", - json={ - "entry": "cn=test,dc=md,dc=test", - "password": None, - "attributes": [ - {"type": "name", "vals": [" test;incorrect"]}, - {"type": "cn", "vals": ["test"]}, - {"type": "objectClass", "vals": ["computer", "top"]}, - { - "type": "memberOf", - "vals": ["cn=domain admins,cn=groups,dc=md,dc=test"], - }, - ], - }, - ) - - data = response.json() - - assert isinstance(data, dict) - assert data.get("resultCode") == LDAPCodes.UNDEFINED_ATTRIBUTE_TYPE - - @pytest.mark.asyncio @pytest.mark.usefixtures("session") async def test_api_add_incorrect_user_samaccount_with_dot( @@ -171,6 +143,57 @@ async def test_api_add_computer(http_client: AsyncClient) -> None: else: raise Exception("Computer without userAccountControl") + for attr in data["search_result"][0]["partial_attributes"]: + if attr["type"] == "sAMAccountName": + assert attr["vals"][0] == "PC" + break + else: + raise Exception("Computer without sAMAccountName") + + +@pytest.mark.asyncio +@pytest.mark.usefixtures("session") +async def test_add_user_samaccounttype( + http_client: AsyncClient, +) -> None: + """Add user without sAMAccountType: server sets SAM_USER_OBJECT.""" + entry = "cn=samuser,dc=md,dc=test" + await http_client.post( + "/entry/add", + json={ + "entry": entry, + "password": "P@ssw0rd", + "attributes": [ + {"type": "name", "vals": ["samuser"]}, + {"type": "cn", "vals": ["samuser"]}, + {"type": "objectClass", "vals": ["user", "top"]}, + {"type": "sAMAccountName", "vals": ["samuser"]}, + {"type": "userPrincipalName", "vals": ["samuser@md.test"]}, + ], + }, + ) + response = await http_client.post( + "entry/search", + json={ + "base_object": entry, + "scope": 0, + "deref_aliases": 0, + "size_limit": 1, + "time_limit": 10, + "types_only": True, + "filter": "(objectClass=*)", + "attributes": ["sAMAccountType"], + "page_number": 1, + }, + ) + data = response.json() + attrs = { + a["type"]: a for a in data["search_result"][0]["partial_attributes"] + } + assert attrs["sAMAccountType"]["vals"][0] == str( + SamAccountTypeCodes.SAM_USER_OBJECT, + ) + @pytest.mark.asyncio @pytest.mark.usefixtures("session") @@ -186,7 +209,7 @@ async def test_api_correct_add_double_member_of( user = "cn=test0,dc=md,dc=test" un = "test0" groups = [ - "cn=domain admins,cn=groups,dc=md,dc=test", + "cn=domain admins,cn=Groups,dc=md,dc=test", new_group, ] @@ -307,7 +330,7 @@ async def test_api_correct_add_double_member_of( assert data.get("resultCode") == LDAPCodes.SUCCESS assert data["search_result"][0]["object_name"] == user - created_groups = groups + ["cn=domain users,cn=groups,dc=md,dc=test"] + created_groups = groups + ["cn=domain users,cn=Groups,dc=md,dc=test"] for attr in data["search_result"][0]["partial_attributes"]: if attr["type"] == "memberOf": @@ -528,7 +551,7 @@ async def test_api_double_add(http_client: AsyncClient) -> None: { "type": "memberOf", "vals": [ - "cn=domain admins,cn=groups,dc=md,dc=test", + "cn=domain admins,cn=Groups,dc=md,dc=test", ], }, ], @@ -568,7 +591,7 @@ async def test_api_add_double_case_insensetive( { "type": "memberOf", "vals": [ - "cn=domain admins,cn=groups,dc=md,dc=test", + "cn=domain admins,cn=Groups,dc=md,dc=test", ], }, ], @@ -597,7 +620,7 @@ async def test_api_add_double_case_insensetive( { "type": "memberOf", "vals": [ - "cn=domain admins,cn=groups,dc=md,dc=test", + "cn=domain admins,cn=Groups,dc=md,dc=test", ], }, ], diff --git a/tests/test_api/test_main/test_router/test_modify.py b/tests/test_api/test_main/test_router/test_modify.py index 3e46e879d..321c38795 100644 --- a/tests/test_api/test_main/test_router/test_modify.py +++ b/tests/test_api/test_main/test_router/test_modify.py @@ -8,6 +8,7 @@ from httpx import AsyncClient from sqlalchemy.ext.asyncio import AsyncSession +from ldap_protocol.kerberos.base import AbstractKadmin from ldap_protocol.ldap_codes import LDAPCodes from ldap_protocol.ldap_requests.modify import Operation @@ -16,10 +17,13 @@ @pytest.mark.usefixtures("adding_test_user") @pytest.mark.usefixtures("setup_session") @pytest.mark.usefixtures("session") -async def test_api_correct_modify(http_client: AsyncClient) -> None: +async def test_api_correct_modify_user_accountexpires( + http_client: AsyncClient, +) -> None: """Test API for modify object attribute.""" entry_dn = "cn=test,dc=md,dc=test" new_value = "133632677730000000" + response = await http_client.patch( "/entry/update", json={ @@ -37,7 +41,6 @@ async def test_api_correct_modify(http_client: AsyncClient) -> None: ) data = response.json() - assert isinstance(data, dict) assert data.get("resultCode") == LDAPCodes.SUCCESS @@ -57,39 +60,145 @@ async def test_api_correct_modify(http_client: AsyncClient) -> None: ) data = response.json() - assert data["resultCode"] == LDAPCodes.SUCCESS assert data["search_result"][0]["object_name"] == entry_dn for attr in data["search_result"][0]["partial_attributes"]: if attr["type"] == "accountExpires": assert attr["vals"][0] == new_value + break + else: + raise Exception("User without accountExpires") @pytest.mark.asyncio +@pytest.mark.usefixtures("adding_test_user") @pytest.mark.usefixtures("setup_session") @pytest.mark.usefixtures("session") -async def test_api_duplicate_with_spaces_modify( +async def test_api_correct_modify_user_samaccountname( http_client: AsyncClient, + kadmin: AbstractKadmin, ) -> None: - """Test API for modify duplicated object name.""" - entry_dn = "cn=new_test,dc=md,dc=test" + """Test API for modify object attribute.""" + entry_dn = "cn=test,dc=md,dc=test" + + response = await http_client.patch( + "/entry/update", + json={ + "object": entry_dn, + "changes": [ + { + "operation": Operation.REPLACE, + "modification": { + "type": "sAMAccountName", + "vals": ["NEW user name"], + }, + }, + ], + }, + ) + + data = response.json() + assert isinstance(data, dict) + assert data.get("resultCode") == LDAPCodes.SUCCESS + assert kadmin.modify_princ.call_args.args == ("new_user", "NEW user name") # type: ignore + response = await http_client.post( - "/entry/add", + "entry/search", json={ - "entry": entry_dn, - "password": None, - "attributes": [ + "base_object": entry_dn, + "scope": 0, + "deref_aliases": 0, + "size_limit": 1000, + "time_limit": 10, + "types_only": True, + "filter": "(objectClass=*)", + "attributes": [], + "page_number": 1, + }, + ) + + data = response.json() + assert data["resultCode"] == LDAPCodes.SUCCESS + assert data["search_result"][0]["object_name"] == entry_dn + + for attr in data["search_result"][0]["partial_attributes"]: + if attr["type"] == "sAMAccountName": + assert attr["vals"][0] == "NEW user name" + break + else: + raise Exception("User without sAMAccountName") + + +@pytest.mark.asyncio +@pytest.mark.usefixtures("adding_test_user") +@pytest.mark.usefixtures("setup_session") +@pytest.mark.usefixtures("session") +async def test_api_correct_modify_user_userprincipalname( + http_client: AsyncClient, + kadmin: AbstractKadmin, +) -> None: + """Test API for modify object attribute.""" + entry_dn = "cn=test,dc=md,dc=test" + + response = await http_client.patch( + "/entry/update", + json={ + "object": entry_dn, + "changes": [ { - "type": "objectClass", - "vals": ["organization", "top"], + "operation": Operation.REPLACE, + "modification": { + "type": "userPrincipalName", + "vals": ["newbiguser@md.test"], + }, }, ], }, ) + data = response.json() + assert isinstance(data, dict) assert data.get("resultCode") == LDAPCodes.SUCCESS + assert kadmin.modify_princ.call_args.args == ("new_user", "newbiguser") # type: ignore + response = await http_client.post( + "entry/search", + json={ + "base_object": entry_dn, + "scope": 0, + "deref_aliases": 0, + "size_limit": 1000, + "time_limit": 10, + "types_only": True, + "filter": "(objectClass=*)", + "attributes": [], + "page_number": 1, + }, + ) + + data = response.json() + assert data["resultCode"] == LDAPCodes.SUCCESS + assert data["search_result"][0]["object_name"] == entry_dn + + for attr in data["search_result"][0]["partial_attributes"]: + if attr["type"] == "userPrincipalName": + assert attr["vals"][0] == "newbiguser@md.test" + break + else: + raise Exception("User without userPrincipalName") + + +@pytest.mark.asyncio +@pytest.mark.usefixtures("adding_test_computer") +@pytest.mark.usefixtures("setup_session") +@pytest.mark.usefixtures("session") +async def test_api_correct_modify_computer_samaccountname_replace( + http_client: AsyncClient, + kadmin: AbstractKadmin, +) -> None: + """Test API for modify computer sAMAccountName.""" + entry_dn = "cn=mycomputer,dc=md,dc=test" response = await http_client.patch( "/entry/update", json={ @@ -98,8 +207,8 @@ async def test_api_duplicate_with_spaces_modify( { "operation": Operation.REPLACE, "modification": { - "type": "cn", - "vals": [" test"], + "type": "sAMAccountName", + "vals": ["maincomputer"], }, }, ], @@ -110,6 +219,15 @@ async def test_api_duplicate_with_spaces_modify( assert isinstance(data, dict) assert data.get("resultCode") == LDAPCodes.SUCCESS + assert kadmin.modify_princ.call_count == 2 # type: ignore + assert kadmin.modify_princ.call_args_list[0].args == ( # type: ignore + "host/mycomputer", + "host/maincomputer", + ) + assert kadmin.modify_princ.call_args_list[1].args == ( # type: ignore + "host/mycomputer.md.test", + "host/maincomputer.md.test", + ) response = await http_client.post( "entry/search", @@ -127,9 +245,48 @@ async def test_api_duplicate_with_spaces_modify( ) data = response.json() - assert isinstance(data, dict) + + assert data["resultCode"] == LDAPCodes.SUCCESS assert data["search_result"][0]["object_name"] == entry_dn + for attr in data["search_result"][0]["partial_attributes"]: + if attr["type"] == "sAMAccountName": + assert attr["vals"][0] == "maincomputer" + break + else: + raise Exception("Computer without sAMAccountName") + + +@pytest.mark.asyncio +@pytest.mark.usefixtures("adding_test_computer") +@pytest.mark.usefixtures("setup_session") +@pytest.mark.usefixtures("session") +async def test_api_incorrect_modify_computer_samaccountname_add( + http_client: AsyncClient, +) -> None: + """Test API for modify computer sAMAccountName.""" + entry_dn = "cn=mycomputer,dc=md,dc=test" + response = await http_client.patch( + "/entry/update", + json={ + "object": entry_dn, + "changes": [ + { + "operation": Operation.ADD, + "modification": { + "type": "sAMAccountName", + "vals": ["maincomputer"], + }, + }, + ], + }, + ) + + data = response.json() + + assert isinstance(data, dict) + assert data.get("resultCode") == LDAPCodes.OPERATIONS_ERROR + @pytest.mark.asyncio @pytest.mark.usefixtures("adding_test_user") @@ -262,8 +419,8 @@ async def test_api_correct_modify_replace_memberof( http_client: AsyncClient, ) -> None: """Test API for modify object attribute.""" - user = "cn=user1,cn=moscow,cn=russia,cn=users,dc=md,dc=test" - new_group = "cn=domain admins,cn=groups,dc=md,dc=test" + user = "cn=user1,cn=moscow,cn=russia,cn=Users,dc=md,dc=test" + new_group = "cn=domain admins,cn=Groups,dc=md,dc=test" response = await http_client.patch( "/entry/update", json={ @@ -320,13 +477,13 @@ async def test_api_modify_add_loop_detect_member( response = await http_client.patch( "/entry/update", json={ - "object": "cn=developers,cn=groups,dc=md,dc=test", + "object": "cn=developers,cn=Groups,dc=md,dc=test", "changes": [ { "operation": Operation.ADD, "modification": { "type": "member", - "vals": ["cn=domain admins,cn=groups,dc=md,dc=test"], + "vals": ["cn=domain admins,cn=Groups,dc=md,dc=test"], }, }, ], @@ -347,13 +504,13 @@ async def test_api_modify_add_loop_detect_memberof( response = await http_client.patch( "/entry/update", json={ - "object": "cn=domain admins,cn=groups,dc=md,dc=test", + "object": "cn=domain admins,cn=Groups,dc=md,dc=test", "changes": [ { "operation": Operation.ADD, "modification": { "type": "memberOf", - "vals": ["cn=developers,cn=groups,dc=md,dc=test"], + "vals": ["cn=developers,cn=Groups,dc=md,dc=test"], }, }, ], @@ -374,15 +531,15 @@ async def test_api_modify_replace_loop_detect_member( response = await http_client.patch( "/entry/update", json={ - "object": "cn=developers,cn=groups,dc=md,dc=test", + "object": "cn=developers,cn=Groups,dc=md,dc=test", "changes": [ { "operation": Operation.REPLACE, "modification": { "type": "member", "vals": [ - "cn=user1,cn=moscow,cn=russia,cn=users,dc=md,dc=test", - "cn=domain admins,cn=groups,dc=md,dc=test", + "cn=user1,cn=moscow,cn=russia,cn=Users,dc=md,dc=test", + "cn=domain admins,cn=Groups,dc=md,dc=test", ], }, }, @@ -404,15 +561,15 @@ async def test_api_modify_replace_loop_detect_memberof( response = await http_client.patch( "/entry/update", json={ - "object": "cn=domain admins,cn=groups,dc=md,dc=test", + "object": "cn=domain admins,cn=Groups,dc=md,dc=test", "changes": [ { "operation": Operation.REPLACE, "modification": { "type": "memberOf", "vals": [ - "cn=domain computers,cn=groups,dc=md,dc=test", - "cn=developers,cn=groups,dc=md,dc=test", + "cn=domain computers,cn=Groups,dc=md,dc=test", + "cn=developers,cn=Groups,dc=md,dc=test", ], }, }, @@ -431,7 +588,7 @@ async def test_api_modify_incorrect_uac(http_client: AsyncClient) -> None: response = await http_client.patch( "/entry/update", json={ - "object": "cn=user0,cn=users,dc=md,dc=test", + "object": "cn=user0,cn=Users,dc=md,dc=test", "changes": [ { "operation": Operation.REPLACE, @@ -455,7 +612,7 @@ async def test_qpi_modify_primary_object_classes( http_client: AsyncClient, ) -> None: """Test deleting primary object class.""" - entry_dn = "cn=user0,cn=users,dc=md,dc=test" + entry_dn = "cn=user0,cn=Users,dc=md,dc=test" response = await http_client.patch( "/entry/update", json={ @@ -487,7 +644,7 @@ async def test_api_set_primary_group( ) -> None: """Test API for setting primary group.""" user_dn = "cn=test,dc=md,dc=test" - group_dn = "cn=domain admins,cn=groups,dc=md,dc=test" + group_dn = "cn=domain admins,cn=Groups,dc=md,dc=test" response = await http_client.post( "/entry/set_primary_group", diff --git a/tests/test_api/test_main/test_router/test_modify_dn.py b/tests/test_api/test_main/test_router/test_modify_dn.py index b27360dae..8313049f5 100644 --- a/tests/test_api/test_main/test_router/test_modify_dn.py +++ b/tests/test_api/test_main/test_router/test_modify_dn.py @@ -219,13 +219,13 @@ async def test_api_modify_dn_with_level_up( @pytest.mark.usefixtures("session") async def test_api_correct_update_dn(http_client: AsyncClient) -> None: """Test API for update DN.""" - old_user_dn = "cn=user1,cn=moscow,cn=russia,cn=users,dc=md,dc=test" + old_user_dn = "cn=user1,cn=moscow,cn=russia,cn=Users,dc=md,dc=test" newrdn_user = "cn=new_test2" - old_group_dn = "cn=developers,cn=groups,dc=md,dc=test" - new_group_dn = "cn=new_developers,cn=groups,dc=md,dc=test" + old_group_dn = "cn=developers,cn=Groups,dc=md,dc=test" + new_group_dn = "cn=new_developers,cn=Groups,dc=md,dc=test" newrdn_group = "cn=new_developers" - new_superior_group = "cn=groups,dc=md,dc=test" + new_superior_group = "cn=Groups,dc=md,dc=test" new_user_dn = ",".join((newrdn_user, new_superior_group)) @@ -338,8 +338,8 @@ async def test_api_correct_update_dn(http_client: AsyncClient) -> None: @pytest.mark.usefixtures("session") async def test_api_update_dn_with_parent(http_client: AsyncClient) -> None: """Test API for update DN.""" - old_user_dn = "cn=user1,cn=moscow,cn=russia,cn=users,dc=md,dc=test" - new_user_dn = "cn=new_test2,cn=users,dc=md,dc=test" + old_user_dn = "cn=user1,cn=moscow,cn=russia,cn=Users,dc=md,dc=test" + new_user_dn = "cn=new_test2,cn=Users,dc=md,dc=test" groups_user = None newrdn_user, new_superior = new_user_dn.split(",", maxsplit=1) @@ -540,3 +540,81 @@ async def test_api_update_dn_invalid_new_superior( assert isinstance(data, dict) assert data.get("resultCode") == LDAPCodes.INVALID_DN_SYNTAX + + +@pytest.mark.asyncio +@pytest.mark.usefixtures("adding_test_user") +@pytest.mark.usefixtures("setup_session") +@pytest.mark.usefixtures("session") +async def test_api_modify_dn_many(http_client: AsyncClient) -> None: + """Test API for bulk modify DN.""" + entry_dn_1 = "cn=test,dc=md,dc=test" + entry_dn_2 = "cn=test2,dc=md,dc=test" + + response = await http_client.post( + "/entry/add", + json={ + "entry": entry_dn_2, + "password": None, + "attributes": [ + {"type": "name", "vals": ["test2"]}, + {"type": "cn", "vals": ["test2"]}, + {"type": "objectClass", "vals": ["organization", "top"]}, + ], + }, + ) + assert response.json()["resultCode"] == LDAPCodes.SUCCESS + + response = await http_client.post( + "/entry/update_many/dn", + json=[ + { + "entry": entry_dn_1, + "newrdn": "cn=test", + "deleteoldrdn": True, + "new_superior": "ou=testModifyDn1,dc=md,dc=test", + }, + { + "entry": entry_dn_2, + "newrdn": "cn=test2", + "deleteoldrdn": True, + "new_superior": "ou=testModifyDn1,dc=md,dc=test", + }, + ], + ) + + data = response.json() + assert all( + result.get("resultCode") == LDAPCodes.SUCCESS for result in data + ) + + +@pytest.mark.asyncio +@pytest.mark.usefixtures("adding_test_user") +@pytest.mark.usefixtures("setup_session") +@pytest.mark.usefixtures("session") +async def test_api_modify_dn_many_with_error(http_client: AsyncClient) -> None: + """Test bulk modify DN with one invalid entry.""" + entry_dn = "cn=test,dc=md,dc=test" + + response = await http_client.post( + "/entry/update_many/dn", + json=[ + { + "entry": entry_dn, + "newrdn": "cn=test", + "deleteoldrdn": True, + "new_superior": "ou=testModifyDn1,dc=md,dc=test", + }, + { + "entry": "cn=nonExistent,dc=md,dc=test", + "newrdn": "cn=nonExistent", + "deleteoldrdn": True, + "new_superior": "dc=md,dc=test", + }, + ], + ) + + data = response.json() + assert data[0].get("resultCode") == LDAPCodes.SUCCESS + assert data[1].get("resultCode") == LDAPCodes.NO_SUCH_OBJECT diff --git a/tests/test_api/test_main/test_router/test_search.py b/tests/test_api/test_main/test_router/test_search.py index 34a9377aa..1c591bd17 100644 --- a/tests/test_api/test_main/test_router/test_search.py +++ b/tests/test_api/test_main/test_router/test_search.py @@ -9,6 +9,7 @@ from enums import EntityTypeNames from ldap_protocol.ldap_codes import LDAPCodes +from ldap_protocol.ldap_requests.modify import Operation from tests.search_request_datasets import ( test_search_by_rule_anr_dataset, test_search_by_rule_bit_and_dataset, @@ -72,6 +73,39 @@ async def test_api_root_dse(http_client: AsyncClient) -> None: assert all(attr in aquired_attrs for attr in root_attrs) +@pytest.mark.asyncio +@pytest.mark.usefixtures("session") +async def test_api_root_dse_return_one_attr(http_client: AsyncClient) -> None: + """Test api root dse.""" + response = await http_client.post( + "entry/search", + json={ + "base_object": "", + "scope": 0, + "deref_aliases": 0, + "size_limit": 1000, + "time_limit": 10, + "types_only": True, + "filter": "(objectClass=*)", + "attributes": ["namingContexts"], + "page_number": 1, + }, + ) + + data = response.json() + + attrs = sorted( + data["search_result"][0]["partial_attributes"], + key=lambda x: x["type"], + ) + + aquired_attrs = {attr["type"] for attr in attrs} + root_attrs = {"namingContexts"} + + assert data["search_result"][0]["object_name"] == "" + assert aquired_attrs == root_attrs + + @pytest.mark.asyncio @pytest.mark.usefixtures("session") async def test_api_search(http_client: AsyncClient) -> None: @@ -96,8 +130,8 @@ async def test_api_search(http_client: AsyncClient) -> None: assert response["resultCode"] == LDAPCodes.SUCCESS sub_dirs = { - "cn=groups,dc=md,dc=test", - "cn=users,dc=md,dc=test", + "cn=Groups,dc=md,dc=test", + "cn=Users,dc=md,dc=test", "ou=testModifyDn1,dc=md,dc=test", "ou=testModifyDn3,dc=md,dc=test", "ou=test_bit_rules,dc=md,dc=test", @@ -111,7 +145,7 @@ async def test_api_search(http_client: AsyncClient) -> None: @pytest.mark.usefixtures("session") async def test_api_search_filter_memberof(http_client: AsyncClient) -> None: """Test api search.""" - member = "cn=user1,cn=moscow,cn=russia,cn=users,dc=md,dc=test" + member = "cn=user1,cn=moscow,cn=russia,cn=Users,dc=md,dc=test" raw_response = await http_client.post( "entry/search", json={ @@ -121,7 +155,7 @@ async def test_api_search_filter_memberof(http_client: AsyncClient) -> None: "size_limit": 1000, "time_limit": 10, "types_only": True, - "filter": "(memberOf=cn=developers,cn=groups,dc=md,dc=test)", + "filter": "(memberOf=cn=developers,cn=Groups,dc=md,dc=test)", "attributes": [], "page_number": 1, }, @@ -137,8 +171,8 @@ async def test_api_search_filter_memberof(http_client: AsyncClient) -> None: @pytest.mark.usefixtures("session") async def test_api_search_filter_member(http_client: AsyncClient) -> None: """Test api search.""" - member = "cn=user1,cn=moscow,cn=russia,cn=users,dc=md,dc=test" - group = "cn=developers,cn=groups,dc=md,dc=test" + member = "cn=user1,cn=moscow,cn=russia,cn=Users,dc=md,dc=test" + group = "cn=developers,cn=Groups,dc=md,dc=test" raw_response = await http_client.post( "entry/search", json={ @@ -241,11 +275,11 @@ async def test_api_search_filter_account_expires( @pytest.mark.usefixtures("session") async def test_api_search_complex_filter(http_client: AsyncClient) -> None: """Test api search.""" - user = "cn=user1,cn=moscow,cn=russia,cn=users,dc=md,dc=test" + user = "cn=user1,cn=moscow,cn=russia,cn=Users,dc=md,dc=test" raw_response = await http_client.post( "entry/search", json={ - "base_object": "cn=users,dc=md,dc=test", + "base_object": "cn=Users,dc=md,dc=test", "scope": 2, "deref_aliases": 0, "size_limit": 1000, @@ -278,12 +312,12 @@ async def test_api_search_complex_filter(http_client: AsyncClient) -> None: @pytest.mark.usefixtures("session") async def test_api_search_recursive_memberof(http_client: AsyncClient) -> None: """Test api search.""" - group = "cn=domain admins,cn=groups,dc=md,dc=test" + group = "cn=domain admins,cn=Groups,dc=md,dc=test" members = [ - "cn=developers,cn=groups,dc=md,dc=test", - "cn=user0,cn=users,dc=md,dc=test", - "cn=user_admin,cn=users,dc=md,dc=test", - "cn=user1,cn=moscow,cn=russia,cn=users,dc=md,dc=test", + "cn=developers,cn=Groups,dc=md,dc=test", + "cn=user0,cn=Users,dc=md,dc=test", + "cn=user_admin,cn=Users,dc=md,dc=test", + "cn=user1,cn=moscow,cn=russia,cn=Users,dc=md,dc=test", ] response = await http_client.post( "entry/search", @@ -304,6 +338,115 @@ async def test_api_search_recursive_memberof(http_client: AsyncClient) -> None: assert all(obj["object_name"] in members for obj in data["search_result"]) +@pytest.mark.asyncio +@pytest.mark.usefixtures("session") +async def test_search_recursive_member( + http_client: AsyncClient, +) -> None: + """Test recursive member search for user0.""" + user = "cn=user0,cn=users,dc=md,dc=test" + expected_groups = [ + "cn=domain admins,cn=Groups,dc=md,dc=test", + ] + response = await http_client.post( + "entry/search", + json={ + "base_object": "dc=md,dc=test", + "scope": 2, + "deref_aliases": 0, + "size_limit": 1000, + "time_limit": 10, + "types_only": True, + "filter": f"(member:1.2.840.113556.1.4.1941:={user})", + "attributes": [], + "page_number": 1, + }, + ) + data = response.json() + assert data["resultCode"] == LDAPCodes.SUCCESS + dns = {obj["object_name"] for obj in data["search_result"]} + for group in expected_groups: + assert group in dns, f"Group {group} not found in search results" + assert len(data["search_result"]) >= 1 + + +@pytest.mark.asyncio +@pytest.mark.usefixtures("session") +async def test_search_recursive_member_for_many_roots( + http_client: AsyncClient, +) -> None: + """Test recursive member search with nested groups chain.""" + + async def _create_group(dn: str, name: str) -> None: + response = await http_client.post( + "/entry/add", + json={ + "entry": dn, + "password": None, + "attributes": [ + {"type": "name", "vals": [name]}, + {"type": "cn", "vals": [name]}, + { + "type": "objectClass", + "vals": ["top", "posixGroup", "group"], + }, + ], + }, + ) + assert response.json().get("resultCode") == LDAPCodes.SUCCESS + + async def _add_member(dn: str, member: str) -> None: + response = await http_client.patch( + "/entry/update", + json={ + "object": dn, + "changes": [ + { + "operation": Operation.ADD, + "modification": {"type": "member", "vals": [member]}, + }, + ], + }, + ) + assert response.json().get("resultCode") == LDAPCodes.SUCCESS + + group1_dn = "cn=recursive_test_group1,cn=Groups,dc=md,dc=test" + group2_dn = "cn=recursive_test_group2,cn=Groups,dc=md,dc=test" + group3_dn = "cn=recursive_test_group3,cn=Groups,dc=md,dc=test" + user = "cn=user1,cn=moscow,cn=russia,cn=users,dc=md,dc=test" + + await _create_group(group3_dn, "recursive_test_group3") + await _create_group(group2_dn, "recursive_test_group2") + await _create_group(group1_dn, "recursive_test_group1") + + await _add_member(group1_dn, user) + await _add_member(group2_dn, group1_dn) + await _add_member(group3_dn, group2_dn) + + response = await http_client.post( + "entry/search", + json={ + "base_object": "dc=md,dc=test", + "scope": 2, + "deref_aliases": 0, + "size_limit": 1000, + "time_limit": 10, + "types_only": True, + "filter": f"(member:1.2.840.113556.1.4.1941:={user})", + "attributes": [], + "page_number": 1, + }, + ) + data = response.json() + assert data["resultCode"] == LDAPCodes.SUCCESS + dns = {obj["object_name"] for obj in data["search_result"]} + + expected_groups = [group1_dn, group2_dn, group3_dn] + for group in expected_groups: + assert group in dns + assert "cn=domain admins,cn=Groups,dc=md,dc=test" in dns + + @pytest.mark.asyncio @pytest.mark.usefixtures("session") @pytest.mark.parametrize("dataset", test_search_by_rule_anr_dataset) @@ -406,7 +549,7 @@ async def test_api_bytes_to_hex(http_client: AsyncClient) -> None: raw_response = await http_client.post( "entry/search", json={ - "base_object": "cn=user0,cn=users,dc=md,dc=test", + "base_object": "cn=user0,cn=Users,dc=md,dc=test", "scope": 0, "deref_aliases": 0, "size_limit": 1000, @@ -492,3 +635,35 @@ async def test_api_empty_search( assert response["resultCode"] == LDAPCodes.SUCCESS assert not response["search_result"] + + +@pytest.mark.asyncio +@pytest.mark.usefixtures("session") +async def test_api_get_group_name_by_primary_group_id( + http_client: AsyncClient, +) -> None: + """Test api get group path DN by primary group id.""" + primary_group_id = 512 + path_dn = "cn=domain admins,cn=Groups,dc=md,dc=test" + response = await http_client.get( + f"entry/group/primary/{primary_group_id}", + ) + + assert response.status_code == 200 + response = response.json() + + assert response["path_dn"] == path_dn + + +@pytest.mark.asyncio +@pytest.mark.usefixtures("session") +async def test_api_get_group_path_dn_by_primary_group_id_not_found( + http_client: AsyncClient, +) -> None: + """Test api get group path DN by primary group id not found.""" + primary_group_id = 513 + response = await http_client.get( + f"entry/group/primary/{primary_group_id}", + ) + + assert response.status_code == 404 diff --git a/tests/test_api/test_network/test_router.py b/tests/test_api/test_network/test_router.py index 9155ff65d..b98759bcd 100644 --- a/tests/test_api/test_network/test_router.py +++ b/tests/test_api/test_network/test_router.py @@ -68,7 +68,7 @@ async def test_add_policy(http_client: AsyncClient) -> None: "name": "local seriveses", "netmasks": raw_netmasks, "priority": 2, - "groups": ["cn=domain admins,cn=groups,dc=md,dc=test"], + "groups": ["cn=domain admins,cn=Groups,dc=md,dc=test"], "is_http": True, "is_ldap": True, "is_kerberos": True, @@ -108,7 +108,7 @@ async def test_add_policy(http_client: AsyncClient) -> None: "name": "local seriveses", "netmasks": compare_netmasks, "raw": raw_netmasks, - "groups": ["cn=domain admins,cn=groups,dc=md,dc=test"], + "groups": ["cn=domain admins,cn=Groups,dc=md,dc=test"], "priority": 2, "mfa_groups": [], "mfa_status": 0, @@ -153,7 +153,7 @@ async def test_update_policy(http_client: AsyncClient) -> None: "/policy", json={ "id": pol_id, - "groups": ["cn=domain admins,cn=groups,dc=md,dc=test"], + "groups": ["cn=domain admins,cn=Groups,dc=md,dc=test"], "name": "Default open policy 2", }, ) @@ -168,7 +168,7 @@ async def test_update_policy(http_client: AsyncClient) -> None: "name": "Default open policy 2", "netmasks": ["0.0.0.0/0"], "raw": ["0.0.0.0/0"], - "groups": ["cn=domain admins,cn=groups,dc=md,dc=test"], + "groups": ["cn=domain admins,cn=Groups,dc=md,dc=test"], "mfa_groups": [], "mfa_status": 0, "priority": 1, @@ -194,7 +194,7 @@ async def test_update_policy(http_client: AsyncClient) -> None: "mfa_groups": [], "mfa_status": 0, "priority": 1, - "groups": ["cn=domain admins,cn=groups,dc=md,dc=test"], + "groups": ["cn=domain admins,cn=Groups,dc=md,dc=test"], "is_http": True, "is_ldap": True, "is_kerberos": True, @@ -260,7 +260,7 @@ async def test_delete_policy( assert response[0]["priority"] == 1 response = await http_client.delete(f"/policy/{pol_id2}") - assert response.status_code == 422 + assert response.status_code == status.HTTP_422_UNPROCESSABLE_CONTENT assert response.json()["detail"] == "At least one policy should be active" @@ -314,7 +314,7 @@ async def test_switch_policy( assert response.json()[0]["enabled"] is False response = await http_client.patch(f"/policy/{pol_id2}") - assert response.status_code == 422 + assert response.status_code == status.HTTP_422_UNPROCESSABLE_CONTENT assert response.json()["detail"] == "At least one policy should be active" @@ -363,7 +363,7 @@ async def test_swap(http_client: AsyncClient) -> None: "172.8.4.0/24", ], "priority": 2, - "groups": ["cn=domain admins,cn=groups,dc=md,dc=test"], + "groups": ["cn=domain admins,cn=Groups,dc=md,dc=test"], "is_http": True, "is_ldap": True, "is_kerberos": True, @@ -399,7 +399,7 @@ async def test_swap(http_client: AsyncClient) -> None: assert response[0]["priority"] == 1 assert response[0]["groups"] == [ - "cn=domain admins,cn=groups,dc=md,dc=test", + "cn=domain admins,cn=Groups,dc=md,dc=test", ] assert response[1]["priority"] == 2 assert response[1]["name"] == "Default open policy" diff --git a/tests/test_api/test_password_policy/test_password_policy_router.py b/tests/test_api/test_password_policy/test_password_policy_router.py index 0e3dbba8c..01ff38c44 100644 --- a/tests/test_api/test_password_policy/test_password_policy_router.py +++ b/tests/test_api/test_password_policy/test_password_policy_router.py @@ -77,7 +77,7 @@ async def test_get_password_policy_by_dir_path_dn_with_error( password_use_cases: Mock, ) -> None: """Test get one Password Policy endpoint.""" - path = "cn=user1,cn=moscow,cn=russia,cn=users,dc=md,dc=test" + path = "cn=user1,cn=moscow,cn=russia,cn=Users,dc=md,dc=test" response = await http_client_with_login_perm.get( f"/password-policy/by_dir_path_dn/{path}", ) @@ -94,7 +94,7 @@ async def test_get_password_policy_by_dir_path_dn( password_use_cases: Mock, ) -> None: """Test get Password Policy by directory path endpoint.""" - path = "cn=user1,cn=moscow,cn=russia,cn=users,dc=md,dc=test" + path = "cn=user1,cn=moscow,cn=russia,cn=Users,dc=md,dc=test" response = await http_client.get( f"/password-policy/by_dir_path_dn/{path}", ) diff --git a/tests/test_api/test_shadow/conftest.py b/tests/test_api/test_shadow/conftest.py index ab3cb4df8..661b3a307 100644 --- a/tests/test_api/test_shadow/conftest.py +++ b/tests/test_api/test_shadow/conftest.py @@ -50,7 +50,7 @@ async def adding_mfa_user_and_group( response = await http_client.post( "/entry/add", json={ - "entry": "cn=mfa_group,cn=groups,dc=md,dc=test", + "entry": "cn=mfa_group,cn=Groups,dc=md,dc=test", "password": None, "attributes": [ { @@ -111,8 +111,8 @@ async def adding_mfa_user_and_group( { "type": "memberOf", "vals": [ - "cn=mfa_group,cn=groups,dc=md,dc=test", - "cn=domain admins,cn=groups,dc=md,dc=test", + "cn=mfa_group,cn=Groups,dc=md,dc=test", + "cn=domain admins,cn=Groups,dc=md,dc=test", ], }, { diff --git a/tests/test_ldap/policies/test_network/test_pool_client_handler.py b/tests/test_ldap/policies/test_network/test_pool_client_handler.py index 9f212986c..0d2b4c800 100644 --- a/tests/test_ldap/policies/test_network/test_pool_client_handler.py +++ b/tests/test_ldap/policies/test_network/test_pool_client_handler.py @@ -78,7 +78,7 @@ async def test_check_policy_group( assert await network_policy_validator.is_user_group_valid(user, policy) group = await get_group( - dn="cn=domain admins,cn=groups,dc=md,dc=test", + dn="cn=domain admins,cn=Groups,dc=md,dc=test", session=session, ) diff --git a/tests/test_ldap/policies/test_network/test_use_case.py b/tests/test_ldap/policies/test_network/test_use_case.py new file mode 100644 index 000000000..d9714f881 --- /dev/null +++ b/tests/test_ldap/policies/test_network/test_use_case.py @@ -0,0 +1,71 @@ +"""Test network policy use case with empty groups. + +Copyright (c) 2025 MultiFactor +License: https://github.com/MultiDirectoryLab/MultiDirectory/blob/main/LICENSE +""" + +from ipaddress import IPv4Network + +import pytest + +from enums import MFAFlags +from ldap_protocol.policies.network import NetworkPolicyUseCase +from ldap_protocol.policies.network.dto import ( + NetworkPolicyDTO, + NetworkPolicyUpdateDTO, +) + + +@pytest.mark.asyncio +async def test_create_policy( + network_policy_use_case: NetworkPolicyUseCase, +) -> None: + """Test creating policy with empty groups and mfa_groups.""" + dto = NetworkPolicyDTO[None]( + id=None, + name="Test Empty Groups", + netmasks=[IPv4Network("192.168.1.0/24")], + raw=["192.168.1.0/24"], + priority=2, + mfa_status=MFAFlags.DISABLED, + groups=[], + mfa_groups=[], + ) + + result = await network_policy_use_case.create(dto) + poicy = await network_policy_use_case.get(result.id) + assert poicy.groups == [] + assert poicy.mfa_groups == [] + + +@pytest.mark.asyncio +@pytest.mark.usefixtures("setup_session") +async def test_update_policy_to_empty_groups( + network_policy_use_case: NetworkPolicyUseCase, +) -> None: + """Test updating policy from groups to empty.""" + dto = NetworkPolicyDTO[None]( + id=None, + name="Test Update Groups", + netmasks=[IPv4Network("172.16.0.0/12")], + raw=["172.16.0.0/12"], + priority=3, + mfa_status=MFAFlags.DISABLED, + groups=["cn=domain admins,cn=Groups,dc=md,dc=test"], + mfa_groups=["cn=domain admins,cn=Groups,dc=md,dc=test"], + ) + + created = await network_policy_use_case.create(dto) + assert created.groups + assert created.mfa_groups + + update_dto = NetworkPolicyUpdateDTO( + id=created.id, + groups=[], + mfa_groups=[], + ) + + updated = await network_policy_use_case.update(update_dto) + + assert updated.groups == [] + assert updated.mfa_groups == [] diff --git a/tests/test_ldap/policies/test_password/datasets.py b/tests/test_ldap/policies/test_password/datasets.py index ea22dea5a..5aad7bf27 100644 --- a/tests/test_ldap/policies/test_password/datasets.py +++ b/tests/test_ldap/policies/test_password/datasets.py @@ -11,7 +11,7 @@ PasswordPolicyDTO[None, int]( id=None, priority=1, - group_paths=["cn=developers,cn=groups,dc=md,dc=test"], + group_paths=["cn=developers,cn=Groups,dc=md,dc=test"], name="Test Password Policy", language="Latin", is_exact_match=True, @@ -36,7 +36,7 @@ PasswordPolicyDTO[None, int]( id=None, priority=1, - group_paths=["cn=developers,cn=groups,dc=md,dc=test"], + group_paths=["cn=developers,cn=Groups,dc=md,dc=test"], name="Test Password Policy2", language="Latin", is_exact_match=True, @@ -61,7 +61,7 @@ PasswordPolicyDTO[None, int]( id=None, priority=1, - group_paths=["cn=developers,cn=groups,dc=md,dc=test"], + group_paths=["cn=developers,cn=Groups,dc=md,dc=test"], name="Test Password Policy3", language="Latin", is_exact_match=True, diff --git a/tests/test_ldap/policies/test_password/test_use_cases.py b/tests/test_ldap/policies/test_password/test_use_cases.py index a518df2e5..2b03dfd1d 100644 --- a/tests/test_ldap/policies/test_password/test_use_cases.py +++ b/tests/test_ldap/policies/test_password/test_use_cases.py @@ -48,7 +48,7 @@ async def test_get_password_policy_by_dir_path_dn( dto = PasswordPolicyDTO[None, int]( id=None, priority=1, - group_paths=["cn=developers,cn=groups,dc=md,dc=test"], + group_paths=["cn=developers,cn=Groups,dc=md,dc=test"], name="Test Password Policy", language="Latin", is_exact_match=True, @@ -75,7 +75,7 @@ async def test_get_password_policy_by_dir_path_dn( policies = await password_use_cases.get_all() assert any(policy.name == "Test Password Policy" for policy in policies) - path_dn = "cn=user1,cn=moscow,cn=russia,cn=users,dc=md,dc=test" + path_dn = "cn=user1,cn=moscow,cn=russia,cn=Users,dc=md,dc=test" policy = await password_use_cases.get_password_policy_by_dir_path_dn( path_dn, ) @@ -100,7 +100,7 @@ async def test_get_password_policy_by_dir_path_dn_extended( policies = await password_use_cases.get_all() assert any(policy.name == "Test Password Policy" for policy in policies) - path_dn = "cn=user1,cn=moscow,cn=russia,cn=users,dc=md,dc=test" + path_dn = "cn=user1,cn=moscow,cn=russia,cn=Users,dc=md,dc=test" policy = await password_use_cases.get_password_policy_by_dir_path_dn( path_dn, ) diff --git a/tests/test_ldap/test_bind.py b/tests/test_ldap/test_bind.py index e31353880..d43776644 100644 --- a/tests/test_ldap/test_bind.py +++ b/tests/test_ldap/test_bind.py @@ -287,7 +287,7 @@ async def test_bind_invalid_password_or_user( directory = Directory( name="user0", object_class="", - path=["cn=user0", "cn=users", "dc=md", "dc=test"], + path=["cn=user0", "cn=Users", "dc=md", "dc=test"], rdname="cn", ) session.add(directory) @@ -415,7 +415,7 @@ async def test_bind_disabled_user( directory = Directory( name="user0", object_class="", - path=["cn=user0", "cn=users", "dc=md", "dc=test"], + path=["cn=user0", "cn=Users", "dc=md", "dc=test"], rdname="cn", ) session.add(directory) diff --git a/tests/test_ldap/test_container_restrictions/test_container_subcontainers.py b/tests/test_ldap/test_container_restrictions/test_container_subcontainers.py index 4e1ee2de3..08eca190e 100644 --- a/tests/test_ldap/test_container_restrictions/test_container_subcontainers.py +++ b/tests/test_ldap/test_container_restrictions/test_container_subcontainers.py @@ -20,31 +20,31 @@ ("dn", "rdn_attr", "rdn_value", "object_classes"), [ ( - "cn=testcontainer,cn=users,dc=md,dc=test", + "cn=testcontainer,cn=Users,dc=md,dc=test", "cn", "testcontainer", ["container"], ), ( - "ou=testou,cn=users,dc=md,dc=test", + "ou=testou,cn=Users,dc=md,dc=test", "ou", "testou", ["organizationalUnit"], ), ( - "cn=testuser,cn=users,dc=md,dc=test", + "cn=testuser,cn=Users,dc=md,dc=test", "cn", "testuser", ["user", "organizationalPerson"], ), ( - "cn=testgroup,cn=groups,dc=md,dc=test", + "cn=testgroup,cn=Groups,dc=md,dc=test", "cn", "testgroup", ["group", "posixGroup"], ), ( - "cn=testcomputer,cn=computers,dc=md,dc=test", + "cn=testcomputer,cn=Computers,dc=md,dc=test", "cn", "testcomputer", ["computer", "organizationalPerson"], diff --git a/tests/test_ldap/test_ldap3_lib.py b/tests/test_ldap/test_ldap3_lib.py index aae675e15..756f2d142 100644 --- a/tests/test_ldap/test_ldap3_lib.py +++ b/tests/test_ldap/test_ldap3_lib.py @@ -27,11 +27,11 @@ async def test_ldap3_search(ldap_client: LDAPConnection) -> None: @pytest.mark.usefixtures("session") async def test_ldap3_search_memberof(ldap_client: LDAPConnection) -> None: """Test ldap3 search memberof.""" - member = "cn=user1,cn=moscow,cn=russia,cn=users,dc=md,dc=test" + member = "cn=user1,cn=moscow,cn=russia,cn=Users,dc=md,dc=test" result = await ldap_client.search( "dc=md,dc=test", - "(memberOf=cn=developers,cn=groups,dc=md,dc=test)", + "(memberOf=cn=developers,cn=Groups,dc=md,dc=test)", ) assert result diff --git a/tests/test_ldap/test_ldap_schema/__init__.py b/tests/test_ldap/test_ldap_schema/__init__.py new file mode 100644 index 000000000..5134a2e61 --- /dev/null +++ b/tests/test_ldap/test_ldap_schema/__init__.py @@ -0,0 +1,5 @@ +"""Test __init__ module. + +Copyright (c) 2026 MultiFactor +License: https://github.com/MultiDirectoryLab/MultiDirectory/blob/main/LICENSE +""" diff --git a/tests/test_ldap/test_ldap_schema/conftest.py b/tests/test_ldap/test_ldap_schema/conftest.py new file mode 100644 index 000000000..75b13a356 --- /dev/null +++ b/tests/test_ldap/test_ldap_schema/conftest.py @@ -0,0 +1,23 @@ +"""Conftest for LDAP schema AttributeType tests. + +Copyright (c) 2025 MultiFactor +License: https://github.com/MultiDirectoryLab/MultiDirectory/blob/main/LICENSE +""" + +from typing import AsyncIterator + +import pytest_asyncio +from dishka import AsyncContainer, Scope + +from ldap_protocol.ldap_schema.attribute_type_use_case import ( + AttributeTypeUseCase, +) + + +@pytest_asyncio.fixture(scope="function") +async def attribute_type_use_case( + container: AsyncContainer, +) -> AsyncIterator[AttributeTypeUseCase]: + """Get di attribute_type_use_case.""" + async with container(scope=Scope.REQUEST) as container: + yield await container.get(AttributeTypeUseCase) diff --git a/tests/test_ldap/test_ldap_schema/test_attribute_type_use_case.py b/tests/test_ldap/test_ldap_schema/test_attribute_type_use_case.py new file mode 100644 index 000000000..0c359351a --- /dev/null +++ b/tests/test_ldap/test_ldap_schema/test_attribute_type_use_case.py @@ -0,0 +1,36 @@ +"""Test AttributeTypeUseCase. + +Copyright (c) 2026 MultiFactor +License: https://github.com/MultiDirectoryLab/MultiDirectory/blob/main/LICENSE +""" + +import pytest + +from ldap_protocol.ldap_schema.attribute_type_use_case import ( + AttributeTypeUseCase, +) + + +@pytest.mark.asyncio +@pytest.mark.usefixtures("session") +@pytest.mark.usefixtures("setup_session") +async def test_attribute_type_system_flags_use_case_is_not_replicated( + attribute_type_use_case: AttributeTypeUseCase, +) -> None: + """Test AttributeType is not replicated.""" + assert not await attribute_type_use_case.is_attr_replicated("netbootSCPBL") + + +@pytest.mark.asyncio +@pytest.mark.usefixtures("session") +@pytest.mark.usefixtures("setup_session") +async def test_attribute_type_system_flags_use_case_is_replicated( + attribute_type_use_case: AttributeTypeUseCase, +) -> None: + """Test AttributeType is replicated.""" + assert await attribute_type_use_case.is_attr_replicated("objectClass") + await attribute_type_use_case.set_attr_replication_flag( + "objectClass", + False, + ) + assert not await attribute_type_use_case.is_attr_replicated("objectClass") diff --git a/tests/test_ldap/test_netlogon.py b/tests/test_ldap/test_netlogon.py index c12764df5..fa22188e0 100644 --- a/tests/test_ldap/test_netlogon.py +++ b/tests/test_ldap/test_netlogon.py @@ -315,6 +315,8 @@ def test_ds_flags_combination() -> None: | DSFlag.CLOSEST_FLAG | DSFlag.WRITABLE_FLAG | DSFlag.GOOD_TIMESERV_FLAG + | DSFlag.KDC_FLAG + | DSFlag.WS_FLAG ) assert ds_flags == expected_flags diff --git a/tests/test_ldap/test_passwd_change.py b/tests/test_ldap/test_passwd_change.py index cbf503b08..f01a93e7a 100644 --- a/tests/test_ldap/test_passwd_change.py +++ b/tests/test_ldap/test_passwd_change.py @@ -23,7 +23,7 @@ async def test_anonymous_pwd_change( password_utils: PasswordUtils, ) -> None: """Test anonymous pwd change.""" - user_dn = "cn=user0,cn=users,dc=md,dc=test" + user_dn = "cn=user0,cn=Users,dc=md,dc=test" password = creds.pw new_test_password = "Password123" # noqa await anonymous_ldap_client.modify_password( @@ -49,7 +49,7 @@ async def test_bind_pwd_change( password_utils: PasswordUtils, ) -> None: """Test anonymous pwd change.""" - user_dn = "cn=user0,cn=users,dc=md,dc=test" + user_dn = "cn=user0,cn=Users,dc=md,dc=test" password = creds.pw new_test_password = "Password123" # noqa await ldap_client.bind(user_dn, password) diff --git a/tests/test_ldap/test_roles/conftest.py b/tests/test_ldap/test_roles/conftest.py index e82c70526..2d5959e27 100644 --- a/tests/test_ldap/test_roles/conftest.py +++ b/tests/test_ldap/test_roles/conftest.py @@ -24,7 +24,7 @@ async def custom_role(role_dao: RoleDAO) -> RoleDTO: name="Custom Role", creator_upn=None, is_system=False, - groups=["cn=domain users,cn=groups,dc=md,dc=test"], + groups=["cn=domain users,cn=Groups,dc=md,dc=test"], ), ) return await role_dao.get(role_dao.get_last_id()) diff --git a/tests/test_ldap/test_roles/test_multiple_access.py b/tests/test_ldap/test_roles/test_multiple_access.py index da8cc17bc..4691ba0fb 100644 --- a/tests/test_ldap/test_roles/test_multiple_access.py +++ b/tests/test_ldap/test_roles/test_multiple_access.py @@ -18,7 +18,7 @@ from ldap_protocol.ldap_schema.entity_type_dao import EntityTypeDAO from ldap_protocol.roles.ace_dao import AccessControlEntryDAO from ldap_protocol.roles.dataclasses import AccessControlEntryDTO, RoleDTO -from ldap_protocol.utils.queries import get_search_path +from ldap_protocol.utils.queries import get_filter_from_path from repo.pg.tables import queryable_attr as qa from tests.conftest import TestCreds @@ -56,7 +56,7 @@ async def test_multiple_access( role_id=custom_role.get_id(), ace_type=AceType.READ, scope=RoleScope.WHOLE_SUBTREE, - base_dn="cn=russia,cn=users,dc=md,dc=test", + base_dn="cn=russia,cn=Users,dc=md,dc=test", entity_type_id=user_entity_type.id, attribute_type_id=user_account_control_attr.id, is_allow=True, @@ -65,7 +65,7 @@ async def test_multiple_access( role_id=custom_role.get_id(), ace_type=AceType.READ, scope=RoleScope.WHOLE_SUBTREE, - base_dn="cn=russia,cn=users,dc=md,dc=test", + base_dn="cn=russia,cn=Users,dc=md,dc=test", entity_type_id=user_entity_type.id, attribute_type_id=user_principal_name.id, is_allow=True, @@ -74,7 +74,7 @@ async def test_multiple_access( role_id=custom_role.get_id(), ace_type=AceType.WRITE, scope=RoleScope.WHOLE_SUBTREE, - base_dn="cn=russia,cn=users,dc=md,dc=test", + base_dn="cn=russia,cn=Users,dc=md,dc=test", entity_type_id=user_entity_type.id, attribute_type_id=posix_email_attr.id, is_allow=True, @@ -83,7 +83,7 @@ async def test_multiple_access( role_id=custom_role.get_id(), ace_type=AceType.DELETE, scope=RoleScope.WHOLE_SUBTREE, - base_dn="cn=russia,cn=users,dc=md,dc=test", + base_dn="cn=russia,cn=Users,dc=md,dc=test", entity_type_id=user_entity_type.id, attribute_type_id=posix_email_attr.id, is_allow=True, @@ -95,9 +95,9 @@ async def test_multiple_access( await perform_ldap_search_and_validate( settings=settings, creds=creds, - search_base="cn=russia,cn=users,dc=md,dc=test", + search_base="cn=russia,cn=Users,dc=md,dc=test", expected_dn=[ - "dn: cn=user1,cn=moscow,cn=russia,cn=users,dc=md,dc=test", + "dn: cn=user1,cn=moscow,cn=russia,cn=Users,dc=md,dc=test", ], expected_attrs_present=[ "userAccountControl: 512", @@ -106,7 +106,7 @@ async def test_multiple_access( expected_attrs_absent=["posixEmail: user1@mail.com"], ) - user_dn = "cn=user1,cn=moscow,cn=russia,cn=users,dc=md,dc=test" + user_dn = "cn=user1,cn=moscow,cn=russia,cn=Users,dc=md,dc=test" query = ( select(Directory) @@ -114,7 +114,7 @@ async def test_multiple_access( subqueryload(qa(Directory.attributes)), joinedload(qa(Directory.user)), ) - .filter_by(path=get_search_path(user_dn)) + .filter(get_filter_from_path(user_dn)) ) directory = (await session.scalars(query)).one() diff --git a/tests/test_ldap/test_roles/test_search.py b/tests/test_ldap/test_roles/test_search.py index a20e8f0dd..0795be89b 100644 --- a/tests/test_ldap/test_roles/test_search.py +++ b/tests/test_ldap/test_roles/test_search.py @@ -30,7 +30,7 @@ async def test_role_search_1(settings: Settings, creds: TestCreds) -> None: settings=settings, creds=creds, search_base=BASE_DN, - expected_dn=["dn: cn=user_non_admin,cn=users,dc=md,dc=test"], + expected_dn=["dn: cn=user_non_admin,cn=Users,dc=md,dc=test"], expected_attrs_present=[], expected_attrs_absent=[], ) @@ -52,7 +52,7 @@ async def test_role_search_2( role_id=custom_role.get_id(), ace_type=AceType.READ, scope=RoleScope.BASE_OBJECT, - base_dn="cn=groups,dc=md,dc=test", + base_dn="cn=Groups,dc=md,dc=test", attribute_type_id=None, entity_type_id=None, is_allow=True, @@ -65,8 +65,8 @@ async def test_role_search_2( creds=creds, search_base=BASE_DN, expected_dn=[ - "dn: cn=groups,dc=md,dc=test", - "dn: cn=user_non_admin,cn=users,dc=md,dc=test", + "dn: cn=Groups,dc=md,dc=test", + "dn: cn=user_non_admin,cn=Users,dc=md,dc=test", ], expected_attrs_present=[], expected_attrs_absent=[], @@ -102,9 +102,9 @@ async def test_role_search_3( creds=creds, search_base=BASE_DN, expected_dn=[ - "dn: cn=groups,dc=md,dc=test", - "dn: cn=users,dc=md,dc=test", - "dn: cn=user_non_admin,cn=users,dc=md,dc=test", + "dn: cn=Groups,dc=md,dc=test", + "dn: cn=Users,dc=md,dc=test", + "dn: cn=user_non_admin,cn=Users,dc=md,dc=test", "dn: ou=test_bit_rules,dc=md,dc=test", "dn: ou=testModifyDn1,dc=md,dc=test", "dn: ou=testModifyDn3,dc=md,dc=test", @@ -130,7 +130,7 @@ async def test_role_search_4( role_id=custom_role.get_id(), ace_type=AceType.READ, scope=RoleScope.WHOLE_SUBTREE, - base_dn="cn=groups,dc=md,dc=test", + base_dn="cn=Groups,dc=md,dc=test", attribute_type_id=None, entity_type_id=None, is_allow=True, @@ -143,13 +143,13 @@ async def test_role_search_4( creds=creds, search_base=BASE_DN, expected_dn=[ - "dn: cn=admin login only,cn=groups,dc=md,dc=test", - "dn: cn=groups,dc=md,dc=test", - "dn: cn=domain admins,cn=groups,dc=md,dc=test", - "dn: cn=domain computers,cn=groups,dc=md,dc=test", - "dn: cn=developers,cn=groups,dc=md,dc=test", - "dn: cn=domain users,cn=groups,dc=md,dc=test", - "dn: cn=user_non_admin,cn=users,dc=md,dc=test", + "dn: cn=admin login only,cn=Groups,dc=md,dc=test", + "dn: cn=Groups,dc=md,dc=test", + "dn: cn=domain admins,cn=Groups,dc=md,dc=test", + "dn: cn=domain computers,cn=Groups,dc=md,dc=test", + "dn: cn=developers,cn=Groups,dc=md,dc=test", + "dn: cn=domain users,cn=Groups,dc=md,dc=test", + "dn: cn=user_non_admin,cn=Users,dc=md,dc=test", ], expected_attrs_present=[], expected_attrs_absent=[], @@ -189,11 +189,11 @@ async def test_role_search_5( creds=creds, search_base=BASE_DN, expected_dn=[ - "dn: cn=user1,cn=moscow,cn=russia,cn=users,dc=md,dc=test", - "dn: cn=user_non_admin,cn=users,dc=md,dc=test", - "dn: cn=user_admin_for_roles,cn=users,dc=md,dc=test", - "dn: cn=user_admin,cn=users,dc=md,dc=test", - "dn: cn=user0,cn=users,dc=md,dc=test", + "dn: cn=user1,cn=moscow,cn=russia,cn=Users,dc=md,dc=test", + "dn: cn=user_non_admin,cn=Users,dc=md,dc=test", + "dn: cn=user_admin_for_roles,cn=Users,dc=md,dc=test", + "dn: cn=user_admin,cn=Users,dc=md,dc=test", + "dn: cn=user0,cn=Users,dc=md,dc=test", "dn: cn=user_admin_1,ou=test_bit_rules,dc=md,dc=test", "dn: cn=user_admin_2,ou=test_bit_rules,dc=md,dc=test", "dn: cn=user_admin_3,ou=test_bit_rules,dc=md,dc=test", @@ -231,7 +231,7 @@ async def test_role_search_6( role_id=custom_role.get_id(), ace_type=AceType.READ, scope=RoleScope.BASE_OBJECT, - base_dn="cn=user0,cn=users,dc=md,dc=test", + base_dn="cn=user0,cn=Users,dc=md,dc=test", attribute_type_id=posix_email_attr.id, entity_type_id=user_entity_type.id, is_allow=True, @@ -242,9 +242,9 @@ async def test_role_search_6( await perform_ldap_search_and_validate( settings=settings, creds=creds, - search_base="cn=user0,cn=users,dc=md,dc=test", + search_base="cn=user0,cn=Users,dc=md,dc=test", expected_dn=[ - "dn: cn=user0,cn=users,dc=md,dc=test", + "dn: cn=user0,cn=Users,dc=md,dc=test", ], expected_attrs_present=[ "posixEmail: abctest@mail.com", @@ -281,7 +281,7 @@ async def test_role_search_7( role_id=custom_role.get_id(), ace_type=AceType.READ, scope=RoleScope.BASE_OBJECT, - base_dn="cn=user0,cn=users,dc=md,dc=test", + base_dn="cn=user0,cn=Users,dc=md,dc=test", attribute_type_id=None, entity_type_id=user_entity_type.id, is_allow=True, @@ -290,7 +290,7 @@ async def test_role_search_7( role_id=custom_role.get_id(), ace_type=AceType.READ, scope=RoleScope.BASE_OBJECT, - base_dn="cn=user0,cn=users,dc=md,dc=test", + base_dn="cn=user0,cn=Users,dc=md,dc=test", attribute_type_id=description_attr.id, entity_type_id=user_entity_type.id, is_allow=False, @@ -302,9 +302,9 @@ async def test_role_search_7( await perform_ldap_search_and_validate( settings=settings, creds=creds, - search_base="cn=user0,cn=users,dc=md,dc=test", + search_base="cn=user0,cn=Users,dc=md,dc=test", expected_dn=[ - "dn: cn=user0,cn=users,dc=md,dc=test", + "dn: cn=user0,cn=Users,dc=md,dc=test", ], expected_attrs_present=[ "posixEmail: abctest@mail.com", @@ -350,7 +350,7 @@ async def test_role_search_8( role_id=custom_role.get_id(), ace_type=AceType.READ, scope=RoleScope.BASE_OBJECT, - base_dn="cn=user0,cn=users,dc=md,dc=test", + base_dn="cn=user0,cn=Users,dc=md,dc=test", attribute_type_id=description_attr.id, entity_type_id=user_entity_type.id, is_allow=True, @@ -362,9 +362,9 @@ async def test_role_search_8( await perform_ldap_search_and_validate( settings=settings, creds=creds, - search_base="cn=user0,cn=users,dc=md,dc=test", + search_base="cn=user0,cn=Users,dc=md,dc=test", expected_dn=[ - "dn: cn=user0,cn=users,dc=md,dc=test", + "dn: cn=user0,cn=Users,dc=md,dc=test", ], expected_attrs_present=[ "description: 123 desc", @@ -404,7 +404,7 @@ async def test_role_search_9( role_id=custom_role.get_id(), ace_type=AceType.READ, scope=RoleScope.WHOLE_SUBTREE, - base_dn="cn=user0,cn=users,dc=md,dc=test", + base_dn="cn=user0,cn=Users,dc=md,dc=test", attribute_type_id=posix_email_attr.id, entity_type_id=user_entity_type.id, is_allow=True, @@ -413,7 +413,7 @@ async def test_role_search_9( role_id=custom_role.get_id(), ace_type=AceType.READ, scope=RoleScope.BASE_OBJECT, - base_dn="cn=user0,cn=users,dc=md,dc=test", + base_dn="cn=user0,cn=Users,dc=md,dc=test", attribute_type_id=description_attr.id, entity_type_id=user_entity_type.id, is_allow=False, @@ -425,9 +425,9 @@ async def test_role_search_9( await perform_ldap_search_and_validate( settings=settings, creds=creds, - search_base="cn=user0,cn=users,dc=md,dc=test", + search_base="cn=user0,cn=Users,dc=md,dc=test", expected_dn=[ - "dn: cn=user0,cn=users,dc=md,dc=test", + "dn: cn=user0,cn=Users,dc=md,dc=test", ], expected_attrs_present=[ "posixEmail: abctest@mail.com", diff --git a/tests/test_ldap/test_util/test_add.py b/tests/test_ldap/test_util/test_add.py index eef7d047b..b0312bc98 100644 --- a/tests/test_ldap/test_util/test_add.py +++ b/tests/test_ldap/test_util/test_add.py @@ -23,7 +23,7 @@ from ldap_protocol.roles.ace_dao import AccessControlEntryDAO from ldap_protocol.roles.dataclasses import AccessControlEntryDTO, RoleDTO from ldap_protocol.roles.role_dao import RoleDAO -from ldap_protocol.utils.queries import get_search_path +from ldap_protocol.utils.queries import get_filter_from_path from repo.pg.tables import queryable_attr as qa from tests.conftest import TestCreds @@ -37,7 +37,6 @@ async def test_ldap_root_add( ) -> None: """Test ldapadd on server.""" dn = "cn=test,dc=md,dc=test" - search_path = get_search_path(dn) with tempfile.NamedTemporaryFile("w") as file: file.write( ( @@ -46,7 +45,7 @@ async def test_ldap_root_add( "cn: test\n" "objectClass: organization\n" "objectClass: top\n" - "memberOf: cn=domain admins,cn=groups,dc=md,dc=test\n" + "memberOf: cn=domain admins,cn=Groups,dc=md,dc=test\n" ), ) file.seek(0) @@ -73,7 +72,7 @@ async def test_ldap_root_add( new_dir_query = ( select(Directory) .options(subqueryload(qa(Directory.attributes))) - .filter_by(path=search_path) + .filter(get_filter_from_path(dn)) ) new_dir = (await session.scalars(new_dir_query)).one() @@ -96,8 +95,8 @@ async def test_ldap_user_add_with_group( ) -> None: """Test ldapadd on server.""" user_dn = "cn=test,dc=md,dc=test" - user_search_path = get_search_path(user_dn) - group_dn = "cn=domain admins,cn=groups,dc=md,dc=test" + + group_dn = "cn=domain admins,cn=Groups,dc=md,dc=test" with tempfile.NamedTemporaryFile("w") as file: file.write( @@ -144,7 +143,7 @@ async def test_ldap_user_add_with_group( new_dir_query = ( select(Directory) .options(subqueryload(qa(Directory.attributes)), membership) - .filter_by(path=user_search_path) + .filter(get_filter_from_path(user_dn)) ) new_dir = (await session.scalars(new_dir_query)).one() @@ -163,8 +162,7 @@ async def test_ldap_user_add_group_with_group( user: dict, ) -> None: """Test ldapadd on server.""" - child_group_dn = "cn=twisted,cn=groups,dc=md,dc=test" - child_group_search_path = get_search_path(child_group_dn) + child_group_dn = "cn=twisted,cn=Groups,dc=md,dc=test" group_dn = "cn=domain admins,cn=groups,dc=md,dc=test" with tempfile.NamedTemporaryFile("w") as file: @@ -208,13 +206,16 @@ async def test_ldap_user_add_group_with_group( new_dir_query = ( select(Directory) .options(membership) - .filter_by(path=child_group_search_path) + .filter(get_filter_from_path(child_group_dn)) ) new_dir = (await session.scalars(new_dir_query)).one() assert new_dir.name == "twisted" - groups = [group.directory.path_dn for group in new_dir.group.parent_groups] + groups = [ + group.directory.path_dn.lower() + for group in new_dir.group.parent_groups + ] assert group_dn in groups @@ -287,7 +288,7 @@ async def try_add() -> int: name="Add Role", creator_upn=None, is_system=False, - groups=["cn=domain users,cn=groups," + base_dn], + groups=["cn=domain users,cn=Groups," + base_dn], ), ) @@ -355,7 +356,7 @@ async def test_ldap_user_add_with_duplicate_groups( ) -> None: """Duplicate memberOf yields single membership.""" user_dn = "cn=dup,dc=md,dc=test" - group_dn = "cn=domain admins,cn=groups,dc=md,dc=test" + group_dn = "cn=domain admins,cn=Groups,dc=md,dc=test" with tempfile.NamedTemporaryFile("w") as file: ldif = [ @@ -394,11 +395,10 @@ async def test_ldap_user_add_with_duplicate_groups( assert result == 0 - user_search_path = get_search_path(user_dn) user_row = await session.scalar( select(User) .join(qa(User.directory)) - .filter_by(path=user_search_path) + .filter(get_filter_from_path(user_dn)) .options( selectinload(qa(User.groups)).selectinload(qa(Group.directory)), ), diff --git a/tests/test_ldap/test_util/test_delete.py b/tests/test_ldap/test_util/test_delete.py index f93d1eed9..bff5011c2 100644 --- a/tests/test_ldap/test_util/test_delete.py +++ b/tests/test_ldap/test_util/test_delete.py @@ -39,7 +39,7 @@ async def test_ldap_delete( "cn: test\n" "objectClass: organization\n" "objectClass: top\n" - "memberOf: cn=domain admins,cn=groups,dc=md,dc=test\n" + "memberOf: cn=domain admins,cn=Groups,dc=md,dc=test\n" ), ) file.seek(0) @@ -94,7 +94,7 @@ async def test_ldap_delete( "-x", "-w", user["password"], - "cn=user0,cn=users,dc=md,dc=test", + "cn=user0,cn=Users,dc=md,dc=test", stdout=asyncio.subprocess.PIPE, stderr=asyncio.subprocess.PIPE, ) @@ -171,7 +171,7 @@ async def try_delete() -> int: name="Delete Role", creator_upn=None, is_system=False, - groups=["cn=domain users,cn=groups," + base_dn], + groups=["cn=domain users,cn=Groups," + base_dn], ), ) @@ -223,7 +223,7 @@ async def test_ldap_delete_primary_object_classes( user: dict, ) -> None: """Test deleting primary object class.""" - entry_dn = "cn=user0,cn=users,dc=md,dc=test" + entry_dn = "cn=user0,cn=Users,dc=md,dc=test" with tempfile.NamedTemporaryFile("w") as file: file.write( ( diff --git a/tests/test_ldap/test_util/test_modify.py b/tests/test_ldap/test_util/test_modify.py index 02b174b4b..b5eadf172 100644 --- a/tests/test_ldap/test_util/test_modify.py +++ b/tests/test_ldap/test_util/test_modify.py @@ -18,14 +18,15 @@ from config import Settings from entities import Directory, Group -from enums import AceType, RoleScope +from enums import AceType, EntityTypeNames, RoleScope from ldap_protocol.kerberos.base import AbstractKadmin from ldap_protocol.ldap_codes import LDAPCodes +from ldap_protocol.ldap_schema.entity_type_dao import EntityTypeDAO from ldap_protocol.objects import Operation from ldap_protocol.roles.ace_dao import AccessControlEntryDAO from ldap_protocol.roles.dataclasses import AccessControlEntryDTO, RoleDTO from ldap_protocol.roles.role_dao import RoleDAO -from ldap_protocol.utils.queries import get_search_path +from ldap_protocol.utils.queries import get_filter_from_path from repo.pg.tables import Attribute, directory_table, queryable_attr as qa from tests.conftest import TestCreds @@ -38,14 +39,14 @@ async def test_ldap_base_modify( user: dict, ) -> None: """Test ldapmodify on server.""" - dn = "cn=user0,cn=users,dc=md,dc=test" + dn = "cn=user0,cn=Users,dc=md,dc=test" query = ( select(Directory) .options( subqueryload(qa(Directory.attributes)), joinedload(qa(Directory.user)), ) - .filter_by(path=get_search_path(dn)) + .filter(get_filter_from_path(dn)) ) directory = (await session.scalars(query)).one() @@ -139,11 +140,11 @@ async def test_ldap_membersip_user_delete( user: dict, ) -> None: """Test ldapmodify on server.""" - dn = "cn=user_admin,cn=users,dc=md,dc=test" + dn = "cn=user_admin,cn=Users,dc=md,dc=test" query = ( select(Directory) .options(selectinload(qa(Directory.groups))) - .filter_by(path=get_search_path(dn)) + .filter(get_filter_from_path(dn)) ) directory = (await session.scalars(query)).one() @@ -187,11 +188,11 @@ async def test_ldap_membersip_self_delete_admin_domain( user: dict, ) -> None: """Test ldapmodify on server.""" - dn = "cn=user0,cn=users,dc=md,dc=test" + dn = "cn=user0,cn=Users,dc=md,dc=test" query = ( select(Directory) .options(selectinload(qa(Directory.groups))) - .filter_by(path=get_search_path(dn)) + .filter(get_filter_from_path(dn)) ) directory = (await session.scalars(query)).one() @@ -201,7 +202,7 @@ async def test_ldap_membersip_self_delete_admin_domain( with tempfile.NamedTemporaryFile("w") as file: file.write( f"dn: {dn}\nchangetype: modify\ndelete: memberOf\n" - "memberOf: cn=domain admins,cn=groups,dc=md,dc=test\n", + "memberOf: cn=domain admins,cn=Groups,dc=md,dc=test\n", ) file.seek(0) proc = await asyncio.create_subprocess_exec( @@ -250,7 +251,7 @@ async def test_self_disable( response = await http_client.patch( "entry/update", json={ - "object": "cn=user0,cn=users,dc=md,dc=test", + "object": "cn=user0,cn=Users,dc=md,dc=test", "changes": [ { "operation": Operation.REPLACE, @@ -288,7 +289,7 @@ async def test_ldap_membersip_user_add( creds: TestCreds, ) -> None: """Test ldapmodify on server.""" - dn = "cn=user_non_admin,cn=users,dc=md,dc=test" + dn = "cn=user_non_admin,cn=Users,dc=md,dc=test" query = ( select(Directory) .options( @@ -296,7 +297,7 @@ async def test_ldap_membersip_user_add( qa(Group.directory), ), ) - .filter_by(path=get_search_path(dn)) + .filter(get_filter_from_path(dn)) ) directory = (await session.scalars(query)).one() @@ -312,7 +313,7 @@ async def test_ldap_membersip_user_add( f"dn: {dn}\n" "changetype: modify\n" "add: memberOf\n" - "memberOf: cn=domain admins,cn=groups,dc=md,dc=test\n" + "memberOf: cn=domain admins,cn=Groups,dc=md,dc=test\n" "-\n" ), ) @@ -351,17 +352,17 @@ async def test_ldap_membersip_user_replace( user: dict, ) -> None: """Test ldapmodify on server.""" - dn = "cn=user_admin,cn=users,dc=md,dc=test" + dn = "cn=user_admin,cn=Users,dc=md,dc=test" query = ( select(Directory) .options(selectinload(qa(Directory.groups))) - .filter_by(path=get_search_path(dn)) + .filter(get_filter_from_path(dn)) ) directory = (await session.scalars(query)).one() assert directory.groups - new_group_dn = "cn=twisted,cn=groups,dc=md,dc=test\n" + new_group_dn = "cn=twisted,cn=Groups,dc=md,dc=test\n" # add new group with tempfile.NamedTemporaryFile("w") as file: @@ -372,7 +373,7 @@ async def test_ldap_membersip_user_replace( "cn: twisted\n" "objectClass: group\n" "objectClass: top\n" - "memberOf: cn=domain admins,cn=groups,dc=md,dc=test\n" + "memberOf: cn=domain admins,cn=Groups,dc=md,dc=test\n" ), ) file.seek(0) @@ -403,7 +404,7 @@ async def test_ldap_membersip_user_replace( f"dn: {dn}\n" "changetype: modify\n" "replace: memberOf\n" - "memberOf: cn=twisted,cn=groups,dc=md,dc=test\n" + "memberOf: cn=twisted,cn=Groups,dc=md,dc=test\n" "-\n" ), ) @@ -442,7 +443,7 @@ async def test_ldap_membersip_grp_replace( user: dict, ) -> None: """Test ldapmodify on server.""" - dn = "cn=domain admins,cn=groups,dc=md,dc=test" + dn = "cn=domain admins,cn=Groups,dc=md,dc=test" query = ( select(Directory) @@ -451,7 +452,7 @@ async def test_ldap_membersip_grp_replace( .selectinload(qa(Group.parent_groups)) .selectinload(qa(Group.directory)), ) - .filter_by(path=get_search_path(dn)) + .filter(get_filter_from_path(dn)) ) directory = await session.scalar(query) @@ -463,7 +464,7 @@ async def test_ldap_membersip_grp_replace( with tempfile.NamedTemporaryFile("w") as file: file.write( ( - "dn: cn=twisted1,cn=groups,dc=md,dc=test\n" + "dn: cn=twisted1,cn=Groups,dc=md,dc=test\n" "name: twisted\n" "cn: twisted\n" "objectClass: group\n" @@ -498,7 +499,7 @@ async def test_ldap_membersip_grp_replace( f"dn: {dn}\n" "changetype: modify\n" "replace: memberOf\n" - "memberOf: cn=twisted1,cn=groups,dc=md,dc=test\n" + "memberOf: cn=twisted1,cn=Groups,dc=md,dc=test\n" "-\n" ), ) @@ -537,7 +538,7 @@ async def test_ldap_modify_dn( user: dict, ) -> None: """Test ldapmodify on server.""" - dn = "cn=user0,cn=users,dc=md,dc=test" + dn = "cn=user0,cn=Users,dc=md,dc=test" with tempfile.NamedTemporaryFile("w") as file: file.write( @@ -546,7 +547,7 @@ async def test_ldap_modify_dn( "changetype: modrdn\n" "newrdn: cn=user2\n" "deleteoldrdn: 1\n" - "newsuperior: cn=users,dc=md,dc=test\n" + "newsuperior: cn=Users,dc=md,dc=test\n" ), ) file.seek(0) @@ -574,7 +575,7 @@ async def test_ldap_modify_dn( select(Directory) .filter( directory_table.c.path - == ["dc=test", "dc=md", "cn=users", "cn=user2"], + == ["dc=test", "dc=md", "cn=Users", "cn=user2"], directory_table.c.entity_type_id.isnot(None), ), ) # fmt: skip @@ -588,7 +589,7 @@ async def test_ldap_modify_password_change( creds: TestCreds, ) -> None: """Test ldapmodify on server.""" - dn = "cn=user0,cn=users,dc=md,dc=test" + dn = "cn=user0,cn=Users,dc=md,dc=test" new_password = "Password12345" # noqa with tempfile.NamedTemporaryFile("w") as file: @@ -655,9 +656,8 @@ async def test_ldap_modify_with_ap( access_control_entry_dao: AccessControlEntryDAO, ) -> None: """Test ldapmodify on server.""" - dn = "cn=users,dc=md,dc=test" + dn = "cn=Users,dc=md,dc=test" base_dn = "dc=md,dc=test" - search_path = get_search_path(dn) query = ( select(Directory) @@ -665,7 +665,7 @@ async def test_ldap_modify_with_ap( subqueryload(qa(Directory.attributes)), joinedload(qa(Directory.user)), ) - .filter_by(path=search_path) + .filter(get_filter_from_path(dn)) ) directory = await session.scalar(query) @@ -719,7 +719,7 @@ async def try_modify() -> int: name="Modify Role", creator_upn=None, is_system=False, - groups=["cn=domain users,cn=groups," + base_dn], + groups=["cn=domain users,cn=Groups," + base_dn], ), ) @@ -781,6 +781,104 @@ async def try_modify() -> int: assert "posixEmail" not in attributes +@pytest.mark.asyncio +@pytest.mark.usefixtures("setup_session") +async def test_ldap_modify_rdn( + settings: Settings, + creds: TestCreds, +) -> None: + """Test modify RDN.""" + dn = "cn=user0,cn=Users,dc=md,dc=test" + + async def try_modify() -> int: + with tempfile.NamedTemporaryFile("w") as file: + file.write( + (f"dn: {dn}\nchangetype: modify\nreplace: cn\ncn: modme\n-\n"), + ) + file.seek(0) + proc = await asyncio.create_subprocess_exec( + "ldapmodify", + "-vvv", + "-H", + f"ldap://{settings.HOST}:{settings.PORT}", + "-D", + "user_admin", + "-x", + "-w", + creds.pw, + "-f", + file.name, + stdout=asyncio.subprocess.PIPE, + stderr=asyncio.subprocess.PIPE, + ) + + await proc.communicate() + return await proc.wait() + + assert await try_modify() == LDAPCodes.NOT_ALLOWED_ON_RDN + + +@pytest.mark.asyncio +@pytest.mark.usefixtures("setup_session") +async def test_ldap_modify_name( + session: AsyncSession, + settings: Settings, + creds: TestCreds, +) -> None: + """Test modify name.""" + dn = "cn=user0,cn=Users,dc=md,dc=test" + + query = ( + select(Directory) + .options( + subqueryload(qa(Directory.attributes)), + joinedload(qa(Directory.user)), + ) + .filter(get_filter_from_path(dn)) + ) + + old_directory = await session.scalar(query) + assert old_directory + + async def try_modify() -> int: + with tempfile.NamedTemporaryFile("w") as file: + file.write( + ( + f"dn: {dn}\n" + "changetype: modify\n" + "replace: name\n" + "name: changename\n" + "-\n" + ), + ) + file.seek(0) + proc = await asyncio.create_subprocess_exec( + "ldapmodify", + "-vvv", + "-H", + f"ldap://{settings.HOST}:{settings.PORT}", + "-D", + "user_admin", + "-x", + "-w", + creds.pw, + "-f", + file.name, + stdout=asyncio.subprocess.PIPE, + stderr=asyncio.subprocess.PIPE, + ) + + await proc.communicate() + return await proc.wait() + + assert await try_modify() == LDAPCodes.SUCCESS + + new_directory = await session.scalar(query) + assert new_directory + + assert old_directory.name == new_directory.name + + async def run_single_modify( settings: Settings, operation: Literal["add", "delete", "replace"], @@ -822,6 +920,49 @@ async def run_single_modify( return await proc.wait() +async def run_single_modrdn( + *, + settings: Settings, + bind_dn: str, + password: str, + dn: str, + newrdn: str, + deleteoldrdn: int = 1, + newsuperior: str | None = None, +) -> int: + with tempfile.NamedTemporaryFile("w") as file: + lines = [ + f"dn: {dn}", + "changetype: modrdn", + f"newrdn: {newrdn}", + f"deleteoldrdn: {deleteoldrdn}", + ] + if newsuperior is not None: + lines.append(f"newsuperior: {newsuperior}") + + file.write("\n".join(lines) + "\n") + file.seek(0) + + proc = await asyncio.create_subprocess_exec( + "ldapmodify", + "-vvv", + "-H", + f"ldap://{settings.HOST}:{settings.PORT}", + "-D", + bind_dn, + "-x", + "-w", + password, + "-f", + file.name, + stdout=asyncio.subprocess.PIPE, + stderr=asyncio.subprocess.PIPE, + ) + + await proc.communicate() + return await proc.wait() + + async def fetch_directory_by_dn(session: AsyncSession, dn: str) -> Directory: """Fetch directory by DN.""" query = ( @@ -831,7 +972,7 @@ async def fetch_directory_by_dn(session: AsyncSession, dn: str) -> Directory: selectinload(qa(Directory.attributes)), joinedload(qa(Directory.group)), ) - .filter(qa(Directory.path) == get_search_path(dn)) + .filter(get_filter_from_path(dn)) ) return (await session.scalars(query)).one() @@ -843,25 +984,25 @@ async def fetch_directory_by_dn(session: AsyncSession, dn: str) -> Directory: [ ( "add", - "cn=developers,cn=groups,dc=md,dc=test", + "cn=developers,cn=Groups,dc=md,dc=test", {"domain admins", "developers"}, True, ), ( "add", - "cn=domain admins,cn=groups,dc=md,dc=test", + "cn=domain admins,cn=Groups,dc=md,dc=test", {"domain admins"}, True, ), ( "delete", - "cn=developers,cn=groups,dc=md,dc=test", + "cn=developers,cn=Groups,dc=md,dc=test", {"domain admins", "developers"}, False, ), ( "replace", - "cn=developers,cn=groups,dc=md,dc=test", + "cn=developers,cn=Groups,dc=md,dc=test", {"domain admins", "developers"}, True, ), @@ -877,7 +1018,7 @@ async def test_ldap_modify_primary_group_id_scenarios( creds: TestCreds, ) -> None: """Test ldapmodify request with primaryGroupID for various scenarios.""" - user_dn = "cn=user_admin,cn=users,dc=md,dc=test" + user_dn = "cn=user_admin,cn=Users,dc=md,dc=test" user_dir = await fetch_directory_by_dn(session, user_dn) group_dir = await fetch_directory_by_dn(session, group_dn) @@ -932,22 +1073,22 @@ async def test_ldap_modify_primary_group_id_scenarios( ("values", "include_dev_group", "expected_result", "expected_groups"), [ ( - ["cn=domain admins,cn=groups,dc=md,dc=test"], + ["cn=domain admins,cn=Groups,dc=md,dc=test"], True, 1, {"domain admins", "developers"}, ), ( - ["cn=domain admins,cn=groups,dc=md,dc=test"], + ["cn=domain admins,cn=Groups,dc=md,dc=test"], False, 0, {"domain admins"}, ), ( [ - "cn=domain admins,cn=groups,dc=md,dc=test", - "cn=developers,cn=groups,dc=md,dc=test", - "cn=domain computers,cn=groups,dc=md,dc=test", + "cn=domain admins,cn=Groups,dc=md,dc=test", + "cn=developers,cn=Groups,dc=md,dc=test", + "cn=domain computers,cn=Groups,dc=md,dc=test", ], True, 0, @@ -965,8 +1106,8 @@ async def test_ldap_modify_replace_memberof_primary_group_various( creds: TestCreds, ) -> None: """Test ldapmodify request replace memberOf attribute.""" - user_dn = "cn=user_admin,cn=users,dc=md,dc=test" - dev_group_dn = "cn=developers,cn=groups,dc=md,dc=test" + user_dn = "cn=user_admin,cn=Users,dc=md,dc=test" + dev_group_dn = "cn=developers,cn=Groups,dc=md,dc=test" user_dir = await fetch_directory_by_dn(session, user_dn) dev_group_dir = await fetch_directory_by_dn(session, dev_group_dn) @@ -998,3 +1139,244 @@ async def test_ldap_modify_replace_memberof_primary_group_various( user_dir = await fetch_directory_by_dn(session, user_dn) group_names = {group.directory.name for group in user_dir.groups} assert group_names == expected_groups + + +@pytest.mark.asyncio +@pytest.mark.usefixtures("setup_session") +async def test_modify_dn_rename_with_ap( + settings: Settings, + creds: TestCreds, + role_dao: RoleDAO, + access_control_entry_dao: AccessControlEntryDAO, + entity_type_dao: EntityTypeDAO, + attribute_type_dao: EntityTypeDAO, +) -> None: + dn = "cn=user0,cn=Users,dc=md,dc=test" + base_dn = "dc=md,dc=test" + + user_entity_type = await entity_type_dao.get(EntityTypeNames.USER) + assert user_entity_type + + rdn_attr = await attribute_type_dao.get("cn") + assert rdn_attr + + res = await run_single_modrdn( + settings=settings, + bind_dn="user_non_admin", + password=creds.pw, + dn=dn, + newrdn="cn=user2", + deleteoldrdn=1, + ) + + assert res == LDAPCodes.INSUFFICIENT_ACCESS_RIGHTS + + await role_dao.create( + dto=RoleDTO( + name="Modify Role", + creator_upn=None, + is_system=False, + groups=["cn=domain users,cn=Groups," + base_dn], + ), + ) + + role_id = role_dao.get_last_id() + + write_ace = AccessControlEntryDTO( + role_id=role_id, + ace_type=AceType.WRITE, + scope=RoleScope.WHOLE_SUBTREE, + base_dn=dn, + attribute_type_id=rdn_attr.id, + entity_type_id=user_entity_type.id, + is_allow=True, + ) + delete_ace = AccessControlEntryDTO( + role_id=role_id, + ace_type=AceType.DELETE, + scope=RoleScope.WHOLE_SUBTREE, + base_dn=dn, + attribute_type_id=rdn_attr.id, + entity_type_id=user_entity_type.id, + is_allow=True, + ) + + await access_control_entry_dao.create_bulk([write_ace, delete_ace]) + + aces_before = await access_control_entry_dao.get_all() + + res = await run_single_modrdn( + settings=settings, + bind_dn="user_non_admin", + password=creds.pw, + dn=dn, + newrdn="cn=user2", + deleteoldrdn=1, + ) + + assert res == LDAPCodes.SUCCESS + + aces_after = await access_control_entry_dao.get_all() + + inherited_aces_before = [ + ace for ace in aces_before if ace.base_dn == base_dn + ] + explicit_aces_before = [ + ace for ace in aces_before if ace.base_dn != base_dn + ] + + inherited_aces_after = [ + ace for ace in aces_after if ace.base_dn == base_dn + ] + explicit_aces_after = [ace for ace in aces_after if ace.base_dn != base_dn] + + assert inherited_aces_before == inherited_aces_after + assert len(explicit_aces_after) == len(explicit_aces_before) + + # NOTE: Check explicit ACEs have same properties except base_dn + for ace_before, ace_after in zip( + explicit_aces_before, + explicit_aces_after, + ): + assert ace_before.id == ace_after.id + assert ace_before.role_id == ace_after.role_id + assert ace_before.ace_type == ace_after.ace_type + assert ace_before.scope == ace_after.scope + assert ace_before.attribute_type_id == ace_after.attribute_type_id + assert ace_before.entity_type_id == ace_after.entity_type_id + assert ace_before.is_allow == ace_after.is_allow + + assert ace_after.base_dn == "cn=user2,cn=Users,dc=md,dc=test" + + +@pytest.mark.asyncio +@pytest.mark.usefixtures("setup_session") +async def test_modify_dn_move_with_ap( + settings: Settings, + creds: TestCreds, + role_dao: RoleDAO, + access_control_entry_dao: AccessControlEntryDAO, + entity_type_dao: EntityTypeDAO, + attribute_type_dao: EntityTypeDAO, +) -> None: + dn = "cn=user0,cn=Users,dc=md,dc=test" + base_dn = "dc=md,dc=test" + + user_entity_type = await entity_type_dao.get(EntityTypeNames.USER) + assert user_entity_type + + rdn_attr = await attribute_type_dao.get("cn") + assert rdn_attr + + new_parent_dn = "cn=Groups,dc=md,dc=test" + + res = await run_single_modrdn( + settings=settings, + bind_dn="user_non_admin", + password=creds.pw, + dn=dn, + newrdn="cn=user2", + deleteoldrdn=1, + newsuperior=new_parent_dn, + ) + + assert res == LDAPCodes.INSUFFICIENT_ACCESS_RIGHTS + + await role_dao.create( + dto=RoleDTO( + name="Modify Role", + creator_upn=None, + is_system=False, + groups=["cn=domain users,cn=Groups," + base_dn], + ), + ) + + role_id = role_dao.get_last_id() + + write_ace = AccessControlEntryDTO( + role_id=role_id, + ace_type=AceType.WRITE, + scope=RoleScope.WHOLE_SUBTREE, + base_dn=dn, + attribute_type_id=rdn_attr.id, + entity_type_id=user_entity_type.id, + is_allow=True, + ) + create_ace = AccessControlEntryDTO( + role_id=role_id, + ace_type=AceType.CREATE_CHILD, + scope=RoleScope.WHOLE_SUBTREE, + base_dn=new_parent_dn, + attribute_type_id=None, + entity_type_id=user_entity_type.id, + is_allow=True, + ) + delete_ace = AccessControlEntryDTO( + role_id=role_id, + ace_type=AceType.DELETE, + scope=RoleScope.WHOLE_SUBTREE, + base_dn=dn, + attribute_type_id=None, + entity_type_id=user_entity_type.id, + is_allow=True, + ) + + await access_control_entry_dao.create_bulk( + [write_ace, create_ace, delete_ace], + ) + + aces_before = await access_control_entry_dao.get_all() + + res = await run_single_modrdn( + settings=settings, + bind_dn="user_non_admin", + password=creds.pw, + dn=dn, + newrdn="cn=user2", + deleteoldrdn=1, + newsuperior=new_parent_dn, + ) + + assert res == LDAPCodes.SUCCESS + + aces_after = await access_control_entry_dao.get_all() + + inherited_aces_before = [ + ace + for ace in aces_before + if ace.base_dn != "cn=user0,cn=Users,dc=md,dc=test" + ] + explicit_aces_before = [ + ace + for ace in aces_before + if ace.base_dn == "cn=user0,cn=Users,dc=md,dc=test" + ] + + inherited_aces_after = [ + ace + for ace in aces_after + if ace.base_dn != "cn=user2,cn=Groups,dc=md,dc=test" + ] + explicit_aces_after = [ + ace + for ace in aces_after + if ace.base_dn == "cn=user2,cn=Groups,dc=md,dc=test" + ] + + assert inherited_aces_before == inherited_aces_after + assert len(explicit_aces_after) == len(explicit_aces_before) + + # check expicit aces have same properties except base_dn + for ace_before, ace_after in zip( + explicit_aces_before, + explicit_aces_after, + ): + assert ace_before.id == ace_after.id + assert ace_before.role_id == ace_after.role_id + assert ace_before.ace_type == ace_after.ace_type + assert ace_before.scope == ace_after.scope + assert ace_before.attribute_type_id == ace_after.attribute_type_id + assert ace_before.entity_type_id == ace_after.entity_type_id + assert ace_before.is_allow == ace_after.is_allow + + assert ace_after.base_dn == "cn=user2,cn=Groups,dc=md,dc=test" diff --git a/tests/test_ldap/test_util/test_search.py b/tests/test_ldap/test_util/test_search.py index 903fb2598..338822a62 100644 --- a/tests/test_ldap/test_util/test_search.py +++ b/tests/test_ldap/test_util/test_search.py @@ -62,9 +62,9 @@ async def test_ldap_search(settings: Settings, creds: TestCreds) -> None: result = await proc.wait() assert result == 0 - assert "dn: cn=groups,dc=md,dc=test" in data - assert "dn: cn=users,dc=md,dc=test" in data - assert "dn: cn=user0,cn=users,dc=md,dc=test" in data + assert "dn: cn=Groups,dc=md,dc=test" in data + assert "dn: cn=Users,dc=md,dc=test" in data + assert "dn: cn=user0,cn=Users,dc=md,dc=test" in data @pytest.mark.asyncio @@ -89,7 +89,7 @@ async def test_ldap_search_filter( "dc=md,dc=test", "(&" "(objectClass=user)" - "(memberOf:1.2.840.113556.1.4.1941:=cn=domain admins,cn=groups,dc=md,\ + "(memberOf:1.2.840.113556.1.4.1941:=cn=domain admins,cn=Groups,dc=md,\ dc=test)" ")", stdout=asyncio.subprocess.PIPE, @@ -101,8 +101,8 @@ async def test_ldap_search_filter( result = await proc.wait() assert result == 0 - assert "dn: cn=user0,cn=users,dc=md,dc=test" in data - assert "dn: cn=user1,cn=moscow,cn=russia,cn=users,dc=md,dc=test" in data + assert "dn: cn=user0,cn=Users,dc=md,dc=test" in data + assert "dn: cn=user1,cn=moscow,cn=russia,cn=Users,dc=md,dc=test" in data @pytest.mark.asyncio @@ -298,7 +298,7 @@ async def test_ldap_search_filter_prefix( result = await proc.wait() assert result == 0 - assert "dn: cn=user0,cn=users,dc=md,dc=test" in data + assert "dn: cn=user0,cn=Users,dc=md,dc=test" in data @pytest.mark.asyncio @@ -317,7 +317,7 @@ async def test_bind_policy( assert policy group = await get_group( - dn="cn=domain admins,cn=groups,dc=md,dc=test", + dn="cn=domain admins,cn=Groups,dc=md,dc=test", session=session, ) policy.groups.append(group) @@ -368,7 +368,7 @@ async def test_bind_policy_missing_group( user = (await session.scalars(user_query)).one() policy.groups = await get_groups( - ["cn=domain admins,cn=groups,dc=md,dc=test"], + ["cn=domain admins,cn=Groups,dc=md,dc=test"], session, ) user.groups.clear() @@ -432,7 +432,7 @@ async def test_bvalue_in_search_request( ) -> None: """Test SearchRequest with bytes data.""" request = SearchRequest( - base_object="cn=user0,cn=users,dc=md,dc=test", + base_object="cn=user0,cn=Users,dc=md,dc=test", scope=0, deref_aliases=0, size_limit=0, @@ -525,7 +525,7 @@ async def test_ldap_search_access_control_denied( assert result == 0 assert dn_list == [ - "dn: cn=user_non_admin,cn=users,dc=md,dc=test", + "dn: cn=user_non_admin,cn=Users,dc=md,dc=test", ] await session.commit() @@ -535,7 +535,7 @@ async def test_ldap_search_access_control_denied( name="Groups Read Role", creator_upn=None, is_system=False, - groups=["cn=domain users,cn=groups,dc=md,dc=test"], + groups=["cn=domain users,cn=Groups,dc=md,dc=test"], ), ) @@ -543,7 +543,7 @@ async def test_ldap_search_access_control_denied( role_id=role_dao.get_last_id(), ace_type=AceType.READ, scope=RoleScope.WHOLE_SUBTREE, - base_dn="cn=groups,dc=md,dc=test", + base_dn="cn=Groups,dc=md,dc=test", attribute_type_id=None, entity_type_id=None, is_allow=True, @@ -577,12 +577,12 @@ async def test_ldap_search_access_control_denied( assert result == 0 assert sorted(dn_list) == sorted( [ - "dn: cn=groups,dc=md,dc=test", - "dn: cn=domain admins,cn=groups,dc=md,dc=test", - "dn: cn=admin login only,cn=groups,dc=md,dc=test", - "dn: cn=developers,cn=groups,dc=md,dc=test", - "dn: cn=domain computers,cn=groups,dc=md,dc=test", - "dn: cn=domain users,cn=groups,dc=md,dc=test", - "dn: cn=user_non_admin,cn=users,dc=md,dc=test", + "dn: cn=Groups,dc=md,dc=test", + "dn: cn=domain admins,cn=Groups,dc=md,dc=test", + "dn: cn=admin login only,cn=Groups,dc=md,dc=test", + "dn: cn=developers,cn=Groups,dc=md,dc=test", + "dn: cn=domain computers,cn=Groups,dc=md,dc=test", + "dn: cn=domain users,cn=Groups,dc=md,dc=test", + "dn: cn=user_non_admin,cn=Users,dc=md,dc=test", ], ) diff --git a/tests/test_shedule.py b/tests/test_shedule.py index dc5aaaf01..fa293902a 100644 --- a/tests/test_shedule.py +++ b/tests/test_shedule.py @@ -8,11 +8,14 @@ from sqlalchemy.ext.asyncio import AsyncSession from config import Settings +from extra.scripts.add_domain_controller import add_domain_controller from extra.scripts.check_ldap_principal import check_ldap_principal from extra.scripts.principal_block_user_sync import principal_block_sync from extra.scripts.uac_sync import disable_accounts from extra.scripts.update_krb5_config import update_krb5_config from ldap_protocol.kerberos import AbstractKadmin +from ldap_protocol.ldap_schema.entity_type_dao import EntityTypeDAO +from ldap_protocol.roles.role_use_case import RoleUseCase @pytest.mark.asyncio @@ -73,3 +76,21 @@ async def test_update_krb5_config( session=session, settings=settings, ) + + +@pytest.mark.asyncio +@pytest.mark.usefixtures("session") +@pytest.mark.usefixtures("setup_session") +async def test_add_domain_controller( + session: AsyncSession, + settings: Settings, + role_use_case: RoleUseCase, + entity_type_dao: EntityTypeDAO, +) -> None: + """Test add domain controller.""" + await add_domain_controller( + settings=settings, + session=session, + role_use_case=role_use_case, + entity_type_dao=entity_type_dao, + ) diff --git a/traefik.yml b/traefik.yml index f95bf72f3..cdb9a6ee3 100644 --- a/traefik.yml +++ b/traefik.yml @@ -7,6 +7,12 @@ api: ping: entryPoint: "ping" +tcp: + serversTransports: + ldap_transport: + proxyProtocol: + version: 2 + entryPoints: ping: address: ":8800" @@ -34,8 +40,6 @@ entryPoints: address: ":749" kpasswd: address: ":464" - bind_dns_udp: - address: ":53/udp" tls: stores: diff --git a/uv.lock b/uv.lock index 85838a5a8..8154ccc0e 100644 --- a/uv.lock +++ b/uv.lock @@ -251,6 +251,19 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/b7/b9/89381173b4f336e986d72471198614806cd313e0f85c143ccb677c310223/dishka-1.7.2-py3-none-any.whl", hash = "sha256:f6faa6ab321903926b825b3337d77172ee693450279b314434864978d01fbad3", size = 94774, upload-time = "2025-09-24T21:23:03.246Z" }, ] +[[package]] +name = "dnsdist-console" +version = "1.6.0" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "libnacl" }, + { name = "scrypt" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/e3/f9/1d3eb92c2a94af1fd970b42e48584f544e832a3d40c86d14340f1def78db/dnsdist_console-1.6.0.tar.gz", hash = "sha256:4afb35b52640db5c4865aa6458147651c757907465204497e6c74f36f5a7eb0a", size = 10337, upload-time = "2025-04-02T10:29:53.382Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/67/54/5005c1355e3d97ee4cf7e5d118836d898d5ed58470ab8a63139dead16b14/dnsdist_console-1.6.0-py3-none-any.whl", hash = "sha256:074056c0364d6450636051bb1d49a0d07620ae9cf14b1b9ce17f130b8585adce", size = 11336, upload-time = "2025-04-02T10:29:52.132Z" }, +] + [[package]] name = "dnspython" version = "2.8.0" @@ -474,6 +487,15 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/4e/f6/71d6ec9f18da0b2201287ce9db6afb1a1f637dedb3f0703409558981c723/ldap3-2.9.1-py2.py3-none-any.whl", hash = "sha256:5869596fc4948797020d3f03b7939da938778a0f9e2009f7a072ccf92b8e8d70", size = 432192, upload-time = "2021-07-18T06:34:12.905Z" }, ] +[[package]] +name = "libnacl" +version = "2.1.0" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/df/fc/65daa1a3fd7dd939133c30c6d393ea47e32317d2195619923b67daa29d60/libnacl-2.1.0.tar.gz", hash = "sha256:f3418da7df29e6d9b11fd7d990289d16397dc1020e4e35192e11aee826922860", size = 42189, upload-time = "2023-08-06T21:23:56.86Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/ad/ce/85fa0276de7303b44fef63e07c14d618b8630bbe41c7dd7e34db246eab8d/libnacl-2.1.0-py3-none-any.whl", hash = "sha256:a8546b221afe8b72b6a9f298cd92a4c1f90570d7b5baa295acb1913644e230a5", size = 21870, upload-time = "2023-08-06T21:23:55.12Z" }, +] + [[package]] name = "loguru" version = "0.7.3" @@ -543,6 +565,7 @@ dependencies = [ { name = "bcrypt" }, { name = "cryptography" }, { name = "dishka" }, + { name = "dnsdist-console" }, { name = "dnspython" }, { name = "fastapi" }, { name = "fastapi-error-map" }, @@ -597,6 +620,7 @@ requires-dist = [ { name = "bcrypt", specifier = "==4.0.1" }, { name = "cryptography", specifier = ">=44.0.1" }, { name = "dishka", specifier = ">=1.6.0" }, + { name = "dnsdist-console", specifier = ">=1.6.0" }, { name = "dnspython", specifier = ">=2.7.0" }, { name = "fastapi", specifier = ">=0.115.0" }, { name = "fastapi-error-map", specifier = ">=0.9.8" }, @@ -1013,6 +1037,18 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/40/f7/70aad26e5877c8f7ee5b161c4c9fa0100e63fc4c944dc6d97b9c7e871417/ruff-0.11.9-py3-none-win_arm64.whl", hash = "sha256:bcf42689c22f2e240f496d0c183ef2c6f7b35e809f12c1db58f75d9aa8d630ca", size = 10741080, upload-time = "2025-05-09T16:19:39.605Z" }, ] +[[package]] +name = "scrypt" +version = "0.9.4" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/db/38/c9b79f61c04fa79b8fae28213111a6f70d8249d4d789ca7030453326ab62/scrypt-0.9.4.tar.gz", hash = "sha256:0d212010ba8c2e55475ba6258f30cee4da0432017514d8f6e855b7f1f8c55c77", size = 84526, upload-time = "2025-08-05T05:54:37.283Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/aa/22/98e17e1ea6461a5c51c866192304182846fd004852f789dfece9f44c6553/scrypt-0.9.4-cp313-cp313-macosx_14_0_arm64.whl", hash = "sha256:58f424ac1656d342b2651bf5577f1b2aad9959c2e41ebbadf591035b372368a9", size = 2293992, upload-time = "2025-08-05T06:00:53.887Z" }, + { url = "https://files.pythonhosted.org/packages/32/d9/076f90cb1086e32ebab30123952f4f162c80c53b8814e35bedb4aa720241/scrypt-0.9.4-cp313-cp313-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:efaa359fab4682215d8826c5ff6ecda525d37eabfc0da4ab197a3fd95e6d5f87", size = 1508819, upload-time = "2025-08-05T06:04:54.445Z" }, + { url = "https://files.pythonhosted.org/packages/87/c0/d59f086fc8a589db06eac0b3829e2956de7a4c34d7c99c20b3cf6a858627/scrypt-0.9.4-cp313-cp313-musllinux_1_2_x86_64.whl", hash = "sha256:ee29481f0751eb4e91c4fca8895d44822d523225c796e0ed016a550e7f20f582", size = 2015695, upload-time = "2025-08-05T06:04:55.477Z" }, + { url = "https://files.pythonhosted.org/packages/98/0d/f62590144acf914a2eb4c3689cc6a2ca737a379962b0358f00dd0a1445cb/scrypt-0.9.4-cp313-cp313-win_amd64.whl", hash = "sha256:6ae6a0f7ccf7df9f0612b9166abbff5b6dacc41661044a0d090111aeb5e0bcf2", size = 47267, upload-time = "2025-08-05T06:03:57.035Z" }, +] + [[package]] name = "setuptools" version = "80.9.0"