You've already forked directdnsonly
Compare commits
4 Commits
b8f12d0208
...
0903d78458
| Author | SHA1 | Date | |
|---|---|---|---|
| 0903d78458 | |||
| 74c5f4012e | |||
| 807d6271f1 | |||
| bd46227364 |
1
.gitignore
vendored
1
.gitignore
vendored
@@ -1,4 +1,5 @@
|
||||
*.db
|
||||
dist/
|
||||
venv/
|
||||
.venv
|
||||
.idea
|
||||
|
||||
105
app.py
105
app.py
@@ -1,105 +0,0 @@
|
||||
from flask import Flask, request
|
||||
import mmap
|
||||
import re
|
||||
|
||||
app = Flask(__name__)
|
||||
|
||||
|
||||
@app.route('/')
|
||||
def hello_world():
|
||||
return 'Hello World!'
|
||||
|
||||
|
||||
@app.route('/CMD_API_LOGIN_TEST')
|
||||
def login_test():
|
||||
multi_dict = request.values
|
||||
for key in multi_dict:
|
||||
print(multi_dict.get(key))
|
||||
print(multi_dict.getlist(key))
|
||||
# print(request.values)
|
||||
print(request.headers)
|
||||
print(request.authorization)
|
||||
|
||||
return 'error=0&text=Login OK&details=none'
|
||||
|
||||
|
||||
@app.route('/CMD_API_DNS_ADMIN', methods=['GET', 'POST'])
|
||||
def domain_admin():
|
||||
print(str(request.data, encoding="utf-8"))
|
||||
print(request.values.get('action'))
|
||||
action = request.values.get('action')
|
||||
if action == 'exists':
|
||||
# DirectAdmin is checking whether the domain is in the cluster
|
||||
return 'result: exists=1'
|
||||
if action == 'delete':
|
||||
# Domain is being removed from the DNS
|
||||
hostname = request.values.get('hostname')
|
||||
username = request.values.get('username')
|
||||
domain = request.values.get('select0')
|
||||
|
||||
|
||||
if action == 'rawsave':
|
||||
# DirectAdmin wants to add/update a domain
|
||||
hostname = request.values.get('hostname')
|
||||
username = request.values.get('username')
|
||||
domain = request.values.get('domain')
|
||||
|
||||
if not check_zone_exists(str(domain)):
|
||||
put_zone_index(str(domain))
|
||||
write_zone_file(str(domain), request.data.decode("utf-8"))
|
||||
else:
|
||||
# Domain already exists
|
||||
write_zone_file(str(domain), request.data.decode("utf-8"))
|
||||
|
||||
|
||||
def create_zone_index():
|
||||
# Create an index of all zones present from zone definitions
|
||||
regex = r"(?<=\")(?P<domain>.*)(?=\"\s)"
|
||||
|
||||
with open(zone_index_file, 'w+') as f:
|
||||
with open(named_conf, 'r') as named_file:
|
||||
while True:
|
||||
# read line
|
||||
line = named_file.readline()
|
||||
if not line:
|
||||
# Reached end of file
|
||||
break
|
||||
print(line)
|
||||
hosted_domain = re.search(regex, line).group(0)
|
||||
f.write(hosted_domain + "\n")
|
||||
|
||||
|
||||
def put_zone_index(zone_name):
|
||||
# add a new zone to index
|
||||
with open(zone_index_file, 'a+') as f:
|
||||
# We are using append mode
|
||||
f.write(zone_name)
|
||||
|
||||
|
||||
def write_zone_file(zone_name, data):
|
||||
# Write the zone to file
|
||||
with open(zones_dir + '/' + zone_name + '.db', 'w') as f:
|
||||
f.write(data)
|
||||
|
||||
|
||||
def check_zone_exists(zone_name):
|
||||
# Check if zone is present in the index
|
||||
with open(zone_index_file, 'r') as f:
|
||||
try:
|
||||
s = mmap.mmap(f.fileno(), 0, access=mmap.ACCESS_READ)
|
||||
if s.find(bytes(zone_name, encoding='utf8')) != -1:
|
||||
return True
|
||||
else:
|
||||
return False
|
||||
except ValueError as e:
|
||||
# File Empty?
|
||||
return False
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
zones_dir = "/etc/pdns/zones"
|
||||
zone_index_file = "/etc/pdns/zones/.index"
|
||||
named_conf = "/etc/pdns/named.conf"
|
||||
create_zone_index()
|
||||
|
||||
app.run(host="0.0.0.0")
|
||||
@@ -1 +0,0 @@
|
||||
{}
|
||||
@@ -1,29 +0,0 @@
|
||||
---
|
||||
timezone: Pacific/Auckland
|
||||
log_level: INFO
|
||||
queue_location: ./data/queues
|
||||
|
||||
dns:
|
||||
# default_backend: coredns_mysql
|
||||
backends:
|
||||
bind_backend:
|
||||
type: bind
|
||||
enabled: false
|
||||
zones_dir: /etc/named/zones/dadns
|
||||
named_conf: /etc/bind/named.conf.local
|
||||
coredns_primary:
|
||||
enabled: true
|
||||
host: "mysql" # Matches Docker service name
|
||||
port: 3306
|
||||
database: "coredns"
|
||||
username: "coredns"
|
||||
password: "coredns123"
|
||||
table_name: "records"
|
||||
coredns_secondary:
|
||||
enabled: false
|
||||
host: "mysql" # Matches Docker service name
|
||||
port: 3306
|
||||
database: "coredns"
|
||||
username: "coredns"
|
||||
password: "coredns123"
|
||||
table_name: "records"
|
||||
@@ -113,19 +113,23 @@ class DNSAdminAPI:
|
||||
|
||||
if domain_exists:
|
||||
record = get_domain_record(domain)
|
||||
return urlencode({
|
||||
"error": 0,
|
||||
"exists": 1,
|
||||
"details": f"Domain exists on {record.hostname}",
|
||||
})
|
||||
return urlencode(
|
||||
{
|
||||
"error": 0,
|
||||
"exists": 1,
|
||||
"details": f"Domain exists on {record.hostname}",
|
||||
}
|
||||
)
|
||||
|
||||
# parent match only
|
||||
parent_record = get_parent_domain_record(domain)
|
||||
return urlencode({
|
||||
"error": 0,
|
||||
"exists": 2,
|
||||
"details": f"Parent Domain exists on {parent_record.hostname}",
|
||||
})
|
||||
return urlencode(
|
||||
{
|
||||
"error": 0,
|
||||
"exists": 2,
|
||||
"details": f"Parent Domain exists on {parent_record.hostname}",
|
||||
}
|
||||
)
|
||||
|
||||
def _handle_rawsave(self, domain: str, params: dict):
|
||||
"""Process zone file saves"""
|
||||
|
||||
@@ -7,7 +7,6 @@ from dns import zone as dns_zone_module
|
||||
from dns.rdataclass import IN
|
||||
from loguru import logger
|
||||
from .base import DNSBackend
|
||||
from config import config
|
||||
|
||||
Base = declarative_base()
|
||||
|
||||
@@ -120,9 +119,7 @@ class CoreDNSMySQLBackend(DNSBackend):
|
||||
)
|
||||
session.add(existing_soa)
|
||||
changes["added"] += 1
|
||||
logger.debug(
|
||||
f"Added SOA record: {soa_name} SOA {soa_content}"
|
||||
)
|
||||
logger.debug(f"Added SOA record: {soa_name} SOA {soa_content}")
|
||||
|
||||
# Process all non-SOA records
|
||||
for record_name, record_type, record_content, record_ttl in source_records:
|
||||
@@ -173,7 +170,7 @@ class CoreDNSMySQLBackend(DNSBackend):
|
||||
changes["removed"] += 1
|
||||
|
||||
session.commit()
|
||||
total_changes = changes['added'] + changes['updated'] + changes['removed']
|
||||
total_changes = changes["added"] + changes["updated"] + changes["removed"]
|
||||
if total_changes > 0:
|
||||
logger.info(
|
||||
f"[{self.instance_name}] Zone {zone_name} updated: "
|
||||
@@ -181,9 +178,7 @@ class CoreDNSMySQLBackend(DNSBackend):
|
||||
f"{changes['removed']} removed"
|
||||
)
|
||||
else:
|
||||
logger.debug(
|
||||
f"[{self.instance_name}] Zone {zone_name}: no changes"
|
||||
)
|
||||
logger.debug(f"[{self.instance_name}] Zone {zone_name}: no changes")
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
@@ -197,7 +192,11 @@ class CoreDNSMySQLBackend(DNSBackend):
|
||||
session = self.Session()
|
||||
try:
|
||||
# First find the zone
|
||||
zone = session.query(Zone).filter_by(name=zone_name).first()
|
||||
zone = (
|
||||
session.query(Zone)
|
||||
.filter_by(zone_name=self.dot_fqdn(zone_name))
|
||||
.first()
|
||||
)
|
||||
if not zone:
|
||||
logger.warning(f"Zone {zone_name} not found for deletion")
|
||||
return False
|
||||
@@ -231,7 +230,9 @@ class CoreDNSMySQLBackend(DNSBackend):
|
||||
session = self.Session()
|
||||
try:
|
||||
exists = (
|
||||
session.query(Zone).filter_by(name=self.dot_fqdn(zone_name)).first()
|
||||
session.query(Zone)
|
||||
.filter_by(zone_name=self.dot_fqdn(zone_name))
|
||||
.first()
|
||||
is not None
|
||||
)
|
||||
logger.debug(f"Zone existence check for {zone_name}: {exists}")
|
||||
@@ -267,17 +268,11 @@ class CoreDNSMySQLBackend(DNSBackend):
|
||||
The normalized CNAME target string
|
||||
"""
|
||||
if record_content.startswith("@"):
|
||||
logger.debug(
|
||||
f"CNAME target starts with '@', replacing with zone FQDN"
|
||||
)
|
||||
logger.debug(f"CNAME target starts with '@', replacing with zone FQDN")
|
||||
record_content = self.dot_fqdn(zone_name)
|
||||
elif not record_content.endswith("."):
|
||||
logger.debug(
|
||||
f"CNAME target {record_content} is relative, appending zone"
|
||||
)
|
||||
record_content = ".".join(
|
||||
[record_content, self.dot_fqdn(zone_name)]
|
||||
)
|
||||
logger.debug(f"CNAME target {record_content} is relative, appending zone")
|
||||
record_content = ".".join([record_content, self.dot_fqdn(zone_name)])
|
||||
return record_content
|
||||
|
||||
def _parse_zone_to_record_set(
|
||||
@@ -307,9 +302,7 @@ class CoreDNSMySQLBackend(DNSBackend):
|
||||
continue
|
||||
|
||||
if record_type == "CNAME":
|
||||
record_content = self._normalize_cname_data(
|
||||
zone_name, record_content
|
||||
)
|
||||
record_content = self._normalize_cname_data(zone_name, record_content)
|
||||
|
||||
records.add((record_name, record_type, record_content, ttl))
|
||||
|
||||
@@ -342,9 +335,7 @@ class CoreDNSMySQLBackend(DNSBackend):
|
||||
)
|
||||
return False, 0
|
||||
|
||||
actual_count = (
|
||||
session.query(Record).filter_by(zone_id=zone.id).count()
|
||||
)
|
||||
actual_count = session.query(Record).filter_by(zone_id=zone.id).count()
|
||||
matches = actual_count == expected_count
|
||||
|
||||
if not matches:
|
||||
@@ -410,14 +401,11 @@ class CoreDNSMySQLBackend(DNSBackend):
|
||||
)
|
||||
# Build lookup keys (without TTL) matching write_zone's key format
|
||||
expected_keys: Set[Tuple[str, str, str]] = {
|
||||
(hostname, rtype, data)
|
||||
for hostname, rtype, data, _ in source_records
|
||||
(hostname, rtype, data) for hostname, rtype, data, _ in source_records
|
||||
}
|
||||
|
||||
# Query all records currently in the backend for this zone
|
||||
db_records = (
|
||||
session.query(Record).filter_by(zone_id=zone.id).all()
|
||||
)
|
||||
db_records = session.query(Record).filter_by(zone_id=zone.id).all()
|
||||
|
||||
removed = 0
|
||||
for record in db_records:
|
||||
|
||||
@@ -27,6 +27,8 @@ class ReconciliationWorker:
|
||||
self.interval_seconds = reconciliation_config.get("interval_minutes", 60) * 60
|
||||
self.servers = reconciliation_config.get("directadmin_servers") or []
|
||||
self.verify_ssl = reconciliation_config.get("verify_ssl", True)
|
||||
self.ipp = int(reconciliation_config.get("ipp", 1000))
|
||||
self.dry_run = bool(reconciliation_config.get("dry_run", False))
|
||||
self._stop_event = threading.Event()
|
||||
self._thread = None
|
||||
|
||||
@@ -46,11 +48,16 @@ class ReconciliationWorker:
|
||||
)
|
||||
self._thread.start()
|
||||
server_names = [s.get("hostname", "?") for s in self.servers]
|
||||
mode = "DRY-RUN" if self.dry_run else "LIVE"
|
||||
logger.info(
|
||||
f"Reconciliation poller started — "
|
||||
f"Reconciliation poller started [{mode}] — "
|
||||
f"interval: {self.interval_seconds // 60}m, "
|
||||
f"servers: {server_names}"
|
||||
)
|
||||
if self.dry_run:
|
||||
logger.warning(
|
||||
"[reconciler] DRY-RUN mode active — orphans will be logged but NOT queued for deletion"
|
||||
)
|
||||
|
||||
def stop(self):
|
||||
self._stop_event.set()
|
||||
@@ -93,6 +100,7 @@ class ReconciliationWorker:
|
||||
server.get("username"),
|
||||
server.get("password"),
|
||||
server.get("ssl", True),
|
||||
ipp=self.ipp,
|
||||
)
|
||||
if da_domains is not None:
|
||||
for d in da_domains:
|
||||
@@ -101,106 +109,86 @@ class ReconciliationWorker:
|
||||
f"[reconciler] {hostname}: {len(da_domains) if da_domains else 0} active domain(s) in DA"
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"[reconciler] Unexpected error polling {hostname}: {e}"
|
||||
)
|
||||
logger.error(f"[reconciler] Unexpected error polling {hostname}: {e}")
|
||||
|
||||
# Now check local DB for all domains, update master if needed, and queue deletes only from recorded master
|
||||
session = connect()
|
||||
all_local_domains = session.query(Domain).all()
|
||||
migrated = 0
|
||||
for record in all_local_domains:
|
||||
domain = record.domain
|
||||
recorded_master = record.hostname
|
||||
actual_master = all_da_domains.get(domain)
|
||||
if actual_master:
|
||||
if actual_master != recorded_master:
|
||||
logger.warning(
|
||||
f"[reconciler] Domain '{domain}' migrated: recorded master '{recorded_master}' -> new master '{actual_master}'. Updating local DB."
|
||||
try:
|
||||
all_local_domains = session.query(Domain).all()
|
||||
migrated = 0
|
||||
backfilled = 0
|
||||
known_servers = {s.get("hostname") for s in self.servers}
|
||||
for record in all_local_domains:
|
||||
domain = record.domain
|
||||
recorded_master = record.hostname
|
||||
actual_master = all_da_domains.get(domain)
|
||||
if actual_master:
|
||||
if not recorded_master:
|
||||
logger.info(
|
||||
f"[reconciler] Domain '{domain}' hostname backfilled: '{actual_master}'"
|
||||
)
|
||||
record.hostname = actual_master
|
||||
backfilled += 1
|
||||
elif actual_master != recorded_master:
|
||||
logger.warning(
|
||||
f"[reconciler] Domain '{domain}' migrated: "
|
||||
f"'{recorded_master}' -> '{actual_master}'. Updating local DB."
|
||||
)
|
||||
record.hostname = actual_master
|
||||
migrated += 1
|
||||
else:
|
||||
# Only act if the recorded master is one we're polling
|
||||
if recorded_master in known_servers:
|
||||
if self.dry_run:
|
||||
logger.warning(
|
||||
f"[reconciler] [DRY-RUN] Would delete orphan: {record.domain} "
|
||||
f"(master: {recorded_master})"
|
||||
)
|
||||
else:
|
||||
self.delete_queue.put(
|
||||
{
|
||||
"domain": record.domain,
|
||||
"hostname": record.hostname,
|
||||
"username": record.username or "",
|
||||
"source": "reconciler",
|
||||
}
|
||||
)
|
||||
logger.debug(
|
||||
f"[reconciler] Queued delete for orphan: {record.domain} "
|
||||
f"(master: {recorded_master})"
|
||||
)
|
||||
total_queued += 1
|
||||
if migrated or backfilled:
|
||||
session.commit()
|
||||
if backfilled:
|
||||
logger.info(
|
||||
f"[reconciler] {backfilled} domain(s) had missing hostname backfilled."
|
||||
)
|
||||
record.hostname = actual_master
|
||||
migrated += 1
|
||||
else:
|
||||
# Only queue delete if this is the recorded master
|
||||
if recorded_master in [s.get("hostname") for s in self.servers]:
|
||||
self.delete_queue.put({
|
||||
"domain": record.domain,
|
||||
"hostname": record.hostname,
|
||||
"username": record.username or "",
|
||||
"source": "reconciler",
|
||||
})
|
||||
logger.debug(
|
||||
f"[reconciler] Queued delete for orphan: {record.domain} (master: {recorded_master})"
|
||||
if migrated:
|
||||
logger.info(
|
||||
f"[reconciler] {migrated} domain(s) migrated to new master."
|
||||
)
|
||||
total_queued += 1
|
||||
if migrated:
|
||||
session.commit()
|
||||
logger.info(f"[reconciler] {migrated} domain(s) migrated to new master and updated in DB.")
|
||||
logger.info(
|
||||
f"[reconciler] Reconciliation pass complete — "
|
||||
f"{total_queued} domain(s) queued for deletion"
|
||||
)
|
||||
|
||||
def _reconcile_server(self, server: dict) -> int:
|
||||
"""Reconcile one DA server. Returns number of domains queued for delete."""
|
||||
hostname = server["hostname"]
|
||||
port = server.get("port", 2222)
|
||||
username = server.get("username")
|
||||
password = server.get("password")
|
||||
use_ssl = server.get("ssl", True)
|
||||
|
||||
logger.info(f"[reconciler] Polling {hostname}:{port}")
|
||||
|
||||
da_domains = self._fetch_da_domains(
|
||||
hostname, port, username, password, use_ssl
|
||||
)
|
||||
if da_domains is None:
|
||||
# Fetch failed — never delete on uncertainty
|
||||
return 0
|
||||
|
||||
logger.debug(
|
||||
f"[reconciler] {hostname}: {len(da_domains)} active domain(s) in DA"
|
||||
)
|
||||
|
||||
session = connect()
|
||||
our_domains = session.query(Domain).filter_by(hostname=hostname).all()
|
||||
|
||||
if not our_domains:
|
||||
logger.debug(
|
||||
f"[reconciler] {hostname}: no domains registered from this server"
|
||||
)
|
||||
return 0
|
||||
|
||||
orphans = [d for d in our_domains if d.domain not in da_domains]
|
||||
|
||||
if not orphans:
|
||||
finally:
|
||||
session.close()
|
||||
if self.dry_run:
|
||||
logger.info(
|
||||
f"[reconciler] {hostname}: all {len(our_domains)} registered "
|
||||
f"domain(s) confirmed active in DA"
|
||||
f"[reconciler] Reconciliation pass complete [DRY-RUN] — "
|
||||
f"{total_queued} orphan(s) identified (none deleted)"
|
||||
)
|
||||
return 0
|
||||
|
||||
logger.warning(
|
||||
f"[reconciler] {hostname}: {len(orphans)} orphaned domain(s) "
|
||||
f"no longer in DA — queuing for deletion: "
|
||||
f"{[d.domain for d in orphans]}"
|
||||
)
|
||||
|
||||
for record in orphans:
|
||||
self.delete_queue.put({
|
||||
"domain": record.domain,
|
||||
"hostname": record.hostname,
|
||||
"username": record.username or "",
|
||||
"source": "reconciler",
|
||||
})
|
||||
logger.debug(
|
||||
f"[reconciler] Queued delete for orphan: {record.domain}"
|
||||
else:
|
||||
logger.info(
|
||||
f"[reconciler] Reconciliation pass complete — "
|
||||
f"{total_queued} domain(s) queued for deletion"
|
||||
)
|
||||
|
||||
return len(orphans)
|
||||
|
||||
def _fetch_da_domains(
|
||||
self, hostname: str, port: int, username: str, password: str, use_ssl: bool, ipp: int = 1000
|
||||
self,
|
||||
hostname: str,
|
||||
port: int,
|
||||
username: str,
|
||||
password: str,
|
||||
use_ssl: bool,
|
||||
ipp: int = 1000,
|
||||
):
|
||||
"""Fetch all domains from a DA server via CMD_DNS_ADMIN (JSON, paging supported).
|
||||
|
||||
@@ -227,13 +215,21 @@ class ReconciliationWorker:
|
||||
|
||||
response = requests.get(url, **req_kwargs)
|
||||
|
||||
if response.is_redirect or response.status_code in (301, 302, 303, 307, 308):
|
||||
if response.is_redirect or response.status_code in (
|
||||
301,
|
||||
302,
|
||||
303,
|
||||
307,
|
||||
308,
|
||||
):
|
||||
if not cookies:
|
||||
logger.debug(
|
||||
f"[reconciler] {hostname}:{port} redirected Basic Auth "
|
||||
f"(HTTP {response.status_code}) — attempting session login (DA Evo)"
|
||||
)
|
||||
cookies = self._da_session_login(scheme, hostname, port, username, password)
|
||||
cookies = self._da_session_login(
|
||||
scheme, hostname, port, username, password
|
||||
)
|
||||
if cookies is None:
|
||||
return None
|
||||
continue # retry this page with cookies
|
||||
@@ -265,7 +261,10 @@ class ReconciliationWorker:
|
||||
total_pages = int(info.get("total_pages", 1))
|
||||
page += 1
|
||||
continue
|
||||
except Exception:
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"[reconciler] JSON decode failed for {hostname}:{port} page {page}: {e}\nRaw response: {response.text[:500]}"
|
||||
)
|
||||
# Fallback to legacy parser
|
||||
domains = self._parse_da_domain_list(response.text)
|
||||
all_domains.update(domains)
|
||||
@@ -298,9 +297,7 @@ class ReconciliationWorker:
|
||||
)
|
||||
return None
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"[reconciler] Unexpected error fetching from {hostname}: {e}"
|
||||
)
|
||||
logger.error(f"[reconciler] Unexpected error fetching from {hostname}: {e}")
|
||||
return None
|
||||
|
||||
def _da_session_login(
|
||||
@@ -329,12 +326,12 @@ class ReconciliationWorker:
|
||||
f"check username/password."
|
||||
)
|
||||
return None
|
||||
logger.debug(f"[reconciler] {hostname}:{port} session login successful (DA Evo)")
|
||||
logger.debug(
|
||||
f"[reconciler] {hostname}:{port} session login successful (DA Evo)"
|
||||
)
|
||||
return response.cookies
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"[reconciler] {hostname}:{port} session login failed: {e}"
|
||||
)
|
||||
logger.error(f"[reconciler] {hostname}:{port} session login failed: {e}")
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
@@ -353,24 +350,44 @@ class ReconciliationWorker:
|
||||
domains = params.get("list[]", [])
|
||||
return {d.strip().lower() for d in domains if d.strip()}
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import argparse
|
||||
import sys
|
||||
from queue import Queue
|
||||
|
||||
parser = argparse.ArgumentParser(description="Test DirectAdmin domain fetcher (JSON/paging)")
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Test DirectAdmin domain fetcher (JSON/paging)"
|
||||
)
|
||||
parser.add_argument("--hostname", required=True, help="DirectAdmin server hostname")
|
||||
parser.add_argument("--port", type=int, default=2222, help="DirectAdmin port (default: 2222)")
|
||||
parser.add_argument(
|
||||
"--port", type=int, default=2222, help="DirectAdmin port (default: 2222)"
|
||||
)
|
||||
parser.add_argument("--username", required=True, help="DirectAdmin admin username")
|
||||
parser.add_argument("--password", required=True, help="DirectAdmin admin password")
|
||||
parser.add_argument("--ssl", action="store_true", help="Use HTTPS (default: True)")
|
||||
parser.add_argument("--no-ssl", dest="ssl", action="store_false", help="Use HTTP (not recommended)")
|
||||
parser.add_argument(
|
||||
"--no-ssl", dest="ssl", action="store_false", help="Use HTTP (not recommended)"
|
||||
)
|
||||
parser.set_defaults(ssl=True)
|
||||
parser.add_argument("--verify-ssl", action="store_true", help="Verify SSL certs (default: True)")
|
||||
parser.add_argument("--no-verify-ssl", dest="verify_ssl", action="store_false", help="Don't verify SSL certs")
|
||||
parser.add_argument(
|
||||
"--verify-ssl", action="store_true", help="Verify SSL certs (default: True)"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--no-verify-ssl",
|
||||
dest="verify_ssl",
|
||||
action="store_false",
|
||||
help="Don't verify SSL certs",
|
||||
)
|
||||
parser.set_defaults(verify_ssl=True)
|
||||
parser.add_argument("--ipp", type=int, default=1000, help="Items per page (default: 1000)")
|
||||
parser.add_argument("--print-json", action="store_true", help="Print raw JSON response for first page")
|
||||
parser.add_argument(
|
||||
"--ipp", type=int, default=1000, help="Items per page (default: 1000)"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--print-json",
|
||||
action="store_true",
|
||||
help="Print raw JSON response for first page",
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
@@ -391,7 +408,9 @@ if __name__ == "__main__":
|
||||
q = Queue()
|
||||
worker = ReconciliationWorker(q, config)
|
||||
server = config["directadmin_servers"][0]
|
||||
print(f"Fetching domains from {server['hostname']}:{server['port']} (ipp={args.ipp})...")
|
||||
print(
|
||||
f"Fetching domains from {server['hostname']}:{server['port']} (ipp={args.ipp})..."
|
||||
)
|
||||
# Directly call the fetch method for testing
|
||||
domains = worker._fetch_da_domains(
|
||||
server["hostname"],
|
||||
@@ -399,7 +418,7 @@ if __name__ == "__main__":
|
||||
server.get("username"),
|
||||
server.get("password"),
|
||||
server.get("ssl", True),
|
||||
ipp=args.ipp
|
||||
ipp=args.ipp,
|
||||
)
|
||||
if domains is None:
|
||||
print("Failed to fetch domains.", file=sys.stderr)
|
||||
|
||||
@@ -59,9 +59,7 @@ def count_zone_records(zone_data: str, domain_name: str) -> int:
|
||||
if rdata.rdclass == IN:
|
||||
count += 1
|
||||
|
||||
logger.debug(
|
||||
f"Source zone {domain_name} contains {count} records"
|
||||
)
|
||||
logger.debug(f"Source zone {domain_name} contains {count} records")
|
||||
return count
|
||||
|
||||
except DNSException as e:
|
||||
|
||||
@@ -10,6 +10,8 @@ from typing import Any, Dict
|
||||
def load_config() -> Vyper:
|
||||
# Initialize Vyper
|
||||
v.set_config_name("app") # Looks for app.yaml/app.yml
|
||||
# Bundled config colocated with this module (always present in the package)
|
||||
v.add_config_path(str(Path(__file__).parent))
|
||||
v.add_config_path(".") # Search in current directory
|
||||
v.add_config_path("./config")
|
||||
v.set_env_prefix("DADNS")
|
||||
@@ -54,11 +56,15 @@ def load_config() -> Vyper:
|
||||
|
||||
# Reconciliation poller defaults
|
||||
v.set_default("reconciliation.enabled", False)
|
||||
v.set_default("reconciliation.dry_run", False)
|
||||
v.set_default("reconciliation.interval_minutes", 60)
|
||||
v.set_default("reconciliation.verify_ssl", True)
|
||||
|
||||
# Read configuration
|
||||
if not v.read_in_config():
|
||||
try:
|
||||
if not v.read_in_config():
|
||||
logger.warning("No config file found, using defaults")
|
||||
except Exception:
|
||||
logger.warning("No config file found, using defaults")
|
||||
|
||||
return v
|
||||
|
||||
@@ -12,8 +12,10 @@ app:
|
||||
# If a DA server is unreachable, that server is skipped entirely.
|
||||
#reconciliation:
|
||||
# enabled: true
|
||||
# dry_run: true # log orphans but do NOT queue deletes — safe first-run mode
|
||||
# interval_minutes: 60
|
||||
# verify_ssl: true # set false for self-signed DA certs
|
||||
# ipp: 1000 # items per page when polling DA (default 1000)
|
||||
# directadmin_servers:
|
||||
# - hostname: da1.example.com
|
||||
# port: 2222
|
||||
|
||||
@@ -50,7 +50,9 @@ def main():
|
||||
|
||||
# Configure CherryPy
|
||||
user_password_dict = {
|
||||
config.get_string("app.auth_username"): config.get_string("app.auth_password")
|
||||
config.get_string("app.auth_username"): config.get_string(
|
||||
"app.auth_password"
|
||||
)
|
||||
}
|
||||
check_password = cherrypy.lib.auth_basic.checkpassword_dict(user_password_dict)
|
||||
|
||||
|
||||
@@ -14,7 +14,9 @@ from directdnsonly.app.reconciler import ReconciliationWorker
|
||||
|
||||
|
||||
class WorkerManager:
|
||||
def __init__(self, queue_path: str, backend_registry, reconciliation_config: dict = None):
|
||||
def __init__(
|
||||
self, queue_path: str, backend_registry, reconciliation_config: dict = None
|
||||
):
|
||||
self.queue_path = queue_path
|
||||
self.backend_registry = backend_registry
|
||||
self._running = False
|
||||
@@ -86,9 +88,7 @@ class WorkerManager:
|
||||
f"{len(backends)} backends concurrently: "
|
||||
f"{', '.join(backends.keys())}"
|
||||
)
|
||||
self._process_backends_parallel(
|
||||
backends, item, session
|
||||
)
|
||||
self._process_backends_parallel(backends, item, session)
|
||||
else:
|
||||
# Single backend, no need for thread overhead
|
||||
for backend_name, backend in backends.items():
|
||||
@@ -126,9 +126,7 @@ class WorkerManager:
|
||||
try:
|
||||
logger.debug(f"Using backend: {backend_name}")
|
||||
if backend.write_zone(item["domain"], item["zone_file"]):
|
||||
logger.debug(
|
||||
f"Successfully updated {item['domain']} in {backend_name}"
|
||||
)
|
||||
logger.debug(f"Successfully updated {item['domain']} in {backend_name}")
|
||||
if backend.get_name() == "bind":
|
||||
# Need to update the named.conf
|
||||
backend.update_named_conf(
|
||||
@@ -144,9 +142,7 @@ class WorkerManager:
|
||||
backend_name, backend, item["domain"], item["zone_file"]
|
||||
)
|
||||
else:
|
||||
logger.error(
|
||||
f"Failed to update {item['domain']} in {backend_name}"
|
||||
)
|
||||
logger.error(f"Failed to update {item['domain']} in {backend_name}")
|
||||
except Exception as e:
|
||||
logger.error(f"Error in {backend_name}: {str(e)}")
|
||||
|
||||
@@ -165,9 +161,7 @@ class WorkerManager:
|
||||
|
||||
record = session.query(Domain).filter_by(domain=domain).first()
|
||||
if not record:
|
||||
logger.warning(
|
||||
f"Domain {domain} not found in DB — skipping delete"
|
||||
)
|
||||
logger.warning(f"Domain {domain} not found in DB — skipping delete")
|
||||
self.delete_queue.task_done()
|
||||
continue
|
||||
|
||||
@@ -184,29 +178,83 @@ class WorkerManager:
|
||||
f"skipping ownership check, proceeding with delete"
|
||||
)
|
||||
|
||||
session.delete(record)
|
||||
session.commit()
|
||||
logger.info(f"Removed {domain} from database")
|
||||
|
||||
remaining_domains = [d.domain for d in session.query(Domain).all()]
|
||||
|
||||
backends = self.backend_registry.get_available_backends()
|
||||
remaining_domains = [d.domain for d in session.query(Domain).all()]
|
||||
delete_success = True
|
||||
if not backends:
|
||||
logger.warning(
|
||||
f"No active backends — {domain} removed from DB only"
|
||||
f"No active backends — {domain} will be removed from DB only"
|
||||
)
|
||||
elif len(backends) > 1:
|
||||
self._process_backends_delete_parallel(
|
||||
backends, domain, remaining_domains
|
||||
)
|
||||
else:
|
||||
for backend_name, backend in backends.items():
|
||||
self._delete_single_backend(
|
||||
backend_name, backend, domain, remaining_domains
|
||||
)
|
||||
# Parallel delete, track failures
|
||||
results = []
|
||||
|
||||
self.delete_queue.task_done()
|
||||
logger.success(f"Delete completed for {domain}")
|
||||
def delete_backend_wrapper(
|
||||
backend_name, backend, domain, remaining_domains
|
||||
):
|
||||
try:
|
||||
return backend.delete_zone(domain)
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Error deleting {domain} from {backend_name}: {e}"
|
||||
)
|
||||
return False
|
||||
|
||||
from concurrent.futures import ThreadPoolExecutor, as_completed
|
||||
|
||||
with ThreadPoolExecutor(max_workers=len(backends)) as executor:
|
||||
futures = {
|
||||
executor.submit(
|
||||
delete_backend_wrapper,
|
||||
backend_name,
|
||||
backend,
|
||||
domain,
|
||||
remaining_domains,
|
||||
): backend_name
|
||||
for backend_name, backend in backends.items()
|
||||
}
|
||||
for future in as_completed(futures):
|
||||
backend_name = futures[future]
|
||||
try:
|
||||
result = future.result()
|
||||
results.append(result)
|
||||
if not result:
|
||||
logger.error(
|
||||
f"Failed to delete {domain} from {backend_name}"
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Unhandled error deleting from {backend_name}: {e}"
|
||||
)
|
||||
results.append(False)
|
||||
delete_success = all(results)
|
||||
else:
|
||||
# Single backend
|
||||
for backend_name, backend in backends.items():
|
||||
try:
|
||||
result = backend.delete_zone(domain)
|
||||
if not result:
|
||||
logger.error(
|
||||
f"Failed to delete {domain} from {backend_name}"
|
||||
)
|
||||
delete_success = False
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Error deleting {domain} from {backend_name}: {e}"
|
||||
)
|
||||
delete_success = False
|
||||
|
||||
if delete_success:
|
||||
session.delete(record)
|
||||
session.commit()
|
||||
logger.info(f"Removed {domain} from database")
|
||||
self.delete_queue.task_done()
|
||||
logger.success(f"Delete completed for {domain}")
|
||||
else:
|
||||
logger.error(
|
||||
f"Delete failed for {domain} on one or more backends — DB record retained"
|
||||
)
|
||||
self.delete_queue.task_done()
|
||||
|
||||
except Empty:
|
||||
continue
|
||||
@@ -239,7 +287,10 @@ class WorkerManager:
|
||||
futures = {
|
||||
executor.submit(
|
||||
self._delete_single_backend,
|
||||
backend_name, backend, domain, remaining_domains
|
||||
backend_name,
|
||||
backend,
|
||||
domain,
|
||||
remaining_domains,
|
||||
): backend_name
|
||||
for backend_name, backend in backends.items()
|
||||
}
|
||||
@@ -248,9 +299,7 @@ class WorkerManager:
|
||||
try:
|
||||
future.result()
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Unhandled error deleting from {backend_name}: {e}"
|
||||
)
|
||||
logger.error(f"Unhandled error deleting from {backend_name}: {e}")
|
||||
elapsed = (time.monotonic() - start_time) * 1000
|
||||
logger.debug(
|
||||
f"Parallel delete of {domain} across "
|
||||
@@ -261,13 +310,11 @@ class WorkerManager:
|
||||
"""Process zone updates across multiple backends in parallel"""
|
||||
start_time = time.monotonic()
|
||||
with ThreadPoolExecutor(
|
||||
max_workers=len(backends),
|
||||
thread_name_prefix="backend"
|
||||
max_workers=len(backends), thread_name_prefix="backend"
|
||||
) as executor:
|
||||
futures = {
|
||||
executor.submit(
|
||||
self._process_single_backend,
|
||||
backend_name, backend, item, session
|
||||
self._process_single_backend, backend_name, backend, item, session
|
||||
): backend_name
|
||||
for backend_name, backend in backends.items()
|
||||
}
|
||||
@@ -286,9 +333,7 @@ class WorkerManager:
|
||||
f"{len(backends)} backends completed in {elapsed:.0f}ms"
|
||||
)
|
||||
|
||||
def _verify_backend_record_count(
|
||||
self, backend_name, backend, zone_name, zone_data
|
||||
):
|
||||
def _verify_backend_record_count(self, backend_name, backend, zone_name, zone_data):
|
||||
"""Verify and reconcile the backend record count against the
|
||||
authoritative BIND zone from DirectAdmin.
|
||||
|
||||
@@ -313,9 +358,7 @@ class WorkerManager:
|
||||
)
|
||||
return
|
||||
|
||||
matches, actual = backend.verify_zone_record_count(
|
||||
zone_name, expected
|
||||
)
|
||||
matches, actual = backend.verify_zone_record_count(zone_name, expected)
|
||||
|
||||
if matches:
|
||||
return # All good
|
||||
@@ -326,9 +369,7 @@ class WorkerManager:
|
||||
f"record(s) for {zone_name} — reconciling against "
|
||||
f"DirectAdmin source zone"
|
||||
)
|
||||
success, removed = backend.reconcile_zone_records(
|
||||
zone_name, zone_data
|
||||
)
|
||||
success, removed = backend.reconcile_zone_records(zone_name, zone_data)
|
||||
if success and removed > 0:
|
||||
# Verify again after reconciliation
|
||||
matches, new_count = backend.verify_zone_record_count(
|
||||
@@ -406,6 +447,9 @@ class WorkerManager:
|
||||
"save_queue_size": self.save_queue.qsize(),
|
||||
"delete_queue_size": self.delete_queue.qsize(),
|
||||
"save_worker_alive": self._save_thread and self._save_thread.is_alive(),
|
||||
"delete_worker_alive": self._delete_thread and self._delete_thread.is_alive(),
|
||||
"reconciler_alive": self._reconciler.is_alive if self._reconciler else False,
|
||||
"delete_worker_alive": self._delete_thread
|
||||
and self._delete_thread.is_alive(),
|
||||
"reconciler_alive": (
|
||||
self._reconciler.is_alive if self._reconciler else False
|
||||
),
|
||||
}
|
||||
|
||||
109
justfile
109
justfile
@@ -1,17 +1,98 @@
|
||||
#!/usr/bin/env just --justfile
|
||||
# directdnsonly — developer task runner
|
||||
# Requires: just, pyenv, poetry
|
||||
|
||||
APP_NAME := "directdnsonly"
|
||||
|
||||
# Ensure pyenv shims and common install locations are on PATH so that `python`
|
||||
# resolves via pyenv (.python-version) and `poetry` is found without a full
|
||||
# shell init in every recipe.
|
||||
export PATH := env_var("HOME") + "/.pyenv/shims:" + env_var("HOME") + "/.pyenv/bin:" + env_var("HOME") + "/.local/bin:" + env_var("PATH")
|
||||
|
||||
# List available recipes (default)
|
||||
default:
|
||||
@just --list
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Setup
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
# Install all dependencies (including dev group)
|
||||
install:
|
||||
poetry install
|
||||
|
||||
# Install only production dependencies
|
||||
install-prod:
|
||||
poetry install --only main
|
||||
|
||||
# Show the Python interpreter that will be used
|
||||
which-python:
|
||||
@poetry run python --version
|
||||
@poetry run python -c "import sys; print(sys.executable)"
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Testing
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
# Run the full test suite
|
||||
test:
|
||||
poetry run pytest tests/ -v
|
||||
|
||||
# Run tests with terminal coverage report
|
||||
coverage:
|
||||
poetry run pytest tests/ -v --cov=directdnsonly --cov-report=term-missing
|
||||
|
||||
# Run tests with HTML coverage report (opens in browser)
|
||||
coverage-html:
|
||||
poetry run pytest tests/ --cov=directdnsonly --cov-report=html
|
||||
@echo "Coverage report: htmlcov/index.html"
|
||||
|
||||
# Run a single test file or pattern, e.g. just test-one test_reconciler
|
||||
test-one target:
|
||||
poetry run pytest tests/ -v -k "{{target}}"
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Code quality
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
# Format all source and test files with black
|
||||
fmt:
|
||||
poetry run black directdnsonly/ tests/
|
||||
|
||||
# Check formatting without making changes (CI-safe)
|
||||
fmt-check:
|
||||
poetry run black --check directdnsonly/ tests/
|
||||
|
||||
# CI gate — run fmt-check then test, fail fast
|
||||
ci: fmt-check test
|
||||
|
||||
# Start the application
|
||||
run:
|
||||
poetry run python -m directdnsonly
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Build
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
# Build a standalone binary with PyInstaller
|
||||
build:
|
||||
cd src && \
|
||||
pyinstaller \
|
||||
-p . \
|
||||
--hidden-import=json \
|
||||
--hidden-import=pyopenssl \
|
||||
--hidden-import=pymysql \
|
||||
--hidden-import=jaraco \
|
||||
--hidden-import=cheroot \
|
||||
--hidden-import=cheroot.ssl.pyopenssl \
|
||||
--hidden-import=cheroot.ssl.builtin \
|
||||
--hidden-import=lib \
|
||||
--hidden-import=os \
|
||||
--hidden-import=builtins \
|
||||
--noconfirm --onefile {{APP_NAME}}.py
|
||||
poetry run pyinstaller \
|
||||
--hidden-import=json \
|
||||
--hidden-import=pymysql \
|
||||
--hidden-import=cheroot \
|
||||
--hidden-import=cheroot.ssl.pyopenssl \
|
||||
--hidden-import=cheroot.ssl.builtin \
|
||||
--noconfirm --onefile \
|
||||
--name=directdnsonly \
|
||||
directdnsonly/main.py
|
||||
rm -f *.spec
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Clean
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
# Remove build artefacts, caches, and compiled bytecode
|
||||
clean:
|
||||
rm -rf dist/ build/*.spec .coverage htmlcov/ .pytest_cache/
|
||||
find . -type d -name __pycache__ -exec rm -rf {} + 2>/dev/null || true
|
||||
find . -name "*.pyc" -delete 2>/dev/null || true
|
||||
|
||||
40
tests/conftest.py
Normal file
40
tests/conftest.py
Normal file
@@ -0,0 +1,40 @@
|
||||
"""Shared test fixtures for directdnsonly test suite."""
|
||||
|
||||
import pytest
|
||||
from sqlalchemy import create_engine
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
|
||||
from directdnsonly.app.db import Base
|
||||
from directdnsonly.app.db.models import (
|
||||
Domain,
|
||||
Key,
|
||||
) # noqa: F401 — registers models with Base
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def engine():
|
||||
eng = create_engine("sqlite:///:memory:")
|
||||
Base.metadata.create_all(eng)
|
||||
yield eng
|
||||
eng.dispose()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def db_session(engine):
|
||||
session = sessionmaker(bind=engine)()
|
||||
yield session
|
||||
session.close()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def patch_connect(db_session, monkeypatch):
|
||||
"""Patch connect() at every call-site, returning the shared test session.
|
||||
|
||||
Modules that import connect() directly (e.g. utils, reconciler) are
|
||||
patched at their local name so the in-memory SQLite session is used
|
||||
instead of trying to read from vyper config.
|
||||
"""
|
||||
_factory = lambda: db_session # noqa: E731
|
||||
monkeypatch.setattr("directdnsonly.app.utils.connect", _factory)
|
||||
monkeypatch.setattr("directdnsonly.app.reconciler.connect", _factory)
|
||||
return db_session
|
||||
219
tests/test_admin_api.py
Normal file
219
tests/test_admin_api.py
Normal file
@@ -0,0 +1,219 @@
|
||||
"""Tests for directdnsonly.app.api.admin — DNSAdminAPI handler methods."""
|
||||
|
||||
import pytest
|
||||
from unittest.mock import MagicMock, patch
|
||||
from urllib.parse import parse_qs
|
||||
|
||||
import cherrypy
|
||||
|
||||
from directdnsonly.app.api.admin import DNSAdminAPI
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Fixtures
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def save_queue():
|
||||
return MagicMock()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def delete_queue():
|
||||
return MagicMock()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def api(save_queue, delete_queue):
|
||||
return DNSAdminAPI(save_queue, delete_queue, backend_registry=MagicMock())
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# CMD_API_LOGIN_TEST
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def test_login_test_returns_success(api):
|
||||
result = api.CMD_API_LOGIN_TEST()
|
||||
parsed = parse_qs(result)
|
||||
assert parsed["error"] == ["0"]
|
||||
assert parsed["text"] == ["Login OK"]
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# _handle_exists — GET action=exists
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def test_handle_exists_missing_domain_returns_error(api):
|
||||
with patch.object(cherrypy, "response", MagicMock()):
|
||||
result = api._handle_exists({"action": "exists"})
|
||||
parsed = parse_qs(result)
|
||||
assert parsed["error"] == ["1"]
|
||||
|
||||
|
||||
def test_handle_exists_unsupported_action_returns_error(api):
|
||||
with patch.object(cherrypy, "response", MagicMock()):
|
||||
result = api._handle_exists({"action": "rawsave"})
|
||||
parsed = parse_qs(result)
|
||||
assert parsed["error"] == ["1"]
|
||||
|
||||
|
||||
def test_handle_exists_domain_not_found(api):
|
||||
with (
|
||||
patch("directdnsonly.app.api.admin.check_zone_exists", return_value=False),
|
||||
patch(
|
||||
"directdnsonly.app.api.admin.check_parent_domain_owner", return_value=False
|
||||
),
|
||||
):
|
||||
result = api._handle_exists({"action": "exists", "domain": "example.com"})
|
||||
|
||||
parsed = parse_qs(result)
|
||||
assert parsed["error"] == ["0"]
|
||||
assert parsed["exists"] == ["0"]
|
||||
|
||||
|
||||
def test_handle_exists_domain_found(api):
|
||||
record = MagicMock()
|
||||
record.hostname = "da1.example.com"
|
||||
|
||||
with (
|
||||
patch("directdnsonly.app.api.admin.check_zone_exists", return_value=True),
|
||||
patch("directdnsonly.app.api.admin.get_domain_record", return_value=record),
|
||||
):
|
||||
result = api._handle_exists({"action": "exists", "domain": "example.com"})
|
||||
|
||||
parsed = parse_qs(result)
|
||||
assert parsed["error"] == ["0"]
|
||||
assert parsed["exists"] == ["1"]
|
||||
assert "da1.example.com" in parsed["details"][0]
|
||||
|
||||
|
||||
def test_handle_exists_parent_found(api):
|
||||
parent = MagicMock()
|
||||
parent.hostname = "da2.example.com"
|
||||
|
||||
with (
|
||||
patch("directdnsonly.app.api.admin.check_zone_exists", return_value=False),
|
||||
patch(
|
||||
"directdnsonly.app.api.admin.check_parent_domain_owner", return_value=True
|
||||
),
|
||||
patch(
|
||||
"directdnsonly.app.api.admin.get_parent_domain_record", return_value=parent
|
||||
),
|
||||
):
|
||||
result = api._handle_exists(
|
||||
{
|
||||
"action": "exists",
|
||||
"domain": "sub.example.com",
|
||||
"check_for_parent_domain": "1",
|
||||
}
|
||||
)
|
||||
|
||||
parsed = parse_qs(result)
|
||||
assert parsed["error"] == ["0"]
|
||||
assert parsed["exists"] == ["2"]
|
||||
assert "da2.example.com" in parsed["details"][0]
|
||||
|
||||
|
||||
def test_handle_exists_no_parent_check_when_flag_absent(api):
|
||||
"""check_parent_domain_owner should not be called if flag not set."""
|
||||
record = MagicMock()
|
||||
record.hostname = "da1.example.com"
|
||||
|
||||
with (
|
||||
patch("directdnsonly.app.api.admin.check_zone_exists", return_value=True),
|
||||
patch("directdnsonly.app.api.admin.check_parent_domain_owner") as mock_parent,
|
||||
patch("directdnsonly.app.api.admin.get_domain_record", return_value=record),
|
||||
):
|
||||
api._handle_exists({"action": "exists", "domain": "example.com"})
|
||||
|
||||
mock_parent.assert_not_called()
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# _handle_rawsave
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
SAMPLE_ZONE = "$ORIGIN example.com.\n$TTL 300\nexample.com. 300 IN A 1.2.3.4\n"
|
||||
|
||||
|
||||
def test_rawsave_enqueues_item(api, save_queue):
|
||||
with (
|
||||
patch(
|
||||
"directdnsonly.app.api.admin.validate_and_normalize_zone",
|
||||
return_value=SAMPLE_ZONE,
|
||||
),
|
||||
patch.object(cherrypy, "request", MagicMock(remote=MagicMock(ip="127.0.0.1"))),
|
||||
):
|
||||
result = api._handle_rawsave(
|
||||
"example.com",
|
||||
{
|
||||
"zone_file": SAMPLE_ZONE,
|
||||
"hostname": "da1.example.com",
|
||||
"username": "admin",
|
||||
},
|
||||
)
|
||||
|
||||
save_queue.put.assert_called_once()
|
||||
item = save_queue.put.call_args[0][0]
|
||||
assert item["domain"] == "example.com"
|
||||
assert item["hostname"] == "da1.example.com"
|
||||
assert item["username"] == "admin"
|
||||
assert item["client_ip"] == "127.0.0.1"
|
||||
|
||||
parsed = parse_qs(result)
|
||||
assert parsed["error"] == ["0"]
|
||||
|
||||
|
||||
def test_rawsave_missing_zone_file_raises(api):
|
||||
with patch.object(cherrypy, "request", MagicMock(remote=MagicMock(ip="127.0.0.1"))):
|
||||
with pytest.raises(ValueError, match="Missing zone file"):
|
||||
api._handle_rawsave("example.com", {})
|
||||
|
||||
|
||||
def test_rawsave_invalid_zone_raises(api):
|
||||
with (
|
||||
patch(
|
||||
"directdnsonly.app.api.admin.validate_and_normalize_zone",
|
||||
side_effect=ValueError("Invalid zone data: bad record"),
|
||||
),
|
||||
patch.object(cherrypy, "request", MagicMock(remote=MagicMock(ip="127.0.0.1"))),
|
||||
):
|
||||
with pytest.raises(ValueError, match="Invalid zone data"):
|
||||
api._handle_rawsave("example.com", {"zone_file": "garbage"})
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# _handle_delete
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def test_delete_enqueues_item(api, delete_queue):
|
||||
with patch.object(cherrypy, "request", MagicMock(remote=MagicMock(ip="10.0.0.1"))):
|
||||
result = api._handle_delete(
|
||||
"example.com",
|
||||
{
|
||||
"hostname": "da1.example.com",
|
||||
"username": "admin",
|
||||
},
|
||||
)
|
||||
|
||||
delete_queue.put.assert_called_once()
|
||||
item = delete_queue.put.call_args[0][0]
|
||||
assert item["domain"] == "example.com"
|
||||
assert item["hostname"] == "da1.example.com"
|
||||
assert item["client_ip"] == "10.0.0.1"
|
||||
|
||||
parsed = parse_qs(result)
|
||||
assert parsed["error"] == ["0"]
|
||||
|
||||
|
||||
def test_delete_missing_params_uses_empty_strings(api, delete_queue):
|
||||
with patch.object(cherrypy, "request", MagicMock(remote=MagicMock(ip="127.0.0.1"))):
|
||||
api._handle_delete("example.com", {})
|
||||
|
||||
item = delete_queue.put.call_args[0][0]
|
||||
assert item["hostname"] == ""
|
||||
assert item["username"] == ""
|
||||
@@ -1,47 +1,167 @@
|
||||
"""Tests for the CoreDNS MySQL backend (run against in-memory SQLite)."""
|
||||
|
||||
import pytest
|
||||
from sqlalchemy import create_engine
|
||||
from sqlalchemy.orm import scoped_session, sessionmaker
|
||||
|
||||
from directdnsonly.app.backends.coredns_mysql import CoreDNSMySQLBackend, CoreDNSRecord
|
||||
from loguru import logger
|
||||
from directdnsonly.app.backends.coredns_mysql import (
|
||||
Base,
|
||||
CoreDNSMySQLBackend,
|
||||
Record,
|
||||
Zone,
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Fixture — in-memory SQLite backend (bypasses real MySQL connection)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mysql_backend(tmp_path):
|
||||
# Setup in-memory SQLite for testing (replace with test MySQL in CI)
|
||||
def mysql_backend():
|
||||
engine = create_engine("sqlite:///:memory:")
|
||||
CoreDNSRecord.metadata.create_all(engine)
|
||||
Base.metadata.create_all(engine)
|
||||
|
||||
class TestBackend(CoreDNSMySQLBackend):
|
||||
class _TestBackend(CoreDNSMySQLBackend):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
# Manually initialise without triggering the MySQL create_engine call
|
||||
self.config = {}
|
||||
self.instance_name = "test"
|
||||
self.engine = engine
|
||||
self.Session = scoped_session(sessionmaker(bind=engine))
|
||||
|
||||
yield TestBackend()
|
||||
yield _TestBackend()
|
||||
engine.dispose()
|
||||
|
||||
|
||||
def test_zone_operations(mysql_backend):
|
||||
zone_data = """
|
||||
# ---------------------------------------------------------------------------
|
||||
# write_zone / zone_exists / delete_zone
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
ZONE_DATA = """\
|
||||
$ORIGIN example.com.
|
||||
$TTL 300
|
||||
example.com. 300 IN SOA ns.example.com. admin.example.com. (2023 3600 1800 604800 86400)
|
||||
example.com. 300 IN A 192.0.2.1
|
||||
"""
|
||||
# Test zone creation
|
||||
assert mysql_backend.write_zone("example.com", zone_data)
|
||||
|
||||
|
||||
def test_write_zone_creates_zone(mysql_backend):
|
||||
assert mysql_backend.write_zone("example.com", ZONE_DATA)
|
||||
|
||||
|
||||
def test_zone_exists_after_write(mysql_backend):
|
||||
mysql_backend.write_zone("example.com", ZONE_DATA)
|
||||
assert mysql_backend.zone_exists("example.com")
|
||||
|
||||
# Test record update
|
||||
updated_zone = """
|
||||
|
||||
def test_zone_does_not_exist_before_write(mysql_backend):
|
||||
assert not mysql_backend.zone_exists("missing.com")
|
||||
|
||||
|
||||
def test_write_zone_idempotent(mysql_backend):
|
||||
assert mysql_backend.write_zone("example.com", ZONE_DATA)
|
||||
assert mysql_backend.write_zone("example.com", ZONE_DATA)
|
||||
|
||||
|
||||
def test_write_zone_updates_records(mysql_backend):
|
||||
mysql_backend.write_zone("example.com", ZONE_DATA)
|
||||
|
||||
updated = """\
|
||||
$ORIGIN example.com.
|
||||
$TTL 300
|
||||
example.com. 3600 IN A 192.0.2.1
|
||||
example.com. 300 IN AAAA 2001:db8::1
|
||||
"""
|
||||
assert mysql_backend.write_zone("example.com", updated_zone)
|
||||
assert mysql_backend.write_zone("example.com", updated)
|
||||
|
||||
# Test record removal
|
||||
reduced_zone = "example.com. 300 IN A 192.0.2.1"
|
||||
assert mysql_backend.write_zone("example.com", reduced_zone)
|
||||
|
||||
# Test zone deletion
|
||||
def test_write_zone_removes_stale_records(mysql_backend):
|
||||
mysql_backend.write_zone("example.com", ZONE_DATA)
|
||||
|
||||
reduced = "example.com. 300 IN A 192.0.2.1"
|
||||
mysql_backend.write_zone("example.com", reduced)
|
||||
|
||||
session = mysql_backend.Session()
|
||||
zone = session.query(Zone).filter_by(zone_name="example.com.").first()
|
||||
records = session.query(Record).filter_by(zone_id=zone.id, type="AAAA").all()
|
||||
assert records == []
|
||||
session.close()
|
||||
|
||||
|
||||
def test_delete_zone_removes_zone_and_records(mysql_backend):
|
||||
mysql_backend.write_zone("example.com", ZONE_DATA)
|
||||
assert mysql_backend.delete_zone("example.com")
|
||||
assert not mysql_backend.zone_exists("example.com")
|
||||
|
||||
|
||||
def test_delete_nonexistent_zone_returns_false(mysql_backend):
|
||||
assert not mysql_backend.delete_zone("ghost.com")
|
||||
|
||||
|
||||
def test_reload_zone_returns_true(mysql_backend):
|
||||
assert mysql_backend.reload_zone("example.com")
|
||||
assert mysql_backend.reload_zone()
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# verify_zone_record_count
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def test_verify_zone_record_count_match(mysql_backend):
|
||||
mysql_backend.write_zone("example.com", ZONE_DATA)
|
||||
# SOA + A = 2 records total
|
||||
matches, count = mysql_backend.verify_zone_record_count("example.com", 2)
|
||||
assert matches
|
||||
assert count == 2
|
||||
|
||||
|
||||
def test_verify_zone_record_count_mismatch(mysql_backend):
|
||||
mysql_backend.write_zone("example.com", ZONE_DATA)
|
||||
matches, count = mysql_backend.verify_zone_record_count("example.com", 99)
|
||||
assert not matches
|
||||
assert count == 2
|
||||
|
||||
|
||||
def test_verify_zone_record_count_missing_zone(mysql_backend):
|
||||
matches, count = mysql_backend.verify_zone_record_count("ghost.com", 0)
|
||||
assert not matches
|
||||
assert count == 0
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# reconcile_zone_records
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def test_reconcile_removes_extra_records(mysql_backend):
|
||||
mysql_backend.write_zone("example.com", ZONE_DATA)
|
||||
|
||||
# Inject a phantom record directly into the DB
|
||||
session = mysql_backend.Session()
|
||||
zone = session.query(Zone).filter_by(zone_name="example.com.").first()
|
||||
session.add(
|
||||
Record(
|
||||
zone_id=zone.id,
|
||||
hostname="phantom",
|
||||
type="A",
|
||||
data="10.0.0.99",
|
||||
ttl=300,
|
||||
online=True,
|
||||
)
|
||||
)
|
||||
session.commit()
|
||||
session.close()
|
||||
|
||||
success, removed = mysql_backend.reconcile_zone_records("example.com", ZONE_DATA)
|
||||
assert success
|
||||
assert removed == 1
|
||||
|
||||
|
||||
def test_reconcile_no_changes_when_zone_matches(mysql_backend):
|
||||
mysql_backend.write_zone("example.com", ZONE_DATA)
|
||||
success, removed = mysql_backend.reconcile_zone_records("example.com", ZONE_DATA)
|
||||
assert success
|
||||
assert removed == 0
|
||||
|
||||
299
tests/test_reconciler.py
Normal file
299
tests/test_reconciler.py
Normal file
@@ -0,0 +1,299 @@
|
||||
"""Tests for directdnsonly.app.reconciler — ReconciliationWorker."""
|
||||
|
||||
import pytest
|
||||
import requests.exceptions
|
||||
from queue import Queue
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
from directdnsonly.app.reconciler import ReconciliationWorker
|
||||
from directdnsonly.app.db.models import Domain
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Fixtures
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
SERVER = {
|
||||
"hostname": "da1.example.com",
|
||||
"port": 2222,
|
||||
"username": "admin",
|
||||
"password": "secret",
|
||||
"ssl": True,
|
||||
}
|
||||
|
||||
BASE_CONFIG = {
|
||||
"enabled": True,
|
||||
"dry_run": False,
|
||||
"interval_minutes": 60,
|
||||
"verify_ssl": True,
|
||||
"ipp": 100,
|
||||
"directadmin_servers": [SERVER],
|
||||
}
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def delete_queue():
|
||||
return Queue()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def worker(delete_queue):
|
||||
return ReconciliationWorker(delete_queue, BASE_CONFIG)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def dry_run_worker(delete_queue):
|
||||
cfg = {**BASE_CONFIG, "dry_run": True}
|
||||
return ReconciliationWorker(delete_queue, cfg)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# _reconcile_all — orphan detection
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def test_orphan_queued_when_domain_missing_from_da(worker, delete_queue, patch_connect):
|
||||
patch_connect.add(
|
||||
Domain(domain="orphan.com", hostname="da1.example.com", username="admin")
|
||||
)
|
||||
patch_connect.commit()
|
||||
|
||||
with patch.object(worker, "_fetch_da_domains", return_value=set()):
|
||||
worker._reconcile_all()
|
||||
|
||||
assert not delete_queue.empty()
|
||||
item = delete_queue.get_nowait()
|
||||
assert item["domain"] == "orphan.com"
|
||||
assert item["source"] == "reconciler"
|
||||
|
||||
|
||||
def test_orphan_not_queued_in_dry_run(dry_run_worker, delete_queue, patch_connect):
|
||||
patch_connect.add(
|
||||
Domain(domain="orphan.com", hostname="da1.example.com", username="admin")
|
||||
)
|
||||
patch_connect.commit()
|
||||
|
||||
with patch.object(dry_run_worker, "_fetch_da_domains", return_value=set()):
|
||||
dry_run_worker._reconcile_all()
|
||||
|
||||
assert delete_queue.empty()
|
||||
|
||||
|
||||
def test_orphan_not_queued_for_unknown_server(worker, delete_queue, patch_connect):
|
||||
"""Domains whose recorded master is NOT in our configured servers are skipped."""
|
||||
patch_connect.add(
|
||||
Domain(domain="other.com", hostname="da99.unknown.com", username="admin")
|
||||
)
|
||||
patch_connect.commit()
|
||||
|
||||
with patch.object(worker, "_fetch_da_domains", return_value=set()):
|
||||
worker._reconcile_all()
|
||||
|
||||
assert delete_queue.empty()
|
||||
|
||||
|
||||
def test_active_domain_not_queued(worker, delete_queue, patch_connect):
|
||||
patch_connect.add(
|
||||
Domain(domain="good.com", hostname="da1.example.com", username="admin")
|
||||
)
|
||||
patch_connect.commit()
|
||||
|
||||
with patch.object(worker, "_fetch_da_domains", return_value={"good.com"}):
|
||||
worker._reconcile_all()
|
||||
|
||||
assert delete_queue.empty()
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# _reconcile_all — hostname backfill and migration
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def test_backfill_null_hostname(worker, patch_connect):
|
||||
patch_connect.add(Domain(domain="backfill.com", hostname=None, username="admin"))
|
||||
patch_connect.commit()
|
||||
|
||||
with patch.object(worker, "_fetch_da_domains", return_value={"backfill.com"}):
|
||||
worker._reconcile_all()
|
||||
|
||||
record = patch_connect.query(Domain).filter_by(domain="backfill.com").first()
|
||||
assert record.hostname == "da1.example.com"
|
||||
|
||||
|
||||
def test_migration_updates_hostname(worker, patch_connect):
|
||||
patch_connect.add(
|
||||
Domain(domain="moved.com", hostname="da-old.example.com", username="admin")
|
||||
)
|
||||
patch_connect.commit()
|
||||
|
||||
with patch.object(worker, "_fetch_da_domains", return_value={"moved.com"}):
|
||||
worker._reconcile_all()
|
||||
|
||||
record = patch_connect.query(Domain).filter_by(domain="moved.com").first()
|
||||
assert record.hostname == "da1.example.com"
|
||||
|
||||
|
||||
def test_dry_run_still_backfills(dry_run_worker, patch_connect):
|
||||
"""Backfill is a data-repair operation, applied even in dry-run mode."""
|
||||
patch_connect.add(Domain(domain="fill.com", hostname=None, username="admin"))
|
||||
patch_connect.commit()
|
||||
|
||||
with patch.object(dry_run_worker, "_fetch_da_domains", return_value={"fill.com"}):
|
||||
dry_run_worker._reconcile_all()
|
||||
|
||||
record = patch_connect.query(Domain).filter_by(domain="fill.com").first()
|
||||
assert record.hostname == "da1.example.com"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# _fetch_da_domains — HTTP handling
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _make_json_response(domains_dict, total_pages=1):
|
||||
"""Return a mock requests.Response with JSON payload matching DA format."""
|
||||
data = {str(i): {"domain": d} for i, d in enumerate(domains_dict)}
|
||||
data["info"] = {"total_pages": total_pages}
|
||||
mock = MagicMock()
|
||||
mock.status_code = 200
|
||||
mock.is_redirect = False
|
||||
mock.headers = {"Content-Type": "application/json"}
|
||||
mock.json.return_value = data
|
||||
mock.raise_for_status = MagicMock()
|
||||
return mock
|
||||
|
||||
|
||||
def test_fetch_returns_domains_from_json(worker):
|
||||
mock_resp = _make_json_response(["example.com", "test.com"])
|
||||
|
||||
with patch("requests.get", return_value=mock_resp):
|
||||
result = worker._fetch_da_domains(
|
||||
"da1.example.com", 2222, "admin", "secret", True
|
||||
)
|
||||
|
||||
assert result == {"example.com", "test.com"}
|
||||
|
||||
|
||||
def test_fetch_paginates(worker):
|
||||
page1 = _make_json_response(["a.com"], total_pages=2)
|
||||
page2 = _make_json_response(["b.com"], total_pages=2)
|
||||
|
||||
with patch("requests.get", side_effect=[page1, page2]):
|
||||
result = worker._fetch_da_domains(
|
||||
"da1.example.com", 2222, "admin", "secret", True
|
||||
)
|
||||
|
||||
assert result == {"a.com", "b.com"}
|
||||
|
||||
|
||||
def test_fetch_redirect_triggers_session_login(worker):
|
||||
redirect_resp = MagicMock()
|
||||
redirect_resp.status_code = 302
|
||||
redirect_resp.is_redirect = True
|
||||
|
||||
with (
|
||||
patch("requests.get", return_value=redirect_resp),
|
||||
patch.object(worker, "_da_session_login", return_value=None),
|
||||
):
|
||||
result = worker._fetch_da_domains(
|
||||
"da1.example.com", 2222, "admin", "secret", True
|
||||
)
|
||||
|
||||
assert result is None
|
||||
|
||||
|
||||
def test_fetch_html_response_returns_none(worker):
|
||||
mock_resp = MagicMock()
|
||||
mock_resp.status_code = 200
|
||||
mock_resp.is_redirect = False
|
||||
mock_resp.headers = {"Content-Type": "text/html; charset=utf-8"}
|
||||
mock_resp.raise_for_status = MagicMock()
|
||||
|
||||
with patch("requests.get", return_value=mock_resp):
|
||||
result = worker._fetch_da_domains(
|
||||
"da1.example.com", 2222, "admin", "secret", True
|
||||
)
|
||||
|
||||
assert result is None
|
||||
|
||||
|
||||
def test_fetch_connection_error_returns_none(worker):
|
||||
with patch(
|
||||
"requests.get", side_effect=requests.exceptions.ConnectionError("refused")
|
||||
):
|
||||
result = worker._fetch_da_domains(
|
||||
"da1.example.com", 2222, "admin", "secret", True
|
||||
)
|
||||
|
||||
assert result is None
|
||||
|
||||
|
||||
def test_fetch_timeout_returns_none(worker):
|
||||
with patch("requests.get", side_effect=requests.exceptions.Timeout()):
|
||||
result = worker._fetch_da_domains(
|
||||
"da1.example.com", 2222, "admin", "secret", True
|
||||
)
|
||||
|
||||
assert result is None
|
||||
|
||||
|
||||
def test_fetch_ssl_error_returns_none(worker):
|
||||
with patch(
|
||||
"requests.get", side_effect=requests.exceptions.SSLError("cert verify failed")
|
||||
):
|
||||
result = worker._fetch_da_domains(
|
||||
"da1.example.com", 2222, "admin", "secret", True
|
||||
)
|
||||
|
||||
assert result is None
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# _parse_da_domain_list — legacy format fallback
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def test_parse_standard_querystring():
|
||||
body = "list[]=example.com&list[]=test.com"
|
||||
result = ReconciliationWorker._parse_da_domain_list(body)
|
||||
assert result == {"example.com", "test.com"}
|
||||
|
||||
|
||||
def test_parse_newline_separated():
|
||||
body = "list[]=example.com\nlist[]=test.com"
|
||||
result = ReconciliationWorker._parse_da_domain_list(body)
|
||||
assert result == {"example.com", "test.com"}
|
||||
|
||||
|
||||
def test_parse_empty_body_returns_empty_set():
|
||||
assert ReconciliationWorker._parse_da_domain_list("") == set()
|
||||
|
||||
|
||||
def test_parse_normalises_to_lowercase():
|
||||
result = ReconciliationWorker._parse_da_domain_list("list[]=EXAMPLE.COM")
|
||||
assert "example.com" in result
|
||||
assert "EXAMPLE.COM" not in result
|
||||
|
||||
|
||||
def test_parse_strips_whitespace():
|
||||
result = ReconciliationWorker._parse_da_domain_list("list[]= example.com ")
|
||||
assert "example.com" in result
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Worker lifecycle
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def test_disabled_worker_does_not_start(delete_queue):
|
||||
cfg = {**BASE_CONFIG, "enabled": False}
|
||||
w = ReconciliationWorker(delete_queue, cfg)
|
||||
w.start()
|
||||
assert not w.is_alive
|
||||
|
||||
|
||||
def test_no_servers_does_not_start(delete_queue):
|
||||
cfg = {**BASE_CONFIG, "directadmin_servers": []}
|
||||
w = ReconciliationWorker(delete_queue, cfg)
|
||||
w.start()
|
||||
assert not w.is_alive
|
||||
138
tests/test_utils.py
Normal file
138
tests/test_utils.py
Normal file
@@ -0,0 +1,138 @@
|
||||
"""Tests for directdnsonly.app.utils — zone index helper functions."""
|
||||
|
||||
import pytest
|
||||
|
||||
from directdnsonly.app.db.models import Domain
|
||||
from directdnsonly.app.utils import (
|
||||
check_zone_exists,
|
||||
check_parent_domain_owner,
|
||||
get_domain_record,
|
||||
get_parent_domain_record,
|
||||
put_zone_index,
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# check_zone_exists
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def test_check_zone_exists_not_found(patch_connect):
|
||||
assert check_zone_exists("example.com") is False
|
||||
|
||||
|
||||
def test_check_zone_exists_found(patch_connect):
|
||||
patch_connect.add(
|
||||
Domain(domain="example.com", hostname="da1.example.com", username="admin")
|
||||
)
|
||||
patch_connect.commit()
|
||||
|
||||
assert check_zone_exists("example.com") is True
|
||||
|
||||
|
||||
def test_check_zone_exists_does_not_match_partial(patch_connect):
|
||||
patch_connect.add(
|
||||
Domain(domain="example.com", hostname="da1.example.com", username="admin")
|
||||
)
|
||||
patch_connect.commit()
|
||||
|
||||
assert check_zone_exists("sub.example.com") is False
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# put_zone_index
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def test_put_zone_index_adds_record(patch_connect):
|
||||
put_zone_index("new.com", "da1.example.com", "admin")
|
||||
|
||||
record = patch_connect.query(Domain).filter_by(domain="new.com").first()
|
||||
assert record is not None
|
||||
assert record.hostname == "da1.example.com"
|
||||
assert record.username == "admin"
|
||||
|
||||
|
||||
def test_put_zone_index_stores_domain_name(patch_connect):
|
||||
put_zone_index("another.nz", "da2.example.com", "user1")
|
||||
|
||||
assert check_zone_exists("another.nz") is True
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# get_domain_record
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def test_get_domain_record_returns_none_when_missing(patch_connect):
|
||||
assert get_domain_record("missing.com") is None
|
||||
|
||||
|
||||
def test_get_domain_record_returns_record(patch_connect):
|
||||
patch_connect.add(
|
||||
Domain(domain="found.com", hostname="da1.example.com", username="admin")
|
||||
)
|
||||
patch_connect.commit()
|
||||
|
||||
record = get_domain_record("found.com")
|
||||
assert record is not None
|
||||
assert record.domain == "found.com"
|
||||
assert record.hostname == "da1.example.com"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# check_parent_domain_owner
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def test_check_parent_domain_owner_not_found(patch_connect):
|
||||
assert check_parent_domain_owner("sub.example.com") is False
|
||||
|
||||
|
||||
def test_check_parent_domain_owner_found(patch_connect):
|
||||
patch_connect.add(
|
||||
Domain(domain="example.com", hostname="da1.example.com", username="admin")
|
||||
)
|
||||
patch_connect.commit()
|
||||
|
||||
assert check_parent_domain_owner("sub.example.com") is True
|
||||
|
||||
|
||||
def test_check_parent_domain_owner_single_label_returns_false(patch_connect):
|
||||
# A single-label name like "com" has no parent
|
||||
assert check_parent_domain_owner("com") is False
|
||||
|
||||
|
||||
def test_check_parent_domain_owner_ignores_grandparent(patch_connect):
|
||||
# Only the immediate parent is checked, not grandparents
|
||||
patch_connect.add(
|
||||
Domain(domain="example.com", hostname="da1.example.com", username="admin")
|
||||
)
|
||||
patch_connect.commit()
|
||||
|
||||
# deep.sub.example.com's immediate parent is sub.example.com (not in DB)
|
||||
assert check_parent_domain_owner("deep.sub.example.com") is False
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# get_parent_domain_record
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def test_get_parent_domain_record_returns_none_when_missing(patch_connect):
|
||||
assert get_parent_domain_record("sub.example.com") is None
|
||||
|
||||
|
||||
def test_get_parent_domain_record_returns_parent(patch_connect):
|
||||
patch_connect.add(
|
||||
Domain(domain="example.com", hostname="da1.example.com", username="admin")
|
||||
)
|
||||
patch_connect.commit()
|
||||
|
||||
parent = get_parent_domain_record("sub.example.com")
|
||||
assert parent is not None
|
||||
assert parent.domain == "example.com"
|
||||
|
||||
|
||||
def test_get_parent_domain_record_single_label_returns_none(patch_connect):
|
||||
assert get_parent_domain_record("com") is None
|
||||
101
tests/test_zone_parser.py
Normal file
101
tests/test_zone_parser.py
Normal file
@@ -0,0 +1,101 @@
|
||||
"""Tests for directdnsonly.app.utils.zone_parser."""
|
||||
|
||||
import pytest
|
||||
from dns.exception import DNSException
|
||||
|
||||
from directdnsonly.app.utils.zone_parser import (
|
||||
count_zone_records,
|
||||
validate_and_normalize_zone,
|
||||
)
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
MINIMAL_ZONE = "example.com. 300 IN A 1.2.3.4"
|
||||
|
||||
FULL_ZONE = """\
|
||||
$ORIGIN example.com.
|
||||
$TTL 300
|
||||
@ IN SOA ns1.example.com. admin.example.com. 2024010101 3600 900 604800 300
|
||||
@ IN NS ns1.example.com.
|
||||
@ IN A 1.2.3.4
|
||||
www IN A 5.6.7.8
|
||||
mail IN MX 10 mail.example.com.
|
||||
"""
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# validate_and_normalize_zone
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def test_validate_adds_origin_when_missing():
|
||||
result = validate_and_normalize_zone(MINIMAL_ZONE, "example.com")
|
||||
assert "$ORIGIN example.com." in result
|
||||
|
||||
|
||||
def test_validate_adds_ttl_when_missing():
|
||||
result = validate_and_normalize_zone(MINIMAL_ZONE, "example.com")
|
||||
assert "$TTL" in result
|
||||
|
||||
|
||||
def test_validate_does_not_duplicate_origin():
|
||||
zone = "$ORIGIN example.com.\nexample.com. 300 IN A 1.2.3.4"
|
||||
result = validate_and_normalize_zone(zone, "example.com")
|
||||
assert result.count("$ORIGIN") == 1
|
||||
|
||||
|
||||
def test_validate_does_not_duplicate_ttl():
|
||||
zone = "$TTL 300\nexample.com. 300 IN A 1.2.3.4"
|
||||
result = validate_and_normalize_zone(zone, "example.com")
|
||||
assert result.count("$TTL") == 1
|
||||
|
||||
|
||||
def test_validate_appends_dot_to_domain():
|
||||
result = validate_and_normalize_zone(MINIMAL_ZONE, "example.com")
|
||||
assert "$ORIGIN example.com." in result
|
||||
|
||||
|
||||
def test_validate_returns_string():
|
||||
result = validate_and_normalize_zone(MINIMAL_ZONE, "example.com")
|
||||
assert isinstance(result, str)
|
||||
|
||||
|
||||
def test_validate_full_zone_passes():
|
||||
result = validate_and_normalize_zone(FULL_ZONE, "example.com")
|
||||
assert result is not None
|
||||
|
||||
|
||||
def test_validate_raises_on_invalid_zone():
|
||||
bad_zone = "this is not a zone file at all !!!"
|
||||
with pytest.raises(ValueError, match="Invalid zone data"):
|
||||
validate_and_normalize_zone(bad_zone, "example.com")
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# count_zone_records
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def test_count_records_simple_zone():
|
||||
zone = "$ORIGIN example.com.\n$TTL 300\n@ IN A 1.2.3.4\n@ IN AAAA ::1\n"
|
||||
count = count_zone_records(zone, "example.com")
|
||||
assert count == 2
|
||||
|
||||
|
||||
def test_count_records_soa_included():
|
||||
count = count_zone_records(FULL_ZONE, "example.com")
|
||||
# SOA + NS + A (apex) + A (www) + MX = 5
|
||||
assert count == 5
|
||||
|
||||
|
||||
def test_count_records_returns_negative_on_bad_zone():
|
||||
count = count_zone_records("not a valid zone", "example.com")
|
||||
assert count == -1
|
||||
|
||||
|
||||
def test_count_records_empty_zone():
|
||||
zone = "$ORIGIN example.com.\n$TTL 300\n"
|
||||
count = count_zone_records(zone, "example.com")
|
||||
assert count == 0
|
||||
Reference in New Issue
Block a user