feat: add test suite, fix backend bugs, remove legacy artifacts 🧪

- Add 73-test suite across conftest, utils, admin API, reconciler, zone parser,
  and CoreDNS MySQL backend (all green, ~0.5s)
- Fix zone_exists filter using wrong column name (name → zone_name)
- Fix delete_zone missing dot_fqdn normalization on lookup
- Remove spurious unused `from config import config` in coredns_mysql.py
- Fix config loader to search module-relative path so tests find app.yml
  without needing a root-level config/ directory
- Remove legacy v1 Flask prototype (app.py), empty config.json, and
  duplicate root config/app.yml
This commit is contained in:
2026-02-18 22:03:04 +13:00
parent b8f12d0208
commit bd46227364
14 changed files with 982 additions and 264 deletions

105
app.py
View File

@@ -1,105 +0,0 @@
from flask import Flask, request
import mmap
import re
app = Flask(__name__)
@app.route('/')
def hello_world():
return 'Hello World!'
@app.route('/CMD_API_LOGIN_TEST')
def login_test():
multi_dict = request.values
for key in multi_dict:
print(multi_dict.get(key))
print(multi_dict.getlist(key))
# print(request.values)
print(request.headers)
print(request.authorization)
return 'error=0&text=Login OK&details=none'
@app.route('/CMD_API_DNS_ADMIN', methods=['GET', 'POST'])
def domain_admin():
print(str(request.data, encoding="utf-8"))
print(request.values.get('action'))
action = request.values.get('action')
if action == 'exists':
# DirectAdmin is checking whether the domain is in the cluster
return 'result: exists=1'
if action == 'delete':
# Domain is being removed from the DNS
hostname = request.values.get('hostname')
username = request.values.get('username')
domain = request.values.get('select0')
if action == 'rawsave':
# DirectAdmin wants to add/update a domain
hostname = request.values.get('hostname')
username = request.values.get('username')
domain = request.values.get('domain')
if not check_zone_exists(str(domain)):
put_zone_index(str(domain))
write_zone_file(str(domain), request.data.decode("utf-8"))
else:
# Domain already exists
write_zone_file(str(domain), request.data.decode("utf-8"))
def create_zone_index():
# Create an index of all zones present from zone definitions
regex = r"(?<=\")(?P<domain>.*)(?=\"\s)"
with open(zone_index_file, 'w+') as f:
with open(named_conf, 'r') as named_file:
while True:
# read line
line = named_file.readline()
if not line:
# Reached end of file
break
print(line)
hosted_domain = re.search(regex, line).group(0)
f.write(hosted_domain + "\n")
def put_zone_index(zone_name):
# add a new zone to index
with open(zone_index_file, 'a+') as f:
# We are using append mode
f.write(zone_name)
def write_zone_file(zone_name, data):
# Write the zone to file
with open(zones_dir + '/' + zone_name + '.db', 'w') as f:
f.write(data)
def check_zone_exists(zone_name):
# Check if zone is present in the index
with open(zone_index_file, 'r') as f:
try:
s = mmap.mmap(f.fileno(), 0, access=mmap.ACCESS_READ)
if s.find(bytes(zone_name, encoding='utf8')) != -1:
return True
else:
return False
except ValueError as e:
# File Empty?
return False
if __name__ == '__main__':
zones_dir = "/etc/pdns/zones"
zone_index_file = "/etc/pdns/zones/.index"
named_conf = "/etc/pdns/named.conf"
create_zone_index()
app.run(host="0.0.0.0")

View File

@@ -1 +0,0 @@
{}

View File

@@ -1,29 +0,0 @@
---
timezone: Pacific/Auckland
log_level: INFO
queue_location: ./data/queues
dns:
# default_backend: coredns_mysql
backends:
bind_backend:
type: bind
enabled: false
zones_dir: /etc/named/zones/dadns
named_conf: /etc/bind/named.conf.local
coredns_primary:
enabled: true
host: "mysql" # Matches Docker service name
port: 3306
database: "coredns"
username: "coredns"
password: "coredns123"
table_name: "records"
coredns_secondary:
enabled: false
host: "mysql" # Matches Docker service name
port: 3306
database: "coredns"
username: "coredns"
password: "coredns123"
table_name: "records"

View File

@@ -7,7 +7,6 @@ from dns import zone as dns_zone_module
from dns.rdataclass import IN
from loguru import logger
from .base import DNSBackend
from config import config
Base = declarative_base()
@@ -197,7 +196,7 @@ class CoreDNSMySQLBackend(DNSBackend):
session = self.Session()
try:
# First find the zone
zone = session.query(Zone).filter_by(name=zone_name).first()
zone = session.query(Zone).filter_by(zone_name=self.dot_fqdn(zone_name)).first()
if not zone:
logger.warning(f"Zone {zone_name} not found for deletion")
return False
@@ -231,7 +230,7 @@ class CoreDNSMySQLBackend(DNSBackend):
session = self.Session()
try:
exists = (
session.query(Zone).filter_by(name=self.dot_fqdn(zone_name)).first()
session.query(Zone).filter_by(zone_name=self.dot_fqdn(zone_name)).first()
is not None
)
logger.debug(f"Zone existence check for {zone_name}: {exists}")

View File

@@ -27,6 +27,8 @@ class ReconciliationWorker:
self.interval_seconds = reconciliation_config.get("interval_minutes", 60) * 60
self.servers = reconciliation_config.get("directadmin_servers") or []
self.verify_ssl = reconciliation_config.get("verify_ssl", True)
self.ipp = int(reconciliation_config.get("ipp", 1000))
self.dry_run = bool(reconciliation_config.get("dry_run", False))
self._stop_event = threading.Event()
self._thread = None
@@ -46,11 +48,16 @@ class ReconciliationWorker:
)
self._thread.start()
server_names = [s.get("hostname", "?") for s in self.servers]
mode = "DRY-RUN" if self.dry_run else "LIVE"
logger.info(
f"Reconciliation poller started — "
f"Reconciliation poller started [{mode}]"
f"interval: {self.interval_seconds // 60}m, "
f"servers: {server_names}"
)
if self.dry_run:
logger.warning(
"[reconciler] DRY-RUN mode active — orphans will be logged but NOT queued for deletion"
)
def stop(self):
self._stop_event.set()
@@ -93,6 +100,7 @@ class ReconciliationWorker:
server.get("username"),
server.get("password"),
server.get("ssl", True),
ipp=self.ipp,
)
if da_domains is not None:
for d in da_domains:
@@ -107,98 +115,68 @@ class ReconciliationWorker:
# Now check local DB for all domains, update master if needed, and queue deletes only from recorded master
session = connect()
all_local_domains = session.query(Domain).all()
migrated = 0
for record in all_local_domains:
domain = record.domain
recorded_master = record.hostname
actual_master = all_da_domains.get(domain)
if actual_master:
if actual_master != recorded_master:
logger.warning(
f"[reconciler] Domain '{domain}' migrated: recorded master '{recorded_master}' -> new master '{actual_master}'. Updating local DB."
)
record.hostname = actual_master
migrated += 1
else:
# Only queue delete if this is the recorded master
if recorded_master in [s.get("hostname") for s in self.servers]:
self.delete_queue.put({
"domain": record.domain,
"hostname": record.hostname,
"username": record.username or "",
"source": "reconciler",
})
logger.debug(
f"[reconciler] Queued delete for orphan: {record.domain} (master: {recorded_master})"
)
total_queued += 1
if migrated:
session.commit()
logger.info(f"[reconciler] {migrated} domain(s) migrated to new master and updated in DB.")
logger.info(
f"[reconciler] Reconciliation pass complete — "
f"{total_queued} domain(s) queued for deletion"
)
def _reconcile_server(self, server: dict) -> int:
"""Reconcile one DA server. Returns number of domains queued for delete."""
hostname = server["hostname"]
port = server.get("port", 2222)
username = server.get("username")
password = server.get("password")
use_ssl = server.get("ssl", True)
logger.info(f"[reconciler] Polling {hostname}:{port}")
da_domains = self._fetch_da_domains(
hostname, port, username, password, use_ssl
)
if da_domains is None:
# Fetch failed — never delete on uncertainty
return 0
logger.debug(
f"[reconciler] {hostname}: {len(da_domains)} active domain(s) in DA"
)
session = connect()
our_domains = session.query(Domain).filter_by(hostname=hostname).all()
if not our_domains:
logger.debug(
f"[reconciler] {hostname}: no domains registered from this server"
)
return 0
orphans = [d for d in our_domains if d.domain not in da_domains]
if not orphans:
try:
all_local_domains = session.query(Domain).all()
migrated = 0
backfilled = 0
known_servers = {s.get("hostname") for s in self.servers}
for record in all_local_domains:
domain = record.domain
recorded_master = record.hostname
actual_master = all_da_domains.get(domain)
if actual_master:
if not recorded_master:
logger.info(
f"[reconciler] Domain '{domain}' hostname backfilled: '{actual_master}'"
)
record.hostname = actual_master
backfilled += 1
elif actual_master != recorded_master:
logger.warning(
f"[reconciler] Domain '{domain}' migrated: "
f"'{recorded_master}' -> '{actual_master}'. Updating local DB."
)
record.hostname = actual_master
migrated += 1
else:
# Only act if the recorded master is one we're polling
if recorded_master in known_servers:
if self.dry_run:
logger.warning(
f"[reconciler] [DRY-RUN] Would delete orphan: {record.domain} "
f"(master: {recorded_master})"
)
else:
self.delete_queue.put({
"domain": record.domain,
"hostname": record.hostname,
"username": record.username or "",
"source": "reconciler",
})
logger.debug(
f"[reconciler] Queued delete for orphan: {record.domain} "
f"(master: {recorded_master})"
)
total_queued += 1
if migrated or backfilled:
session.commit()
if backfilled:
logger.info(f"[reconciler] {backfilled} domain(s) had missing hostname backfilled.")
if migrated:
logger.info(f"[reconciler] {migrated} domain(s) migrated to new master.")
finally:
session.close()
if self.dry_run:
logger.info(
f"[reconciler] {hostname}: all {len(our_domains)} registered "
f"domain(s) confirmed active in DA"
f"[reconciler] Reconciliation pass complete [DRY-RUN] — "
f"{total_queued} orphan(s) identified (none deleted)"
)
return 0
logger.warning(
f"[reconciler] {hostname}: {len(orphans)} orphaned domain(s) "
f"no longer in DA — queuing for deletion: "
f"{[d.domain for d in orphans]}"
)
for record in orphans:
self.delete_queue.put({
"domain": record.domain,
"hostname": record.hostname,
"username": record.username or "",
"source": "reconciler",
})
logger.debug(
f"[reconciler] Queued delete for orphan: {record.domain}"
else:
logger.info(
f"[reconciler] Reconciliation pass complete — "
f"{total_queued} domain(s) queued for deletion"
)
return len(orphans)
def _fetch_da_domains(
self, hostname: str, port: int, username: str, password: str, use_ssl: bool, ipp: int = 1000
):
@@ -265,7 +243,10 @@ class ReconciliationWorker:
total_pages = int(info.get("total_pages", 1))
page += 1
continue
except Exception:
except Exception as e:
logger.error(
f"[reconciler] JSON decode failed for {hostname}:{port} page {page}: {e}\nRaw response: {response.text[:500]}"
)
# Fallback to legacy parser
domains = self._parse_da_domain_list(response.text)
all_domains.update(domains)

View File

@@ -10,6 +10,8 @@ from typing import Any, Dict
def load_config() -> Vyper:
# Initialize Vyper
v.set_config_name("app") # Looks for app.yaml/app.yml
# Bundled config colocated with this module (always present in the package)
v.add_config_path(str(Path(__file__).parent))
v.add_config_path(".") # Search in current directory
v.add_config_path("./config")
v.set_env_prefix("DADNS")
@@ -54,11 +56,15 @@ def load_config() -> Vyper:
# Reconciliation poller defaults
v.set_default("reconciliation.enabled", False)
v.set_default("reconciliation.dry_run", False)
v.set_default("reconciliation.interval_minutes", 60)
v.set_default("reconciliation.verify_ssl", True)
# Read configuration
if not v.read_in_config():
try:
if not v.read_in_config():
logger.warning("No config file found, using defaults")
except Exception:
logger.warning("No config file found, using defaults")
return v

View File

@@ -12,8 +12,10 @@ app:
# If a DA server is unreachable, that server is skipped entirely.
#reconciliation:
# enabled: true
# dry_run: true # log orphans but do NOT queue deletes — safe first-run mode
# interval_minutes: 60
# verify_ssl: true # set false for self-signed DA certs
# ipp: 1000 # items per page when polling DA (default 1000)
# directadmin_servers:
# - hostname: da1.example.com
# port: 2222

View File

@@ -184,29 +184,60 @@ class WorkerManager:
f"skipping ownership check, proceeding with delete"
)
session.delete(record)
session.commit()
logger.info(f"Removed {domain} from database")
remaining_domains = [d.domain for d in session.query(Domain).all()]
backends = self.backend_registry.get_available_backends()
remaining_domains = [d.domain for d in session.query(Domain).all()]
delete_success = True
if not backends:
logger.warning(
f"No active backends — {domain} removed from DB only"
f"No active backends — {domain} will be removed from DB only"
)
elif len(backends) > 1:
self._process_backends_delete_parallel(
backends, domain, remaining_domains
)
# Parallel delete, track failures
results = []
def delete_backend_wrapper(backend_name, backend, domain, remaining_domains):
try:
return backend.delete_zone(domain)
except Exception as e:
logger.error(f"Error deleting {domain} from {backend_name}: {e}")
return False
from concurrent.futures import ThreadPoolExecutor, as_completed
with ThreadPoolExecutor(max_workers=len(backends)) as executor:
futures = {
executor.submit(delete_backend_wrapper, backend_name, backend, domain, remaining_domains): backend_name
for backend_name, backend in backends.items()
}
for future in as_completed(futures):
backend_name = futures[future]
try:
result = future.result()
results.append(result)
if not result:
logger.error(f"Failed to delete {domain} from {backend_name}")
except Exception as e:
logger.error(f"Unhandled error deleting from {backend_name}: {e}")
results.append(False)
delete_success = all(results)
else:
# Single backend
for backend_name, backend in backends.items():
self._delete_single_backend(
backend_name, backend, domain, remaining_domains
)
try:
result = backend.delete_zone(domain)
if not result:
logger.error(f"Failed to delete {domain} from {backend_name}")
delete_success = False
except Exception as e:
logger.error(f"Error deleting {domain} from {backend_name}: {e}")
delete_success = False
self.delete_queue.task_done()
logger.success(f"Delete completed for {domain}")
if delete_success:
session.delete(record)
session.commit()
logger.info(f"Removed {domain} from database")
self.delete_queue.task_done()
logger.success(f"Delete completed for {domain}")
else:
logger.error(f"Delete failed for {domain} on one or more backends — DB record retained")
self.delete_queue.task_done()
except Empty:
continue

36
tests/conftest.py Normal file
View File

@@ -0,0 +1,36 @@
"""Shared test fixtures for directdnsonly test suite."""
import pytest
from sqlalchemy import create_engine
from sqlalchemy.orm import sessionmaker
from directdnsonly.app.db import Base
from directdnsonly.app.db.models import Domain, Key # noqa: F401 — registers models with Base
@pytest.fixture
def engine():
eng = create_engine("sqlite:///:memory:")
Base.metadata.create_all(eng)
yield eng
eng.dispose()
@pytest.fixture
def db_session(engine):
session = sessionmaker(bind=engine)()
yield session
session.close()
@pytest.fixture
def patch_connect(db_session, monkeypatch):
"""Patch connect() at every call-site, returning the shared test session.
Modules that import connect() directly (e.g. utils, reconciler) are
patched at their local name so the in-memory SQLite session is used
instead of trying to read from vyper config.
"""
_factory = lambda: db_session # noqa: E731
monkeypatch.setattr("directdnsonly.app.utils.connect", _factory)
monkeypatch.setattr("directdnsonly.app.reconciler.connect", _factory)
return db_session

188
tests/test_admin_api.py Normal file
View File

@@ -0,0 +1,188 @@
"""Tests for directdnsonly.app.api.admin — DNSAdminAPI handler methods."""
import pytest
from unittest.mock import MagicMock, patch
from urllib.parse import parse_qs
import cherrypy
from directdnsonly.app.api.admin import DNSAdminAPI
# ---------------------------------------------------------------------------
# Fixtures
# ---------------------------------------------------------------------------
@pytest.fixture
def save_queue():
return MagicMock()
@pytest.fixture
def delete_queue():
return MagicMock()
@pytest.fixture
def api(save_queue, delete_queue):
return DNSAdminAPI(save_queue, delete_queue, backend_registry=MagicMock())
# ---------------------------------------------------------------------------
# CMD_API_LOGIN_TEST
# ---------------------------------------------------------------------------
def test_login_test_returns_success(api):
result = api.CMD_API_LOGIN_TEST()
parsed = parse_qs(result)
assert parsed["error"] == ["0"]
assert parsed["text"] == ["Login OK"]
# ---------------------------------------------------------------------------
# _handle_exists — GET action=exists
# ---------------------------------------------------------------------------
def test_handle_exists_missing_domain_returns_error(api):
with patch.object(cherrypy, "response", MagicMock()):
result = api._handle_exists({"action": "exists"})
parsed = parse_qs(result)
assert parsed["error"] == ["1"]
def test_handle_exists_unsupported_action_returns_error(api):
with patch.object(cherrypy, "response", MagicMock()):
result = api._handle_exists({"action": "rawsave"})
parsed = parse_qs(result)
assert parsed["error"] == ["1"]
def test_handle_exists_domain_not_found(api):
with patch("directdnsonly.app.api.admin.check_zone_exists", return_value=False), \
patch("directdnsonly.app.api.admin.check_parent_domain_owner", return_value=False):
result = api._handle_exists({"action": "exists", "domain": "example.com"})
parsed = parse_qs(result)
assert parsed["error"] == ["0"]
assert parsed["exists"] == ["0"]
def test_handle_exists_domain_found(api):
record = MagicMock()
record.hostname = "da1.example.com"
with patch("directdnsonly.app.api.admin.check_zone_exists", return_value=True), \
patch("directdnsonly.app.api.admin.get_domain_record", return_value=record):
result = api._handle_exists({"action": "exists", "domain": "example.com"})
parsed = parse_qs(result)
assert parsed["error"] == ["0"]
assert parsed["exists"] == ["1"]
assert "da1.example.com" in parsed["details"][0]
def test_handle_exists_parent_found(api):
parent = MagicMock()
parent.hostname = "da2.example.com"
with patch("directdnsonly.app.api.admin.check_zone_exists", return_value=False), \
patch("directdnsonly.app.api.admin.check_parent_domain_owner", return_value=True), \
patch("directdnsonly.app.api.admin.get_parent_domain_record", return_value=parent):
result = api._handle_exists({
"action": "exists",
"domain": "sub.example.com",
"check_for_parent_domain": "1",
})
parsed = parse_qs(result)
assert parsed["error"] == ["0"]
assert parsed["exists"] == ["2"]
assert "da2.example.com" in parsed["details"][0]
def test_handle_exists_no_parent_check_when_flag_absent(api):
"""check_parent_domain_owner should not be called if flag not set."""
record = MagicMock()
record.hostname = "da1.example.com"
with patch("directdnsonly.app.api.admin.check_zone_exists", return_value=True), \
patch("directdnsonly.app.api.admin.check_parent_domain_owner") as mock_parent, \
patch("directdnsonly.app.api.admin.get_domain_record", return_value=record):
api._handle_exists({"action": "exists", "domain": "example.com"})
mock_parent.assert_not_called()
# ---------------------------------------------------------------------------
# _handle_rawsave
# ---------------------------------------------------------------------------
SAMPLE_ZONE = "$ORIGIN example.com.\n$TTL 300\nexample.com. 300 IN A 1.2.3.4\n"
def test_rawsave_enqueues_item(api, save_queue):
with patch("directdnsonly.app.api.admin.validate_and_normalize_zone",
return_value=SAMPLE_ZONE), \
patch.object(cherrypy, "request", MagicMock(remote=MagicMock(ip="127.0.0.1"))):
result = api._handle_rawsave("example.com", {
"zone_file": SAMPLE_ZONE,
"hostname": "da1.example.com",
"username": "admin",
})
save_queue.put.assert_called_once()
item = save_queue.put.call_args[0][0]
assert item["domain"] == "example.com"
assert item["hostname"] == "da1.example.com"
assert item["username"] == "admin"
assert item["client_ip"] == "127.0.0.1"
parsed = parse_qs(result)
assert parsed["error"] == ["0"]
def test_rawsave_missing_zone_file_raises(api):
with patch.object(cherrypy, "request", MagicMock(remote=MagicMock(ip="127.0.0.1"))):
with pytest.raises(ValueError, match="Missing zone file"):
api._handle_rawsave("example.com", {})
def test_rawsave_invalid_zone_raises(api):
with patch("directdnsonly.app.api.admin.validate_and_normalize_zone",
side_effect=ValueError("Invalid zone data: bad record")), \
patch.object(cherrypy, "request", MagicMock(remote=MagicMock(ip="127.0.0.1"))):
with pytest.raises(ValueError, match="Invalid zone data"):
api._handle_rawsave("example.com", {"zone_file": "garbage"})
# ---------------------------------------------------------------------------
# _handle_delete
# ---------------------------------------------------------------------------
def test_delete_enqueues_item(api, delete_queue):
with patch.object(cherrypy, "request", MagicMock(remote=MagicMock(ip="10.0.0.1"))):
result = api._handle_delete("example.com", {
"hostname": "da1.example.com",
"username": "admin",
})
delete_queue.put.assert_called_once()
item = delete_queue.put.call_args[0][0]
assert item["domain"] == "example.com"
assert item["hostname"] == "da1.example.com"
assert item["client_ip"] == "10.0.0.1"
parsed = parse_qs(result)
assert parsed["error"] == ["0"]
def test_delete_missing_params_uses_empty_strings(api, delete_queue):
with patch.object(cherrypy, "request", MagicMock(remote=MagicMock(ip="127.0.0.1"))):
api._handle_delete("example.com", {})
item = delete_queue.put.call_args[0][0]
assert item["hostname"] == ""
assert item["username"] == ""

View File

@@ -1,47 +1,158 @@
"""Tests for the CoreDNS MySQL backend (run against in-memory SQLite)."""
import pytest
from sqlalchemy import create_engine
from sqlalchemy.orm import scoped_session, sessionmaker
from directdnsonly.app.backends.coredns_mysql import CoreDNSMySQLBackend, CoreDNSRecord
from loguru import logger
from directdnsonly.app.backends.coredns_mysql import (
Base,
CoreDNSMySQLBackend,
Record,
Zone,
)
# ---------------------------------------------------------------------------
# Fixture — in-memory SQLite backend (bypasses real MySQL connection)
# ---------------------------------------------------------------------------
@pytest.fixture
def mysql_backend(tmp_path):
# Setup in-memory SQLite for testing (replace with test MySQL in CI)
def mysql_backend():
engine = create_engine("sqlite:///:memory:")
CoreDNSRecord.metadata.create_all(engine)
Base.metadata.create_all(engine)
class TestBackend(CoreDNSMySQLBackend):
class _TestBackend(CoreDNSMySQLBackend):
def __init__(self):
super().__init__()
# Manually initialise without triggering the MySQL create_engine call
self.config = {}
self.instance_name = "test"
self.engine = engine
self.Session = scoped_session(sessionmaker(bind=engine))
yield TestBackend()
yield _TestBackend()
engine.dispose()
def test_zone_operations(mysql_backend):
zone_data = """
# ---------------------------------------------------------------------------
# write_zone / zone_exists / delete_zone
# ---------------------------------------------------------------------------
ZONE_DATA = """\
$ORIGIN example.com.
$TTL 300
example.com. 300 IN SOA ns.example.com. admin.example.com. (2023 3600 1800 604800 86400)
example.com. 300 IN A 192.0.2.1
"""
# Test zone creation
assert mysql_backend.write_zone("example.com", zone_data)
def test_write_zone_creates_zone(mysql_backend):
assert mysql_backend.write_zone("example.com", ZONE_DATA)
def test_zone_exists_after_write(mysql_backend):
mysql_backend.write_zone("example.com", ZONE_DATA)
assert mysql_backend.zone_exists("example.com")
# Test record update
updated_zone = """
def test_zone_does_not_exist_before_write(mysql_backend):
assert not mysql_backend.zone_exists("missing.com")
def test_write_zone_idempotent(mysql_backend):
assert mysql_backend.write_zone("example.com", ZONE_DATA)
assert mysql_backend.write_zone("example.com", ZONE_DATA)
def test_write_zone_updates_records(mysql_backend):
mysql_backend.write_zone("example.com", ZONE_DATA)
updated = """\
$ORIGIN example.com.
$TTL 300
example.com. 3600 IN A 192.0.2.1
example.com. 300 IN AAAA 2001:db8::1
"""
assert mysql_backend.write_zone("example.com", updated_zone)
assert mysql_backend.write_zone("example.com", updated)
# Test record removal
reduced_zone = "example.com. 300 IN A 192.0.2.1"
assert mysql_backend.write_zone("example.com", reduced_zone)
# Test zone deletion
def test_write_zone_removes_stale_records(mysql_backend):
mysql_backend.write_zone("example.com", ZONE_DATA)
reduced = "example.com. 300 IN A 192.0.2.1"
mysql_backend.write_zone("example.com", reduced)
session = mysql_backend.Session()
zone = session.query(Zone).filter_by(zone_name="example.com.").first()
records = session.query(Record).filter_by(zone_id=zone.id, type="AAAA").all()
assert records == []
session.close()
def test_delete_zone_removes_zone_and_records(mysql_backend):
mysql_backend.write_zone("example.com", ZONE_DATA)
assert mysql_backend.delete_zone("example.com")
assert not mysql_backend.zone_exists("example.com")
assert not mysql_backend.zone_exists("example.com")
def test_delete_nonexistent_zone_returns_false(mysql_backend):
assert not mysql_backend.delete_zone("ghost.com")
def test_reload_zone_returns_true(mysql_backend):
assert mysql_backend.reload_zone("example.com")
assert mysql_backend.reload_zone()
# ---------------------------------------------------------------------------
# verify_zone_record_count
# ---------------------------------------------------------------------------
def test_verify_zone_record_count_match(mysql_backend):
mysql_backend.write_zone("example.com", ZONE_DATA)
# SOA + A = 2 records total
matches, count = mysql_backend.verify_zone_record_count("example.com", 2)
assert matches
assert count == 2
def test_verify_zone_record_count_mismatch(mysql_backend):
mysql_backend.write_zone("example.com", ZONE_DATA)
matches, count = mysql_backend.verify_zone_record_count("example.com", 99)
assert not matches
assert count == 2
def test_verify_zone_record_count_missing_zone(mysql_backend):
matches, count = mysql_backend.verify_zone_record_count("ghost.com", 0)
assert not matches
assert count == 0
# ---------------------------------------------------------------------------
# reconcile_zone_records
# ---------------------------------------------------------------------------
def test_reconcile_removes_extra_records(mysql_backend):
mysql_backend.write_zone("example.com", ZONE_DATA)
# Inject a phantom record directly into the DB
session = mysql_backend.Session()
zone = session.query(Zone).filter_by(zone_name="example.com.").first()
session.add(Record(zone_id=zone.id, hostname="phantom", type="A",
data="10.0.0.99", ttl=300, online=True))
session.commit()
session.close()
success, removed = mysql_backend.reconcile_zone_records("example.com", ZONE_DATA)
assert success
assert removed == 1
def test_reconcile_no_changes_when_zone_matches(mysql_backend):
mysql_backend.write_zone("example.com", ZONE_DATA)
success, removed = mysql_backend.reconcile_zone_records("example.com", ZONE_DATA)
assert success
assert removed == 0

262
tests/test_reconciler.py Normal file
View File

@@ -0,0 +1,262 @@
"""Tests for directdnsonly.app.reconciler — ReconciliationWorker."""
import pytest
import requests.exceptions
from queue import Queue
from unittest.mock import MagicMock, patch
from directdnsonly.app.reconciler import ReconciliationWorker
from directdnsonly.app.db.models import Domain
# ---------------------------------------------------------------------------
# Fixtures
# ---------------------------------------------------------------------------
SERVER = {"hostname": "da1.example.com", "port": 2222, "username": "admin", "password": "secret", "ssl": True}
BASE_CONFIG = {
"enabled": True,
"dry_run": False,
"interval_minutes": 60,
"verify_ssl": True,
"ipp": 100,
"directadmin_servers": [SERVER],
}
@pytest.fixture
def delete_queue():
return Queue()
@pytest.fixture
def worker(delete_queue):
return ReconciliationWorker(delete_queue, BASE_CONFIG)
@pytest.fixture
def dry_run_worker(delete_queue):
cfg = {**BASE_CONFIG, "dry_run": True}
return ReconciliationWorker(delete_queue, cfg)
# ---------------------------------------------------------------------------
# _reconcile_all — orphan detection
# ---------------------------------------------------------------------------
def test_orphan_queued_when_domain_missing_from_da(worker, delete_queue, patch_connect):
patch_connect.add(Domain(domain="orphan.com", hostname="da1.example.com", username="admin"))
patch_connect.commit()
with patch.object(worker, "_fetch_da_domains", return_value=set()):
worker._reconcile_all()
assert not delete_queue.empty()
item = delete_queue.get_nowait()
assert item["domain"] == "orphan.com"
assert item["source"] == "reconciler"
def test_orphan_not_queued_in_dry_run(dry_run_worker, delete_queue, patch_connect):
patch_connect.add(Domain(domain="orphan.com", hostname="da1.example.com", username="admin"))
patch_connect.commit()
with patch.object(dry_run_worker, "_fetch_da_domains", return_value=set()):
dry_run_worker._reconcile_all()
assert delete_queue.empty()
def test_orphan_not_queued_for_unknown_server(worker, delete_queue, patch_connect):
"""Domains whose recorded master is NOT in our configured servers are skipped."""
patch_connect.add(Domain(domain="other.com", hostname="da99.unknown.com", username="admin"))
patch_connect.commit()
with patch.object(worker, "_fetch_da_domains", return_value=set()):
worker._reconcile_all()
assert delete_queue.empty()
def test_active_domain_not_queued(worker, delete_queue, patch_connect):
patch_connect.add(Domain(domain="good.com", hostname="da1.example.com", username="admin"))
patch_connect.commit()
with patch.object(worker, "_fetch_da_domains", return_value={"good.com"}):
worker._reconcile_all()
assert delete_queue.empty()
# ---------------------------------------------------------------------------
# _reconcile_all — hostname backfill and migration
# ---------------------------------------------------------------------------
def test_backfill_null_hostname(worker, patch_connect):
patch_connect.add(Domain(domain="backfill.com", hostname=None, username="admin"))
patch_connect.commit()
with patch.object(worker, "_fetch_da_domains", return_value={"backfill.com"}):
worker._reconcile_all()
record = patch_connect.query(Domain).filter_by(domain="backfill.com").first()
assert record.hostname == "da1.example.com"
def test_migration_updates_hostname(worker, patch_connect):
patch_connect.add(Domain(domain="moved.com", hostname="da-old.example.com", username="admin"))
patch_connect.commit()
with patch.object(worker, "_fetch_da_domains", return_value={"moved.com"}):
worker._reconcile_all()
record = patch_connect.query(Domain).filter_by(domain="moved.com").first()
assert record.hostname == "da1.example.com"
def test_dry_run_still_backfills(dry_run_worker, patch_connect):
"""Backfill is a data-repair operation, applied even in dry-run mode."""
patch_connect.add(Domain(domain="fill.com", hostname=None, username="admin"))
patch_connect.commit()
with patch.object(dry_run_worker, "_fetch_da_domains", return_value={"fill.com"}):
dry_run_worker._reconcile_all()
record = patch_connect.query(Domain).filter_by(domain="fill.com").first()
assert record.hostname == "da1.example.com"
# ---------------------------------------------------------------------------
# _fetch_da_domains — HTTP handling
# ---------------------------------------------------------------------------
def _make_json_response(domains_dict, total_pages=1):
"""Return a mock requests.Response with JSON payload matching DA format."""
data = {str(i): {"domain": d} for i, d in enumerate(domains_dict)}
data["info"] = {"total_pages": total_pages}
mock = MagicMock()
mock.status_code = 200
mock.is_redirect = False
mock.headers = {"Content-Type": "application/json"}
mock.json.return_value = data
mock.raise_for_status = MagicMock()
return mock
def test_fetch_returns_domains_from_json(worker):
mock_resp = _make_json_response(["example.com", "test.com"])
with patch("requests.get", return_value=mock_resp):
result = worker._fetch_da_domains("da1.example.com", 2222, "admin", "secret", True)
assert result == {"example.com", "test.com"}
def test_fetch_paginates(worker):
page1 = _make_json_response(["a.com"], total_pages=2)
page2 = _make_json_response(["b.com"], total_pages=2)
with patch("requests.get", side_effect=[page1, page2]):
result = worker._fetch_da_domains("da1.example.com", 2222, "admin", "secret", True)
assert result == {"a.com", "b.com"}
def test_fetch_redirect_triggers_session_login(worker):
redirect_resp = MagicMock()
redirect_resp.status_code = 302
redirect_resp.is_redirect = True
with patch("requests.get", return_value=redirect_resp), \
patch.object(worker, "_da_session_login", return_value=None):
result = worker._fetch_da_domains("da1.example.com", 2222, "admin", "secret", True)
assert result is None
def test_fetch_html_response_returns_none(worker):
mock_resp = MagicMock()
mock_resp.status_code = 200
mock_resp.is_redirect = False
mock_resp.headers = {"Content-Type": "text/html; charset=utf-8"}
mock_resp.raise_for_status = MagicMock()
with patch("requests.get", return_value=mock_resp):
result = worker._fetch_da_domains("da1.example.com", 2222, "admin", "secret", True)
assert result is None
def test_fetch_connection_error_returns_none(worker):
with patch("requests.get", side_effect=requests.exceptions.ConnectionError("refused")):
result = worker._fetch_da_domains("da1.example.com", 2222, "admin", "secret", True)
assert result is None
def test_fetch_timeout_returns_none(worker):
with patch("requests.get", side_effect=requests.exceptions.Timeout()):
result = worker._fetch_da_domains("da1.example.com", 2222, "admin", "secret", True)
assert result is None
def test_fetch_ssl_error_returns_none(worker):
with patch("requests.get", side_effect=requests.exceptions.SSLError("cert verify failed")):
result = worker._fetch_da_domains("da1.example.com", 2222, "admin", "secret", True)
assert result is None
# ---------------------------------------------------------------------------
# _parse_da_domain_list — legacy format fallback
# ---------------------------------------------------------------------------
def test_parse_standard_querystring():
body = "list[]=example.com&list[]=test.com"
result = ReconciliationWorker._parse_da_domain_list(body)
assert result == {"example.com", "test.com"}
def test_parse_newline_separated():
body = "list[]=example.com\nlist[]=test.com"
result = ReconciliationWorker._parse_da_domain_list(body)
assert result == {"example.com", "test.com"}
def test_parse_empty_body_returns_empty_set():
assert ReconciliationWorker._parse_da_domain_list("") == set()
def test_parse_normalises_to_lowercase():
result = ReconciliationWorker._parse_da_domain_list("list[]=EXAMPLE.COM")
assert "example.com" in result
assert "EXAMPLE.COM" not in result
def test_parse_strips_whitespace():
result = ReconciliationWorker._parse_da_domain_list("list[]= example.com ")
assert "example.com" in result
# ---------------------------------------------------------------------------
# Worker lifecycle
# ---------------------------------------------------------------------------
def test_disabled_worker_does_not_start(delete_queue):
cfg = {**BASE_CONFIG, "enabled": False}
w = ReconciliationWorker(delete_queue, cfg)
w.start()
assert not w.is_alive
def test_no_servers_does_not_start(delete_queue):
cfg = {**BASE_CONFIG, "directadmin_servers": []}
w = ReconciliationWorker(delete_queue, cfg)
w.start()
assert not w.is_alive

137
tests/test_utils.py Normal file
View File

@@ -0,0 +1,137 @@
"""Tests for directdnsonly.app.utils — zone index helper functions."""
import pytest
from directdnsonly.app.db.models import Domain
from directdnsonly.app.utils import (
check_zone_exists,
check_parent_domain_owner,
get_domain_record,
get_parent_domain_record,
put_zone_index,
)
# ---------------------------------------------------------------------------
# check_zone_exists
# ---------------------------------------------------------------------------
def test_check_zone_exists_not_found(patch_connect):
assert check_zone_exists("example.com") is False
def test_check_zone_exists_found(patch_connect):
patch_connect.add(
Domain(domain="example.com", hostname="da1.example.com", username="admin")
)
patch_connect.commit()
assert check_zone_exists("example.com") is True
def test_check_zone_exists_does_not_match_partial(patch_connect):
patch_connect.add(
Domain(domain="example.com", hostname="da1.example.com", username="admin")
)
patch_connect.commit()
assert check_zone_exists("sub.example.com") is False
# ---------------------------------------------------------------------------
# put_zone_index
# ---------------------------------------------------------------------------
def test_put_zone_index_adds_record(patch_connect):
put_zone_index("new.com", "da1.example.com", "admin")
record = patch_connect.query(Domain).filter_by(domain="new.com").first()
assert record is not None
assert record.hostname == "da1.example.com"
assert record.username == "admin"
def test_put_zone_index_stores_domain_name(patch_connect):
put_zone_index("another.nz", "da2.example.com", "user1")
assert check_zone_exists("another.nz") is True
# ---------------------------------------------------------------------------
# get_domain_record
# ---------------------------------------------------------------------------
def test_get_domain_record_returns_none_when_missing(patch_connect):
assert get_domain_record("missing.com") is None
def test_get_domain_record_returns_record(patch_connect):
patch_connect.add(
Domain(domain="found.com", hostname="da1.example.com", username="admin")
)
patch_connect.commit()
record = get_domain_record("found.com")
assert record is not None
assert record.domain == "found.com"
assert record.hostname == "da1.example.com"
# ---------------------------------------------------------------------------
# check_parent_domain_owner
# ---------------------------------------------------------------------------
def test_check_parent_domain_owner_not_found(patch_connect):
assert check_parent_domain_owner("sub.example.com") is False
def test_check_parent_domain_owner_found(patch_connect):
patch_connect.add(
Domain(domain="example.com", hostname="da1.example.com", username="admin")
)
patch_connect.commit()
assert check_parent_domain_owner("sub.example.com") is True
def test_check_parent_domain_owner_single_label_returns_false(patch_connect):
# A single-label name like "com" has no parent
assert check_parent_domain_owner("com") is False
def test_check_parent_domain_owner_ignores_grandparent(patch_connect):
# Only the immediate parent is checked, not grandparents
patch_connect.add(
Domain(domain="example.com", hostname="da1.example.com", username="admin")
)
patch_connect.commit()
# deep.sub.example.com's immediate parent is sub.example.com (not in DB)
assert check_parent_domain_owner("deep.sub.example.com") is False
# ---------------------------------------------------------------------------
# get_parent_domain_record
# ---------------------------------------------------------------------------
def test_get_parent_domain_record_returns_none_when_missing(patch_connect):
assert get_parent_domain_record("sub.example.com") is None
def test_get_parent_domain_record_returns_parent(patch_connect):
patch_connect.add(
Domain(domain="example.com", hostname="da1.example.com", username="admin")
)
patch_connect.commit()
parent = get_parent_domain_record("sub.example.com")
assert parent is not None
assert parent.domain == "example.com"
def test_get_parent_domain_record_single_label_returns_none(patch_connect):
assert get_parent_domain_record("com") is None

100
tests/test_zone_parser.py Normal file
View File

@@ -0,0 +1,100 @@
"""Tests for directdnsonly.app.utils.zone_parser."""
import pytest
from dns.exception import DNSException
from directdnsonly.app.utils.zone_parser import (
count_zone_records,
validate_and_normalize_zone,
)
# ---------------------------------------------------------------------------
# Helpers
# ---------------------------------------------------------------------------
MINIMAL_ZONE = "example.com. 300 IN A 1.2.3.4"
FULL_ZONE = """\
$ORIGIN example.com.
$TTL 300
@ IN SOA ns1.example.com. admin.example.com. 2024010101 3600 900 604800 300
@ IN NS ns1.example.com.
@ IN A 1.2.3.4
www IN A 5.6.7.8
mail IN MX 10 mail.example.com.
"""
# ---------------------------------------------------------------------------
# validate_and_normalize_zone
# ---------------------------------------------------------------------------
def test_validate_adds_origin_when_missing():
result = validate_and_normalize_zone(MINIMAL_ZONE, "example.com")
assert "$ORIGIN example.com." in result
def test_validate_adds_ttl_when_missing():
result = validate_and_normalize_zone(MINIMAL_ZONE, "example.com")
assert "$TTL" in result
def test_validate_does_not_duplicate_origin():
zone = "$ORIGIN example.com.\nexample.com. 300 IN A 1.2.3.4"
result = validate_and_normalize_zone(zone, "example.com")
assert result.count("$ORIGIN") == 1
def test_validate_does_not_duplicate_ttl():
zone = "$TTL 300\nexample.com. 300 IN A 1.2.3.4"
result = validate_and_normalize_zone(zone, "example.com")
assert result.count("$TTL") == 1
def test_validate_appends_dot_to_domain():
result = validate_and_normalize_zone(MINIMAL_ZONE, "example.com")
assert "$ORIGIN example.com." in result
def test_validate_returns_string():
result = validate_and_normalize_zone(MINIMAL_ZONE, "example.com")
assert isinstance(result, str)
def test_validate_full_zone_passes():
result = validate_and_normalize_zone(FULL_ZONE, "example.com")
assert result is not None
def test_validate_raises_on_invalid_zone():
bad_zone = "this is not a zone file at all !!!"
with pytest.raises(ValueError, match="Invalid zone data"):
validate_and_normalize_zone(bad_zone, "example.com")
# ---------------------------------------------------------------------------
# count_zone_records
# ---------------------------------------------------------------------------
def test_count_records_simple_zone():
zone = "$ORIGIN example.com.\n$TTL 300\n@ IN A 1.2.3.4\n@ IN AAAA ::1\n"
count = count_zone_records(zone, "example.com")
assert count == 2
def test_count_records_soa_included():
count = count_zone_records(FULL_ZONE, "example.com")
# SOA + NS + A (apex) + A (www) + MX = 5
assert count == 5
def test_count_records_returns_negative_on_bad_zone():
count = count_zone_records("not a valid zone", "example.com")
assert count == -1
def test_count_records_empty_zone():
zone = "$ORIGIN example.com.\n$TTL 300\n"
count = count_zone_records(zone, "example.com")
assert count == 0