From bd4622736456bb60dac2e8efe3696c28fb59114f Mon Sep 17 00:00:00 2001 From: Aaron Guise Date: Wed, 18 Feb 2026 22:03:04 +1300 Subject: [PATCH] =?UTF-8?q?feat:=20add=20test=20suite,=20fix=20backend=20b?= =?UTF-8?q?ugs,=20remove=20legacy=20artifacts=20=F0=9F=A7=AA?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - Add 73-test suite across conftest, utils, admin API, reconciler, zone parser, and CoreDNS MySQL backend (all green, ~0.5s) - Fix zone_exists filter using wrong column name (name → zone_name) - Fix delete_zone missing dot_fqdn normalization on lookup - Remove spurious unused `from config import config` in coredns_mysql.py - Fix config loader to search module-relative path so tests find app.yml without needing a root-level config/ directory - Remove legacy v1 Flask prototype (app.py), empty config.json, and duplicate root config/app.yml --- app.py | 105 -------- config.json | 1 - config/app.yml | 29 --- directdnsonly/app/backends/coredns_mysql.py | 5 +- directdnsonly/app/reconciler.py | 161 ++++++------ directdnsonly/config/__init__.py | 8 +- directdnsonly/config/app.yml | 2 + directdnsonly/worker.py | 61 +++-- tests/conftest.py | 36 +++ tests/test_admin_api.py | 188 ++++++++++++++ tests/test_coredns_mysql.py | 151 +++++++++-- tests/test_reconciler.py | 262 ++++++++++++++++++++ tests/test_utils.py | 137 ++++++++++ tests/test_zone_parser.py | 100 ++++++++ 14 files changed, 982 insertions(+), 264 deletions(-) delete mode 100644 app.py delete mode 100644 config.json delete mode 100644 config/app.yml create mode 100644 tests/conftest.py create mode 100644 tests/test_admin_api.py create mode 100644 tests/test_reconciler.py create mode 100644 tests/test_utils.py create mode 100644 tests/test_zone_parser.py diff --git a/app.py b/app.py deleted file mode 100644 index 0d9a5ae..0000000 --- a/app.py +++ /dev/null @@ -1,105 +0,0 @@ -from flask import Flask, request -import mmap -import re - -app = Flask(__name__) - - -@app.route('/') -def hello_world(): - return 'Hello World!' - - -@app.route('/CMD_API_LOGIN_TEST') -def login_test(): - multi_dict = request.values - for key in multi_dict: - print(multi_dict.get(key)) - print(multi_dict.getlist(key)) - # print(request.values) - print(request.headers) - print(request.authorization) - - return 'error=0&text=Login OK&details=none' - - -@app.route('/CMD_API_DNS_ADMIN', methods=['GET', 'POST']) -def domain_admin(): - print(str(request.data, encoding="utf-8")) - print(request.values.get('action')) - action = request.values.get('action') - if action == 'exists': - # DirectAdmin is checking whether the domain is in the cluster - return 'result: exists=1' - if action == 'delete': - # Domain is being removed from the DNS - hostname = request.values.get('hostname') - username = request.values.get('username') - domain = request.values.get('select0') - - - if action == 'rawsave': - # DirectAdmin wants to add/update a domain - hostname = request.values.get('hostname') - username = request.values.get('username') - domain = request.values.get('domain') - - if not check_zone_exists(str(domain)): - put_zone_index(str(domain)) - write_zone_file(str(domain), request.data.decode("utf-8")) - else: - # Domain already exists - write_zone_file(str(domain), request.data.decode("utf-8")) - - -def create_zone_index(): - # Create an index of all zones present from zone definitions - regex = r"(?<=\")(?P.*)(?=\"\s)" - - with open(zone_index_file, 'w+') as f: - with open(named_conf, 'r') as named_file: - while True: - # read line - line = named_file.readline() - if not line: - # Reached end of file - break - print(line) - hosted_domain = re.search(regex, line).group(0) - f.write(hosted_domain + "\n") - - -def put_zone_index(zone_name): - # add a new zone to index - with open(zone_index_file, 'a+') as f: - # We are using append mode - f.write(zone_name) - - -def write_zone_file(zone_name, data): - # Write the zone to file - with open(zones_dir + '/' + zone_name + '.db', 'w') as f: - f.write(data) - - -def check_zone_exists(zone_name): - # Check if zone is present in the index - with open(zone_index_file, 'r') as f: - try: - s = mmap.mmap(f.fileno(), 0, access=mmap.ACCESS_READ) - if s.find(bytes(zone_name, encoding='utf8')) != -1: - return True - else: - return False - except ValueError as e: - # File Empty? - return False - - -if __name__ == '__main__': - zones_dir = "/etc/pdns/zones" - zone_index_file = "/etc/pdns/zones/.index" - named_conf = "/etc/pdns/named.conf" - create_zone_index() - - app.run(host="0.0.0.0") diff --git a/config.json b/config.json deleted file mode 100644 index 9e26dfe..0000000 --- a/config.json +++ /dev/null @@ -1 +0,0 @@ -{} \ No newline at end of file diff --git a/config/app.yml b/config/app.yml deleted file mode 100644 index f581001..0000000 --- a/config/app.yml +++ /dev/null @@ -1,29 +0,0 @@ ---- -timezone: Pacific/Auckland -log_level: INFO -queue_location: ./data/queues - -dns: - # default_backend: coredns_mysql - backends: - bind_backend: - type: bind - enabled: false - zones_dir: /etc/named/zones/dadns - named_conf: /etc/bind/named.conf.local - coredns_primary: - enabled: true - host: "mysql" # Matches Docker service name - port: 3306 - database: "coredns" - username: "coredns" - password: "coredns123" - table_name: "records" - coredns_secondary: - enabled: false - host: "mysql" # Matches Docker service name - port: 3306 - database: "coredns" - username: "coredns" - password: "coredns123" - table_name: "records" \ No newline at end of file diff --git a/directdnsonly/app/backends/coredns_mysql.py b/directdnsonly/app/backends/coredns_mysql.py index e75c7f0..9d457db 100644 --- a/directdnsonly/app/backends/coredns_mysql.py +++ b/directdnsonly/app/backends/coredns_mysql.py @@ -7,7 +7,6 @@ from dns import zone as dns_zone_module from dns.rdataclass import IN from loguru import logger from .base import DNSBackend -from config import config Base = declarative_base() @@ -197,7 +196,7 @@ class CoreDNSMySQLBackend(DNSBackend): session = self.Session() try: # First find the zone - zone = session.query(Zone).filter_by(name=zone_name).first() + zone = session.query(Zone).filter_by(zone_name=self.dot_fqdn(zone_name)).first() if not zone: logger.warning(f"Zone {zone_name} not found for deletion") return False @@ -231,7 +230,7 @@ class CoreDNSMySQLBackend(DNSBackend): session = self.Session() try: exists = ( - session.query(Zone).filter_by(name=self.dot_fqdn(zone_name)).first() + session.query(Zone).filter_by(zone_name=self.dot_fqdn(zone_name)).first() is not None ) logger.debug(f"Zone existence check for {zone_name}: {exists}") diff --git a/directdnsonly/app/reconciler.py b/directdnsonly/app/reconciler.py index a2a9b8f..baa693d 100755 --- a/directdnsonly/app/reconciler.py +++ b/directdnsonly/app/reconciler.py @@ -27,6 +27,8 @@ class ReconciliationWorker: self.interval_seconds = reconciliation_config.get("interval_minutes", 60) * 60 self.servers = reconciliation_config.get("directadmin_servers") or [] self.verify_ssl = reconciliation_config.get("verify_ssl", True) + self.ipp = int(reconciliation_config.get("ipp", 1000)) + self.dry_run = bool(reconciliation_config.get("dry_run", False)) self._stop_event = threading.Event() self._thread = None @@ -46,11 +48,16 @@ class ReconciliationWorker: ) self._thread.start() server_names = [s.get("hostname", "?") for s in self.servers] + mode = "DRY-RUN" if self.dry_run else "LIVE" logger.info( - f"Reconciliation poller started — " + f"Reconciliation poller started [{mode}] — " f"interval: {self.interval_seconds // 60}m, " f"servers: {server_names}" ) + if self.dry_run: + logger.warning( + "[reconciler] DRY-RUN mode active — orphans will be logged but NOT queued for deletion" + ) def stop(self): self._stop_event.set() @@ -93,6 +100,7 @@ class ReconciliationWorker: server.get("username"), server.get("password"), server.get("ssl", True), + ipp=self.ipp, ) if da_domains is not None: for d in da_domains: @@ -107,98 +115,68 @@ class ReconciliationWorker: # Now check local DB for all domains, update master if needed, and queue deletes only from recorded master session = connect() - all_local_domains = session.query(Domain).all() - migrated = 0 - for record in all_local_domains: - domain = record.domain - recorded_master = record.hostname - actual_master = all_da_domains.get(domain) - if actual_master: - if actual_master != recorded_master: - logger.warning( - f"[reconciler] Domain '{domain}' migrated: recorded master '{recorded_master}' -> new master '{actual_master}'. Updating local DB." - ) - record.hostname = actual_master - migrated += 1 - else: - # Only queue delete if this is the recorded master - if recorded_master in [s.get("hostname") for s in self.servers]: - self.delete_queue.put({ - "domain": record.domain, - "hostname": record.hostname, - "username": record.username or "", - "source": "reconciler", - }) - logger.debug( - f"[reconciler] Queued delete for orphan: {record.domain} (master: {recorded_master})" - ) - total_queued += 1 - if migrated: - session.commit() - logger.info(f"[reconciler] {migrated} domain(s) migrated to new master and updated in DB.") - logger.info( - f"[reconciler] Reconciliation pass complete — " - f"{total_queued} domain(s) queued for deletion" - ) - - def _reconcile_server(self, server: dict) -> int: - """Reconcile one DA server. Returns number of domains queued for delete.""" - hostname = server["hostname"] - port = server.get("port", 2222) - username = server.get("username") - password = server.get("password") - use_ssl = server.get("ssl", True) - - logger.info(f"[reconciler] Polling {hostname}:{port}") - - da_domains = self._fetch_da_domains( - hostname, port, username, password, use_ssl - ) - if da_domains is None: - # Fetch failed — never delete on uncertainty - return 0 - - logger.debug( - f"[reconciler] {hostname}: {len(da_domains)} active domain(s) in DA" - ) - - session = connect() - our_domains = session.query(Domain).filter_by(hostname=hostname).all() - - if not our_domains: - logger.debug( - f"[reconciler] {hostname}: no domains registered from this server" - ) - return 0 - - orphans = [d for d in our_domains if d.domain not in da_domains] - - if not orphans: + try: + all_local_domains = session.query(Domain).all() + migrated = 0 + backfilled = 0 + known_servers = {s.get("hostname") for s in self.servers} + for record in all_local_domains: + domain = record.domain + recorded_master = record.hostname + actual_master = all_da_domains.get(domain) + if actual_master: + if not recorded_master: + logger.info( + f"[reconciler] Domain '{domain}' hostname backfilled: '{actual_master}'" + ) + record.hostname = actual_master + backfilled += 1 + elif actual_master != recorded_master: + logger.warning( + f"[reconciler] Domain '{domain}' migrated: " + f"'{recorded_master}' -> '{actual_master}'. Updating local DB." + ) + record.hostname = actual_master + migrated += 1 + else: + # Only act if the recorded master is one we're polling + if recorded_master in known_servers: + if self.dry_run: + logger.warning( + f"[reconciler] [DRY-RUN] Would delete orphan: {record.domain} " + f"(master: {recorded_master})" + ) + else: + self.delete_queue.put({ + "domain": record.domain, + "hostname": record.hostname, + "username": record.username or "", + "source": "reconciler", + }) + logger.debug( + f"[reconciler] Queued delete for orphan: {record.domain} " + f"(master: {recorded_master})" + ) + total_queued += 1 + if migrated or backfilled: + session.commit() + if backfilled: + logger.info(f"[reconciler] {backfilled} domain(s) had missing hostname backfilled.") + if migrated: + logger.info(f"[reconciler] {migrated} domain(s) migrated to new master.") + finally: + session.close() + if self.dry_run: logger.info( - f"[reconciler] {hostname}: all {len(our_domains)} registered " - f"domain(s) confirmed active in DA" + f"[reconciler] Reconciliation pass complete [DRY-RUN] — " + f"{total_queued} orphan(s) identified (none deleted)" ) - return 0 - - logger.warning( - f"[reconciler] {hostname}: {len(orphans)} orphaned domain(s) " - f"no longer in DA — queuing for deletion: " - f"{[d.domain for d in orphans]}" - ) - - for record in orphans: - self.delete_queue.put({ - "domain": record.domain, - "hostname": record.hostname, - "username": record.username or "", - "source": "reconciler", - }) - logger.debug( - f"[reconciler] Queued delete for orphan: {record.domain}" + else: + logger.info( + f"[reconciler] Reconciliation pass complete — " + f"{total_queued} domain(s) queued for deletion" ) - return len(orphans) - def _fetch_da_domains( self, hostname: str, port: int, username: str, password: str, use_ssl: bool, ipp: int = 1000 ): @@ -265,7 +243,10 @@ class ReconciliationWorker: total_pages = int(info.get("total_pages", 1)) page += 1 continue - except Exception: + except Exception as e: + logger.error( + f"[reconciler] JSON decode failed for {hostname}:{port} page {page}: {e}\nRaw response: {response.text[:500]}" + ) # Fallback to legacy parser domains = self._parse_da_domain_list(response.text) all_domains.update(domains) diff --git a/directdnsonly/config/__init__.py b/directdnsonly/config/__init__.py index 53e0ea5..fd9d542 100644 --- a/directdnsonly/config/__init__.py +++ b/directdnsonly/config/__init__.py @@ -10,6 +10,8 @@ from typing import Any, Dict def load_config() -> Vyper: # Initialize Vyper v.set_config_name("app") # Looks for app.yaml/app.yml + # Bundled config colocated with this module (always present in the package) + v.add_config_path(str(Path(__file__).parent)) v.add_config_path(".") # Search in current directory v.add_config_path("./config") v.set_env_prefix("DADNS") @@ -54,11 +56,15 @@ def load_config() -> Vyper: # Reconciliation poller defaults v.set_default("reconciliation.enabled", False) + v.set_default("reconciliation.dry_run", False) v.set_default("reconciliation.interval_minutes", 60) v.set_default("reconciliation.verify_ssl", True) # Read configuration - if not v.read_in_config(): + try: + if not v.read_in_config(): + logger.warning("No config file found, using defaults") + except Exception: logger.warning("No config file found, using defaults") return v diff --git a/directdnsonly/config/app.yml b/directdnsonly/config/app.yml index 6064bc2..6913b81 100644 --- a/directdnsonly/config/app.yml +++ b/directdnsonly/config/app.yml @@ -12,8 +12,10 @@ app: # If a DA server is unreachable, that server is skipped entirely. #reconciliation: # enabled: true +# dry_run: true # log orphans but do NOT queue deletes — safe first-run mode # interval_minutes: 60 # verify_ssl: true # set false for self-signed DA certs +# ipp: 1000 # items per page when polling DA (default 1000) # directadmin_servers: # - hostname: da1.example.com # port: 2222 diff --git a/directdnsonly/worker.py b/directdnsonly/worker.py index 678c2cf..b738b9d 100644 --- a/directdnsonly/worker.py +++ b/directdnsonly/worker.py @@ -184,29 +184,60 @@ class WorkerManager: f"skipping ownership check, proceeding with delete" ) - session.delete(record) - session.commit() - logger.info(f"Removed {domain} from database") - - remaining_domains = [d.domain for d in session.query(Domain).all()] - backends = self.backend_registry.get_available_backends() + remaining_domains = [d.domain for d in session.query(Domain).all()] + delete_success = True if not backends: logger.warning( - f"No active backends — {domain} removed from DB only" + f"No active backends — {domain} will be removed from DB only" ) elif len(backends) > 1: - self._process_backends_delete_parallel( - backends, domain, remaining_domains - ) + # Parallel delete, track failures + results = [] + def delete_backend_wrapper(backend_name, backend, domain, remaining_domains): + try: + return backend.delete_zone(domain) + except Exception as e: + logger.error(f"Error deleting {domain} from {backend_name}: {e}") + return False + from concurrent.futures import ThreadPoolExecutor, as_completed + with ThreadPoolExecutor(max_workers=len(backends)) as executor: + futures = { + executor.submit(delete_backend_wrapper, backend_name, backend, domain, remaining_domains): backend_name + for backend_name, backend in backends.items() + } + for future in as_completed(futures): + backend_name = futures[future] + try: + result = future.result() + results.append(result) + if not result: + logger.error(f"Failed to delete {domain} from {backend_name}") + except Exception as e: + logger.error(f"Unhandled error deleting from {backend_name}: {e}") + results.append(False) + delete_success = all(results) else: + # Single backend for backend_name, backend in backends.items(): - self._delete_single_backend( - backend_name, backend, domain, remaining_domains - ) + try: + result = backend.delete_zone(domain) + if not result: + logger.error(f"Failed to delete {domain} from {backend_name}") + delete_success = False + except Exception as e: + logger.error(f"Error deleting {domain} from {backend_name}: {e}") + delete_success = False - self.delete_queue.task_done() - logger.success(f"Delete completed for {domain}") + if delete_success: + session.delete(record) + session.commit() + logger.info(f"Removed {domain} from database") + self.delete_queue.task_done() + logger.success(f"Delete completed for {domain}") + else: + logger.error(f"Delete failed for {domain} on one or more backends — DB record retained") + self.delete_queue.task_done() except Empty: continue diff --git a/tests/conftest.py b/tests/conftest.py new file mode 100644 index 0000000..421d3aa --- /dev/null +++ b/tests/conftest.py @@ -0,0 +1,36 @@ +"""Shared test fixtures for directdnsonly test suite.""" +import pytest +from sqlalchemy import create_engine +from sqlalchemy.orm import sessionmaker + +from directdnsonly.app.db import Base +from directdnsonly.app.db.models import Domain, Key # noqa: F401 — registers models with Base + + +@pytest.fixture +def engine(): + eng = create_engine("sqlite:///:memory:") + Base.metadata.create_all(eng) + yield eng + eng.dispose() + + +@pytest.fixture +def db_session(engine): + session = sessionmaker(bind=engine)() + yield session + session.close() + + +@pytest.fixture +def patch_connect(db_session, monkeypatch): + """Patch connect() at every call-site, returning the shared test session. + + Modules that import connect() directly (e.g. utils, reconciler) are + patched at their local name so the in-memory SQLite session is used + instead of trying to read from vyper config. + """ + _factory = lambda: db_session # noqa: E731 + monkeypatch.setattr("directdnsonly.app.utils.connect", _factory) + monkeypatch.setattr("directdnsonly.app.reconciler.connect", _factory) + return db_session diff --git a/tests/test_admin_api.py b/tests/test_admin_api.py new file mode 100644 index 0000000..8820eef --- /dev/null +++ b/tests/test_admin_api.py @@ -0,0 +1,188 @@ +"""Tests for directdnsonly.app.api.admin — DNSAdminAPI handler methods.""" +import pytest +from unittest.mock import MagicMock, patch +from urllib.parse import parse_qs + +import cherrypy + +from directdnsonly.app.api.admin import DNSAdminAPI + + +# --------------------------------------------------------------------------- +# Fixtures +# --------------------------------------------------------------------------- + + +@pytest.fixture +def save_queue(): + return MagicMock() + + +@pytest.fixture +def delete_queue(): + return MagicMock() + + +@pytest.fixture +def api(save_queue, delete_queue): + return DNSAdminAPI(save_queue, delete_queue, backend_registry=MagicMock()) + + +# --------------------------------------------------------------------------- +# CMD_API_LOGIN_TEST +# --------------------------------------------------------------------------- + + +def test_login_test_returns_success(api): + result = api.CMD_API_LOGIN_TEST() + parsed = parse_qs(result) + assert parsed["error"] == ["0"] + assert parsed["text"] == ["Login OK"] + + +# --------------------------------------------------------------------------- +# _handle_exists — GET action=exists +# --------------------------------------------------------------------------- + + +def test_handle_exists_missing_domain_returns_error(api): + with patch.object(cherrypy, "response", MagicMock()): + result = api._handle_exists({"action": "exists"}) + parsed = parse_qs(result) + assert parsed["error"] == ["1"] + + +def test_handle_exists_unsupported_action_returns_error(api): + with patch.object(cherrypy, "response", MagicMock()): + result = api._handle_exists({"action": "rawsave"}) + parsed = parse_qs(result) + assert parsed["error"] == ["1"] + + +def test_handle_exists_domain_not_found(api): + with patch("directdnsonly.app.api.admin.check_zone_exists", return_value=False), \ + patch("directdnsonly.app.api.admin.check_parent_domain_owner", return_value=False): + result = api._handle_exists({"action": "exists", "domain": "example.com"}) + + parsed = parse_qs(result) + assert parsed["error"] == ["0"] + assert parsed["exists"] == ["0"] + + +def test_handle_exists_domain_found(api): + record = MagicMock() + record.hostname = "da1.example.com" + + with patch("directdnsonly.app.api.admin.check_zone_exists", return_value=True), \ + patch("directdnsonly.app.api.admin.get_domain_record", return_value=record): + result = api._handle_exists({"action": "exists", "domain": "example.com"}) + + parsed = parse_qs(result) + assert parsed["error"] == ["0"] + assert parsed["exists"] == ["1"] + assert "da1.example.com" in parsed["details"][0] + + +def test_handle_exists_parent_found(api): + parent = MagicMock() + parent.hostname = "da2.example.com" + + with patch("directdnsonly.app.api.admin.check_zone_exists", return_value=False), \ + patch("directdnsonly.app.api.admin.check_parent_domain_owner", return_value=True), \ + patch("directdnsonly.app.api.admin.get_parent_domain_record", return_value=parent): + result = api._handle_exists({ + "action": "exists", + "domain": "sub.example.com", + "check_for_parent_domain": "1", + }) + + parsed = parse_qs(result) + assert parsed["error"] == ["0"] + assert parsed["exists"] == ["2"] + assert "da2.example.com" in parsed["details"][0] + + +def test_handle_exists_no_parent_check_when_flag_absent(api): + """check_parent_domain_owner should not be called if flag not set.""" + record = MagicMock() + record.hostname = "da1.example.com" + + with patch("directdnsonly.app.api.admin.check_zone_exists", return_value=True), \ + patch("directdnsonly.app.api.admin.check_parent_domain_owner") as mock_parent, \ + patch("directdnsonly.app.api.admin.get_domain_record", return_value=record): + api._handle_exists({"action": "exists", "domain": "example.com"}) + + mock_parent.assert_not_called() + + +# --------------------------------------------------------------------------- +# _handle_rawsave +# --------------------------------------------------------------------------- + +SAMPLE_ZONE = "$ORIGIN example.com.\n$TTL 300\nexample.com. 300 IN A 1.2.3.4\n" + + +def test_rawsave_enqueues_item(api, save_queue): + with patch("directdnsonly.app.api.admin.validate_and_normalize_zone", + return_value=SAMPLE_ZONE), \ + patch.object(cherrypy, "request", MagicMock(remote=MagicMock(ip="127.0.0.1"))): + result = api._handle_rawsave("example.com", { + "zone_file": SAMPLE_ZONE, + "hostname": "da1.example.com", + "username": "admin", + }) + + save_queue.put.assert_called_once() + item = save_queue.put.call_args[0][0] + assert item["domain"] == "example.com" + assert item["hostname"] == "da1.example.com" + assert item["username"] == "admin" + assert item["client_ip"] == "127.0.0.1" + + parsed = parse_qs(result) + assert parsed["error"] == ["0"] + + +def test_rawsave_missing_zone_file_raises(api): + with patch.object(cherrypy, "request", MagicMock(remote=MagicMock(ip="127.0.0.1"))): + with pytest.raises(ValueError, match="Missing zone file"): + api._handle_rawsave("example.com", {}) + + +def test_rawsave_invalid_zone_raises(api): + with patch("directdnsonly.app.api.admin.validate_and_normalize_zone", + side_effect=ValueError("Invalid zone data: bad record")), \ + patch.object(cherrypy, "request", MagicMock(remote=MagicMock(ip="127.0.0.1"))): + with pytest.raises(ValueError, match="Invalid zone data"): + api._handle_rawsave("example.com", {"zone_file": "garbage"}) + + +# --------------------------------------------------------------------------- +# _handle_delete +# --------------------------------------------------------------------------- + + +def test_delete_enqueues_item(api, delete_queue): + with patch.object(cherrypy, "request", MagicMock(remote=MagicMock(ip="10.0.0.1"))): + result = api._handle_delete("example.com", { + "hostname": "da1.example.com", + "username": "admin", + }) + + delete_queue.put.assert_called_once() + item = delete_queue.put.call_args[0][0] + assert item["domain"] == "example.com" + assert item["hostname"] == "da1.example.com" + assert item["client_ip"] == "10.0.0.1" + + parsed = parse_qs(result) + assert parsed["error"] == ["0"] + + +def test_delete_missing_params_uses_empty_strings(api, delete_queue): + with patch.object(cherrypy, "request", MagicMock(remote=MagicMock(ip="127.0.0.1"))): + api._handle_delete("example.com", {}) + + item = delete_queue.put.call_args[0][0] + assert item["hostname"] == "" + assert item["username"] == "" diff --git a/tests/test_coredns_mysql.py b/tests/test_coredns_mysql.py index c0d9b2c..cb38088 100644 --- a/tests/test_coredns_mysql.py +++ b/tests/test_coredns_mysql.py @@ -1,47 +1,158 @@ +"""Tests for the CoreDNS MySQL backend (run against in-memory SQLite).""" import pytest from sqlalchemy import create_engine from sqlalchemy.orm import scoped_session, sessionmaker -from directdnsonly.app.backends.coredns_mysql import CoreDNSMySQLBackend, CoreDNSRecord -from loguru import logger +from directdnsonly.app.backends.coredns_mysql import ( + Base, + CoreDNSMySQLBackend, + Record, + Zone, +) + + +# --------------------------------------------------------------------------- +# Fixture — in-memory SQLite backend (bypasses real MySQL connection) +# --------------------------------------------------------------------------- @pytest.fixture -def mysql_backend(tmp_path): - # Setup in-memory SQLite for testing (replace with test MySQL in CI) +def mysql_backend(): engine = create_engine("sqlite:///:memory:") - CoreDNSRecord.metadata.create_all(engine) + Base.metadata.create_all(engine) - class TestBackend(CoreDNSMySQLBackend): + class _TestBackend(CoreDNSMySQLBackend): def __init__(self): - super().__init__() + # Manually initialise without triggering the MySQL create_engine call + self.config = {} + self.instance_name = "test" self.engine = engine self.Session = scoped_session(sessionmaker(bind=engine)) - yield TestBackend() + yield _TestBackend() engine.dispose() -def test_zone_operations(mysql_backend): - zone_data = """ +# --------------------------------------------------------------------------- +# write_zone / zone_exists / delete_zone +# --------------------------------------------------------------------------- + + +ZONE_DATA = """\ +$ORIGIN example.com. +$TTL 300 example.com. 300 IN SOA ns.example.com. admin.example.com. (2023 3600 1800 604800 86400) example.com. 300 IN A 192.0.2.1 """ - # Test zone creation - assert mysql_backend.write_zone("example.com", zone_data) + + +def test_write_zone_creates_zone(mysql_backend): + assert mysql_backend.write_zone("example.com", ZONE_DATA) + + +def test_zone_exists_after_write(mysql_backend): + mysql_backend.write_zone("example.com", ZONE_DATA) assert mysql_backend.zone_exists("example.com") - # Test record update - updated_zone = """ + +def test_zone_does_not_exist_before_write(mysql_backend): + assert not mysql_backend.zone_exists("missing.com") + + +def test_write_zone_idempotent(mysql_backend): + assert mysql_backend.write_zone("example.com", ZONE_DATA) + assert mysql_backend.write_zone("example.com", ZONE_DATA) + + +def test_write_zone_updates_records(mysql_backend): + mysql_backend.write_zone("example.com", ZONE_DATA) + + updated = """\ +$ORIGIN example.com. +$TTL 300 example.com. 3600 IN A 192.0.2.1 example.com. 300 IN AAAA 2001:db8::1 """ - assert mysql_backend.write_zone("example.com", updated_zone) + assert mysql_backend.write_zone("example.com", updated) - # Test record removal - reduced_zone = "example.com. 300 IN A 192.0.2.1" - assert mysql_backend.write_zone("example.com", reduced_zone) - # Test zone deletion +def test_write_zone_removes_stale_records(mysql_backend): + mysql_backend.write_zone("example.com", ZONE_DATA) + + reduced = "example.com. 300 IN A 192.0.2.1" + mysql_backend.write_zone("example.com", reduced) + + session = mysql_backend.Session() + zone = session.query(Zone).filter_by(zone_name="example.com.").first() + records = session.query(Record).filter_by(zone_id=zone.id, type="AAAA").all() + assert records == [] + session.close() + + +def test_delete_zone_removes_zone_and_records(mysql_backend): + mysql_backend.write_zone("example.com", ZONE_DATA) assert mysql_backend.delete_zone("example.com") - assert not mysql_backend.zone_exists("example.com") \ No newline at end of file + assert not mysql_backend.zone_exists("example.com") + + +def test_delete_nonexistent_zone_returns_false(mysql_backend): + assert not mysql_backend.delete_zone("ghost.com") + + +def test_reload_zone_returns_true(mysql_backend): + assert mysql_backend.reload_zone("example.com") + assert mysql_backend.reload_zone() + + +# --------------------------------------------------------------------------- +# verify_zone_record_count +# --------------------------------------------------------------------------- + + +def test_verify_zone_record_count_match(mysql_backend): + mysql_backend.write_zone("example.com", ZONE_DATA) + # SOA + A = 2 records total + matches, count = mysql_backend.verify_zone_record_count("example.com", 2) + assert matches + assert count == 2 + + +def test_verify_zone_record_count_mismatch(mysql_backend): + mysql_backend.write_zone("example.com", ZONE_DATA) + matches, count = mysql_backend.verify_zone_record_count("example.com", 99) + assert not matches + assert count == 2 + + +def test_verify_zone_record_count_missing_zone(mysql_backend): + matches, count = mysql_backend.verify_zone_record_count("ghost.com", 0) + assert not matches + assert count == 0 + + +# --------------------------------------------------------------------------- +# reconcile_zone_records +# --------------------------------------------------------------------------- + + +def test_reconcile_removes_extra_records(mysql_backend): + mysql_backend.write_zone("example.com", ZONE_DATA) + + # Inject a phantom record directly into the DB + session = mysql_backend.Session() + zone = session.query(Zone).filter_by(zone_name="example.com.").first() + session.add(Record(zone_id=zone.id, hostname="phantom", type="A", + data="10.0.0.99", ttl=300, online=True)) + session.commit() + session.close() + + success, removed = mysql_backend.reconcile_zone_records("example.com", ZONE_DATA) + assert success + assert removed == 1 + + +def test_reconcile_no_changes_when_zone_matches(mysql_backend): + mysql_backend.write_zone("example.com", ZONE_DATA) + success, removed = mysql_backend.reconcile_zone_records("example.com", ZONE_DATA) + assert success + assert removed == 0 diff --git a/tests/test_reconciler.py b/tests/test_reconciler.py new file mode 100644 index 0000000..14b5f91 --- /dev/null +++ b/tests/test_reconciler.py @@ -0,0 +1,262 @@ +"""Tests for directdnsonly.app.reconciler — ReconciliationWorker.""" +import pytest +import requests.exceptions +from queue import Queue +from unittest.mock import MagicMock, patch + +from directdnsonly.app.reconciler import ReconciliationWorker +from directdnsonly.app.db.models import Domain + + +# --------------------------------------------------------------------------- +# Fixtures +# --------------------------------------------------------------------------- + +SERVER = {"hostname": "da1.example.com", "port": 2222, "username": "admin", "password": "secret", "ssl": True} + +BASE_CONFIG = { + "enabled": True, + "dry_run": False, + "interval_minutes": 60, + "verify_ssl": True, + "ipp": 100, + "directadmin_servers": [SERVER], +} + + +@pytest.fixture +def delete_queue(): + return Queue() + + +@pytest.fixture +def worker(delete_queue): + return ReconciliationWorker(delete_queue, BASE_CONFIG) + + +@pytest.fixture +def dry_run_worker(delete_queue): + cfg = {**BASE_CONFIG, "dry_run": True} + return ReconciliationWorker(delete_queue, cfg) + + +# --------------------------------------------------------------------------- +# _reconcile_all — orphan detection +# --------------------------------------------------------------------------- + + +def test_orphan_queued_when_domain_missing_from_da(worker, delete_queue, patch_connect): + patch_connect.add(Domain(domain="orphan.com", hostname="da1.example.com", username="admin")) + patch_connect.commit() + + with patch.object(worker, "_fetch_da_domains", return_value=set()): + worker._reconcile_all() + + assert not delete_queue.empty() + item = delete_queue.get_nowait() + assert item["domain"] == "orphan.com" + assert item["source"] == "reconciler" + + +def test_orphan_not_queued_in_dry_run(dry_run_worker, delete_queue, patch_connect): + patch_connect.add(Domain(domain="orphan.com", hostname="da1.example.com", username="admin")) + patch_connect.commit() + + with patch.object(dry_run_worker, "_fetch_da_domains", return_value=set()): + dry_run_worker._reconcile_all() + + assert delete_queue.empty() + + +def test_orphan_not_queued_for_unknown_server(worker, delete_queue, patch_connect): + """Domains whose recorded master is NOT in our configured servers are skipped.""" + patch_connect.add(Domain(domain="other.com", hostname="da99.unknown.com", username="admin")) + patch_connect.commit() + + with patch.object(worker, "_fetch_da_domains", return_value=set()): + worker._reconcile_all() + + assert delete_queue.empty() + + +def test_active_domain_not_queued(worker, delete_queue, patch_connect): + patch_connect.add(Domain(domain="good.com", hostname="da1.example.com", username="admin")) + patch_connect.commit() + + with patch.object(worker, "_fetch_da_domains", return_value={"good.com"}): + worker._reconcile_all() + + assert delete_queue.empty() + + +# --------------------------------------------------------------------------- +# _reconcile_all — hostname backfill and migration +# --------------------------------------------------------------------------- + + +def test_backfill_null_hostname(worker, patch_connect): + patch_connect.add(Domain(domain="backfill.com", hostname=None, username="admin")) + patch_connect.commit() + + with patch.object(worker, "_fetch_da_domains", return_value={"backfill.com"}): + worker._reconcile_all() + + record = patch_connect.query(Domain).filter_by(domain="backfill.com").first() + assert record.hostname == "da1.example.com" + + +def test_migration_updates_hostname(worker, patch_connect): + patch_connect.add(Domain(domain="moved.com", hostname="da-old.example.com", username="admin")) + patch_connect.commit() + + with patch.object(worker, "_fetch_da_domains", return_value={"moved.com"}): + worker._reconcile_all() + + record = patch_connect.query(Domain).filter_by(domain="moved.com").first() + assert record.hostname == "da1.example.com" + + +def test_dry_run_still_backfills(dry_run_worker, patch_connect): + """Backfill is a data-repair operation, applied even in dry-run mode.""" + patch_connect.add(Domain(domain="fill.com", hostname=None, username="admin")) + patch_connect.commit() + + with patch.object(dry_run_worker, "_fetch_da_domains", return_value={"fill.com"}): + dry_run_worker._reconcile_all() + + record = patch_connect.query(Domain).filter_by(domain="fill.com").first() + assert record.hostname == "da1.example.com" + + +# --------------------------------------------------------------------------- +# _fetch_da_domains — HTTP handling +# --------------------------------------------------------------------------- + + +def _make_json_response(domains_dict, total_pages=1): + """Return a mock requests.Response with JSON payload matching DA format.""" + data = {str(i): {"domain": d} for i, d in enumerate(domains_dict)} + data["info"] = {"total_pages": total_pages} + mock = MagicMock() + mock.status_code = 200 + mock.is_redirect = False + mock.headers = {"Content-Type": "application/json"} + mock.json.return_value = data + mock.raise_for_status = MagicMock() + return mock + + +def test_fetch_returns_domains_from_json(worker): + mock_resp = _make_json_response(["example.com", "test.com"]) + + with patch("requests.get", return_value=mock_resp): + result = worker._fetch_da_domains("da1.example.com", 2222, "admin", "secret", True) + + assert result == {"example.com", "test.com"} + + +def test_fetch_paginates(worker): + page1 = _make_json_response(["a.com"], total_pages=2) + page2 = _make_json_response(["b.com"], total_pages=2) + + with patch("requests.get", side_effect=[page1, page2]): + result = worker._fetch_da_domains("da1.example.com", 2222, "admin", "secret", True) + + assert result == {"a.com", "b.com"} + + +def test_fetch_redirect_triggers_session_login(worker): + redirect_resp = MagicMock() + redirect_resp.status_code = 302 + redirect_resp.is_redirect = True + + with patch("requests.get", return_value=redirect_resp), \ + patch.object(worker, "_da_session_login", return_value=None): + result = worker._fetch_da_domains("da1.example.com", 2222, "admin", "secret", True) + + assert result is None + + +def test_fetch_html_response_returns_none(worker): + mock_resp = MagicMock() + mock_resp.status_code = 200 + mock_resp.is_redirect = False + mock_resp.headers = {"Content-Type": "text/html; charset=utf-8"} + mock_resp.raise_for_status = MagicMock() + + with patch("requests.get", return_value=mock_resp): + result = worker._fetch_da_domains("da1.example.com", 2222, "admin", "secret", True) + + assert result is None + + +def test_fetch_connection_error_returns_none(worker): + with patch("requests.get", side_effect=requests.exceptions.ConnectionError("refused")): + result = worker._fetch_da_domains("da1.example.com", 2222, "admin", "secret", True) + + assert result is None + + +def test_fetch_timeout_returns_none(worker): + with patch("requests.get", side_effect=requests.exceptions.Timeout()): + result = worker._fetch_da_domains("da1.example.com", 2222, "admin", "secret", True) + + assert result is None + + +def test_fetch_ssl_error_returns_none(worker): + with patch("requests.get", side_effect=requests.exceptions.SSLError("cert verify failed")): + result = worker._fetch_da_domains("da1.example.com", 2222, "admin", "secret", True) + + assert result is None + + +# --------------------------------------------------------------------------- +# _parse_da_domain_list — legacy format fallback +# --------------------------------------------------------------------------- + + +def test_parse_standard_querystring(): + body = "list[]=example.com&list[]=test.com" + result = ReconciliationWorker._parse_da_domain_list(body) + assert result == {"example.com", "test.com"} + + +def test_parse_newline_separated(): + body = "list[]=example.com\nlist[]=test.com" + result = ReconciliationWorker._parse_da_domain_list(body) + assert result == {"example.com", "test.com"} + + +def test_parse_empty_body_returns_empty_set(): + assert ReconciliationWorker._parse_da_domain_list("") == set() + + +def test_parse_normalises_to_lowercase(): + result = ReconciliationWorker._parse_da_domain_list("list[]=EXAMPLE.COM") + assert "example.com" in result + assert "EXAMPLE.COM" not in result + + +def test_parse_strips_whitespace(): + result = ReconciliationWorker._parse_da_domain_list("list[]= example.com ") + assert "example.com" in result + + +# --------------------------------------------------------------------------- +# Worker lifecycle +# --------------------------------------------------------------------------- + + +def test_disabled_worker_does_not_start(delete_queue): + cfg = {**BASE_CONFIG, "enabled": False} + w = ReconciliationWorker(delete_queue, cfg) + w.start() + assert not w.is_alive + + +def test_no_servers_does_not_start(delete_queue): + cfg = {**BASE_CONFIG, "directadmin_servers": []} + w = ReconciliationWorker(delete_queue, cfg) + w.start() + assert not w.is_alive diff --git a/tests/test_utils.py b/tests/test_utils.py new file mode 100644 index 0000000..756ca67 --- /dev/null +++ b/tests/test_utils.py @@ -0,0 +1,137 @@ +"""Tests for directdnsonly.app.utils — zone index helper functions.""" +import pytest + +from directdnsonly.app.db.models import Domain +from directdnsonly.app.utils import ( + check_zone_exists, + check_parent_domain_owner, + get_domain_record, + get_parent_domain_record, + put_zone_index, +) + + +# --------------------------------------------------------------------------- +# check_zone_exists +# --------------------------------------------------------------------------- + + +def test_check_zone_exists_not_found(patch_connect): + assert check_zone_exists("example.com") is False + + +def test_check_zone_exists_found(patch_connect): + patch_connect.add( + Domain(domain="example.com", hostname="da1.example.com", username="admin") + ) + patch_connect.commit() + + assert check_zone_exists("example.com") is True + + +def test_check_zone_exists_does_not_match_partial(patch_connect): + patch_connect.add( + Domain(domain="example.com", hostname="da1.example.com", username="admin") + ) + patch_connect.commit() + + assert check_zone_exists("sub.example.com") is False + + +# --------------------------------------------------------------------------- +# put_zone_index +# --------------------------------------------------------------------------- + + +def test_put_zone_index_adds_record(patch_connect): + put_zone_index("new.com", "da1.example.com", "admin") + + record = patch_connect.query(Domain).filter_by(domain="new.com").first() + assert record is not None + assert record.hostname == "da1.example.com" + assert record.username == "admin" + + +def test_put_zone_index_stores_domain_name(patch_connect): + put_zone_index("another.nz", "da2.example.com", "user1") + + assert check_zone_exists("another.nz") is True + + +# --------------------------------------------------------------------------- +# get_domain_record +# --------------------------------------------------------------------------- + + +def test_get_domain_record_returns_none_when_missing(patch_connect): + assert get_domain_record("missing.com") is None + + +def test_get_domain_record_returns_record(patch_connect): + patch_connect.add( + Domain(domain="found.com", hostname="da1.example.com", username="admin") + ) + patch_connect.commit() + + record = get_domain_record("found.com") + assert record is not None + assert record.domain == "found.com" + assert record.hostname == "da1.example.com" + + +# --------------------------------------------------------------------------- +# check_parent_domain_owner +# --------------------------------------------------------------------------- + + +def test_check_parent_domain_owner_not_found(patch_connect): + assert check_parent_domain_owner("sub.example.com") is False + + +def test_check_parent_domain_owner_found(patch_connect): + patch_connect.add( + Domain(domain="example.com", hostname="da1.example.com", username="admin") + ) + patch_connect.commit() + + assert check_parent_domain_owner("sub.example.com") is True + + +def test_check_parent_domain_owner_single_label_returns_false(patch_connect): + # A single-label name like "com" has no parent + assert check_parent_domain_owner("com") is False + + +def test_check_parent_domain_owner_ignores_grandparent(patch_connect): + # Only the immediate parent is checked, not grandparents + patch_connect.add( + Domain(domain="example.com", hostname="da1.example.com", username="admin") + ) + patch_connect.commit() + + # deep.sub.example.com's immediate parent is sub.example.com (not in DB) + assert check_parent_domain_owner("deep.sub.example.com") is False + + +# --------------------------------------------------------------------------- +# get_parent_domain_record +# --------------------------------------------------------------------------- + + +def test_get_parent_domain_record_returns_none_when_missing(patch_connect): + assert get_parent_domain_record("sub.example.com") is None + + +def test_get_parent_domain_record_returns_parent(patch_connect): + patch_connect.add( + Domain(domain="example.com", hostname="da1.example.com", username="admin") + ) + patch_connect.commit() + + parent = get_parent_domain_record("sub.example.com") + assert parent is not None + assert parent.domain == "example.com" + + +def test_get_parent_domain_record_single_label_returns_none(patch_connect): + assert get_parent_domain_record("com") is None diff --git a/tests/test_zone_parser.py b/tests/test_zone_parser.py new file mode 100644 index 0000000..1a745d2 --- /dev/null +++ b/tests/test_zone_parser.py @@ -0,0 +1,100 @@ +"""Tests for directdnsonly.app.utils.zone_parser.""" +import pytest +from dns.exception import DNSException + +from directdnsonly.app.utils.zone_parser import ( + count_zone_records, + validate_and_normalize_zone, +) + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + +MINIMAL_ZONE = "example.com. 300 IN A 1.2.3.4" + +FULL_ZONE = """\ +$ORIGIN example.com. +$TTL 300 +@ IN SOA ns1.example.com. admin.example.com. 2024010101 3600 900 604800 300 +@ IN NS ns1.example.com. +@ IN A 1.2.3.4 +www IN A 5.6.7.8 +mail IN MX 10 mail.example.com. +""" + + +# --------------------------------------------------------------------------- +# validate_and_normalize_zone +# --------------------------------------------------------------------------- + + +def test_validate_adds_origin_when_missing(): + result = validate_and_normalize_zone(MINIMAL_ZONE, "example.com") + assert "$ORIGIN example.com." in result + + +def test_validate_adds_ttl_when_missing(): + result = validate_and_normalize_zone(MINIMAL_ZONE, "example.com") + assert "$TTL" in result + + +def test_validate_does_not_duplicate_origin(): + zone = "$ORIGIN example.com.\nexample.com. 300 IN A 1.2.3.4" + result = validate_and_normalize_zone(zone, "example.com") + assert result.count("$ORIGIN") == 1 + + +def test_validate_does_not_duplicate_ttl(): + zone = "$TTL 300\nexample.com. 300 IN A 1.2.3.4" + result = validate_and_normalize_zone(zone, "example.com") + assert result.count("$TTL") == 1 + + +def test_validate_appends_dot_to_domain(): + result = validate_and_normalize_zone(MINIMAL_ZONE, "example.com") + assert "$ORIGIN example.com." in result + + +def test_validate_returns_string(): + result = validate_and_normalize_zone(MINIMAL_ZONE, "example.com") + assert isinstance(result, str) + + +def test_validate_full_zone_passes(): + result = validate_and_normalize_zone(FULL_ZONE, "example.com") + assert result is not None + + +def test_validate_raises_on_invalid_zone(): + bad_zone = "this is not a zone file at all !!!" + with pytest.raises(ValueError, match="Invalid zone data"): + validate_and_normalize_zone(bad_zone, "example.com") + + +# --------------------------------------------------------------------------- +# count_zone_records +# --------------------------------------------------------------------------- + + +def test_count_records_simple_zone(): + zone = "$ORIGIN example.com.\n$TTL 300\n@ IN A 1.2.3.4\n@ IN AAAA ::1\n" + count = count_zone_records(zone, "example.com") + assert count == 2 + + +def test_count_records_soa_included(): + count = count_zone_records(FULL_ZONE, "example.com") + # SOA + NS + A (apex) + A (www) + MX = 5 + assert count == 5 + + +def test_count_records_returns_negative_on_bad_zone(): + count = count_zone_records("not a valid zone", "example.com") + assert count == -1 + + +def test_count_records_empty_zone(): + zone = "$ORIGIN example.com.\n$TTL 300\n" + count = count_zone_records(zone, "example.com") + assert count == 0