Compare commits

..

4 Commits

Author SHA1 Message Date
0903d78458 fix: update .gitignore to include dist/ and modify build command in justfile 🐛 2026-02-18 23:04:41 +13:00
74c5f4012e style: apply black formatting across codebase 🎨
No logic changes — pure reformatting of line lengths, dict literals,
method-chain line breaks, and trailing newlines to satisfy black's style.
2026-02-18 22:53:09 +13:00
807d6271f1 chore: rewrite justfile for pyenv + poetry dev workflow 🔧
Replaces outdated PyInstaller-only recipe with full task runner:
install, test, coverage, coverage-html, test-one, fmt, fmt-check, ci, run,
build, clean. PATH export wires in pyenv shims and poetry automatically.
2026-02-18 22:46:18 +13:00
bd46227364 feat: add test suite, fix backend bugs, remove legacy artifacts 🧪
- Add 73-test suite across conftest, utils, admin API, reconciler, zone parser,
  and CoreDNS MySQL backend (all green, ~0.5s)
- Fix zone_exists filter using wrong column name (name → zone_name)
- Fix delete_zone missing dot_fqdn normalization on lookup
- Remove spurious unused `from config import config` in coredns_mysql.py
- Fix config loader to search module-relative path so tests find app.yml
  without needing a root-level config/ directory
- Remove legacy v1 Flask prototype (app.py), empty config.json, and
  duplicate root config/app.yml
2026-02-18 22:03:04 +13:00
19 changed files with 1301 additions and 374 deletions

1
.gitignore vendored
View File

@@ -1,4 +1,5 @@
*.db *.db
dist/
venv/ venv/
.venv .venv
.idea .idea

105
app.py
View File

@@ -1,105 +0,0 @@
from flask import Flask, request
import mmap
import re
app = Flask(__name__)
@app.route('/')
def hello_world():
return 'Hello World!'
@app.route('/CMD_API_LOGIN_TEST')
def login_test():
multi_dict = request.values
for key in multi_dict:
print(multi_dict.get(key))
print(multi_dict.getlist(key))
# print(request.values)
print(request.headers)
print(request.authorization)
return 'error=0&text=Login OK&details=none'
@app.route('/CMD_API_DNS_ADMIN', methods=['GET', 'POST'])
def domain_admin():
print(str(request.data, encoding="utf-8"))
print(request.values.get('action'))
action = request.values.get('action')
if action == 'exists':
# DirectAdmin is checking whether the domain is in the cluster
return 'result: exists=1'
if action == 'delete':
# Domain is being removed from the DNS
hostname = request.values.get('hostname')
username = request.values.get('username')
domain = request.values.get('select0')
if action == 'rawsave':
# DirectAdmin wants to add/update a domain
hostname = request.values.get('hostname')
username = request.values.get('username')
domain = request.values.get('domain')
if not check_zone_exists(str(domain)):
put_zone_index(str(domain))
write_zone_file(str(domain), request.data.decode("utf-8"))
else:
# Domain already exists
write_zone_file(str(domain), request.data.decode("utf-8"))
def create_zone_index():
# Create an index of all zones present from zone definitions
regex = r"(?<=\")(?P<domain>.*)(?=\"\s)"
with open(zone_index_file, 'w+') as f:
with open(named_conf, 'r') as named_file:
while True:
# read line
line = named_file.readline()
if not line:
# Reached end of file
break
print(line)
hosted_domain = re.search(regex, line).group(0)
f.write(hosted_domain + "\n")
def put_zone_index(zone_name):
# add a new zone to index
with open(zone_index_file, 'a+') as f:
# We are using append mode
f.write(zone_name)
def write_zone_file(zone_name, data):
# Write the zone to file
with open(zones_dir + '/' + zone_name + '.db', 'w') as f:
f.write(data)
def check_zone_exists(zone_name):
# Check if zone is present in the index
with open(zone_index_file, 'r') as f:
try:
s = mmap.mmap(f.fileno(), 0, access=mmap.ACCESS_READ)
if s.find(bytes(zone_name, encoding='utf8')) != -1:
return True
else:
return False
except ValueError as e:
# File Empty?
return False
if __name__ == '__main__':
zones_dir = "/etc/pdns/zones"
zone_index_file = "/etc/pdns/zones/.index"
named_conf = "/etc/pdns/named.conf"
create_zone_index()
app.run(host="0.0.0.0")

View File

@@ -1 +0,0 @@
{}

View File

@@ -1,29 +0,0 @@
---
timezone: Pacific/Auckland
log_level: INFO
queue_location: ./data/queues
dns:
# default_backend: coredns_mysql
backends:
bind_backend:
type: bind
enabled: false
zones_dir: /etc/named/zones/dadns
named_conf: /etc/bind/named.conf.local
coredns_primary:
enabled: true
host: "mysql" # Matches Docker service name
port: 3306
database: "coredns"
username: "coredns"
password: "coredns123"
table_name: "records"
coredns_secondary:
enabled: false
host: "mysql" # Matches Docker service name
port: 3306
database: "coredns"
username: "coredns"
password: "coredns123"
table_name: "records"

View File

@@ -113,19 +113,23 @@ class DNSAdminAPI:
if domain_exists: if domain_exists:
record = get_domain_record(domain) record = get_domain_record(domain)
return urlencode({ return urlencode(
"error": 0, {
"exists": 1, "error": 0,
"details": f"Domain exists on {record.hostname}", "exists": 1,
}) "details": f"Domain exists on {record.hostname}",
}
)
# parent match only # parent match only
parent_record = get_parent_domain_record(domain) parent_record = get_parent_domain_record(domain)
return urlencode({ return urlencode(
"error": 0, {
"exists": 2, "error": 0,
"details": f"Parent Domain exists on {parent_record.hostname}", "exists": 2,
}) "details": f"Parent Domain exists on {parent_record.hostname}",
}
)
def _handle_rawsave(self, domain: str, params: dict): def _handle_rawsave(self, domain: str, params: dict):
"""Process zone file saves""" """Process zone file saves"""

View File

@@ -7,7 +7,6 @@ from dns import zone as dns_zone_module
from dns.rdataclass import IN from dns.rdataclass import IN
from loguru import logger from loguru import logger
from .base import DNSBackend from .base import DNSBackend
from config import config
Base = declarative_base() Base = declarative_base()
@@ -120,9 +119,7 @@ class CoreDNSMySQLBackend(DNSBackend):
) )
session.add(existing_soa) session.add(existing_soa)
changes["added"] += 1 changes["added"] += 1
logger.debug( logger.debug(f"Added SOA record: {soa_name} SOA {soa_content}")
f"Added SOA record: {soa_name} SOA {soa_content}"
)
# Process all non-SOA records # Process all non-SOA records
for record_name, record_type, record_content, record_ttl in source_records: for record_name, record_type, record_content, record_ttl in source_records:
@@ -173,7 +170,7 @@ class CoreDNSMySQLBackend(DNSBackend):
changes["removed"] += 1 changes["removed"] += 1
session.commit() session.commit()
total_changes = changes['added'] + changes['updated'] + changes['removed'] total_changes = changes["added"] + changes["updated"] + changes["removed"]
if total_changes > 0: if total_changes > 0:
logger.info( logger.info(
f"[{self.instance_name}] Zone {zone_name} updated: " f"[{self.instance_name}] Zone {zone_name} updated: "
@@ -181,9 +178,7 @@ class CoreDNSMySQLBackend(DNSBackend):
f"{changes['removed']} removed" f"{changes['removed']} removed"
) )
else: else:
logger.debug( logger.debug(f"[{self.instance_name}] Zone {zone_name}: no changes")
f"[{self.instance_name}] Zone {zone_name}: no changes"
)
return True return True
except Exception as e: except Exception as e:
@@ -197,7 +192,11 @@ class CoreDNSMySQLBackend(DNSBackend):
session = self.Session() session = self.Session()
try: try:
# First find the zone # First find the zone
zone = session.query(Zone).filter_by(name=zone_name).first() zone = (
session.query(Zone)
.filter_by(zone_name=self.dot_fqdn(zone_name))
.first()
)
if not zone: if not zone:
logger.warning(f"Zone {zone_name} not found for deletion") logger.warning(f"Zone {zone_name} not found for deletion")
return False return False
@@ -231,7 +230,9 @@ class CoreDNSMySQLBackend(DNSBackend):
session = self.Session() session = self.Session()
try: try:
exists = ( exists = (
session.query(Zone).filter_by(name=self.dot_fqdn(zone_name)).first() session.query(Zone)
.filter_by(zone_name=self.dot_fqdn(zone_name))
.first()
is not None is not None
) )
logger.debug(f"Zone existence check for {zone_name}: {exists}") logger.debug(f"Zone existence check for {zone_name}: {exists}")
@@ -267,17 +268,11 @@ class CoreDNSMySQLBackend(DNSBackend):
The normalized CNAME target string The normalized CNAME target string
""" """
if record_content.startswith("@"): if record_content.startswith("@"):
logger.debug( logger.debug(f"CNAME target starts with '@', replacing with zone FQDN")
f"CNAME target starts with '@', replacing with zone FQDN"
)
record_content = self.dot_fqdn(zone_name) record_content = self.dot_fqdn(zone_name)
elif not record_content.endswith("."): elif not record_content.endswith("."):
logger.debug( logger.debug(f"CNAME target {record_content} is relative, appending zone")
f"CNAME target {record_content} is relative, appending zone" record_content = ".".join([record_content, self.dot_fqdn(zone_name)])
)
record_content = ".".join(
[record_content, self.dot_fqdn(zone_name)]
)
return record_content return record_content
def _parse_zone_to_record_set( def _parse_zone_to_record_set(
@@ -307,9 +302,7 @@ class CoreDNSMySQLBackend(DNSBackend):
continue continue
if record_type == "CNAME": if record_type == "CNAME":
record_content = self._normalize_cname_data( record_content = self._normalize_cname_data(zone_name, record_content)
zone_name, record_content
)
records.add((record_name, record_type, record_content, ttl)) records.add((record_name, record_type, record_content, ttl))
@@ -342,9 +335,7 @@ class CoreDNSMySQLBackend(DNSBackend):
) )
return False, 0 return False, 0
actual_count = ( actual_count = session.query(Record).filter_by(zone_id=zone.id).count()
session.query(Record).filter_by(zone_id=zone.id).count()
)
matches = actual_count == expected_count matches = actual_count == expected_count
if not matches: if not matches:
@@ -410,14 +401,11 @@ class CoreDNSMySQLBackend(DNSBackend):
) )
# Build lookup keys (without TTL) matching write_zone's key format # Build lookup keys (without TTL) matching write_zone's key format
expected_keys: Set[Tuple[str, str, str]] = { expected_keys: Set[Tuple[str, str, str]] = {
(hostname, rtype, data) (hostname, rtype, data) for hostname, rtype, data, _ in source_records
for hostname, rtype, data, _ in source_records
} }
# Query all records currently in the backend for this zone # Query all records currently in the backend for this zone
db_records = ( db_records = session.query(Record).filter_by(zone_id=zone.id).all()
session.query(Record).filter_by(zone_id=zone.id).all()
)
removed = 0 removed = 0
for record in db_records: for record in db_records:

View File

@@ -27,6 +27,8 @@ class ReconciliationWorker:
self.interval_seconds = reconciliation_config.get("interval_minutes", 60) * 60 self.interval_seconds = reconciliation_config.get("interval_minutes", 60) * 60
self.servers = reconciliation_config.get("directadmin_servers") or [] self.servers = reconciliation_config.get("directadmin_servers") or []
self.verify_ssl = reconciliation_config.get("verify_ssl", True) self.verify_ssl = reconciliation_config.get("verify_ssl", True)
self.ipp = int(reconciliation_config.get("ipp", 1000))
self.dry_run = bool(reconciliation_config.get("dry_run", False))
self._stop_event = threading.Event() self._stop_event = threading.Event()
self._thread = None self._thread = None
@@ -46,11 +48,16 @@ class ReconciliationWorker:
) )
self._thread.start() self._thread.start()
server_names = [s.get("hostname", "?") for s in self.servers] server_names = [s.get("hostname", "?") for s in self.servers]
mode = "DRY-RUN" if self.dry_run else "LIVE"
logger.info( logger.info(
f"Reconciliation poller started — " f"Reconciliation poller started [{mode}]"
f"interval: {self.interval_seconds // 60}m, " f"interval: {self.interval_seconds // 60}m, "
f"servers: {server_names}" f"servers: {server_names}"
) )
if self.dry_run:
logger.warning(
"[reconciler] DRY-RUN mode active — orphans will be logged but NOT queued for deletion"
)
def stop(self): def stop(self):
self._stop_event.set() self._stop_event.set()
@@ -93,6 +100,7 @@ class ReconciliationWorker:
server.get("username"), server.get("username"),
server.get("password"), server.get("password"),
server.get("ssl", True), server.get("ssl", True),
ipp=self.ipp,
) )
if da_domains is not None: if da_domains is not None:
for d in da_domains: for d in da_domains:
@@ -101,106 +109,86 @@ class ReconciliationWorker:
f"[reconciler] {hostname}: {len(da_domains) if da_domains else 0} active domain(s) in DA" f"[reconciler] {hostname}: {len(da_domains) if da_domains else 0} active domain(s) in DA"
) )
except Exception as e: except Exception as e:
logger.error( logger.error(f"[reconciler] Unexpected error polling {hostname}: {e}")
f"[reconciler] Unexpected error polling {hostname}: {e}"
)
# Now check local DB for all domains, update master if needed, and queue deletes only from recorded master # Now check local DB for all domains, update master if needed, and queue deletes only from recorded master
session = connect() session = connect()
all_local_domains = session.query(Domain).all() try:
migrated = 0 all_local_domains = session.query(Domain).all()
for record in all_local_domains: migrated = 0
domain = record.domain backfilled = 0
recorded_master = record.hostname known_servers = {s.get("hostname") for s in self.servers}
actual_master = all_da_domains.get(domain) for record in all_local_domains:
if actual_master: domain = record.domain
if actual_master != recorded_master: recorded_master = record.hostname
logger.warning( actual_master = all_da_domains.get(domain)
f"[reconciler] Domain '{domain}' migrated: recorded master '{recorded_master}' -> new master '{actual_master}'. Updating local DB." if actual_master:
if not recorded_master:
logger.info(
f"[reconciler] Domain '{domain}' hostname backfilled: '{actual_master}'"
)
record.hostname = actual_master
backfilled += 1
elif actual_master != recorded_master:
logger.warning(
f"[reconciler] Domain '{domain}' migrated: "
f"'{recorded_master}' -> '{actual_master}'. Updating local DB."
)
record.hostname = actual_master
migrated += 1
else:
# Only act if the recorded master is one we're polling
if recorded_master in known_servers:
if self.dry_run:
logger.warning(
f"[reconciler] [DRY-RUN] Would delete orphan: {record.domain} "
f"(master: {recorded_master})"
)
else:
self.delete_queue.put(
{
"domain": record.domain,
"hostname": record.hostname,
"username": record.username or "",
"source": "reconciler",
}
)
logger.debug(
f"[reconciler] Queued delete for orphan: {record.domain} "
f"(master: {recorded_master})"
)
total_queued += 1
if migrated or backfilled:
session.commit()
if backfilled:
logger.info(
f"[reconciler] {backfilled} domain(s) had missing hostname backfilled."
) )
record.hostname = actual_master if migrated:
migrated += 1 logger.info(
else: f"[reconciler] {migrated} domain(s) migrated to new master."
# Only queue delete if this is the recorded master
if recorded_master in [s.get("hostname") for s in self.servers]:
self.delete_queue.put({
"domain": record.domain,
"hostname": record.hostname,
"username": record.username or "",
"source": "reconciler",
})
logger.debug(
f"[reconciler] Queued delete for orphan: {record.domain} (master: {recorded_master})"
) )
total_queued += 1 finally:
if migrated: session.close()
session.commit() if self.dry_run:
logger.info(f"[reconciler] {migrated} domain(s) migrated to new master and updated in DB.")
logger.info(
f"[reconciler] Reconciliation pass complete — "
f"{total_queued} domain(s) queued for deletion"
)
def _reconcile_server(self, server: dict) -> int:
"""Reconcile one DA server. Returns number of domains queued for delete."""
hostname = server["hostname"]
port = server.get("port", 2222)
username = server.get("username")
password = server.get("password")
use_ssl = server.get("ssl", True)
logger.info(f"[reconciler] Polling {hostname}:{port}")
da_domains = self._fetch_da_domains(
hostname, port, username, password, use_ssl
)
if da_domains is None:
# Fetch failed — never delete on uncertainty
return 0
logger.debug(
f"[reconciler] {hostname}: {len(da_domains)} active domain(s) in DA"
)
session = connect()
our_domains = session.query(Domain).filter_by(hostname=hostname).all()
if not our_domains:
logger.debug(
f"[reconciler] {hostname}: no domains registered from this server"
)
return 0
orphans = [d for d in our_domains if d.domain not in da_domains]
if not orphans:
logger.info( logger.info(
f"[reconciler] {hostname}: all {len(our_domains)} registered " f"[reconciler] Reconciliation pass complete [DRY-RUN] — "
f"domain(s) confirmed active in DA" f"{total_queued} orphan(s) identified (none deleted)"
) )
return 0 else:
logger.info(
logger.warning( f"[reconciler] Reconciliation pass complete — "
f"[reconciler] {hostname}: {len(orphans)} orphaned domain(s) " f"{total_queued} domain(s) queued for deletion"
f"no longer in DA — queuing for deletion: "
f"{[d.domain for d in orphans]}"
)
for record in orphans:
self.delete_queue.put({
"domain": record.domain,
"hostname": record.hostname,
"username": record.username or "",
"source": "reconciler",
})
logger.debug(
f"[reconciler] Queued delete for orphan: {record.domain}"
) )
return len(orphans)
def _fetch_da_domains( def _fetch_da_domains(
self, hostname: str, port: int, username: str, password: str, use_ssl: bool, ipp: int = 1000 self,
hostname: str,
port: int,
username: str,
password: str,
use_ssl: bool,
ipp: int = 1000,
): ):
"""Fetch all domains from a DA server via CMD_DNS_ADMIN (JSON, paging supported). """Fetch all domains from a DA server via CMD_DNS_ADMIN (JSON, paging supported).
@@ -227,13 +215,21 @@ class ReconciliationWorker:
response = requests.get(url, **req_kwargs) response = requests.get(url, **req_kwargs)
if response.is_redirect or response.status_code in (301, 302, 303, 307, 308): if response.is_redirect or response.status_code in (
301,
302,
303,
307,
308,
):
if not cookies: if not cookies:
logger.debug( logger.debug(
f"[reconciler] {hostname}:{port} redirected Basic Auth " f"[reconciler] {hostname}:{port} redirected Basic Auth "
f"(HTTP {response.status_code}) — attempting session login (DA Evo)" f"(HTTP {response.status_code}) — attempting session login (DA Evo)"
) )
cookies = self._da_session_login(scheme, hostname, port, username, password) cookies = self._da_session_login(
scheme, hostname, port, username, password
)
if cookies is None: if cookies is None:
return None return None
continue # retry this page with cookies continue # retry this page with cookies
@@ -265,7 +261,10 @@ class ReconciliationWorker:
total_pages = int(info.get("total_pages", 1)) total_pages = int(info.get("total_pages", 1))
page += 1 page += 1
continue continue
except Exception: except Exception as e:
logger.error(
f"[reconciler] JSON decode failed for {hostname}:{port} page {page}: {e}\nRaw response: {response.text[:500]}"
)
# Fallback to legacy parser # Fallback to legacy parser
domains = self._parse_da_domain_list(response.text) domains = self._parse_da_domain_list(response.text)
all_domains.update(domains) all_domains.update(domains)
@@ -298,9 +297,7 @@ class ReconciliationWorker:
) )
return None return None
except Exception as e: except Exception as e:
logger.error( logger.error(f"[reconciler] Unexpected error fetching from {hostname}: {e}")
f"[reconciler] Unexpected error fetching from {hostname}: {e}"
)
return None return None
def _da_session_login( def _da_session_login(
@@ -329,12 +326,12 @@ class ReconciliationWorker:
f"check username/password." f"check username/password."
) )
return None return None
logger.debug(f"[reconciler] {hostname}:{port} session login successful (DA Evo)") logger.debug(
f"[reconciler] {hostname}:{port} session login successful (DA Evo)"
)
return response.cookies return response.cookies
except Exception as e: except Exception as e:
logger.error( logger.error(f"[reconciler] {hostname}:{port} session login failed: {e}")
f"[reconciler] {hostname}:{port} session login failed: {e}"
)
return None return None
@staticmethod @staticmethod
@@ -353,24 +350,44 @@ class ReconciliationWorker:
domains = params.get("list[]", []) domains = params.get("list[]", [])
return {d.strip().lower() for d in domains if d.strip()} return {d.strip().lower() for d in domains if d.strip()}
if __name__ == "__main__": if __name__ == "__main__":
import argparse import argparse
import sys import sys
from queue import Queue from queue import Queue
parser = argparse.ArgumentParser(description="Test DirectAdmin domain fetcher (JSON/paging)") parser = argparse.ArgumentParser(
description="Test DirectAdmin domain fetcher (JSON/paging)"
)
parser.add_argument("--hostname", required=True, help="DirectAdmin server hostname") parser.add_argument("--hostname", required=True, help="DirectAdmin server hostname")
parser.add_argument("--port", type=int, default=2222, help="DirectAdmin port (default: 2222)") parser.add_argument(
"--port", type=int, default=2222, help="DirectAdmin port (default: 2222)"
)
parser.add_argument("--username", required=True, help="DirectAdmin admin username") parser.add_argument("--username", required=True, help="DirectAdmin admin username")
parser.add_argument("--password", required=True, help="DirectAdmin admin password") parser.add_argument("--password", required=True, help="DirectAdmin admin password")
parser.add_argument("--ssl", action="store_true", help="Use HTTPS (default: True)") parser.add_argument("--ssl", action="store_true", help="Use HTTPS (default: True)")
parser.add_argument("--no-ssl", dest="ssl", action="store_false", help="Use HTTP (not recommended)") parser.add_argument(
"--no-ssl", dest="ssl", action="store_false", help="Use HTTP (not recommended)"
)
parser.set_defaults(ssl=True) parser.set_defaults(ssl=True)
parser.add_argument("--verify-ssl", action="store_true", help="Verify SSL certs (default: True)") parser.add_argument(
parser.add_argument("--no-verify-ssl", dest="verify_ssl", action="store_false", help="Don't verify SSL certs") "--verify-ssl", action="store_true", help="Verify SSL certs (default: True)"
)
parser.add_argument(
"--no-verify-ssl",
dest="verify_ssl",
action="store_false",
help="Don't verify SSL certs",
)
parser.set_defaults(verify_ssl=True) parser.set_defaults(verify_ssl=True)
parser.add_argument("--ipp", type=int, default=1000, help="Items per page (default: 1000)") parser.add_argument(
parser.add_argument("--print-json", action="store_true", help="Print raw JSON response for first page") "--ipp", type=int, default=1000, help="Items per page (default: 1000)"
)
parser.add_argument(
"--print-json",
action="store_true",
help="Print raw JSON response for first page",
)
args = parser.parse_args() args = parser.parse_args()
@@ -391,7 +408,9 @@ if __name__ == "__main__":
q = Queue() q = Queue()
worker = ReconciliationWorker(q, config) worker = ReconciliationWorker(q, config)
server = config["directadmin_servers"][0] server = config["directadmin_servers"][0]
print(f"Fetching domains from {server['hostname']}:{server['port']} (ipp={args.ipp})...") print(
f"Fetching domains from {server['hostname']}:{server['port']} (ipp={args.ipp})..."
)
# Directly call the fetch method for testing # Directly call the fetch method for testing
domains = worker._fetch_da_domains( domains = worker._fetch_da_domains(
server["hostname"], server["hostname"],
@@ -399,7 +418,7 @@ if __name__ == "__main__":
server.get("username"), server.get("username"),
server.get("password"), server.get("password"),
server.get("ssl", True), server.get("ssl", True),
ipp=args.ipp ipp=args.ipp,
) )
if domains is None: if domains is None:
print("Failed to fetch domains.", file=sys.stderr) print("Failed to fetch domains.", file=sys.stderr)

View File

@@ -59,9 +59,7 @@ def count_zone_records(zone_data: str, domain_name: str) -> int:
if rdata.rdclass == IN: if rdata.rdclass == IN:
count += 1 count += 1
logger.debug( logger.debug(f"Source zone {domain_name} contains {count} records")
f"Source zone {domain_name} contains {count} records"
)
return count return count
except DNSException as e: except DNSException as e:

View File

@@ -10,6 +10,8 @@ from typing import Any, Dict
def load_config() -> Vyper: def load_config() -> Vyper:
# Initialize Vyper # Initialize Vyper
v.set_config_name("app") # Looks for app.yaml/app.yml v.set_config_name("app") # Looks for app.yaml/app.yml
# Bundled config colocated with this module (always present in the package)
v.add_config_path(str(Path(__file__).parent))
v.add_config_path(".") # Search in current directory v.add_config_path(".") # Search in current directory
v.add_config_path("./config") v.add_config_path("./config")
v.set_env_prefix("DADNS") v.set_env_prefix("DADNS")
@@ -54,11 +56,15 @@ def load_config() -> Vyper:
# Reconciliation poller defaults # Reconciliation poller defaults
v.set_default("reconciliation.enabled", False) v.set_default("reconciliation.enabled", False)
v.set_default("reconciliation.dry_run", False)
v.set_default("reconciliation.interval_minutes", 60) v.set_default("reconciliation.interval_minutes", 60)
v.set_default("reconciliation.verify_ssl", True) v.set_default("reconciliation.verify_ssl", True)
# Read configuration # Read configuration
if not v.read_in_config(): try:
if not v.read_in_config():
logger.warning("No config file found, using defaults")
except Exception:
logger.warning("No config file found, using defaults") logger.warning("No config file found, using defaults")
return v return v

View File

@@ -12,8 +12,10 @@ app:
# If a DA server is unreachable, that server is skipped entirely. # If a DA server is unreachable, that server is skipped entirely.
#reconciliation: #reconciliation:
# enabled: true # enabled: true
# dry_run: true # log orphans but do NOT queue deletes — safe first-run mode
# interval_minutes: 60 # interval_minutes: 60
# verify_ssl: true # set false for self-signed DA certs # verify_ssl: true # set false for self-signed DA certs
# ipp: 1000 # items per page when polling DA (default 1000)
# directadmin_servers: # directadmin_servers:
# - hostname: da1.example.com # - hostname: da1.example.com
# port: 2222 # port: 2222

View File

@@ -50,7 +50,9 @@ def main():
# Configure CherryPy # Configure CherryPy
user_password_dict = { user_password_dict = {
config.get_string("app.auth_username"): config.get_string("app.auth_password") config.get_string("app.auth_username"): config.get_string(
"app.auth_password"
)
} }
check_password = cherrypy.lib.auth_basic.checkpassword_dict(user_password_dict) check_password = cherrypy.lib.auth_basic.checkpassword_dict(user_password_dict)

View File

@@ -14,7 +14,9 @@ from directdnsonly.app.reconciler import ReconciliationWorker
class WorkerManager: class WorkerManager:
def __init__(self, queue_path: str, backend_registry, reconciliation_config: dict = None): def __init__(
self, queue_path: str, backend_registry, reconciliation_config: dict = None
):
self.queue_path = queue_path self.queue_path = queue_path
self.backend_registry = backend_registry self.backend_registry = backend_registry
self._running = False self._running = False
@@ -86,9 +88,7 @@ class WorkerManager:
f"{len(backends)} backends concurrently: " f"{len(backends)} backends concurrently: "
f"{', '.join(backends.keys())}" f"{', '.join(backends.keys())}"
) )
self._process_backends_parallel( self._process_backends_parallel(backends, item, session)
backends, item, session
)
else: else:
# Single backend, no need for thread overhead # Single backend, no need for thread overhead
for backend_name, backend in backends.items(): for backend_name, backend in backends.items():
@@ -126,9 +126,7 @@ class WorkerManager:
try: try:
logger.debug(f"Using backend: {backend_name}") logger.debug(f"Using backend: {backend_name}")
if backend.write_zone(item["domain"], item["zone_file"]): if backend.write_zone(item["domain"], item["zone_file"]):
logger.debug( logger.debug(f"Successfully updated {item['domain']} in {backend_name}")
f"Successfully updated {item['domain']} in {backend_name}"
)
if backend.get_name() == "bind": if backend.get_name() == "bind":
# Need to update the named.conf # Need to update the named.conf
backend.update_named_conf( backend.update_named_conf(
@@ -144,9 +142,7 @@ class WorkerManager:
backend_name, backend, item["domain"], item["zone_file"] backend_name, backend, item["domain"], item["zone_file"]
) )
else: else:
logger.error( logger.error(f"Failed to update {item['domain']} in {backend_name}")
f"Failed to update {item['domain']} in {backend_name}"
)
except Exception as e: except Exception as e:
logger.error(f"Error in {backend_name}: {str(e)}") logger.error(f"Error in {backend_name}: {str(e)}")
@@ -165,9 +161,7 @@ class WorkerManager:
record = session.query(Domain).filter_by(domain=domain).first() record = session.query(Domain).filter_by(domain=domain).first()
if not record: if not record:
logger.warning( logger.warning(f"Domain {domain} not found in DB — skipping delete")
f"Domain {domain} not found in DB — skipping delete"
)
self.delete_queue.task_done() self.delete_queue.task_done()
continue continue
@@ -184,29 +178,83 @@ class WorkerManager:
f"skipping ownership check, proceeding with delete" f"skipping ownership check, proceeding with delete"
) )
session.delete(record)
session.commit()
logger.info(f"Removed {domain} from database")
remaining_domains = [d.domain for d in session.query(Domain).all()]
backends = self.backend_registry.get_available_backends() backends = self.backend_registry.get_available_backends()
remaining_domains = [d.domain for d in session.query(Domain).all()]
delete_success = True
if not backends: if not backends:
logger.warning( logger.warning(
f"No active backends — {domain} removed from DB only" f"No active backends — {domain} will be removed from DB only"
) )
elif len(backends) > 1: elif len(backends) > 1:
self._process_backends_delete_parallel( # Parallel delete, track failures
backends, domain, remaining_domains results = []
)
else:
for backend_name, backend in backends.items():
self._delete_single_backend(
backend_name, backend, domain, remaining_domains
)
self.delete_queue.task_done() def delete_backend_wrapper(
logger.success(f"Delete completed for {domain}") backend_name, backend, domain, remaining_domains
):
try:
return backend.delete_zone(domain)
except Exception as e:
logger.error(
f"Error deleting {domain} from {backend_name}: {e}"
)
return False
from concurrent.futures import ThreadPoolExecutor, as_completed
with ThreadPoolExecutor(max_workers=len(backends)) as executor:
futures = {
executor.submit(
delete_backend_wrapper,
backend_name,
backend,
domain,
remaining_domains,
): backend_name
for backend_name, backend in backends.items()
}
for future in as_completed(futures):
backend_name = futures[future]
try:
result = future.result()
results.append(result)
if not result:
logger.error(
f"Failed to delete {domain} from {backend_name}"
)
except Exception as e:
logger.error(
f"Unhandled error deleting from {backend_name}: {e}"
)
results.append(False)
delete_success = all(results)
else:
# Single backend
for backend_name, backend in backends.items():
try:
result = backend.delete_zone(domain)
if not result:
logger.error(
f"Failed to delete {domain} from {backend_name}"
)
delete_success = False
except Exception as e:
logger.error(
f"Error deleting {domain} from {backend_name}: {e}"
)
delete_success = False
if delete_success:
session.delete(record)
session.commit()
logger.info(f"Removed {domain} from database")
self.delete_queue.task_done()
logger.success(f"Delete completed for {domain}")
else:
logger.error(
f"Delete failed for {domain} on one or more backends — DB record retained"
)
self.delete_queue.task_done()
except Empty: except Empty:
continue continue
@@ -239,7 +287,10 @@ class WorkerManager:
futures = { futures = {
executor.submit( executor.submit(
self._delete_single_backend, self._delete_single_backend,
backend_name, backend, domain, remaining_domains backend_name,
backend,
domain,
remaining_domains,
): backend_name ): backend_name
for backend_name, backend in backends.items() for backend_name, backend in backends.items()
} }
@@ -248,9 +299,7 @@ class WorkerManager:
try: try:
future.result() future.result()
except Exception as e: except Exception as e:
logger.error( logger.error(f"Unhandled error deleting from {backend_name}: {e}")
f"Unhandled error deleting from {backend_name}: {e}"
)
elapsed = (time.monotonic() - start_time) * 1000 elapsed = (time.monotonic() - start_time) * 1000
logger.debug( logger.debug(
f"Parallel delete of {domain} across " f"Parallel delete of {domain} across "
@@ -261,13 +310,11 @@ class WorkerManager:
"""Process zone updates across multiple backends in parallel""" """Process zone updates across multiple backends in parallel"""
start_time = time.monotonic() start_time = time.monotonic()
with ThreadPoolExecutor( with ThreadPoolExecutor(
max_workers=len(backends), max_workers=len(backends), thread_name_prefix="backend"
thread_name_prefix="backend"
) as executor: ) as executor:
futures = { futures = {
executor.submit( executor.submit(
self._process_single_backend, self._process_single_backend, backend_name, backend, item, session
backend_name, backend, item, session
): backend_name ): backend_name
for backend_name, backend in backends.items() for backend_name, backend in backends.items()
} }
@@ -286,9 +333,7 @@ class WorkerManager:
f"{len(backends)} backends completed in {elapsed:.0f}ms" f"{len(backends)} backends completed in {elapsed:.0f}ms"
) )
def _verify_backend_record_count( def _verify_backend_record_count(self, backend_name, backend, zone_name, zone_data):
self, backend_name, backend, zone_name, zone_data
):
"""Verify and reconcile the backend record count against the """Verify and reconcile the backend record count against the
authoritative BIND zone from DirectAdmin. authoritative BIND zone from DirectAdmin.
@@ -313,9 +358,7 @@ class WorkerManager:
) )
return return
matches, actual = backend.verify_zone_record_count( matches, actual = backend.verify_zone_record_count(zone_name, expected)
zone_name, expected
)
if matches: if matches:
return # All good return # All good
@@ -326,9 +369,7 @@ class WorkerManager:
f"record(s) for {zone_name} — reconciling against " f"record(s) for {zone_name} — reconciling against "
f"DirectAdmin source zone" f"DirectAdmin source zone"
) )
success, removed = backend.reconcile_zone_records( success, removed = backend.reconcile_zone_records(zone_name, zone_data)
zone_name, zone_data
)
if success and removed > 0: if success and removed > 0:
# Verify again after reconciliation # Verify again after reconciliation
matches, new_count = backend.verify_zone_record_count( matches, new_count = backend.verify_zone_record_count(
@@ -406,6 +447,9 @@ class WorkerManager:
"save_queue_size": self.save_queue.qsize(), "save_queue_size": self.save_queue.qsize(),
"delete_queue_size": self.delete_queue.qsize(), "delete_queue_size": self.delete_queue.qsize(),
"save_worker_alive": self._save_thread and self._save_thread.is_alive(), "save_worker_alive": self._save_thread and self._save_thread.is_alive(),
"delete_worker_alive": self._delete_thread and self._delete_thread.is_alive(), "delete_worker_alive": self._delete_thread
"reconciler_alive": self._reconciler.is_alive if self._reconciler else False, and self._delete_thread.is_alive(),
"reconciler_alive": (
self._reconciler.is_alive if self._reconciler else False
),
} }

109
justfile
View File

@@ -1,17 +1,98 @@
#!/usr/bin/env just --justfile #!/usr/bin/env just --justfile
# directdnsonly — developer task runner
# Requires: just, pyenv, poetry
APP_NAME := "directdnsonly" APP_NAME := "directdnsonly"
# Ensure pyenv shims and common install locations are on PATH so that `python`
# resolves via pyenv (.python-version) and `poetry` is found without a full
# shell init in every recipe.
export PATH := env_var("HOME") + "/.pyenv/shims:" + env_var("HOME") + "/.pyenv/bin:" + env_var("HOME") + "/.local/bin:" + env_var("PATH")
# List available recipes (default)
default:
@just --list
# ---------------------------------------------------------------------------
# Setup
# ---------------------------------------------------------------------------
# Install all dependencies (including dev group)
install:
poetry install
# Install only production dependencies
install-prod:
poetry install --only main
# Show the Python interpreter that will be used
which-python:
@poetry run python --version
@poetry run python -c "import sys; print(sys.executable)"
# ---------------------------------------------------------------------------
# Testing
# ---------------------------------------------------------------------------
# Run the full test suite
test:
poetry run pytest tests/ -v
# Run tests with terminal coverage report
coverage:
poetry run pytest tests/ -v --cov=directdnsonly --cov-report=term-missing
# Run tests with HTML coverage report (opens in browser)
coverage-html:
poetry run pytest tests/ --cov=directdnsonly --cov-report=html
@echo "Coverage report: htmlcov/index.html"
# Run a single test file or pattern, e.g. just test-one test_reconciler
test-one target:
poetry run pytest tests/ -v -k "{{target}}"
# ---------------------------------------------------------------------------
# Code quality
# ---------------------------------------------------------------------------
# Format all source and test files with black
fmt:
poetry run black directdnsonly/ tests/
# Check formatting without making changes (CI-safe)
fmt-check:
poetry run black --check directdnsonly/ tests/
# CI gate — run fmt-check then test, fail fast
ci: fmt-check test
# Start the application
run:
poetry run python -m directdnsonly
# ---------------------------------------------------------------------------
# Build
# ---------------------------------------------------------------------------
# Build a standalone binary with PyInstaller
build: build:
cd src && \ poetry run pyinstaller \
pyinstaller \ --hidden-import=json \
-p . \ --hidden-import=pymysql \
--hidden-import=json \ --hidden-import=cheroot \
--hidden-import=pyopenssl \ --hidden-import=cheroot.ssl.pyopenssl \
--hidden-import=pymysql \ --hidden-import=cheroot.ssl.builtin \
--hidden-import=jaraco \ --noconfirm --onefile \
--hidden-import=cheroot \ --name=directdnsonly \
--hidden-import=cheroot.ssl.pyopenssl \ directdnsonly/main.py
--hidden-import=cheroot.ssl.builtin \ rm -f *.spec
--hidden-import=lib \
--hidden-import=os \ # ---------------------------------------------------------------------------
--hidden-import=builtins \ # Clean
--noconfirm --onefile {{APP_NAME}}.py # ---------------------------------------------------------------------------
# Remove build artefacts, caches, and compiled bytecode
clean:
rm -rf dist/ build/*.spec .coverage htmlcov/ .pytest_cache/
find . -type d -name __pycache__ -exec rm -rf {} + 2>/dev/null || true
find . -name "*.pyc" -delete 2>/dev/null || true

40
tests/conftest.py Normal file
View File

@@ -0,0 +1,40 @@
"""Shared test fixtures for directdnsonly test suite."""
import pytest
from sqlalchemy import create_engine
from sqlalchemy.orm import sessionmaker
from directdnsonly.app.db import Base
from directdnsonly.app.db.models import (
Domain,
Key,
) # noqa: F401 — registers models with Base
@pytest.fixture
def engine():
eng = create_engine("sqlite:///:memory:")
Base.metadata.create_all(eng)
yield eng
eng.dispose()
@pytest.fixture
def db_session(engine):
session = sessionmaker(bind=engine)()
yield session
session.close()
@pytest.fixture
def patch_connect(db_session, monkeypatch):
"""Patch connect() at every call-site, returning the shared test session.
Modules that import connect() directly (e.g. utils, reconciler) are
patched at their local name so the in-memory SQLite session is used
instead of trying to read from vyper config.
"""
_factory = lambda: db_session # noqa: E731
monkeypatch.setattr("directdnsonly.app.utils.connect", _factory)
monkeypatch.setattr("directdnsonly.app.reconciler.connect", _factory)
return db_session

219
tests/test_admin_api.py Normal file
View File

@@ -0,0 +1,219 @@
"""Tests for directdnsonly.app.api.admin — DNSAdminAPI handler methods."""
import pytest
from unittest.mock import MagicMock, patch
from urllib.parse import parse_qs
import cherrypy
from directdnsonly.app.api.admin import DNSAdminAPI
# ---------------------------------------------------------------------------
# Fixtures
# ---------------------------------------------------------------------------
@pytest.fixture
def save_queue():
return MagicMock()
@pytest.fixture
def delete_queue():
return MagicMock()
@pytest.fixture
def api(save_queue, delete_queue):
return DNSAdminAPI(save_queue, delete_queue, backend_registry=MagicMock())
# ---------------------------------------------------------------------------
# CMD_API_LOGIN_TEST
# ---------------------------------------------------------------------------
def test_login_test_returns_success(api):
result = api.CMD_API_LOGIN_TEST()
parsed = parse_qs(result)
assert parsed["error"] == ["0"]
assert parsed["text"] == ["Login OK"]
# ---------------------------------------------------------------------------
# _handle_exists — GET action=exists
# ---------------------------------------------------------------------------
def test_handle_exists_missing_domain_returns_error(api):
with patch.object(cherrypy, "response", MagicMock()):
result = api._handle_exists({"action": "exists"})
parsed = parse_qs(result)
assert parsed["error"] == ["1"]
def test_handle_exists_unsupported_action_returns_error(api):
with patch.object(cherrypy, "response", MagicMock()):
result = api._handle_exists({"action": "rawsave"})
parsed = parse_qs(result)
assert parsed["error"] == ["1"]
def test_handle_exists_domain_not_found(api):
with (
patch("directdnsonly.app.api.admin.check_zone_exists", return_value=False),
patch(
"directdnsonly.app.api.admin.check_parent_domain_owner", return_value=False
),
):
result = api._handle_exists({"action": "exists", "domain": "example.com"})
parsed = parse_qs(result)
assert parsed["error"] == ["0"]
assert parsed["exists"] == ["0"]
def test_handle_exists_domain_found(api):
record = MagicMock()
record.hostname = "da1.example.com"
with (
patch("directdnsonly.app.api.admin.check_zone_exists", return_value=True),
patch("directdnsonly.app.api.admin.get_domain_record", return_value=record),
):
result = api._handle_exists({"action": "exists", "domain": "example.com"})
parsed = parse_qs(result)
assert parsed["error"] == ["0"]
assert parsed["exists"] == ["1"]
assert "da1.example.com" in parsed["details"][0]
def test_handle_exists_parent_found(api):
parent = MagicMock()
parent.hostname = "da2.example.com"
with (
patch("directdnsonly.app.api.admin.check_zone_exists", return_value=False),
patch(
"directdnsonly.app.api.admin.check_parent_domain_owner", return_value=True
),
patch(
"directdnsonly.app.api.admin.get_parent_domain_record", return_value=parent
),
):
result = api._handle_exists(
{
"action": "exists",
"domain": "sub.example.com",
"check_for_parent_domain": "1",
}
)
parsed = parse_qs(result)
assert parsed["error"] == ["0"]
assert parsed["exists"] == ["2"]
assert "da2.example.com" in parsed["details"][0]
def test_handle_exists_no_parent_check_when_flag_absent(api):
"""check_parent_domain_owner should not be called if flag not set."""
record = MagicMock()
record.hostname = "da1.example.com"
with (
patch("directdnsonly.app.api.admin.check_zone_exists", return_value=True),
patch("directdnsonly.app.api.admin.check_parent_domain_owner") as mock_parent,
patch("directdnsonly.app.api.admin.get_domain_record", return_value=record),
):
api._handle_exists({"action": "exists", "domain": "example.com"})
mock_parent.assert_not_called()
# ---------------------------------------------------------------------------
# _handle_rawsave
# ---------------------------------------------------------------------------
SAMPLE_ZONE = "$ORIGIN example.com.\n$TTL 300\nexample.com. 300 IN A 1.2.3.4\n"
def test_rawsave_enqueues_item(api, save_queue):
with (
patch(
"directdnsonly.app.api.admin.validate_and_normalize_zone",
return_value=SAMPLE_ZONE,
),
patch.object(cherrypy, "request", MagicMock(remote=MagicMock(ip="127.0.0.1"))),
):
result = api._handle_rawsave(
"example.com",
{
"zone_file": SAMPLE_ZONE,
"hostname": "da1.example.com",
"username": "admin",
},
)
save_queue.put.assert_called_once()
item = save_queue.put.call_args[0][0]
assert item["domain"] == "example.com"
assert item["hostname"] == "da1.example.com"
assert item["username"] == "admin"
assert item["client_ip"] == "127.0.0.1"
parsed = parse_qs(result)
assert parsed["error"] == ["0"]
def test_rawsave_missing_zone_file_raises(api):
with patch.object(cherrypy, "request", MagicMock(remote=MagicMock(ip="127.0.0.1"))):
with pytest.raises(ValueError, match="Missing zone file"):
api._handle_rawsave("example.com", {})
def test_rawsave_invalid_zone_raises(api):
with (
patch(
"directdnsonly.app.api.admin.validate_and_normalize_zone",
side_effect=ValueError("Invalid zone data: bad record"),
),
patch.object(cherrypy, "request", MagicMock(remote=MagicMock(ip="127.0.0.1"))),
):
with pytest.raises(ValueError, match="Invalid zone data"):
api._handle_rawsave("example.com", {"zone_file": "garbage"})
# ---------------------------------------------------------------------------
# _handle_delete
# ---------------------------------------------------------------------------
def test_delete_enqueues_item(api, delete_queue):
with patch.object(cherrypy, "request", MagicMock(remote=MagicMock(ip="10.0.0.1"))):
result = api._handle_delete(
"example.com",
{
"hostname": "da1.example.com",
"username": "admin",
},
)
delete_queue.put.assert_called_once()
item = delete_queue.put.call_args[0][0]
assert item["domain"] == "example.com"
assert item["hostname"] == "da1.example.com"
assert item["client_ip"] == "10.0.0.1"
parsed = parse_qs(result)
assert parsed["error"] == ["0"]
def test_delete_missing_params_uses_empty_strings(api, delete_queue):
with patch.object(cherrypy, "request", MagicMock(remote=MagicMock(ip="127.0.0.1"))):
api._handle_delete("example.com", {})
item = delete_queue.put.call_args[0][0]
assert item["hostname"] == ""
assert item["username"] == ""

View File

@@ -1,47 +1,167 @@
"""Tests for the CoreDNS MySQL backend (run against in-memory SQLite)."""
import pytest import pytest
from sqlalchemy import create_engine from sqlalchemy import create_engine
from sqlalchemy.orm import scoped_session, sessionmaker from sqlalchemy.orm import scoped_session, sessionmaker
from directdnsonly.app.backends.coredns_mysql import CoreDNSMySQLBackend, CoreDNSRecord from directdnsonly.app.backends.coredns_mysql import (
from loguru import logger Base,
CoreDNSMySQLBackend,
Record,
Zone,
)
# ---------------------------------------------------------------------------
# Fixture — in-memory SQLite backend (bypasses real MySQL connection)
# ---------------------------------------------------------------------------
@pytest.fixture @pytest.fixture
def mysql_backend(tmp_path): def mysql_backend():
# Setup in-memory SQLite for testing (replace with test MySQL in CI)
engine = create_engine("sqlite:///:memory:") engine = create_engine("sqlite:///:memory:")
CoreDNSRecord.metadata.create_all(engine) Base.metadata.create_all(engine)
class TestBackend(CoreDNSMySQLBackend): class _TestBackend(CoreDNSMySQLBackend):
def __init__(self): def __init__(self):
super().__init__() # Manually initialise without triggering the MySQL create_engine call
self.config = {}
self.instance_name = "test"
self.engine = engine self.engine = engine
self.Session = scoped_session(sessionmaker(bind=engine)) self.Session = scoped_session(sessionmaker(bind=engine))
yield TestBackend() yield _TestBackend()
engine.dispose() engine.dispose()
def test_zone_operations(mysql_backend): # ---------------------------------------------------------------------------
zone_data = """ # write_zone / zone_exists / delete_zone
# ---------------------------------------------------------------------------
ZONE_DATA = """\
$ORIGIN example.com.
$TTL 300
example.com. 300 IN SOA ns.example.com. admin.example.com. (2023 3600 1800 604800 86400) example.com. 300 IN SOA ns.example.com. admin.example.com. (2023 3600 1800 604800 86400)
example.com. 300 IN A 192.0.2.1 example.com. 300 IN A 192.0.2.1
""" """
# Test zone creation
assert mysql_backend.write_zone("example.com", zone_data)
def test_write_zone_creates_zone(mysql_backend):
assert mysql_backend.write_zone("example.com", ZONE_DATA)
def test_zone_exists_after_write(mysql_backend):
mysql_backend.write_zone("example.com", ZONE_DATA)
assert mysql_backend.zone_exists("example.com") assert mysql_backend.zone_exists("example.com")
# Test record update
updated_zone = """ def test_zone_does_not_exist_before_write(mysql_backend):
assert not mysql_backend.zone_exists("missing.com")
def test_write_zone_idempotent(mysql_backend):
assert mysql_backend.write_zone("example.com", ZONE_DATA)
assert mysql_backend.write_zone("example.com", ZONE_DATA)
def test_write_zone_updates_records(mysql_backend):
mysql_backend.write_zone("example.com", ZONE_DATA)
updated = """\
$ORIGIN example.com.
$TTL 300
example.com. 3600 IN A 192.0.2.1 example.com. 3600 IN A 192.0.2.1
example.com. 300 IN AAAA 2001:db8::1 example.com. 300 IN AAAA 2001:db8::1
""" """
assert mysql_backend.write_zone("example.com", updated_zone) assert mysql_backend.write_zone("example.com", updated)
# Test record removal
reduced_zone = "example.com. 300 IN A 192.0.2.1"
assert mysql_backend.write_zone("example.com", reduced_zone)
# Test zone deletion def test_write_zone_removes_stale_records(mysql_backend):
mysql_backend.write_zone("example.com", ZONE_DATA)
reduced = "example.com. 300 IN A 192.0.2.1"
mysql_backend.write_zone("example.com", reduced)
session = mysql_backend.Session()
zone = session.query(Zone).filter_by(zone_name="example.com.").first()
records = session.query(Record).filter_by(zone_id=zone.id, type="AAAA").all()
assert records == []
session.close()
def test_delete_zone_removes_zone_and_records(mysql_backend):
mysql_backend.write_zone("example.com", ZONE_DATA)
assert mysql_backend.delete_zone("example.com") assert mysql_backend.delete_zone("example.com")
assert not mysql_backend.zone_exists("example.com") assert not mysql_backend.zone_exists("example.com")
def test_delete_nonexistent_zone_returns_false(mysql_backend):
assert not mysql_backend.delete_zone("ghost.com")
def test_reload_zone_returns_true(mysql_backend):
assert mysql_backend.reload_zone("example.com")
assert mysql_backend.reload_zone()
# ---------------------------------------------------------------------------
# verify_zone_record_count
# ---------------------------------------------------------------------------
def test_verify_zone_record_count_match(mysql_backend):
mysql_backend.write_zone("example.com", ZONE_DATA)
# SOA + A = 2 records total
matches, count = mysql_backend.verify_zone_record_count("example.com", 2)
assert matches
assert count == 2
def test_verify_zone_record_count_mismatch(mysql_backend):
mysql_backend.write_zone("example.com", ZONE_DATA)
matches, count = mysql_backend.verify_zone_record_count("example.com", 99)
assert not matches
assert count == 2
def test_verify_zone_record_count_missing_zone(mysql_backend):
matches, count = mysql_backend.verify_zone_record_count("ghost.com", 0)
assert not matches
assert count == 0
# ---------------------------------------------------------------------------
# reconcile_zone_records
# ---------------------------------------------------------------------------
def test_reconcile_removes_extra_records(mysql_backend):
mysql_backend.write_zone("example.com", ZONE_DATA)
# Inject a phantom record directly into the DB
session = mysql_backend.Session()
zone = session.query(Zone).filter_by(zone_name="example.com.").first()
session.add(
Record(
zone_id=zone.id,
hostname="phantom",
type="A",
data="10.0.0.99",
ttl=300,
online=True,
)
)
session.commit()
session.close()
success, removed = mysql_backend.reconcile_zone_records("example.com", ZONE_DATA)
assert success
assert removed == 1
def test_reconcile_no_changes_when_zone_matches(mysql_backend):
mysql_backend.write_zone("example.com", ZONE_DATA)
success, removed = mysql_backend.reconcile_zone_records("example.com", ZONE_DATA)
assert success
assert removed == 0

299
tests/test_reconciler.py Normal file
View File

@@ -0,0 +1,299 @@
"""Tests for directdnsonly.app.reconciler — ReconciliationWorker."""
import pytest
import requests.exceptions
from queue import Queue
from unittest.mock import MagicMock, patch
from directdnsonly.app.reconciler import ReconciliationWorker
from directdnsonly.app.db.models import Domain
# ---------------------------------------------------------------------------
# Fixtures
# ---------------------------------------------------------------------------
SERVER = {
"hostname": "da1.example.com",
"port": 2222,
"username": "admin",
"password": "secret",
"ssl": True,
}
BASE_CONFIG = {
"enabled": True,
"dry_run": False,
"interval_minutes": 60,
"verify_ssl": True,
"ipp": 100,
"directadmin_servers": [SERVER],
}
@pytest.fixture
def delete_queue():
return Queue()
@pytest.fixture
def worker(delete_queue):
return ReconciliationWorker(delete_queue, BASE_CONFIG)
@pytest.fixture
def dry_run_worker(delete_queue):
cfg = {**BASE_CONFIG, "dry_run": True}
return ReconciliationWorker(delete_queue, cfg)
# ---------------------------------------------------------------------------
# _reconcile_all — orphan detection
# ---------------------------------------------------------------------------
def test_orphan_queued_when_domain_missing_from_da(worker, delete_queue, patch_connect):
patch_connect.add(
Domain(domain="orphan.com", hostname="da1.example.com", username="admin")
)
patch_connect.commit()
with patch.object(worker, "_fetch_da_domains", return_value=set()):
worker._reconcile_all()
assert not delete_queue.empty()
item = delete_queue.get_nowait()
assert item["domain"] == "orphan.com"
assert item["source"] == "reconciler"
def test_orphan_not_queued_in_dry_run(dry_run_worker, delete_queue, patch_connect):
patch_connect.add(
Domain(domain="orphan.com", hostname="da1.example.com", username="admin")
)
patch_connect.commit()
with patch.object(dry_run_worker, "_fetch_da_domains", return_value=set()):
dry_run_worker._reconcile_all()
assert delete_queue.empty()
def test_orphan_not_queued_for_unknown_server(worker, delete_queue, patch_connect):
"""Domains whose recorded master is NOT in our configured servers are skipped."""
patch_connect.add(
Domain(domain="other.com", hostname="da99.unknown.com", username="admin")
)
patch_connect.commit()
with patch.object(worker, "_fetch_da_domains", return_value=set()):
worker._reconcile_all()
assert delete_queue.empty()
def test_active_domain_not_queued(worker, delete_queue, patch_connect):
patch_connect.add(
Domain(domain="good.com", hostname="da1.example.com", username="admin")
)
patch_connect.commit()
with patch.object(worker, "_fetch_da_domains", return_value={"good.com"}):
worker._reconcile_all()
assert delete_queue.empty()
# ---------------------------------------------------------------------------
# _reconcile_all — hostname backfill and migration
# ---------------------------------------------------------------------------
def test_backfill_null_hostname(worker, patch_connect):
patch_connect.add(Domain(domain="backfill.com", hostname=None, username="admin"))
patch_connect.commit()
with patch.object(worker, "_fetch_da_domains", return_value={"backfill.com"}):
worker._reconcile_all()
record = patch_connect.query(Domain).filter_by(domain="backfill.com").first()
assert record.hostname == "da1.example.com"
def test_migration_updates_hostname(worker, patch_connect):
patch_connect.add(
Domain(domain="moved.com", hostname="da-old.example.com", username="admin")
)
patch_connect.commit()
with patch.object(worker, "_fetch_da_domains", return_value={"moved.com"}):
worker._reconcile_all()
record = patch_connect.query(Domain).filter_by(domain="moved.com").first()
assert record.hostname == "da1.example.com"
def test_dry_run_still_backfills(dry_run_worker, patch_connect):
"""Backfill is a data-repair operation, applied even in dry-run mode."""
patch_connect.add(Domain(domain="fill.com", hostname=None, username="admin"))
patch_connect.commit()
with patch.object(dry_run_worker, "_fetch_da_domains", return_value={"fill.com"}):
dry_run_worker._reconcile_all()
record = patch_connect.query(Domain).filter_by(domain="fill.com").first()
assert record.hostname == "da1.example.com"
# ---------------------------------------------------------------------------
# _fetch_da_domains — HTTP handling
# ---------------------------------------------------------------------------
def _make_json_response(domains_dict, total_pages=1):
"""Return a mock requests.Response with JSON payload matching DA format."""
data = {str(i): {"domain": d} for i, d in enumerate(domains_dict)}
data["info"] = {"total_pages": total_pages}
mock = MagicMock()
mock.status_code = 200
mock.is_redirect = False
mock.headers = {"Content-Type": "application/json"}
mock.json.return_value = data
mock.raise_for_status = MagicMock()
return mock
def test_fetch_returns_domains_from_json(worker):
mock_resp = _make_json_response(["example.com", "test.com"])
with patch("requests.get", return_value=mock_resp):
result = worker._fetch_da_domains(
"da1.example.com", 2222, "admin", "secret", True
)
assert result == {"example.com", "test.com"}
def test_fetch_paginates(worker):
page1 = _make_json_response(["a.com"], total_pages=2)
page2 = _make_json_response(["b.com"], total_pages=2)
with patch("requests.get", side_effect=[page1, page2]):
result = worker._fetch_da_domains(
"da1.example.com", 2222, "admin", "secret", True
)
assert result == {"a.com", "b.com"}
def test_fetch_redirect_triggers_session_login(worker):
redirect_resp = MagicMock()
redirect_resp.status_code = 302
redirect_resp.is_redirect = True
with (
patch("requests.get", return_value=redirect_resp),
patch.object(worker, "_da_session_login", return_value=None),
):
result = worker._fetch_da_domains(
"da1.example.com", 2222, "admin", "secret", True
)
assert result is None
def test_fetch_html_response_returns_none(worker):
mock_resp = MagicMock()
mock_resp.status_code = 200
mock_resp.is_redirect = False
mock_resp.headers = {"Content-Type": "text/html; charset=utf-8"}
mock_resp.raise_for_status = MagicMock()
with patch("requests.get", return_value=mock_resp):
result = worker._fetch_da_domains(
"da1.example.com", 2222, "admin", "secret", True
)
assert result is None
def test_fetch_connection_error_returns_none(worker):
with patch(
"requests.get", side_effect=requests.exceptions.ConnectionError("refused")
):
result = worker._fetch_da_domains(
"da1.example.com", 2222, "admin", "secret", True
)
assert result is None
def test_fetch_timeout_returns_none(worker):
with patch("requests.get", side_effect=requests.exceptions.Timeout()):
result = worker._fetch_da_domains(
"da1.example.com", 2222, "admin", "secret", True
)
assert result is None
def test_fetch_ssl_error_returns_none(worker):
with patch(
"requests.get", side_effect=requests.exceptions.SSLError("cert verify failed")
):
result = worker._fetch_da_domains(
"da1.example.com", 2222, "admin", "secret", True
)
assert result is None
# ---------------------------------------------------------------------------
# _parse_da_domain_list — legacy format fallback
# ---------------------------------------------------------------------------
def test_parse_standard_querystring():
body = "list[]=example.com&list[]=test.com"
result = ReconciliationWorker._parse_da_domain_list(body)
assert result == {"example.com", "test.com"}
def test_parse_newline_separated():
body = "list[]=example.com\nlist[]=test.com"
result = ReconciliationWorker._parse_da_domain_list(body)
assert result == {"example.com", "test.com"}
def test_parse_empty_body_returns_empty_set():
assert ReconciliationWorker._parse_da_domain_list("") == set()
def test_parse_normalises_to_lowercase():
result = ReconciliationWorker._parse_da_domain_list("list[]=EXAMPLE.COM")
assert "example.com" in result
assert "EXAMPLE.COM" not in result
def test_parse_strips_whitespace():
result = ReconciliationWorker._parse_da_domain_list("list[]= example.com ")
assert "example.com" in result
# ---------------------------------------------------------------------------
# Worker lifecycle
# ---------------------------------------------------------------------------
def test_disabled_worker_does_not_start(delete_queue):
cfg = {**BASE_CONFIG, "enabled": False}
w = ReconciliationWorker(delete_queue, cfg)
w.start()
assert not w.is_alive
def test_no_servers_does_not_start(delete_queue):
cfg = {**BASE_CONFIG, "directadmin_servers": []}
w = ReconciliationWorker(delete_queue, cfg)
w.start()
assert not w.is_alive

138
tests/test_utils.py Normal file
View File

@@ -0,0 +1,138 @@
"""Tests for directdnsonly.app.utils — zone index helper functions."""
import pytest
from directdnsonly.app.db.models import Domain
from directdnsonly.app.utils import (
check_zone_exists,
check_parent_domain_owner,
get_domain_record,
get_parent_domain_record,
put_zone_index,
)
# ---------------------------------------------------------------------------
# check_zone_exists
# ---------------------------------------------------------------------------
def test_check_zone_exists_not_found(patch_connect):
assert check_zone_exists("example.com") is False
def test_check_zone_exists_found(patch_connect):
patch_connect.add(
Domain(domain="example.com", hostname="da1.example.com", username="admin")
)
patch_connect.commit()
assert check_zone_exists("example.com") is True
def test_check_zone_exists_does_not_match_partial(patch_connect):
patch_connect.add(
Domain(domain="example.com", hostname="da1.example.com", username="admin")
)
patch_connect.commit()
assert check_zone_exists("sub.example.com") is False
# ---------------------------------------------------------------------------
# put_zone_index
# ---------------------------------------------------------------------------
def test_put_zone_index_adds_record(patch_connect):
put_zone_index("new.com", "da1.example.com", "admin")
record = patch_connect.query(Domain).filter_by(domain="new.com").first()
assert record is not None
assert record.hostname == "da1.example.com"
assert record.username == "admin"
def test_put_zone_index_stores_domain_name(patch_connect):
put_zone_index("another.nz", "da2.example.com", "user1")
assert check_zone_exists("another.nz") is True
# ---------------------------------------------------------------------------
# get_domain_record
# ---------------------------------------------------------------------------
def test_get_domain_record_returns_none_when_missing(patch_connect):
assert get_domain_record("missing.com") is None
def test_get_domain_record_returns_record(patch_connect):
patch_connect.add(
Domain(domain="found.com", hostname="da1.example.com", username="admin")
)
patch_connect.commit()
record = get_domain_record("found.com")
assert record is not None
assert record.domain == "found.com"
assert record.hostname == "da1.example.com"
# ---------------------------------------------------------------------------
# check_parent_domain_owner
# ---------------------------------------------------------------------------
def test_check_parent_domain_owner_not_found(patch_connect):
assert check_parent_domain_owner("sub.example.com") is False
def test_check_parent_domain_owner_found(patch_connect):
patch_connect.add(
Domain(domain="example.com", hostname="da1.example.com", username="admin")
)
patch_connect.commit()
assert check_parent_domain_owner("sub.example.com") is True
def test_check_parent_domain_owner_single_label_returns_false(patch_connect):
# A single-label name like "com" has no parent
assert check_parent_domain_owner("com") is False
def test_check_parent_domain_owner_ignores_grandparent(patch_connect):
# Only the immediate parent is checked, not grandparents
patch_connect.add(
Domain(domain="example.com", hostname="da1.example.com", username="admin")
)
patch_connect.commit()
# deep.sub.example.com's immediate parent is sub.example.com (not in DB)
assert check_parent_domain_owner("deep.sub.example.com") is False
# ---------------------------------------------------------------------------
# get_parent_domain_record
# ---------------------------------------------------------------------------
def test_get_parent_domain_record_returns_none_when_missing(patch_connect):
assert get_parent_domain_record("sub.example.com") is None
def test_get_parent_domain_record_returns_parent(patch_connect):
patch_connect.add(
Domain(domain="example.com", hostname="da1.example.com", username="admin")
)
patch_connect.commit()
parent = get_parent_domain_record("sub.example.com")
assert parent is not None
assert parent.domain == "example.com"
def test_get_parent_domain_record_single_label_returns_none(patch_connect):
assert get_parent_domain_record("com") is None

101
tests/test_zone_parser.py Normal file
View File

@@ -0,0 +1,101 @@
"""Tests for directdnsonly.app.utils.zone_parser."""
import pytest
from dns.exception import DNSException
from directdnsonly.app.utils.zone_parser import (
count_zone_records,
validate_and_normalize_zone,
)
# ---------------------------------------------------------------------------
# Helpers
# ---------------------------------------------------------------------------
MINIMAL_ZONE = "example.com. 300 IN A 1.2.3.4"
FULL_ZONE = """\
$ORIGIN example.com.
$TTL 300
@ IN SOA ns1.example.com. admin.example.com. 2024010101 3600 900 604800 300
@ IN NS ns1.example.com.
@ IN A 1.2.3.4
www IN A 5.6.7.8
mail IN MX 10 mail.example.com.
"""
# ---------------------------------------------------------------------------
# validate_and_normalize_zone
# ---------------------------------------------------------------------------
def test_validate_adds_origin_when_missing():
result = validate_and_normalize_zone(MINIMAL_ZONE, "example.com")
assert "$ORIGIN example.com." in result
def test_validate_adds_ttl_when_missing():
result = validate_and_normalize_zone(MINIMAL_ZONE, "example.com")
assert "$TTL" in result
def test_validate_does_not_duplicate_origin():
zone = "$ORIGIN example.com.\nexample.com. 300 IN A 1.2.3.4"
result = validate_and_normalize_zone(zone, "example.com")
assert result.count("$ORIGIN") == 1
def test_validate_does_not_duplicate_ttl():
zone = "$TTL 300\nexample.com. 300 IN A 1.2.3.4"
result = validate_and_normalize_zone(zone, "example.com")
assert result.count("$TTL") == 1
def test_validate_appends_dot_to_domain():
result = validate_and_normalize_zone(MINIMAL_ZONE, "example.com")
assert "$ORIGIN example.com." in result
def test_validate_returns_string():
result = validate_and_normalize_zone(MINIMAL_ZONE, "example.com")
assert isinstance(result, str)
def test_validate_full_zone_passes():
result = validate_and_normalize_zone(FULL_ZONE, "example.com")
assert result is not None
def test_validate_raises_on_invalid_zone():
bad_zone = "this is not a zone file at all !!!"
with pytest.raises(ValueError, match="Invalid zone data"):
validate_and_normalize_zone(bad_zone, "example.com")
# ---------------------------------------------------------------------------
# count_zone_records
# ---------------------------------------------------------------------------
def test_count_records_simple_zone():
zone = "$ORIGIN example.com.\n$TTL 300\n@ IN A 1.2.3.4\n@ IN AAAA ::1\n"
count = count_zone_records(zone, "example.com")
assert count == 2
def test_count_records_soa_included():
count = count_zone_records(FULL_ZONE, "example.com")
# SOA + NS + A (apex) + A (www) + MX = 5
assert count == 5
def test_count_records_returns_negative_on_bad_zone():
count = count_zone_records("not a valid zone", "example.com")
assert count == -1
def test_count_records_empty_zone():
zone = "$ORIGIN example.com.\n$TTL 300\n"
count = count_zone_records(zone, "example.com")
assert count == 0