Files
directdnsonly/directdnsonly/app/backends/coredns_mysql.py
Aaron Guise 74c5f4012e style: apply black formatting across codebase 🎨
No logic changes — pure reformatting of line lengths, dict literals,
method-chain line breaks, and trailing newlines to satisfy black's style.
2026-02-18 22:53:09 +13:00

449 lines
17 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
from typing import Optional, Dict, Set, Tuple, Any
from sqlalchemy import create_engine, Column, String, Integer, Text, ForeignKey, Boolean
from sqlalchemy.ext.declarative import declarative_base
from sqlalchemy.orm import sessionmaker, scoped_session, relationship
from dns import zone as dns_zone_module
from dns.rdataclass import IN
from loguru import logger
from .base import DNSBackend
Base = declarative_base()
class Zone(Base):
__tablename__ = "zones"
id = Column(Integer, primary_key=True)
zone_name = Column(String(255), nullable=False, index=True, unique=True)
class Record(Base):
__tablename__ = "records"
id = Column(Integer, primary_key=True)
zone_id = Column(Integer, ForeignKey("zones.id"), nullable=False)
hostname = Column(String(255), nullable=False, index=True)
type = Column(String(10), nullable=False)
data = Column(Text, nullable=False)
ttl = Column(Integer, nullable=True)
online = Column(Boolean, nullable=False, default=False)
zone = relationship("Zone", backref="records")
class CoreDNSMySQLBackend(DNSBackend):
def __init__(self, config: Dict[str, Any]):
super().__init__(config)
self.host = config.get("host", "localhost")
self.port = config.get("port", 3306)
self.database = config.get("database", "coredns")
self.username = config.get("username")
self.password = config.get("password")
self.engine = create_engine(
f"mysql+pymysql://{self.username}:{self.password}@"
f"{self.host}:{self.port}/{self.database}",
pool_pre_ping=True,
pool_size=5,
max_overflow=10,
)
self.Session = scoped_session(sessionmaker(bind=self.engine))
Base.metadata.create_all(self.engine)
logger.info(
f"Initialized CoreDNS MySQL backend '{self.instance_name}' "
f"for {self.database}@{self.host}:{self.port}"
)
@staticmethod
def dot_fqdn(zone_name):
return f"{zone_name}." if not zone_name.endswith(".") else zone_name
@classmethod
def get_name(cls) -> str:
return "coredns_mysql"
@classmethod
def is_available(cls) -> bool:
try:
import pymysql
return True
except ImportError:
logger.warning("PyMySQL not available - CoreDNS MySQL backend disabled")
return False
def write_zone(self, zone_name: str, zone_data: str) -> bool:
session = self.Session()
try:
# Ensure zone exists
zone = self._ensure_zone_exists(session, zone_name)
# Get existing records for this zone but track SOA records separately
existing_records = {}
existing_soa = None
for r in session.query(Record).filter_by(zone_id=zone.id).all():
if r.type == "SOA":
existing_soa = r
else:
existing_records[(r.hostname, r.type, r.data)] = r
# Parse the zone data into a normalised record set
source_records, source_soa = self._parse_zone_to_record_set(
zone_name, zone_data
)
# Track changes
current_records = set()
changes = {"added": 0, "updated": 0, "removed": 0}
# Handle SOA record
if source_soa:
soa_name, soa_content, soa_ttl = source_soa
soa_parts = soa_content.split()
if len(soa_parts) == 7:
if existing_soa:
existing_soa.data = soa_content
existing_soa.ttl = soa_ttl
existing_soa.online = True
changes["updated"] += 1
logger.debug(
f"Updated SOA record: {soa_name} SOA {soa_content}"
)
else:
existing_soa = Record(
zone_id=zone.id,
hostname=soa_name,
type="SOA",
data=soa_content,
ttl=soa_ttl,
online=True,
)
session.add(existing_soa)
changes["added"] += 1
logger.debug(f"Added SOA record: {soa_name} SOA {soa_content}")
# Process all non-SOA records
for record_name, record_type, record_content, record_ttl in source_records:
key = (record_name, record_type, record_content)
current_records.add(key)
if key in existing_records:
# Update existing record if TTL changed
record = existing_records[key]
if record.ttl != record_ttl:
record.ttl = record_ttl
record.online = True
changes["updated"] += 1
logger.debug(
f"Updated TTL for record: {record_name} {record_type} {record_content}"
)
else:
# Add new record
new_record = Record(
zone_id=zone.id,
hostname=record_name,
type=record_type,
data=record_content,
ttl=record_ttl,
online=True,
)
session.add(new_record)
changes["added"] += 1
logger.debug(
f"Added new record: {record_name} {record_type} {record_content}"
)
# Remove records that no longer exist in the source zone
for key, record in existing_records.items():
if key not in current_records:
logger.debug(
f"Removed record: {record.hostname} {record.type} {record.data}"
)
session.delete(record)
changes["removed"] += 1
# Handle SOA removal if needed
if existing_soa and not source_soa:
logger.debug(
f"Removed SOA record: {existing_soa.hostname} SOA {existing_soa.data}"
)
session.delete(existing_soa)
changes["removed"] += 1
session.commit()
total_changes = changes["added"] + changes["updated"] + changes["removed"]
if total_changes > 0:
logger.info(
f"[{self.instance_name}] Zone {zone_name} updated: "
f"{changes['added']} added, {changes['updated']} updated, "
f"{changes['removed']} removed"
)
else:
logger.debug(f"[{self.instance_name}] Zone {zone_name}: no changes")
return True
except Exception as e:
logger.error(f"Error writing zone {zone_name}: {e}")
session.rollback()
return False
finally:
session.close()
def delete_zone(self, zone_name: str) -> bool:
session = self.Session()
try:
# First find the zone
zone = (
session.query(Zone)
.filter_by(zone_name=self.dot_fqdn(zone_name))
.first()
)
if not zone:
logger.warning(f"Zone {zone_name} not found for deletion")
return False
# Delete all records associated with the zone
count = session.query(Record).filter_by(zone_id=zone.id).delete()
# Delete the zone itself
session.delete(zone)
session.commit()
logger.info(f"Deleted zone {zone_name} with {count} records")
return True
except Exception as e:
session.rollback()
logger.error(f"Zone deletion failed for {zone_name}: {e}")
return False
finally:
session.close()
def reload_zone(self, zone_name: Optional[str] = None) -> bool:
# In coredns_mysql_extend, the core plugin handles reloading automatically
# when database changes are detected, so we just log the request
if zone_name:
logger.debug(f"CoreDNS reload triggered for zone {zone_name}")
else:
logger.debug("CoreDNS reload triggered for all zones")
return True
def zone_exists(self, zone_name: str) -> bool:
session = self.Session()
try:
exists = (
session.query(Zone)
.filter_by(zone_name=self.dot_fqdn(zone_name))
.first()
is not None
)
logger.debug(f"Zone existence check for {zone_name}: {exists}")
return exists
except Exception as e:
logger.error(f"Zone existence check failed for {zone_name}: {e}")
return False
finally:
session.close()
def _ensure_zone_exists(self, session, zone_name: str) -> Zone:
"""Ensure a zone exists in the database, creating it if necessary"""
zone = session.query(Zone).filter_by(zone_name=self.dot_fqdn(zone_name)).first()
if not zone:
logger.debug(f"Creating new zone: {self.dot_fqdn(zone_name)}")
zone = Zone(zone_name=self.dot_fqdn(zone_name))
session.add(zone)
session.flush() # Get the zone ID
return zone
def _normalize_cname_data(self, zone_name: str, record_content: str) -> str:
"""Normalize CNAME record data to ensure consistent FQDN format.
This ensures CNAME targets are always stored as fully-qualified domain
names so that record comparison between the BIND zone source and the
database is deterministic.
Args:
zone_name: The zone name for relative-name expansion
record_content: The raw CNAME target from the parsed zone
Returns:
The normalized CNAME target string
"""
if record_content.startswith("@"):
logger.debug(f"CNAME target starts with '@', replacing with zone FQDN")
record_content = self.dot_fqdn(zone_name)
elif not record_content.endswith("."):
logger.debug(f"CNAME target {record_content} is relative, appending zone")
record_content = ".".join([record_content, self.dot_fqdn(zone_name)])
return record_content
def _parse_zone_to_record_set(
self, zone_name: str, zone_data: str
) -> Tuple[Set[Tuple[str, str, str, int]], Optional[Tuple[str, str, int]]]:
"""Parse a BIND zone file into a set of normalised record keys.
Returns:
Tuple of:
- set of (hostname, type, data, ttl) tuples for non-SOA records
- (hostname, soa_data, ttl) tuple for the SOA record, or None
"""
dns_zone = dns_zone_module.from_text(zone_data, check_origin=False)
records: Set[Tuple[str, str, str, int]] = set()
soa = None
for name, ttl, rdata in dns_zone.iterate_rdatas():
if rdata.rdclass != IN:
continue
record_name = str(name)
record_type = rdata.rdtype.name
record_content = rdata.to_text()
if record_type == "SOA":
soa = (record_name, record_content, ttl)
continue
if record_type == "CNAME":
record_content = self._normalize_cname_data(zone_name, record_content)
records.add((record_name, record_type, record_content, ttl))
return records, soa
def verify_zone_record_count(
self, zone_name: str, expected_count: int
) -> tuple[bool, int]:
"""Verify the record count in this backend matches the expected count
from the source (DirectAdmin) zone file.
Args:
zone_name: The zone to verify
expected_count: The number of records parsed from the source BIND zone
Returns:
Tuple of (matches: bool, actual_count: int)
"""
session = self.Session()
try:
zone = (
session.query(Zone)
.filter_by(zone_name=self.dot_fqdn(zone_name))
.first()
)
if not zone:
logger.warning(
f"[{self.instance_name}] Zone {zone_name} not found "
f"during record count verification"
)
return False, 0
actual_count = session.query(Record).filter_by(zone_id=zone.id).count()
matches = actual_count == expected_count
if not matches:
logger.warning(
f"[{self.instance_name}] Record count mismatch for "
f"{zone_name}: source zone has {expected_count} records, "
f"backend has {actual_count} records "
f"(difference: {actual_count - expected_count:+d})"
)
else:
logger.debug(
f"[{self.instance_name}] Record count verified for "
f"{zone_name}: {actual_count} records match source"
)
return matches, actual_count
except Exception as e:
logger.error(
f"[{self.instance_name}] Error verifying record count "
f"for {zone_name}: {e}"
)
return False, -1
finally:
session.close()
def reconcile_zone_records(
self, zone_name: str, zone_data: str
) -> Tuple[bool, int]:
"""Reconcile backend records against the authoritative BIND zone from
DirectAdmin. Any records in the backend that are **not** present in
the source zone will be deleted.
This is the post-write safety net: even though ``write_zone`` already
removes stale records during normal processing, this method catches
any extras that may have crept in via race conditions, manual edits,
or replication drift between MySQL nodes.
Args:
zone_name: The zone to reconcile
zone_data: The raw BIND zone file content (authoritative source)
Returns:
Tuple of (success: bool, records_removed: int)
"""
session = self.Session()
try:
zone = (
session.query(Zone)
.filter_by(zone_name=self.dot_fqdn(zone_name))
.first()
)
if not zone:
logger.warning(
f"[{self.instance_name}] Zone {zone_name} not found "
f"during reconciliation"
)
return False, 0
# Build the expected record set from the source BIND zone
source_records, source_soa = self._parse_zone_to_record_set(
zone_name, zone_data
)
# Build lookup keys (without TTL) matching write_zone's key format
expected_keys: Set[Tuple[str, str, str]] = {
(hostname, rtype, data) for hostname, rtype, data, _ in source_records
}
# Query all records currently in the backend for this zone
db_records = session.query(Record).filter_by(zone_id=zone.id).all()
removed = 0
for record in db_records:
# SOA records are managed separately skip them
if record.type == "SOA":
continue
key = (record.hostname, record.type, record.data)
if key not in expected_keys:
logger.debug(
f"[{self.instance_name}] Reconcile: removing extra "
f"record from {zone_name}: "
f"{record.hostname} {record.type} {record.data}"
)
session.delete(record)
removed += 1
if removed > 0:
session.commit()
logger.info(
f"[{self.instance_name}] Reconciliation for {zone_name}: "
f"removed {removed} extra record(s) not in source zone"
)
else:
logger.debug(
f"[{self.instance_name}] Reconciliation for {zone_name}: "
f"all records match source zone — no action needed"
)
return True, removed
except Exception as e:
logger.error(
f"[{self.instance_name}] Error reconciling records "
f"for {zone_name}: {e}"
)
session.rollback()
return False, 0
finally:
session.close()