169 lines
6.2 KiB
Python
169 lines
6.2 KiB
Python
|
|
"""
|
||
|
|
Unit tests for DatabaseClient class.
|
||
|
|
"""
|
||
|
|
|
||
|
|
import pytest
|
||
|
|
import pytest_asyncio
|
||
|
|
from unittest.mock import Mock, patch, MagicMock
|
||
|
|
from typing import Dict, Any
|
||
|
|
|
||
|
|
from src.database.client import DatabaseClient
|
||
|
|
|
||
|
|
|
||
|
|
class TestDatabaseClient:
|
||
|
|
"""Test cases for DatabaseClient."""
|
||
|
|
|
||
|
|
@pytest_asyncio.fixture
|
||
|
|
async def db_client(self):
|
||
|
|
"""Create a test database client."""
|
||
|
|
client = DatabaseClient(
|
||
|
|
database_url="postgresql://test:test@localhost:5432/test",
|
||
|
|
min_connections=1,
|
||
|
|
max_connections=2
|
||
|
|
)
|
||
|
|
|
||
|
|
# Mock the connection pool
|
||
|
|
client.connection_pool = Mock()
|
||
|
|
client.connection_pool.getconn.return_value = Mock()
|
||
|
|
client.connection_pool.putconn = Mock()
|
||
|
|
|
||
|
|
return client
|
||
|
|
|
||
|
|
@pytest.mark.asyncio
|
||
|
|
async def test_connect_success(self, db_client):
|
||
|
|
"""Test successful database connection."""
|
||
|
|
with patch('psycopg2.pool.ThreadedConnectionPool') as mock_pool:
|
||
|
|
mock_pool_instance = Mock()
|
||
|
|
mock_pool.return_value = mock_pool_instance
|
||
|
|
|
||
|
|
mock_conn = MagicMock()
|
||
|
|
mock_cursor = MagicMock()
|
||
|
|
mock_cursor.fetchone.return_value = {'version': 'PostgreSQL 14.0'}
|
||
|
|
mock_conn.cursor.return_value.__enter__.return_value = mock_cursor
|
||
|
|
mock_conn.cursor.return_value.__exit__.return_value = None
|
||
|
|
|
||
|
|
mock_pool_instance.getconn.return_value = mock_conn
|
||
|
|
|
||
|
|
await db_client.connect()
|
||
|
|
|
||
|
|
assert db_client.connection_pool is not None
|
||
|
|
mock_pool.assert_called_once()
|
||
|
|
|
||
|
|
@pytest.mark.asyncio
|
||
|
|
async def test_connect_failure(self, db_client):
|
||
|
|
"""Test database connection failure."""
|
||
|
|
with patch('psycopg2.pool.ThreadedConnectionPool') as mock_pool:
|
||
|
|
mock_pool.side_effect = Exception("Connection failed")
|
||
|
|
|
||
|
|
with pytest.raises(Exception, match="Connection failed"):
|
||
|
|
await db_client.connect()
|
||
|
|
|
||
|
|
def test_execute_query_success(self, db_client):
|
||
|
|
"""Test successful query execution."""
|
||
|
|
# Mock the cursor
|
||
|
|
mock_cursor = MagicMock()
|
||
|
|
mock_cursor.fetchall.return_value = [{'id': 1, 'name': 'test'}]
|
||
|
|
|
||
|
|
# Mock the cursor context manager
|
||
|
|
db_client._get_cursor = MagicMock()
|
||
|
|
db_client._get_cursor.return_value.__enter__.return_value = mock_cursor
|
||
|
|
db_client._get_cursor.return_value.__exit__.return_value = None
|
||
|
|
|
||
|
|
result = db_client.execute_query("SELECT * FROM test", (1,))
|
||
|
|
|
||
|
|
assert result == [{'id': 1, 'name': 'test'}]
|
||
|
|
mock_cursor.execute.assert_called_once_with("SELECT * FROM test", (1,))
|
||
|
|
|
||
|
|
def test_execute_query_failure(self, db_client):
|
||
|
|
"""Test query execution failure."""
|
||
|
|
# Mock the cursor to raise an exception
|
||
|
|
mock_cursor = MagicMock()
|
||
|
|
mock_cursor.execute.side_effect = Exception("Query failed")
|
||
|
|
|
||
|
|
# Mock the cursor context manager
|
||
|
|
db_client._get_cursor = MagicMock()
|
||
|
|
db_client._get_cursor.return_value.__enter__.return_value = mock_cursor
|
||
|
|
db_client._get_cursor.return_value.__exit__.return_value = None
|
||
|
|
|
||
|
|
with pytest.raises(Exception, match="Query failed"):
|
||
|
|
db_client.execute_query("SELECT * FROM test")
|
||
|
|
|
||
|
|
def test_execute_success(self, db_client):
|
||
|
|
"""Test successful execute operation."""
|
||
|
|
# Mock the cursor
|
||
|
|
mock_cursor = MagicMock()
|
||
|
|
|
||
|
|
# Mock the cursor context manager
|
||
|
|
db_client._get_cursor = MagicMock()
|
||
|
|
db_client._get_cursor.return_value.__enter__.return_value = mock_cursor
|
||
|
|
db_client._get_cursor.return_value.__exit__.return_value = None
|
||
|
|
|
||
|
|
db_client.execute("INSERT INTO test VALUES (%s)", (1,))
|
||
|
|
|
||
|
|
mock_cursor.execute.assert_called_once_with("INSERT INTO test VALUES (%s)", (1,))
|
||
|
|
|
||
|
|
def test_health_check_success(self, db_client):
|
||
|
|
"""Test successful health check."""
|
||
|
|
# Mock the cursor
|
||
|
|
mock_cursor = MagicMock()
|
||
|
|
mock_cursor.fetchone.return_value = {'health_check': 1}
|
||
|
|
|
||
|
|
# Mock the cursor context manager
|
||
|
|
db_client._get_cursor = MagicMock()
|
||
|
|
db_client._get_cursor.return_value.__enter__.return_value = mock_cursor
|
||
|
|
db_client._get_cursor.return_value.__exit__.return_value = None
|
||
|
|
|
||
|
|
result = db_client.health_check()
|
||
|
|
|
||
|
|
assert result is True
|
||
|
|
mock_cursor.execute.assert_called_once_with("SELECT 1 as health_check;")
|
||
|
|
|
||
|
|
def test_health_check_failure(self, db_client):
|
||
|
|
"""Test health check failure."""
|
||
|
|
# Mock the cursor to raise an exception
|
||
|
|
mock_cursor = MagicMock()
|
||
|
|
mock_cursor.execute.side_effect = Exception("Health check failed")
|
||
|
|
|
||
|
|
# Mock the cursor context manager
|
||
|
|
db_client._get_cursor = MagicMock()
|
||
|
|
db_client._get_cursor.return_value.__enter__.return_value = mock_cursor
|
||
|
|
db_client._get_cursor.return_value.__exit__.return_value = None
|
||
|
|
|
||
|
|
result = db_client.health_check()
|
||
|
|
|
||
|
|
assert result is False
|
||
|
|
|
||
|
|
def test_get_connection_stats(self, db_client):
|
||
|
|
"""Test connection statistics retrieval."""
|
||
|
|
stats = db_client.get_connection_stats()
|
||
|
|
|
||
|
|
expected = {
|
||
|
|
"min_connections": 1,
|
||
|
|
"max_connections": 2,
|
||
|
|
"pool_status": "active"
|
||
|
|
}
|
||
|
|
assert stats == expected
|
||
|
|
|
||
|
|
def test_get_connection_stats_no_pool(self):
|
||
|
|
"""Test connection statistics when pool is not initialized."""
|
||
|
|
client = DatabaseClient("test_url")
|
||
|
|
stats = client.get_connection_stats()
|
||
|
|
|
||
|
|
assert stats == {"status": "pool_not_initialized"}
|
||
|
|
|
||
|
|
@pytest.mark.asyncio
|
||
|
|
async def test_disconnect(self, db_client):
|
||
|
|
"""Test database disconnection."""
|
||
|
|
db_client.connection_pool.closeall = Mock()
|
||
|
|
|
||
|
|
await db_client.disconnect()
|
||
|
|
|
||
|
|
db_client.connection_pool.closeall.assert_called_once()
|
||
|
|
|
||
|
|
@pytest.mark.asyncio
|
||
|
|
async def test_disconnect_no_pool(self, db_client):
|
||
|
|
"""Test disconnection when no pool exists."""
|
||
|
|
db_client.connection_pool = None
|
||
|
|
|
||
|
|
# Should not raise an exception
|
||
|
|
await db_client.disconnect()
|