You've already forked directdnsonly
chore: complete SQLAlchemy 2.0 migration in coredns_mysql backend and tests ⬆️
Migrate remaining session.query() calls in coredns_mysql.py to select()/session.execute() style; update bulk delete to delete() construct and count to func.count(); drop sessionmaker(bind=). Update test fixtures and assertions to match. Zero session.query() calls remaining across the entire codebase.
This commit is contained in:
@@ -1,6 +1,6 @@
|
|||||||
from typing import Optional, Dict, Set, Tuple, Any
|
from typing import Optional, Dict, Set, Tuple, Any
|
||||||
|
|
||||||
from sqlalchemy import create_engine, Column, String, Integer, Text, ForeignKey, Boolean
|
from sqlalchemy import create_engine, Column, String, Integer, Text, ForeignKey, Boolean, select, func, delete
|
||||||
from sqlalchemy.orm import sessionmaker, scoped_session, relationship, declarative_base
|
from sqlalchemy.orm import sessionmaker, scoped_session, relationship, declarative_base
|
||||||
from dns import zone as dns_zone_module
|
from dns import zone as dns_zone_module
|
||||||
from dns.rdataclass import IN
|
from dns.rdataclass import IN
|
||||||
@@ -45,7 +45,7 @@ class CoreDNSMySQLBackend(DNSBackend):
|
|||||||
pool_size=5,
|
pool_size=5,
|
||||||
max_overflow=10,
|
max_overflow=10,
|
||||||
)
|
)
|
||||||
self.Session = scoped_session(sessionmaker(bind=self.engine))
|
self.Session = scoped_session(sessionmaker(self.engine))
|
||||||
Base.metadata.create_all(self.engine)
|
Base.metadata.create_all(self.engine)
|
||||||
logger.info(
|
logger.info(
|
||||||
f"Initialized CoreDNS MySQL backend '{self.instance_name}' "
|
f"Initialized CoreDNS MySQL backend '{self.instance_name}' "
|
||||||
@@ -79,7 +79,7 @@ class CoreDNSMySQLBackend(DNSBackend):
|
|||||||
# Get existing records for this zone but track SOA records separately
|
# Get existing records for this zone but track SOA records separately
|
||||||
existing_records = {}
|
existing_records = {}
|
||||||
existing_soa = None
|
existing_soa = None
|
||||||
for r in session.query(Record).filter_by(zone_id=zone.id).all():
|
for r in session.execute(select(Record).filter_by(zone_id=zone.id)).scalars().all():
|
||||||
if r.type == "SOA":
|
if r.type == "SOA":
|
||||||
existing_soa = r
|
existing_soa = r
|
||||||
else:
|
else:
|
||||||
@@ -191,17 +191,17 @@ class CoreDNSMySQLBackend(DNSBackend):
|
|||||||
session = self.Session()
|
session = self.Session()
|
||||||
try:
|
try:
|
||||||
# First find the zone
|
# First find the zone
|
||||||
zone = (
|
zone = session.execute(
|
||||||
session.query(Zone)
|
select(Zone).filter_by(zone_name=self.dot_fqdn(zone_name))
|
||||||
.filter_by(zone_name=self.dot_fqdn(zone_name))
|
).scalar_one_or_none()
|
||||||
.first()
|
|
||||||
)
|
|
||||||
if not zone:
|
if not zone:
|
||||||
logger.warning(f"Zone {zone_name} not found for deletion")
|
logger.warning(f"Zone {zone_name} not found for deletion")
|
||||||
return False
|
return False
|
||||||
|
|
||||||
# Delete all records associated with the zone
|
# Delete all records associated with the zone
|
||||||
count = session.query(Record).filter_by(zone_id=zone.id).delete()
|
count = session.execute(
|
||||||
|
delete(Record).where(Record.zone_id == zone.id)
|
||||||
|
).rowcount
|
||||||
|
|
||||||
# Delete the zone itself
|
# Delete the zone itself
|
||||||
session.delete(zone)
|
session.delete(zone)
|
||||||
@@ -228,12 +228,9 @@ class CoreDNSMySQLBackend(DNSBackend):
|
|||||||
def zone_exists(self, zone_name: str) -> bool:
|
def zone_exists(self, zone_name: str) -> bool:
|
||||||
session = self.Session()
|
session = self.Session()
|
||||||
try:
|
try:
|
||||||
exists = (
|
exists = session.execute(
|
||||||
session.query(Zone)
|
select(Zone).filter_by(zone_name=self.dot_fqdn(zone_name))
|
||||||
.filter_by(zone_name=self.dot_fqdn(zone_name))
|
).scalar_one_or_none() is not None
|
||||||
.first()
|
|
||||||
is not None
|
|
||||||
)
|
|
||||||
logger.debug(f"Zone existence check for {zone_name}: {exists}")
|
logger.debug(f"Zone existence check for {zone_name}: {exists}")
|
||||||
return exists
|
return exists
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
@@ -244,7 +241,9 @@ class CoreDNSMySQLBackend(DNSBackend):
|
|||||||
|
|
||||||
def _ensure_zone_exists(self, session, zone_name: str) -> Zone:
|
def _ensure_zone_exists(self, session, zone_name: str) -> Zone:
|
||||||
"""Ensure a zone exists in the database, creating it if necessary"""
|
"""Ensure a zone exists in the database, creating it if necessary"""
|
||||||
zone = session.query(Zone).filter_by(zone_name=self.dot_fqdn(zone_name)).first()
|
zone = session.execute(
|
||||||
|
select(Zone).filter_by(zone_name=self.dot_fqdn(zone_name))
|
||||||
|
).scalar_one_or_none()
|
||||||
if not zone:
|
if not zone:
|
||||||
logger.debug(f"Creating new zone: {self.dot_fqdn(zone_name)}")
|
logger.debug(f"Creating new zone: {self.dot_fqdn(zone_name)}")
|
||||||
zone = Zone(zone_name=self.dot_fqdn(zone_name))
|
zone = Zone(zone_name=self.dot_fqdn(zone_name))
|
||||||
@@ -322,11 +321,9 @@ class CoreDNSMySQLBackend(DNSBackend):
|
|||||||
"""
|
"""
|
||||||
session = self.Session()
|
session = self.Session()
|
||||||
try:
|
try:
|
||||||
zone = (
|
zone = session.execute(
|
||||||
session.query(Zone)
|
select(Zone).filter_by(zone_name=self.dot_fqdn(zone_name))
|
||||||
.filter_by(zone_name=self.dot_fqdn(zone_name))
|
).scalar_one_or_none()
|
||||||
.first()
|
|
||||||
)
|
|
||||||
if not zone:
|
if not zone:
|
||||||
logger.warning(
|
logger.warning(
|
||||||
f"[{self.instance_name}] Zone {zone_name} not found "
|
f"[{self.instance_name}] Zone {zone_name} not found "
|
||||||
@@ -334,7 +331,9 @@ class CoreDNSMySQLBackend(DNSBackend):
|
|||||||
)
|
)
|
||||||
return False, 0
|
return False, 0
|
||||||
|
|
||||||
actual_count = session.query(Record).filter_by(zone_id=zone.id).count()
|
actual_count = session.execute(
|
||||||
|
select(func.count()).select_from(Record).where(Record.zone_id == zone.id)
|
||||||
|
).scalar()
|
||||||
matches = actual_count == expected_count
|
matches = actual_count == expected_count
|
||||||
|
|
||||||
if not matches:
|
if not matches:
|
||||||
@@ -382,11 +381,9 @@ class CoreDNSMySQLBackend(DNSBackend):
|
|||||||
"""
|
"""
|
||||||
session = self.Session()
|
session = self.Session()
|
||||||
try:
|
try:
|
||||||
zone = (
|
zone = session.execute(
|
||||||
session.query(Zone)
|
select(Zone).filter_by(zone_name=self.dot_fqdn(zone_name))
|
||||||
.filter_by(zone_name=self.dot_fqdn(zone_name))
|
).scalar_one_or_none()
|
||||||
.first()
|
|
||||||
)
|
|
||||||
if not zone:
|
if not zone:
|
||||||
logger.warning(
|
logger.warning(
|
||||||
f"[{self.instance_name}] Zone {zone_name} not found "
|
f"[{self.instance_name}] Zone {zone_name} not found "
|
||||||
@@ -404,7 +401,9 @@ class CoreDNSMySQLBackend(DNSBackend):
|
|||||||
}
|
}
|
||||||
|
|
||||||
# Query all records currently in the backend for this zone
|
# Query all records currently in the backend for this zone
|
||||||
db_records = session.query(Record).filter_by(zone_id=zone.id).all()
|
db_records = session.execute(
|
||||||
|
select(Record).where(Record.zone_id == zone.id)
|
||||||
|
).scalars().all()
|
||||||
|
|
||||||
removed = 0
|
removed = 0
|
||||||
for record in db_records:
|
for record in db_records:
|
||||||
|
|||||||
@@ -1,7 +1,7 @@
|
|||||||
"""Tests for the CoreDNS MySQL backend (run against in-memory SQLite)."""
|
"""Tests for the CoreDNS MySQL backend (run against in-memory SQLite)."""
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
from sqlalchemy import create_engine
|
from sqlalchemy import create_engine, select
|
||||||
from sqlalchemy.orm import scoped_session, sessionmaker
|
from sqlalchemy.orm import scoped_session, sessionmaker
|
||||||
|
|
||||||
from directdnsonly.app.backends.coredns_mysql import (
|
from directdnsonly.app.backends.coredns_mysql import (
|
||||||
@@ -28,7 +28,7 @@ def mysql_backend():
|
|||||||
self.config = {}
|
self.config = {}
|
||||||
self.instance_name = "test"
|
self.instance_name = "test"
|
||||||
self.engine = engine
|
self.engine = engine
|
||||||
self.Session = scoped_session(sessionmaker(bind=engine))
|
self.Session = scoped_session(sessionmaker(engine))
|
||||||
|
|
||||||
yield _TestBackend()
|
yield _TestBackend()
|
||||||
engine.dispose()
|
engine.dispose()
|
||||||
@@ -84,8 +84,8 @@ def test_write_zone_removes_stale_records(mysql_backend):
|
|||||||
mysql_backend.write_zone("example.com", reduced)
|
mysql_backend.write_zone("example.com", reduced)
|
||||||
|
|
||||||
session = mysql_backend.Session()
|
session = mysql_backend.Session()
|
||||||
zone = session.query(Zone).filter_by(zone_name="example.com.").first()
|
zone = session.execute(select(Zone).filter_by(zone_name="example.com.")).scalar_one_or_none()
|
||||||
records = session.query(Record).filter_by(zone_id=zone.id, type="AAAA").all()
|
records = session.execute(select(Record).filter_by(zone_id=zone.id, type="AAAA")).scalars().all()
|
||||||
assert records == []
|
assert records == []
|
||||||
session.close()
|
session.close()
|
||||||
|
|
||||||
@@ -141,7 +141,7 @@ def test_reconcile_removes_extra_records(mysql_backend):
|
|||||||
|
|
||||||
# Inject a phantom record directly into the DB
|
# Inject a phantom record directly into the DB
|
||||||
session = mysql_backend.Session()
|
session = mysql_backend.Session()
|
||||||
zone = session.query(Zone).filter_by(zone_name="example.com.").first()
|
zone = session.execute(select(Zone).filter_by(zone_name="example.com.")).scalar_one_or_none()
|
||||||
session.add(
|
session.add(
|
||||||
Record(
|
Record(
|
||||||
zone_id=zone.id,
|
zone_id=zone.id,
|
||||||
|
|||||||
@@ -3,6 +3,7 @@
|
|||||||
import datetime
|
import datetime
|
||||||
import json
|
import json
|
||||||
import pytest
|
import pytest
|
||||||
|
from sqlalchemy import select, func
|
||||||
from unittest.mock import patch, MagicMock
|
from unittest.mock import patch, MagicMock
|
||||||
|
|
||||||
from directdnsonly.app.peer_sync import PeerSyncWorker
|
from directdnsonly.app.peer_sync import PeerSyncWorker
|
||||||
@@ -123,7 +124,9 @@ def test_sync_creates_new_local_record(patch_connect, monkeypatch):
|
|||||||
|
|
||||||
worker._sync_from_peer(_make_peer())
|
worker._sync_from_peer(_make_peer())
|
||||||
|
|
||||||
record = session.query(Domain).filter_by(domain="example.com").first()
|
record = session.execute(
|
||||||
|
select(Domain).filter_by(domain="example.com")
|
||||||
|
).scalar_one_or_none()
|
||||||
assert record is not None
|
assert record is not None
|
||||||
assert record.zone_data == ZONE_DATA
|
assert record.zone_data == ZONE_DATA
|
||||||
assert record.zone_updated_at == NOW
|
assert record.zone_updated_at == NOW
|
||||||
@@ -152,7 +155,6 @@ def test_sync_updates_older_local_record(patch_connect, monkeypatch):
|
|||||||
|
|
||||||
worker._sync_from_peer(_make_peer())
|
worker._sync_from_peer(_make_peer())
|
||||||
|
|
||||||
from sqlalchemy import select
|
|
||||||
record = session.execute(
|
record = session.execute(
|
||||||
select(Domain).filter_by(domain="example.com")
|
select(Domain).filter_by(domain="example.com")
|
||||||
).scalar_one_or_none()
|
).scalar_one_or_none()
|
||||||
@@ -187,7 +189,9 @@ def test_sync_skips_when_local_is_newer(patch_connect, monkeypatch):
|
|||||||
|
|
||||||
# zone_data fetch should not have been called
|
# zone_data fetch should not have been called
|
||||||
assert not fetch_calls
|
assert not fetch_calls
|
||||||
record = session.query(Domain).filter_by(domain="example.com").first()
|
record = session.execute(
|
||||||
|
select(Domain).filter_by(domain="example.com")
|
||||||
|
).scalar_one_or_none()
|
||||||
assert record.zone_data == "newer local"
|
assert record.zone_data == "newer local"
|
||||||
|
|
||||||
|
|
||||||
@@ -219,7 +223,7 @@ def test_sync_skips_peer_with_bad_status(patch_connect, monkeypatch):
|
|||||||
worker._sync_from_peer(_make_peer())
|
worker._sync_from_peer(_make_peer())
|
||||||
|
|
||||||
# No records should have been created
|
# No records should have been created
|
||||||
assert session.query(Domain).count() == 0
|
assert session.execute(select(func.count()).select_from(Domain)).scalar() == 0
|
||||||
|
|
||||||
|
|
||||||
def test_sync_skips_missing_zone_data_in_response(patch_connect, monkeypatch):
|
def test_sync_skips_missing_zone_data_in_response(patch_connect, monkeypatch):
|
||||||
@@ -241,7 +245,7 @@ def test_sync_skips_missing_zone_data_in_response(patch_connect, monkeypatch):
|
|||||||
|
|
||||||
worker._sync_from_peer(_make_peer())
|
worker._sync_from_peer(_make_peer())
|
||||||
|
|
||||||
assert session.query(Domain).count() == 0
|
assert session.execute(select(func.count()).select_from(Domain)).scalar() == 0
|
||||||
|
|
||||||
|
|
||||||
def test_sync_empty_peer_list(patch_connect, monkeypatch):
|
def test_sync_empty_peer_list(patch_connect, monkeypatch):
|
||||||
|
|||||||
Reference in New Issue
Block a user