style: apply black formatting across codebase 🎨

No logic changes — pure reformatting of line lengths, dict literals,
method-chain line breaks, and trailing newlines to satisfy black's style.
This commit is contained in:
2026-02-18 22:53:09 +13:00
parent 807d6271f1
commit 74c5f4012e
12 changed files with 291 additions and 164 deletions

View File

@@ -113,19 +113,23 @@ class DNSAdminAPI:
if domain_exists: if domain_exists:
record = get_domain_record(domain) record = get_domain_record(domain)
return urlencode({ return urlencode(
"error": 0, {
"exists": 1, "error": 0,
"details": f"Domain exists on {record.hostname}", "exists": 1,
}) "details": f"Domain exists on {record.hostname}",
}
)
# parent match only # parent match only
parent_record = get_parent_domain_record(domain) parent_record = get_parent_domain_record(domain)
return urlencode({ return urlencode(
"error": 0, {
"exists": 2, "error": 0,
"details": f"Parent Domain exists on {parent_record.hostname}", "exists": 2,
}) "details": f"Parent Domain exists on {parent_record.hostname}",
}
)
def _handle_rawsave(self, domain: str, params: dict): def _handle_rawsave(self, domain: str, params: dict):
"""Process zone file saves""" """Process zone file saves"""

View File

@@ -119,9 +119,7 @@ class CoreDNSMySQLBackend(DNSBackend):
) )
session.add(existing_soa) session.add(existing_soa)
changes["added"] += 1 changes["added"] += 1
logger.debug( logger.debug(f"Added SOA record: {soa_name} SOA {soa_content}")
f"Added SOA record: {soa_name} SOA {soa_content}"
)
# Process all non-SOA records # Process all non-SOA records
for record_name, record_type, record_content, record_ttl in source_records: for record_name, record_type, record_content, record_ttl in source_records:
@@ -172,7 +170,7 @@ class CoreDNSMySQLBackend(DNSBackend):
changes["removed"] += 1 changes["removed"] += 1
session.commit() session.commit()
total_changes = changes['added'] + changes['updated'] + changes['removed'] total_changes = changes["added"] + changes["updated"] + changes["removed"]
if total_changes > 0: if total_changes > 0:
logger.info( logger.info(
f"[{self.instance_name}] Zone {zone_name} updated: " f"[{self.instance_name}] Zone {zone_name} updated: "
@@ -180,9 +178,7 @@ class CoreDNSMySQLBackend(DNSBackend):
f"{changes['removed']} removed" f"{changes['removed']} removed"
) )
else: else:
logger.debug( logger.debug(f"[{self.instance_name}] Zone {zone_name}: no changes")
f"[{self.instance_name}] Zone {zone_name}: no changes"
)
return True return True
except Exception as e: except Exception as e:
@@ -196,7 +192,11 @@ class CoreDNSMySQLBackend(DNSBackend):
session = self.Session() session = self.Session()
try: try:
# First find the zone # First find the zone
zone = session.query(Zone).filter_by(zone_name=self.dot_fqdn(zone_name)).first() zone = (
session.query(Zone)
.filter_by(zone_name=self.dot_fqdn(zone_name))
.first()
)
if not zone: if not zone:
logger.warning(f"Zone {zone_name} not found for deletion") logger.warning(f"Zone {zone_name} not found for deletion")
return False return False
@@ -230,7 +230,9 @@ class CoreDNSMySQLBackend(DNSBackend):
session = self.Session() session = self.Session()
try: try:
exists = ( exists = (
session.query(Zone).filter_by(zone_name=self.dot_fqdn(zone_name)).first() session.query(Zone)
.filter_by(zone_name=self.dot_fqdn(zone_name))
.first()
is not None is not None
) )
logger.debug(f"Zone existence check for {zone_name}: {exists}") logger.debug(f"Zone existence check for {zone_name}: {exists}")
@@ -266,17 +268,11 @@ class CoreDNSMySQLBackend(DNSBackend):
The normalized CNAME target string The normalized CNAME target string
""" """
if record_content.startswith("@"): if record_content.startswith("@"):
logger.debug( logger.debug(f"CNAME target starts with '@', replacing with zone FQDN")
f"CNAME target starts with '@', replacing with zone FQDN"
)
record_content = self.dot_fqdn(zone_name) record_content = self.dot_fqdn(zone_name)
elif not record_content.endswith("."): elif not record_content.endswith("."):
logger.debug( logger.debug(f"CNAME target {record_content} is relative, appending zone")
f"CNAME target {record_content} is relative, appending zone" record_content = ".".join([record_content, self.dot_fqdn(zone_name)])
)
record_content = ".".join(
[record_content, self.dot_fqdn(zone_name)]
)
return record_content return record_content
def _parse_zone_to_record_set( def _parse_zone_to_record_set(
@@ -306,9 +302,7 @@ class CoreDNSMySQLBackend(DNSBackend):
continue continue
if record_type == "CNAME": if record_type == "CNAME":
record_content = self._normalize_cname_data( record_content = self._normalize_cname_data(zone_name, record_content)
zone_name, record_content
)
records.add((record_name, record_type, record_content, ttl)) records.add((record_name, record_type, record_content, ttl))
@@ -341,9 +335,7 @@ class CoreDNSMySQLBackend(DNSBackend):
) )
return False, 0 return False, 0
actual_count = ( actual_count = session.query(Record).filter_by(zone_id=zone.id).count()
session.query(Record).filter_by(zone_id=zone.id).count()
)
matches = actual_count == expected_count matches = actual_count == expected_count
if not matches: if not matches:
@@ -409,14 +401,11 @@ class CoreDNSMySQLBackend(DNSBackend):
) )
# Build lookup keys (without TTL) matching write_zone's key format # Build lookup keys (without TTL) matching write_zone's key format
expected_keys: Set[Tuple[str, str, str]] = { expected_keys: Set[Tuple[str, str, str]] = {
(hostname, rtype, data) (hostname, rtype, data) for hostname, rtype, data, _ in source_records
for hostname, rtype, data, _ in source_records
} }
# Query all records currently in the backend for this zone # Query all records currently in the backend for this zone
db_records = ( db_records = session.query(Record).filter_by(zone_id=zone.id).all()
session.query(Record).filter_by(zone_id=zone.id).all()
)
removed = 0 removed = 0
for record in db_records: for record in db_records:

View File

@@ -109,9 +109,7 @@ class ReconciliationWorker:
f"[reconciler] {hostname}: {len(da_domains) if da_domains else 0} active domain(s) in DA" f"[reconciler] {hostname}: {len(da_domains) if da_domains else 0} active domain(s) in DA"
) )
except Exception as e: except Exception as e:
logger.error( logger.error(f"[reconciler] Unexpected error polling {hostname}: {e}")
f"[reconciler] Unexpected error polling {hostname}: {e}"
)
# Now check local DB for all domains, update master if needed, and queue deletes only from recorded master # Now check local DB for all domains, update master if needed, and queue deletes only from recorded master
session = connect() session = connect()
@@ -147,12 +145,14 @@ class ReconciliationWorker:
f"(master: {recorded_master})" f"(master: {recorded_master})"
) )
else: else:
self.delete_queue.put({ self.delete_queue.put(
"domain": record.domain, {
"hostname": record.hostname, "domain": record.domain,
"username": record.username or "", "hostname": record.hostname,
"source": "reconciler", "username": record.username or "",
}) "source": "reconciler",
}
)
logger.debug( logger.debug(
f"[reconciler] Queued delete for orphan: {record.domain} " f"[reconciler] Queued delete for orphan: {record.domain} "
f"(master: {recorded_master})" f"(master: {recorded_master})"
@@ -161,9 +161,13 @@ class ReconciliationWorker:
if migrated or backfilled: if migrated or backfilled:
session.commit() session.commit()
if backfilled: if backfilled:
logger.info(f"[reconciler] {backfilled} domain(s) had missing hostname backfilled.") logger.info(
f"[reconciler] {backfilled} domain(s) had missing hostname backfilled."
)
if migrated: if migrated:
logger.info(f"[reconciler] {migrated} domain(s) migrated to new master.") logger.info(
f"[reconciler] {migrated} domain(s) migrated to new master."
)
finally: finally:
session.close() session.close()
if self.dry_run: if self.dry_run:
@@ -178,7 +182,13 @@ class ReconciliationWorker:
) )
def _fetch_da_domains( def _fetch_da_domains(
self, hostname: str, port: int, username: str, password: str, use_ssl: bool, ipp: int = 1000 self,
hostname: str,
port: int,
username: str,
password: str,
use_ssl: bool,
ipp: int = 1000,
): ):
"""Fetch all domains from a DA server via CMD_DNS_ADMIN (JSON, paging supported). """Fetch all domains from a DA server via CMD_DNS_ADMIN (JSON, paging supported).
@@ -205,13 +215,21 @@ class ReconciliationWorker:
response = requests.get(url, **req_kwargs) response = requests.get(url, **req_kwargs)
if response.is_redirect or response.status_code in (301, 302, 303, 307, 308): if response.is_redirect or response.status_code in (
301,
302,
303,
307,
308,
):
if not cookies: if not cookies:
logger.debug( logger.debug(
f"[reconciler] {hostname}:{port} redirected Basic Auth " f"[reconciler] {hostname}:{port} redirected Basic Auth "
f"(HTTP {response.status_code}) — attempting session login (DA Evo)" f"(HTTP {response.status_code}) — attempting session login (DA Evo)"
) )
cookies = self._da_session_login(scheme, hostname, port, username, password) cookies = self._da_session_login(
scheme, hostname, port, username, password
)
if cookies is None: if cookies is None:
return None return None
continue # retry this page with cookies continue # retry this page with cookies
@@ -279,9 +297,7 @@ class ReconciliationWorker:
) )
return None return None
except Exception as e: except Exception as e:
logger.error( logger.error(f"[reconciler] Unexpected error fetching from {hostname}: {e}")
f"[reconciler] Unexpected error fetching from {hostname}: {e}"
)
return None return None
def _da_session_login( def _da_session_login(
@@ -310,12 +326,12 @@ class ReconciliationWorker:
f"check username/password." f"check username/password."
) )
return None return None
logger.debug(f"[reconciler] {hostname}:{port} session login successful (DA Evo)") logger.debug(
f"[reconciler] {hostname}:{port} session login successful (DA Evo)"
)
return response.cookies return response.cookies
except Exception as e: except Exception as e:
logger.error( logger.error(f"[reconciler] {hostname}:{port} session login failed: {e}")
f"[reconciler] {hostname}:{port} session login failed: {e}"
)
return None return None
@staticmethod @staticmethod
@@ -334,24 +350,44 @@ class ReconciliationWorker:
domains = params.get("list[]", []) domains = params.get("list[]", [])
return {d.strip().lower() for d in domains if d.strip()} return {d.strip().lower() for d in domains if d.strip()}
if __name__ == "__main__": if __name__ == "__main__":
import argparse import argparse
import sys import sys
from queue import Queue from queue import Queue
parser = argparse.ArgumentParser(description="Test DirectAdmin domain fetcher (JSON/paging)") parser = argparse.ArgumentParser(
description="Test DirectAdmin domain fetcher (JSON/paging)"
)
parser.add_argument("--hostname", required=True, help="DirectAdmin server hostname") parser.add_argument("--hostname", required=True, help="DirectAdmin server hostname")
parser.add_argument("--port", type=int, default=2222, help="DirectAdmin port (default: 2222)") parser.add_argument(
"--port", type=int, default=2222, help="DirectAdmin port (default: 2222)"
)
parser.add_argument("--username", required=True, help="DirectAdmin admin username") parser.add_argument("--username", required=True, help="DirectAdmin admin username")
parser.add_argument("--password", required=True, help="DirectAdmin admin password") parser.add_argument("--password", required=True, help="DirectAdmin admin password")
parser.add_argument("--ssl", action="store_true", help="Use HTTPS (default: True)") parser.add_argument("--ssl", action="store_true", help="Use HTTPS (default: True)")
parser.add_argument("--no-ssl", dest="ssl", action="store_false", help="Use HTTP (not recommended)") parser.add_argument(
"--no-ssl", dest="ssl", action="store_false", help="Use HTTP (not recommended)"
)
parser.set_defaults(ssl=True) parser.set_defaults(ssl=True)
parser.add_argument("--verify-ssl", action="store_true", help="Verify SSL certs (default: True)") parser.add_argument(
parser.add_argument("--no-verify-ssl", dest="verify_ssl", action="store_false", help="Don't verify SSL certs") "--verify-ssl", action="store_true", help="Verify SSL certs (default: True)"
)
parser.add_argument(
"--no-verify-ssl",
dest="verify_ssl",
action="store_false",
help="Don't verify SSL certs",
)
parser.set_defaults(verify_ssl=True) parser.set_defaults(verify_ssl=True)
parser.add_argument("--ipp", type=int, default=1000, help="Items per page (default: 1000)") parser.add_argument(
parser.add_argument("--print-json", action="store_true", help="Print raw JSON response for first page") "--ipp", type=int, default=1000, help="Items per page (default: 1000)"
)
parser.add_argument(
"--print-json",
action="store_true",
help="Print raw JSON response for first page",
)
args = parser.parse_args() args = parser.parse_args()
@@ -372,7 +408,9 @@ if __name__ == "__main__":
q = Queue() q = Queue()
worker = ReconciliationWorker(q, config) worker = ReconciliationWorker(q, config)
server = config["directadmin_servers"][0] server = config["directadmin_servers"][0]
print(f"Fetching domains from {server['hostname']}:{server['port']} (ipp={args.ipp})...") print(
f"Fetching domains from {server['hostname']}:{server['port']} (ipp={args.ipp})..."
)
# Directly call the fetch method for testing # Directly call the fetch method for testing
domains = worker._fetch_da_domains( domains = worker._fetch_da_domains(
server["hostname"], server["hostname"],
@@ -380,7 +418,7 @@ if __name__ == "__main__":
server.get("username"), server.get("username"),
server.get("password"), server.get("password"),
server.get("ssl", True), server.get("ssl", True),
ipp=args.ipp ipp=args.ipp,
) )
if domains is None: if domains is None:
print("Failed to fetch domains.", file=sys.stderr) print("Failed to fetch domains.", file=sys.stderr)

View File

@@ -59,9 +59,7 @@ def count_zone_records(zone_data: str, domain_name: str) -> int:
if rdata.rdclass == IN: if rdata.rdclass == IN:
count += 1 count += 1
logger.debug( logger.debug(f"Source zone {domain_name} contains {count} records")
f"Source zone {domain_name} contains {count} records"
)
return count return count
except DNSException as e: except DNSException as e:

View File

@@ -50,7 +50,9 @@ def main():
# Configure CherryPy # Configure CherryPy
user_password_dict = { user_password_dict = {
config.get_string("app.auth_username"): config.get_string("app.auth_password") config.get_string("app.auth_username"): config.get_string(
"app.auth_password"
)
} }
check_password = cherrypy.lib.auth_basic.checkpassword_dict(user_password_dict) check_password = cherrypy.lib.auth_basic.checkpassword_dict(user_password_dict)

View File

@@ -14,7 +14,9 @@ from directdnsonly.app.reconciler import ReconciliationWorker
class WorkerManager: class WorkerManager:
def __init__(self, queue_path: str, backend_registry, reconciliation_config: dict = None): def __init__(
self, queue_path: str, backend_registry, reconciliation_config: dict = None
):
self.queue_path = queue_path self.queue_path = queue_path
self.backend_registry = backend_registry self.backend_registry = backend_registry
self._running = False self._running = False
@@ -86,9 +88,7 @@ class WorkerManager:
f"{len(backends)} backends concurrently: " f"{len(backends)} backends concurrently: "
f"{', '.join(backends.keys())}" f"{', '.join(backends.keys())}"
) )
self._process_backends_parallel( self._process_backends_parallel(backends, item, session)
backends, item, session
)
else: else:
# Single backend, no need for thread overhead # Single backend, no need for thread overhead
for backend_name, backend in backends.items(): for backend_name, backend in backends.items():
@@ -126,9 +126,7 @@ class WorkerManager:
try: try:
logger.debug(f"Using backend: {backend_name}") logger.debug(f"Using backend: {backend_name}")
if backend.write_zone(item["domain"], item["zone_file"]): if backend.write_zone(item["domain"], item["zone_file"]):
logger.debug( logger.debug(f"Successfully updated {item['domain']} in {backend_name}")
f"Successfully updated {item['domain']} in {backend_name}"
)
if backend.get_name() == "bind": if backend.get_name() == "bind":
# Need to update the named.conf # Need to update the named.conf
backend.update_named_conf( backend.update_named_conf(
@@ -144,9 +142,7 @@ class WorkerManager:
backend_name, backend, item["domain"], item["zone_file"] backend_name, backend, item["domain"], item["zone_file"]
) )
else: else:
logger.error( logger.error(f"Failed to update {item['domain']} in {backend_name}")
f"Failed to update {item['domain']} in {backend_name}"
)
except Exception as e: except Exception as e:
logger.error(f"Error in {backend_name}: {str(e)}") logger.error(f"Error in {backend_name}: {str(e)}")
@@ -165,9 +161,7 @@ class WorkerManager:
record = session.query(Domain).filter_by(domain=domain).first() record = session.query(Domain).filter_by(domain=domain).first()
if not record: if not record:
logger.warning( logger.warning(f"Domain {domain} not found in DB — skipping delete")
f"Domain {domain} not found in DB — skipping delete"
)
self.delete_queue.task_done() self.delete_queue.task_done()
continue continue
@@ -194,16 +188,29 @@ class WorkerManager:
elif len(backends) > 1: elif len(backends) > 1:
# Parallel delete, track failures # Parallel delete, track failures
results = [] results = []
def delete_backend_wrapper(backend_name, backend, domain, remaining_domains):
def delete_backend_wrapper(
backend_name, backend, domain, remaining_domains
):
try: try:
return backend.delete_zone(domain) return backend.delete_zone(domain)
except Exception as e: except Exception as e:
logger.error(f"Error deleting {domain} from {backend_name}: {e}") logger.error(
f"Error deleting {domain} from {backend_name}: {e}"
)
return False return False
from concurrent.futures import ThreadPoolExecutor, as_completed from concurrent.futures import ThreadPoolExecutor, as_completed
with ThreadPoolExecutor(max_workers=len(backends)) as executor: with ThreadPoolExecutor(max_workers=len(backends)) as executor:
futures = { futures = {
executor.submit(delete_backend_wrapper, backend_name, backend, domain, remaining_domains): backend_name executor.submit(
delete_backend_wrapper,
backend_name,
backend,
domain,
remaining_domains,
): backend_name
for backend_name, backend in backends.items() for backend_name, backend in backends.items()
} }
for future in as_completed(futures): for future in as_completed(futures):
@@ -212,9 +219,13 @@ class WorkerManager:
result = future.result() result = future.result()
results.append(result) results.append(result)
if not result: if not result:
logger.error(f"Failed to delete {domain} from {backend_name}") logger.error(
f"Failed to delete {domain} from {backend_name}"
)
except Exception as e: except Exception as e:
logger.error(f"Unhandled error deleting from {backend_name}: {e}") logger.error(
f"Unhandled error deleting from {backend_name}: {e}"
)
results.append(False) results.append(False)
delete_success = all(results) delete_success = all(results)
else: else:
@@ -223,10 +234,14 @@ class WorkerManager:
try: try:
result = backend.delete_zone(domain) result = backend.delete_zone(domain)
if not result: if not result:
logger.error(f"Failed to delete {domain} from {backend_name}") logger.error(
f"Failed to delete {domain} from {backend_name}"
)
delete_success = False delete_success = False
except Exception as e: except Exception as e:
logger.error(f"Error deleting {domain} from {backend_name}: {e}") logger.error(
f"Error deleting {domain} from {backend_name}: {e}"
)
delete_success = False delete_success = False
if delete_success: if delete_success:
@@ -236,7 +251,9 @@ class WorkerManager:
self.delete_queue.task_done() self.delete_queue.task_done()
logger.success(f"Delete completed for {domain}") logger.success(f"Delete completed for {domain}")
else: else:
logger.error(f"Delete failed for {domain} on one or more backends — DB record retained") logger.error(
f"Delete failed for {domain} on one or more backends — DB record retained"
)
self.delete_queue.task_done() self.delete_queue.task_done()
except Empty: except Empty:
@@ -270,7 +287,10 @@ class WorkerManager:
futures = { futures = {
executor.submit( executor.submit(
self._delete_single_backend, self._delete_single_backend,
backend_name, backend, domain, remaining_domains backend_name,
backend,
domain,
remaining_domains,
): backend_name ): backend_name
for backend_name, backend in backends.items() for backend_name, backend in backends.items()
} }
@@ -279,9 +299,7 @@ class WorkerManager:
try: try:
future.result() future.result()
except Exception as e: except Exception as e:
logger.error( logger.error(f"Unhandled error deleting from {backend_name}: {e}")
f"Unhandled error deleting from {backend_name}: {e}"
)
elapsed = (time.monotonic() - start_time) * 1000 elapsed = (time.monotonic() - start_time) * 1000
logger.debug( logger.debug(
f"Parallel delete of {domain} across " f"Parallel delete of {domain} across "
@@ -292,13 +310,11 @@ class WorkerManager:
"""Process zone updates across multiple backends in parallel""" """Process zone updates across multiple backends in parallel"""
start_time = time.monotonic() start_time = time.monotonic()
with ThreadPoolExecutor( with ThreadPoolExecutor(
max_workers=len(backends), max_workers=len(backends), thread_name_prefix="backend"
thread_name_prefix="backend"
) as executor: ) as executor:
futures = { futures = {
executor.submit( executor.submit(
self._process_single_backend, self._process_single_backend, backend_name, backend, item, session
backend_name, backend, item, session
): backend_name ): backend_name
for backend_name, backend in backends.items() for backend_name, backend in backends.items()
} }
@@ -317,9 +333,7 @@ class WorkerManager:
f"{len(backends)} backends completed in {elapsed:.0f}ms" f"{len(backends)} backends completed in {elapsed:.0f}ms"
) )
def _verify_backend_record_count( def _verify_backend_record_count(self, backend_name, backend, zone_name, zone_data):
self, backend_name, backend, zone_name, zone_data
):
"""Verify and reconcile the backend record count against the """Verify and reconcile the backend record count against the
authoritative BIND zone from DirectAdmin. authoritative BIND zone from DirectAdmin.
@@ -344,9 +358,7 @@ class WorkerManager:
) )
return return
matches, actual = backend.verify_zone_record_count( matches, actual = backend.verify_zone_record_count(zone_name, expected)
zone_name, expected
)
if matches: if matches:
return # All good return # All good
@@ -357,9 +369,7 @@ class WorkerManager:
f"record(s) for {zone_name} — reconciling against " f"record(s) for {zone_name} — reconciling against "
f"DirectAdmin source zone" f"DirectAdmin source zone"
) )
success, removed = backend.reconcile_zone_records( success, removed = backend.reconcile_zone_records(zone_name, zone_data)
zone_name, zone_data
)
if success and removed > 0: if success and removed > 0:
# Verify again after reconciliation # Verify again after reconciliation
matches, new_count = backend.verify_zone_record_count( matches, new_count = backend.verify_zone_record_count(
@@ -437,6 +447,9 @@ class WorkerManager:
"save_queue_size": self.save_queue.qsize(), "save_queue_size": self.save_queue.qsize(),
"delete_queue_size": self.delete_queue.qsize(), "delete_queue_size": self.delete_queue.qsize(),
"save_worker_alive": self._save_thread and self._save_thread.is_alive(), "save_worker_alive": self._save_thread and self._save_thread.is_alive(),
"delete_worker_alive": self._delete_thread and self._delete_thread.is_alive(), "delete_worker_alive": self._delete_thread
"reconciler_alive": self._reconciler.is_alive if self._reconciler else False, and self._delete_thread.is_alive(),
"reconciler_alive": (
self._reconciler.is_alive if self._reconciler else False
),
} }

View File

@@ -1,10 +1,14 @@
"""Shared test fixtures for directdnsonly test suite.""" """Shared test fixtures for directdnsonly test suite."""
import pytest import pytest
from sqlalchemy import create_engine from sqlalchemy import create_engine
from sqlalchemy.orm import sessionmaker from sqlalchemy.orm import sessionmaker
from directdnsonly.app.db import Base from directdnsonly.app.db import Base
from directdnsonly.app.db.models import Domain, Key # noqa: F401 — registers models with Base from directdnsonly.app.db.models import (
Domain,
Key,
) # noqa: F401 — registers models with Base
@pytest.fixture @pytest.fixture

View File

@@ -1,4 +1,5 @@
"""Tests for directdnsonly.app.api.admin — DNSAdminAPI handler methods.""" """Tests for directdnsonly.app.api.admin — DNSAdminAPI handler methods."""
import pytest import pytest
from unittest.mock import MagicMock, patch from unittest.mock import MagicMock, patch
from urllib.parse import parse_qs from urllib.parse import parse_qs
@@ -60,8 +61,12 @@ def test_handle_exists_unsupported_action_returns_error(api):
def test_handle_exists_domain_not_found(api): def test_handle_exists_domain_not_found(api):
with patch("directdnsonly.app.api.admin.check_zone_exists", return_value=False), \ with (
patch("directdnsonly.app.api.admin.check_parent_domain_owner", return_value=False): patch("directdnsonly.app.api.admin.check_zone_exists", return_value=False),
patch(
"directdnsonly.app.api.admin.check_parent_domain_owner", return_value=False
),
):
result = api._handle_exists({"action": "exists", "domain": "example.com"}) result = api._handle_exists({"action": "exists", "domain": "example.com"})
parsed = parse_qs(result) parsed = parse_qs(result)
@@ -73,8 +78,10 @@ def test_handle_exists_domain_found(api):
record = MagicMock() record = MagicMock()
record.hostname = "da1.example.com" record.hostname = "da1.example.com"
with patch("directdnsonly.app.api.admin.check_zone_exists", return_value=True), \ with (
patch("directdnsonly.app.api.admin.get_domain_record", return_value=record): patch("directdnsonly.app.api.admin.check_zone_exists", return_value=True),
patch("directdnsonly.app.api.admin.get_domain_record", return_value=record),
):
result = api._handle_exists({"action": "exists", "domain": "example.com"}) result = api._handle_exists({"action": "exists", "domain": "example.com"})
parsed = parse_qs(result) parsed = parse_qs(result)
@@ -87,14 +94,22 @@ def test_handle_exists_parent_found(api):
parent = MagicMock() parent = MagicMock()
parent.hostname = "da2.example.com" parent.hostname = "da2.example.com"
with patch("directdnsonly.app.api.admin.check_zone_exists", return_value=False), \ with (
patch("directdnsonly.app.api.admin.check_parent_domain_owner", return_value=True), \ patch("directdnsonly.app.api.admin.check_zone_exists", return_value=False),
patch("directdnsonly.app.api.admin.get_parent_domain_record", return_value=parent): patch(
result = api._handle_exists({ "directdnsonly.app.api.admin.check_parent_domain_owner", return_value=True
"action": "exists", ),
"domain": "sub.example.com", patch(
"check_for_parent_domain": "1", "directdnsonly.app.api.admin.get_parent_domain_record", return_value=parent
}) ),
):
result = api._handle_exists(
{
"action": "exists",
"domain": "sub.example.com",
"check_for_parent_domain": "1",
}
)
parsed = parse_qs(result) parsed = parse_qs(result)
assert parsed["error"] == ["0"] assert parsed["error"] == ["0"]
@@ -107,9 +122,11 @@ def test_handle_exists_no_parent_check_when_flag_absent(api):
record = MagicMock() record = MagicMock()
record.hostname = "da1.example.com" record.hostname = "da1.example.com"
with patch("directdnsonly.app.api.admin.check_zone_exists", return_value=True), \ with (
patch("directdnsonly.app.api.admin.check_parent_domain_owner") as mock_parent, \ patch("directdnsonly.app.api.admin.check_zone_exists", return_value=True),
patch("directdnsonly.app.api.admin.get_domain_record", return_value=record): patch("directdnsonly.app.api.admin.check_parent_domain_owner") as mock_parent,
patch("directdnsonly.app.api.admin.get_domain_record", return_value=record),
):
api._handle_exists({"action": "exists", "domain": "example.com"}) api._handle_exists({"action": "exists", "domain": "example.com"})
mock_parent.assert_not_called() mock_parent.assert_not_called()
@@ -123,14 +140,21 @@ SAMPLE_ZONE = "$ORIGIN example.com.\n$TTL 300\nexample.com. 300 IN A 1.2.3.4\n"
def test_rawsave_enqueues_item(api, save_queue): def test_rawsave_enqueues_item(api, save_queue):
with patch("directdnsonly.app.api.admin.validate_and_normalize_zone", with (
return_value=SAMPLE_ZONE), \ patch(
patch.object(cherrypy, "request", MagicMock(remote=MagicMock(ip="127.0.0.1"))): "directdnsonly.app.api.admin.validate_and_normalize_zone",
result = api._handle_rawsave("example.com", { return_value=SAMPLE_ZONE,
"zone_file": SAMPLE_ZONE, ),
"hostname": "da1.example.com", patch.object(cherrypy, "request", MagicMock(remote=MagicMock(ip="127.0.0.1"))),
"username": "admin", ):
}) result = api._handle_rawsave(
"example.com",
{
"zone_file": SAMPLE_ZONE,
"hostname": "da1.example.com",
"username": "admin",
},
)
save_queue.put.assert_called_once() save_queue.put.assert_called_once()
item = save_queue.put.call_args[0][0] item = save_queue.put.call_args[0][0]
@@ -150,9 +174,13 @@ def test_rawsave_missing_zone_file_raises(api):
def test_rawsave_invalid_zone_raises(api): def test_rawsave_invalid_zone_raises(api):
with patch("directdnsonly.app.api.admin.validate_and_normalize_zone", with (
side_effect=ValueError("Invalid zone data: bad record")), \ patch(
patch.object(cherrypy, "request", MagicMock(remote=MagicMock(ip="127.0.0.1"))): "directdnsonly.app.api.admin.validate_and_normalize_zone",
side_effect=ValueError("Invalid zone data: bad record"),
),
patch.object(cherrypy, "request", MagicMock(remote=MagicMock(ip="127.0.0.1"))),
):
with pytest.raises(ValueError, match="Invalid zone data"): with pytest.raises(ValueError, match="Invalid zone data"):
api._handle_rawsave("example.com", {"zone_file": "garbage"}) api._handle_rawsave("example.com", {"zone_file": "garbage"})
@@ -164,10 +192,13 @@ def test_rawsave_invalid_zone_raises(api):
def test_delete_enqueues_item(api, delete_queue): def test_delete_enqueues_item(api, delete_queue):
with patch.object(cherrypy, "request", MagicMock(remote=MagicMock(ip="10.0.0.1"))): with patch.object(cherrypy, "request", MagicMock(remote=MagicMock(ip="10.0.0.1"))):
result = api._handle_delete("example.com", { result = api._handle_delete(
"hostname": "da1.example.com", "example.com",
"username": "admin", {
}) "hostname": "da1.example.com",
"username": "admin",
},
)
delete_queue.put.assert_called_once() delete_queue.put.assert_called_once()
item = delete_queue.put.call_args[0][0] item = delete_queue.put.call_args[0][0]

View File

@@ -1,4 +1,5 @@
"""Tests for the CoreDNS MySQL backend (run against in-memory SQLite).""" """Tests for the CoreDNS MySQL backend (run against in-memory SQLite)."""
import pytest import pytest
from sqlalchemy import create_engine from sqlalchemy import create_engine
from sqlalchemy.orm import scoped_session, sessionmaker from sqlalchemy.orm import scoped_session, sessionmaker
@@ -141,8 +142,16 @@ def test_reconcile_removes_extra_records(mysql_backend):
# Inject a phantom record directly into the DB # Inject a phantom record directly into the DB
session = mysql_backend.Session() session = mysql_backend.Session()
zone = session.query(Zone).filter_by(zone_name="example.com.").first() zone = session.query(Zone).filter_by(zone_name="example.com.").first()
session.add(Record(zone_id=zone.id, hostname="phantom", type="A", session.add(
data="10.0.0.99", ttl=300, online=True)) Record(
zone_id=zone.id,
hostname="phantom",
type="A",
data="10.0.0.99",
ttl=300,
online=True,
)
)
session.commit() session.commit()
session.close() session.close()

View File

@@ -1,4 +1,5 @@
"""Tests for directdnsonly.app.reconciler — ReconciliationWorker.""" """Tests for directdnsonly.app.reconciler — ReconciliationWorker."""
import pytest import pytest
import requests.exceptions import requests.exceptions
from queue import Queue from queue import Queue
@@ -12,7 +13,13 @@ from directdnsonly.app.db.models import Domain
# Fixtures # Fixtures
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------
SERVER = {"hostname": "da1.example.com", "port": 2222, "username": "admin", "password": "secret", "ssl": True} SERVER = {
"hostname": "da1.example.com",
"port": 2222,
"username": "admin",
"password": "secret",
"ssl": True,
}
BASE_CONFIG = { BASE_CONFIG = {
"enabled": True, "enabled": True,
@@ -46,7 +53,9 @@ def dry_run_worker(delete_queue):
def test_orphan_queued_when_domain_missing_from_da(worker, delete_queue, patch_connect): def test_orphan_queued_when_domain_missing_from_da(worker, delete_queue, patch_connect):
patch_connect.add(Domain(domain="orphan.com", hostname="da1.example.com", username="admin")) patch_connect.add(
Domain(domain="orphan.com", hostname="da1.example.com", username="admin")
)
patch_connect.commit() patch_connect.commit()
with patch.object(worker, "_fetch_da_domains", return_value=set()): with patch.object(worker, "_fetch_da_domains", return_value=set()):
@@ -59,7 +68,9 @@ def test_orphan_queued_when_domain_missing_from_da(worker, delete_queue, patch_c
def test_orphan_not_queued_in_dry_run(dry_run_worker, delete_queue, patch_connect): def test_orphan_not_queued_in_dry_run(dry_run_worker, delete_queue, patch_connect):
patch_connect.add(Domain(domain="orphan.com", hostname="da1.example.com", username="admin")) patch_connect.add(
Domain(domain="orphan.com", hostname="da1.example.com", username="admin")
)
patch_connect.commit() patch_connect.commit()
with patch.object(dry_run_worker, "_fetch_da_domains", return_value=set()): with patch.object(dry_run_worker, "_fetch_da_domains", return_value=set()):
@@ -70,7 +81,9 @@ def test_orphan_not_queued_in_dry_run(dry_run_worker, delete_queue, patch_connec
def test_orphan_not_queued_for_unknown_server(worker, delete_queue, patch_connect): def test_orphan_not_queued_for_unknown_server(worker, delete_queue, patch_connect):
"""Domains whose recorded master is NOT in our configured servers are skipped.""" """Domains whose recorded master is NOT in our configured servers are skipped."""
patch_connect.add(Domain(domain="other.com", hostname="da99.unknown.com", username="admin")) patch_connect.add(
Domain(domain="other.com", hostname="da99.unknown.com", username="admin")
)
patch_connect.commit() patch_connect.commit()
with patch.object(worker, "_fetch_da_domains", return_value=set()): with patch.object(worker, "_fetch_da_domains", return_value=set()):
@@ -80,7 +93,9 @@ def test_orphan_not_queued_for_unknown_server(worker, delete_queue, patch_connec
def test_active_domain_not_queued(worker, delete_queue, patch_connect): def test_active_domain_not_queued(worker, delete_queue, patch_connect):
patch_connect.add(Domain(domain="good.com", hostname="da1.example.com", username="admin")) patch_connect.add(
Domain(domain="good.com", hostname="da1.example.com", username="admin")
)
patch_connect.commit() patch_connect.commit()
with patch.object(worker, "_fetch_da_domains", return_value={"good.com"}): with patch.object(worker, "_fetch_da_domains", return_value={"good.com"}):
@@ -106,7 +121,9 @@ def test_backfill_null_hostname(worker, patch_connect):
def test_migration_updates_hostname(worker, patch_connect): def test_migration_updates_hostname(worker, patch_connect):
patch_connect.add(Domain(domain="moved.com", hostname="da-old.example.com", username="admin")) patch_connect.add(
Domain(domain="moved.com", hostname="da-old.example.com", username="admin")
)
patch_connect.commit() patch_connect.commit()
with patch.object(worker, "_fetch_da_domains", return_value={"moved.com"}): with patch.object(worker, "_fetch_da_domains", return_value={"moved.com"}):
@@ -150,7 +167,9 @@ def test_fetch_returns_domains_from_json(worker):
mock_resp = _make_json_response(["example.com", "test.com"]) mock_resp = _make_json_response(["example.com", "test.com"])
with patch("requests.get", return_value=mock_resp): with patch("requests.get", return_value=mock_resp):
result = worker._fetch_da_domains("da1.example.com", 2222, "admin", "secret", True) result = worker._fetch_da_domains(
"da1.example.com", 2222, "admin", "secret", True
)
assert result == {"example.com", "test.com"} assert result == {"example.com", "test.com"}
@@ -160,7 +179,9 @@ def test_fetch_paginates(worker):
page2 = _make_json_response(["b.com"], total_pages=2) page2 = _make_json_response(["b.com"], total_pages=2)
with patch("requests.get", side_effect=[page1, page2]): with patch("requests.get", side_effect=[page1, page2]):
result = worker._fetch_da_domains("da1.example.com", 2222, "admin", "secret", True) result = worker._fetch_da_domains(
"da1.example.com", 2222, "admin", "secret", True
)
assert result == {"a.com", "b.com"} assert result == {"a.com", "b.com"}
@@ -170,9 +191,13 @@ def test_fetch_redirect_triggers_session_login(worker):
redirect_resp.status_code = 302 redirect_resp.status_code = 302
redirect_resp.is_redirect = True redirect_resp.is_redirect = True
with patch("requests.get", return_value=redirect_resp), \ with (
patch.object(worker, "_da_session_login", return_value=None): patch("requests.get", return_value=redirect_resp),
result = worker._fetch_da_domains("da1.example.com", 2222, "admin", "secret", True) patch.object(worker, "_da_session_login", return_value=None),
):
result = worker._fetch_da_domains(
"da1.example.com", 2222, "admin", "secret", True
)
assert result is None assert result is None
@@ -185,28 +210,40 @@ def test_fetch_html_response_returns_none(worker):
mock_resp.raise_for_status = MagicMock() mock_resp.raise_for_status = MagicMock()
with patch("requests.get", return_value=mock_resp): with patch("requests.get", return_value=mock_resp):
result = worker._fetch_da_domains("da1.example.com", 2222, "admin", "secret", True) result = worker._fetch_da_domains(
"da1.example.com", 2222, "admin", "secret", True
)
assert result is None assert result is None
def test_fetch_connection_error_returns_none(worker): def test_fetch_connection_error_returns_none(worker):
with patch("requests.get", side_effect=requests.exceptions.ConnectionError("refused")): with patch(
result = worker._fetch_da_domains("da1.example.com", 2222, "admin", "secret", True) "requests.get", side_effect=requests.exceptions.ConnectionError("refused")
):
result = worker._fetch_da_domains(
"da1.example.com", 2222, "admin", "secret", True
)
assert result is None assert result is None
def test_fetch_timeout_returns_none(worker): def test_fetch_timeout_returns_none(worker):
with patch("requests.get", side_effect=requests.exceptions.Timeout()): with patch("requests.get", side_effect=requests.exceptions.Timeout()):
result = worker._fetch_da_domains("da1.example.com", 2222, "admin", "secret", True) result = worker._fetch_da_domains(
"da1.example.com", 2222, "admin", "secret", True
)
assert result is None assert result is None
def test_fetch_ssl_error_returns_none(worker): def test_fetch_ssl_error_returns_none(worker):
with patch("requests.get", side_effect=requests.exceptions.SSLError("cert verify failed")): with patch(
result = worker._fetch_da_domains("da1.example.com", 2222, "admin", "secret", True) "requests.get", side_effect=requests.exceptions.SSLError("cert verify failed")
):
result = worker._fetch_da_domains(
"da1.example.com", 2222, "admin", "secret", True
)
assert result is None assert result is None

View File

@@ -1,4 +1,5 @@
"""Tests for directdnsonly.app.utils — zone index helper functions.""" """Tests for directdnsonly.app.utils — zone index helper functions."""
import pytest import pytest
from directdnsonly.app.db.models import Domain from directdnsonly.app.db.models import Domain

View File

@@ -1,4 +1,5 @@
"""Tests for directdnsonly.app.utils.zone_parser.""" """Tests for directdnsonly.app.utils.zone_parser."""
import pytest import pytest
from dns.exception import DNSException from dns.exception import DNSException