diff --git a/directdnsonly/app/backends/coredns_mysql.py b/directdnsonly/app/backends/coredns_mysql.py index 5cda61d..951a4fa 100644 --- a/directdnsonly/app/backends/coredns_mysql.py +++ b/directdnsonly/app/backends/coredns_mysql.py @@ -1,6 +1,6 @@ from typing import Optional, Dict, Set, Tuple, Any -from sqlalchemy import create_engine, Column, String, Integer, Text, ForeignKey, Boolean +from sqlalchemy import create_engine, Column, String, Integer, Text, ForeignKey, Boolean, select, func, delete from sqlalchemy.orm import sessionmaker, scoped_session, relationship, declarative_base from dns import zone as dns_zone_module from dns.rdataclass import IN @@ -45,7 +45,7 @@ class CoreDNSMySQLBackend(DNSBackend): pool_size=5, max_overflow=10, ) - self.Session = scoped_session(sessionmaker(bind=self.engine)) + self.Session = scoped_session(sessionmaker(self.engine)) Base.metadata.create_all(self.engine) logger.info( f"Initialized CoreDNS MySQL backend '{self.instance_name}' " @@ -79,7 +79,7 @@ class CoreDNSMySQLBackend(DNSBackend): # Get existing records for this zone but track SOA records separately existing_records = {} existing_soa = None - for r in session.query(Record).filter_by(zone_id=zone.id).all(): + for r in session.execute(select(Record).filter_by(zone_id=zone.id)).scalars().all(): if r.type == "SOA": existing_soa = r else: @@ -191,17 +191,17 @@ class CoreDNSMySQLBackend(DNSBackend): session = self.Session() try: # First find the zone - zone = ( - session.query(Zone) - .filter_by(zone_name=self.dot_fqdn(zone_name)) - .first() - ) + zone = session.execute( + select(Zone).filter_by(zone_name=self.dot_fqdn(zone_name)) + ).scalar_one_or_none() if not zone: logger.warning(f"Zone {zone_name} not found for deletion") return False # Delete all records associated with the zone - count = session.query(Record).filter_by(zone_id=zone.id).delete() + count = session.execute( + delete(Record).where(Record.zone_id == zone.id) + ).rowcount # Delete the zone itself session.delete(zone) @@ -228,12 +228,9 @@ class CoreDNSMySQLBackend(DNSBackend): def zone_exists(self, zone_name: str) -> bool: session = self.Session() try: - exists = ( - session.query(Zone) - .filter_by(zone_name=self.dot_fqdn(zone_name)) - .first() - is not None - ) + exists = session.execute( + select(Zone).filter_by(zone_name=self.dot_fqdn(zone_name)) + ).scalar_one_or_none() is not None logger.debug(f"Zone existence check for {zone_name}: {exists}") return exists except Exception as e: @@ -244,7 +241,9 @@ class CoreDNSMySQLBackend(DNSBackend): def _ensure_zone_exists(self, session, zone_name: str) -> Zone: """Ensure a zone exists in the database, creating it if necessary""" - zone = session.query(Zone).filter_by(zone_name=self.dot_fqdn(zone_name)).first() + zone = session.execute( + select(Zone).filter_by(zone_name=self.dot_fqdn(zone_name)) + ).scalar_one_or_none() if not zone: logger.debug(f"Creating new zone: {self.dot_fqdn(zone_name)}") zone = Zone(zone_name=self.dot_fqdn(zone_name)) @@ -322,11 +321,9 @@ class CoreDNSMySQLBackend(DNSBackend): """ session = self.Session() try: - zone = ( - session.query(Zone) - .filter_by(zone_name=self.dot_fqdn(zone_name)) - .first() - ) + zone = session.execute( + select(Zone).filter_by(zone_name=self.dot_fqdn(zone_name)) + ).scalar_one_or_none() if not zone: logger.warning( f"[{self.instance_name}] Zone {zone_name} not found " @@ -334,7 +331,9 @@ class CoreDNSMySQLBackend(DNSBackend): ) return False, 0 - actual_count = session.query(Record).filter_by(zone_id=zone.id).count() + actual_count = session.execute( + select(func.count()).select_from(Record).where(Record.zone_id == zone.id) + ).scalar() matches = actual_count == expected_count if not matches: @@ -382,11 +381,9 @@ class CoreDNSMySQLBackend(DNSBackend): """ session = self.Session() try: - zone = ( - session.query(Zone) - .filter_by(zone_name=self.dot_fqdn(zone_name)) - .first() - ) + zone = session.execute( + select(Zone).filter_by(zone_name=self.dot_fqdn(zone_name)) + ).scalar_one_or_none() if not zone: logger.warning( f"[{self.instance_name}] Zone {zone_name} not found " @@ -404,7 +401,9 @@ class CoreDNSMySQLBackend(DNSBackend): } # Query all records currently in the backend for this zone - db_records = session.query(Record).filter_by(zone_id=zone.id).all() + db_records = session.execute( + select(Record).where(Record.zone_id == zone.id) + ).scalars().all() removed = 0 for record in db_records: diff --git a/tests/test_coredns_mysql.py b/tests/test_coredns_mysql.py index 977bb5e..80b1e36 100644 --- a/tests/test_coredns_mysql.py +++ b/tests/test_coredns_mysql.py @@ -1,7 +1,7 @@ """Tests for the CoreDNS MySQL backend (run against in-memory SQLite).""" import pytest -from sqlalchemy import create_engine +from sqlalchemy import create_engine, select from sqlalchemy.orm import scoped_session, sessionmaker from directdnsonly.app.backends.coredns_mysql import ( @@ -28,7 +28,7 @@ def mysql_backend(): self.config = {} self.instance_name = "test" self.engine = engine - self.Session = scoped_session(sessionmaker(bind=engine)) + self.Session = scoped_session(sessionmaker(engine)) yield _TestBackend() engine.dispose() @@ -84,8 +84,8 @@ def test_write_zone_removes_stale_records(mysql_backend): mysql_backend.write_zone("example.com", reduced) session = mysql_backend.Session() - zone = session.query(Zone).filter_by(zone_name="example.com.").first() - records = session.query(Record).filter_by(zone_id=zone.id, type="AAAA").all() + zone = session.execute(select(Zone).filter_by(zone_name="example.com.")).scalar_one_or_none() + records = session.execute(select(Record).filter_by(zone_id=zone.id, type="AAAA")).scalars().all() assert records == [] session.close() @@ -141,7 +141,7 @@ def test_reconcile_removes_extra_records(mysql_backend): # Inject a phantom record directly into the DB session = mysql_backend.Session() - zone = session.query(Zone).filter_by(zone_name="example.com.").first() + zone = session.execute(select(Zone).filter_by(zone_name="example.com.")).scalar_one_or_none() session.add( Record( zone_id=zone.id, diff --git a/tests/test_peer_sync.py b/tests/test_peer_sync.py index 0fc2475..3a68fbb 100644 --- a/tests/test_peer_sync.py +++ b/tests/test_peer_sync.py @@ -3,6 +3,7 @@ import datetime import json import pytest +from sqlalchemy import select, func from unittest.mock import patch, MagicMock from directdnsonly.app.peer_sync import PeerSyncWorker @@ -123,7 +124,9 @@ def test_sync_creates_new_local_record(patch_connect, monkeypatch): worker._sync_from_peer(_make_peer()) - record = session.query(Domain).filter_by(domain="example.com").first() + record = session.execute( + select(Domain).filter_by(domain="example.com") + ).scalar_one_or_none() assert record is not None assert record.zone_data == ZONE_DATA assert record.zone_updated_at == NOW @@ -152,7 +155,6 @@ def test_sync_updates_older_local_record(patch_connect, monkeypatch): worker._sync_from_peer(_make_peer()) - from sqlalchemy import select record = session.execute( select(Domain).filter_by(domain="example.com") ).scalar_one_or_none() @@ -187,7 +189,9 @@ def test_sync_skips_when_local_is_newer(patch_connect, monkeypatch): # zone_data fetch should not have been called assert not fetch_calls - record = session.query(Domain).filter_by(domain="example.com").first() + record = session.execute( + select(Domain).filter_by(domain="example.com") + ).scalar_one_or_none() assert record.zone_data == "newer local" @@ -219,7 +223,7 @@ def test_sync_skips_peer_with_bad_status(patch_connect, monkeypatch): worker._sync_from_peer(_make_peer()) # No records should have been created - assert session.query(Domain).count() == 0 + assert session.execute(select(func.count()).select_from(Domain)).scalar() == 0 def test_sync_skips_missing_zone_data_in_response(patch_connect, monkeypatch): @@ -241,7 +245,7 @@ def test_sync_skips_missing_zone_data_in_response(patch_connect, monkeypatch): worker._sync_from_peer(_make_peer()) - assert session.query(Domain).count() == 0 + assert session.execute(select(func.count()).select_from(Domain)).scalar() == 0 def test_sync_empty_peer_list(patch_connect, monkeypatch):