191 lines
6.1 KiB
Python
191 lines
6.1 KiB
Python
"""
|
|
Tests for Phase 1 missing features implementation.
|
|
"""
|
|
|
|
import pytest
|
|
from unittest.mock import Mock, patch
|
|
from datetime import datetime
|
|
|
|
from src.database.flexible_client import FlexibleDatabaseClient
|
|
from src.database.async_client import AsyncDatabaseClient
|
|
|
|
|
|
class TestFlexibleDatabaseClientEnhancements:
|
|
"""Test enhancements to FlexibleDatabaseClient."""
|
|
|
|
def test_query_timeout_parameter(self):
|
|
"""Test that query timeout parameter is accepted."""
|
|
client = FlexibleDatabaseClient(
|
|
database_url="sqlite:///:memory:",
|
|
query_timeout=45
|
|
)
|
|
assert client.query_timeout == 45
|
|
|
|
def test_connection_stats_tracking(self):
|
|
"""Test connection statistics tracking."""
|
|
client = FlexibleDatabaseClient(
|
|
database_url="sqlite:///:memory:",
|
|
query_timeout=30
|
|
)
|
|
|
|
# Initial stats
|
|
stats = client.get_connection_stats()
|
|
assert stats["connection_attempts"] == 0
|
|
assert stats["successful_connections"] == 0
|
|
assert stats["failed_connections"] == 0
|
|
assert stats["query_timeout"] == 30
|
|
assert stats["database_type"] == "SQLite"
|
|
|
|
def test_health_check_method(self):
|
|
"""Test health check method."""
|
|
client = FlexibleDatabaseClient(
|
|
database_url="sqlite:///:memory:",
|
|
query_timeout=30
|
|
)
|
|
|
|
# Health check should fail when not connected (engine is None)
|
|
assert client.health_check() is False
|
|
|
|
# Health check should pass when connected
|
|
mock_engine = Mock()
|
|
mock_conn = Mock()
|
|
mock_result = Mock()
|
|
mock_result.fetchone.return_value = [1]
|
|
mock_conn.execute.return_value = mock_result
|
|
|
|
# Create a context manager mock
|
|
mock_context = Mock()
|
|
mock_context.__enter__ = Mock(return_value=mock_conn)
|
|
mock_context.__exit__ = Mock(return_value=None)
|
|
mock_engine.connect.return_value = mock_context
|
|
|
|
client.engine = mock_engine
|
|
assert client.health_check() is True
|
|
assert client.last_health_check is not None
|
|
|
|
def test_is_healthy_method(self):
|
|
"""Test is_healthy method."""
|
|
client = FlexibleDatabaseClient(
|
|
database_url="sqlite:///:memory:",
|
|
query_timeout=30
|
|
)
|
|
|
|
# Mock health_check to return True
|
|
with patch.object(client, 'health_check', return_value=True):
|
|
assert client.is_healthy() is True
|
|
|
|
# Mock health_check to return False
|
|
with patch.object(client, 'health_check', return_value=False):
|
|
assert client.is_healthy() is False
|
|
|
|
|
|
class TestAsyncDatabaseClient:
|
|
"""Test AsyncDatabaseClient implementation."""
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_async_client_initialization(self):
|
|
"""Test async client initialization."""
|
|
client = AsyncDatabaseClient(
|
|
database_url="sqlite:///:memory:",
|
|
query_timeout=45
|
|
)
|
|
|
|
assert client.query_timeout == 45
|
|
assert client.engine is None
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_async_url_conversion(self):
|
|
"""Test async URL conversion."""
|
|
client = AsyncDatabaseClient(
|
|
database_url="postgresql://user:pass@host/db",
|
|
query_timeout=30
|
|
)
|
|
|
|
async_url = client._convert_to_async_url("postgresql://user:pass@host/db")
|
|
assert async_url == "postgresql+asyncpg://user:pass@host/db"
|
|
|
|
async_url = client._convert_to_async_url("sqlite:///path/to/db.db")
|
|
assert async_url == "sqlite+aiosqlite:///path/to/db.db"
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_async_health_check(self):
|
|
"""Test async health check."""
|
|
client = AsyncDatabaseClient(
|
|
database_url="sqlite:///:memory:",
|
|
query_timeout=30
|
|
)
|
|
|
|
# Mock async engine
|
|
mock_engine = Mock()
|
|
mock_conn = Mock()
|
|
mock_result = Mock()
|
|
mock_result.scalar.return_value = 1
|
|
|
|
# Create an async mock for execute
|
|
async def mock_execute(*args, **kwargs):
|
|
return mock_result
|
|
|
|
mock_conn.execute = mock_execute
|
|
|
|
# Create an async context manager
|
|
async def mock_aenter(self):
|
|
return mock_conn
|
|
|
|
async def mock_aexit(self, *args):
|
|
pass
|
|
|
|
mock_context = Mock()
|
|
mock_context.__aenter__ = mock_aenter
|
|
mock_context.__aexit__ = mock_aexit
|
|
mock_engine.connect.return_value = mock_context
|
|
|
|
client.engine = mock_engine
|
|
health = await client.health_check()
|
|
assert health is True
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_async_connection_info(self):
|
|
"""Test async connection info."""
|
|
client = AsyncDatabaseClient(
|
|
database_url="postgresql://user:pass@host/db",
|
|
pool_size=5,
|
|
max_overflow=10,
|
|
query_timeout=45
|
|
)
|
|
|
|
info = await client.get_connection_info()
|
|
assert info["database_type"] == "PostgreSQL"
|
|
assert info["pool_size"] == 5
|
|
assert info["max_overflow"] == 10
|
|
assert info["query_timeout"] == 45
|
|
|
|
|
|
class TestDatabaseSettings:
|
|
"""Test database settings enhancements."""
|
|
|
|
def test_query_timeout_setting(self):
|
|
"""Test that query timeout setting is available."""
|
|
from config.settings import Settings
|
|
|
|
settings = Settings()
|
|
assert hasattr(settings, 'db_query_timeout')
|
|
assert settings.db_query_timeout == 30
|
|
|
|
def test_database_url_with_control_reader(self):
|
|
"""Test database URL uses control_reader user."""
|
|
from config.settings import Settings
|
|
|
|
settings = Settings(
|
|
db_host="localhost",
|
|
db_port=5432,
|
|
db_name="calejo",
|
|
db_user="control_reader",
|
|
db_password="secure_password"
|
|
)
|
|
|
|
url = settings.database_url
|
|
assert "control_reader" in url
|
|
assert "secure_password" in url
|
|
assert "localhost" in url
|
|
assert "5432" in url
|
|
assert "calejo" in url |