You've already forked directdnsonly
No logic changes — pure reformatting of line lengths, dict literals, method-chain line breaks, and trailing newlines to satisfy black's style.
449 lines
17 KiB
Python
449 lines
17 KiB
Python
from typing import Optional, Dict, Set, Tuple, Any
|
||
|
||
from sqlalchemy import create_engine, Column, String, Integer, Text, ForeignKey, Boolean
|
||
from sqlalchemy.ext.declarative import declarative_base
|
||
from sqlalchemy.orm import sessionmaker, scoped_session, relationship
|
||
from dns import zone as dns_zone_module
|
||
from dns.rdataclass import IN
|
||
from loguru import logger
|
||
from .base import DNSBackend
|
||
|
||
Base = declarative_base()
|
||
|
||
|
||
class Zone(Base):
|
||
__tablename__ = "zones"
|
||
id = Column(Integer, primary_key=True)
|
||
zone_name = Column(String(255), nullable=False, index=True, unique=True)
|
||
|
||
|
||
class Record(Base):
|
||
__tablename__ = "records"
|
||
id = Column(Integer, primary_key=True)
|
||
zone_id = Column(Integer, ForeignKey("zones.id"), nullable=False)
|
||
hostname = Column(String(255), nullable=False, index=True)
|
||
type = Column(String(10), nullable=False)
|
||
data = Column(Text, nullable=False)
|
||
ttl = Column(Integer, nullable=True)
|
||
online = Column(Boolean, nullable=False, default=False)
|
||
|
||
zone = relationship("Zone", backref="records")
|
||
|
||
|
||
class CoreDNSMySQLBackend(DNSBackend):
|
||
def __init__(self, config: Dict[str, Any]):
|
||
super().__init__(config)
|
||
self.host = config.get("host", "localhost")
|
||
self.port = config.get("port", 3306)
|
||
self.database = config.get("database", "coredns")
|
||
self.username = config.get("username")
|
||
self.password = config.get("password")
|
||
|
||
self.engine = create_engine(
|
||
f"mysql+pymysql://{self.username}:{self.password}@"
|
||
f"{self.host}:{self.port}/{self.database}",
|
||
pool_pre_ping=True,
|
||
pool_size=5,
|
||
max_overflow=10,
|
||
)
|
||
self.Session = scoped_session(sessionmaker(bind=self.engine))
|
||
Base.metadata.create_all(self.engine)
|
||
logger.info(
|
||
f"Initialized CoreDNS MySQL backend '{self.instance_name}' "
|
||
f"for {self.database}@{self.host}:{self.port}"
|
||
)
|
||
|
||
@staticmethod
|
||
def dot_fqdn(zone_name):
|
||
return f"{zone_name}." if not zone_name.endswith(".") else zone_name
|
||
|
||
@classmethod
|
||
def get_name(cls) -> str:
|
||
return "coredns_mysql"
|
||
|
||
@classmethod
|
||
def is_available(cls) -> bool:
|
||
try:
|
||
import pymysql
|
||
|
||
return True
|
||
except ImportError:
|
||
logger.warning("PyMySQL not available - CoreDNS MySQL backend disabled")
|
||
return False
|
||
|
||
def write_zone(self, zone_name: str, zone_data: str) -> bool:
|
||
session = self.Session()
|
||
try:
|
||
# Ensure zone exists
|
||
zone = self._ensure_zone_exists(session, zone_name)
|
||
|
||
# Get existing records for this zone but track SOA records separately
|
||
existing_records = {}
|
||
existing_soa = None
|
||
for r in session.query(Record).filter_by(zone_id=zone.id).all():
|
||
if r.type == "SOA":
|
||
existing_soa = r
|
||
else:
|
||
existing_records[(r.hostname, r.type, r.data)] = r
|
||
|
||
# Parse the zone data into a normalised record set
|
||
source_records, source_soa = self._parse_zone_to_record_set(
|
||
zone_name, zone_data
|
||
)
|
||
|
||
# Track changes
|
||
current_records = set()
|
||
changes = {"added": 0, "updated": 0, "removed": 0}
|
||
|
||
# Handle SOA record
|
||
if source_soa:
|
||
soa_name, soa_content, soa_ttl = source_soa
|
||
soa_parts = soa_content.split()
|
||
if len(soa_parts) == 7:
|
||
if existing_soa:
|
||
existing_soa.data = soa_content
|
||
existing_soa.ttl = soa_ttl
|
||
existing_soa.online = True
|
||
changes["updated"] += 1
|
||
logger.debug(
|
||
f"Updated SOA record: {soa_name} SOA {soa_content}"
|
||
)
|
||
else:
|
||
existing_soa = Record(
|
||
zone_id=zone.id,
|
||
hostname=soa_name,
|
||
type="SOA",
|
||
data=soa_content,
|
||
ttl=soa_ttl,
|
||
online=True,
|
||
)
|
||
session.add(existing_soa)
|
||
changes["added"] += 1
|
||
logger.debug(f"Added SOA record: {soa_name} SOA {soa_content}")
|
||
|
||
# Process all non-SOA records
|
||
for record_name, record_type, record_content, record_ttl in source_records:
|
||
key = (record_name, record_type, record_content)
|
||
current_records.add(key)
|
||
|
||
if key in existing_records:
|
||
# Update existing record if TTL changed
|
||
record = existing_records[key]
|
||
if record.ttl != record_ttl:
|
||
record.ttl = record_ttl
|
||
record.online = True
|
||
changes["updated"] += 1
|
||
logger.debug(
|
||
f"Updated TTL for record: {record_name} {record_type} {record_content}"
|
||
)
|
||
else:
|
||
# Add new record
|
||
new_record = Record(
|
||
zone_id=zone.id,
|
||
hostname=record_name,
|
||
type=record_type,
|
||
data=record_content,
|
||
ttl=record_ttl,
|
||
online=True,
|
||
)
|
||
session.add(new_record)
|
||
changes["added"] += 1
|
||
logger.debug(
|
||
f"Added new record: {record_name} {record_type} {record_content}"
|
||
)
|
||
|
||
# Remove records that no longer exist in the source zone
|
||
for key, record in existing_records.items():
|
||
if key not in current_records:
|
||
logger.debug(
|
||
f"Removed record: {record.hostname} {record.type} {record.data}"
|
||
)
|
||
session.delete(record)
|
||
changes["removed"] += 1
|
||
|
||
# Handle SOA removal if needed
|
||
if existing_soa and not source_soa:
|
||
logger.debug(
|
||
f"Removed SOA record: {existing_soa.hostname} SOA {existing_soa.data}"
|
||
)
|
||
session.delete(existing_soa)
|
||
changes["removed"] += 1
|
||
|
||
session.commit()
|
||
total_changes = changes["added"] + changes["updated"] + changes["removed"]
|
||
if total_changes > 0:
|
||
logger.info(
|
||
f"[{self.instance_name}] Zone {zone_name} updated: "
|
||
f"{changes['added']} added, {changes['updated']} updated, "
|
||
f"{changes['removed']} removed"
|
||
)
|
||
else:
|
||
logger.debug(f"[{self.instance_name}] Zone {zone_name}: no changes")
|
||
return True
|
||
|
||
except Exception as e:
|
||
logger.error(f"Error writing zone {zone_name}: {e}")
|
||
session.rollback()
|
||
return False
|
||
finally:
|
||
session.close()
|
||
|
||
def delete_zone(self, zone_name: str) -> bool:
|
||
session = self.Session()
|
||
try:
|
||
# First find the zone
|
||
zone = (
|
||
session.query(Zone)
|
||
.filter_by(zone_name=self.dot_fqdn(zone_name))
|
||
.first()
|
||
)
|
||
if not zone:
|
||
logger.warning(f"Zone {zone_name} not found for deletion")
|
||
return False
|
||
|
||
# Delete all records associated with the zone
|
||
count = session.query(Record).filter_by(zone_id=zone.id).delete()
|
||
|
||
# Delete the zone itself
|
||
session.delete(zone)
|
||
session.commit()
|
||
|
||
logger.info(f"Deleted zone {zone_name} with {count} records")
|
||
return True
|
||
except Exception as e:
|
||
session.rollback()
|
||
logger.error(f"Zone deletion failed for {zone_name}: {e}")
|
||
return False
|
||
finally:
|
||
session.close()
|
||
|
||
def reload_zone(self, zone_name: Optional[str] = None) -> bool:
|
||
# In coredns_mysql_extend, the core plugin handles reloading automatically
|
||
# when database changes are detected, so we just log the request
|
||
if zone_name:
|
||
logger.debug(f"CoreDNS reload triggered for zone {zone_name}")
|
||
else:
|
||
logger.debug("CoreDNS reload triggered for all zones")
|
||
return True
|
||
|
||
def zone_exists(self, zone_name: str) -> bool:
|
||
session = self.Session()
|
||
try:
|
||
exists = (
|
||
session.query(Zone)
|
||
.filter_by(zone_name=self.dot_fqdn(zone_name))
|
||
.first()
|
||
is not None
|
||
)
|
||
logger.debug(f"Zone existence check for {zone_name}: {exists}")
|
||
return exists
|
||
except Exception as e:
|
||
logger.error(f"Zone existence check failed for {zone_name}: {e}")
|
||
return False
|
||
finally:
|
||
session.close()
|
||
|
||
def _ensure_zone_exists(self, session, zone_name: str) -> Zone:
|
||
"""Ensure a zone exists in the database, creating it if necessary"""
|
||
zone = session.query(Zone).filter_by(zone_name=self.dot_fqdn(zone_name)).first()
|
||
if not zone:
|
||
logger.debug(f"Creating new zone: {self.dot_fqdn(zone_name)}")
|
||
zone = Zone(zone_name=self.dot_fqdn(zone_name))
|
||
session.add(zone)
|
||
session.flush() # Get the zone ID
|
||
return zone
|
||
|
||
def _normalize_cname_data(self, zone_name: str, record_content: str) -> str:
|
||
"""Normalize CNAME record data to ensure consistent FQDN format.
|
||
|
||
This ensures CNAME targets are always stored as fully-qualified domain
|
||
names so that record comparison between the BIND zone source and the
|
||
database is deterministic.
|
||
|
||
Args:
|
||
zone_name: The zone name for relative-name expansion
|
||
record_content: The raw CNAME target from the parsed zone
|
||
|
||
Returns:
|
||
The normalized CNAME target string
|
||
"""
|
||
if record_content.startswith("@"):
|
||
logger.debug(f"CNAME target starts with '@', replacing with zone FQDN")
|
||
record_content = self.dot_fqdn(zone_name)
|
||
elif not record_content.endswith("."):
|
||
logger.debug(f"CNAME target {record_content} is relative, appending zone")
|
||
record_content = ".".join([record_content, self.dot_fqdn(zone_name)])
|
||
return record_content
|
||
|
||
def _parse_zone_to_record_set(
|
||
self, zone_name: str, zone_data: str
|
||
) -> Tuple[Set[Tuple[str, str, str, int]], Optional[Tuple[str, str, int]]]:
|
||
"""Parse a BIND zone file into a set of normalised record keys.
|
||
|
||
Returns:
|
||
Tuple of:
|
||
- set of (hostname, type, data, ttl) tuples for non-SOA records
|
||
- (hostname, soa_data, ttl) tuple for the SOA record, or None
|
||
"""
|
||
dns_zone = dns_zone_module.from_text(zone_data, check_origin=False)
|
||
records: Set[Tuple[str, str, str, int]] = set()
|
||
soa = None
|
||
|
||
for name, ttl, rdata in dns_zone.iterate_rdatas():
|
||
if rdata.rdclass != IN:
|
||
continue
|
||
|
||
record_name = str(name)
|
||
record_type = rdata.rdtype.name
|
||
record_content = rdata.to_text()
|
||
|
||
if record_type == "SOA":
|
||
soa = (record_name, record_content, ttl)
|
||
continue
|
||
|
||
if record_type == "CNAME":
|
||
record_content = self._normalize_cname_data(zone_name, record_content)
|
||
|
||
records.add((record_name, record_type, record_content, ttl))
|
||
|
||
return records, soa
|
||
|
||
def verify_zone_record_count(
|
||
self, zone_name: str, expected_count: int
|
||
) -> tuple[bool, int]:
|
||
"""Verify the record count in this backend matches the expected count
|
||
from the source (DirectAdmin) zone file.
|
||
|
||
Args:
|
||
zone_name: The zone to verify
|
||
expected_count: The number of records parsed from the source BIND zone
|
||
|
||
Returns:
|
||
Tuple of (matches: bool, actual_count: int)
|
||
"""
|
||
session = self.Session()
|
||
try:
|
||
zone = (
|
||
session.query(Zone)
|
||
.filter_by(zone_name=self.dot_fqdn(zone_name))
|
||
.first()
|
||
)
|
||
if not zone:
|
||
logger.warning(
|
||
f"[{self.instance_name}] Zone {zone_name} not found "
|
||
f"during record count verification"
|
||
)
|
||
return False, 0
|
||
|
||
actual_count = session.query(Record).filter_by(zone_id=zone.id).count()
|
||
matches = actual_count == expected_count
|
||
|
||
if not matches:
|
||
logger.warning(
|
||
f"[{self.instance_name}] Record count mismatch for "
|
||
f"{zone_name}: source zone has {expected_count} records, "
|
||
f"backend has {actual_count} records "
|
||
f"(difference: {actual_count - expected_count:+d})"
|
||
)
|
||
else:
|
||
logger.debug(
|
||
f"[{self.instance_name}] Record count verified for "
|
||
f"{zone_name}: {actual_count} records match source"
|
||
)
|
||
|
||
return matches, actual_count
|
||
|
||
except Exception as e:
|
||
logger.error(
|
||
f"[{self.instance_name}] Error verifying record count "
|
||
f"for {zone_name}: {e}"
|
||
)
|
||
return False, -1
|
||
finally:
|
||
session.close()
|
||
|
||
def reconcile_zone_records(
|
||
self, zone_name: str, zone_data: str
|
||
) -> Tuple[bool, int]:
|
||
"""Reconcile backend records against the authoritative BIND zone from
|
||
DirectAdmin. Any records in the backend that are **not** present in
|
||
the source zone will be deleted.
|
||
|
||
This is the post-write safety net: even though ``write_zone`` already
|
||
removes stale records during normal processing, this method catches
|
||
any extras that may have crept in via race conditions, manual edits,
|
||
or replication drift between MySQL nodes.
|
||
|
||
Args:
|
||
zone_name: The zone to reconcile
|
||
zone_data: The raw BIND zone file content (authoritative source)
|
||
|
||
Returns:
|
||
Tuple of (success: bool, records_removed: int)
|
||
"""
|
||
session = self.Session()
|
||
try:
|
||
zone = (
|
||
session.query(Zone)
|
||
.filter_by(zone_name=self.dot_fqdn(zone_name))
|
||
.first()
|
||
)
|
||
if not zone:
|
||
logger.warning(
|
||
f"[{self.instance_name}] Zone {zone_name} not found "
|
||
f"during reconciliation"
|
||
)
|
||
return False, 0
|
||
|
||
# Build the expected record set from the source BIND zone
|
||
source_records, source_soa = self._parse_zone_to_record_set(
|
||
zone_name, zone_data
|
||
)
|
||
# Build lookup keys (without TTL) matching write_zone's key format
|
||
expected_keys: Set[Tuple[str, str, str]] = {
|
||
(hostname, rtype, data) for hostname, rtype, data, _ in source_records
|
||
}
|
||
|
||
# Query all records currently in the backend for this zone
|
||
db_records = session.query(Record).filter_by(zone_id=zone.id).all()
|
||
|
||
removed = 0
|
||
for record in db_records:
|
||
# SOA records are managed separately – skip them
|
||
if record.type == "SOA":
|
||
continue
|
||
|
||
key = (record.hostname, record.type, record.data)
|
||
if key not in expected_keys:
|
||
logger.debug(
|
||
f"[{self.instance_name}] Reconcile: removing extra "
|
||
f"record from {zone_name}: "
|
||
f"{record.hostname} {record.type} {record.data}"
|
||
)
|
||
session.delete(record)
|
||
removed += 1
|
||
|
||
if removed > 0:
|
||
session.commit()
|
||
logger.info(
|
||
f"[{self.instance_name}] Reconciliation for {zone_name}: "
|
||
f"removed {removed} extra record(s) not in source zone"
|
||
)
|
||
else:
|
||
logger.debug(
|
||
f"[{self.instance_name}] Reconciliation for {zone_name}: "
|
||
f"all records match source zone — no action needed"
|
||
)
|
||
|
||
return True, removed
|
||
|
||
except Exception as e:
|
||
logger.error(
|
||
f"[{self.instance_name}] Error reconciling records "
|
||
f"for {zone_name}: {e}"
|
||
)
|
||
session.rollback()
|
||
return False, 0
|
||
finally:
|
||
session.close()
|