chore: complete SQLAlchemy 2.0 migration in coredns_mysql backend and tests ⬆️

Migrate remaining session.query() calls in coredns_mysql.py to
select()/session.execute() style; update bulk delete to delete()
construct and count to func.count(); drop sessionmaker(bind=).
Update test fixtures and assertions to match.

Zero session.query() calls remaining across the entire codebase.
This commit is contained in:
2026-02-19 23:43:54 +13:00
parent d81ecd6bdd
commit f9907d2859
3 changed files with 41 additions and 38 deletions

View File

@@ -1,6 +1,6 @@
from typing import Optional, Dict, Set, Tuple, Any from typing import Optional, Dict, Set, Tuple, Any
from sqlalchemy import create_engine, Column, String, Integer, Text, ForeignKey, Boolean from sqlalchemy import create_engine, Column, String, Integer, Text, ForeignKey, Boolean, select, func, delete
from sqlalchemy.orm import sessionmaker, scoped_session, relationship, declarative_base from sqlalchemy.orm import sessionmaker, scoped_session, relationship, declarative_base
from dns import zone as dns_zone_module from dns import zone as dns_zone_module
from dns.rdataclass import IN from dns.rdataclass import IN
@@ -45,7 +45,7 @@ class CoreDNSMySQLBackend(DNSBackend):
pool_size=5, pool_size=5,
max_overflow=10, max_overflow=10,
) )
self.Session = scoped_session(sessionmaker(bind=self.engine)) self.Session = scoped_session(sessionmaker(self.engine))
Base.metadata.create_all(self.engine) Base.metadata.create_all(self.engine)
logger.info( logger.info(
f"Initialized CoreDNS MySQL backend '{self.instance_name}' " f"Initialized CoreDNS MySQL backend '{self.instance_name}' "
@@ -79,7 +79,7 @@ class CoreDNSMySQLBackend(DNSBackend):
# Get existing records for this zone but track SOA records separately # Get existing records for this zone but track SOA records separately
existing_records = {} existing_records = {}
existing_soa = None existing_soa = None
for r in session.query(Record).filter_by(zone_id=zone.id).all(): for r in session.execute(select(Record).filter_by(zone_id=zone.id)).scalars().all():
if r.type == "SOA": if r.type == "SOA":
existing_soa = r existing_soa = r
else: else:
@@ -191,17 +191,17 @@ class CoreDNSMySQLBackend(DNSBackend):
session = self.Session() session = self.Session()
try: try:
# First find the zone # First find the zone
zone = ( zone = session.execute(
session.query(Zone) select(Zone).filter_by(zone_name=self.dot_fqdn(zone_name))
.filter_by(zone_name=self.dot_fqdn(zone_name)) ).scalar_one_or_none()
.first()
)
if not zone: if not zone:
logger.warning(f"Zone {zone_name} not found for deletion") logger.warning(f"Zone {zone_name} not found for deletion")
return False return False
# Delete all records associated with the zone # Delete all records associated with the zone
count = session.query(Record).filter_by(zone_id=zone.id).delete() count = session.execute(
delete(Record).where(Record.zone_id == zone.id)
).rowcount
# Delete the zone itself # Delete the zone itself
session.delete(zone) session.delete(zone)
@@ -228,12 +228,9 @@ class CoreDNSMySQLBackend(DNSBackend):
def zone_exists(self, zone_name: str) -> bool: def zone_exists(self, zone_name: str) -> bool:
session = self.Session() session = self.Session()
try: try:
exists = ( exists = session.execute(
session.query(Zone) select(Zone).filter_by(zone_name=self.dot_fqdn(zone_name))
.filter_by(zone_name=self.dot_fqdn(zone_name)) ).scalar_one_or_none() is not None
.first()
is not None
)
logger.debug(f"Zone existence check for {zone_name}: {exists}") logger.debug(f"Zone existence check for {zone_name}: {exists}")
return exists return exists
except Exception as e: except Exception as e:
@@ -244,7 +241,9 @@ class CoreDNSMySQLBackend(DNSBackend):
def _ensure_zone_exists(self, session, zone_name: str) -> Zone: def _ensure_zone_exists(self, session, zone_name: str) -> Zone:
"""Ensure a zone exists in the database, creating it if necessary""" """Ensure a zone exists in the database, creating it if necessary"""
zone = session.query(Zone).filter_by(zone_name=self.dot_fqdn(zone_name)).first() zone = session.execute(
select(Zone).filter_by(zone_name=self.dot_fqdn(zone_name))
).scalar_one_or_none()
if not zone: if not zone:
logger.debug(f"Creating new zone: {self.dot_fqdn(zone_name)}") logger.debug(f"Creating new zone: {self.dot_fqdn(zone_name)}")
zone = Zone(zone_name=self.dot_fqdn(zone_name)) zone = Zone(zone_name=self.dot_fqdn(zone_name))
@@ -322,11 +321,9 @@ class CoreDNSMySQLBackend(DNSBackend):
""" """
session = self.Session() session = self.Session()
try: try:
zone = ( zone = session.execute(
session.query(Zone) select(Zone).filter_by(zone_name=self.dot_fqdn(zone_name))
.filter_by(zone_name=self.dot_fqdn(zone_name)) ).scalar_one_or_none()
.first()
)
if not zone: if not zone:
logger.warning( logger.warning(
f"[{self.instance_name}] Zone {zone_name} not found " f"[{self.instance_name}] Zone {zone_name} not found "
@@ -334,7 +331,9 @@ class CoreDNSMySQLBackend(DNSBackend):
) )
return False, 0 return False, 0
actual_count = session.query(Record).filter_by(zone_id=zone.id).count() actual_count = session.execute(
select(func.count()).select_from(Record).where(Record.zone_id == zone.id)
).scalar()
matches = actual_count == expected_count matches = actual_count == expected_count
if not matches: if not matches:
@@ -382,11 +381,9 @@ class CoreDNSMySQLBackend(DNSBackend):
""" """
session = self.Session() session = self.Session()
try: try:
zone = ( zone = session.execute(
session.query(Zone) select(Zone).filter_by(zone_name=self.dot_fqdn(zone_name))
.filter_by(zone_name=self.dot_fqdn(zone_name)) ).scalar_one_or_none()
.first()
)
if not zone: if not zone:
logger.warning( logger.warning(
f"[{self.instance_name}] Zone {zone_name} not found " f"[{self.instance_name}] Zone {zone_name} not found "
@@ -404,7 +401,9 @@ class CoreDNSMySQLBackend(DNSBackend):
} }
# Query all records currently in the backend for this zone # Query all records currently in the backend for this zone
db_records = session.query(Record).filter_by(zone_id=zone.id).all() db_records = session.execute(
select(Record).where(Record.zone_id == zone.id)
).scalars().all()
removed = 0 removed = 0
for record in db_records: for record in db_records:

View File

@@ -1,7 +1,7 @@
"""Tests for the CoreDNS MySQL backend (run against in-memory SQLite).""" """Tests for the CoreDNS MySQL backend (run against in-memory SQLite)."""
import pytest import pytest
from sqlalchemy import create_engine from sqlalchemy import create_engine, select
from sqlalchemy.orm import scoped_session, sessionmaker from sqlalchemy.orm import scoped_session, sessionmaker
from directdnsonly.app.backends.coredns_mysql import ( from directdnsonly.app.backends.coredns_mysql import (
@@ -28,7 +28,7 @@ def mysql_backend():
self.config = {} self.config = {}
self.instance_name = "test" self.instance_name = "test"
self.engine = engine self.engine = engine
self.Session = scoped_session(sessionmaker(bind=engine)) self.Session = scoped_session(sessionmaker(engine))
yield _TestBackend() yield _TestBackend()
engine.dispose() engine.dispose()
@@ -84,8 +84,8 @@ def test_write_zone_removes_stale_records(mysql_backend):
mysql_backend.write_zone("example.com", reduced) mysql_backend.write_zone("example.com", reduced)
session = mysql_backend.Session() session = mysql_backend.Session()
zone = session.query(Zone).filter_by(zone_name="example.com.").first() zone = session.execute(select(Zone).filter_by(zone_name="example.com.")).scalar_one_or_none()
records = session.query(Record).filter_by(zone_id=zone.id, type="AAAA").all() records = session.execute(select(Record).filter_by(zone_id=zone.id, type="AAAA")).scalars().all()
assert records == [] assert records == []
session.close() session.close()
@@ -141,7 +141,7 @@ def test_reconcile_removes_extra_records(mysql_backend):
# Inject a phantom record directly into the DB # Inject a phantom record directly into the DB
session = mysql_backend.Session() session = mysql_backend.Session()
zone = session.query(Zone).filter_by(zone_name="example.com.").first() zone = session.execute(select(Zone).filter_by(zone_name="example.com.")).scalar_one_or_none()
session.add( session.add(
Record( Record(
zone_id=zone.id, zone_id=zone.id,

View File

@@ -3,6 +3,7 @@
import datetime import datetime
import json import json
import pytest import pytest
from sqlalchemy import select, func
from unittest.mock import patch, MagicMock from unittest.mock import patch, MagicMock
from directdnsonly.app.peer_sync import PeerSyncWorker from directdnsonly.app.peer_sync import PeerSyncWorker
@@ -123,7 +124,9 @@ def test_sync_creates_new_local_record(patch_connect, monkeypatch):
worker._sync_from_peer(_make_peer()) worker._sync_from_peer(_make_peer())
record = session.query(Domain).filter_by(domain="example.com").first() record = session.execute(
select(Domain).filter_by(domain="example.com")
).scalar_one_or_none()
assert record is not None assert record is not None
assert record.zone_data == ZONE_DATA assert record.zone_data == ZONE_DATA
assert record.zone_updated_at == NOW assert record.zone_updated_at == NOW
@@ -152,7 +155,6 @@ def test_sync_updates_older_local_record(patch_connect, monkeypatch):
worker._sync_from_peer(_make_peer()) worker._sync_from_peer(_make_peer())
from sqlalchemy import select
record = session.execute( record = session.execute(
select(Domain).filter_by(domain="example.com") select(Domain).filter_by(domain="example.com")
).scalar_one_or_none() ).scalar_one_or_none()
@@ -187,7 +189,9 @@ def test_sync_skips_when_local_is_newer(patch_connect, monkeypatch):
# zone_data fetch should not have been called # zone_data fetch should not have been called
assert not fetch_calls assert not fetch_calls
record = session.query(Domain).filter_by(domain="example.com").first() record = session.execute(
select(Domain).filter_by(domain="example.com")
).scalar_one_or_none()
assert record.zone_data == "newer local" assert record.zone_data == "newer local"
@@ -219,7 +223,7 @@ def test_sync_skips_peer_with_bad_status(patch_connect, monkeypatch):
worker._sync_from_peer(_make_peer()) worker._sync_from_peer(_make_peer())
# No records should have been created # No records should have been created
assert session.query(Domain).count() == 0 assert session.execute(select(func.count()).select_from(Domain)).scalar() == 0
def test_sync_skips_missing_zone_data_in_response(patch_connect, monkeypatch): def test_sync_skips_missing_zone_data_in_response(patch_connect, monkeypatch):
@@ -241,7 +245,7 @@ def test_sync_skips_missing_zone_data_in_response(patch_connect, monkeypatch):
worker._sync_from_peer(_make_peer()) worker._sync_from_peer(_make_peer())
assert session.query(Domain).count() == 0 assert session.execute(select(func.count()).select_from(Domain)).scalar() == 0
def test_sync_empty_peer_list(patch_connect, monkeypatch): def test_sync_empty_peer_list(patch_connect, monkeypatch):